add vit tf doctest with @add_code_sample_docstrings (#16636)

* add vit tf doctest with @add_code_sample_docstrings

* add labels string back in

Co-authored-by: Johannes Kolbe <johannes.kolbe@tech.better.team>
This commit is contained in:
Johannes Kolbe 2022-04-08 13:31:38 +02:00 committed by GitHub
parent 4ef0abb738
commit 9db2eebbe2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 28 additions and 48 deletions

View File

@ -33,14 +33,23 @@ from ...modeling_tf_utils import (
unpack_inputs,
)
from ...tf_utils import shape_list
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_vit import ViTConfig
logger = logging.get_logger(__name__)
# General docstring
_CONFIG_FOR_DOC = "ViTConfig"
_CHECKPOINT_FOR_DOC = "google/vit-base-patch16-224"
_FEAT_EXTRACTOR_FOR_DOC = "ViTFeatureExtractor"
# Base docstring
_CHECKPOINT_FOR_DOC = "google/vit-base-patch16-224-in21k"
_EXPECTED_OUTPUT_SHAPE = [1, 197, 768]
# Image classification docstring
_IMAGE_CLASS_CHECKPOINT = "google/vit-base-patch16-224"
_IMAGE_CLASS_EXPECTED_OUTPUT = "Egyptian cat"
# Inspired by
@ -645,7 +654,14 @@ class TFViTModel(TFViTPreTrainedModel):
@unpack_inputs
@add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFBaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
@add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TFBaseModelOutputWithPooling,
config_class=_CONFIG_FOR_DOC,
modality="vision",
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
def call(
self,
pixel_values: Optional[TFModelInputType] = None,
@ -656,26 +672,6 @@ class TFViTModel(TFViTPreTrainedModel):
return_dict: Optional[bool] = None,
training: bool = False,
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
r"""
Returns:
Examples:
```python
>>> from transformers import ViTFeatureExtractor, TFViTModel
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
>>> model = TFViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
>>> inputs = feature_extractor(images=image, return_tensors="tf")
>>> outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
```"""
outputs = self.vit(
pixel_values=pixel_values,
@ -744,7 +740,13 @@ class TFViTForImageClassification(TFViTPreTrainedModel, TFSequenceClassification
@unpack_inputs
@add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
@add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_IMAGE_CLASS_CHECKPOINT,
output_type=TFSequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
)
def call(
self,
pixel_values: Optional[TFModelInputType] = None,
@ -761,30 +763,7 @@ class TFViTForImageClassification(TFViTPreTrainedModel, TFSequenceClassification
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
Returns:
Examples:
```python
>>> from transformers import ViTFeatureExtractor, TFViTForImageClassification
>>> import tensorflow as tf
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
>>> model = TFViTForImageClassification.from_pretrained("google/vit-base-patch16-224")
>>> inputs = feature_extractor(images=image, return_tensors="tf")
>>> outputs = model(**inputs)
>>> logits = outputs.logits
>>> # model predicts one of the 1000 ImageNet classes
>>> predicted_class_idx = tf.math.argmax(logits, axis=-1)[0]
>>> print("Predicted class:", model.config.id2label[int(predicted_class_idx)])
```"""
"""
outputs = self.vit(
pixel_values=pixel_values,

View File

@ -36,6 +36,7 @@ src/transformers/models/van/modeling_van.py
src/transformers/models/vilt/modeling_vilt.py
src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py
src/transformers/models/vit/modeling_vit.py
src/transformers/models/vit/modeling_tf_vit.py
src/transformers/models/vit_mae/modeling_vit_mae.py
src/transformers/models/wav2vec2/modeling_wav2vec2.py
src/transformers/models/wav2vec2/tokenization_wav2vec2.py