Several fixes for Gemma3n (#39135)

* remove the skips

* fix the epsilon to a small value (does not make sense otherwise)

* safeguard

* overload test_eager_matches_sdpa

* Update test_modeling_common.py

* skip appropriate tests

* correct no_split_layer

* fix all devices issue

* fix backward

* fix
This commit is contained in:
Cyril Vallez 2025-07-01 10:34:53 +02:00 committed by GitHub
parent d53518c5f2
commit dbc98328da
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 491 additions and 390 deletions

View File

@ -1135,9 +1135,17 @@ class Gemma3nTextAltUp(nn.Module):
corrected += predictions # add the original input
return corrected.contiguous().type_as(activated)
def forward(self, corrected: torch.Tensor) -> torch.Tensor:
"""
This is only defined as the `forward` so that accelerate hooks can move correctly `correct_output_scale`
(which is a nn.Parameter, not a Module) between devices when offloading. It is otherwise only used in
`scale_corrected_output`
"""
return (corrected.type_as(self.correct_output_scale) * self.correct_output_scale).type_as(corrected)
def scale_corrected_output(self, corrected: torch.Tensor) -> torch.Tensor:
"""Scales the provided 3D tensor of shape [batch_size, num_tokens, hidden_size]."""
return (corrected.type_as(self.correct_output_scale) * self.correct_output_scale).type_as(corrected)
return self.forward(corrected)
class Gemma3nTextRotaryEmbedding(nn.Module):
@ -1290,7 +1298,7 @@ class Gemma3nTextAttention(nn.Module):
self.v_norm = Gemma3nRMSNorm(dim=config.head_dim, eps=config.rms_norm_eps, with_scale=False)
first_kv_shared_layer_idx = self.config.num_hidden_layers - self.config.num_kv_shared_layers
self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx
self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0
# Find the index of the last sliding or full layer before sharing starts (or None if no sharing)
layer_type = config.layer_types[layer_idx]
self.kv_shared_layer_index = (
@ -1319,21 +1327,22 @@ class Gemma3nTextAttention(nn.Module):
query_states = query_states.transpose(1, 2)
if self.is_kv_shared_layer and self.kv_shared_layer_index is not None and past_key_value is not None:
# HybridCache has complex slicing when layer_type == "sliding_attention" that impact Shared KV Cache.
# Device of past layer may be different from current one
indices = cache_position.to(past_key_value.key_cache[self.kv_shared_layer_index].device)
# In this case we need special handling of the slice as the layer is of fixed small size (for full layers, we never go beyond)
if isinstance(past_key_value, HybridCache) and self.is_sliding:
max_length = past_key_value.sliding_window
if cache_position.shape[0] > max_length:
# If in the prefill phase for a "sliding_attention" layer and the prefill is larger than the cache,
# slice into the entire cache.
indices = slice(0, max_length)
else:
# If prefill fits or generating for a "sliding_attention" layer, clamp to max_cache_len - 1
indices = cache_position.clamp(min=0, max=max_length - 1)
else:
indices = cache_position
indices = (
slice(0, max_length)
if cache_position.shape[0] > max_length
else cache_position.clamp(min=0, max=max_length - 1)
)
key_states = past_key_value.key_cache[self.kv_shared_layer_index][:, :, indices]
value_states = past_key_value.value_cache[self.kv_shared_layer_index][:, :, indices]
# Device of past layer may be different from current one
key_states = past_key_value.key_cache[self.kv_shared_layer_index][:, :, indices].to(query_states.device)
value_states = past_key_value.value_cache[self.kv_shared_layer_index][:, :, indices].to(
query_states.device
)
else:
key_states = self.k_proj(hidden_states).view(hidden_shape)
key_states = self.k_norm(key_states)
@ -1447,10 +1456,9 @@ class Gemma3nTextDecoderLayer(GradientCheckpointingLayer):
attn_ffw_laurel_gated = attn_laurel + attn_ffw_norm
corrected_predictions = self.altup.correct(predictions, attn_ffw_laurel_gated)
first_prediction = corrected_predictions[self.config.altup_active_idx]
first_prediction_clone = first_prediction.clone()
first_prediction = corrected_predictions[self.config.altup_active_idx].clone()
if self.config.altup_correct_scale:
first_prediction = self.altup.scale_corrected_output(first_prediction_clone)
first_prediction = self.altup.scale_corrected_output(first_prediction)
# per_layer_input_gate adapted from jax.numpy.einsum("btd,dp->btp", ...)
first_prediction = self.per_layer_input_gate(first_prediction)
@ -1475,7 +1483,7 @@ class Gemma3nPreTrainedModel(PreTrainedModel):
config_class = Gemma3nConfig
base_model_prefix = ""
supports_gradient_checkpointing = True
_no_split_modules = ["Gemma3nDecoderLayer"]
_no_split_modules = ["Gemma3nTextDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True
@ -1656,18 +1664,17 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel):
position_embeddings_local = self.rotary_emb_local(hidden_states_0, position_ids)
# Expand hidden_states to support per-layer inputs
target_magnitude: torch.Tensor = torch.mean(hidden_states_0**2, dim=-1, keepdim=True) ** 0.5
epsilon_tensor = torch.tensor(torch.finfo().min)
target_magnitude = torch.mean(hidden_states_0**2, dim=-1, keepdim=True) ** 0.5
epsilon_tensor = torch.tensor(1e-5)
temp_hidden_states = [hidden_states_0]
for i in range(1, self.config.altup_num_inputs):
# altup_proj adapted from jax.numpy.einsum("btp,pd->btd", ...)
altup_proj: torch.Tensor = self.altup_projections[i - 1](hidden_states_0)
current_hidden_state = altup_proj.type(hidden_states_0.dtype)
new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True) ** 0.5
current_hidden_state = current_hidden_state * (
target_magnitude / torch.maximum(new_magnitude, epsilon_tensor)
)
altup_proj = self.altup_projections[i - 1](hidden_states_0)
current_hidden_state = altup_proj.to(dtype=hidden_states_0.dtype, device=target_magnitude.device)
new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True)
new_magnitude = torch.sqrt(torch.maximum(new_magnitude, epsilon_tensor.to(target_magnitude.device)))
current_hidden_state = current_hidden_state * target_magnitude / new_magnitude
temp_hidden_states.append(current_hidden_state)
hidden_states = torch.stack(temp_hidden_states, dim=0) # [num_altup_inputs, batch, seq_len, hidden_size]
@ -1685,9 +1692,9 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel):
layer_outputs = decoder_layer(
hidden_states,
position_embeddings_global=position_embeddings_global,
position_embeddings_local=position_embeddings_local,
per_layer_input=per_layer_input,
position_embeddings_global,
position_embeddings_local,
per_layer_input,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
@ -1712,11 +1719,10 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel):
for i in range(1, self.config.altup_num_inputs):
# altup_unembed_projections adapted from jax.numpy.einsum("btp,pd->btd", ...)
altup_unemb_proj: torch.Tensor = self.altup_unembed_projections[i - 1](hidden_states[i])
current_hidden_state = altup_unemb_proj.type(hidden_states_0.dtype)
new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True) ** 0.5
current_hidden_state = current_hidden_state * (
target_magnitude / torch.maximum(new_magnitude, epsilon_tensor)
)
current_hidden_state = altup_unemb_proj.to(dtype=hidden_states_0.dtype, device=target_magnitude.device)
new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True)
new_magnitude = torch.sqrt(torch.maximum(new_magnitude, epsilon_tensor.to(target_magnitude.device)))
current_hidden_state = current_hidden_state * target_magnitude / new_magnitude
temp_hidden_states.append(current_hidden_state)
hidden_states = torch.stack(temp_hidden_states)
@ -1743,7 +1749,9 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel):
per_layer_inputs: Optional[torch.Tensor] = None,
) -> torch.Tensor:
per_layer_projection: torch.Tensor = self.per_layer_model_projection(inputs_embeds)
per_layer_projection *= self.per_layer_projection_scale.type(inputs_embeds.dtype)
per_layer_projection *= self.per_layer_projection_scale.to(
dtype=inputs_embeds.dtype, device=per_layer_projection.device
)
per_layer_projection = per_layer_projection.reshape(
*inputs_embeds.shape[:-1],
self.config.num_hidden_layers,
@ -1758,7 +1766,9 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel):
# per-layer inputs are sometimes padded with zeros, slice the relevant embeddings.
per_layer_inputs = per_layer_inputs[..., : self.config.num_hidden_layers, :]
return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale.type(inputs_embeds.dtype)
return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale.to(
dtype=inputs_embeds.dtype, device=per_layer_projection.device
)
@auto_docstring(custom_intro="The base Gemma 3n language model with a language modeling head.")

View File

@ -1685,9 +1685,17 @@ class Gemma3nTextAltUp(nn.Module):
corrected += predictions # add the original input
return corrected.contiguous().type_as(activated)
def forward(self, corrected: torch.Tensor) -> torch.Tensor:
"""
This is only defined as the `forward` so that accelerate hooks can move correctly `correct_output_scale`
(which is a nn.Parameter, not a Module) between devices when offloading. It is otherwise only used in
`scale_corrected_output`
"""
return (corrected.type_as(self.correct_output_scale) * self.correct_output_scale).type_as(corrected)
def scale_corrected_output(self, corrected: torch.Tensor) -> torch.Tensor:
"""Scales the provided 3D tensor of shape [batch_size, num_tokens, hidden_size]."""
return (corrected.type_as(self.correct_output_scale) * self.correct_output_scale).type_as(corrected)
return self.forward(corrected)
class Gemma3nTextRotaryEmbedding(Gemma2RotaryEmbedding):
@ -1732,7 +1740,7 @@ class Gemma3nTextAttention(Gemma3Attention):
self.v_norm = Gemma3nRMSNorm(dim=config.head_dim, eps=config.rms_norm_eps, with_scale=False)
first_kv_shared_layer_idx = self.config.num_hidden_layers - self.config.num_kv_shared_layers
self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx
self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0
# Find the index of the last sliding or full layer before sharing starts (or None if no sharing)
layer_type = config.layer_types[layer_idx]
self.kv_shared_layer_index = (
@ -1761,21 +1769,22 @@ class Gemma3nTextAttention(Gemma3Attention):
query_states = query_states.transpose(1, 2)
if self.is_kv_shared_layer and self.kv_shared_layer_index is not None and past_key_value is not None:
# HybridCache has complex slicing when layer_type == "sliding_attention" that impact Shared KV Cache.
# Device of past layer may be different from current one
indices = cache_position.to(past_key_value.key_cache[self.kv_shared_layer_index].device)
# In this case we need special handling of the slice as the layer is of fixed small size (for full layers, we never go beyond)
if isinstance(past_key_value, HybridCache) and self.is_sliding:
max_length = past_key_value.sliding_window
if cache_position.shape[0] > max_length:
# If in the prefill phase for a "sliding_attention" layer and the prefill is larger than the cache,
# slice into the entire cache.
indices = slice(0, max_length)
else:
# If prefill fits or generating for a "sliding_attention" layer, clamp to max_cache_len - 1
indices = cache_position.clamp(min=0, max=max_length - 1)
else:
indices = cache_position
indices = (
slice(0, max_length)
if cache_position.shape[0] > max_length
else cache_position.clamp(min=0, max=max_length - 1)
)
key_states = past_key_value.key_cache[self.kv_shared_layer_index][:, :, indices]
value_states = past_key_value.value_cache[self.kv_shared_layer_index][:, :, indices]
# Device of past layer may be different from current one
key_states = past_key_value.key_cache[self.kv_shared_layer_index][:, :, indices].to(query_states.device)
value_states = past_key_value.value_cache[self.kv_shared_layer_index][:, :, indices].to(
query_states.device
)
else:
key_states = self.k_proj(hidden_states).view(hidden_shape)
key_states = self.k_norm(key_states)
@ -1880,10 +1889,9 @@ class Gemma3nTextDecoderLayer(Gemma3DecoderLayer):
attn_ffw_laurel_gated = attn_laurel + attn_ffw_norm
corrected_predictions = self.altup.correct(predictions, attn_ffw_laurel_gated)
first_prediction = corrected_predictions[self.config.altup_active_idx]
first_prediction_clone = first_prediction.clone()
first_prediction = corrected_predictions[self.config.altup_active_idx].clone()
if self.config.altup_correct_scale:
first_prediction = self.altup.scale_corrected_output(first_prediction_clone)
first_prediction = self.altup.scale_corrected_output(first_prediction)
# per_layer_input_gate adapted from jax.numpy.einsum("btd,dp->btp", ...)
first_prediction = self.per_layer_input_gate(first_prediction)
@ -1906,7 +1914,7 @@ class Gemma3nTextDecoderLayer(Gemma3DecoderLayer):
class Gemma3nPreTrainedModel(Gemma2PreTrainedModel):
config_class = Gemma3nConfig
base_model_prefix = ""
_no_split_modules = ["Gemma3nDecoderLayer"]
_no_split_modules = ["Gemma3nTextDecoderLayer"]
def _init_weights(self, module):
# important: this ported version of Gemma2 isn't meant for training from scratch - only
@ -1995,7 +2003,9 @@ class Gemma3nTextModel(Gemma3TextModel):
per_layer_inputs: Optional[torch.Tensor] = None,
) -> torch.Tensor:
per_layer_projection: torch.Tensor = self.per_layer_model_projection(inputs_embeds)
per_layer_projection *= self.per_layer_projection_scale.type(inputs_embeds.dtype)
per_layer_projection *= self.per_layer_projection_scale.to(
dtype=inputs_embeds.dtype, device=per_layer_projection.device
)
per_layer_projection = per_layer_projection.reshape(
*inputs_embeds.shape[:-1],
self.config.num_hidden_layers,
@ -2010,7 +2020,9 @@ class Gemma3nTextModel(Gemma3TextModel):
# per-layer inputs are sometimes padded with zeros, slice the relevant embeddings.
per_layer_inputs = per_layer_inputs[..., : self.config.num_hidden_layers, :]
return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale.type(inputs_embeds.dtype)
return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale.to(
dtype=inputs_embeds.dtype, device=per_layer_projection.device
)
@can_return_tuple
@auto_docstring
@ -2091,18 +2103,17 @@ class Gemma3nTextModel(Gemma3TextModel):
position_embeddings_local = self.rotary_emb_local(hidden_states_0, position_ids)
# Expand hidden_states to support per-layer inputs
target_magnitude: torch.Tensor = torch.mean(hidden_states_0**2, dim=-1, keepdim=True) ** 0.5
epsilon_tensor = torch.tensor(torch.finfo().min)
target_magnitude = torch.mean(hidden_states_0**2, dim=-1, keepdim=True) ** 0.5
epsilon_tensor = torch.tensor(1e-5)
temp_hidden_states = [hidden_states_0]
for i in range(1, self.config.altup_num_inputs):
# altup_proj adapted from jax.numpy.einsum("btp,pd->btd", ...)
altup_proj: torch.Tensor = self.altup_projections[i - 1](hidden_states_0)
current_hidden_state = altup_proj.type(hidden_states_0.dtype)
new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True) ** 0.5
current_hidden_state = current_hidden_state * (
target_magnitude / torch.maximum(new_magnitude, epsilon_tensor)
)
altup_proj = self.altup_projections[i - 1](hidden_states_0)
current_hidden_state = altup_proj.to(dtype=hidden_states_0.dtype, device=target_magnitude.device)
new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True)
new_magnitude = torch.sqrt(torch.maximum(new_magnitude, epsilon_tensor.to(target_magnitude.device)))
current_hidden_state = current_hidden_state * target_magnitude / new_magnitude
temp_hidden_states.append(current_hidden_state)
hidden_states = torch.stack(temp_hidden_states, dim=0) # [num_altup_inputs, batch, seq_len, hidden_size]
@ -2120,9 +2131,9 @@ class Gemma3nTextModel(Gemma3TextModel):
layer_outputs = decoder_layer(
hidden_states,
position_embeddings_global=position_embeddings_global,
position_embeddings_local=position_embeddings_local,
per_layer_input=per_layer_input,
position_embeddings_global,
position_embeddings_local,
per_layer_input,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
@ -2147,11 +2158,10 @@ class Gemma3nTextModel(Gemma3TextModel):
for i in range(1, self.config.altup_num_inputs):
# altup_unembed_projections adapted from jax.numpy.einsum("btp,pd->btd", ...)
altup_unemb_proj: torch.Tensor = self.altup_unembed_projections[i - 1](hidden_states[i])
current_hidden_state = altup_unemb_proj.type(hidden_states_0.dtype)
new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True) ** 0.5
current_hidden_state = current_hidden_state * (
target_magnitude / torch.maximum(new_magnitude, epsilon_tensor)
)
current_hidden_state = altup_unemb_proj.to(dtype=hidden_states_0.dtype, device=target_magnitude.device)
new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True)
new_magnitude = torch.sqrt(torch.maximum(new_magnitude, epsilon_tensor.to(target_magnitude.device)))
current_hidden_state = current_hidden_state * target_magnitude / new_magnitude
temp_hidden_states.append(current_hidden_state)
hidden_states = torch.stack(temp_hidden_states)

View File

@ -1642,7 +1642,6 @@ def set_model_tester_for_less_flaky_test(test_case):
"AriaVisionText2TextModelTester",
"GPTNeoModelTester",
"DPTModelTester",
"Gemma3nTextModelTester", # cannot have a single layer combined with the cache sharing config attrs in the tester
]
if test_case.model_tester.__class__.__name__ in exceptional_classes:
target_num_hidden_layers = None

View File

@ -39,13 +39,20 @@ from transformers.testing_utils import (
require_read_token,
require_torch,
require_torch_gpu,
require_torch_sdpa,
slow,
torch_device,
)
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from ...test_modeling_common import (
TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION,
ModelTesterMixin,
_test_eager_matches_sdpa_inference,
floats_tensor,
ids_tensor,
)
from ..gemma.test_modeling_gemma import GemmaModelTester
@ -256,6 +263,7 @@ class Gemma3nTextModelTester(GemmaModelTester):
vocab_size=99,
vocab_size_per_layer_input=99,
hidden_size=16,
hidden_size_per_layer_input=16,
num_hidden_layers=4, # override to correctly test sharing cache pattern
num_kv_shared_layers=2, # important to override
layer_types=[
@ -291,6 +299,7 @@ class Gemma3nTextModelTester(GemmaModelTester):
self.vocab_size = vocab_size
self.vocab_size_per_layer_input = vocab_size_per_layer_input
self.hidden_size = hidden_size
self.hidden_size_per_layer_input = hidden_size_per_layer_input
self.num_hidden_layers = num_hidden_layers
self.num_kv_shared_layers = num_kv_shared_layers
self.layer_types = layer_types
@ -317,7 +326,6 @@ class Gemma3nTextModelTester(GemmaModelTester):
for_causal_lm_class = Gemma3nForCausalLM
@unittest.skip("Skipped for now!")
@require_torch
class Gemma3nTextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (Gemma3nTextModel, Gemma3nForCausalLM) if is_torch_available() else ()
@ -365,6 +373,64 @@ class Gemma3nTextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Tes
[expected_shape] * len(iter_hidden_states),
)
@parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION)
@require_torch_sdpa
def test_eager_matches_sdpa_inference(
self,
name,
torch_dtype,
padding_side,
use_attention_mask,
output_attentions,
enable_kernels,
):
"We need to relax a bit the `atols` for fp32 here due to the altup projections"
atols = {
("cpu", False, torch.float32): 1e-3, # this was relaxed
("cpu", False, torch.float16): 5e-3,
("cpu", False, torch.bfloat16): 1e-2,
("cpu", True, torch.float32): 1e-3, # this was relaxed
("cpu", True, torch.float16): 5e-3,
("cpu", True, torch.bfloat16): 1e-2,
("cuda", False, torch.float32): 1e-3, # this was relaxed
("cuda", False, torch.bfloat16): 1e-2,
("cuda", False, torch.float16): 5e-3,
("cuda", True, torch.float32): 1e-3, # this was relaxed
("cuda", True, torch.bfloat16): 1e-2,
("cuda", True, torch.float16): 5e-3,
}
_test_eager_matches_sdpa_inference(
self, name, torch_dtype, padding_side, use_attention_mask, output_attentions, enable_kernels, atols=atols
)
@pytest.mark.generate
@unittest.skip(
"Gemma3n has a special shape for hidden states (due to per-layer projs) which is not compatible with contrastive decoding"
)
def test_contrastive_generate(self):
pass
@pytest.mark.generate
@unittest.skip(
"Gemma3n has a special shape for hidden states (due to per-layer projs) which is not compatible with contrastive decoding"
)
def test_contrastive_generate_dict_outputs_use_cache(self):
pass
@pytest.mark.generate
@unittest.skip(
"Gemma3n has a special shape for hidden states (due to per-layer projs) which is not compatible with contrastive decoding"
)
def test_contrastive_generate_low_memory(self):
pass
@pytest.mark.generate
@unittest.skip(
"Gemma3n has a special shape for hidden states (due to per-layer projs) which is not compatible with dola decoding"
)
def test_dola_decoding_sample(self):
pass
class Gemma3nVision2TextModelTester:
text_config = {"activation_sparsity_pattern": None}

View File

@ -156,6 +156,334 @@ TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION = [
] + [("fp32_pad_left_output_attentions", "fp32", "left", True, True, False)]
def _test_eager_matches_sdpa_inference(
self,
name,
torch_dtype,
padding_side,
use_attention_mask,
output_attentions,
enable_kernels,
atols=None,
rtols=None,
):
"""
This test is written as a regular function to be able to overload it easily with different tolerances.
Otherwise, `paramterezie.expand` prevents it as it removes the original function from the namespace.
"""
# TODO: we shouldn't need to do this skip, i.e. the test would be composable from the model tester. CLIP-like
# models have a custom mixin, which we detect to skip this test.
if any(".CLIPModelTesterMixin" in str(base) for base in self.__class__.__bases__):
self.skipTest(reason="CLIP-like models have a different `test_eager_matches_sdpa_inference`")
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
if not self.all_model_classes[0]._supports_sdpa:
self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
# convert shorthand name to torch.dtype
if torch_dtype == "fp16":
torch_dtype = torch.float16
elif torch_dtype == "bf16":
torch_dtype = torch.bfloat16
elif torch_dtype == "fp32":
torch_dtype = torch.float32
if not is_torch_fp16_available_on_device(torch_device) and torch_dtype == torch.float16:
self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)")
if not is_torch_bf16_available_on_device(torch_device) and torch_dtype == torch.bfloat16:
self.skipTest(
f"bfloat16 not supported on {torch_device} (on the specific device currently used, e.g. Nvidia T4 GPU)"
)
# Dictionary of tolerances for eager <> sdpa tests. Key = (device, sdpa_kernels_enabled, dtype)
if atols is None:
atols = {
("cpu", False, torch.float32): 1e-6,
("cpu", False, torch.float16): 5e-3,
("cpu", False, torch.bfloat16): 1e-2,
("cpu", True, torch.float32): 1e-6,
("cpu", True, torch.float16): 5e-3,
("cpu", True, torch.bfloat16): 1e-2,
("cuda", False, torch.float32): 1e-6,
("cuda", False, torch.bfloat16): 1e-2,
("cuda", False, torch.float16): 5e-3,
("cuda", True, torch.float32): 1e-6,
("cuda", True, torch.bfloat16): 1e-2,
("cuda", True, torch.float16): 5e-3,
}
if rtols is None:
rtols = {
("cpu", False, torch.float32): 1e-4,
("cpu", False, torch.float16): 5e-3,
("cpu", False, torch.bfloat16): 1e-2,
("cpu", True, torch.float32): 1e-4,
("cpu", True, torch.float16): 5e-3,
("cpu", True, torch.bfloat16): 1e-2,
("cuda", False, torch.float32): 1e-4,
("cuda", False, torch.bfloat16): 1e-2,
("cuda", False, torch.float16): 5e-3,
("cuda", True, torch.float32): 1e-4,
("cuda", True, torch.bfloat16): 3e-2, # (different from others)
("cuda", True, torch.float16): 5e-3,
}
set_model_tester_for_less_flaky_test(self)
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
set_config_for_less_flaky_test(config)
model = model_class(config)
# TODO: standardize the interfaces for musicgen models, see other todo in this test
if model.__class__.__name__ == "MusicgenMelodyForConditionalGeneration":
is_encoder_decoder = True
else:
is_encoder_decoder = model.config.is_encoder_decoder
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_from_pretrained_kwargs = {
"pretrained_model_name_or_path": tmpdirname,
"torch_dtype": torch_dtype,
}
if hasattr(config, "use_mask_token") or "use_mask_token" in inspect.signature(model.__init__).parameters:
model_from_pretrained_kwargs["use_mask_token"] = True
# TODO: remove this try/except, models should have a shared API
try:
model_sdpa = model_class.from_pretrained(**model_from_pretrained_kwargs, attn_implementation="sdpa")
except ValueError:
model_sdpa = model_class.from_pretrained(**model_from_pretrained_kwargs)
model_sdpa = model_sdpa.eval().to(torch_device, dtype=torch_dtype)
model_eager = model_class.from_pretrained(**model_from_pretrained_kwargs, attn_implementation="eager")
model_eager = model_eager.eval().to(torch_device, dtype=torch_dtype)
set_model_for_less_flaky_test(model_eager)
set_model_for_less_flaky_test(model_sdpa)
can_output_attn = "output_attentions" in inspect.signature(model_sdpa.forward).parameters
if not (self.has_attentions and can_output_attn) and output_attentions:
self.skipTest(reason="Model does not support output_attentions")
# TODO: if we can also check with `batch_size=1` without being flaky?
for batch_size in [7]:
# musicgen decoder models; TODO: find better abstraction
if (
model.__class__.__name__.startswith("Musicgen")
and hasattr(self.model_tester, "num_codebooks")
and not hasattr(model_eager, "text_encoder")
):
input_data_batch_size = batch_size * self.model_tester.num_codebooks
else:
input_data_batch_size = batch_size
processed_inputs = {}
processed_inputs[model.main_input_name] = inputs_dict[model.main_input_name]
for key in getattr(self, "additional_model_inputs", []):
# Some models don't have all `additional_model_inputs`, especially when we
# craft cases to test model in different settings
if key in inputs_dict:
processed_inputs[key] = inputs_dict[key]
for key, value in processed_inputs.items():
if torch.is_floating_point(value):
value = value.to(torch_dtype)
# extend value to have at least `input_data_batch_size` elements
if value.shape[0] < input_data_batch_size:
size = (input_data_batch_size - value.shape[0], *value.shape[1:])
if torch.is_floating_point(value):
extension = torch.rand(size=size, dtype=value.dtype, device=torch_device)
else:
extension = torch.randint(high=5, size=size, dtype=value.dtype, device=torch_device)
value = torch.cat((value, extension), dim=0).to(torch_device)
processed_inputs[key] = value[:input_data_batch_size]
if not use_attention_mask:
dummy_attention_mask = None
else:
dummy_attention_mask = inputs_dict.get("attention_mask", None)
if dummy_attention_mask is None:
if is_encoder_decoder:
seqlen = inputs_dict.get("decoder_input_ids", processed_inputs[model.main_input_name]).shape[
-1
]
else:
seqlen = processed_inputs[model.main_input_name].shape[-1]
dummy_attention_mask = torch.ones(batch_size, seqlen).to(torch.int64).to(torch_device)
# extend dummy_attention_mask to have at least `batch_size` elements
if dummy_attention_mask.shape[0] < batch_size:
size = (batch_size - dummy_attention_mask.shape[0], *dummy_attention_mask.shape[1:])
extension = torch.ones(size=size, dtype=dummy_attention_mask.dtype, device=torch_device)
dummy_attention_mask = torch.cat((dummy_attention_mask, extension), dim=0)
dummy_attention_mask = dummy_attention_mask[:batch_size].to(torch_device)
dummy_attention_mask[:] = 1
if padding_side == "left":
dummy_attention_mask[-1, :2] = 0
dummy_attention_mask[-1, 2:] = 1
elif padding_side == "right":
dummy_attention_mask[-1, -2:] = 0
dummy_attention_mask[-1, :-2] = 1
if is_encoder_decoder:
# musicgen encoder-decoder models; TODO: find better abstraction
if model.__class__.__name__.startswith("Musicgen") and hasattr(self.model_tester, "num_codebooks"):
input_data_batch_size = batch_size * self.model_tester.num_codebooks
else:
input_data_batch_size = batch_size
decoder_input_ids = inputs_dict.get("decoder_input_ids", processed_inputs[model.main_input_name])
decoder_input_ids = decoder_input_ids[:input_data_batch_size]
if decoder_input_ids.shape[0] != input_data_batch_size:
extension = torch.ones(
input_data_batch_size - decoder_input_ids.shape[0],
*decoder_input_ids.shape[1:],
dtype=decoder_input_ids.dtype,
device=torch_device,
)
decoder_input_ids = torch.cat((decoder_input_ids, extension), dim=0)
decoder_input_ids = decoder_input_ids.to(torch_device)
# TODO: never an `attention_mask` arg here?
processed_inputs.update(
{
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": dummy_attention_mask,
"output_hidden_states": True,
}
)
else:
processed_inputs.update(
{
"output_hidden_states": True,
}
)
# Otherwise fails for e.g. WhisperEncoderModel
if "attention_mask" in inspect.signature(model_eager.forward).parameters:
processed_inputs["attention_mask"] = dummy_attention_mask
if self.has_attentions and "output_attentions" in inspect.signature(model_sdpa.forward).parameters:
processed_inputs["output_attentions"] = output_attentions
if "bool_masked_pos" in inspect.signature(model_eager.forward).parameters:
dummy_mask = torch.ones((self.model_tester.num_masks,))
# In case of additional token (like class) we define a custom `mask_length`
if hasattr(self.model_tester, "mask_length"):
mask_length = self.model_tester.mask_length - dummy_mask.size(0)
else:
mask_length = self.model_tester.seq_length - dummy_mask.size(0)
dummy_mask = torch.cat([dummy_mask, torch.zeros(mask_length)])
dummy_bool_masked_pos = dummy_mask.expand(batch_size, -1).bool()
processed_inputs["bool_masked_pos"] = dummy_bool_masked_pos.to(torch_device)
if "noise" in inspect.signature(model_eager.forward).parameters:
np.random.seed(2)
num_patches = int((self.model_tester.image_size // self.model_tester.patch_size) ** 2)
noise = np.random.uniform(size=(batch_size, num_patches))
processed_inputs["noise"] = torch.from_numpy(noise)
# TODO: test gradients as well (& for FA2 as well!)
with torch.no_grad():
with sdpa_kernel(
enable_flash=enable_kernels,
enable_math=True,
enable_mem_efficient=enable_kernels,
):
prepared_inputs = self._prepare_for_class(processed_inputs, model_class)
prepared_inputs = {
k: v.to(torch_device) if isinstance(v, torch.Tensor) else v for k, v in prepared_inputs.items()
}
outputs_eager = model_eager(**prepared_inputs)
outputs_sdpa = model_sdpa(**prepared_inputs)
if "logits_per_text" in outputs_eager:
key = "logits_per_text"
elif "vision_hidden_states" in outputs_eager:
key = "vision_hidden_states"
elif "audio_values" in outputs_eager:
key = "audio_values"
elif "decoder_hidden_states" in outputs_eager:
key = "decoder_hidden_states"
elif "logits" in outputs_eager and "Classification" in model_class.__name__:
key = "logits"
elif "language_model_outputs" in outputs_eager and "blip" in model_class.__name__.lower():
outputs_eager = outputs_eager["language_model_outputs"]
outputs_sdpa = outputs_sdpa["language_model_outputs"]
key = "hidden_states" if "hidden_states" in outputs_eager else "decoder_hidden_states"
else:
key = "hidden_states"
# TODO: rename logits -> hidden_states
logits_eager = outputs_eager[key]
logits_sdpa = outputs_sdpa[key]
if key in ["vision_hidden_states", "decoder_hidden_states", "hidden_states"]:
logits_eager = logits_eager[-1]
logits_sdpa = logits_sdpa[-1]
if key == "logits_per_text":
nan_mask = torch.isnan(logits_eager)
logits_eager[nan_mask] = 0
logits_sdpa[nan_mask] = 0
if torch_device in ["cpu", "cuda"]:
atol = atols[torch_device, enable_kernels, torch_dtype]
rtol = rtols[torch_device, enable_kernels, torch_dtype]
elif torch_device == "hpu":
atol = atols["cuda", enable_kernels, torch_dtype]
rtol = rtols["cuda", enable_kernels, torch_dtype]
elif torch_device == "xpu":
# As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH
# which is implemented on PyTorch level using aten operators and is
# device agnostic with respect to implementation of each aten operator.
atol = atols["cuda", False, torch_dtype]
rtol = rtols["cuda", False, torch_dtype]
else:
atol = 1e-7
rtol = 1e-4
# Masked tokens output slightly deviates - we don't mind that.
if use_attention_mask:
_logits_sdpa = torch.zeros_like(input=logits_sdpa)
_logits_eager = torch.zeros_like(input=logits_eager)
_logits_sdpa[:-1] = logits_sdpa[:-1]
_logits_eager[:-1] = logits_eager[:-1]
if padding_side == "left":
_logits_sdpa[-1:, 2:] = logits_sdpa[-1:, 2:]
_logits_eager[-1:, 2:] = logits_eager[-1:, 2:]
elif padding_side == "right":
_logits_sdpa[-1:, 2:] = logits_sdpa[-1:, :-2]
_logits_eager[-1:, 2:] = logits_eager[-1:, :-2]
logits_sdpa = _logits_sdpa
logits_eager = _logits_eager
results = [
torch.allclose(_logits_sdpa, _logits_eager, atol=atol, rtol=rtol)
for (_logits_sdpa, _logits_eager) in zip(logits_sdpa, logits_eager)
]
# If 80% batch elements have matched results, it's fine
if np.mean(results) < 0.8:
mean_relative_diff = ((logits_sdpa - logits_eager).abs() / (logits_eager.abs() + 1e-12)).mean()
raise ValueError(
f"mean relative difference for {key}: {mean_relative_diff:.3e}, torch atol = {atol}, torch rtol = "
f"{rtol}"
)
def _config_zero_init(config):
configs_no_init = copy.deepcopy(config)
for key in configs_no_init.__dict__.keys():
@ -3405,321 +3733,9 @@ class ModelTesterMixin:
def test_eager_matches_sdpa_inference(
self, name, torch_dtype, padding_side, use_attention_mask, output_attentions, enable_kernels
):
# TODO: we shouldn't need to do this skip, i.e. the test would be composable from the model tester. CLIP-like
# models have a custom mixin, which we detect to skip this test.
if any(".CLIPModelTesterMixin" in str(base) for base in self.__class__.__bases__):
self.skipTest(reason="CLIP-like models have a different `test_eager_matches_sdpa_inference`")
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
if not self.all_model_classes[0]._supports_sdpa:
self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
# convert shorthand name to torch.dtype
if torch_dtype == "fp16":
torch_dtype = torch.float16
elif torch_dtype == "bf16":
torch_dtype = torch.bfloat16
elif torch_dtype == "fp32":
torch_dtype = torch.float32
if not is_torch_fp16_available_on_device(torch_device) and torch_dtype == torch.float16:
self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)")
if not is_torch_bf16_available_on_device(torch_device) and torch_dtype == torch.bfloat16:
self.skipTest(
f"bfloat16 not supported on {torch_device} (on the specific device currently used, e.g. Nvidia T4 GPU)"
)
# Dictionary of tolerances for eager <> sdpa tests. Key = (device, sdpa_kernels_enabled, dtype)
atols = {
("cpu", False, torch.float32): 1e-6,
("cpu", False, torch.float16): 5e-3,
("cpu", False, torch.bfloat16): 1e-2,
("cpu", True, torch.float32): 1e-6,
("cpu", True, torch.float16): 5e-3,
("cpu", True, torch.bfloat16): 1e-2,
("cuda", False, torch.float32): 1e-6,
("cuda", False, torch.bfloat16): 1e-2,
("cuda", False, torch.float16): 5e-3,
("cuda", True, torch.float32): 1e-6,
("cuda", True, torch.bfloat16): 1e-2,
("cuda", True, torch.float16): 5e-3,
}
rtols = {
("cpu", False, torch.float32): 1e-4,
("cpu", False, torch.float16): 5e-3,
("cpu", False, torch.bfloat16): 1e-2,
("cpu", True, torch.float32): 1e-4,
("cpu", True, torch.float16): 5e-3,
("cpu", True, torch.bfloat16): 1e-2,
("cuda", False, torch.float32): 1e-4,
("cuda", False, torch.bfloat16): 1e-2,
("cuda", False, torch.float16): 5e-3,
("cuda", True, torch.float32): 1e-4,
("cuda", True, torch.bfloat16): 3e-2, # (different from others)
("cuda", True, torch.float16): 5e-3,
}
set_model_tester_for_less_flaky_test(self)
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
set_config_for_less_flaky_test(config)
model = model_class(config)
# TODO: standardize the interfaces for musicgen models, see other todo in this test
if model.__class__.__name__ == "MusicgenMelodyForConditionalGeneration":
is_encoder_decoder = True
else:
is_encoder_decoder = model.config.is_encoder_decoder
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_from_pretrained_kwargs = {
"pretrained_model_name_or_path": tmpdirname,
"torch_dtype": torch_dtype,
}
if (
hasattr(config, "use_mask_token")
or "use_mask_token" in inspect.signature(model.__init__).parameters
):
model_from_pretrained_kwargs["use_mask_token"] = True
# TODO: remove this try/except, models should have a shared API
try:
model_sdpa = model_class.from_pretrained(
**model_from_pretrained_kwargs, attn_implementation="sdpa"
)
except ValueError:
model_sdpa = model_class.from_pretrained(**model_from_pretrained_kwargs)
model_sdpa = model_sdpa.eval().to(torch_device, dtype=torch_dtype)
model_eager = model_class.from_pretrained(**model_from_pretrained_kwargs, attn_implementation="eager")
model_eager = model_eager.eval().to(torch_device, dtype=torch_dtype)
set_model_for_less_flaky_test(model_eager)
set_model_for_less_flaky_test(model_sdpa)
can_output_attn = "output_attentions" in inspect.signature(model_sdpa.forward).parameters
if not (self.has_attentions and can_output_attn) and output_attentions:
self.skipTest(reason="Model does not support output_attentions")
# TODO: if we can also check with `batch_size=1` without being flaky?
for batch_size in [7]:
# musicgen decoder models; TODO: find better abstraction
if (
model.__class__.__name__.startswith("Musicgen")
and hasattr(self.model_tester, "num_codebooks")
and not hasattr(model_eager, "text_encoder")
):
input_data_batch_size = batch_size * self.model_tester.num_codebooks
else:
input_data_batch_size = batch_size
processed_inputs = {}
processed_inputs[model.main_input_name] = inputs_dict[model.main_input_name]
for key in getattr(self, "additional_model_inputs", []):
# Some models don't have all `additional_model_inputs`, especially when we
# craft cases to test model in different settings
if key in inputs_dict:
processed_inputs[key] = inputs_dict[key]
for key, value in processed_inputs.items():
if torch.is_floating_point(value):
value = value.to(torch_dtype)
# extend value to have at least `input_data_batch_size` elements
if value.shape[0] < input_data_batch_size:
size = (input_data_batch_size - value.shape[0], *value.shape[1:])
if torch.is_floating_point(value):
extension = torch.rand(size=size, dtype=value.dtype, device=torch_device)
else:
extension = torch.randint(high=5, size=size, dtype=value.dtype, device=torch_device)
value = torch.cat((value, extension), dim=0).to(torch_device)
processed_inputs[key] = value[:input_data_batch_size]
if not use_attention_mask:
dummy_attention_mask = None
else:
dummy_attention_mask = inputs_dict.get("attention_mask", None)
if dummy_attention_mask is None:
if is_encoder_decoder:
seqlen = inputs_dict.get(
"decoder_input_ids", processed_inputs[model.main_input_name]
).shape[-1]
else:
seqlen = processed_inputs[model.main_input_name].shape[-1]
dummy_attention_mask = torch.ones(batch_size, seqlen).to(torch.int64).to(torch_device)
# extend dummy_attention_mask to have at least `batch_size` elements
if dummy_attention_mask.shape[0] < batch_size:
size = (batch_size - dummy_attention_mask.shape[0], *dummy_attention_mask.shape[1:])
extension = torch.ones(size=size, dtype=dummy_attention_mask.dtype, device=torch_device)
dummy_attention_mask = torch.cat((dummy_attention_mask, extension), dim=0)
dummy_attention_mask = dummy_attention_mask[:batch_size].to(torch_device)
dummy_attention_mask[:] = 1
if padding_side == "left":
dummy_attention_mask[-1, :2] = 0
dummy_attention_mask[-1, 2:] = 1
elif padding_side == "right":
dummy_attention_mask[-1, -2:] = 0
dummy_attention_mask[-1, :-2] = 1
if is_encoder_decoder:
# musicgen encoder-decoder models; TODO: find better abstraction
if model.__class__.__name__.startswith("Musicgen") and hasattr(self.model_tester, "num_codebooks"):
input_data_batch_size = batch_size * self.model_tester.num_codebooks
else:
input_data_batch_size = batch_size
decoder_input_ids = inputs_dict.get("decoder_input_ids", processed_inputs[model.main_input_name])
decoder_input_ids = decoder_input_ids[:input_data_batch_size]
if decoder_input_ids.shape[0] != input_data_batch_size:
extension = torch.ones(
input_data_batch_size - decoder_input_ids.shape[0],
*decoder_input_ids.shape[1:],
dtype=decoder_input_ids.dtype,
device=torch_device,
)
decoder_input_ids = torch.cat((decoder_input_ids, extension), dim=0)
decoder_input_ids = decoder_input_ids.to(torch_device)
# TODO: never an `attention_mask` arg here?
processed_inputs.update(
{
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": dummy_attention_mask,
"output_hidden_states": True,
}
)
else:
processed_inputs.update(
{
"output_hidden_states": True,
}
)
# Otherwise fails for e.g. WhisperEncoderModel
if "attention_mask" in inspect.signature(model_eager.forward).parameters:
processed_inputs["attention_mask"] = dummy_attention_mask
if self.has_attentions and "output_attentions" in inspect.signature(model_sdpa.forward).parameters:
processed_inputs["output_attentions"] = output_attentions
if "bool_masked_pos" in inspect.signature(model_eager.forward).parameters:
dummy_mask = torch.ones((self.model_tester.num_masks,))
# In case of additional token (like class) we define a custom `mask_length`
if hasattr(self.model_tester, "mask_length"):
mask_length = self.model_tester.mask_length - dummy_mask.size(0)
else:
mask_length = self.model_tester.seq_length - dummy_mask.size(0)
dummy_mask = torch.cat([dummy_mask, torch.zeros(mask_length)])
dummy_bool_masked_pos = dummy_mask.expand(batch_size, -1).bool()
processed_inputs["bool_masked_pos"] = dummy_bool_masked_pos.to(torch_device)
if "noise" in inspect.signature(model_eager.forward).parameters:
np.random.seed(2)
num_patches = int((self.model_tester.image_size // self.model_tester.patch_size) ** 2)
noise = np.random.uniform(size=(batch_size, num_patches))
processed_inputs["noise"] = torch.from_numpy(noise)
# TODO: test gradients as well (& for FA2 as well!)
with torch.no_grad():
with sdpa_kernel(
enable_flash=enable_kernels,
enable_math=True,
enable_mem_efficient=enable_kernels,
):
prepared_inputs = self._prepare_for_class(processed_inputs, model_class)
prepared_inputs = {
k: v.to(torch_device) if isinstance(v, torch.Tensor) else v
for k, v in prepared_inputs.items()
}
outputs_eager = model_eager(**prepared_inputs)
outputs_sdpa = model_sdpa(**prepared_inputs)
if "logits_per_text" in outputs_eager:
key = "logits_per_text"
elif "vision_hidden_states" in outputs_eager:
key = "vision_hidden_states"
elif "audio_values" in outputs_eager:
key = "audio_values"
elif "decoder_hidden_states" in outputs_eager:
key = "decoder_hidden_states"
elif "logits" in outputs_eager and "Classification" in model_class.__name__:
key = "logits"
elif "language_model_outputs" in outputs_eager and "blip" in model_class.__name__.lower():
outputs_eager = outputs_eager["language_model_outputs"]
outputs_sdpa = outputs_sdpa["language_model_outputs"]
key = "hidden_states" if "hidden_states" in outputs_eager else "decoder_hidden_states"
else:
key = "hidden_states"
# TODO: rename logits -> hidden_states
logits_eager = outputs_eager[key]
logits_sdpa = outputs_sdpa[key]
if key in ["vision_hidden_states", "decoder_hidden_states", "hidden_states"]:
logits_eager = logits_eager[-1]
logits_sdpa = logits_sdpa[-1]
if key == "logits_per_text":
nan_mask = torch.isnan(logits_eager)
logits_eager[nan_mask] = 0
logits_sdpa[nan_mask] = 0
if torch_device in ["cpu", "cuda"]:
atol = atols[torch_device, enable_kernels, torch_dtype]
rtol = rtols[torch_device, enable_kernels, torch_dtype]
elif torch_device == "hpu":
atol = atols["cuda", enable_kernels, torch_dtype]
rtol = rtols["cuda", enable_kernels, torch_dtype]
elif torch_device == "xpu":
# As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH
# which is implemented on PyTorch level using aten operators and is
# device agnostic with respect to implementation of each aten operator.
atol = atols["cuda", False, torch_dtype]
rtol = rtols["cuda", False, torch_dtype]
else:
atol = 1e-7
rtol = 1e-4
# Masked tokens output slightly deviates - we don't mind that.
if use_attention_mask:
_logits_sdpa = torch.zeros_like(input=logits_sdpa)
_logits_eager = torch.zeros_like(input=logits_eager)
_logits_sdpa[:-1] = logits_sdpa[:-1]
_logits_eager[:-1] = logits_eager[:-1]
if padding_side == "left":
_logits_sdpa[-1:, 2:] = logits_sdpa[-1:, 2:]
_logits_eager[-1:, 2:] = logits_eager[-1:, 2:]
elif padding_side == "right":
_logits_sdpa[-1:, 2:] = logits_sdpa[-1:, :-2]
_logits_eager[-1:, 2:] = logits_eager[-1:, :-2]
logits_sdpa = _logits_sdpa
logits_eager = _logits_eager
results = [
torch.allclose(_logits_sdpa, _logits_eager, atol=atol, rtol=rtol)
for (_logits_sdpa, _logits_eager) in zip(logits_sdpa, logits_eager)
]
# If 80% batch elements have matched results, it's fine
if np.mean(results) < 0.8:
mean_relative_diff = ((logits_sdpa - logits_eager).abs() / (logits_eager.abs() + 1e-12)).mean()
raise ValueError(
f"mean relative difference for {key}: {mean_relative_diff:.3e}, torch atol = {atol}, torch rtol = "
f"{rtol}"
)
_test_eager_matches_sdpa_inference(
self, name, torch_dtype, padding_side, use_attention_mask, output_attentions, enable_kernels
)
@require_torch_sdpa
@require_torch_accelerator