diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 6858d5c4b87..ffdc17fc7dc 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -31,7 +31,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache -from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask +from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_utils import PreTrainedModel from ...utils import ( @@ -612,9 +612,98 @@ class MistralFlashAttention2(MistralAttention): ) +# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mistral +class MistralSdpaAttention(MistralAttention): + """ + Mistral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `MistralAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from MistralAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "MistralModel is using MistralSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal=self.is_causal and attention_mask is None and q_len > 1, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + MISTRAL_ATTENTION_CLASSES = { "eager": MistralAttention, "flash_attention_2": MistralFlashAttention2, + "sdpa": MistralSdpaAttention, } @@ -717,6 +806,7 @@ class MistralPreTrainedModel(PreTrainedModel): _no_split_modules = ["MistralDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_sdpa = True _supports_cache_class = True def _init_weights(self, module): @@ -822,7 +912,7 @@ class MistralModel(MistralPreTrainedModel): self.layers = nn.ModuleList( [MistralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._attn_implementation = config._attn_implementation self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False @@ -893,7 +983,7 @@ class MistralModel(MistralPreTrainedModel): if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if attention_mask is not None and self._use_flash_attention_2 and use_cache: + if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: is_padding_right = attention_mask[:, -1].sum().item() != batch_size if is_padding_right: raise ValueError( @@ -902,9 +992,18 @@ class MistralModel(MistralPreTrainedModel): " call `tokenizer.padding_side = 'left'` before tokenizing the input. " ) - if self._use_flash_attention_2: + if self._attn_implementation == "flash_attention_2": # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self._attn_implementation == "sdpa" and not output_attentions: + # output_attentions=True can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) else: # 4d mask is passed through the layers attention_mask = _prepare_4d_causal_attention_mask( diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 1c2febfbac6..2cd91f10734 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -33,6 +33,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...modeling_attn_mask_utils import ( _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, ) from ...modeling_outputs import ( MoeCausalLMOutputWithPast, @@ -660,6 +661,101 @@ class MixtralFlashAttention2(MixtralAttention): ) +# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mixtral +class MixtralSdpaAttention(MixtralAttention): + """ + Mixtral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `MixtralAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from MixtralAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "MixtralModel is using MixtralSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal=self.is_causal and attention_mask is None and q_len > 1, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +MIXTRAL_ATTENTION_CLASSES = { + "eager": MixtralAttention, + "flash_attention_2": MixtralFlashAttention2, + "sdpa": MixtralSdpaAttention, +} + + class MixtralBLockSparseTop2MLP(nn.Module): def __init__(self, config: MixtralConfig): super().__init__() @@ -678,12 +774,6 @@ class MixtralBLockSparseTop2MLP(nn.Module): return current_hidden_states -MISTRAL_ATTENTION_CLASSES = { - "eager": MixtralAttention, - "flash_attention_2": MixtralFlashAttention2, -} - - class MixtralSparseMoeBlock(nn.Module): """ This implementation is @@ -759,7 +849,7 @@ class MixtralDecoderLayer(nn.Module): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = MISTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) + self.self_attn = MIXTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) self.block_sparse_moe = MixtralSparseMoeBlock(config) self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -861,6 +951,7 @@ class MixtralPreTrainedModel(PreTrainedModel): _no_split_modules = ["MixtralDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_sdpa = True _supports_cache_class = True def _init_weights(self, module): @@ -964,7 +1055,7 @@ class MixtralModel(MixtralPreTrainedModel): self.layers = nn.ModuleList( [MixtralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._attn_implementation = config._attn_implementation self.norm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False @@ -1040,7 +1131,7 @@ class MixtralModel(MixtralPreTrainedModel): if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if attention_mask is not None and self._use_flash_attention_2 and use_cache: + if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: is_padding_right = attention_mask[:, -1].sum().item() != batch_size if is_padding_right: raise ValueError( @@ -1049,9 +1140,18 @@ class MixtralModel(MixtralPreTrainedModel): " call `tokenizer.padding_side = 'left'` before tokenizing the input. " ) - if self._use_flash_attention_2: + if self._attn_implementation == "flash_attention_2": # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self._attn_implementation == "sdpa" and not output_attentions: + # output_attentions=True can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) else: # 4d mask is passed through the layers attention_mask = _prepare_4d_causal_attention_mask( diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py index ed14bdfeb58..5e91e70ecd5 100644 --- a/tests/models/mistral/test_modeling_mistral.py +++ b/tests/models/mistral/test_modeling_mistral.py @@ -28,6 +28,7 @@ from transformers.testing_utils import ( require_flash_attn, require_torch, require_torch_gpu, + require_torch_sdpa, slow, torch_device, ) @@ -528,6 +529,44 @@ class MistralIntegrationTest(unittest.TestCase): backend_empty_cache(torch_device) gc.collect() + @slow + @require_torch_sdpa + def test_model_7b_long_prompt_sdpa(self): + EXPECTED_OUTPUT_TOKEN_IDS = [306, 338] + # An input with 4097 tokens that is above the size of the sliding window + input_ids = [1] + [306, 338] * 2048 + model = MistralForCausalLM.from_pretrained( + "mistralai/Mistral-7B-v0.1", + device_map="auto", + attn_implementation="sdpa", + ) + input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device) + generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0) + self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist()) + + # Assisted generation + assistant_model = model + assistant_model.generation_config.num_assistant_tokens = 2 + assistant_model.generation_config.num_assistant_tokens_schedule = "constant" + generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0) + self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist()) + + del assistant_model + + backend_empty_cache(torch_device) + gc.collect() + + EXPECTED_TEXT_COMPLETION = """My favourite condiment is 100% ketchup. I love it on everything. I’m not a big""" + prompt = "My favourite condiment is " + tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False) + + input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.model.embed_tokens.weight.device) + + # greedy generation outputs + generated_ids = model.generate(input_ids, max_new_tokens=20, temperature=0) + text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, text) + @slow def test_speculative_generation(self): EXPECTED_TEXT_COMPLETION = ( diff --git a/tests/models/mixtral/test_modeling_mixtral.py b/tests/models/mixtral/test_modeling_mixtral.py index eb75314f1ce..78c3d68d470 100644 --- a/tests/models/mixtral/test_modeling_mixtral.py +++ b/tests/models/mixtral/test_modeling_mixtral.py @@ -38,11 +38,7 @@ from ...test_pipeline_mixin import PipelineTesterMixin if is_torch_available(): import torch - from transformers import ( - MixtralForCausalLM, - MixtralForSequenceClassification, - MixtralModel, - ) + from transformers import MixtralForCausalLM, MixtralForSequenceClassification, MixtralModel class MixtralModelTester: