add sdpa to ViT [follow up of #29325] (#30555)

remove blank line (+1 squashed commit)
Squashed commits:
[24ccd2061] [run-slow]vit_msn,vision_encoder_decoder (+24 squashed commits)
Squashed commits:
[08bd27e7a] [run-slow]vit_msn,vision_encoder_decoder
[ec96a8db3] [run-slow]vit_msn
[ead817eca] fix vit msn multi gpu
[d12cdc8fd] [run-slow]audio_spectrogram_transformer,deit,vision_encoder_decoder,vision_text_dual_encoder,vit,vit_hybrid,vit_mae,vit_msn,videomae,yolos
[3fdbfa88f] doc
[a3ff33e4a] finish implementation
[e20b7b7fb] Update test_modeling_common.py
[e290c5810] Update test_modeling_flax_common.py
[d3af86f46] comment
[ff7dd32d8] more comments
[59b137889] suggestion
[7e2ba6d67] attn_implementation as attribute of the class
[fe66ab71f] minor
[38642b568] Apply suggestions from code review

Accept comments

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
[22cde7d52] Update tests/test_modeling_common.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
[48e137cc6] Update tests/test_modeling_common.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
[99f4c679f] Update tests/test_modeling_common.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
[96cf20a6d] Update src/transformers/models/vit_msn/modeling_vit_msn.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
[c59377d23] Update src/transformers/models/vit_mae/modeling_vit_mae.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
[b70a47259] Update tests/models/vision_text_dual_encoder/test_modeling_vision_text_dual_encoder.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
[00c84d216] [run-slow]audio_spectrogram_transformer,deit,vision_encoder_decoder,vision_text_dual_encoder,vit,vit_hybrid,vit_mae,vit_msn,videomae,yolos
[61f00ebb0] all tests are passing locally
[e9e0b82b7] vision encoder/decoder
[4d5076b56] test-vision (+20 squashed commits)
Squashed commits:
[d1add8db9] yolo
[9fde65716] fix flax
[986566c28] minor
[ca2f21d1f] vit
[3333efd7a] easy models change
[ebfc21402] [run-slow]audio_spectrogram_transformer,deit,vision_encoder_decoder,vision_text_dual_encoder,vit,vit_hybrid,vit_mae,vit_msn,videomae,yolos
[b8b8603ed] [run-slow]vision_encoder_decoder,vision_text_dual_encoder,yolos
[48ecc7e26] all tests are passing locally
[bff7fc366] minor
[62f88306f] fix yolo and text_encoder tests
[121507555] [run-slow]audio_spectrogram_transformer,deit,vit,vit_hybrid,vit_mae,vit_msn,videomae
[1064cae0a] [run-slow]vision_encoder_decoder,vision_text_dual_encoder,yolos
[b7f52ff3a] [run-slow]audio_spectrogram_transformer,deit,vit,vit_hybrid,vit_mae,vit_msn,videomae
[cffaa10dd] fix-copies
[ef6c511c4] test vit hybrid
[7d4ba8644] vit hybrid
[66f919033] [run-slow]audio_spectrogram_transformer,deit,vit,vit_hybrid,vit_mae,vit_msn,videomae
[1fcc0a031] fixes
[cfde6eb21] fixup
[e77df1ed3] all except yolo end encoder decoder (+17 squashed commits)
Squashed commits:
[602913e22] vit + vit_mae are working
[547f6c4cc] RUN_SLOW=1 pytest tests/models/audio_spectrogram_transformer/ tests/models/deit/ tests/models/videomae/  passes
[61a97dfa9] it s the complete opposite...
[aefab37d4] fix more tests
[71802a1b9] fix all torch tests
[40b12eb58] encoder - decoder tests
[941552b69] slow decorator where appropriate
[14d055d80] has_attentions to yolo and msn
[3381fa19f] add correct name
[e261316a7] repo consistency
[31c6d0c08] fixup
[9d214276c] minor fix
[11ed2e1b7] chore
[eca6644c4] add sdpa to vit-based models
[cffbf390b] make fix-copies result
[6468319b0] fix style
[d324cd02a] add sdpa for vit

Co-authored-by: Liubov Yaronskaya <luba.yaronskaya@gmail.com>
This commit is contained in:
hyenal 2024-05-16 10:56:11 +01:00 committed by GitHub
parent 9fd606dbdb
commit 1c21f48a50
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
34 changed files with 709 additions and 26 deletions

View File

@ -43,6 +43,34 @@ the authors compute the stats for a downstream dataset.
- Note that the AST needs a low learning rate (the authors use a 10 times smaller learning rate compared to their CNN model proposed in the - Note that the AST needs a low learning rate (the authors use a 10 times smaller learning rate compared to their CNN model proposed in the
[PSLA paper](https://arxiv.org/abs/2102.01243)) and converges quickly, so please search for a suitable learning rate and learning rate scheduler for your task. [PSLA paper](https://arxiv.org/abs/2102.01243)) and converges quickly, so please search for a suitable learning rate and learning rate scheduler for your task.
### Using Scaled Dot Product Attention (SDPA)
PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
page for more information.
SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
```
from transformers import ASTForAudioClassification
model = ASTForAudioClassification.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593", attn_implementation="sdpa", torch_dtype=torch.float16)
...
```
For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).
On a local benchmark (A100-40GB, PyTorch 2.3.0, OS Ubuntu 22.04) with `float32` and `MIT/ast-finetuned-audioset-10-10-0.4593` model, we saw the following speedups during inference.
| Batch size | Average inference time (ms), eager mode | Average inference time (ms), sdpa model | Speed up, Sdpa / Eager (x) |
|--------------|-------------------------------------------|-------------------------------------------|------------------------------|
| 1 | 27 | 6 | 4.5 |
| 2 | 12 | 6 | 2 |
| 4 | 21 | 8 | 2.62 |
| 8 | 40 | 14 | 2.86 |
## Resources ## Resources
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with the Audio Spectrogram Transformer. A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with the Audio Spectrogram Transformer.

View File

@ -68,6 +68,34 @@ This model was contributed by [nielsr](https://huggingface.co/nielsr). The Tenso
*facebook/deit-base-patch16-384*. Note that one should use [`DeiTImageProcessor`] in order to *facebook/deit-base-patch16-384*. Note that one should use [`DeiTImageProcessor`] in order to
prepare images for the model. prepare images for the model.
### Using Scaled Dot Product Attention (SDPA)
PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
page for more information.
SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
```
from transformers import DeiTForImageClassification
model = DeiTForImageClassification.from_pretrained("facebook/deit-base-distilled-patch16-224", attn_implementation="sdpa", torch_dtype=torch.float16)
...
```
For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).
On a local benchmark (A100-40GB, PyTorch 2.3.0, OS Ubuntu 22.04) with `float32` and `facebook/deit-base-distilled-patch16-224` model, we saw the following speedups during inference.
| Batch size | Average inference time (ms), eager mode | Average inference time (ms), sdpa model | Speed up, Sdpa / Eager (x) |
|--------------|-------------------------------------------|-------------------------------------------|------------------------------|
| 1 | 8 | 6 | 1.33 |
| 2 | 9 | 6 | 1.5 |
| 4 | 9 | 6 | 1.5 |
| 8 | 8 | 6 | 1.33 |
## Resources ## Resources
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with DeiT. A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with DeiT.

View File

@ -33,6 +33,34 @@ alt="drawing" width="600"/>
This model was contributed by [nielsr](https://huggingface.co/nielsr). This model was contributed by [nielsr](https://huggingface.co/nielsr).
The original code can be found [here](https://github.com/MCG-NJU/VideoMAE). The original code can be found [here](https://github.com/MCG-NJU/VideoMAE).
## Using Scaled Dot Product Attention (SDPA)
PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
page for more information.
SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
```
from transformers import VideoMAEForVideoClassification
model = VideoMAEForVideoClassification.from_pretrained("MCG-NJU/videomae-base-finetuned-kinetics", attn_implementation="sdpa", torch_dtype=torch.float16)
...
```
For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).
On a local benchmark (A100-40GB, PyTorch 2.3.0, OS Ubuntu 22.04) with `float32` and `MCG-NJU/videomae-base-finetuned-kinetics` model, we saw the following speedups during inference.
| Batch size | Average inference time (ms), eager mode | Average inference time (ms), sdpa model | Speed up, Sdpa / Eager (x) |
|--------------|-------------------------------------------|-------------------------------------------|------------------------------|
| 1 | 37 | 10 | 3.7 |
| 2 | 24 | 18 | 1.33 |
| 4 | 43 | 32 | 1.34 |
| 8 | 84 | 60 | 1.4 |
## Resources ## Resources
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with VideoMAE. If A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with VideoMAE. If

View File

@ -88,6 +88,34 @@ who already converted the weights from JAX to PyTorch. Credits go to him!
language modeling). With this approach, the smaller ViT-B/16 model achieves 79.9% accuracy on ImageNet, a significant language modeling). With this approach, the smaller ViT-B/16 model achieves 79.9% accuracy on ImageNet, a significant
improvement of 2% to training from scratch, but still 4% behind supervised pre-training. improvement of 2% to training from scratch, but still 4% behind supervised pre-training.
### Using Scaled Dot Product Attention (SDPA)
PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
page for more information.
SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
```
from transformers import ViTForImageClassification
model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224", attn_implementation="sdpa", torch_dtype=torch.float16)
...
```
For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).
On a local benchmark (A100-40GB, PyTorch 2.3.0, OS Ubuntu 22.04) with `float32` and `google/vit-base-patch16-224` model, we saw the following speedups during inference.
| Batch size | Average inference time (ms), eager mode | Average inference time (ms), sdpa model | Speed up, Sdpa / Eager (x) |
|--------------|-------------------------------------------|-------------------------------------------|------------------------------|
| 1 | 7 | 6 | 1.17 |
| 2 | 8 | 6 | 1.33 |
| 4 | 8 | 6 | 1.33 |
| 8 | 8 | 6 | 1.33 |
## Resources ## Resources
Demo notebooks regarding inference as well as fine-tuning ViT on custom data can be found [here](https://github.com/NielsRogge/Transformers-Tutorials/tree/master/VisionTransformer). Demo notebooks regarding inference as well as fine-tuning ViT on custom data can be found [here](https://github.com/NielsRogge/Transformers-Tutorials/tree/master/VisionTransformer).

View File

@ -39,6 +39,34 @@ substantially fewer computational resources to train.*
This model was contributed by [nielsr](https://huggingface.co/nielsr). The original code (written in JAX) can be This model was contributed by [nielsr](https://huggingface.co/nielsr). The original code (written in JAX) can be
found [here](https://github.com/google-research/vision_transformer). found [here](https://github.com/google-research/vision_transformer).
## Using Scaled Dot Product Attention (SDPA)
PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
page for more information.
SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
```
from transformers import ViTHybridForImageClassification
model = ViTHybridForImageClassification.from_pretrained("google/vit-hybrid-base-bit-384", attn_implementation="sdpa", torch_dtype=torch.float16)
...
```
For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).
On a local benchmark (A100-40GB, PyTorch 2.3.0, OS Ubuntu 22.04) with `float32` and `google/vit-hybrid-base-bit-384` model, we saw the following speedups during inference.
| Batch size | Average inference time (ms), eager mode | Average inference time (ms), sdpa model | Speed up, Sdpa / Eager (x) |
|--------------|-------------------------------------------|-------------------------------------------|------------------------------|
| 1 | 29 | 18 | 1.61 |
| 2 | 26 | 18 | 1.44 |
| 4 | 25 | 18 | 1.39 |
| 8 | 34 | 24 | 1.42 |
## Resources ## Resources
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with ViT Hybrid. A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with ViT Hybrid.

View File

@ -52,6 +52,34 @@ consists of Transformer blocks) takes as input. Each mask token is a shared, lea
sin/cos position embeddings are added both to the input of the encoder and the decoder. sin/cos position embeddings are added both to the input of the encoder and the decoder.
- For a visual understanding of how MAEs work you can check out this [post](https://keras.io/examples/vision/masked_image_modeling/). - For a visual understanding of how MAEs work you can check out this [post](https://keras.io/examples/vision/masked_image_modeling/).
### Using Scaled Dot Product Attention (SDPA)
PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
page for more information.
SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
```
from transformers import ViTMAEModel
model = ViTMAEModel.from_pretrained("facebook/vit-mae-base", attn_implementation="sdpa", torch_dtype=torch.float16)
...
```
For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).
On a local benchmark (A100-40GB, PyTorch 2.3.0, OS Ubuntu 22.04) with `float32` and `facebook/vit-mae-base` model, we saw the following speedups during inference.
| Batch size | Average inference time (ms), eager mode | Average inference time (ms), sdpa model | Speed up, Sdpa / Eager (x) |
|--------------|-------------------------------------------|-------------------------------------------|------------------------------|
| 1 | 11 | 6 | 1.83 |
| 2 | 8 | 6 | 1.33 |
| 4 | 8 | 6 | 1.33 |
| 8 | 8 | 6 | 1.33 |
## Resources ## Resources
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with ViTMAE. A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with ViTMAE.

View File

@ -49,6 +49,34 @@ use the [`ViTMSNForImageClassification`] class which is initialized from [`ViTMS
- MSN is particularly useful in the low-shot and extreme low-shot regimes. Notably, it achieves 75.7% top-1 accuracy with only 1% of ImageNet-1K - MSN is particularly useful in the low-shot and extreme low-shot regimes. Notably, it achieves 75.7% top-1 accuracy with only 1% of ImageNet-1K
labels when fine-tuned. labels when fine-tuned.
### Using Scaled Dot Product Attention (SDPA)
PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
page for more information.
SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
```
from transformers import ViTMSNForImageClassification
model = ViTMSNForImageClassification.from_pretrained("facebook/vit-msn-base", attn_implementation="sdpa", torch_dtype=torch.float16)
...
```
For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).
On a local benchmark (A100-40GB, PyTorch 2.3.0, OS Ubuntu 22.04) with `float32` and `facebook/vit-msn-base` model, we saw the following speedups during inference.
| Batch size | Average inference time (ms), eager mode | Average inference time (ms), sdpa model | Speed up, Sdpa / Eager (x) |
|--------------|-------------------------------------------|-------------------------------------------|------------------------------|
| 1 | 7 | 6 | 1.17 |
| 2 | 8 | 6 | 1.33 |
| 4 | 8 | 6 | 1.33 |
| 8 | 8 | 6 | 1.33 |
## Resources ## Resources
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with ViT MSN. A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with ViT MSN.

View File

@ -32,6 +32,34 @@ alt="drawing" width="600"/>
This model was contributed by [nielsr](https://huggingface.co/nielsr). The original code can be found [here](https://github.com/hustvl/YOLOS). This model was contributed by [nielsr](https://huggingface.co/nielsr). The original code can be found [here](https://github.com/hustvl/YOLOS).
## Using Scaled Dot Product Attention (SDPA)
PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
page for more information.
SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
```
from transformers import AutoModelForObjectDetection
model = AutoModelForObjectDetection.from_pretrained("hustvl/yolos-base", attn_implementation="sdpa", torch_dtype=torch.float16)
...
```
For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).
On a local benchmark (A100-40GB, PyTorch 2.3.0, OS Ubuntu 22.04) with `float32` and `hustvl/yolos-base` model, we saw the following speedups during inference.
| Batch size | Average inference time (ms), eager mode | Average inference time (ms), sdpa model | Speed up, Sdpa / Eager (x) |
|--------------|-------------------------------------------|-------------------------------------------|------------------------------|
| 1 | 106 | 76 | 1.39 |
| 2 | 154 | 90 | 1.71 |
| 4 | 222 | 116 | 1.91 |
| 8 | 368 | 168 | 2.19 |
## Resources ## Resources
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with YOLOS. A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with YOLOS.

View File

@ -192,10 +192,12 @@ FlashAttention is more memory efficient, meaning you can train on much larger se
PyTorch's [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) (SDPA) can also call FlashAttention and memory-efficient attention kernels under the hood. SDPA support is currently being added natively in Transformers and is used by default for `torch>=2.1.1` when an implementation is available. You may also set `attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used. PyTorch's [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) (SDPA) can also call FlashAttention and memory-efficient attention kernels under the hood. SDPA support is currently being added natively in Transformers and is used by default for `torch>=2.1.1` when an implementation is available. You may also set `attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
For now, Transformers supports SDPA inference and training for the following architectures: For now, Transformers supports SDPA inference and training for the following architectures:
* [Audio Spectrogram Transformer](https://huggingface.co/docs/transformers/model_doc/audio-spectrogram-transformer#transformers.ASTModel)
* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel) * [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel)
* [Bert](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel) * [Bert](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel)
* [Cohere](https://huggingface.co/docs/transformers/model_doc/cohere#transformers.CohereModel) * [Cohere](https://huggingface.co/docs/transformers/model_doc/cohere#transformers.CohereModel)
* [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel) * [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel)
* [DeiT](https://huggingface.co/docs/transformers/model_doc/deit#transformers.DeiTModel)
* [Dpr](https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DprReader) * [Dpr](https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DprReader)
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel) * [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
* [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel) * [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel)
@ -216,12 +218,18 @@ For now, Transformers supports SDPA inference and training for the following arc
* [Qwen2MoE](https://huggingface.co/docs/transformers/model_doc/qwen2_moe#transformers.Qwen2MoeModel) * [Qwen2MoE](https://huggingface.co/docs/transformers/model_doc/qwen2_moe#transformers.Qwen2MoeModel)
* [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel) * [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel)
* [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel) * [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel)
* [ViT](https://huggingface.co/docs/transformers/model_doc/vit#transformers.ViTModel)
* [ViTHybrid](https://huggingface.co/docs/transformers/model_doc/vit_hybrid#transformers.ViTHybridModel)
* [ViTMAE](https://huggingface.co/docs/transformers/model_doc/vit_mae#transformers.ViTMAEModel)
* [ViTMSN](https://huggingface.co/docs/transformers/model_doc/vit_msn#transformers.ViTMSNModel)
* [VideoMAE](https://huggingface.co/docs/transformers/model_doc/videomae#transformers.VideoMAEModell)
* [wav2vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2#transformers.Wav2Vec2Model) * [wav2vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2#transformers.Wav2Vec2Model)
* [Hubert](https://huggingface.co/docs/transformers/model_doc/hubert#transformers.HubertModel) * [Hubert](https://huggingface.co/docs/transformers/model_doc/hubert#transformers.HubertModel)
* [data2vec_audio](https://huggingface.co/docs/transformers/main/en/model_doc/data2vec#transformers.Data2VecAudioModel) * [data2vec_audio](https://huggingface.co/docs/transformers/main/en/model_doc/data2vec#transformers.Data2VecAudioModel)
* [Sew](https://huggingface.co/docs/transformers/main/en/model_doc/sew#transformers.SEWModel) * [Sew](https://huggingface.co/docs/transformers/main/en/model_doc/sew#transformers.SEWModel)
* [UniSpeech](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech#transformers.UniSpeechModel) * [UniSpeech](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech#transformers.UniSpeechModel)
* [unispeech_sat](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech-sat#transformers.UniSpeechSatModel) * [unispeech_sat](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech-sat#transformers.UniSpeechSatModel)
* [YOLOS](https://huggingface.co/docs/transformers/model_doc/yolos#transformers.YolosModel)
<Tip> <Tip>

View File

@ -169,6 +169,38 @@ class ASTSelfAttention(nn.Module):
return outputs return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->AST
class ASTSdpaSelfAttention(ASTSelfAttention):
def __init__(self, config: ASTConfig) -> None:
super().__init__(config)
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
head_mask,
self.attention_probs_dropout_prob if self.training else 0.0,
is_causal=False,
scale=None,
)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
return context_layer, None
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->AST # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->AST
class ASTSelfOutput(nn.Module): class ASTSelfOutput(nn.Module):
""" """
@ -228,6 +260,13 @@ class ASTAttention(nn.Module):
return outputs return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->AST
class ASTSdpaAttention(ASTAttention):
def __init__(self, config: ASTConfig) -> None:
super().__init__(config)
self.attention = ASTSdpaSelfAttention(config)
# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->AST # Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->AST
class ASTIntermediate(nn.Module): class ASTIntermediate(nn.Module):
def __init__(self, config: ASTConfig) -> None: def __init__(self, config: ASTConfig) -> None:
@ -261,7 +300,13 @@ class ASTOutput(nn.Module):
return hidden_states return hidden_states
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->AST AST_ATTENTION_CLASSES = {
"eager": ASTAttention,
"sdpa": ASTSdpaAttention,
}
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->AST,VIT->AST
class ASTLayer(nn.Module): class ASTLayer(nn.Module):
"""This corresponds to the Block class in the timm implementation.""" """This corresponds to the Block class in the timm implementation."""
@ -269,7 +314,7 @@ class ASTLayer(nn.Module):
super().__init__() super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1 self.seq_len_dim = 1
self.attention = ASTAttention(config) self.attention = AST_ATTENTION_CLASSES[config._attn_implementation](config)
self.intermediate = ASTIntermediate(config) self.intermediate = ASTIntermediate(config)
self.output = ASTOutput(config) self.output = ASTOutput(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
@ -366,6 +411,7 @@ class ASTPreTrainedModel(PreTrainedModel):
base_model_prefix = "audio_spectrogram_transformer" base_model_prefix = "audio_spectrogram_transformer"
main_input_name = "input_values" main_input_name = "input_values"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_supports_sdpa = True
# Copied from transformers.models.deit.modeling_deit.DeiTPreTrainedModel._init_weights # Copied from transformers.models.deit.modeling_deit.DeiTPreTrainedModel._init_weights
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:

View File

@ -190,6 +190,38 @@ class DeiTSelfAttention(nn.Module):
return outputs return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->DeiT
class DeiTSdpaSelfAttention(DeiTSelfAttention):
def __init__(self, config: DeiTConfig) -> None:
super().__init__(config)
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
head_mask,
self.attention_probs_dropout_prob if self.training else 0.0,
is_causal=False,
scale=None,
)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
return context_layer, None
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->DeiT # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->DeiT
class DeiTSelfOutput(nn.Module): class DeiTSelfOutput(nn.Module):
""" """
@ -249,6 +281,13 @@ class DeiTAttention(nn.Module):
return outputs return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->DeiT
class DeiTSdpaAttention(DeiTAttention):
def __init__(self, config: DeiTConfig) -> None:
super().__init__(config)
self.attention = DeiTSdpaSelfAttention(config)
# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->DeiT # Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->DeiT
class DeiTIntermediate(nn.Module): class DeiTIntermediate(nn.Module):
def __init__(self, config: DeiTConfig) -> None: def __init__(self, config: DeiTConfig) -> None:
@ -282,7 +321,13 @@ class DeiTOutput(nn.Module):
return hidden_states return hidden_states
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->DeiT DEIT_ATTENTION_CLASSES = {
"eager": DeiTAttention,
"sdpa": DeiTSdpaAttention,
}
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->DeiT,VIT->DEIT
class DeiTLayer(nn.Module): class DeiTLayer(nn.Module):
"""This corresponds to the Block class in the timm implementation.""" """This corresponds to the Block class in the timm implementation."""
@ -290,7 +335,7 @@ class DeiTLayer(nn.Module):
super().__init__() super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1 self.seq_len_dim = 1
self.attention = DeiTAttention(config) self.attention = DEIT_ATTENTION_CLASSES[config._attn_implementation](config)
self.intermediate = DeiTIntermediate(config) self.intermediate = DeiTIntermediate(config)
self.output = DeiTOutput(config) self.output = DeiTOutput(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
@ -388,6 +433,7 @@ class DeiTPreTrainedModel(PreTrainedModel):
main_input_name = "pixel_values" main_input_name = "pixel_values"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["DeiTLayer"] _no_split_modules = ["DeiTLayer"]
_supports_sdpa = True
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights""" """Initialize the weights"""

50
src/transformers/models/videomae/modeling_videomae.py Normal file → Executable file
View File

@ -134,7 +134,6 @@ class VideoMAEEmbeddings(nn.Module):
# add position embeddings # add position embeddings
embeddings = embeddings + self.position_embeddings.type_as(embeddings).to(embeddings.device).clone().detach() embeddings = embeddings + self.position_embeddings.type_as(embeddings).to(embeddings.device).clone().detach()
# only keep visible patches # only keep visible patches
# ~bool_masked_pos means visible # ~bool_masked_pos means visible
if bool_masked_pos is not None: if bool_masked_pos is not None:
@ -268,6 +267,40 @@ class VideoMAESelfAttention(nn.Module):
return outputs return outputs
class VideoMAESdpaSelfAttention(VideoMAESelfAttention):
def __init__(self, config: VideoMAEConfig) -> None:
super().__init__(config)
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
k_bias = torch.zeros_like(self.v_bias, requires_grad=False) if self.q_bias is not None else None
keys = nn.functional.linear(input=hidden_states, weight=self.key.weight, bias=k_bias)
values = nn.functional.linear(input=hidden_states, weight=self.value.weight, bias=self.v_bias)
queries = nn.functional.linear(input=hidden_states, weight=self.query.weight, bias=self.q_bias)
key_layer = self.transpose_for_scores(keys)
value_layer = self.transpose_for_scores(values)
query_layer = self.transpose_for_scores(queries)
context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
head_mask,
self.attention_probs_dropout_prob if self.training else 0.0,
is_causal=False,
scale=None,
)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
return context_layer, None
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->VideoMAE # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->VideoMAE
class VideoMAESelfOutput(nn.Module): class VideoMAESelfOutput(nn.Module):
""" """
@ -327,6 +360,13 @@ class VideoMAEAttention(nn.Module):
return outputs return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->VideoMAE
class VideoMAESdpaAttention(VideoMAEAttention):
def __init__(self, config: VideoMAEConfig) -> None:
super().__init__(config)
self.attention = VideoMAESdpaSelfAttention(config)
# Copied from transformers.models.vit.modeling_vit.ViTIntermediate ViT->VideoMAE # Copied from transformers.models.vit.modeling_vit.ViTIntermediate ViT->VideoMAE
class VideoMAEIntermediate(nn.Module): class VideoMAEIntermediate(nn.Module):
def __init__(self, config: VideoMAEConfig) -> None: def __init__(self, config: VideoMAEConfig) -> None:
@ -360,7 +400,10 @@ class VideoMAEOutput(nn.Module):
return hidden_states return hidden_states
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->VideoMAE VIDEOMAE_ATTENTION_CLASSES = {"eager": VideoMAEAttention, "sdpa": VideoMAESdpaAttention}
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->VideoMAE,VIT->VIDEOMAE
class VideoMAELayer(nn.Module): class VideoMAELayer(nn.Module):
"""This corresponds to the Block class in the timm implementation.""" """This corresponds to the Block class in the timm implementation."""
@ -368,7 +411,7 @@ class VideoMAELayer(nn.Module):
super().__init__() super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1 self.seq_len_dim = 1
self.attention = VideoMAEAttention(config) self.attention = VIDEOMAE_ATTENTION_CLASSES[config._attn_implementation](config)
self.intermediate = VideoMAEIntermediate(config) self.intermediate = VideoMAEIntermediate(config)
self.output = VideoMAEOutput(config) self.output = VideoMAEOutput(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
@ -465,6 +508,7 @@ class VideoMAEPreTrainedModel(PreTrainedModel):
base_model_prefix = "videomae" base_model_prefix = "videomae"
main_input_name = "pixel_values" main_input_name = "pixel_values"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_supports_sdpa = True
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""

View File

@ -336,8 +336,20 @@ class VisionEncoderDecoderModel(PreTrainedModel):
del tf_model del tf_model
gc.collect() gc.collect()
attn_implementation = kwargs.get("attn_implementation", None)
kwargs_encoder_decoder = {}
if attn_implementation:
kwargs_encoder_decoder = {
"encoder_attn_implementation": attn_implementation,
"decoder_attn_implementation": attn_implementation,
}
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained( model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
encoder_dir, decoder_dir, encoder_from_tf=True, decoder_from_tf=True encoder_dir,
decoder_dir,
encoder_from_tf=True,
decoder_from_tf=True,
**kwargs_encoder_decoder,
) )
# This is only for copying some specific attributes of this particular model. # This is only for copying some specific attributes of this particular model.
model.config = config model.config = config

View File

@ -236,6 +236,37 @@ class ViTSelfAttention(nn.Module):
return outputs return outputs
class ViTSdpaSelfAttention(ViTSelfAttention):
def __init__(self, config: ViTConfig) -> None:
super().__init__(config)
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
head_mask,
self.attention_probs_dropout_prob if self.training else 0.0,
is_causal=False,
scale=None,
)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
return context_layer, None
class ViTSelfOutput(nn.Module): class ViTSelfOutput(nn.Module):
""" """
The residual connection is defined in ViTLayer instead of here (as is the case with other models), due to the The residual connection is defined in ViTLayer instead of here (as is the case with other models), due to the
@ -293,6 +324,12 @@ class ViTAttention(nn.Module):
return outputs return outputs
class ViTSdpaAttention(ViTAttention):
def __init__(self, config: ViTConfig) -> None:
super().__init__(config)
self.attention = ViTSdpaSelfAttention(config)
class ViTIntermediate(nn.Module): class ViTIntermediate(nn.Module):
def __init__(self, config: ViTConfig) -> None: def __init__(self, config: ViTConfig) -> None:
super().__init__() super().__init__()
@ -324,6 +361,12 @@ class ViTOutput(nn.Module):
return hidden_states return hidden_states
VIT_ATTENTION_CLASSES = {
"eager": ViTAttention,
"sdpa": ViTSdpaAttention,
}
class ViTLayer(nn.Module): class ViTLayer(nn.Module):
"""This corresponds to the Block class in the timm implementation.""" """This corresponds to the Block class in the timm implementation."""
@ -331,7 +374,7 @@ class ViTLayer(nn.Module):
super().__init__() super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1 self.seq_len_dim = 1
self.attention = ViTAttention(config) self.attention = VIT_ATTENTION_CLASSES[config._attn_implementation](config)
self.intermediate = ViTIntermediate(config) self.intermediate = ViTIntermediate(config)
self.output = ViTOutput(config) self.output = ViTOutput(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
@ -428,6 +471,7 @@ class ViTPreTrainedModel(PreTrainedModel):
main_input_name = "pixel_values" main_input_name = "pixel_values"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["ViTEmbeddings", "ViTLayer"] _no_split_modules = ["ViTEmbeddings", "ViTLayer"]
_supports_sdpa = True
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights""" """Initialize the weights"""

View File

@ -248,6 +248,38 @@ class ViTHybridSelfAttention(nn.Module):
return outputs return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->ViTHybrid
class ViTHybridSdpaSelfAttention(ViTHybridSelfAttention):
def __init__(self, config: ViTHybridConfig) -> None:
super().__init__(config)
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
head_mask,
self.attention_probs_dropout_prob if self.training else 0.0,
is_causal=False,
scale=None,
)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
return context_layer, None
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->ViTHybrid # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->ViTHybrid
class ViTHybridSelfOutput(nn.Module): class ViTHybridSelfOutput(nn.Module):
""" """
@ -307,6 +339,13 @@ class ViTHybridAttention(nn.Module):
return outputs return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->ViTHybrid
class ViTHybridSdpaAttention(ViTHybridAttention):
def __init__(self, config: ViTHybridConfig) -> None:
super().__init__(config)
self.attention = ViTHybridSdpaSelfAttention(config)
# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->ViTHybrid # Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->ViTHybrid
class ViTHybridIntermediate(nn.Module): class ViTHybridIntermediate(nn.Module):
def __init__(self, config: ViTHybridConfig) -> None: def __init__(self, config: ViTHybridConfig) -> None:
@ -340,6 +379,12 @@ class ViTHybridOutput(nn.Module):
return hidden_states return hidden_states
VIT_HYBRID_ATTENTION_CLASSES = {
"eager": ViTHybridAttention,
"sdpa": ViTHybridSdpaAttention,
}
class ViTHybridLayer(nn.Module): class ViTHybridLayer(nn.Module):
"""This corresponds to the Block class in the timm implementation.""" """This corresponds to the Block class in the timm implementation."""
@ -347,7 +392,7 @@ class ViTHybridLayer(nn.Module):
super().__init__() super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1 self.seq_len_dim = 1
self.attention = ViTHybridAttention(config) self.attention = VIT_HYBRID_ATTENTION_CLASSES[config._attn_implementation](config)
self.intermediate = ViTHybridIntermediate(config) self.intermediate = ViTHybridIntermediate(config)
self.output = ViTHybridOutput(config) self.output = ViTHybridOutput(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
@ -447,6 +492,7 @@ class ViTHybridPreTrainedModel(PreTrainedModel):
main_input_name = "pixel_values" main_input_name = "pixel_values"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["ViTHybridEmbeddings", "ViTHybridLayer"] _no_split_modules = ["ViTHybridEmbeddings", "ViTHybridLayer"]
_supports_sdpa = True
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights""" """Initialize the weights"""

View File

@ -241,8 +241,8 @@ class ViTMAEEmbeddings(nn.Module):
noise = torch.rand(batch_size, seq_length, device=sequence.device) # noise in [0, 1] noise = torch.rand(batch_size, seq_length, device=sequence.device) # noise in [0, 1]
# sort noise for each sample # sort noise for each sample
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove ids_shuffle = torch.argsort(noise, dim=1).to(sequence.device) # ascend: small is keep, large is remove
ids_restore = torch.argsort(ids_shuffle, dim=1) ids_restore = torch.argsort(ids_shuffle, dim=1).to(sequence.device)
# keep the first subset # keep the first subset
ids_keep = ids_shuffle[:, :len_keep] ids_keep = ids_shuffle[:, :len_keep]
@ -370,6 +370,38 @@ class ViTMAESelfAttention(nn.Module):
return outputs return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention ViT->ViTMAE
class ViTMAESdpaSelfAttention(ViTMAESelfAttention):
def __init__(self, config: ViTMAEConfig) -> None:
super().__init__(config)
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
head_mask,
self.attention_probs_dropout_prob if self.training else 0.0,
is_causal=False,
scale=None,
)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
return context_layer, None
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->ViTMAE # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->ViTMAE
class ViTMAESelfOutput(nn.Module): class ViTMAESelfOutput(nn.Module):
""" """
@ -429,6 +461,13 @@ class ViTMAEAttention(nn.Module):
return outputs return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->ViTMAE
class ViTMAESdpaAttention(ViTMAEAttention):
def __init__(self, config: ViTMAEConfig) -> None:
super().__init__(config)
self.attention = ViTMAESdpaSelfAttention(config)
# Copied from transformers.models.vit.modeling_vit.ViTIntermediate ViT->ViTMAE # Copied from transformers.models.vit.modeling_vit.ViTIntermediate ViT->ViTMAE
class ViTMAEIntermediate(nn.Module): class ViTMAEIntermediate(nn.Module):
def __init__(self, config: ViTMAEConfig) -> None: def __init__(self, config: ViTMAEConfig) -> None:
@ -462,7 +501,13 @@ class ViTMAEOutput(nn.Module):
return hidden_states return hidden_states
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->ViTMAE VITMAE_ATTENTION_CLASSES = {
"eager": ViTMAEAttention,
"sdpa": ViTMAESdpaAttention,
}
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->ViTMAE,VIT->VITMAE
class ViTMAELayer(nn.Module): class ViTMAELayer(nn.Module):
"""This corresponds to the Block class in the timm implementation.""" """This corresponds to the Block class in the timm implementation."""
@ -470,7 +515,7 @@ class ViTMAELayer(nn.Module):
super().__init__() super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1 self.seq_len_dim = 1
self.attention = ViTMAEAttention(config) self.attention = VITMAE_ATTENTION_CLASSES[config._attn_implementation](config)
self.intermediate = ViTMAEIntermediate(config) self.intermediate = ViTMAEIntermediate(config)
self.output = ViTMAEOutput(config) self.output = ViTMAEOutput(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
@ -567,6 +612,7 @@ class ViTMAEPreTrainedModel(PreTrainedModel):
base_model_prefix = "vit" base_model_prefix = "vit"
main_input_name = "pixel_values" main_input_name = "pixel_values"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_supports_sdpa = True
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
@ -764,7 +810,8 @@ class ViTMAEDecoder(nn.Module):
# append mask tokens to sequence # append mask tokens to sequence
mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle # unshuffle
x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]).to(x_.device))
x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
# add pos embed # add pos embed

View File

@ -222,6 +222,38 @@ class ViTMSNSelfAttention(nn.Module):
return outputs return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->ViTMSN
class ViTMSNSdpaSelfAttention(ViTMSNSelfAttention):
def __init__(self, config: ViTMSNConfig) -> None:
super().__init__(config)
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
head_mask,
self.attention_probs_dropout_prob if self.training else 0.0,
is_causal=False,
scale=None,
)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
return context_layer, None
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->ViTMSN # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->ViTMSN
class ViTMSNSelfOutput(nn.Module): class ViTMSNSelfOutput(nn.Module):
""" """
@ -281,6 +313,13 @@ class ViTMSNAttention(nn.Module):
return outputs return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->ViTMSN
class ViTMSNSdpaAttention(ViTMSNAttention):
def __init__(self, config: ViTMSNConfig) -> None:
super().__init__(config)
self.attention = ViTMSNSdpaSelfAttention(config)
# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->ViTMSN # Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->ViTMSN
class ViTMSNIntermediate(nn.Module): class ViTMSNIntermediate(nn.Module):
def __init__(self, config: ViTMSNConfig) -> None: def __init__(self, config: ViTMSNConfig) -> None:
@ -314,7 +353,10 @@ class ViTMSNOutput(nn.Module):
return hidden_states return hidden_states
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->ViTMSN VITMSN_ATTENTION_CLASSES = {"eager": ViTMSNAttention, "sdpa": ViTMSNSdpaAttention}
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->ViTMSN, VIT->VITMSN
class ViTMSNLayer(nn.Module): class ViTMSNLayer(nn.Module):
"""This corresponds to the Block class in the timm implementation.""" """This corresponds to the Block class in the timm implementation."""
@ -322,7 +364,7 @@ class ViTMSNLayer(nn.Module):
super().__init__() super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1 self.seq_len_dim = 1
self.attention = ViTMSNAttention(config) self.attention = VITMSN_ATTENTION_CLASSES[config._attn_implementation](config)
self.intermediate = ViTMSNIntermediate(config) self.intermediate = ViTMSNIntermediate(config)
self.output = ViTMSNOutput(config) self.output = ViTMSNOutput(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
@ -419,7 +461,8 @@ class ViTMSNPreTrainedModel(PreTrainedModel):
base_model_prefix = "vit" base_model_prefix = "vit"
main_input_name = "pixel_values" main_input_name = "pixel_values"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["ViTMSNAttention"] _no_split_modules = ["ViTMSNAttention", "ViTMSNSdpaAttention"]
_supports_sdpa = True
# todo: Resort to https://github.com/facebookresearch/msn/blob/main/src/deit.py#L200-#L211 # todo: Resort to https://github.com/facebookresearch/msn/blob/main/src/deit.py#L200-#L211
# when creating pre-training scripts. # when creating pre-training scripts.

View File

@ -307,6 +307,38 @@ class YolosSelfAttention(nn.Module):
return outputs return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->Yolos
class YolosSdpaSelfAttention(YolosSelfAttention):
def __init__(self, config: YolosConfig) -> None:
super().__init__(config)
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
head_mask,
self.attention_probs_dropout_prob if self.training else 0.0,
is_causal=False,
scale=None,
)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
return context_layer, None
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Yolos # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Yolos
class YolosSelfOutput(nn.Module): class YolosSelfOutput(nn.Module):
""" """
@ -366,6 +398,13 @@ class YolosAttention(nn.Module):
return outputs return outputs
# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->Yolos
class YolosSdpaAttention(YolosAttention):
def __init__(self, config: YolosConfig) -> None:
super().__init__(config)
self.attention = YolosSdpaSelfAttention(config)
# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->Yolos # Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->Yolos
class YolosIntermediate(nn.Module): class YolosIntermediate(nn.Module):
def __init__(self, config: YolosConfig) -> None: def __init__(self, config: YolosConfig) -> None:
@ -399,7 +438,10 @@ class YolosOutput(nn.Module):
return hidden_states return hidden_states
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->Yolos YOLOS_ATTENTION_CLASSES = {"eager": YolosAttention, "sdpa": YolosSdpaAttention}
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->Yolos,VIT->YOLOS
class YolosLayer(nn.Module): class YolosLayer(nn.Module):
"""This corresponds to the Block class in the timm implementation.""" """This corresponds to the Block class in the timm implementation."""
@ -407,7 +449,7 @@ class YolosLayer(nn.Module):
super().__init__() super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1 self.seq_len_dim = 1
self.attention = YolosAttention(config) self.attention = YOLOS_ATTENTION_CLASSES[config._attn_implementation](config)
self.intermediate = YolosIntermediate(config) self.intermediate = YolosIntermediate(config)
self.output = YolosOutput(config) self.output = YolosOutput(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
@ -531,6 +573,7 @@ class YolosPreTrainedModel(PreTrainedModel):
main_input_name = "pixel_values" main_input_name = "pixel_values"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = [] _no_split_modules = []
_supports_sdpa = True
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights""" """Initialize the weights"""

View File

@ -63,6 +63,7 @@ class ASTModelTester:
scope=None, scope=None,
frequency_stride=2, frequency_stride=2,
time_stride=2, time_stride=2,
attn_implementation="eager",
): ):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
@ -83,6 +84,7 @@ class ASTModelTester:
self.scope = scope self.scope = scope
self.frequency_stride = frequency_stride self.frequency_stride = frequency_stride
self.time_stride = time_stride self.time_stride = time_stride
self.attn_implementation = attn_implementation
# in AST, the seq length equals the number of patches + 2 (we add 2 for the [CLS] and distillation tokens) # in AST, the seq length equals the number of patches + 2 (we add 2 for the [CLS] and distillation tokens)
frequency_out_dimension = (self.num_mel_bins - self.patch_size) // self.frequency_stride + 1 frequency_out_dimension = (self.num_mel_bins - self.patch_size) // self.frequency_stride + 1
@ -117,6 +119,7 @@ class ASTModelTester:
initializer_range=self.initializer_range, initializer_range=self.initializer_range,
frequency_stride=self.frequency_stride, frequency_stride=self.frequency_stride,
time_stride=self.time_stride, time_stride=self.time_stride,
attn_implementation=self.attn_implementation,
) )
def create_and_check_model(self, config, input_values, labels): def create_and_check_model(self, config, input_values, labels):

View File

@ -80,6 +80,8 @@ class DeiTModelTester:
num_labels=3, num_labels=3,
scope=None, scope=None,
encoder_stride=2, encoder_stride=2,
mask_ratio=0.5,
attn_implementation="eager",
): ):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
@ -99,10 +101,14 @@ class DeiTModelTester:
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.scope = scope self.scope = scope
self.encoder_stride = encoder_stride self.encoder_stride = encoder_stride
self.attn_implementation = attn_implementation
# in DeiT, the seq length equals the number of patches + 2 (we add 2 for the [CLS] and distilation tokens) # in DeiT, the seq length equals the number of patches + 2 (we add 2 for the [CLS] and distilation tokens)
num_patches = (image_size // patch_size) ** 2 num_patches = (image_size // patch_size) ** 2
self.seq_length = num_patches + 2 self.seq_length = num_patches + 2
self.mask_ratio = mask_ratio
self.num_masks = int(mask_ratio * self.seq_length)
self.mask_length = num_patches
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
@ -130,6 +136,7 @@ class DeiTModelTester:
is_decoder=False, is_decoder=False,
initializer_range=self.initializer_range, initializer_range=self.initializer_range,
encoder_stride=self.encoder_stride, encoder_stride=self.encoder_stride,
attn_implementation=self.attn_implementation,
) )
def create_and_check_model(self, config, pixel_values, labels): def create_and_check_model(self, config, pixel_values, labels):

View File

@ -71,6 +71,7 @@ class TFDeiTModelTester:
num_labels=3, num_labels=3,
scope=None, scope=None,
encoder_stride=2, encoder_stride=2,
attn_implementation="eager",
): ):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
@ -90,6 +91,7 @@ class TFDeiTModelTester:
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.scope = scope self.scope = scope
self.encoder_stride = encoder_stride self.encoder_stride = encoder_stride
self.attn_implementation = attn_implementation
# in DeiT, the seq length equals the number of patches + 2 (we add 2 for the [CLS] and distilation tokens) # in DeiT, the seq length equals the number of patches + 2 (we add 2 for the [CLS] and distilation tokens)
num_patches = (image_size // patch_size) ** 2 num_patches = (image_size // patch_size) ** 2
@ -121,6 +123,7 @@ class TFDeiTModelTester:
is_decoder=False, is_decoder=False,
initializer_range=self.initializer_range, initializer_range=self.initializer_range,
encoder_stride=self.encoder_stride, encoder_stride=self.encoder_stride,
attn_implementation=self.attn_implementation,
) )
def create_and_check_model(self, config, pixel_values, labels): def create_and_check_model(self, config, pixel_values, labels):

View File

@ -70,6 +70,7 @@ class VideoMAEModelTester:
initializer_range=0.02, initializer_range=0.02,
mask_ratio=0.9, mask_ratio=0.9,
scope=None, scope=None,
attn_implementation="eager",
): ):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
@ -91,6 +92,7 @@ class VideoMAEModelTester:
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.mask_ratio = mask_ratio self.mask_ratio = mask_ratio
self.scope = scope self.scope = scope
self.attn_implementation = attn_implementation
# in VideoMAE, the number of tokens equals num_frames/tubelet_size * num_patches per frame # in VideoMAE, the number of tokens equals num_frames/tubelet_size * num_patches per frame
self.num_patches_per_frame = (image_size // patch_size) ** 2 self.num_patches_per_frame = (image_size // patch_size) ** 2
@ -132,6 +134,7 @@ class VideoMAEModelTester:
decoder_intermediate_size=self.intermediate_size, decoder_intermediate_size=self.intermediate_size,
decoder_num_attention_heads=self.num_attention_heads, decoder_num_attention_heads=self.num_attention_heads,
decoder_num_hidden_layers=self.num_hidden_layers, decoder_num_hidden_layers=self.num_hidden_layers,
attn_implementation=self.attn_implementation,
) )
def create_and_check_model(self, config, pixel_values, labels): def create_and_check_model(self, config, pixel_values, labels):
@ -197,7 +200,8 @@ class VideoMAEModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
# hence we define a single mask, which we then repeat for each example in the batch # hence we define a single mask, which we then repeat for each example in the batch
mask = torch.ones((self.model_tester.num_masks,)) mask = torch.ones((self.model_tester.num_masks,))
mask = torch.cat([mask, torch.zeros(self.model_tester.seq_length - mask.size(0))]) mask = torch.cat([mask, torch.zeros(self.model_tester.seq_length - mask.size(0))])
bool_masked_pos = mask.expand(self.model_tester.batch_size, -1).bool() batch_size = inputs_dict["pixel_values"].shape[0]
bool_masked_pos = mask.expand(batch_size, -1).bool()
inputs_dict["bool_masked_pos"] = bool_masked_pos.to(torch_device) inputs_dict["bool_masked_pos"] = bool_masked_pos.to(torch_device)
if return_labels: if return_labels:

View File

@ -492,7 +492,9 @@ class TFVisionEncoderDecoderMixin:
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
tf_model.save_pretrained(tmpdirname, safe_serialization=False) tf_model.save_pretrained(tmpdirname, safe_serialization=False)
pt_model = VisionEncoderDecoderModel.from_pretrained(tmpdirname, from_tf=True) pt_model = VisionEncoderDecoderModel.from_pretrained(
tmpdirname, from_tf=True, attn_implementation=tf_model.config._attn_implementation
)
self.check_pt_tf_equivalence(tf_model, pt_model, tf_inputs_dict) self.check_pt_tf_equivalence(tf_model, pt_model, tf_inputs_dict)

View File

@ -49,6 +49,7 @@ class FlaxViTModelTester(unittest.TestCase):
attention_probs_dropout_prob=0.1, attention_probs_dropout_prob=0.1,
type_sequence_label_size=10, type_sequence_label_size=10,
initializer_range=0.02, initializer_range=0.02,
attn_implementation="eager",
): ):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
@ -66,6 +67,7 @@ class FlaxViTModelTester(unittest.TestCase):
self.attention_probs_dropout_prob = attention_probs_dropout_prob self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.type_sequence_label_size = type_sequence_label_size self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.attn_implementation = attn_implementation
# in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token) # in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
num_patches = (image_size // patch_size) ** 2 num_patches = (image_size // patch_size) ** 2
@ -87,6 +89,7 @@ class FlaxViTModelTester(unittest.TestCase):
attention_probs_dropout_prob=self.attention_probs_dropout_prob, attention_probs_dropout_prob=self.attention_probs_dropout_prob,
is_decoder=False, is_decoder=False,
initializer_range=self.initializer_range, initializer_range=self.initializer_range,
attn_implementation=self.attn_implementation,
) )
return config, pixel_values return config, pixel_values

View File

@ -63,6 +63,7 @@ class TFViTModelTester:
initializer_range=0.02, initializer_range=0.02,
num_labels=3, num_labels=3,
scope=None, scope=None,
attn_implementation="eager",
): ):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
@ -81,6 +82,7 @@ class TFViTModelTester:
self.type_sequence_label_size = type_sequence_label_size self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.scope = scope self.scope = scope
self.attn_implementation = attn_implementation
# in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token) # in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
num_patches = (image_size // patch_size) ** 2 num_patches = (image_size // patch_size) ** 2
@ -111,6 +113,7 @@ class TFViTModelTester:
attention_probs_dropout_prob=self.attention_probs_dropout_prob, attention_probs_dropout_prob=self.attention_probs_dropout_prob,
is_decoder=False, is_decoder=False,
initializer_range=self.initializer_range, initializer_range=self.initializer_range,
attn_implementation=self.attn_implementation,
) )
def create_and_check_model(self, config, pixel_values, labels): def create_and_check_model(self, config, pixel_values, labels):

View File

@ -68,6 +68,8 @@ class ViTModelTester:
initializer_range=0.02, initializer_range=0.02,
scope=None, scope=None,
encoder_stride=2, encoder_stride=2,
mask_ratio=0.5,
attn_implementation="eager",
): ):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
@ -87,10 +89,14 @@ class ViTModelTester:
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.scope = scope self.scope = scope
self.encoder_stride = encoder_stride self.encoder_stride = encoder_stride
self.attn_implementation = attn_implementation
# in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token) # in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
num_patches = (image_size // patch_size) ** 2 num_patches = (image_size // patch_size) ** 2
self.seq_length = num_patches + 1 self.seq_length = num_patches + 1
self.mask_ratio = mask_ratio
self.num_masks = int(mask_ratio * self.seq_length)
self.mask_length = num_patches
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
@ -118,6 +124,7 @@ class ViTModelTester:
is_decoder=False, is_decoder=False,
initializer_range=self.initializer_range, initializer_range=self.initializer_range,
encoder_stride=self.encoder_stride, encoder_stride=self.encoder_stride,
attn_implementation=self.attn_implementation,
) )
def create_and_check_model(self, config, pixel_values, labels): def create_and_check_model(self, config, pixel_values, labels):

View File

@ -58,6 +58,7 @@ class ViTHybridModelTester:
initializer_range=0.02, initializer_range=0.02,
backbone_featmap_shape=[1, 16, 4, 4], backbone_featmap_shape=[1, 16, 4, 4],
scope=None, scope=None,
attn_implementation="eager",
): ):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
@ -77,6 +78,7 @@ class ViTHybridModelTester:
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.scope = scope self.scope = scope
self.backbone_featmap_shape = backbone_featmap_shape self.backbone_featmap_shape = backbone_featmap_shape
self.attn_implementation = attn_implementation
# in ViT hybrid, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token) # in ViT hybrid, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
# the number of patches is based on the feature map of the backbone, which by default uses an output stride # the number of patches is based on the feature map of the backbone, which by default uses an output stride
@ -122,6 +124,7 @@ class ViTHybridModelTester:
backbone_featmap_shape=self.backbone_featmap_shape, backbone_featmap_shape=self.backbone_featmap_shape,
backbone_config=backbone_config, backbone_config=backbone_config,
backbone=None, backbone=None,
attn_implementation=self.attn_implementation,
) )
def create_and_check_model(self, config, pixel_values, labels): def create_and_check_model(self, config, pixel_values, labels):

View File

@ -72,6 +72,7 @@ class TFViTMAEModelTester:
num_labels=3, num_labels=3,
mask_ratio=0.6, mask_ratio=0.6,
scope=None, scope=None,
attn_implementation="eager",
): ):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
@ -91,6 +92,7 @@ class TFViTMAEModelTester:
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.mask_ratio = mask_ratio self.mask_ratio = mask_ratio
self.scope = scope self.scope = scope
self.attn_implementation = attn_implementation
# in ViTMAE, the expected sequence length = (num_patches + 1) * (1 - config.mask_ratio), rounded above # in ViTMAE, the expected sequence length = (num_patches + 1) * (1 - config.mask_ratio), rounded above
# (we add 1 for the [CLS] token) # (we add 1 for the [CLS] token)
@ -127,6 +129,7 @@ class TFViTMAEModelTester:
is_decoder=False, is_decoder=False,
initializer_range=self.initializer_range, initializer_range=self.initializer_range,
mask_ratio=self.mask_ratio, mask_ratio=self.mask_ratio,
attn_implementation=self.attn_implementation,
) )
def create_and_check_model(self, config, pixel_values, labels): def create_and_check_model(self, config, pixel_values, labels):

View File

@ -63,8 +63,9 @@ class ViTMAEModelTester:
type_sequence_label_size=10, type_sequence_label_size=10,
initializer_range=0.02, initializer_range=0.02,
num_labels=3, num_labels=3,
mask_ratio=0.6,
scope=None, scope=None,
mask_ratio=0.5,
attn_implementation="eager",
): ):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
@ -84,11 +85,15 @@ class ViTMAEModelTester:
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.mask_ratio = mask_ratio self.mask_ratio = mask_ratio
self.scope = scope self.scope = scope
self.attn_implementation = attn_implementation
# in ViTMAE, the expected sequence length = (num_patches + 1) * (1 - config.mask_ratio), rounded above # in ViTMAE, the expected sequence length = (num_patches + 1) * (1 - config.mask_ratio), rounded above
# (we add 1 for the [CLS] token) # (we add 1 for the [CLS] token)
num_patches = (image_size // patch_size) ** 2 num_patches = (image_size // patch_size) ** 2
self.seq_length = int(math.ceil((1 - mask_ratio) * (num_patches + 1))) self.seq_length = int(math.ceil((1 - mask_ratio) * (num_patches + 1)))
self.mask_ratio = mask_ratio
self.num_masks = int(mask_ratio * self.seq_length)
self.mask_length = num_patches
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
@ -120,6 +125,7 @@ class ViTMAEModelTester:
decoder_intermediate_size=self.intermediate_size, decoder_intermediate_size=self.intermediate_size,
decoder_num_attention_heads=self.num_attention_heads, decoder_num_attention_heads=self.num_attention_heads,
decoder_num_hidden_layers=self.num_hidden_layers, decoder_num_hidden_layers=self.num_hidden_layers,
attn_implementation=self.attn_implementation,
) )
def create_and_check_model(self, config, pixel_values, labels): def create_and_check_model(self, config, pixel_values, labels):

View File

@ -59,6 +59,7 @@ class ViTMSNModelTester:
type_sequence_label_size=10, type_sequence_label_size=10,
initializer_range=0.02, initializer_range=0.02,
scope=None, scope=None,
attn_implementation="eager",
): ):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
@ -77,6 +78,7 @@ class ViTMSNModelTester:
self.type_sequence_label_size = type_sequence_label_size self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.scope = scope self.scope = scope
self.attn_implementation = attn_implementation
# in ViT MSN, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token) # in ViT MSN, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
num_patches = (image_size // patch_size) ** 2 num_patches = (image_size // patch_size) ** 2
@ -106,6 +108,7 @@ class ViTMSNModelTester:
hidden_dropout_prob=self.hidden_dropout_prob, hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob, attention_probs_dropout_prob=self.attention_probs_dropout_prob,
initializer_range=self.initializer_range, initializer_range=self.initializer_range,
attn_implementation=self.attn_implementation,
) )
def create_and_check_model(self, config, pixel_values, labels): def create_and_check_model(self, config, pixel_values, labels):

View File

@ -62,6 +62,7 @@ class YolosModelTester:
scope=None, scope=None,
n_targets=8, n_targets=8,
num_detection_tokens=10, num_detection_tokens=10,
attn_implementation="eager",
): ):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
@ -83,6 +84,7 @@ class YolosModelTester:
self.scope = scope self.scope = scope
self.n_targets = n_targets self.n_targets = n_targets
self.num_detection_tokens = num_detection_tokens self.num_detection_tokens = num_detection_tokens
self.attn_implementation = attn_implementation
# we set the expected sequence length (which is used in several tests) # we set the expected sequence length (which is used in several tests)
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token) + num_detection_tokens # expected sequence length = num_patches + 1 (we add 1 for the [CLS] token) + num_detection_tokens
num_patches = (image_size[1] // patch_size) * (image_size[0] // patch_size) num_patches = (image_size[1] // patch_size) * (image_size[0] // patch_size)
@ -123,6 +125,7 @@ class YolosModelTester:
initializer_range=self.initializer_range, initializer_range=self.initializer_range,
num_detection_tokens=self.num_detection_tokens, num_detection_tokens=self.num_detection_tokens,
num_labels=self.num_labels, num_labels=self.num_labels,
attn_implementation=self.attn_implementation,
) )
def create_and_check_model(self, config, pixel_values, labels): def create_and_check_model(self, config, pixel_values, labels):

View File

@ -2788,7 +2788,9 @@ class ModelTesterMixin:
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
fx_model.save_pretrained(tmpdirname) fx_model.save_pretrained(tmpdirname)
pt_model_loaded = model_class.from_pretrained(tmpdirname, from_flax=True) pt_model_loaded = model_class.from_pretrained(
tmpdirname, from_flax=True, attn_implementation=fx_model.config._attn_implementation
)
# send pytorch model to the correct device # send pytorch model to the correct device
pt_model_loaded.to(torch_device) pt_model_loaded.to(torch_device)
@ -3724,6 +3726,11 @@ class ModelTesterMixin:
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config) model = model_class(config)
# FIXME: we deactivate boolean mask for models using "use_mask_token" in their constructors.
# These models support masking only in the case `use_mask_token=True`. Otherwise they cannot consume an input mask.
# This means that the class needs to be instantiated much later, after `use_mask` is set, which means a significant refactor of the code.
# However masking there is not done at any layers that matters (i.e self-attention), therefore we can safely deactivate it.
deactivate_mask = "use_mask_token" in inspect.signature(model_class).parameters
is_encoder_decoder = model.config.is_encoder_decoder is_encoder_decoder = model.config.is_encoder_decoder
@ -3861,6 +3868,27 @@ class ModelTesterMixin:
and "output_attentions" in inspect.signature(model_sdpa.forward).parameters and "output_attentions" in inspect.signature(model_sdpa.forward).parameters
): ):
processed_inputs["output_attentions"] = output_attentions processed_inputs["output_attentions"] = output_attentions
if not deactivate_mask and (
"bool_masked_pos" in inspect.signature(model_eager.forward).parameters
):
dummy_mask = torch.ones((self.model_tester.num_masks,))
# In case of additional token (like class) we define a custom `mask_length`
if hasattr(self.model_tester, "mask_length"):
mask_length = self.model_tester.mask_length - dummy_mask.size(0)
else:
mask_length = self.model_tester.seq_length - dummy_mask.size(0)
dummy_mask = torch.cat([dummy_mask, torch.zeros(mask_length)])
dummy_bool_masked_pos = dummy_mask.expand(batch_size, -1).bool()
processed_inputs["bool_masked_pos"] = dummy_bool_masked_pos.to(torch_device)
if "noise" in inspect.signature(model_eager.forward).parameters:
np.random.seed(2)
num_patches = int(
(self.model_tester.image_size // self.model_tester.patch_size) ** 2
)
noise = np.random.uniform(size=(batch_size, num_patches))
processed_inputs["noise"] = torch.from_numpy(noise)
# TODO: test gradients as well (& for FA2 as well!) # TODO: test gradients as well (& for FA2 as well!)
with torch.no_grad(): with torch.no_grad():

View File

@ -371,7 +371,9 @@ class FlaxModelTesterMixin:
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
fx_model.save_pretrained(tmpdirname) fx_model.save_pretrained(tmpdirname)
pt_model_loaded = pt_model_class.from_pretrained(tmpdirname, from_flax=True) pt_model_loaded = pt_model_class.from_pretrained(
tmpdirname, from_flax=True, attn_implementation=fx_model.config._attn_implementation
)
# send pytorch model to the correct device # send pytorch model to the correct device
pt_model_loaded.to(torch_device) pt_model_loaded.to(torch_device)

View File

@ -84,7 +84,7 @@ def check_sdpa_support_list():
archs_supporting_sdpa.append(model_name) archs_supporting_sdpa.append(model_name)
for arch in archs_supporting_sdpa: for arch in archs_supporting_sdpa:
if arch not in doctext: if arch not in doctext and arch not in doctext.replace("-", "_"):
raise ValueError( raise ValueError(
f"{arch} should be in listed in the SDPA documentation but is not. Please update the documentation." f"{arch} should be in listed in the SDPA documentation but is not. Please update the documentation."
) )