mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add FA2 and sdpa
support for SigLIP (#31499)
* Rebase to main * Fix attention implementation autoset for tex and vision configs * Fixup * Minor fixes * Fix copies * Fix attention_mask for FA2 * Add eqvivalence tests for siglip * Remove right padding test * Uncomment flaky * Fix import * Add to docs * Fix test message * Add sdpa * Add sdpa equivalence test * Add siglip sdpa to docs * Fix typing for attention output * Add sdpa tests * Fix signature of FA2 * Autoset attn_implementation in config * Rename bsz -> batch_size * Move back autoset attn method * Mark as flaky * Correct attention mask padding * [run-slow] siglip * Add FA2 and sdpa docs * Style fix * Remove flaky for FA2 test * Change attention implementation set * Change attn_implementaiton propogation * Fix typos * Add modality to assert message * Add more sdpa backends in test * [run slow] siglip * Add math sdpa backend for all options * [run slow] siglip
This commit is contained in:
parent
076e66e479
commit
a177821b24
@ -107,6 +107,88 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
|
||||
|
||||
If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
|
||||
|
||||
|
||||
## Combining SigLIP and Flash Attention 2
|
||||
|
||||
First, make sure to install the latest version of Flash Attention 2.
|
||||
|
||||
```bash
|
||||
pip install -U flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
Make also sure that you have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of flash-attn repository. Make also sure to load your model in half-precision (e.g. `torch.float16``)
|
||||
|
||||
To load and run a model using Flash Attention 2, refer to the snippet below:
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
>>> import requests
|
||||
>>> from PIL import Image
|
||||
>>> from transformers import SiglipProcessor, SiglipModel
|
||||
>>> device = "cuda" # the device to load the model onto
|
||||
|
||||
>>> model = SiglipModel.from_pretrained(
|
||||
... "google/siglip-so400m-patch14-384",
|
||||
... attn_implementation="flash_attention_2",
|
||||
... torch_dtype=torch.float16,
|
||||
... device_map=device,
|
||||
... )
|
||||
>>> processor = SiglipProcessor.from_pretrained("google/siglip-so400m-patch14-384")
|
||||
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> candidate_labels = ["2 cats", "2 dogs"]
|
||||
# follows the pipeline prompt template to get same results
|
||||
>>> candidate_labels = [f'This is a photo of {label}.' for label in candidate_labels]
|
||||
# important: we pass `padding=max_length` since the model was trained with this
|
||||
>>> inputs = processor(text=candidate_labels, images=image, padding="max_length", return_tensors="pt")
|
||||
>>> inputs.to(device)
|
||||
|
||||
>>> with torch.no_grad():
|
||||
... with torch.autocast(device):
|
||||
... outputs = model(**inputs)
|
||||
|
||||
>>> logits_per_image = outputs.logits_per_image
|
||||
>>> probs = torch.sigmoid(logits_per_image) # these are the probabilities
|
||||
>>> print(f"{probs[0][0]:.1%} that image 0 is '{candidate_labels[0]}'")
|
||||
51.3% that image 0 is 'This is a photo of 2 cats.'
|
||||
```
|
||||
|
||||
|
||||
## Using Scaled Dot Product Attention (SDPA)
|
||||
|
||||
PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
|
||||
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
|
||||
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
|
||||
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
|
||||
page for more information.
|
||||
|
||||
You may set `attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used. Make sure you have `torch>=2.1.1`.
|
||||
|
||||
```python
|
||||
>>> from transformers import SiglipModel
|
||||
|
||||
>>> model = SiglipModel.from_pretrained(
|
||||
... "google/siglip-so400m-patch14-384",
|
||||
... attn_implementation="sdpa",
|
||||
... torch_dtype=torch.float16,
|
||||
... device_map=device,
|
||||
... )
|
||||
```
|
||||
|
||||
For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).
|
||||
|
||||
|
||||
## Expected speedups
|
||||
|
||||
Below is an expected speedup diagram that compares inference time between the native implementation in transformers using `google/siglip-so400m-patch14-384` checkpoint in `float16` precision and the Flash Attention 2 / SDPA version of the model using different batch sizes.
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://i.imgur.com/cWm4rsn.png">
|
||||
</div>
|
||||
|
||||
|
||||
## SiglipConfig
|
||||
|
||||
[[autodoc]] SiglipConfig
|
||||
|
@ -70,6 +70,7 @@ FlashAttention-2 is currently supported for the following architectures:
|
||||
* [OPT](https://huggingface.co/docs/transformers/model_doc/opt#transformers.OPTModel)
|
||||
* [Phi](https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel)
|
||||
* [Phi3](https://huggingface.co/docs/transformers/model_doc/phi3#transformers.Phi3Model)
|
||||
* [SigLIP](https://huggingface.co/docs/transformers/model_doc/siglip)
|
||||
* [StableLm](https://huggingface.co/docs/transformers/model_doc/stablelm#transformers.StableLmModel)
|
||||
* [Starcoder2](https://huggingface.co/docs/transformers/model_doc/starcoder2#transformers.Starcoder2Model)
|
||||
* [Qwen2](https://huggingface.co/docs/transformers/model_doc/qwen2#transformers.Qwen2Model)
|
||||
@ -231,6 +232,7 @@ For now, Transformers supports SDPA inference and training for the following arc
|
||||
* [wav2vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2#transformers.Wav2Vec2Model)
|
||||
* [Hubert](https://huggingface.co/docs/transformers/model_doc/hubert#transformers.HubertModel)
|
||||
* [data2vec_audio](https://huggingface.co/docs/transformers/main/en/model_doc/data2vec#transformers.Data2VecAudioModel)
|
||||
* [SigLIP](https://huggingface.co/docs/transformers/model_doc/siglip)
|
||||
* [Sew](https://huggingface.co/docs/transformers/main/en/model_doc/sew#transformers.SEWModel)
|
||||
* [UniSpeech](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech#transformers.UniSpeechModel)
|
||||
* [unispeech_sat](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech-sat#transformers.UniSpeechSatModel)
|
||||
|
@ -221,7 +221,7 @@ class Idefics2VisionAttention(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
batch_size, q_len, _ = hidden_states.size()
|
||||
|
@ -21,6 +21,7 @@ from typing import Any, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
@ -34,12 +35,19 @@ from ...utils import (
|
||||
ModelOutput,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_flash_attn_2_available,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from .configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig
|
||||
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
# General docstring
|
||||
@ -47,6 +55,19 @@ _CONFIG_FOR_DOC = "SiglipConfig"
|
||||
_CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224"
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
||||
def _get_unpad_data(attention_mask):
|
||||
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
||||
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
||||
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
||||
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
||||
return (
|
||||
indices,
|
||||
cu_seqlens,
|
||||
max_seqlen_in_batch,
|
||||
)
|
||||
|
||||
|
||||
def _trunc_normal_(tensor, mean, std, a, b):
|
||||
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
||||
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
||||
@ -373,7 +394,7 @@ class SiglipAttention(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
batch_size, q_len, _ = hidden_states.size()
|
||||
@ -421,6 +442,266 @@ class SiglipAttention(nn.Module):
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class SiglipFlashAttention2(SiglipAttention):
|
||||
"""
|
||||
SiglipAttention flash attention module. This module inherits from `SiglipAttention` as the weights of the module stays
|
||||
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
||||
flash attention and deal with padding tokens in case the input contains any of them.
|
||||
"""
|
||||
|
||||
is_causal = False
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
# Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
output_attentions: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
output_attentions = False
|
||||
|
||||
batch_size, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
# Flash attention requires the input to have the shape
|
||||
# batch_size x seq_length x head_dim x hidden_dim
|
||||
# therefore we just need to keep the original shape
|
||||
query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
||||
# to be able to avoid many of these transpose/reshape/view.
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
dropout_rate = self.dropout if self.training else 0.0
|
||||
|
||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||
# cast them back in the correct dtype just to be sure everything works as expected.
|
||||
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
||||
# in fp32.
|
||||
|
||||
input_dtype = query_states.dtype
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
else:
|
||||
target_dtype = self.q_proj.weight.dtype
|
||||
|
||||
logger.warning_once(
|
||||
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
||||
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||
f" {target_dtype}."
|
||||
)
|
||||
|
||||
query_states = query_states.to(target_dtype)
|
||||
key_states = key_states.to(target_dtype)
|
||||
value_states = value_states.to(target_dtype)
|
||||
|
||||
attn_output = self._flash_attention_forward(
|
||||
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim).contiguous()
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
|
||||
def _flash_attention_forward(
|
||||
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
|
||||
):
|
||||
"""
|
||||
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
||||
first unpad the input, then computes the attention scores and pad the final attention scores.
|
||||
|
||||
Args:
|
||||
query_states (`torch.Tensor`):
|
||||
Input query states to be passed to Flash Attention API
|
||||
key_states (`torch.Tensor`):
|
||||
Input key states to be passed to Flash Attention API
|
||||
value_states (`torch.Tensor`):
|
||||
Input value states to be passed to Flash Attention API
|
||||
attention_mask (`torch.Tensor`):
|
||||
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
||||
position of padding tokens and 1 for the position of non-padding tokens.
|
||||
dropout (`float`):
|
||||
Attention dropout
|
||||
softmax_scale (`float`, *optional*):
|
||||
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
||||
"""
|
||||
|
||||
if not self._flash_attn_uses_top_left_mask:
|
||||
causal = self.is_causal
|
||||
else:
|
||||
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
|
||||
causal = self.is_causal and query_length != 1
|
||||
|
||||
# Contains at least one padding token in the sequence
|
||||
if attention_mask is not None:
|
||||
batch_size = query_states.shape[0]
|
||||
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
|
||||
query_states, key_states, value_states, attention_mask, query_length
|
||||
)
|
||||
|
||||
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
||||
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
||||
|
||||
attn_output_unpad = flash_attn_varlen_func(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_in_batch_q,
|
||||
max_seqlen_k=max_seqlen_in_batch_k,
|
||||
dropout_p=dropout,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=causal,
|
||||
)
|
||||
|
||||
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
||||
else:
|
||||
attn_output = flash_attn_func(
|
||||
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
|
||||
)
|
||||
|
||||
return attn_output
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
|
||||
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
|
||||
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
||||
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
||||
|
||||
key_layer = index_first_axis(
|
||||
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
||||
)
|
||||
value_layer = index_first_axis(
|
||||
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
||||
)
|
||||
if query_length == kv_seq_len:
|
||||
query_layer = index_first_axis(
|
||||
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
|
||||
)
|
||||
cu_seqlens_q = cu_seqlens_k
|
||||
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
||||
indices_q = indices_k
|
||||
elif query_length == 1:
|
||||
max_seqlen_in_batch_q = 1
|
||||
cu_seqlens_q = torch.arange(
|
||||
batch_size + 1, dtype=torch.int32, device=query_layer.device
|
||||
) # There is a memcpy here, that is very bad.
|
||||
indices_q = cu_seqlens_q[:-1]
|
||||
query_layer = query_layer.squeeze(1)
|
||||
else:
|
||||
# The -q_len: slice assumes left padding.
|
||||
attention_mask = attention_mask[:, -query_length:]
|
||||
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
|
||||
|
||||
return (
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
indices_q,
|
||||
(cu_seqlens_q, cu_seqlens_k),
|
||||
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
||||
)
|
||||
|
||||
|
||||
class SiglipSdpaAttention(SiglipAttention):
|
||||
"""
|
||||
Siglip attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
||||
`SiglipAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
||||
SDPA API.
|
||||
"""
|
||||
|
||||
is_causal = False
|
||||
|
||||
# Adapted from SiglipAttention.forward and transformers.models.llama.modeling_llama.LlamaSdpaAttention.forward
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
if output_attentions:
|
||||
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
||||
logger.warning_once(
|
||||
"SiglipModel is using SiglipSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
||||
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
batch_size, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
||||
if query_states.device.type == "cuda" and attention_mask is not None:
|
||||
query_states = query_states.contiguous()
|
||||
key_states = key_states.contiguous()
|
||||
value_states = value_states.contiguous()
|
||||
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
is_causal = True if self.is_causal and q_len > 1 else False
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=self.dropout if self.training else 0.0,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.view(batch_size, q_len, self.embed_dim)
|
||||
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, None
|
||||
|
||||
|
||||
SIGLIP_ATTENTION_CLASSES = {
|
||||
"eager": SiglipAttention,
|
||||
"flash_attention_2": SiglipFlashAttention2,
|
||||
"sdpa": SiglipSdpaAttention,
|
||||
}
|
||||
|
||||
|
||||
# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip
|
||||
class SiglipMLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
@ -437,12 +718,11 @@ class SiglipMLP(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->Siglip
|
||||
class SiglipEncoderLayer(nn.Module):
|
||||
def __init__(self, config: SiglipConfig):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
self.self_attn = SiglipAttention(config)
|
||||
self.self_attn = SIGLIP_ATTENTION_CLASSES[config._attn_implementation](config=config)
|
||||
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
self.mlp = SiglipMLP(config)
|
||||
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
@ -503,6 +783,8 @@ class SiglipPreTrainedModel(PreTrainedModel):
|
||||
"SiglipEncoderLayer",
|
||||
"SiglipMultiheadAttentionPoolingHead",
|
||||
]
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
@ -754,6 +1036,7 @@ class SiglipTextTransformer(nn.Module):
|
||||
self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||||
|
||||
self.head = nn.Linear(embed_dim, embed_dim)
|
||||
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
||||
|
||||
@add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig)
|
||||
@ -786,7 +1069,7 @@ class SiglipTextTransformer(nn.Module):
|
||||
|
||||
# note: SigLIP's text model does not use a causal mask, unlike the original CLIP model.
|
||||
# expand attention_mask
|
||||
if attention_mask is not None:
|
||||
if attention_mask is not None and not self._use_flash_attention_2:
|
||||
# [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
|
||||
attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
|
||||
|
||||
@ -1041,8 +1324,13 @@ class SiglipModel(SiglipPreTrainedModel):
|
||||
text_config = config.text_config
|
||||
vision_config = config.vision_config
|
||||
|
||||
self.text_model = SiglipTextTransformer(text_config)
|
||||
self.vision_model = SiglipVisionTransformer(vision_config)
|
||||
# First, initialize the text and vision models with proper attention implementation
|
||||
text_model = SiglipTextModel._from_config(text_config, attn_implementation=config._attn_implementation)
|
||||
vision_model = SiglipVisionModel._from_config(vision_config, attn_implementation=config._attn_implementation)
|
||||
|
||||
# Second, get the text and vision submodules (for backward compatibility)
|
||||
self.text_model = text_model.text_model
|
||||
self.vision_model = vision_model.vision_model
|
||||
|
||||
self.logit_scale = nn.Parameter(torch.randn(1))
|
||||
self.logit_bias = nn.Parameter(torch.randn(1))
|
||||
@ -1270,7 +1558,13 @@ class SiglipForImageClassification(SiglipPreTrainedModel):
|
||||
super().__init__(config)
|
||||
|
||||
self.num_labels = config.num_labels
|
||||
self.vision_model = SiglipVisionTransformer(config.vision_config)
|
||||
|
||||
# Create the vision model with proper attention
|
||||
# and take only vision_model submodule (for backward compatibility)
|
||||
vision_model = SiglipVisionModel._from_config(
|
||||
config.vision_config, attn_implementation=config._attn_implementation
|
||||
)
|
||||
self.vision_model = vision_model.vision_model
|
||||
|
||||
# Classifier head
|
||||
self.classifier = (
|
||||
|
@ -18,18 +18,30 @@ import inspect
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
from parameterized import parameterized
|
||||
from pytest import mark
|
||||
|
||||
from transformers import SiglipConfig, SiglipTextConfig, SiglipVisionConfig
|
||||
from transformers.testing_utils import (
|
||||
require_flash_attn,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
require_torch_sdpa,
|
||||
require_vision,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
from transformers.utils import (
|
||||
is_torch_available,
|
||||
is_torch_bf16_available_on_device,
|
||||
is_torch_fp16_available_on_device,
|
||||
is_torch_sdpa_available,
|
||||
is_vision_available,
|
||||
)
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import (
|
||||
@ -37,6 +49,7 @@ from ...test_modeling_common import (
|
||||
_config_zero_init,
|
||||
floats_tensor,
|
||||
ids_tensor,
|
||||
is_flaky,
|
||||
random_attention_mask,
|
||||
)
|
||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
@ -48,6 +61,8 @@ if is_torch_available():
|
||||
|
||||
from transformers import SiglipForImageClassification, SiglipModel, SiglipTextModel, SiglipVisionModel
|
||||
|
||||
if is_torch_sdpa_available():
|
||||
from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
@ -55,6 +70,155 @@ if is_vision_available():
|
||||
from transformers import SiglipProcessor
|
||||
|
||||
|
||||
class SiglipModelTesterMixin(ModelTesterMixin):
|
||||
def test_eager_matches_sdpa_inference(
|
||||
self,
|
||||
torch_dtype: str,
|
||||
use_attention_mask_options: Tuple[bool, ...] = (True, False),
|
||||
logit_keys: Tuple[str, ...] = ("logits_per_image", "logits_per_text", "image_embeds", "text_embeds"),
|
||||
):
|
||||
if not self.all_model_classes[0]._supports_sdpa:
|
||||
self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
|
||||
|
||||
if torch_dtype == "float16" and not is_torch_fp16_available_on_device(torch_device):
|
||||
self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)")
|
||||
|
||||
if torch_dtype == "bfloat16" and not is_torch_bf16_available_on_device(torch_device):
|
||||
self.skipTest(
|
||||
f"bfloat16 not supported on {torch_device} (on the specific device currently used, e.g. Nvidia T4 GPU)"
|
||||
)
|
||||
|
||||
# Convert to torch dtype
|
||||
dtypes = {
|
||||
"float16": torch.float16,
|
||||
"bfloat16": torch.bfloat16,
|
||||
"float32": torch.float32,
|
||||
}
|
||||
torch_dtype = dtypes[torch_dtype]
|
||||
|
||||
atols = {
|
||||
torch.float32: 1e-5,
|
||||
torch.bfloat16: 3e-2,
|
||||
torch.float16: 5e-3,
|
||||
}
|
||||
rtols = {
|
||||
torch.float32: 1e-4,
|
||||
torch.bfloat16: 3e-2,
|
||||
torch.float16: 5e-3,
|
||||
}
|
||||
|
||||
atol = atols[torch_dtype]
|
||||
rtol = rtols[torch_dtype]
|
||||
|
||||
def get_mean_reldiff(msg, current_case, x, ref, atol, rtol):
|
||||
return f"{msg} {current_case}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}"
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
|
||||
# Load the model with SDPA
|
||||
model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype)
|
||||
model_sdpa = model_sdpa.eval().to(torch_device)
|
||||
|
||||
# Load model with eager attention
|
||||
model_eager = model_class.from_pretrained(
|
||||
tmpdirname,
|
||||
torch_dtype=torch_dtype,
|
||||
attn_implementation="eager",
|
||||
)
|
||||
model_eager = model_eager.eval().to(torch_device)
|
||||
|
||||
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
|
||||
self.assertTrue(model_eager.config._attn_implementation == "eager")
|
||||
|
||||
for name, submodule in model_eager.named_modules():
|
||||
class_name = submodule.__class__.__name__
|
||||
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
|
||||
raise ValueError("The eager model should not have SDPA attention layers")
|
||||
|
||||
has_sdpa = False
|
||||
for name, submodule in model_sdpa.named_modules():
|
||||
class_name = submodule.__class__.__name__
|
||||
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
|
||||
has_sdpa = True
|
||||
break
|
||||
if not has_sdpa and model_sdpa.config.model_type != "falcon":
|
||||
raise ValueError("The SDPA model should have SDPA attention layers")
|
||||
|
||||
# We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving the model each time,
|
||||
# but it would be nicer to have an efficient way to use parameterized.expand
|
||||
cases = [
|
||||
(use_mask, output_attentions, sdpa_backend, batch_size)
|
||||
for use_mask in use_attention_mask_options
|
||||
for output_attentions in [True, False]
|
||||
for sdpa_backend in [
|
||||
SDPBackend.MATH,
|
||||
[SDPBackend.FLASH_ATTENTION, SDPBackend.MATH],
|
||||
[SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH],
|
||||
[SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH],
|
||||
]
|
||||
for batch_size in [1, 5]
|
||||
]
|
||||
fail_cases = []
|
||||
|
||||
for use_mask, output_attentions, sdpa_backend, batch_size in cases:
|
||||
processed_inputs = inputs_dict.copy()
|
||||
|
||||
# convert to torch_dtype
|
||||
if "pixel_values" in processed_inputs:
|
||||
processed_inputs["pixel_values"] = processed_inputs["pixel_values"].to(torch_dtype)
|
||||
|
||||
# slice for different batch sizes
|
||||
for key in ["pixel_values", "input_ids", "attention_mask"]:
|
||||
if key in processed_inputs:
|
||||
processed_inputs[key] = processed_inputs[key][:batch_size]
|
||||
|
||||
# set attention mask with left padding
|
||||
if not use_mask:
|
||||
processed_inputs.pop("attention_mask", None)
|
||||
else:
|
||||
dummy_attention_mask = processed_inputs["attention_mask"]
|
||||
dummy_attention_mask[:] = 1
|
||||
dummy_attention_mask[:, :1] = 0
|
||||
processed_inputs["attention_mask"] = dummy_attention_mask
|
||||
|
||||
processed_inputs["output_attentions"] = output_attentions
|
||||
processed_inputs["output_hidden_states"] = True
|
||||
|
||||
current_case = (
|
||||
f"padding_side=left, use_mask={use_mask}, batch_size={batch_size}, sdpa_backend={sdpa_backend}"
|
||||
)
|
||||
|
||||
prepared_inputs = self._prepare_for_class(processed_inputs, model_class)
|
||||
|
||||
with torch.no_grad():
|
||||
try:
|
||||
with sdpa_kernel(sdpa_backend):
|
||||
outputs_eager = model_eager(**prepared_inputs)
|
||||
outputs_sdpa = model_sdpa(**prepared_inputs)
|
||||
except Exception as e:
|
||||
fail_cases.append(f"{current_case}: {e}")
|
||||
continue
|
||||
|
||||
for key in logit_keys:
|
||||
eager_logits = outputs_eager[key]
|
||||
sdpa_logits = outputs_sdpa[key]
|
||||
|
||||
if use_mask:
|
||||
eager_logits = eager_logits[:, 1:]
|
||||
sdpa_logits = sdpa_logits[:, 1:]
|
||||
|
||||
is_close = torch.allclose(eager_logits, sdpa_logits, atol=atol, rtol=rtol)
|
||||
if not is_close:
|
||||
fail_cases.append(get_mean_reldiff(key, current_case, sdpa_logits, eager_logits, atol, rtol))
|
||||
|
||||
self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases))
|
||||
|
||||
|
||||
class SiglipVisionModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
@ -135,7 +299,7 @@ class SiglipVisionModelTester:
|
||||
|
||||
|
||||
@require_torch
|
||||
class SiglipVisionModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
class SiglipVisionModelTest(SiglipModelTesterMixin, unittest.TestCase):
|
||||
"""
|
||||
Here we also overwrite some of the tests of test_modeling_common.py, as SIGLIP does not use input_ids, inputs_embeds,
|
||||
attention_mask and seq_length.
|
||||
@ -225,6 +389,17 @@ class SiglipVisionModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
model = SiglipVisionModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
|
||||
@require_torch_sdpa
|
||||
@slow
|
||||
@is_flaky()
|
||||
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
|
||||
super().test_eager_matches_sdpa_inference(
|
||||
torch_dtype=torch_dtype,
|
||||
logit_keys=("pooler_output", "last_hidden_state"),
|
||||
use_attention_mask_options=(False,),
|
||||
)
|
||||
|
||||
|
||||
class SiglipTextModelTester:
|
||||
def __init__(
|
||||
@ -314,7 +489,7 @@ class SiglipTextModelTester:
|
||||
|
||||
|
||||
@require_torch
|
||||
class SiglipTextModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
class SiglipTextModelTest(SiglipModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (SiglipTextModel,) if is_torch_available() else ()
|
||||
fx_compatible = False
|
||||
test_pruning = False
|
||||
@ -376,6 +551,17 @@ class SiglipTextModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
model = SiglipTextModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
|
||||
@require_torch_sdpa
|
||||
@slow
|
||||
@is_flaky()
|
||||
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
|
||||
super().test_eager_matches_sdpa_inference(
|
||||
torch_dtype=torch_dtype,
|
||||
logit_keys=("pooler_output", "last_hidden_state"),
|
||||
use_attention_mask_options=(False, True),
|
||||
)
|
||||
|
||||
|
||||
class SiglipModelTester:
|
||||
def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True):
|
||||
@ -429,7 +615,7 @@ class SiglipModelTester:
|
||||
|
||||
|
||||
@require_torch
|
||||
class SiglipModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
class SiglipModelTest(SiglipModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (SiglipModel,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = {"feature-extraction": SiglipModel} if is_torch_available() else {}
|
||||
fx_compatible = False
|
||||
@ -571,6 +757,100 @@ class SiglipModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
model = SiglipModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
@slow
|
||||
def test_flash_attn_2_inference_equivalence(self):
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model_fa = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
|
||||
)
|
||||
model_fa.to(torch_device)
|
||||
|
||||
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
|
||||
model.to(torch_device)
|
||||
|
||||
dummy_pixel_values = inputs_dict["pixel_values"].to(torch.bfloat16)
|
||||
dummy_input_ids = inputs_dict["input_ids"]
|
||||
|
||||
outputs = model(pixel_values=dummy_pixel_values, input_ids=dummy_input_ids, output_hidden_states=True)
|
||||
outputs_fa = model_fa(
|
||||
pixel_values=dummy_pixel_values, input_ids=dummy_input_ids, output_hidden_states=True
|
||||
)
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(outputs.logits_per_image, outputs_fa.logits_per_image, atol=4e-2, rtol=4e-2),
|
||||
f"Image logits max diff: {torch.max(torch.abs(outputs.logits_per_image - outputs_fa.logits_per_image))}",
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(outputs.logits_per_text, outputs_fa.logits_per_text, atol=4e-2, rtol=4e-2),
|
||||
f"Text logits max diff: {torch.max(torch.abs(outputs.logits_per_text - outputs_fa.logits_per_text))}",
|
||||
)
|
||||
|
||||
# Test with attention mask
|
||||
dummy_attention_mask = inputs_dict["attention_mask"]
|
||||
|
||||
if dummy_attention_mask is not None:
|
||||
dummy_attention_mask[:, 1:] = 1
|
||||
dummy_attention_mask[:, :1] = 0
|
||||
|
||||
outputs = model(
|
||||
pixel_values=dummy_pixel_values,
|
||||
input_ids=dummy_input_ids,
|
||||
attention_mask=dummy_attention_mask,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
outputs_fa = model_fa(
|
||||
pixel_values=dummy_pixel_values,
|
||||
input_ids=dummy_input_ids,
|
||||
attention_mask=dummy_attention_mask,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(outputs.logits_per_image, outputs_fa.logits_per_image, atol=4e-2, rtol=4e-2),
|
||||
f"Logits max diff: {torch.max(torch.abs(outputs.logits_per_image - outputs_fa.logits_per_image))}",
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(outputs.logits_per_text, outputs_fa.logits_per_text, atol=4e-2, rtol=4e-2),
|
||||
f"Logits max diff: {torch.max(torch.abs(outputs.logits_per_text - outputs_fa.logits_per_text))}",
|
||||
)
|
||||
|
||||
# check with inference + dropout
|
||||
model.train()
|
||||
_ = model_fa(
|
||||
pixel_values=dummy_pixel_values,
|
||||
input_ids=dummy_input_ids,
|
||||
attention_mask=dummy_attention_mask,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
||||
self.skipTest("SigLIP does not support right padding")
|
||||
|
||||
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
|
||||
@require_torch_sdpa
|
||||
@slow
|
||||
@is_flaky()
|
||||
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
|
||||
super().test_eager_matches_sdpa_inference(
|
||||
torch_dtype=torch_dtype,
|
||||
logit_keys=("logits_per_image", "logits_per_text", "image_embeds", "text_embeds"),
|
||||
use_attention_mask_options=(False, True),
|
||||
)
|
||||
|
||||
|
||||
class SiglipForImageClassificationModelTester(SiglipModelTester):
|
||||
def __init__(self, parent):
|
||||
@ -594,7 +874,7 @@ class SiglipForImageClassificationModelTester(SiglipModelTester):
|
||||
|
||||
|
||||
@require_torch
|
||||
class SiglipForImageClassificationModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
class SiglipForImageClassificationModelTest(SiglipModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (SiglipForImageClassification,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = {"image-classification": SiglipForImageClassification} if is_torch_available() else {}
|
||||
fx_compatible = False
|
||||
@ -636,6 +916,15 @@ class SiglipForImageClassificationModelTest(ModelTesterMixin, PipelineTesterMixi
|
||||
def test_initialization(self):
|
||||
pass
|
||||
|
||||
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
|
||||
@require_torch_sdpa
|
||||
@slow
|
||||
@is_flaky()
|
||||
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
|
||||
super().test_eager_matches_sdpa_inference(
|
||||
torch_dtype=torch_dtype, logit_keys=("logits",), use_attention_mask_options=(False,)
|
||||
)
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
|
Loading…
Reference in New Issue
Block a user