mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 04:40:06 +06:00
[tests] expand flex-attn test for vision models (#38434)
* expand the test for VLMs * typo * mark models `supports_flex` + expand test for additional kwargs * flex attn for refactored vision models * fix copies * fix * unskip * style * address comments
This commit is contained in:
parent
de4cf5a38e
commit
bf68dd9e6e
@ -12,24 +12,52 @@ from torch import nn
|
||||
|
||||
from ...cache_utils import Cache, HybridCache, StaticCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import BaseModelOutputWithPast
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ..auto import AutoModel, AutoModelForCausalLM
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import ModelOutput, auto_docstring, can_return_tuple, is_torchdynamo_compiling
|
||||
from ..auto import AutoModel
|
||||
from .configuration_new_task_model import NewTaskModelConfig
|
||||
|
||||
|
||||
_CONFIG_FOR_DOC = "NewTaskModelConfig"
|
||||
@dataclass
|
||||
class NewTaskModelModelOutputWithPast(BaseModelOutputWithPast):
|
||||
"""
|
||||
Base class for NewTaskModel outputs, with hidden states and attentions.
|
||||
|
||||
Args:
|
||||
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
Sequence of hidden-states at the output of the last layer of the model.
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
||||
`past_key_values` input) to speed up sequential decoding.
|
||||
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
||||
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
||||
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||||
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
||||
sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
image_hidden_states (`torch.FloatTensor`, *optional*):
|
||||
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
|
||||
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
|
||||
"""
|
||||
|
||||
image_hidden_states: Optional[torch.FloatTensor] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class NewTaskModelCausalLMOutputWithPast(ModelOutput):
|
||||
"""
|
||||
Base class for NewTaskModelcausal language model (or autoregressive) outputs.
|
||||
Base class for NewTaskModel causal language model (or autoregressive) outputs.
|
||||
|
||||
Args:
|
||||
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
||||
@ -77,30 +105,10 @@ class NewTaskModelMultiModalProjector(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
NEW_TASK_MODEL_START_DOCSTRING = r"""
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||
etc.)
|
||||
|
||||
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
||||
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
||||
and behavior.
|
||||
|
||||
Parameters:
|
||||
config ([`NewTaskModelConfig`] or [`NewTaskModelVisionConfig`]):
|
||||
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
||||
load the weights associated with the model, only the configuration. Check out the
|
||||
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
|
||||
NEW_TASK_MODEL_START_DOCSTRING,
|
||||
)
|
||||
@auto_docstring
|
||||
class NewTaskModelPreTrainedModel(PreTrainedModel):
|
||||
config_class = NewTaskModelConfig
|
||||
base_model_prefix = "model"
|
||||
base_model_prefix = ""
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["NewTaskModelMultiModalProjector"]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
@ -109,6 +117,8 @@ class NewTaskModelPreTrainedModel(PreTrainedModel):
|
||||
_supports_static_cache = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
# important: this ported version of NewTaskModelisn't meant for training from scratch - only
|
||||
@ -121,102 +131,24 @@ class NewTaskModelPreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
NEW_TASK_MODEL_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
||||
it.
|
||||
|
||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
|
||||
The tensors corresponding to the input images. Pixel values can be obtained using
|
||||
[`AutoImageProcessor`]. See [`SiglipImageProcessor.__call__`] for details ([]`NewTaskModelProcessor`] uses
|
||||
[`SiglipImageProcessor`] for processing images).
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
|
||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
|
||||
`past_key_values`).
|
||||
|
||||
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
||||
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
||||
information on the default strategy.
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
||||
config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
||||
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
||||
|
||||
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
||||
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
||||
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
||||
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
||||
model's internal embedding lookup matrix.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
||||
`past_key_values`).
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
|
||||
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
|
||||
the complete sequence length.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""The NEW_TASK_MODEL model which consists of a vision backbone and a language model.""",
|
||||
NEW_TASK_MODEL_START_DOCSTRING,
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
The Base NewTaskModel model which consists of a vision backbone and a language model withou language modeling head.,
|
||||
"""
|
||||
)
|
||||
class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
|
||||
main_input_name: ClassVar[str] = "doc_input_ids" # transformers-related
|
||||
class NewTaskModelModel(NewTaskModelPreTrainedModel):
|
||||
_checkpoint_conversion_mapping = {"language_model.model": "language_model"}
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(self, config: NewTaskModelConfig):
|
||||
super().__init__(config)
|
||||
self.vision_tower = AutoModel.from_config(config=config.vision_config)
|
||||
self.multi_modal_projector = NewTaskModelMultiModalProjector(config)
|
||||
self.vocab_size = config.text_config.vocab_size
|
||||
|
||||
language_model = AutoModelForCausalLM.from_config(config=config.text_config)
|
||||
|
||||
if language_model._tied_weights_keys is not None:
|
||||
self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
|
||||
language_model = AutoModel.from_config(config=config.text_config)
|
||||
self.language_model = language_model
|
||||
|
||||
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
|
||||
|
||||
self.embedding_dim = self.config.embedding_dim
|
||||
self.custom_text_proj = nn.Linear(self.config.text_config.hidden_size, self.embedding_dim)
|
||||
|
||||
if self.language_model._tied_weights_keys is not None:
|
||||
self._tied_weights_keys = [f"model.language_model.{k}" for k in self.language_model._tied_weights_keys]
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
@ -225,18 +157,6 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
|
||||
def set_input_embeddings(self, value):
|
||||
self.language_model.set_input_embeddings(value)
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.language_model.get_output_embeddings()
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.language_model.set_output_embeddings(new_embeddings)
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.language_model.set_decoder(decoder)
|
||||
|
||||
def get_decoder(self):
|
||||
return self.language_model.get_decoder()
|
||||
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask,
|
||||
@ -321,8 +241,191 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
|
||||
image_features = image_features / (self.config.text_config.hidden_size**0.5)
|
||||
return image_features
|
||||
|
||||
@add_start_docstrings_to_model_forward(NEW_TASK_MODEL_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=NewTaskModelCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[Tuple, NewTaskModelModelOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, NewTaskModelForConditionalGeneration
|
||||
|
||||
>>> model = NewTaskModelForConditionalGeneration.from_pretrained("google/new_task_model2-3b-mix-224")
|
||||
>>> processor = AutoProcessor.from_pretrained("google/new_task_model2-3b-mix-224")
|
||||
|
||||
>>> prompt = "Where is the cat standing?"
|
||||
>>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(**inputs,)
|
||||
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Where is the cat standing?\nsnow"
|
||||
```"""
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
is_training = token_type_ids is not None and labels is not None
|
||||
|
||||
# Replace image id woth PAD if the image token if OOV, to avoid index-errors
|
||||
if input_ids is not None and self.config.image_token_id >= self.vocab_size:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
llm_input_ids = input_ids.clone()
|
||||
llm_input_ids[special_image_mask] = 0
|
||||
else:
|
||||
llm_input_ids = input_ids
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
cache_position = torch.arange(
|
||||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||
)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0) + 1 # NewTaskModel positions are 1-indexed
|
||||
|
||||
# Merge text and images
|
||||
if pixel_values is not None:
|
||||
image_features = self.get_image_features(pixel_values)
|
||||
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
else:
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
|
||||
raise ValueError(
|
||||
f"Number of images does not match number of special image tokens in the input text. "
|
||||
f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
|
||||
"tokens from image embeddings."
|
||||
)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training
|
||||
)
|
||||
outputs = self.language_model(
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return NewTaskModelModelOutputWithPast(
|
||||
last_hidden_state=outputs.last_hidden_state,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
image_hidden_states=image_features if pixel_values is not None else None,
|
||||
)
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
The Base NewTaskModel model which consists of a vision backbone and a language model without language modeling head.,
|
||||
"""
|
||||
)
|
||||
class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
|
||||
_checkpoint_conversion_mapping = {
|
||||
"^language_model.model": "model.language_model",
|
||||
"^vision_tower": "model.vision_tower",
|
||||
"^multi_modal_projector": "model.multi_modal_projector",
|
||||
"^language_model.lm_head": "lm_head",
|
||||
}
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
main_input_name: ClassVar[str] = "doc_input_ids" # transformers-related
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.model = NewTaskModelModel(config)
|
||||
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
|
||||
|
||||
self.embedding_dim = self.config.embedding_dim
|
||||
self.custom_text_proj = nn.Linear(self.config.text_config.hidden_size, self.embedding_dim)
|
||||
|
||||
if self.language_model._tied_weights_keys is not None:
|
||||
self._tied_weights_keys = [f"model.language_model.{k}" for k in self.language_model._tied_weights_keys]
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.get_input_embeddings()
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.set_input_embeddings(value)
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.model = decoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model
|
||||
|
||||
# Make modules available throught conditional class for BC
|
||||
@property
|
||||
def language_model(self):
|
||||
return self.model.language_model
|
||||
|
||||
@property
|
||||
def vision_tower(self):
|
||||
return self.model.vision_tower
|
||||
|
||||
@property
|
||||
def multi_modal_projector(self):
|
||||
return self.model.multi_modal_projector
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
@ -341,19 +444,10 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
|
||||
num_logits_to_keep: int = 0,
|
||||
) -> Union[Tuple, NewTaskModelCausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
Returns:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
|
||||
|
||||
Example:
|
||||
|
||||
@ -400,7 +494,8 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
|
||||
# L2 normalization
|
||||
embeddings = proj / proj.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim)
|
||||
|
||||
embeddings = embeddings * attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim)
|
||||
if attention_mask is not None:
|
||||
embeddings = embeddings * attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim)
|
||||
|
||||
return (embeddings,) + vlm_outputs
|
||||
|
||||
@ -420,7 +515,7 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
|
||||
**kwargs,
|
||||
):
|
||||
# Overwritten -- custom `position_ids` and `pixel_values` handling
|
||||
model_inputs = self.language_model.prepare_inputs_for_generation(
|
||||
model_inputs = super().prepare_inputs_for_generation(
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
@ -443,13 +538,68 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
|
||||
is_training = token_type_ids is not None and labels is not None
|
||||
if cache_position[0] == 0 and isinstance(past_key_values, HybridCache):
|
||||
input_tensor = inputs_embeds if inputs_embeds is not None else input_ids
|
||||
causal_mask = self._update_causal_mask(
|
||||
causal_mask = self.model._update_causal_mask(
|
||||
attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training
|
||||
)
|
||||
model_inputs["attention_mask"] = causal_mask
|
||||
|
||||
return model_inputs
|
||||
|
||||
@staticmethod
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
||||
`(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
def resize_token_embeddings(
|
||||
self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None, mean_resizing=True
|
||||
) -> nn.Embedding:
|
||||
|
@ -65,7 +65,8 @@ class NewTaskModelForNewTask(PaliGemmaForConditionalGeneration):
|
||||
# L2 normalization
|
||||
embeddings = proj / proj.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim)
|
||||
|
||||
embeddings = embeddings * attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim)
|
||||
if attention_mask is not None:
|
||||
embeddings = embeddings * attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim)
|
||||
|
||||
return (embeddings,) + vlm_outputs
|
||||
|
||||
|
@ -384,6 +384,8 @@ class ASTPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_supports_sdpa = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_flex_attn = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
||||
"""Initialize the weights"""
|
||||
|
@ -97,6 +97,7 @@ class AyaVisionPreTrainedModel(PreTrainedModel):
|
||||
_supports_sdpa = True
|
||||
_supports_quantized_cache = False
|
||||
_supports_static_cache = False
|
||||
_supports_flex_attn = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
|
@ -830,6 +830,7 @@ class ChameleonPreTrainedModel(PreTrainedModel):
|
||||
_supports_cache_class = True
|
||||
_supports_static_cache = True
|
||||
_supports_param_buffer_assignment = False
|
||||
_supports_flex_attn = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
|
@ -450,6 +450,8 @@ class CLIPPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_supports_sdpa = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_flex_attn = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
|
@ -162,7 +162,8 @@ class ColPaliForRetrieval(ColPaliPreTrainedModel):
|
||||
# L2 normalization
|
||||
embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim)
|
||||
|
||||
embeddings = embeddings * attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim)
|
||||
if attention_mask is not None:
|
||||
embeddings = embeddings * attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim)
|
||||
|
||||
return ColPaliForRetrievalOutput(
|
||||
embeddings=embeddings,
|
||||
|
@ -450,6 +450,8 @@ class DeiTPreTrainedModel(PreTrainedModel):
|
||||
_no_split_modules = ["DeiTLayer"]
|
||||
_supports_sdpa = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_flex_attn = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
||||
"""Initialize the weights"""
|
||||
|
@ -494,6 +494,8 @@ class Dinov2PreTrainedModel(PreTrainedModel):
|
||||
_no_split_modules = ["Dinov2SwiGLUFFN"]
|
||||
_supports_sdpa = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_flex_attn = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
||||
"""Initialize the weights"""
|
||||
|
@ -512,6 +512,8 @@ class Dinov2WithRegistersPreTrainedModel(PreTrainedModel):
|
||||
_no_split_modules = ["Dinov2WithRegistersSwiGLUFFN"]
|
||||
_supports_sdpa = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_flex_attn = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
||||
"""Initialize the weights"""
|
||||
|
@ -826,6 +826,8 @@ class DPTPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_supports_sdpa = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_flex_attn = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
|
@ -1142,8 +1142,8 @@ class Emu3PreTrainedModel(PreTrainedModel):
|
||||
_supports_cache_class = True
|
||||
_supports_static_cache = True
|
||||
_supports_param_buffer_assignment = False
|
||||
_supports_attention_backend = True
|
||||
_supports_flex_attn = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.get_text_config().initializer_range
|
||||
|
@ -593,6 +593,7 @@ class GotOcr2PreTrainedModel(PreTrainedModel):
|
||||
_supports_sdpa = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
_supports_flex_attn = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
|
@ -151,6 +151,8 @@ class IJepaPreTrainedModel(PreTrainedModel):
|
||||
_no_split_modules = ["IJepaEmbeddings", "IJepaLayer"]
|
||||
_supports_sdpa = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_flex_attn = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
||||
"""Initialize the weights"""
|
||||
|
@ -95,6 +95,8 @@ class IJepaPreTrainedModel(PreTrainedModel):
|
||||
_no_split_modules = ["IJepaEmbeddings", "IJepaLayer"]
|
||||
_supports_sdpa = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_flex_attn = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
||||
"""Initialize the weights"""
|
||||
|
@ -178,6 +178,8 @@ class InternVLVisionPreTrainedModel(PreTrainedModel):
|
||||
_no_split_modules = ["InternVLVisionLayer"]
|
||||
_supports_sdpa = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_flex_attn = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
@ -537,6 +539,7 @@ class InternVLPreTrainedModel(PreTrainedModel):
|
||||
_supports_sdpa = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
_supports_flex_attn = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
|
@ -140,6 +140,8 @@ class InternVLVisionPreTrainedModel(PreTrainedModel):
|
||||
_no_split_modules = ["InternVLVisionLayer"]
|
||||
_supports_sdpa = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_flex_attn = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
|
@ -141,6 +141,7 @@ class LlavaPreTrainedModel(PreTrainedModel):
|
||||
_supports_sdpa = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
_supports_flex_attn = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
|
@ -252,6 +252,7 @@ class LlavaNextPreTrainedModel(PreTrainedModel):
|
||||
_supports_sdpa = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
_supports_flex_attn = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
|
@ -195,6 +195,7 @@ class LlavaNextVideoPreTrainedModel(PreTrainedModel):
|
||||
_supports_sdpa = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
_supports_flex_attn = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
|
@ -308,6 +308,7 @@ class LlavaOnevisionPreTrainedModel(PreTrainedModel):
|
||||
_supports_sdpa = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
_supports_flex_attn = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
|
@ -206,6 +206,7 @@ class Mistral3PreTrainedModel(PreTrainedModel):
|
||||
_supports_sdpa = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
_supports_flex_attn = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
|
@ -860,6 +860,7 @@ class MllamaPreTrainedModel(PreTrainedModel):
|
||||
_supports_sdpa = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_flex_attn = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
|
@ -131,6 +131,7 @@ class PaliGemmaPreTrainedModel(PreTrainedModel):
|
||||
_supports_static_cache = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
|
@ -375,6 +375,7 @@ class Phi4MultimodalVisionPreTrainedModel(PreTrainedModel):
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
|
@ -517,6 +517,8 @@ class SiglipPreTrainedModel(PreTrainedModel):
|
||||
]
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
|
@ -749,6 +749,8 @@ class Siglip2PreTrainedModel(PreTrainedModel):
|
||||
]
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
|
@ -491,6 +491,8 @@ class VideoMAEPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_supports_sdpa = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_flex_attn = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
|
@ -142,6 +142,7 @@ class VipLlavaPreTrainedModel(PreTrainedModel):
|
||||
_supports_sdpa = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
_supports_flex_attn = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
|
@ -448,6 +448,8 @@ class ViTPreTrainedModel(PreTrainedModel):
|
||||
_no_split_modules = ["ViTEmbeddings", "ViTLayer"]
|
||||
_supports_sdpa = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_flex_attn = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
||||
"""Initialize the weights"""
|
||||
|
@ -633,6 +633,8 @@ class ViTMAEPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_supports_sdpa = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_flex_attn = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
|
@ -452,6 +452,8 @@ class ViTMSNPreTrainedModel(PreTrainedModel):
|
||||
_no_split_modules = ["ViTMSNAttention", "ViTMSNSdpaAttention"]
|
||||
_supports_sdpa = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_flex_attn = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
# todo: Resort to https://github.com/facebookresearch/msn/blob/main/src/deit.py#L200-#L211
|
||||
# when creating pre-training scripts.
|
||||
|
@ -456,6 +456,8 @@ class VivitPreTrainedModel(PreTrainedModel):
|
||||
_no_split_modules = []
|
||||
_supports_sdpa = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_flex_attn = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
|
@ -532,6 +532,8 @@ class YolosPreTrainedModel(PreTrainedModel):
|
||||
_no_split_modules = []
|
||||
_supports_sdpa = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_flex_attn = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
||||
"""Initialize the weights"""
|
||||
|
@ -535,6 +535,7 @@ class CLIPModelTest(CLIPModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
pipeline_model_mapping = (
|
||||
{"feature-extraction": CLIPModel, "image-feature-extraction": CLIPVisionModel} if is_torch_available() else {}
|
||||
)
|
||||
additional_model_inputs = ["pixel_values"]
|
||||
fx_compatible = True
|
||||
test_head_masking = False
|
||||
test_pruning = False
|
||||
|
@ -401,10 +401,6 @@ class Emu3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, Pipeline
|
||||
def test_generate_with_static_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Emu3 doesn't support Flex attn yet!")
|
||||
def test_flex_attention_with_grads(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class Emu3IntegrationTest(unittest.TestCase):
|
||||
|
@ -351,12 +351,6 @@ class Gemma3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unitte
|
||||
def test_initialization(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
reason="Siglip has no FLEX attention, and we don't have a proper way to set/test attn in VLMs. TODO @raushan"
|
||||
)
|
||||
def test_flex_attention_with_grads(self):
|
||||
pass
|
||||
|
||||
def test_automodelforcausallm(self):
|
||||
"""
|
||||
Regression test for #36741/#36917 -- make sure `AutoModelForCausalLM` works with a Gemma3 config, i.e. that
|
||||
|
@ -236,10 +236,6 @@ class GotOcr2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
def test_past_key_values_format(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Vision backbone doesn't support FLEX yet!")
|
||||
def test_flex_attention_with_grads(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class GotOcr2IntegrationTest(unittest.TestCase):
|
||||
|
@ -569,6 +569,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
all_generative_model_classes = ()
|
||||
greedy_sample_model_classes = (MusicgenForConditionalGeneration,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = {"text-to-audio": MusicgenForConditionalGeneration} if is_torch_available() else {}
|
||||
# Addition keys that are required for forward. MusicGen isn't encoder-decoder in config so we have to pass decoder ids as additional
|
||||
additional_model_inputs = ["decoder_input_ids"]
|
||||
test_pruning = False # training is not supported yet for MusicGen
|
||||
test_headmasking = False
|
||||
test_resize_embeddings = False
|
||||
|
@ -589,6 +589,8 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
||||
all_generative_model_classes = ()
|
||||
greedy_sample_model_classes = (MusicgenMelodyForConditionalGeneration,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = {"text-to-audio": MusicgenMelodyForConditionalGeneration} if is_torch_available() else {}
|
||||
# Addition keys that are required for forward. MusicGen isn't encoder-decoder in config so we have to pass decoder ids as additional
|
||||
additional_model_inputs = ["decoder_input_ids"]
|
||||
test_pruning = False # training is not supported yet for MusicGen
|
||||
test_headmasking = False
|
||||
test_resize_embeddings = False
|
||||
|
@ -103,7 +103,7 @@ class SiglipVisionModelTester:
|
||||
patch_size=2,
|
||||
num_channels=3,
|
||||
is_training=True,
|
||||
hidden_size=32,
|
||||
hidden_size=64,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
@ -274,7 +274,7 @@ class SiglipTextModelTester:
|
||||
use_input_mask=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
hidden_size=64,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
|
@ -180,7 +180,7 @@ class Siglip2VisionModelTester:
|
||||
patch_size=2,
|
||||
num_channels=3,
|
||||
is_training=True,
|
||||
hidden_size=32,
|
||||
hidden_size=64,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
@ -363,7 +363,7 @@ class Siglip2TextModelTester:
|
||||
use_input_mask=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
hidden_size=64,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
|
@ -190,7 +190,8 @@ class VideoMAEModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
|
||||
# Addition keys that are required for forward, used in tests where we manipulate and create new input dict from scratch
|
||||
additional_model_inputs = ["bool_masked_pos"]
|
||||
test_pruning = False
|
||||
test_torchscript = False
|
||||
test_resize_embeddings = False
|
||||
|
@ -322,10 +322,6 @@ class VipLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTest
|
||||
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("LLaVA vision backbones doesn't support flex attention yet")
|
||||
def test_flex_attention_with_grads(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class VipLlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
|
@ -3637,7 +3637,10 @@ class ModelTesterMixin:
|
||||
processed_inputs[model.main_input_name] = inputs_dict[model.main_input_name]
|
||||
|
||||
for key in getattr(self, "additional_model_inputs", []):
|
||||
processed_inputs[key] = inputs_dict[key]
|
||||
# Some models don't have all `additional_model_inputs`, especially when we
|
||||
# craft cases to test model in different settings
|
||||
if key in inputs_dict:
|
||||
processed_inputs[key] = inputs_dict[key]
|
||||
|
||||
for key, value in processed_inputs.items():
|
||||
if torch.is_floating_point(value):
|
||||
@ -4012,19 +4015,21 @@ class ModelTesterMixin:
|
||||
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype)
|
||||
|
||||
sub_models_supporting_fa2 = [
|
||||
(module._supports_flash_attn_2 or module._supports_attention_backend)
|
||||
module._supports_flash_attn_2
|
||||
for name, module in model.named_modules()
|
||||
if isinstance(module, PreTrainedModel) and name != ""
|
||||
]
|
||||
supports_fa2_all_modules = (
|
||||
all(sub_models_supporting_fa2)
|
||||
if len(sub_models_supporting_fa2) > 0
|
||||
else (model._supports_flash_attn_2 or model._supports_attention_backend)
|
||||
else model._supports_flash_attn_2
|
||||
)
|
||||
if not supports_fa2_all_modules:
|
||||
with self.assertRaises(ValueError):
|
||||
model_fa2 = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch_dtype, attn_implementation="flash_attention_2"
|
||||
tmpdirname,
|
||||
torch_dtype=torch_dtype,
|
||||
attn_implementation="flash_attention_2",
|
||||
)
|
||||
else:
|
||||
model_fa2 = model_class.from_pretrained(
|
||||
@ -4572,33 +4577,73 @@ class ModelTesterMixin:
|
||||
@require_torch_gpu
|
||||
def test_flex_attention_with_grads(self):
|
||||
for model_class in self.all_model_classes:
|
||||
# TODO: raushan, fix for composite models after making VLMs support new attn API
|
||||
if not model_class._supports_flex_attn or self._is_composite:
|
||||
self.skipTest(reason="This model does not support flex attention")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config._attn_implementation = "flex_attention"
|
||||
# Flex Attention cannot use dropout
|
||||
if hasattr(config, "attention_dropout"):
|
||||
config.attention_dropout = 0
|
||||
if hasattr(config, "attention_probs_dropout_prob"):
|
||||
config.attention_probs_dropout_prob = 0
|
||||
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
model = model_class(config).to(device=torch_device)
|
||||
|
||||
# Flex attention relies on triton on compilation
|
||||
# However, triton cannot handle hidden dimensions of less than 16
|
||||
# --> forcing at least a hidden dim of 16
|
||||
config.hidden_size *= max(
|
||||
16 // getattr(config, "head_dim", config.hidden_size // config.num_attention_heads), 1
|
||||
# If not all sub-models support flex, skip the test
|
||||
sub_models_supporting_flex = [
|
||||
module._supports_flex_attn
|
||||
for name, module in model.named_modules()
|
||||
if isinstance(module, PreTrainedModel) and name != ""
|
||||
]
|
||||
supports_flex_all_modules = (all(sub_models_supporting_flex) and len(sub_models_supporting_flex) > 0) or (
|
||||
model._supports_flex_attn and len(sub_models_supporting_flex) == 0
|
||||
)
|
||||
if hasattr(config, "head_dim"):
|
||||
config.head_dim = max(16, config.head_dim)
|
||||
if not supports_flex_all_modules:
|
||||
self.skipTest(reason="This model's submodels does not support flex attention")
|
||||
|
||||
def update_config_for_flex(config):
|
||||
# Flex Attention cannot use dropout
|
||||
if hasattr(config, "attention_dropout"):
|
||||
config.attention_dropout = 0
|
||||
if hasattr(config, "attention_probs_dropout_prob"):
|
||||
config.attention_probs_dropout_prob = 0
|
||||
|
||||
# Flex attention relies on triton on compilation
|
||||
# However, triton cannot handle hidden dimensions of less than 16
|
||||
# --> forcing at least a hidden dim of 16
|
||||
|
||||
# Update the head dim and try to update hidden size as well if present in config
|
||||
# NOTE: some models may have none if the values in sub-config, thus we check for `Noneness`
|
||||
head_dim = None
|
||||
if hasattr(config, "head_dim") and config.head_dim is not None:
|
||||
head_dim = config.head_dim
|
||||
config.head_dim = max(16, config.head_dim)
|
||||
|
||||
if (
|
||||
getattr(config, "hidden_size", None) is not None
|
||||
and getattr(config, "num_attention_heads", None) is not None
|
||||
):
|
||||
head_dim = head_dim if head_dim is not None else config.hidden_size // config.num_attention_heads
|
||||
config.hidden_size *= max(16 // head_dim, 1)
|
||||
|
||||
if (
|
||||
getattr(config, "decoder_hidden_size", None) is not None
|
||||
and getattr(config, "decoder_num_attention_heads", None) is not None
|
||||
):
|
||||
decoder_head_dim = config.decoder_hidden_size // config.decoder_num_attention_heads
|
||||
config.decoder_hidden_size *= max(16 // decoder_head_dim, 1)
|
||||
|
||||
# Set default attention to flex and update config values
|
||||
update_config_for_flex(config)
|
||||
for key in config.sub_configs:
|
||||
sub_config = getattr(config, key)
|
||||
update_config_for_flex(sub_config)
|
||||
|
||||
config._attn_implementation = "flex_attention"
|
||||
model = model_class(config).to(device=torch_device)
|
||||
self.assertTrue(model.config._attn_implementation == "flex_attention")
|
||||
|
||||
# Elaborate workaround for encoder-decoder models as some do not specify their main input
|
||||
dummy_inputs = {model.main_input_name: inputs_dict[model.main_input_name].to(torch_device)}
|
||||
if config.is_encoder_decoder:
|
||||
for key in getattr(self, "additional_model_inputs", []):
|
||||
# Some models don't have all `additional_model_inputs`, especially when we
|
||||
# craft cases to test model in different settings
|
||||
if key in inputs_dict:
|
||||
dummy_inputs[key] = inputs_dict[key].to(torch_device)
|
||||
|
||||
if config.get_text_config(decoder=True).is_encoder_decoder:
|
||||
dummy_inputs["decoder_input_ids"] = inputs_dict["decoder_input_ids"].to(torch_device)
|
||||
dummy_inputs["decoder_attention_mask"] = inputs_dict["decoder_attention_mask"].to(torch_device)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user