[ViTMAE] Various fixes (#15221)

* Add MAE to AutoFeatureExtractor

* Add link to notebook

* Fix relative paths
This commit is contained in:
NielsRogge 2022-01-19 15:27:57 +01:00 committed by GitHub
parent 6d92c429c7
commit 842298f84f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 22 additions and 16 deletions

View File

@ -65,21 +65,23 @@ Tips:
Following the original Vision Transformer, some follow-up works have been made:
- DeiT (Data-efficient Image Transformers) by Facebook AI. DeiT models are distilled vision transformers. Refer to
[DeiT's documentation page](deit). The authors of DeiT also released more efficiently trained ViT models, which
you can directly plug into [`ViTModel`] or [`ViTForImageClassification`]. There
are 4 variants available (in 3 different sizes): *facebook/deit-tiny-patch16-224*, *facebook/deit-small-patch16-224*,
*facebook/deit-base-patch16-224* and *facebook/deit-base-patch16-384*. Note that one should use
[`DeiTFeatureExtractor`] in order to prepare images for the model.
- [DeiT](deit) (Data-efficient Image Transformers) by Facebook AI. DeiT models are distilled vision transformers.
The authors of DeiT also released more efficiently trained ViT models, which you can directly plug into [`ViTModel`] or
[`ViTForImageClassification`]. There are 4 variants available (in 3 different sizes): *facebook/deit-tiny-patch16-224*,
*facebook/deit-small-patch16-224*, *facebook/deit-base-patch16-224* and *facebook/deit-base-patch16-384*. Note that one should
use [`DeiTFeatureExtractor`] in order to prepare images for the model.
- BEiT (BERT pre-training of Image Transformers) by Microsoft Research. BEiT models outperform supervised pre-trained
- [BEiT](beit) (BERT pre-training of Image Transformers) by Microsoft Research. BEiT models outperform supervised pre-trained
vision transformers using a self-supervised method inspired by BERT (masked image modeling) and based on a VQ-VAE.
Refer to [BEiT's documentation page](beit).
- DINO (a method for self-supervised training of Vision Transformers) by Facebook AI. Vision Transformers trained using
the DINO method show very interesting properties not seen with convolutional models. They are capable of segmenting
objects, without having ever been trained to do so. DINO checkpoints can be found on the [hub](https://huggingface.co/models?other=dino).
- [MAE](vit_mae) (Masked Autoencoders) by Facebook AI. By pre-training Vision Transformers to reconstruct pixel values for a high portion
(75%) of masked patches (using an asymmetric encoder-decoder architecture), the authors show that this simple method outperforms
supervised pre-training after fine-tuning.
This model was contributed by [nielsr](https://huggingface.co/nielsr). The original code (written in JAX) can be
found [here](https://github.com/google-research/vision_transformer).

View File

@ -32,6 +32,7 @@ Tips:
- MAE (masked auto encoding) is a method for self-supervised pre-training of Vision Transformers (ViTs). The pre-training objective is relatively simple:
by masking a large portion (75%) of the image patches, the model must reconstruct raw pixel values. One can use [`ViTMAEForPreTraining`] for this purpose.
- A notebook that illustrates how to visualize reconstructed pixel values with [`ViTMAEForPreTraining`] can be found [here](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/ViTMAE/ViT_MAE_visualization_demo.ipynb).
- After pre-training, one "throws away" the decoder used to reconstruct pixels, and one uses the encoder for fine-tuning/linear probing. This means that after
fine-tuning, one can directly plug in the weights into a [`ViTForImageClassification`].
- One can use [`ViTFeatureExtractor`] to prepare images for the model. See the code examples for more info.

View File

@ -44,6 +44,7 @@ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
("layoutlmv2", "LayoutLMv2FeatureExtractor"),
("clip", "CLIPFeatureExtractor"),
("perceiver", "PerceiverFeatureExtractor"),
("vit_mae", "ViTFeatureExtractor"),
]
)

View File

@ -603,8 +603,8 @@ VIT_MAE_START_DOCSTRING = r"""
VIT_MAE_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`ViTFeatureExtractor`]. See
[`ViTFeatureExtractor.__call__`] for details.
Pixel values. Pixel values can be obtained using [`AutoFeatureExtractor`]. See
[`AutoFeatureExtractor.__call__`] for details.
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
@ -667,14 +667,14 @@ class ViTMAEModel(ViTMAEPreTrainedModel):
Examples:
```python
>>> from transformers import ViTFeatureExtractor, ViTMAEModel
>>> from transformers import AutoFeatureExtractor, ViTMAEModel
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> feature_extractor = ViTFeatureExtractor.from_pretrained("facebook/vit-mae-base")
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/vit-mae-base")
>>> model = ViTMAEModel.from_pretrained("facebook/vit-mae-base")
>>> inputs = feature_extractor(images=image, return_tensors="pt")
@ -909,19 +909,21 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel):
Examples:
```python
>>> from transformers import ViTFeatureExtractor, ViTMAEModel
>>> from transformers import AutoFeatureExtractor, ViTMAEForPreTraining
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> feature_extractor = ViTFeatureExtractor.from_pretrained("facebook/vit-mae-base")
>>> model = ViTMAEModel.from_pretrained("facebook/vit-mae-base")
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/vit-mae-base")
>>> model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base")
>>> inputs = feature_extractor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
>>> loss = outputs.loss
>>> mask = outputs.mask
>>> ids_restore = outputs.ids_restore
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict