diff --git a/src/transformers/models/t5gemma/modeling_t5gemma.py b/src/transformers/models/t5gemma/modeling_t5gemma.py index 9d9db428e1f..ff8ff935d4d 100644 --- a/src/transformers/models/t5gemma/modeling_t5gemma.py +++ b/src/transformers/models/t5gemma/modeling_t5gemma.py @@ -24,6 +24,8 @@ from typing import Callable, Optional, Union import torch import torch.nn as nn +from transformers.utils.generic import check_model_inputs + from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin @@ -342,13 +344,7 @@ class T5GemmaCrossAttention(nn.Module): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -720,7 +716,7 @@ class T5GemmaEncoder(T5GemmaPreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value - @can_return_tuple + @check_model_inputs def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -793,15 +789,13 @@ class T5GemmaDecoder(T5GemmaEncoder): def __init__(self, config): super().__init__(config) - self.layers = nn.ModuleList( [T5GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) - # Initialize weights and apply final processing self.post_init() - @can_return_tuple + @check_model_inputs def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -841,7 +835,6 @@ class T5GemmaDecoder(T5GemmaEncoder): attention_mask = make_default_2d_attention_mask(input_ids, inputs_embeds, self.config.pad_token_id) if not isinstance(self_attn_mask_mapping := attention_mask, dict): - # Prepare mask arguments mask_kwargs = { "config": self.config, "input_embeds": inputs_embeds, @@ -849,15 +842,12 @@ class T5GemmaDecoder(T5GemmaEncoder): "cache_position": cache_position, "past_key_values": past_key_values.self_attention_cache if past_key_values is not None else None, } - # Create the masks self_attn_mask_mapping = { "full_attention": create_causal_mask(**mask_kwargs), "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs), } - # Attention masks: Cross attention if not isinstance(cross_attn_mask_mapping := encoder_attention_mask, dict): - # Prepare mask arguments mask_kwargs = { "config": self.config, "input_embeds": encoder_hidden_states, @@ -872,15 +862,9 @@ class T5GemmaDecoder(T5GemmaEncoder): ), } - # embed positions hidden_states = inputs_embeds - - # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) - # normalized - # Gemma2 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 - # See https://github.com/huggingface/transformers/pull/29402 normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) hidden_states = hidden_states * normalizer hidden_states = self.dropout(hidden_states) @@ -946,8 +930,6 @@ class T5GemmaModel(T5GemmaPreTrainedModel): inputs_embeds: Optional[torch.Tensor] = None, decoder_inputs_embeds: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> Seq2SeqModelOutput: @@ -962,8 +944,6 @@ class T5GemmaModel(T5GemmaPreTrainedModel): attention_mask=attention_mask, position_ids=position_ids, inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, **kwargs, ) @@ -978,8 +958,6 @@ class T5GemmaModel(T5GemmaPreTrainedModel): encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=attention_mask, use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, cache_position=cache_position, **kwargs, ) @@ -1021,8 +999,6 @@ class T5GemmaEncoderModel(T5GemmaPreTrainedModel): attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutput: encoder_outputs = self.encoder( @@ -1030,8 +1006,6 @@ class T5GemmaEncoderModel(T5GemmaPreTrainedModel): attention_mask=attention_mask, position_ids=position_ids, inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, **kwargs, ) return encoder_outputs @@ -1074,23 +1048,18 @@ class T5GemmaForConditionalGeneration(T5GemmaPreTrainedModel, GenerationMixin): @auto_docstring def forward( self, - # encoder input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, - # decoder decoder_input_ids: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.BoolTensor] = None, decoder_position_ids: Optional[torch.LongTensor] = None, - # others encoder_outputs: Optional[BaseModelOutput] = None, past_key_values: Optional[EncoderDecoderCache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **loss_kwargs, @@ -1130,8 +1099,6 @@ class T5GemmaForConditionalGeneration(T5GemmaPreTrainedModel, GenerationMixin): inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, cache_position=cache_position, **loss_kwargs, ) @@ -1202,21 +1169,16 @@ class T5GemmaForSequenceClassification(T5GemmaPreTrainedModel): @auto_docstring def forward( self, - # encoder input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - # decoder decoder_input_ids: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.Tensor] = None, decoder_position_ids: Optional[torch.LongTensor] = None, - # others encoder_outputs: Optional[BaseModelOutput] = None, inputs_embeds: Optional[torch.FloatTensor] = None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, ) -> SequenceClassifierOutput: r""" decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*): @@ -1254,8 +1216,6 @@ class T5GemmaForSequenceClassification(T5GemmaPreTrainedModel): inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, use_cache=False, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, ) last_hidden_state = outputs.last_hidden_state hidden_states = outputs.decoder_hidden_states @@ -1266,8 +1226,6 @@ class T5GemmaForSequenceClassification(T5GemmaPreTrainedModel): attention_mask=attention_mask, position_ids=position_ids, inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, ) last_hidden_state = outputs.last_hidden_state hidden_states = outputs.hidden_states @@ -1350,21 +1308,16 @@ class T5GemmaForTokenClassification(T5GemmaPreTrainedModel): @auto_docstring def forward( self, - # encoder input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - # decoder decoder_input_ids: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.Tensor] = None, decoder_position_ids: Optional[torch.LongTensor] = None, - # others encoder_outputs: Optional[BaseModelOutput] = None, inputs_embeds: Optional[torch.FloatTensor] = None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, ) -> TokenClassifierOutput: r""" decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*): @@ -1402,8 +1355,6 @@ class T5GemmaForTokenClassification(T5GemmaPreTrainedModel): inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, use_cache=False, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, ) last_hidden_state = outputs.last_hidden_state hidden_states = outputs.decoder_hidden_states @@ -1414,8 +1365,6 @@ class T5GemmaForTokenClassification(T5GemmaPreTrainedModel): attention_mask=attention_mask, position_ids=position_ids, inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, ) last_hidden_state = outputs.last_hidden_state hidden_states = outputs.hidden_states diff --git a/src/transformers/models/t5gemma/modular_t5gemma.py b/src/transformers/models/t5gemma/modular_t5gemma.py index 86ac4ce3ad0..0d9a10f5418 100644 --- a/src/transformers/models/t5gemma/modular_t5gemma.py +++ b/src/transformers/models/t5gemma/modular_t5gemma.py @@ -18,6 +18,8 @@ from typing import Any, Callable, Optional, Union import torch import torch.nn as nn +from transformers.utils.generic import check_model_inputs + from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...configuration_utils import PretrainedConfig from ...generation import GenerationMixin @@ -290,13 +292,7 @@ class T5GemmaCrossAttention(Gemma2Attention): attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -582,7 +578,7 @@ class T5GemmaEncoder(T5GemmaPreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value - @can_return_tuple + @check_model_inputs def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -655,15 +651,13 @@ class T5GemmaDecoder(T5GemmaEncoder): def __init__(self, config): super().__init__(config) - self.layers = nn.ModuleList( [T5GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) - # Initialize weights and apply final processing self.post_init() - @can_return_tuple + @check_model_inputs def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -703,7 +697,6 @@ class T5GemmaDecoder(T5GemmaEncoder): attention_mask = make_default_2d_attention_mask(input_ids, inputs_embeds, self.config.pad_token_id) if not isinstance(self_attn_mask_mapping := attention_mask, dict): - # Prepare mask arguments mask_kwargs = { "config": self.config, "input_embeds": inputs_embeds, @@ -711,15 +704,12 @@ class T5GemmaDecoder(T5GemmaEncoder): "cache_position": cache_position, "past_key_values": past_key_values.self_attention_cache if past_key_values is not None else None, } - # Create the masks self_attn_mask_mapping = { "full_attention": create_causal_mask(**mask_kwargs), "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs), } - # Attention masks: Cross attention if not isinstance(cross_attn_mask_mapping := encoder_attention_mask, dict): - # Prepare mask arguments mask_kwargs = { "config": self.config, "input_embeds": encoder_hidden_states, @@ -734,15 +724,9 @@ class T5GemmaDecoder(T5GemmaEncoder): ), } - # embed positions hidden_states = inputs_embeds - - # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) - # normalized - # Gemma2 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 - # See https://github.com/huggingface/transformers/pull/29402 normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) hidden_states = hidden_states * normalizer hidden_states = self.dropout(hidden_states) @@ -808,8 +792,6 @@ class T5GemmaModel(T5GemmaPreTrainedModel): inputs_embeds: Optional[torch.Tensor] = None, decoder_inputs_embeds: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> Seq2SeqModelOutput: @@ -824,8 +806,6 @@ class T5GemmaModel(T5GemmaPreTrainedModel): attention_mask=attention_mask, position_ids=position_ids, inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, **kwargs, ) @@ -840,8 +820,6 @@ class T5GemmaModel(T5GemmaPreTrainedModel): encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=attention_mask, use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, cache_position=cache_position, **kwargs, ) @@ -883,8 +861,6 @@ class T5GemmaEncoderModel(T5GemmaPreTrainedModel): attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutput: encoder_outputs = self.encoder( @@ -892,8 +868,6 @@ class T5GemmaEncoderModel(T5GemmaPreTrainedModel): attention_mask=attention_mask, position_ids=position_ids, inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, **kwargs, ) return encoder_outputs @@ -936,23 +910,18 @@ class T5GemmaForConditionalGeneration(T5GemmaPreTrainedModel, GenerationMixin): @auto_docstring def forward( self, - # encoder input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, - # decoder decoder_input_ids: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.BoolTensor] = None, decoder_position_ids: Optional[torch.LongTensor] = None, - # others encoder_outputs: Optional[BaseModelOutput] = None, past_key_values: Optional[EncoderDecoderCache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **loss_kwargs, @@ -992,8 +961,6 @@ class T5GemmaForConditionalGeneration(T5GemmaPreTrainedModel, GenerationMixin): inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, cache_position=cache_position, **loss_kwargs, ) @@ -1064,21 +1031,16 @@ class T5GemmaForSequenceClassification(T5GemmaPreTrainedModel): @auto_docstring def forward( self, - # encoder input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - # decoder decoder_input_ids: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.Tensor] = None, decoder_position_ids: Optional[torch.LongTensor] = None, - # others encoder_outputs: Optional[BaseModelOutput] = None, inputs_embeds: Optional[torch.FloatTensor] = None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, ) -> SequenceClassifierOutput: r""" decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*): @@ -1116,8 +1078,6 @@ class T5GemmaForSequenceClassification(T5GemmaPreTrainedModel): inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, use_cache=False, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, ) last_hidden_state = outputs.last_hidden_state hidden_states = outputs.decoder_hidden_states @@ -1128,8 +1088,6 @@ class T5GemmaForSequenceClassification(T5GemmaPreTrainedModel): attention_mask=attention_mask, position_ids=position_ids, inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, ) last_hidden_state = outputs.last_hidden_state hidden_states = outputs.hidden_states @@ -1212,21 +1170,16 @@ class T5GemmaForTokenClassification(T5GemmaPreTrainedModel): @auto_docstring def forward( self, - # encoder input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - # decoder decoder_input_ids: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.Tensor] = None, decoder_position_ids: Optional[torch.LongTensor] = None, - # others encoder_outputs: Optional[BaseModelOutput] = None, inputs_embeds: Optional[torch.FloatTensor] = None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, ) -> TokenClassifierOutput: r""" decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*): @@ -1264,8 +1217,6 @@ class T5GemmaForTokenClassification(T5GemmaPreTrainedModel): inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, use_cache=False, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, ) last_hidden_state = outputs.last_hidden_state hidden_states = outputs.decoder_hidden_states @@ -1276,8 +1227,6 @@ class T5GemmaForTokenClassification(T5GemmaPreTrainedModel): attention_mask=attention_mask, position_ids=position_ids, inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, ) last_hidden_state = outputs.last_hidden_state hidden_states = outputs.hidden_states