From dbc98328da2cabe7938423c51569252f2b49a5b3 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 1 Jul 2025 10:34:53 +0200 Subject: [PATCH] Several fixes for Gemma3n (#39135) * remove the skips * fix the epsilon to a small value (does not make sense otherwise) * safeguard * overload test_eager_matches_sdpa * Update test_modeling_common.py * skip appropriate tests * correct no_split_layer * fix all devices issue * fix backward * fix --- .../models/gemma3n/modeling_gemma3n.py | 82 ++- .../models/gemma3n/modular_gemma3n.py | 82 ++- src/transformers/testing_utils.py | 1 - tests/models/gemma3n/test_modeling_gemma3n.py | 70 +- tests/test_modeling_common.py | 646 +++++++++--------- 5 files changed, 491 insertions(+), 390 deletions(-) diff --git a/src/transformers/models/gemma3n/modeling_gemma3n.py b/src/transformers/models/gemma3n/modeling_gemma3n.py index d29bacf91e2..3a4995610d4 100644 --- a/src/transformers/models/gemma3n/modeling_gemma3n.py +++ b/src/transformers/models/gemma3n/modeling_gemma3n.py @@ -1135,9 +1135,17 @@ class Gemma3nTextAltUp(nn.Module): corrected += predictions # add the original input return corrected.contiguous().type_as(activated) + def forward(self, corrected: torch.Tensor) -> torch.Tensor: + """ + This is only defined as the `forward` so that accelerate hooks can move correctly `correct_output_scale` + (which is a nn.Parameter, not a Module) between devices when offloading. It is otherwise only used in + `scale_corrected_output` + """ + return (corrected.type_as(self.correct_output_scale) * self.correct_output_scale).type_as(corrected) + def scale_corrected_output(self, corrected: torch.Tensor) -> torch.Tensor: """Scales the provided 3D tensor of shape [batch_size, num_tokens, hidden_size].""" - return (corrected.type_as(self.correct_output_scale) * self.correct_output_scale).type_as(corrected) + return self.forward(corrected) class Gemma3nTextRotaryEmbedding(nn.Module): @@ -1290,7 +1298,7 @@ class Gemma3nTextAttention(nn.Module): self.v_norm = Gemma3nRMSNorm(dim=config.head_dim, eps=config.rms_norm_eps, with_scale=False) first_kv_shared_layer_idx = self.config.num_hidden_layers - self.config.num_kv_shared_layers - self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx + self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0 # Find the index of the last sliding or full layer before sharing starts (or None if no sharing) layer_type = config.layer_types[layer_idx] self.kv_shared_layer_index = ( @@ -1319,21 +1327,22 @@ class Gemma3nTextAttention(nn.Module): query_states = query_states.transpose(1, 2) if self.is_kv_shared_layer and self.kv_shared_layer_index is not None and past_key_value is not None: - # HybridCache has complex slicing when layer_type == "sliding_attention" that impact Shared KV Cache. + # Device of past layer may be different from current one + indices = cache_position.to(past_key_value.key_cache[self.kv_shared_layer_index].device) + # In this case we need special handling of the slice as the layer is of fixed small size (for full layers, we never go beyond) if isinstance(past_key_value, HybridCache) and self.is_sliding: max_length = past_key_value.sliding_window - if cache_position.shape[0] > max_length: - # If in the prefill phase for a "sliding_attention" layer and the prefill is larger than the cache, - # slice into the entire cache. - indices = slice(0, max_length) - else: - # If prefill fits or generating for a "sliding_attention" layer, clamp to max_cache_len - 1 - indices = cache_position.clamp(min=0, max=max_length - 1) - else: - indices = cache_position + indices = ( + slice(0, max_length) + if cache_position.shape[0] > max_length + else cache_position.clamp(min=0, max=max_length - 1) + ) - key_states = past_key_value.key_cache[self.kv_shared_layer_index][:, :, indices] - value_states = past_key_value.value_cache[self.kv_shared_layer_index][:, :, indices] + # Device of past layer may be different from current one + key_states = past_key_value.key_cache[self.kv_shared_layer_index][:, :, indices].to(query_states.device) + value_states = past_key_value.value_cache[self.kv_shared_layer_index][:, :, indices].to( + query_states.device + ) else: key_states = self.k_proj(hidden_states).view(hidden_shape) key_states = self.k_norm(key_states) @@ -1447,10 +1456,9 @@ class Gemma3nTextDecoderLayer(GradientCheckpointingLayer): attn_ffw_laurel_gated = attn_laurel + attn_ffw_norm corrected_predictions = self.altup.correct(predictions, attn_ffw_laurel_gated) - first_prediction = corrected_predictions[self.config.altup_active_idx] - first_prediction_clone = first_prediction.clone() + first_prediction = corrected_predictions[self.config.altup_active_idx].clone() if self.config.altup_correct_scale: - first_prediction = self.altup.scale_corrected_output(first_prediction_clone) + first_prediction = self.altup.scale_corrected_output(first_prediction) # per_layer_input_gate adapted from jax.numpy.einsum("btd,dp->btp", ...) first_prediction = self.per_layer_input_gate(first_prediction) @@ -1475,7 +1483,7 @@ class Gemma3nPreTrainedModel(PreTrainedModel): config_class = Gemma3nConfig base_model_prefix = "" supports_gradient_checkpointing = True - _no_split_modules = ["Gemma3nDecoderLayer"] + _no_split_modules = ["Gemma3nTextDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_3 = True _supports_flash_attn_2 = True @@ -1656,18 +1664,17 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel): position_embeddings_local = self.rotary_emb_local(hidden_states_0, position_ids) # Expand hidden_states to support per-layer inputs - target_magnitude: torch.Tensor = torch.mean(hidden_states_0**2, dim=-1, keepdim=True) ** 0.5 - epsilon_tensor = torch.tensor(torch.finfo().min) + target_magnitude = torch.mean(hidden_states_0**2, dim=-1, keepdim=True) ** 0.5 + epsilon_tensor = torch.tensor(1e-5) temp_hidden_states = [hidden_states_0] for i in range(1, self.config.altup_num_inputs): # altup_proj adapted from jax.numpy.einsum("btp,pd->btd", ...) - altup_proj: torch.Tensor = self.altup_projections[i - 1](hidden_states_0) - current_hidden_state = altup_proj.type(hidden_states_0.dtype) - new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True) ** 0.5 - current_hidden_state = current_hidden_state * ( - target_magnitude / torch.maximum(new_magnitude, epsilon_tensor) - ) + altup_proj = self.altup_projections[i - 1](hidden_states_0) + current_hidden_state = altup_proj.to(dtype=hidden_states_0.dtype, device=target_magnitude.device) + new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True) + new_magnitude = torch.sqrt(torch.maximum(new_magnitude, epsilon_tensor.to(target_magnitude.device))) + current_hidden_state = current_hidden_state * target_magnitude / new_magnitude temp_hidden_states.append(current_hidden_state) hidden_states = torch.stack(temp_hidden_states, dim=0) # [num_altup_inputs, batch, seq_len, hidden_size] @@ -1685,9 +1692,9 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel): layer_outputs = decoder_layer( hidden_states, - position_embeddings_global=position_embeddings_global, - position_embeddings_local=position_embeddings_local, - per_layer_input=per_layer_input, + position_embeddings_global, + position_embeddings_local, + per_layer_input, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, @@ -1712,11 +1719,10 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel): for i in range(1, self.config.altup_num_inputs): # altup_unembed_projections adapted from jax.numpy.einsum("btp,pd->btd", ...) altup_unemb_proj: torch.Tensor = self.altup_unembed_projections[i - 1](hidden_states[i]) - current_hidden_state = altup_unemb_proj.type(hidden_states_0.dtype) - new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True) ** 0.5 - current_hidden_state = current_hidden_state * ( - target_magnitude / torch.maximum(new_magnitude, epsilon_tensor) - ) + current_hidden_state = altup_unemb_proj.to(dtype=hidden_states_0.dtype, device=target_magnitude.device) + new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True) + new_magnitude = torch.sqrt(torch.maximum(new_magnitude, epsilon_tensor.to(target_magnitude.device))) + current_hidden_state = current_hidden_state * target_magnitude / new_magnitude temp_hidden_states.append(current_hidden_state) hidden_states = torch.stack(temp_hidden_states) @@ -1743,7 +1749,9 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel): per_layer_inputs: Optional[torch.Tensor] = None, ) -> torch.Tensor: per_layer_projection: torch.Tensor = self.per_layer_model_projection(inputs_embeds) - per_layer_projection *= self.per_layer_projection_scale.type(inputs_embeds.dtype) + per_layer_projection *= self.per_layer_projection_scale.to( + dtype=inputs_embeds.dtype, device=per_layer_projection.device + ) per_layer_projection = per_layer_projection.reshape( *inputs_embeds.shape[:-1], self.config.num_hidden_layers, @@ -1758,7 +1766,9 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel): # per-layer inputs are sometimes padded with zeros, slice the relevant embeddings. per_layer_inputs = per_layer_inputs[..., : self.config.num_hidden_layers, :] - return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale.type(inputs_embeds.dtype) + return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale.to( + dtype=inputs_embeds.dtype, device=per_layer_projection.device + ) @auto_docstring(custom_intro="The base Gemma 3n language model with a language modeling head.") diff --git a/src/transformers/models/gemma3n/modular_gemma3n.py b/src/transformers/models/gemma3n/modular_gemma3n.py index a18ac8c2ef2..b0a5099ff56 100644 --- a/src/transformers/models/gemma3n/modular_gemma3n.py +++ b/src/transformers/models/gemma3n/modular_gemma3n.py @@ -1685,9 +1685,17 @@ class Gemma3nTextAltUp(nn.Module): corrected += predictions # add the original input return corrected.contiguous().type_as(activated) + def forward(self, corrected: torch.Tensor) -> torch.Tensor: + """ + This is only defined as the `forward` so that accelerate hooks can move correctly `correct_output_scale` + (which is a nn.Parameter, not a Module) between devices when offloading. It is otherwise only used in + `scale_corrected_output` + """ + return (corrected.type_as(self.correct_output_scale) * self.correct_output_scale).type_as(corrected) + def scale_corrected_output(self, corrected: torch.Tensor) -> torch.Tensor: """Scales the provided 3D tensor of shape [batch_size, num_tokens, hidden_size].""" - return (corrected.type_as(self.correct_output_scale) * self.correct_output_scale).type_as(corrected) + return self.forward(corrected) class Gemma3nTextRotaryEmbedding(Gemma2RotaryEmbedding): @@ -1732,7 +1740,7 @@ class Gemma3nTextAttention(Gemma3Attention): self.v_norm = Gemma3nRMSNorm(dim=config.head_dim, eps=config.rms_norm_eps, with_scale=False) first_kv_shared_layer_idx = self.config.num_hidden_layers - self.config.num_kv_shared_layers - self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx + self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0 # Find the index of the last sliding or full layer before sharing starts (or None if no sharing) layer_type = config.layer_types[layer_idx] self.kv_shared_layer_index = ( @@ -1761,21 +1769,22 @@ class Gemma3nTextAttention(Gemma3Attention): query_states = query_states.transpose(1, 2) if self.is_kv_shared_layer and self.kv_shared_layer_index is not None and past_key_value is not None: - # HybridCache has complex slicing when layer_type == "sliding_attention" that impact Shared KV Cache. + # Device of past layer may be different from current one + indices = cache_position.to(past_key_value.key_cache[self.kv_shared_layer_index].device) + # In this case we need special handling of the slice as the layer is of fixed small size (for full layers, we never go beyond) if isinstance(past_key_value, HybridCache) and self.is_sliding: max_length = past_key_value.sliding_window - if cache_position.shape[0] > max_length: - # If in the prefill phase for a "sliding_attention" layer and the prefill is larger than the cache, - # slice into the entire cache. - indices = slice(0, max_length) - else: - # If prefill fits or generating for a "sliding_attention" layer, clamp to max_cache_len - 1 - indices = cache_position.clamp(min=0, max=max_length - 1) - else: - indices = cache_position + indices = ( + slice(0, max_length) + if cache_position.shape[0] > max_length + else cache_position.clamp(min=0, max=max_length - 1) + ) - key_states = past_key_value.key_cache[self.kv_shared_layer_index][:, :, indices] - value_states = past_key_value.value_cache[self.kv_shared_layer_index][:, :, indices] + # Device of past layer may be different from current one + key_states = past_key_value.key_cache[self.kv_shared_layer_index][:, :, indices].to(query_states.device) + value_states = past_key_value.value_cache[self.kv_shared_layer_index][:, :, indices].to( + query_states.device + ) else: key_states = self.k_proj(hidden_states).view(hidden_shape) key_states = self.k_norm(key_states) @@ -1880,10 +1889,9 @@ class Gemma3nTextDecoderLayer(Gemma3DecoderLayer): attn_ffw_laurel_gated = attn_laurel + attn_ffw_norm corrected_predictions = self.altup.correct(predictions, attn_ffw_laurel_gated) - first_prediction = corrected_predictions[self.config.altup_active_idx] - first_prediction_clone = first_prediction.clone() + first_prediction = corrected_predictions[self.config.altup_active_idx].clone() if self.config.altup_correct_scale: - first_prediction = self.altup.scale_corrected_output(first_prediction_clone) + first_prediction = self.altup.scale_corrected_output(first_prediction) # per_layer_input_gate adapted from jax.numpy.einsum("btd,dp->btp", ...) first_prediction = self.per_layer_input_gate(first_prediction) @@ -1906,7 +1914,7 @@ class Gemma3nTextDecoderLayer(Gemma3DecoderLayer): class Gemma3nPreTrainedModel(Gemma2PreTrainedModel): config_class = Gemma3nConfig base_model_prefix = "" - _no_split_modules = ["Gemma3nDecoderLayer"] + _no_split_modules = ["Gemma3nTextDecoderLayer"] def _init_weights(self, module): # important: this ported version of Gemma2 isn't meant for training from scratch - only @@ -1995,7 +2003,9 @@ class Gemma3nTextModel(Gemma3TextModel): per_layer_inputs: Optional[torch.Tensor] = None, ) -> torch.Tensor: per_layer_projection: torch.Tensor = self.per_layer_model_projection(inputs_embeds) - per_layer_projection *= self.per_layer_projection_scale.type(inputs_embeds.dtype) + per_layer_projection *= self.per_layer_projection_scale.to( + dtype=inputs_embeds.dtype, device=per_layer_projection.device + ) per_layer_projection = per_layer_projection.reshape( *inputs_embeds.shape[:-1], self.config.num_hidden_layers, @@ -2010,7 +2020,9 @@ class Gemma3nTextModel(Gemma3TextModel): # per-layer inputs are sometimes padded with zeros, slice the relevant embeddings. per_layer_inputs = per_layer_inputs[..., : self.config.num_hidden_layers, :] - return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale.type(inputs_embeds.dtype) + return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale.to( + dtype=inputs_embeds.dtype, device=per_layer_projection.device + ) @can_return_tuple @auto_docstring @@ -2091,18 +2103,17 @@ class Gemma3nTextModel(Gemma3TextModel): position_embeddings_local = self.rotary_emb_local(hidden_states_0, position_ids) # Expand hidden_states to support per-layer inputs - target_magnitude: torch.Tensor = torch.mean(hidden_states_0**2, dim=-1, keepdim=True) ** 0.5 - epsilon_tensor = torch.tensor(torch.finfo().min) + target_magnitude = torch.mean(hidden_states_0**2, dim=-1, keepdim=True) ** 0.5 + epsilon_tensor = torch.tensor(1e-5) temp_hidden_states = [hidden_states_0] for i in range(1, self.config.altup_num_inputs): # altup_proj adapted from jax.numpy.einsum("btp,pd->btd", ...) - altup_proj: torch.Tensor = self.altup_projections[i - 1](hidden_states_0) - current_hidden_state = altup_proj.type(hidden_states_0.dtype) - new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True) ** 0.5 - current_hidden_state = current_hidden_state * ( - target_magnitude / torch.maximum(new_magnitude, epsilon_tensor) - ) + altup_proj = self.altup_projections[i - 1](hidden_states_0) + current_hidden_state = altup_proj.to(dtype=hidden_states_0.dtype, device=target_magnitude.device) + new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True) + new_magnitude = torch.sqrt(torch.maximum(new_magnitude, epsilon_tensor.to(target_magnitude.device))) + current_hidden_state = current_hidden_state * target_magnitude / new_magnitude temp_hidden_states.append(current_hidden_state) hidden_states = torch.stack(temp_hidden_states, dim=0) # [num_altup_inputs, batch, seq_len, hidden_size] @@ -2120,9 +2131,9 @@ class Gemma3nTextModel(Gemma3TextModel): layer_outputs = decoder_layer( hidden_states, - position_embeddings_global=position_embeddings_global, - position_embeddings_local=position_embeddings_local, - per_layer_input=per_layer_input, + position_embeddings_global, + position_embeddings_local, + per_layer_input, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, @@ -2147,11 +2158,10 @@ class Gemma3nTextModel(Gemma3TextModel): for i in range(1, self.config.altup_num_inputs): # altup_unembed_projections adapted from jax.numpy.einsum("btp,pd->btd", ...) altup_unemb_proj: torch.Tensor = self.altup_unembed_projections[i - 1](hidden_states[i]) - current_hidden_state = altup_unemb_proj.type(hidden_states_0.dtype) - new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True) ** 0.5 - current_hidden_state = current_hidden_state * ( - target_magnitude / torch.maximum(new_magnitude, epsilon_tensor) - ) + current_hidden_state = altup_unemb_proj.to(dtype=hidden_states_0.dtype, device=target_magnitude.device) + new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True) + new_magnitude = torch.sqrt(torch.maximum(new_magnitude, epsilon_tensor.to(target_magnitude.device))) + current_hidden_state = current_hidden_state * target_magnitude / new_magnitude temp_hidden_states.append(current_hidden_state) hidden_states = torch.stack(temp_hidden_states) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 3b9cb2c5201..78349b8b906 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -1642,7 +1642,6 @@ def set_model_tester_for_less_flaky_test(test_case): "AriaVisionText2TextModelTester", "GPTNeoModelTester", "DPTModelTester", - "Gemma3nTextModelTester", # cannot have a single layer combined with the cache sharing config attrs in the tester ] if test_case.model_tester.__class__.__name__ in exceptional_classes: target_num_hidden_layers = None diff --git a/tests/models/gemma3n/test_modeling_gemma3n.py b/tests/models/gemma3n/test_modeling_gemma3n.py index 2f546e19e49..060bf15ea1e 100644 --- a/tests/models/gemma3n/test_modeling_gemma3n.py +++ b/tests/models/gemma3n/test_modeling_gemma3n.py @@ -39,13 +39,20 @@ from transformers.testing_utils import ( require_read_token, require_torch, require_torch_gpu, + require_torch_sdpa, slow, torch_device, ) from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester -from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor +from ...test_modeling_common import ( + TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION, + ModelTesterMixin, + _test_eager_matches_sdpa_inference, + floats_tensor, + ids_tensor, +) from ..gemma.test_modeling_gemma import GemmaModelTester @@ -256,6 +263,7 @@ class Gemma3nTextModelTester(GemmaModelTester): vocab_size=99, vocab_size_per_layer_input=99, hidden_size=16, + hidden_size_per_layer_input=16, num_hidden_layers=4, # override to correctly test sharing cache pattern num_kv_shared_layers=2, # important to override layer_types=[ @@ -291,6 +299,7 @@ class Gemma3nTextModelTester(GemmaModelTester): self.vocab_size = vocab_size self.vocab_size_per_layer_input = vocab_size_per_layer_input self.hidden_size = hidden_size + self.hidden_size_per_layer_input = hidden_size_per_layer_input self.num_hidden_layers = num_hidden_layers self.num_kv_shared_layers = num_kv_shared_layers self.layer_types = layer_types @@ -317,7 +326,6 @@ class Gemma3nTextModelTester(GemmaModelTester): for_causal_lm_class = Gemma3nForCausalLM -@unittest.skip("Skipped for now!") @require_torch class Gemma3nTextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): all_model_classes = (Gemma3nTextModel, Gemma3nForCausalLM) if is_torch_available() else () @@ -365,6 +373,64 @@ class Gemma3nTextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Tes [expected_shape] * len(iter_hidden_states), ) + @parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION) + @require_torch_sdpa + def test_eager_matches_sdpa_inference( + self, + name, + torch_dtype, + padding_side, + use_attention_mask, + output_attentions, + enable_kernels, + ): + "We need to relax a bit the `atols` for fp32 here due to the altup projections" + atols = { + ("cpu", False, torch.float32): 1e-3, # this was relaxed + ("cpu", False, torch.float16): 5e-3, + ("cpu", False, torch.bfloat16): 1e-2, + ("cpu", True, torch.float32): 1e-3, # this was relaxed + ("cpu", True, torch.float16): 5e-3, + ("cpu", True, torch.bfloat16): 1e-2, + ("cuda", False, torch.float32): 1e-3, # this was relaxed + ("cuda", False, torch.bfloat16): 1e-2, + ("cuda", False, torch.float16): 5e-3, + ("cuda", True, torch.float32): 1e-3, # this was relaxed + ("cuda", True, torch.bfloat16): 1e-2, + ("cuda", True, torch.float16): 5e-3, + } + _test_eager_matches_sdpa_inference( + self, name, torch_dtype, padding_side, use_attention_mask, output_attentions, enable_kernels, atols=atols + ) + + @pytest.mark.generate + @unittest.skip( + "Gemma3n has a special shape for hidden states (due to per-layer projs) which is not compatible with contrastive decoding" + ) + def test_contrastive_generate(self): + pass + + @pytest.mark.generate + @unittest.skip( + "Gemma3n has a special shape for hidden states (due to per-layer projs) which is not compatible with contrastive decoding" + ) + def test_contrastive_generate_dict_outputs_use_cache(self): + pass + + @pytest.mark.generate + @unittest.skip( + "Gemma3n has a special shape for hidden states (due to per-layer projs) which is not compatible with contrastive decoding" + ) + def test_contrastive_generate_low_memory(self): + pass + + @pytest.mark.generate + @unittest.skip( + "Gemma3n has a special shape for hidden states (due to per-layer projs) which is not compatible with dola decoding" + ) + def test_dola_decoding_sample(self): + pass + class Gemma3nVision2TextModelTester: text_config = {"activation_sparsity_pattern": None} diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index d7a41a6c5d0..dd463932148 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -156,6 +156,334 @@ TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION = [ ] + [("fp32_pad_left_output_attentions", "fp32", "left", True, True, False)] +def _test_eager_matches_sdpa_inference( + self, + name, + torch_dtype, + padding_side, + use_attention_mask, + output_attentions, + enable_kernels, + atols=None, + rtols=None, +): + """ + This test is written as a regular function to be able to overload it easily with different tolerances. + Otherwise, `paramterezie.expand` prevents it as it removes the original function from the namespace. + """ + # TODO: we shouldn't need to do this skip, i.e. the test would be composable from the model tester. CLIP-like + # models have a custom mixin, which we detect to skip this test. + if any(".CLIPModelTesterMixin" in str(base) for base in self.__class__.__bases__): + self.skipTest(reason="CLIP-like models have a different `test_eager_matches_sdpa_inference`") + + if not self.has_attentions: + self.skipTest(reason="Model architecture does not support attentions") + + if not self.all_model_classes[0]._supports_sdpa: + self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA") + + # convert shorthand name to torch.dtype + if torch_dtype == "fp16": + torch_dtype = torch.float16 + elif torch_dtype == "bf16": + torch_dtype = torch.bfloat16 + elif torch_dtype == "fp32": + torch_dtype = torch.float32 + + if not is_torch_fp16_available_on_device(torch_device) and torch_dtype == torch.float16: + self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)") + + if not is_torch_bf16_available_on_device(torch_device) and torch_dtype == torch.bfloat16: + self.skipTest( + f"bfloat16 not supported on {torch_device} (on the specific device currently used, e.g. Nvidia T4 GPU)" + ) + + # Dictionary of tolerances for eager <> sdpa tests. Key = (device, sdpa_kernels_enabled, dtype) + if atols is None: + atols = { + ("cpu", False, torch.float32): 1e-6, + ("cpu", False, torch.float16): 5e-3, + ("cpu", False, torch.bfloat16): 1e-2, + ("cpu", True, torch.float32): 1e-6, + ("cpu", True, torch.float16): 5e-3, + ("cpu", True, torch.bfloat16): 1e-2, + ("cuda", False, torch.float32): 1e-6, + ("cuda", False, torch.bfloat16): 1e-2, + ("cuda", False, torch.float16): 5e-3, + ("cuda", True, torch.float32): 1e-6, + ("cuda", True, torch.bfloat16): 1e-2, + ("cuda", True, torch.float16): 5e-3, + } + if rtols is None: + rtols = { + ("cpu", False, torch.float32): 1e-4, + ("cpu", False, torch.float16): 5e-3, + ("cpu", False, torch.bfloat16): 1e-2, + ("cpu", True, torch.float32): 1e-4, + ("cpu", True, torch.float16): 5e-3, + ("cpu", True, torch.bfloat16): 1e-2, + ("cuda", False, torch.float32): 1e-4, + ("cuda", False, torch.bfloat16): 1e-2, + ("cuda", False, torch.float16): 5e-3, + ("cuda", True, torch.float32): 1e-4, + ("cuda", True, torch.bfloat16): 3e-2, # (different from others) + ("cuda", True, torch.float16): 5e-3, + } + + set_model_tester_for_less_flaky_test(self) + + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + set_config_for_less_flaky_test(config) + model = model_class(config) + # TODO: standardize the interfaces for musicgen models, see other todo in this test + if model.__class__.__name__ == "MusicgenMelodyForConditionalGeneration": + is_encoder_decoder = True + else: + is_encoder_decoder = model.config.is_encoder_decoder + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_from_pretrained_kwargs = { + "pretrained_model_name_or_path": tmpdirname, + "torch_dtype": torch_dtype, + } + + if hasattr(config, "use_mask_token") or "use_mask_token" in inspect.signature(model.__init__).parameters: + model_from_pretrained_kwargs["use_mask_token"] = True + + # TODO: remove this try/except, models should have a shared API + try: + model_sdpa = model_class.from_pretrained(**model_from_pretrained_kwargs, attn_implementation="sdpa") + except ValueError: + model_sdpa = model_class.from_pretrained(**model_from_pretrained_kwargs) + model_sdpa = model_sdpa.eval().to(torch_device, dtype=torch_dtype) + + model_eager = model_class.from_pretrained(**model_from_pretrained_kwargs, attn_implementation="eager") + model_eager = model_eager.eval().to(torch_device, dtype=torch_dtype) + + set_model_for_less_flaky_test(model_eager) + set_model_for_less_flaky_test(model_sdpa) + + can_output_attn = "output_attentions" in inspect.signature(model_sdpa.forward).parameters + if not (self.has_attentions and can_output_attn) and output_attentions: + self.skipTest(reason="Model does not support output_attentions") + + # TODO: if we can also check with `batch_size=1` without being flaky? + for batch_size in [7]: + # musicgen decoder models; TODO: find better abstraction + if ( + model.__class__.__name__.startswith("Musicgen") + and hasattr(self.model_tester, "num_codebooks") + and not hasattr(model_eager, "text_encoder") + ): + input_data_batch_size = batch_size * self.model_tester.num_codebooks + else: + input_data_batch_size = batch_size + + processed_inputs = {} + processed_inputs[model.main_input_name] = inputs_dict[model.main_input_name] + + for key in getattr(self, "additional_model_inputs", []): + # Some models don't have all `additional_model_inputs`, especially when we + # craft cases to test model in different settings + if key in inputs_dict: + processed_inputs[key] = inputs_dict[key] + + for key, value in processed_inputs.items(): + if torch.is_floating_point(value): + value = value.to(torch_dtype) + + # extend value to have at least `input_data_batch_size` elements + if value.shape[0] < input_data_batch_size: + size = (input_data_batch_size - value.shape[0], *value.shape[1:]) + if torch.is_floating_point(value): + extension = torch.rand(size=size, dtype=value.dtype, device=torch_device) + else: + extension = torch.randint(high=5, size=size, dtype=value.dtype, device=torch_device) + value = torch.cat((value, extension), dim=0).to(torch_device) + + processed_inputs[key] = value[:input_data_batch_size] + + if not use_attention_mask: + dummy_attention_mask = None + else: + dummy_attention_mask = inputs_dict.get("attention_mask", None) + if dummy_attention_mask is None: + if is_encoder_decoder: + seqlen = inputs_dict.get("decoder_input_ids", processed_inputs[model.main_input_name]).shape[ + -1 + ] + else: + seqlen = processed_inputs[model.main_input_name].shape[-1] + dummy_attention_mask = torch.ones(batch_size, seqlen).to(torch.int64).to(torch_device) + + # extend dummy_attention_mask to have at least `batch_size` elements + if dummy_attention_mask.shape[0] < batch_size: + size = (batch_size - dummy_attention_mask.shape[0], *dummy_attention_mask.shape[1:]) + extension = torch.ones(size=size, dtype=dummy_attention_mask.dtype, device=torch_device) + dummy_attention_mask = torch.cat((dummy_attention_mask, extension), dim=0) + + dummy_attention_mask = dummy_attention_mask[:batch_size].to(torch_device) + + dummy_attention_mask[:] = 1 + if padding_side == "left": + dummy_attention_mask[-1, :2] = 0 + dummy_attention_mask[-1, 2:] = 1 + elif padding_side == "right": + dummy_attention_mask[-1, -2:] = 0 + dummy_attention_mask[-1, :-2] = 1 + + if is_encoder_decoder: + # musicgen encoder-decoder models; TODO: find better abstraction + if model.__class__.__name__.startswith("Musicgen") and hasattr(self.model_tester, "num_codebooks"): + input_data_batch_size = batch_size * self.model_tester.num_codebooks + else: + input_data_batch_size = batch_size + + decoder_input_ids = inputs_dict.get("decoder_input_ids", processed_inputs[model.main_input_name]) + decoder_input_ids = decoder_input_ids[:input_data_batch_size] + if decoder_input_ids.shape[0] != input_data_batch_size: + extension = torch.ones( + input_data_batch_size - decoder_input_ids.shape[0], + *decoder_input_ids.shape[1:], + dtype=decoder_input_ids.dtype, + device=torch_device, + ) + decoder_input_ids = torch.cat((decoder_input_ids, extension), dim=0) + decoder_input_ids = decoder_input_ids.to(torch_device) + + # TODO: never an `attention_mask` arg here? + processed_inputs.update( + { + "decoder_input_ids": decoder_input_ids, + "decoder_attention_mask": dummy_attention_mask, + "output_hidden_states": True, + } + ) + else: + processed_inputs.update( + { + "output_hidden_states": True, + } + ) + + # Otherwise fails for e.g. WhisperEncoderModel + if "attention_mask" in inspect.signature(model_eager.forward).parameters: + processed_inputs["attention_mask"] = dummy_attention_mask + + if self.has_attentions and "output_attentions" in inspect.signature(model_sdpa.forward).parameters: + processed_inputs["output_attentions"] = output_attentions + if "bool_masked_pos" in inspect.signature(model_eager.forward).parameters: + dummy_mask = torch.ones((self.model_tester.num_masks,)) + + # In case of additional token (like class) we define a custom `mask_length` + if hasattr(self.model_tester, "mask_length"): + mask_length = self.model_tester.mask_length - dummy_mask.size(0) + else: + mask_length = self.model_tester.seq_length - dummy_mask.size(0) + dummy_mask = torch.cat([dummy_mask, torch.zeros(mask_length)]) + dummy_bool_masked_pos = dummy_mask.expand(batch_size, -1).bool() + processed_inputs["bool_masked_pos"] = dummy_bool_masked_pos.to(torch_device) + + if "noise" in inspect.signature(model_eager.forward).parameters: + np.random.seed(2) + num_patches = int((self.model_tester.image_size // self.model_tester.patch_size) ** 2) + noise = np.random.uniform(size=(batch_size, num_patches)) + processed_inputs["noise"] = torch.from_numpy(noise) + + # TODO: test gradients as well (& for FA2 as well!) + with torch.no_grad(): + with sdpa_kernel( + enable_flash=enable_kernels, + enable_math=True, + enable_mem_efficient=enable_kernels, + ): + prepared_inputs = self._prepare_for_class(processed_inputs, model_class) + prepared_inputs = { + k: v.to(torch_device) if isinstance(v, torch.Tensor) else v for k, v in prepared_inputs.items() + } + outputs_eager = model_eager(**prepared_inputs) + outputs_sdpa = model_sdpa(**prepared_inputs) + + if "logits_per_text" in outputs_eager: + key = "logits_per_text" + elif "vision_hidden_states" in outputs_eager: + key = "vision_hidden_states" + elif "audio_values" in outputs_eager: + key = "audio_values" + elif "decoder_hidden_states" in outputs_eager: + key = "decoder_hidden_states" + elif "logits" in outputs_eager and "Classification" in model_class.__name__: + key = "logits" + elif "language_model_outputs" in outputs_eager and "blip" in model_class.__name__.lower(): + outputs_eager = outputs_eager["language_model_outputs"] + outputs_sdpa = outputs_sdpa["language_model_outputs"] + key = "hidden_states" if "hidden_states" in outputs_eager else "decoder_hidden_states" + else: + key = "hidden_states" + + # TODO: rename logits -> hidden_states + logits_eager = outputs_eager[key] + logits_sdpa = outputs_sdpa[key] + + if key in ["vision_hidden_states", "decoder_hidden_states", "hidden_states"]: + logits_eager = logits_eager[-1] + logits_sdpa = logits_sdpa[-1] + + if key == "logits_per_text": + nan_mask = torch.isnan(logits_eager) + logits_eager[nan_mask] = 0 + logits_sdpa[nan_mask] = 0 + + if torch_device in ["cpu", "cuda"]: + atol = atols[torch_device, enable_kernels, torch_dtype] + rtol = rtols[torch_device, enable_kernels, torch_dtype] + elif torch_device == "hpu": + atol = atols["cuda", enable_kernels, torch_dtype] + rtol = rtols["cuda", enable_kernels, torch_dtype] + elif torch_device == "xpu": + # As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH + # which is implemented on PyTorch level using aten operators and is + # device agnostic with respect to implementation of each aten operator. + atol = atols["cuda", False, torch_dtype] + rtol = rtols["cuda", False, torch_dtype] + else: + atol = 1e-7 + rtol = 1e-4 + + # Masked tokens output slightly deviates - we don't mind that. + if use_attention_mask: + _logits_sdpa = torch.zeros_like(input=logits_sdpa) + _logits_eager = torch.zeros_like(input=logits_eager) + + _logits_sdpa[:-1] = logits_sdpa[:-1] + _logits_eager[:-1] = logits_eager[:-1] + + if padding_side == "left": + _logits_sdpa[-1:, 2:] = logits_sdpa[-1:, 2:] + _logits_eager[-1:, 2:] = logits_eager[-1:, 2:] + + elif padding_side == "right": + _logits_sdpa[-1:, 2:] = logits_sdpa[-1:, :-2] + _logits_eager[-1:, 2:] = logits_eager[-1:, :-2] + + logits_sdpa = _logits_sdpa + logits_eager = _logits_eager + + results = [ + torch.allclose(_logits_sdpa, _logits_eager, atol=atol, rtol=rtol) + for (_logits_sdpa, _logits_eager) in zip(logits_sdpa, logits_eager) + ] + # If 80% batch elements have matched results, it's fine + if np.mean(results) < 0.8: + mean_relative_diff = ((logits_sdpa - logits_eager).abs() / (logits_eager.abs() + 1e-12)).mean() + raise ValueError( + f"mean relative difference for {key}: {mean_relative_diff:.3e}, torch atol = {atol}, torch rtol = " + f"{rtol}" + ) + + def _config_zero_init(config): configs_no_init = copy.deepcopy(config) for key in configs_no_init.__dict__.keys(): @@ -3405,321 +3733,9 @@ class ModelTesterMixin: def test_eager_matches_sdpa_inference( self, name, torch_dtype, padding_side, use_attention_mask, output_attentions, enable_kernels ): - # TODO: we shouldn't need to do this skip, i.e. the test would be composable from the model tester. CLIP-like - # models have a custom mixin, which we detect to skip this test. - if any(".CLIPModelTesterMixin" in str(base) for base in self.__class__.__bases__): - self.skipTest(reason="CLIP-like models have a different `test_eager_matches_sdpa_inference`") - - if not self.has_attentions: - self.skipTest(reason="Model architecture does not support attentions") - - if not self.all_model_classes[0]._supports_sdpa: - self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA") - - # convert shorthand name to torch.dtype - if torch_dtype == "fp16": - torch_dtype = torch.float16 - elif torch_dtype == "bf16": - torch_dtype = torch.bfloat16 - elif torch_dtype == "fp32": - torch_dtype = torch.float32 - - if not is_torch_fp16_available_on_device(torch_device) and torch_dtype == torch.float16: - self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)") - - if not is_torch_bf16_available_on_device(torch_device) and torch_dtype == torch.bfloat16: - self.skipTest( - f"bfloat16 not supported on {torch_device} (on the specific device currently used, e.g. Nvidia T4 GPU)" - ) - - # Dictionary of tolerances for eager <> sdpa tests. Key = (device, sdpa_kernels_enabled, dtype) - atols = { - ("cpu", False, torch.float32): 1e-6, - ("cpu", False, torch.float16): 5e-3, - ("cpu", False, torch.bfloat16): 1e-2, - ("cpu", True, torch.float32): 1e-6, - ("cpu", True, torch.float16): 5e-3, - ("cpu", True, torch.bfloat16): 1e-2, - ("cuda", False, torch.float32): 1e-6, - ("cuda", False, torch.bfloat16): 1e-2, - ("cuda", False, torch.float16): 5e-3, - ("cuda", True, torch.float32): 1e-6, - ("cuda", True, torch.bfloat16): 1e-2, - ("cuda", True, torch.float16): 5e-3, - } - rtols = { - ("cpu", False, torch.float32): 1e-4, - ("cpu", False, torch.float16): 5e-3, - ("cpu", False, torch.bfloat16): 1e-2, - ("cpu", True, torch.float32): 1e-4, - ("cpu", True, torch.float16): 5e-3, - ("cpu", True, torch.bfloat16): 1e-2, - ("cuda", False, torch.float32): 1e-4, - ("cuda", False, torch.bfloat16): 1e-2, - ("cuda", False, torch.float16): 5e-3, - ("cuda", True, torch.float32): 1e-4, - ("cuda", True, torch.bfloat16): 3e-2, # (different from others) - ("cuda", True, torch.float16): 5e-3, - } - - set_model_tester_for_less_flaky_test(self) - - for model_class in self.all_model_classes: - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - set_config_for_less_flaky_test(config) - model = model_class(config) - # TODO: standardize the interfaces for musicgen models, see other todo in this test - if model.__class__.__name__ == "MusicgenMelodyForConditionalGeneration": - is_encoder_decoder = True - else: - is_encoder_decoder = model.config.is_encoder_decoder - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - model_from_pretrained_kwargs = { - "pretrained_model_name_or_path": tmpdirname, - "torch_dtype": torch_dtype, - } - - if ( - hasattr(config, "use_mask_token") - or "use_mask_token" in inspect.signature(model.__init__).parameters - ): - model_from_pretrained_kwargs["use_mask_token"] = True - - # TODO: remove this try/except, models should have a shared API - try: - model_sdpa = model_class.from_pretrained( - **model_from_pretrained_kwargs, attn_implementation="sdpa" - ) - except ValueError: - model_sdpa = model_class.from_pretrained(**model_from_pretrained_kwargs) - model_sdpa = model_sdpa.eval().to(torch_device, dtype=torch_dtype) - - model_eager = model_class.from_pretrained(**model_from_pretrained_kwargs, attn_implementation="eager") - model_eager = model_eager.eval().to(torch_device, dtype=torch_dtype) - - set_model_for_less_flaky_test(model_eager) - set_model_for_less_flaky_test(model_sdpa) - - can_output_attn = "output_attentions" in inspect.signature(model_sdpa.forward).parameters - if not (self.has_attentions and can_output_attn) and output_attentions: - self.skipTest(reason="Model does not support output_attentions") - - # TODO: if we can also check with `batch_size=1` without being flaky? - for batch_size in [7]: - # musicgen decoder models; TODO: find better abstraction - if ( - model.__class__.__name__.startswith("Musicgen") - and hasattr(self.model_tester, "num_codebooks") - and not hasattr(model_eager, "text_encoder") - ): - input_data_batch_size = batch_size * self.model_tester.num_codebooks - else: - input_data_batch_size = batch_size - - processed_inputs = {} - processed_inputs[model.main_input_name] = inputs_dict[model.main_input_name] - - for key in getattr(self, "additional_model_inputs", []): - # Some models don't have all `additional_model_inputs`, especially when we - # craft cases to test model in different settings - if key in inputs_dict: - processed_inputs[key] = inputs_dict[key] - - for key, value in processed_inputs.items(): - if torch.is_floating_point(value): - value = value.to(torch_dtype) - - # extend value to have at least `input_data_batch_size` elements - if value.shape[0] < input_data_batch_size: - size = (input_data_batch_size - value.shape[0], *value.shape[1:]) - if torch.is_floating_point(value): - extension = torch.rand(size=size, dtype=value.dtype, device=torch_device) - else: - extension = torch.randint(high=5, size=size, dtype=value.dtype, device=torch_device) - value = torch.cat((value, extension), dim=0).to(torch_device) - - processed_inputs[key] = value[:input_data_batch_size] - - if not use_attention_mask: - dummy_attention_mask = None - else: - dummy_attention_mask = inputs_dict.get("attention_mask", None) - if dummy_attention_mask is None: - if is_encoder_decoder: - seqlen = inputs_dict.get( - "decoder_input_ids", processed_inputs[model.main_input_name] - ).shape[-1] - else: - seqlen = processed_inputs[model.main_input_name].shape[-1] - dummy_attention_mask = torch.ones(batch_size, seqlen).to(torch.int64).to(torch_device) - - # extend dummy_attention_mask to have at least `batch_size` elements - if dummy_attention_mask.shape[0] < batch_size: - size = (batch_size - dummy_attention_mask.shape[0], *dummy_attention_mask.shape[1:]) - extension = torch.ones(size=size, dtype=dummy_attention_mask.dtype, device=torch_device) - dummy_attention_mask = torch.cat((dummy_attention_mask, extension), dim=0) - - dummy_attention_mask = dummy_attention_mask[:batch_size].to(torch_device) - - dummy_attention_mask[:] = 1 - if padding_side == "left": - dummy_attention_mask[-1, :2] = 0 - dummy_attention_mask[-1, 2:] = 1 - elif padding_side == "right": - dummy_attention_mask[-1, -2:] = 0 - dummy_attention_mask[-1, :-2] = 1 - - if is_encoder_decoder: - # musicgen encoder-decoder models; TODO: find better abstraction - if model.__class__.__name__.startswith("Musicgen") and hasattr(self.model_tester, "num_codebooks"): - input_data_batch_size = batch_size * self.model_tester.num_codebooks - else: - input_data_batch_size = batch_size - - decoder_input_ids = inputs_dict.get("decoder_input_ids", processed_inputs[model.main_input_name]) - decoder_input_ids = decoder_input_ids[:input_data_batch_size] - if decoder_input_ids.shape[0] != input_data_batch_size: - extension = torch.ones( - input_data_batch_size - decoder_input_ids.shape[0], - *decoder_input_ids.shape[1:], - dtype=decoder_input_ids.dtype, - device=torch_device, - ) - decoder_input_ids = torch.cat((decoder_input_ids, extension), dim=0) - decoder_input_ids = decoder_input_ids.to(torch_device) - - # TODO: never an `attention_mask` arg here? - processed_inputs.update( - { - "decoder_input_ids": decoder_input_ids, - "decoder_attention_mask": dummy_attention_mask, - "output_hidden_states": True, - } - ) - else: - processed_inputs.update( - { - "output_hidden_states": True, - } - ) - - # Otherwise fails for e.g. WhisperEncoderModel - if "attention_mask" in inspect.signature(model_eager.forward).parameters: - processed_inputs["attention_mask"] = dummy_attention_mask - - if self.has_attentions and "output_attentions" in inspect.signature(model_sdpa.forward).parameters: - processed_inputs["output_attentions"] = output_attentions - if "bool_masked_pos" in inspect.signature(model_eager.forward).parameters: - dummy_mask = torch.ones((self.model_tester.num_masks,)) - - # In case of additional token (like class) we define a custom `mask_length` - if hasattr(self.model_tester, "mask_length"): - mask_length = self.model_tester.mask_length - dummy_mask.size(0) - else: - mask_length = self.model_tester.seq_length - dummy_mask.size(0) - dummy_mask = torch.cat([dummy_mask, torch.zeros(mask_length)]) - dummy_bool_masked_pos = dummy_mask.expand(batch_size, -1).bool() - processed_inputs["bool_masked_pos"] = dummy_bool_masked_pos.to(torch_device) - - if "noise" in inspect.signature(model_eager.forward).parameters: - np.random.seed(2) - num_patches = int((self.model_tester.image_size // self.model_tester.patch_size) ** 2) - noise = np.random.uniform(size=(batch_size, num_patches)) - processed_inputs["noise"] = torch.from_numpy(noise) - - # TODO: test gradients as well (& for FA2 as well!) - with torch.no_grad(): - with sdpa_kernel( - enable_flash=enable_kernels, - enable_math=True, - enable_mem_efficient=enable_kernels, - ): - prepared_inputs = self._prepare_for_class(processed_inputs, model_class) - prepared_inputs = { - k: v.to(torch_device) if isinstance(v, torch.Tensor) else v - for k, v in prepared_inputs.items() - } - outputs_eager = model_eager(**prepared_inputs) - outputs_sdpa = model_sdpa(**prepared_inputs) - - if "logits_per_text" in outputs_eager: - key = "logits_per_text" - elif "vision_hidden_states" in outputs_eager: - key = "vision_hidden_states" - elif "audio_values" in outputs_eager: - key = "audio_values" - elif "decoder_hidden_states" in outputs_eager: - key = "decoder_hidden_states" - elif "logits" in outputs_eager and "Classification" in model_class.__name__: - key = "logits" - elif "language_model_outputs" in outputs_eager and "blip" in model_class.__name__.lower(): - outputs_eager = outputs_eager["language_model_outputs"] - outputs_sdpa = outputs_sdpa["language_model_outputs"] - key = "hidden_states" if "hidden_states" in outputs_eager else "decoder_hidden_states" - else: - key = "hidden_states" - - # TODO: rename logits -> hidden_states - logits_eager = outputs_eager[key] - logits_sdpa = outputs_sdpa[key] - - if key in ["vision_hidden_states", "decoder_hidden_states", "hidden_states"]: - logits_eager = logits_eager[-1] - logits_sdpa = logits_sdpa[-1] - - if key == "logits_per_text": - nan_mask = torch.isnan(logits_eager) - logits_eager[nan_mask] = 0 - logits_sdpa[nan_mask] = 0 - - if torch_device in ["cpu", "cuda"]: - atol = atols[torch_device, enable_kernels, torch_dtype] - rtol = rtols[torch_device, enable_kernels, torch_dtype] - elif torch_device == "hpu": - atol = atols["cuda", enable_kernels, torch_dtype] - rtol = rtols["cuda", enable_kernels, torch_dtype] - elif torch_device == "xpu": - # As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH - # which is implemented on PyTorch level using aten operators and is - # device agnostic with respect to implementation of each aten operator. - atol = atols["cuda", False, torch_dtype] - rtol = rtols["cuda", False, torch_dtype] - else: - atol = 1e-7 - rtol = 1e-4 - - # Masked tokens output slightly deviates - we don't mind that. - if use_attention_mask: - _logits_sdpa = torch.zeros_like(input=logits_sdpa) - _logits_eager = torch.zeros_like(input=logits_eager) - - _logits_sdpa[:-1] = logits_sdpa[:-1] - _logits_eager[:-1] = logits_eager[:-1] - - if padding_side == "left": - _logits_sdpa[-1:, 2:] = logits_sdpa[-1:, 2:] - _logits_eager[-1:, 2:] = logits_eager[-1:, 2:] - - elif padding_side == "right": - _logits_sdpa[-1:, 2:] = logits_sdpa[-1:, :-2] - _logits_eager[-1:, 2:] = logits_eager[-1:, :-2] - - logits_sdpa = _logits_sdpa - logits_eager = _logits_eager - - results = [ - torch.allclose(_logits_sdpa, _logits_eager, atol=atol, rtol=rtol) - for (_logits_sdpa, _logits_eager) in zip(logits_sdpa, logits_eager) - ] - # If 80% batch elements have matched results, it's fine - if np.mean(results) < 0.8: - mean_relative_diff = ((logits_sdpa - logits_eager).abs() / (logits_eager.abs() + 1e-12)).mean() - raise ValueError( - f"mean relative difference for {key}: {mean_relative_diff:.3e}, torch atol = {atol}, torch rtol = " - f"{rtol}" - ) + _test_eager_matches_sdpa_inference( + self, name, torch_dtype, padding_side, use_attention_mask, output_attentions, enable_kernels + ) @require_torch_sdpa @require_torch_accelerator