[ViTDet] Fix doc tests (#25880)

Fix docstrings
This commit is contained in:
NielsRogge 2023-08-30 22:49:03 +02:00 committed by GitHub
parent 1c6f072db0
commit 716bb2e391
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -27,7 +27,6 @@ from ...activations import ACT2FN
from ...modeling_outputs import BackboneOutput, BaseModelOutput
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
@ -42,10 +41,6 @@ logger = logging.get_logger(__name__)
# General docstring
_CONFIG_FOR_DOC = "VitDetConfig"
# Base docstring
_CHECKPOINT_FOR_DOC = "facebook/vit-det-base"
_EXPECTED_OUTPUT_SHAPE = [1, 197, 768]
VITDET_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/vit-det-base",
@ -737,13 +732,7 @@ class VitDetModel(VitDetPreTrainedModel):
self.encoder.layer[layer].attention.prune_heads(heads)
@add_start_docstrings_to_model_forward(VITDET_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=BaseModelOutput,
config_class=_CONFIG_FOR_DOC,
modality="vision",
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
@ -752,6 +741,27 @@ class VitDetModel(VitDetPreTrainedModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutput]:
"""
Returns:
Examples:
```python
>>> from transformers import VitDetConfig, VitDetModel
>>> import torch
>>> config = VitDetConfig()
>>> model = VitDetModel(config)
>>> pixel_values = torch.randn(1, 3, 224, 224)
>>> with torch.no_grad():
... outputs = model(pixel_values)
>>> last_hidden_states = outputs.last_hidden_state
>>> list(last_hidden_states.shape)
[1, 768, 14, 14]
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@ -825,19 +835,20 @@ class VitDetBackbone(VitDetPreTrainedModel, BackboneMixin):
Examples:
```python
>>> from transformers import AutoImageProcessor, AutoBackbone
>>> from transformers import VitDetConfig, VitDetBackbone
>>> import torch
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> config = VitDetConfig()
>>> model = VitDetBackbone(config)
>>> processor = AutoImageProcessor.from_pretrained("facebook/convnext-tiny-224")
>>> model = AutoBackbone.from_pretrained("facebook/convnext-tiny-224")
>>> pixel_values = torch.randn(1, 3, 224, 224)
>>> inputs = processor(image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> with torch.no_grad():
... outputs = model(pixel_values)
>>> feature_maps = outputs.feature_maps
>>> list(feature_maps[-1].shape)
[1, 768, 14, 14]
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = (