mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
add vit tf doctest with @add_code_sample_docstrings (#16636)
* add vit tf doctest with @add_code_sample_docstrings * add labels string back in Co-authored-by: Johannes Kolbe <johannes.kolbe@tech.better.team>
This commit is contained in:
parent
4ef0abb738
commit
9db2eebbe2
@ -33,14 +33,23 @@ from ...modeling_tf_utils import (
|
||||
unpack_inputs,
|
||||
)
|
||||
from ...tf_utils import shape_list
|
||||
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
||||
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
||||
from .configuration_vit import ViTConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
# General docstring
|
||||
_CONFIG_FOR_DOC = "ViTConfig"
|
||||
_CHECKPOINT_FOR_DOC = "google/vit-base-patch16-224"
|
||||
_FEAT_EXTRACTOR_FOR_DOC = "ViTFeatureExtractor"
|
||||
|
||||
# Base docstring
|
||||
_CHECKPOINT_FOR_DOC = "google/vit-base-patch16-224-in21k"
|
||||
_EXPECTED_OUTPUT_SHAPE = [1, 197, 768]
|
||||
|
||||
# Image classification docstring
|
||||
_IMAGE_CLASS_CHECKPOINT = "google/vit-base-patch16-224"
|
||||
_IMAGE_CLASS_EXPECTED_OUTPUT = "Egyptian cat"
|
||||
|
||||
|
||||
# Inspired by
|
||||
@ -645,7 +654,14 @@ class TFViTModel(TFViTPreTrainedModel):
|
||||
|
||||
@unpack_inputs
|
||||
@add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=TFBaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=TFBaseModelOutputWithPooling,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
modality="vision",
|
||||
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
pixel_values: Optional[TFModelInputType] = None,
|
||||
@ -656,26 +672,6 @@ class TFViTModel(TFViTPreTrainedModel):
|
||||
return_dict: Optional[bool] = None,
|
||||
training: bool = False,
|
||||
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from transformers import ViTFeatureExtractor, TFViTModel
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
|
||||
>>> model = TFViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
|
||||
|
||||
>>> inputs = feature_extractor(images=image, return_tensors="tf")
|
||||
>>> outputs = model(**inputs)
|
||||
>>> last_hidden_states = outputs.last_hidden_state
|
||||
```"""
|
||||
|
||||
outputs = self.vit(
|
||||
pixel_values=pixel_values,
|
||||
@ -744,7 +740,13 @@ class TFViTForImageClassification(TFViTPreTrainedModel, TFSequenceClassification
|
||||
|
||||
@unpack_inputs
|
||||
@add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=TFSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
|
||||
checkpoint=_IMAGE_CLASS_CHECKPOINT,
|
||||
output_type=TFSequenceClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
pixel_values: Optional[TFModelInputType] = None,
|
||||
@ -761,30 +763,7 @@ class TFViTForImageClassification(TFViTPreTrainedModel, TFSequenceClassification
|
||||
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
|
||||
Returns:
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from transformers import ViTFeatureExtractor, TFViTForImageClassification
|
||||
>>> import tensorflow as tf
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
|
||||
>>> model = TFViTForImageClassification.from_pretrained("google/vit-base-patch16-224")
|
||||
|
||||
>>> inputs = feature_extractor(images=image, return_tensors="tf")
|
||||
>>> outputs = model(**inputs)
|
||||
>>> logits = outputs.logits
|
||||
>>> # model predicts one of the 1000 ImageNet classes
|
||||
>>> predicted_class_idx = tf.math.argmax(logits, axis=-1)[0]
|
||||
>>> print("Predicted class:", model.config.id2label[int(predicted_class_idx)])
|
||||
```"""
|
||||
"""
|
||||
|
||||
outputs = self.vit(
|
||||
pixel_values=pixel_values,
|
||||
|
@ -36,6 +36,7 @@ src/transformers/models/van/modeling_van.py
|
||||
src/transformers/models/vilt/modeling_vilt.py
|
||||
src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py
|
||||
src/transformers/models/vit/modeling_vit.py
|
||||
src/transformers/models/vit/modeling_tf_vit.py
|
||||
src/transformers/models/vit_mae/modeling_vit_mae.py
|
||||
src/transformers/models/wav2vec2/modeling_wav2vec2.py
|
||||
src/transformers/models/wav2vec2/tokenization_wav2vec2.py
|
||||
|
Loading…
Reference in New Issue
Block a user