[tests] Parameterized test_eager_matches_sdpa_inference (#36650)

This commit is contained in:
Joao Gante 2025-03-14 14:41:27 +00:00 committed by GitHub
parent 9215cc62d4
commit 42ebb6c23e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 285 additions and 1900 deletions

View File

@ -17,11 +17,10 @@
import unittest
from packaging import version
from parameterized import parameterized
from transformers import AlbertConfig, AutoTokenizer, is_torch_available
from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, require_torch_sdpa, slow, torch_device
from transformers.testing_utils import require_torch, slow, torch_device
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
@ -289,12 +288,6 @@ class AlbertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
self.model_tester = AlbertModelTester(self)
self.config_tester = ConfigTester(self, config_class=AlbertConfig, hidden_size=37)
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
@unittest.skip("Albert requires `head_mask` which is currently not done in this test.")
def test_eager_matches_sdpa_inference(self):
pass
def test_config(self):
self.config_tester.run_common_tests()

View File

@ -256,11 +256,6 @@ class AyaVisionModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
def test_sdpa_can_dispatch_non_composite_models(self):
pass
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@unittest.skip("Cohere2's eager attn/sdpa attn outputs are expected to be different")
def test_eager_matches_sdpa_inference(self):
pass
@unittest.skip("Cohere2's eager attn/sdpa attn outputs are expected to be different")
def test_eager_matches_sdpa_generate(self):
pass

View File

@ -14,20 +14,15 @@
# limitations under the License.
"""Testing suite for the PyTorch BEiT model."""
import inspect
import tempfile
import unittest
import numpy as np
from datasets import load_dataset
from packaging import version
from parameterized import parameterized
from transformers import BeitConfig
from transformers.testing_utils import (
require_torch,
require_torch_multi_gpu,
require_torch_sdpa,
require_vision,
slow,
torch_device,
@ -35,14 +30,12 @@ from transformers.testing_utils import (
from transformers.utils import (
cached_property,
is_torch_available,
is_torch_bf16_available_on_device,
is_torch_fp16_available_on_device,
is_vision_available,
)
from ...test_backbone_common import BackboneTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor, sdpa_kernel
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
@ -119,6 +112,7 @@ class BeitModelTester:
# in BeiT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
num_patches = (image_size // patch_size) ** 2
self.seq_length = num_patches + 1
self.mask_length = self.seq_length - 1
self.num_masks = int(mask_ratio * self.seq_length)
self.attn_implementation = attn_implementation
@ -414,193 +408,6 @@ class BeitModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
model = BeitModel.from_pretrained(model_name)
self.assertIsNotNone(model)
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
# The common test modifies the num_hidden_layers to be 1. However, for Beit we want to
# avoid that because the num_hidden_layers is generally assumed to be 4. Also, the code
# related to attention masks in the original common tests is not required as the Beit
# model does not handle attention masks. Furthermore, some extra code like modifying
# the norm layers eps values for specialized configs and checking for the 'noise'
# has been omitted to simply the test.
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
if not self.all_model_classes[0]._supports_sdpa:
self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
if torch_dtype == "float16" and not is_torch_fp16_available_on_device(torch_device):
self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)")
if torch_dtype == "bfloat16" and not is_torch_bf16_available_on_device(torch_device):
self.skipTest(
f"bfloat16 not supported on {torch_device} (on the specific device currently used, e.g. Nvidia T4 GPU)"
)
# Not sure whether it's fine to put torch.XXX in a decorator if torch is not available so hacking it here instead.
if torch_dtype == "float16":
torch_dtype = torch.float16
elif torch_dtype == "bfloat16":
torch_dtype = torch.bfloat16
elif torch_dtype == "float32":
torch_dtype = torch.float32
atols = {
("cpu", False, torch.float32): 1e-6,
("cpu", False, torch.float16): 5e-3,
("cpu", False, torch.bfloat16): 1e-2,
("cpu", True, torch.float32): 1e-6,
("cpu", True, torch.float16): 5e-3,
("cpu", True, torch.bfloat16): 1e-2,
("cuda", False, torch.float32): 1e-6,
("cuda", False, torch.bfloat16): 1e-2,
("cuda", False, torch.float16): 5e-3,
("cuda", True, torch.float32): 1e-6,
("cuda", True, torch.bfloat16): 1e-2,
("cuda", True, torch.float16): 5e-3,
}
rtols = {
("cpu", False, torch.float32): 1e-4,
("cpu", False, torch.float16): 5e-3,
("cpu", False, torch.bfloat16): 1e-2,
("cpu", True, torch.float32): 1e-4,
("cpu", True, torch.float16): 5e-3,
("cpu", True, torch.bfloat16): 1e-2,
("cuda", False, torch.float32): 1e-4,
("cuda", False, torch.bfloat16): 1e-2,
("cuda", False, torch.float16): 5e-3,
("cuda", True, torch.float32): 1e-4,
("cuda", True, torch.bfloat16): 3e-2,
("cuda", True, torch.float16): 5e-3,
}
def get_mean_reldiff(failcase, x, ref, atol, rtol):
return f"{failcase}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}"
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.rms_norm_eps = 1.0
config.layer_norm_eps = 1.0
config.norm_eps = 1.0
config.norm_epsilon = 1.0
config.layer_norm_epsilon = 1.0
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype, use_mask_token=True)
model_sdpa = model_sdpa.eval().to(torch_device, dtype=torch_dtype)
model_eager = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch_dtype,
attn_implementation="eager",
use_mask_token=True,
)
model_eager = model_eager.eval().to(torch_device, dtype=torch_dtype)
# Another way to make sure norm layers have desired epsilon. (Some models don't set it from its config.)
for x in model_eager.modules():
if isinstance(x, (nn.LayerNorm, nn.GroupNorm)):
x.eps = 1.0
for x in model_sdpa.modules():
if isinstance(x, (nn.LayerNorm, nn.GroupNorm)):
x.eps = 1.0
# We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 16 times the model,
# but it would be nicer to have an efficient way to use parameterized.expand
fail_cases = []
for padding_side in ["left", "right"]:
for use_mask in [False, True]:
for output_attentions in [True, False]:
can_output_attn = "output_attentions" in inspect.signature(model_sdpa.forward).parameters
if not (self.has_attentions and can_output_attn) and output_attentions:
continue
# TODO: if we can also check with `batch_size=1` without being flaky?
for batch_size in [7]:
dummy_input = inputs_dict[model.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]:
dummy_input = dummy_input.to(torch_dtype)
dummy_input = dummy_input[:batch_size]
for enable_kernels in [False, True]:
failcase = f"padding_side={padding_side}, use_mask={use_mask}, enable_kernels={enable_kernels}"
processed_inputs = {
model.main_input_name: dummy_input,
"output_hidden_states": True,
}
if (
self.has_attentions
and "output_attentions" in inspect.signature(model_sdpa.forward).parameters
):
processed_inputs["output_attentions"] = output_attentions
if "bool_masked_pos" in inspect.signature(model_eager.forward).parameters:
dummy_mask = torch.ones((self.model_tester.num_masks,))
mask_length = self.model_tester.seq_length - 1 - dummy_mask.size(0)
dummy_mask = torch.cat([dummy_mask, torch.zeros(mask_length)])
dummy_bool_masked_pos = dummy_mask.expand(batch_size, -1).bool()
processed_inputs["bool_masked_pos"] = dummy_bool_masked_pos.to(torch_device)
with torch.no_grad():
with sdpa_kernel(
enable_flash=enable_kernels,
enable_math=True,
enable_mem_efficient=enable_kernels,
):
prepared_inputs = self._prepare_for_class(processed_inputs, model_class)
outputs_eager = model_eager(**prepared_inputs)
outputs_sdpa = model_sdpa(**prepared_inputs)
logits_eager = outputs_eager.hidden_states[-1]
logits_sdpa = outputs_sdpa.hidden_states[-1]
if torch_device in ["cpu", "cuda"]:
atol = atols[torch_device, enable_kernels, torch_dtype]
rtol = rtols[torch_device, enable_kernels, torch_dtype]
elif torch_device == "xpu":
# As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH
# which is implemented on PyTorch level using aten operators and is
# device agnostic with respect to implementation of each aten operator.
atol = atols["cuda", False, torch_dtype]
rtol = rtols["cuda", False, torch_dtype]
else:
atol = 1e-7
rtol = 1e-4
# Masked tokens output slightly deviates - we don't mind that.
if use_mask:
_logits_sdpa = torch.zeros_like(input=logits_sdpa)
_logits_eager = torch.zeros_like(input=logits_eager)
_logits_sdpa[:-1] = logits_sdpa[:-1]
_logits_eager[:-1] = logits_eager[:-1]
if padding_side == "left":
_logits_sdpa[-1:, 2:] = logits_sdpa[-1:, 2:]
_logits_eager[-1:, 2:] = logits_eager[-1:, 2:]
elif padding_side == "right":
_logits_sdpa[-1:, 2:] = logits_sdpa[-1:, :-2]
_logits_eager[-1:, 2:] = logits_eager[-1:, :-2]
logits_sdpa = _logits_sdpa
logits_eager = _logits_eager
results = [
torch.allclose(_logits_sdpa, _logits_eager, atol=atol, rtol=rtol)
for (_logits_sdpa, _logits_eager) in zip(logits_sdpa, logits_eager)
]
# If 80% batch elements have matched results, it's fine
if np.mean(results) < 0.8:
fail_cases.append(
get_mean_reldiff(failcase, logits_sdpa, logits_eager, atol, rtol)
)
self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases))
# We will verify our results on an image of cute cats
def prepare_img():

View File

@ -76,11 +76,6 @@ class Cohere2ModelTest(CohereModelTest, unittest.TestCase):
def test_sdpa_can_dispatch_non_composite_models(self):
pass
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@unittest.skip("Cohere2's eager attn/sdpa attn outputs are expected to be different")
def test_eager_matches_sdpa_inference(self):
pass
@unittest.skip("Cohere2's eager attn/sdpa attn outputs are expected to be different")
def test_eager_matches_sdpa_generate(self):
pass

View File

@ -20,7 +20,6 @@ from typing import ClassVar
import torch
from datasets import load_dataset
from parameterized import parameterized
from tests.test_configuration_common import ConfigTester
from tests.test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
@ -32,7 +31,6 @@ from transformers.models.colpali.modeling_colpali import ColPaliForRetrieval, Co
from transformers.models.colpali.processing_colpali import ColPaliProcessor
from transformers.testing_utils import (
require_torch,
require_torch_sdpa,
require_vision,
slow,
torch_device,
@ -271,14 +269,6 @@ class ColPaliForRetrievalModelTest(ModelTesterMixin, unittest.TestCase):
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@require_torch_sdpa
@slow
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
self.skipTest(
"Due to custom causal mask, there is a slightly too big difference between eager and sdpa in bfloat16."
)
@unittest.skip(
reason="From PaliGemma: Some undefined behavior encountered with test versions of this model. Skip for now."
)

View File

@ -14,18 +14,12 @@
# limitations under the License.
"""Testing suite for the PyTorch Data2VecVision model."""
import inspect
import tempfile
import unittest
import numpy as np
from parameterized import parameterized
from transformers import Data2VecVisionConfig
from transformers.testing_utils import (
require_torch,
require_torch_multi_gpu,
require_torch_sdpa,
require_vision,
slow,
torch_device,
@ -33,13 +27,11 @@ from transformers.testing_utils import (
from transformers.utils import (
cached_property,
is_torch_available,
is_torch_bf16_available_on_device,
is_torch_fp16_available_on_device,
is_vision_available,
)
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor, sdpa_kernel
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
@ -111,6 +103,7 @@ class Data2VecVisionModelTester:
# in BeiT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
num_patches = (image_size // patch_size) ** 2
self.seq_length = num_patches + 1
self.mask_length = self.seq_length - 1
self.num_masks = int(mask_ratio * self.seq_length)
self.attn_implementation = attn_implementation
@ -319,194 +312,6 @@ class Data2VecVisionModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Te
model = Data2VecVisionModel.from_pretrained(model_name)
self.assertIsNotNone(model)
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
# Copied from tests.models.beit.test_modeling_beit.BeitModelTest.test_eager_matches_sdpa_inference with Beit->Data2VecVision
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
# The common test modifies the num_hidden_layers to be 1. However, for Data2VecVision we want to
# avoid that because the num_hidden_layers is generally assumed to be 4. Also, the code
# related to attention masks in the original common tests is not required as the Data2VecVision
# model does not handle attention masks. Furthermore, some extra code like modifying
# the norm layers eps values for specialized configs and checking for the 'noise'
# has been omitted to simply the test.
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
if not self.all_model_classes[0]._supports_sdpa:
self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
if torch_dtype == "float16" and not is_torch_fp16_available_on_device(torch_device):
self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)")
if torch_dtype == "bfloat16" and not is_torch_bf16_available_on_device(torch_device):
self.skipTest(
f"bfloat16 not supported on {torch_device} (on the specific device currently used, e.g. Nvidia T4 GPU)"
)
# Not sure whether it's fine to put torch.XXX in a decorator if torch is not available so hacking it here instead.
if torch_dtype == "float16":
torch_dtype = torch.float16
elif torch_dtype == "bfloat16":
torch_dtype = torch.bfloat16
elif torch_dtype == "float32":
torch_dtype = torch.float32
atols = {
("cpu", False, torch.float32): 1e-6,
("cpu", False, torch.float16): 5e-3,
("cpu", False, torch.bfloat16): 1e-2,
("cpu", True, torch.float32): 1e-6,
("cpu", True, torch.float16): 5e-3,
("cpu", True, torch.bfloat16): 1e-2,
("cuda", False, torch.float32): 1e-6,
("cuda", False, torch.bfloat16): 1e-2,
("cuda", False, torch.float16): 5e-3,
("cuda", True, torch.float32): 1e-6,
("cuda", True, torch.bfloat16): 1e-2,
("cuda", True, torch.float16): 5e-3,
}
rtols = {
("cpu", False, torch.float32): 1e-4,
("cpu", False, torch.float16): 5e-3,
("cpu", False, torch.bfloat16): 1e-2,
("cpu", True, torch.float32): 1e-4,
("cpu", True, torch.float16): 5e-3,
("cpu", True, torch.bfloat16): 1e-2,
("cuda", False, torch.float32): 1e-4,
("cuda", False, torch.bfloat16): 1e-2,
("cuda", False, torch.float16): 5e-3,
("cuda", True, torch.float32): 1e-4,
("cuda", True, torch.bfloat16): 3e-2,
("cuda", True, torch.float16): 5e-3,
}
def get_mean_reldiff(failcase, x, ref, atol, rtol):
return f"{failcase}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}"
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.rms_norm_eps = 1.0
config.layer_norm_eps = 1.0
config.norm_eps = 1.0
config.norm_epsilon = 1.0
config.layer_norm_epsilon = 1.0
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype, use_mask_token=True)
model_sdpa = model_sdpa.eval().to(torch_device, dtype=torch_dtype)
model_eager = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch_dtype,
attn_implementation="eager",
use_mask_token=True,
)
model_eager = model_eager.eval().to(torch_device, dtype=torch_dtype)
# Another way to make sure norm layers have desired epsilon. (Some models don't set it from its config.)
for x in model_eager.modules():
if isinstance(x, (nn.LayerNorm, nn.GroupNorm)):
x.eps = 1.0
for x in model_sdpa.modules():
if isinstance(x, (nn.LayerNorm, nn.GroupNorm)):
x.eps = 1.0
# We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 16 times the model,
# but it would be nicer to have an efficient way to use parameterized.expand
fail_cases = []
for padding_side in ["left", "right"]:
for use_mask in [False, True]:
for output_attentions in [True, False]:
can_output_attn = "output_attentions" in inspect.signature(model_sdpa.forward).parameters
if not (self.has_attentions and can_output_attn) and output_attentions:
continue
# TODO: if we can also check with `batch_size=1` without being flaky?
for batch_size in [7]:
dummy_input = inputs_dict[model.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]:
dummy_input = dummy_input.to(torch_dtype)
dummy_input = dummy_input[:batch_size]
for enable_kernels in [False, True]:
failcase = f"padding_side={padding_side}, use_mask={use_mask}, enable_kernels={enable_kernels}"
processed_inputs = {
model.main_input_name: dummy_input,
"output_hidden_states": True,
}
if (
self.has_attentions
and "output_attentions" in inspect.signature(model_sdpa.forward).parameters
):
processed_inputs["output_attentions"] = output_attentions
if "bool_masked_pos" in inspect.signature(model_eager.forward).parameters:
dummy_mask = torch.ones((self.model_tester.num_masks,))
mask_length = self.model_tester.seq_length - 1 - dummy_mask.size(0)
dummy_mask = torch.cat([dummy_mask, torch.zeros(mask_length)])
dummy_bool_masked_pos = dummy_mask.expand(batch_size, -1).bool()
processed_inputs["bool_masked_pos"] = dummy_bool_masked_pos.to(torch_device)
with torch.no_grad():
with sdpa_kernel(
enable_flash=enable_kernels,
enable_math=True,
enable_mem_efficient=enable_kernels,
):
prepared_inputs = self._prepare_for_class(processed_inputs, model_class)
outputs_eager = model_eager(**prepared_inputs)
outputs_sdpa = model_sdpa(**prepared_inputs)
logits_eager = outputs_eager.hidden_states[-1]
logits_sdpa = outputs_sdpa.hidden_states[-1]
if torch_device in ["cpu", "cuda"]:
atol = atols[torch_device, enable_kernels, torch_dtype]
rtol = rtols[torch_device, enable_kernels, torch_dtype]
elif torch_device == "xpu":
# As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH
# which is implemented on PyTorch level using aten operators and is
# device agnostic with respect to implementation of each aten operator.
atol = atols["cuda", False, torch_dtype]
rtol = rtols["cuda", False, torch_dtype]
else:
atol = 1e-7
rtol = 1e-4
# Masked tokens output slightly deviates - we don't mind that.
if use_mask:
_logits_sdpa = torch.zeros_like(input=logits_sdpa)
_logits_eager = torch.zeros_like(input=logits_eager)
_logits_sdpa[:-1] = logits_sdpa[:-1]
_logits_eager[:-1] = logits_eager[:-1]
if padding_side == "left":
_logits_sdpa[-1:, 2:] = logits_sdpa[-1:, 2:]
_logits_eager[-1:, 2:] = logits_eager[-1:, 2:]
elif padding_side == "right":
_logits_sdpa[-1:, 2:] = logits_sdpa[-1:, :-2]
_logits_eager[-1:, 2:] = logits_eager[-1:, :-2]
logits_sdpa = _logits_sdpa
logits_eager = _logits_eager
results = [
torch.allclose(_logits_sdpa, _logits_eager, atol=atol, rtol=rtol)
for (_logits_sdpa, _logits_eager) in zip(logits_sdpa, logits_eager)
]
# If 80% batch elements have matched results, it's fine
if np.mean(results) < 0.8:
fail_cases.append(
get_mean_reldiff(failcase, logits_sdpa, logits_eager, atol, rtol)
)
self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases))
# We will verify our results on an image of cute cats
def prepare_img():

View File

@ -91,11 +91,6 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase):
def test_sdpa_can_dispatch_non_composite_models(self):
pass
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@unittest.skip("Gemma2's eager attn/sdpa attn outputs are expected to be different")
def test_eager_matches_sdpa_inference(self):
pass
@unittest.skip("Gemma2's eager attn/sdpa attn outputs are expected to be different")
def test_eager_matches_sdpa_generate(self):
pass

View File

@ -25,7 +25,6 @@ from transformers.testing_utils import (
TestCasePlus,
require_bitsandbytes,
require_torch,
require_torch_sdpa,
require_vision,
slow,
torch_device,
@ -34,7 +33,13 @@ from transformers.utils import cached_property
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
from ...test_modeling_common import (
TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION,
ModelTesterMixin,
floats_tensor,
ids_tensor,
random_attention_mask,
)
from ...test_pipeline_mixin import PipelineTesterMixin
@ -311,16 +316,12 @@ class IdeficsModelTester:
def prepare_pixel_values(self):
return floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
@require_torch_sdpa
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
self.skipTest(reason="Idefics has a hard requirement on SDPA, skipping this test")
@require_torch_sdpa
@slow
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
def test_eager_matches_sdpa_generate(self):
self.skipTest(reason="Idefics has a hard requirement on SDPA, skipping this test")
@parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION)
@unittest.skip(reason="Idefics has a hard requirement on SDPA, skipping this test")
def test_eager_matches_sdpa_inference(
self, name, torch_dtype, padding_side, use_attention_mask, output_attentions, enable_kernels
):
pass
@require_torch
@ -349,10 +350,11 @@ class IdeficsModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
return inputs_dict
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
@parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION)
@unittest.skip("Idefics requires both text and image inputs which is currently not done in this test.")
def test_eager_matches_sdpa_inference(self):
def test_eager_matches_sdpa_inference(
self, name, torch_dtype, padding_side, use_attention_mask, output_attentions, enable_kernels
):
pass
def test_model_outputs_equivalence(self):
@ -597,10 +599,11 @@ class IdeficsForVisionText2TextTest(IdeficsModelTest, GenerationTesterMixin, uni
)
self.config_tester = ConfigTester(self, config_class=IdeficsConfig, hidden_size=37)
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
@parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION)
@unittest.skip("Idefics requires both text and image inputs which is currently not done in this test.")
def test_eager_matches_sdpa_inference(self, torch_dtype):
def test_eager_matches_sdpa_inference(
self, name, torch_dtype, padding_side, use_attention_mask, output_attentions, enable_kernels
):
pass
@pytest.mark.generate

View File

@ -21,7 +21,6 @@ import unittest
import numpy as np
from datasets import Audio, load_dataset
from parameterized import parameterized
from pytest import mark
from transformers import AutoFeatureExtractor, MimiConfig
@ -31,17 +30,12 @@ from transformers.testing_utils import (
require_flash_attn,
require_torch,
require_torch_gpu,
require_torch_sdpa,
slow,
torch_device,
)
from transformers.utils import (
is_torch_bf16_available_on_device,
is_torch_fp16_available_on_device,
)
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor, sdpa_kernel
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
if is_torch_available():
@ -409,291 +403,6 @@ class MimiModelTest(ModelTesterMixin, unittest.TestCase):
config.use_conv_shortcut = False
self.model_tester.create_and_check_model_forward(config, inputs_dict)
# Overwrite to use `audio_values` as the tensors to compare.
# TODO: Try to do this in the parent class.
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
if torch_dtype == "float16" and torch_device == "cpu":
self.skipTest("`replication_pad1d` not implemented for 'Half")
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
if not self.all_model_classes[0]._supports_sdpa:
self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
if torch_dtype == "float16" and not is_torch_fp16_available_on_device(torch_device):
self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)")
if torch_dtype == "bfloat16" and not is_torch_bf16_available_on_device(torch_device):
self.skipTest(
f"bfloat16 not supported on {torch_device} (on the specific device currently used, e.g. Nvidia T4 GPU)"
)
# Not sure whether it's fine to put torch.XXX in a decorator if torch is not available so hacking it here instead.
if torch_dtype == "float16":
torch_dtype = torch.float16
elif torch_dtype == "bfloat16":
torch_dtype = torch.bfloat16
elif torch_dtype == "float32":
torch_dtype = torch.float32
atols = {
("cpu", False, torch.float32): 1e-6,
("cpu", False, torch.bfloat16): 1e-2,
("cpu", True, torch.float32): 1e-6,
("cpu", True, torch.bfloat16): 1e-2,
("cuda", False, torch.float32): 1e-6,
("cuda", False, torch.bfloat16): 1e-2,
("cuda", False, torch.float16): 5e-3,
("cuda", True, torch.float32): 1e-6,
("cuda", True, torch.bfloat16): 1e-2,
("cuda", True, torch.float16): 5e-3,
}
rtols = {
("cpu", False, torch.float32): 1e-4,
("cpu", False, torch.bfloat16): 1e-2,
("cpu", True, torch.float32): 1e-4,
("cpu", True, torch.bfloat16): 1e-2,
("cuda", False, torch.float32): 1e-4,
("cuda", False, torch.bfloat16): 1e-2,
("cuda", False, torch.float16): 5e-3,
("cuda", True, torch.float32): 1e-4,
("cuda", True, torch.bfloat16): 3e-2,
("cuda", True, torch.float16): 5e-3,
}
def get_mean_reldiff(failcase, x, ref, atol, rtol):
return f"{failcase}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}"
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
# FIXME: we deactivate boolean mask for models using "use_mask_token" in their constructors.
# These models support masking only in the case `use_mask_token=True`. Otherwise they cannot consume an input mask.
# This means that the class needs to be instantiated much later, after `use_mask` is set, which means a significant refactor of the code.
# However masking there is not done at any layers that matters (i.e self-attention), therefore we can safely deactivate it.
deactivate_mask = "use_mask_token" in inspect.signature(model_class).parameters
is_encoder_decoder = model.config.is_encoder_decoder
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype)
model_sdpa = model_sdpa.eval().to(torch_device)
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
model_eager = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch_dtype,
attn_implementation="eager",
)
model_eager = model_eager.eval().to(torch_device)
self.assertTrue(model_eager.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
raise ValueError("The eager model should not have SDPA attention layers")
has_sdpa = False
for name, submodule in model_sdpa.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
has_sdpa = True
break
if not has_sdpa and model_sdpa.config.model_type != "falcon":
raise ValueError("The SDPA model should have SDPA attention layers")
# We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 16 times the model,
# but it would be nicer to have an efficient way to use parameterized.expand
fail_cases = []
for padding_side in ["left", "right"]:
for use_mask in [False, True]:
for output_attentions in [True, False]:
can_output_attn = "output_attentions" in inspect.signature(model_sdpa.forward).parameters
if not (self.has_attentions and can_output_attn) and output_attentions:
continue
for batch_size in [7]:
dummy_input = inputs_dict[model.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]:
dummy_input = dummy_input.to(torch_dtype)
dummy_input = dummy_input[:batch_size]
if dummy_input.shape[0] != batch_size:
if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]:
extension = torch.rand(
batch_size - dummy_input.shape[0],
*dummy_input.shape[1:],
dtype=torch_dtype,
device=torch_device,
)
dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device)
else:
extension = torch.randint(
high=5,
size=(batch_size - dummy_input.shape[0], *dummy_input.shape[1:]),
dtype=dummy_input.dtype,
device=torch_device,
)
dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device)
if not use_mask:
dummy_attention_mask = None
else:
dummy_attention_mask = inputs_dict.get("attention_mask", None)
if dummy_attention_mask is None:
if is_encoder_decoder:
seqlen = inputs_dict.get("decoder_input_ids", dummy_input).shape[-1]
else:
seqlen = dummy_input.shape[-1]
dummy_attention_mask = (
torch.ones(batch_size, seqlen).to(torch.int64).to(torch_device)
)
dummy_attention_mask = dummy_attention_mask[:batch_size]
if dummy_attention_mask.shape[0] != batch_size:
extension = torch.ones(
batch_size - dummy_attention_mask.shape[0],
*dummy_attention_mask.shape[1:],
dtype=dummy_attention_mask.dtype,
device=torch_device,
)
dummy_attention_mask = torch.cat((dummy_attention_mask, extension), dim=0)
dummy_attention_mask = dummy_attention_mask.to(torch_device)
dummy_attention_mask[:] = 1
if padding_side == "left":
dummy_attention_mask[-1, :2] = 0
dummy_attention_mask[-1, 2:] = 1
elif padding_side == "right":
dummy_attention_mask[-1, -2:] = 0
dummy_attention_mask[-1, :-2] = 1
for enable_kernels in [False, True]:
failcase = f"padding_side={padding_side}, use_mask={use_mask}, batch_size={batch_size}, enable_kernels={enable_kernels}"
if is_encoder_decoder:
decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[
:batch_size
]
if decoder_input_ids.shape[0] != batch_size:
extension = torch.ones(
batch_size - decoder_input_ids.shape[0],
*decoder_input_ids.shape[1:],
dtype=decoder_input_ids.dtype,
device=torch_device,
)
decoder_input_ids = torch.cat((decoder_input_ids, extension), dim=0)
decoder_input_ids = decoder_input_ids.to(torch_device)
# TODO: never an `attention_mask` arg here?
processed_inputs = {
model.main_input_name: dummy_input,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": dummy_attention_mask,
"output_hidden_states": True,
}
else:
processed_inputs = {
model.main_input_name: dummy_input,
"output_hidden_states": True,
}
# Otherwise fails for e.g. WhisperEncoderModel
if "attention_mask" in inspect.signature(model_eager.forward).parameters:
processed_inputs["attention_mask"] = dummy_attention_mask
if (
self.has_attentions
and "output_attentions" in inspect.signature(model_sdpa.forward).parameters
):
processed_inputs["output_attentions"] = output_attentions
if not deactivate_mask and (
"bool_masked_pos" in inspect.signature(model_eager.forward).parameters
):
dummy_mask = torch.ones((self.model_tester.num_masks,))
# In case of additional token (like class) we define a custom `mask_length`
if hasattr(self.model_tester, "mask_length"):
mask_length = self.model_tester.mask_length - dummy_mask.size(0)
else:
mask_length = self.model_tester.seq_length - dummy_mask.size(0)
dummy_mask = torch.cat([dummy_mask, torch.zeros(mask_length)])
dummy_bool_masked_pos = dummy_mask.expand(batch_size, -1).bool()
processed_inputs["bool_masked_pos"] = dummy_bool_masked_pos.to(torch_device)
if "noise" in inspect.signature(model_eager.forward).parameters:
np.random.seed(2)
num_patches = int(
(self.model_tester.image_size // self.model_tester.patch_size) ** 2
)
noise = np.random.uniform(size=(batch_size, num_patches))
processed_inputs["noise"] = torch.from_numpy(noise)
# TODO: test gradients as well (& for FA2 as well!)
with torch.no_grad():
with sdpa_kernel(
enable_flash=enable_kernels,
enable_math=True,
enable_mem_efficient=enable_kernels,
):
prepared_inputs = self._prepare_for_class(processed_inputs, model_class)
outputs_eager = model_eager(**prepared_inputs)
outputs_sdpa = model_sdpa(**prepared_inputs)
# Ignore copy
logits_eager = outputs_eager.audio_values
# Ignore copy
logits_sdpa = outputs_sdpa.audio_values
if torch_device in ["cpu", "cuda"]:
atol = atols[torch_device, enable_kernels, torch_dtype]
rtol = rtols[torch_device, enable_kernels, torch_dtype]
elif torch_device == "xpu":
# As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH
# which is implemented on PyTorch level using aten operators and is
# device agnostic with respect to implementation of each aten operator.
atol = atols["cuda", False, torch_dtype]
rtol = rtols["cuda", False, torch_dtype]
else:
atol = 1e-7
rtol = 1e-4
# Masked tokens output slightly deviates - we don't mind that.
if use_mask:
_logits_sdpa = torch.zeros_like(input=logits_sdpa)
_logits_eager = torch.zeros_like(input=logits_eager)
_logits_sdpa[:-1] = logits_sdpa[:-1]
_logits_eager[:-1] = logits_eager[:-1]
if padding_side == "left":
_logits_sdpa[-1:, 2:] = logits_sdpa[-1:, 2:]
_logits_eager[-1:, 2:] = logits_eager[-1:, 2:]
elif padding_side == "right":
_logits_sdpa[-1:, 2:] = logits_sdpa[-1:, :-2]
_logits_eager[-1:, 2:] = logits_eager[-1:, :-2]
logits_sdpa = _logits_sdpa
logits_eager = _logits_eager
results = [
torch.allclose(_logits_sdpa, _logits_eager, atol=atol, rtol=rtol)
for (_logits_sdpa, _logits_eager) in zip(logits_sdpa, logits_eager)
]
# If 80% batch elements have matched results, it's fine
if np.mean(results) < 0.8:
fail_cases.append(
get_mean_reldiff(failcase, logits_sdpa, logits_eager, atol, rtol)
)
self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases))
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test

View File

@ -44,7 +44,12 @@ from transformers.utils import cached_property
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from ...test_modeling_common import (
TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION,
ModelTesterMixin,
floats_tensor,
ids_tensor,
)
from ...test_pipeline_mixin import PipelineTesterMixin
@ -188,11 +193,15 @@ class MoshiDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
logits_processor_kwargs = {}
return logits_processor_kwargs
@parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION)
@require_torch_sdpa
@slow
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
self.skipTest(reason="Moshi has no strict equivalence between two modes, skipping this test.")
def test_eager_matches_sdpa_inference(
self, name, torch_dtype, padding_side, use_attention_mask, output_attentions, enable_kernels
):
if use_attention_mask or (not use_attention_mask and torch_dtype == "fp32" and not output_attentions):
self.skipTest("Test is failing, fix me :) ")
parent_parameterized_test = getattr(ModelTesterMixin, self._testMethodName)
parent_parameterized_test(self)
# Copied from tests.test_modeling_common.ModelTesterMixin.test_resize_tokens_embeddings
def test_resize_tokens_embeddings(self):
@ -620,11 +629,11 @@ class MoshiTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
def test_beam_search_generate_dict_outputs_use_cache(self):
pass
@unittest.skip("Adapting this test is costly. `test_eager_matches_sdpa_generate` tests this already.")
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
@slow
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
@parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION)
@unittest.skip(reason="Unimplemented. Relies on `test_eager_matches_sdpa_generate` to check correctness.")
def test_eager_matches_sdpa_inference(
self, name, torch_dtype, padding_side, use_attention_mask, output_attentions, enable_kernels
):
pass
@unittest.skip(reason="The Moshi model does not have support dynamic compile yet")

View File

@ -21,7 +21,6 @@ import tempfile
import unittest
import numpy as np
from parameterized import parameterized
from pytest import mark
from transformers import (
@ -43,7 +42,7 @@ from transformers.testing_utils import (
slow,
torch_device,
)
from transformers.utils import cached_property, is_torch_bf16_available_on_device, is_torch_fp16_available_on_device
from transformers.utils import cached_property
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
@ -452,226 +451,6 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
# Copied from tests.test_modeling_common.ModelTesterMixin.test_eager_matches_sdpa_inference
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
if not self.all_model_classes[0]._supports_sdpa:
self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
if torch_dtype == "float16" and not is_torch_fp16_available_on_device(torch_device):
self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)")
if torch_dtype == "bfloat16" and not is_torch_bf16_available_on_device(torch_device):
self.skipTest(
f"bfloat16 not supported on {torch_device} (on the specific device currently used, e.g. Nvidia T4 GPU)"
)
# Not sure whether it's fine to put torch.XXX in a decorator if torch is not available so hacking it here instead.
if torch_dtype == "float16":
torch_dtype = torch.float16
elif torch_dtype == "bfloat16":
torch_dtype = torch.bfloat16
elif torch_dtype == "float32":
torch_dtype = torch.float32
atols = {
("cpu", False, torch.float32): 1e-6,
("cpu", False, torch.float16): 5e-3,
("cpu", False, torch.bfloat16): 1e-2,
("cpu", True, torch.float32): 1e-6,
("cpu", True, torch.float16): 5e-3,
("cpu", True, torch.bfloat16): 1e-2,
("cuda", False, torch.float32): 1e-6,
("cuda", False, torch.bfloat16): 1e-2,
("cuda", False, torch.float16): 5e-3,
("cuda", True, torch.float32): 1e-6,
("cuda", True, torch.bfloat16): 1e-2,
("cuda", True, torch.float16): 5e-3,
}
rtols = {
("cpu", False, torch.float32): 1e-4,
("cpu", False, torch.float16): 5e-3,
("cpu", False, torch.bfloat16): 1e-2,
("cpu", True, torch.float32): 1e-4,
("cpu", True, torch.float16): 5e-3,
("cpu", True, torch.bfloat16): 1e-2,
("cuda", False, torch.float32): 1e-4,
("cuda", False, torch.bfloat16): 1e-2,
("cuda", False, torch.float16): 5e-3,
("cuda", True, torch.float32): 1e-4,
("cuda", True, torch.bfloat16): 3e-2,
("cuda", True, torch.float16): 5e-3,
}
def get_mean_reldiff(failcase, x, ref, atol, rtol):
return f"{failcase}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}"
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
is_encoder_decoder = model.config.is_encoder_decoder
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype)
model_sdpa = model_sdpa.eval().to(torch_device)
model_eager = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch_dtype,
attn_implementation="eager",
)
model_eager = model_eager.eval().to(torch_device)
# We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 8 times the model,
# but it would be nicer to have an efficient way to use parameterized.expand
fail_cases = []
for padding_side in ["left", "right"]:
for use_mask in [False, True]:
for batch_size in [7]:
# Ignore copy
batch_size_input_ids = self.model_tester.num_codebooks * batch_size
dummy_input = inputs_dict[model.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]:
dummy_input = dummy_input.to(torch_dtype)
# Ignore copy
dummy_input = dummy_input[:batch_size_input_ids]
# Ignore copy
if dummy_input.shape[0] != batch_size_input_ids:
if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]:
# Ignore copy
extension = torch.rand(
batch_size_input_ids - dummy_input.shape[0],
*dummy_input.shape[1:],
dtype=torch_dtype,
device=torch_device,
)
dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device)
else:
# Ignore copy
extension = torch.randint(
high=5,
size=(batch_size_input_ids - dummy_input.shape[0], *dummy_input.shape[1:]),
dtype=dummy_input.dtype,
device=torch_device,
)
dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device)
if not use_mask:
dummy_attention_mask = None
else:
dummy_attention_mask = inputs_dict.get("attention_mask", None)
if dummy_attention_mask is None:
if is_encoder_decoder:
seqlen = inputs_dict.get("decoder_input_ids", dummy_input).shape[-1]
else:
seqlen = dummy_input.shape[-1]
dummy_attention_mask = (
torch.ones(batch_size, seqlen).to(torch.int64).to(torch_device)
)
dummy_attention_mask = dummy_attention_mask[:batch_size]
if dummy_attention_mask.shape[0] != batch_size:
extension = torch.ones(
batch_size - dummy_attention_mask.shape[0],
*dummy_attention_mask.shape[1:],
dtype=dummy_attention_mask.dtype,
device=torch_device,
)
dummy_attention_mask = torch.cat((dummy_attention_mask, extension), dim=0)
dummy_attention_mask = dummy_attention_mask.to(torch_device)
dummy_attention_mask[:] = 1
if padding_side == "left":
dummy_attention_mask[-1, :2] = 0
dummy_attention_mask[-1, 2:] = 1
elif padding_side == "right":
dummy_attention_mask[-1, -2:] = 0
dummy_attention_mask[-1, :-2] = 1
for enable_kernels in [False, True]:
failcase = f"padding_side={padding_side}, use_mask={use_mask}, batch_size={batch_size}, enable_kernels={enable_kernels}"
other_inputs = {
"output_hidden_states": True,
}
# Otherwise fails for e.g. WhisperEncoderModel
if "attention_mask" in inspect.signature(model_eager.forward).parameters:
other_inputs["attention_mask"] = dummy_attention_mask
# TODO: test gradients as well (& for FA2 as well!)
with torch.no_grad():
with sdpa_kernel(
enable_flash=enable_kernels,
enable_math=True,
enable_mem_efficient=enable_kernels,
):
outputs_eager = model_eager(dummy_input, **other_inputs)
outputs_sdpa = model_sdpa(dummy_input, **other_inputs)
logits_eager = (
outputs_eager.hidden_states[-1]
if not is_encoder_decoder
else outputs_eager.decoder_hidden_states[-1]
)
logits_sdpa = (
outputs_sdpa.hidden_states[-1]
if not is_encoder_decoder
else outputs_sdpa.decoder_hidden_states[-1]
)
if torch_device in ["cpu", "cuda"]:
atol = atols[torch_device, enable_kernels, torch_dtype]
rtol = rtols[torch_device, enable_kernels, torch_dtype]
elif torch_device == "xpu":
# As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH
# which is implemented on PyTorch level using aten operators and is
# device agnostic with respect to implementation of each aten operator.
atol = atols["cuda", False, torch_dtype]
rtol = rtols["cuda", False, torch_dtype]
else:
atol = 1e-7
rtol = 1e-4
# Masked tokens output slightly deviates - we don't mind that.
if use_mask:
_logits_sdpa = torch.zeros_like(input=logits_sdpa)
_logits_eager = torch.zeros_like(input=logits_eager)
_logits_sdpa[:-1] = logits_sdpa[:-1]
_logits_eager[:-1] = logits_eager[:-1]
if padding_side == "left":
_logits_sdpa[-1:, 2:] = logits_sdpa[-1:, 2:]
_logits_eager[-1:, 2:] = logits_eager[-1:, 2:]
elif padding_side == "right":
_logits_sdpa[-1:, 2:] = logits_sdpa[-1:, :-2]
_logits_eager[-1:, 2:] = logits_eager[-1:, :-2]
logits_sdpa = _logits_sdpa
logits_eager = _logits_eager
results = [
torch.allclose(_logits_sdpa, _logits_eager, atol=atol, rtol=rtol)
for (_logits_sdpa, _logits_eager) in zip(logits_sdpa, logits_eager)
]
# If 80% batch elements have matched results, it's fine
if np.mean(results) < 0.8:
fail_cases.append(
get_mean_reldiff(failcase, logits_sdpa, logits_eager, atol, rtol)
)
self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases))
@unittest.skip(
reason=(
"MusicGen has a custom set of generation tests that rely on `GenerationTesterMixin`, controlled by "
@ -1496,261 +1275,6 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
if not has_sdpa and model_sdpa.config.model_type != "falcon":
raise ValueError("The SDPA model should have SDPA attention layers")
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
if not self.all_model_classes[0]._supports_sdpa:
self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
if torch_dtype == "float16" and not is_torch_fp16_available_on_device(torch_device):
self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)")
if torch_dtype == "bfloat16" and not is_torch_bf16_available_on_device(torch_device):
self.skipTest(
f"bfloat16 not supported on {torch_device} (on the specific device currently used, e.g. Nvidia T4 GPU)"
)
# Not sure whether it's fine to put torch.XXX in a decorator if torch is not available so hacking it here instead.
if torch_dtype == "float16":
torch_dtype = torch.float16
elif torch_dtype == "bfloat16":
torch_dtype = torch.bfloat16
elif torch_dtype == "float32":
torch_dtype = torch.float32
atols = {
("cpu", False, torch.float32): 1e-6,
("cpu", False, torch.float16): 5e-3,
("cpu", False, torch.bfloat16): 1e-2,
("cpu", True, torch.float32): 1e-6,
("cpu", True, torch.float16): 5e-3,
("cpu", True, torch.bfloat16): 1e-2,
("cuda", False, torch.float32): 1e-6,
("cuda", False, torch.bfloat16): 1e-2,
("cuda", False, torch.float16): 5e-3,
("cuda", True, torch.float32): 1e-6,
("cuda", True, torch.bfloat16): 1e-2,
("cuda", True, torch.float16): 5e-3,
}
rtols = {
("cpu", False, torch.float32): 1e-4,
("cpu", False, torch.float16): 5e-3,
("cpu", False, torch.bfloat16): 1e-2,
("cpu", True, torch.float32): 1e-4,
("cpu", True, torch.float16): 5e-3,
("cpu", True, torch.bfloat16): 1e-2,
("cuda", False, torch.float32): 1e-4,
("cuda", False, torch.bfloat16): 1e-2,
("cuda", False, torch.float16): 5e-3,
("cuda", True, torch.float32): 1e-4,
("cuda", True, torch.bfloat16): 3e-2,
("cuda", True, torch.float16): 5e-3,
}
def get_mean_reldiff(failcase, x, ref, atol, rtol):
return f"{failcase}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}"
if hasattr(self.model_tester, "num_hidden_layers"):
self.model_tester.num_hidden_layers = 1
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.rms_norm_eps = 1.0
config.layer_norm_eps = 1.0
config.norm_eps = 1.0
config.norm_epsilon = 1.0
config.layer_norm_epsilon = 1.0
for attr in ["text_config", "vision_config", "text_encoder", "audio_encoder", "decoder"]:
if hasattr(config, attr):
getattr(config, attr).rms_norm_eps = 1.0
getattr(config, attr).layer_norm_eps = 1.0
getattr(config, attr).norm_eps = 1.0
getattr(config, attr).norm_epsilon = 1.0
getattr(config, attr).layer_norm_epsilon = 1.0
model = model_class(config)
is_encoder_decoder = model.config.is_encoder_decoder
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype)
model_sdpa = model_sdpa.eval().to(torch_device)
model_eager = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch_dtype,
attn_implementation="eager",
)
model_eager = model_eager.eval().to(torch_device)
for x in model_eager.modules():
if isinstance(x, (torch.nn.LayerNorm, torch.nn.GroupNorm)):
x.eps = 1.0
for x in model_sdpa.modules():
if isinstance(x, (torch.nn.LayerNorm, torch.nn.GroupNorm)):
x.eps = 1.0
# We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 8 times the model,
# but it would be nicer to have an efficient way to use parameterized.expand
fail_cases = []
for padding_side in ["left", "right"]:
for use_mask in [False, True]:
for batch_size in [7]:
dummy_input = inputs_dict[model.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]:
dummy_input = dummy_input.to(torch_dtype)
dummy_input = dummy_input[:batch_size]
if dummy_input.shape[0] != batch_size:
if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]:
extension = torch.rand(
batch_size - dummy_input.shape[0],
*dummy_input.shape[1:],
dtype=torch_dtype,
device=torch_device,
)
dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device)
else:
extension = torch.randint(
high=5,
size=(batch_size - dummy_input.shape[0], *dummy_input.shape[1:]),
dtype=dummy_input.dtype,
device=torch_device,
)
dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device)
if not use_mask:
dummy_attention_mask = None
else:
dummy_attention_mask = inputs_dict.get("attention_mask", None)
if dummy_attention_mask is None:
# Ignore copy
seqlen = inputs_dict.get("decoder_input_ids", dummy_input).shape[-1]
# Ignore copy
dummy_attention_mask = (
torch.ones(batch_size, seqlen).to(torch.int64).to(torch_device)
)
dummy_attention_mask = dummy_attention_mask[:batch_size]
if dummy_attention_mask.shape[0] != batch_size:
extension = torch.ones(
batch_size - dummy_attention_mask.shape[0],
*dummy_attention_mask.shape[1:],
dtype=dummy_attention_mask.dtype,
device=torch_device,
)
dummy_attention_mask = torch.cat((dummy_attention_mask, extension), dim=0)
dummy_attention_mask = dummy_attention_mask.to(torch_device)
dummy_attention_mask[:] = 1
if padding_side == "left":
dummy_attention_mask[-1, :2] = 0
dummy_attention_mask[-1, 2:] = 1
elif padding_side == "right":
dummy_attention_mask[-1, -2:] = 0
dummy_attention_mask[-1, :-2] = 1
for enable_kernels in [False, True]:
failcase = f"padding_side={padding_side}, use_mask={use_mask}, batch_size={batch_size}, enable_kernels={enable_kernels}"
# Ignore copy
batch_size_input_ids = self.model_tester.num_codebooks * batch_size
# Ignore copy
decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[
:batch_size_input_ids
]
# Ignore copy
if decoder_input_ids.shape[0] != batch_size_input_ids:
# Ignore copy
extension = torch.ones(
batch_size_input_ids - decoder_input_ids.shape[0],
*decoder_input_ids.shape[1:],
dtype=decoder_input_ids.dtype,
device=torch_device,
)
decoder_input_ids = torch.cat((decoder_input_ids, extension), dim=0)
decoder_input_ids = decoder_input_ids.to(torch_device)
# TODO: never an `attention_mask` arg here?
# Ignore copy
other_inputs = {
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": dummy_attention_mask,
"output_hidden_states": True,
}
# TODO: test gradients as well (& for FA2 as well!)
# Ignore copy
with torch.no_grad():
with sdpa_kernel(
enable_flash=enable_kernels,
enable_math=True,
enable_mem_efficient=enable_kernels,
):
outputs_eager = model_eager(dummy_input, **other_inputs)
outputs_sdpa = model_sdpa(dummy_input, **other_inputs)
logits_eager = (
outputs_eager.hidden_states[-1]
if not is_encoder_decoder
else outputs_eager.decoder_hidden_states[-1]
)
logits_sdpa = (
outputs_sdpa.hidden_states[-1]
if not is_encoder_decoder
else outputs_sdpa.decoder_hidden_states[-1]
)
if torch_device in ["cpu", "cuda"]:
atol = atols[torch_device, enable_kernels, torch_dtype]
rtol = rtols[torch_device, enable_kernels, torch_dtype]
elif torch_device == "xpu":
# As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH
# which is implemented on PyTorch level using aten operators and is
# device agnostic with respect to implementation of each aten operator.
atol = atols["cuda", False, torch_dtype]
rtol = rtols["cuda", False, torch_dtype]
else:
atol = 1e-7
rtol = 1e-4
# Masked tokens output slightly deviates - we don't mind that.
if use_mask:
_logits_sdpa = torch.zeros_like(input=logits_sdpa)
_logits_eager = torch.zeros_like(input=logits_eager)
_logits_sdpa[:-1] = logits_sdpa[:-1]
_logits_eager[:-1] = logits_eager[:-1]
if padding_side == "left":
_logits_sdpa[-1:, 2:] = logits_sdpa[-1:, 2:]
_logits_eager[-1:, 2:] = logits_eager[-1:, 2:]
elif padding_side == "right":
_logits_sdpa[-1:, 2:] = logits_sdpa[-1:, :-2]
_logits_eager[-1:, 2:] = logits_eager[-1:, :-2]
logits_sdpa = _logits_sdpa
logits_eager = _logits_eager
results = [
torch.allclose(_logits_sdpa, _logits_eager, atol=atol, rtol=rtol)
for (_logits_sdpa, _logits_eager) in zip(logits_sdpa, logits_eager)
]
# If 80% batch elements have matched results, it's fine
if np.mean(results) < 0.8:
fail_cases.append(
get_mean_reldiff(failcase, logits_sdpa, logits_eager, atol, rtol)
)
self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases))
def test_requires_grad_with_frozen_encoders(self):
config = self.model_tester.get_config()
for model_class in self.all_model_classes:

View File

@ -21,7 +21,6 @@ import tempfile
import unittest
import numpy as np
from parameterized import parameterized
from pytest import mark
from transformers import (
@ -41,13 +40,10 @@ from transformers.testing_utils import (
require_torch_gpu,
require_torch_sdpa,
require_torchaudio,
set_config_for_less_flaky_test,
set_model_for_less_flaky_test,
set_model_tester_for_less_flaky_test,
slow,
torch_device,
)
from transformers.utils import cached_property, is_torch_bf16_available_on_device, is_torch_fp16_available_on_device
from transformers.utils import cached_property
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
@ -463,232 +459,6 @@ class MusicgenMelodyDecoderTest(ModelTesterMixin, GenerationTesterMixin, unittes
assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
# Copied from tests.test_modeling_common.ModelTesterMixin.test_eager_matches_sdpa_inference
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
if not self.all_model_classes[0]._supports_sdpa:
self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
if torch_dtype == "float16" and not is_torch_fp16_available_on_device(torch_device):
self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)")
if torch_dtype == "bfloat16" and not is_torch_bf16_available_on_device(torch_device):
self.skipTest(
f"bfloat16 not supported on {torch_device} (on the specific device currently used, e.g. Nvidia T4 GPU)"
)
# Not sure whether it's fine to put torch.XXX in a decorator if torch is not available so hacking it here instead.
if torch_dtype == "float16":
torch_dtype = torch.float16
elif torch_dtype == "bfloat16":
torch_dtype = torch.bfloat16
elif torch_dtype == "float32":
torch_dtype = torch.float32
atols = {
("cpu", False, torch.float32): 1e-6,
("cpu", False, torch.float16): 5e-3,
("cpu", False, torch.bfloat16): 1e-2,
("cpu", True, torch.float32): 1e-6,
("cpu", True, torch.float16): 5e-3,
("cpu", True, torch.bfloat16): 1e-2,
("cuda", False, torch.float32): 1e-6,
("cuda", False, torch.bfloat16): 1e-2,
("cuda", False, torch.float16): 5e-3,
("cuda", True, torch.float32): 1e-6,
("cuda", True, torch.bfloat16): 1e-2,
("cuda", True, torch.float16): 5e-3,
}
rtols = {
("cpu", False, torch.float32): 1e-4,
("cpu", False, torch.float16): 5e-3,
("cpu", False, torch.bfloat16): 1e-2,
("cpu", True, torch.float32): 1e-4,
("cpu", True, torch.float16): 5e-3,
("cpu", True, torch.bfloat16): 1e-2,
("cuda", False, torch.float32): 1e-4,
("cuda", False, torch.bfloat16): 1e-2,
("cuda", False, torch.float16): 5e-3,
("cuda", True, torch.float32): 1e-4,
("cuda", True, torch.bfloat16): 3e-2,
("cuda", True, torch.float16): 5e-3,
}
def get_mean_reldiff(failcase, x, ref, atol, rtol):
return f"{failcase}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}"
set_model_tester_for_less_flaky_test(self)
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
set_config_for_less_flaky_test(config)
model = model_class(config)
is_encoder_decoder = model.config.is_encoder_decoder
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype)
model_sdpa = model_sdpa.eval().to(torch_device)
model_eager = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch_dtype,
attn_implementation="eager",
)
model_eager = model_eager.eval().to(torch_device)
set_model_for_less_flaky_test(model_eager)
set_model_for_less_flaky_test(model_sdpa)
# We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 8 times the model,
# but it would be nicer to have an efficient way to use parameterized.expand
fail_cases = []
for padding_side in ["left", "right"]:
for use_mask in [False, True]:
for batch_size in [7]:
# Ignore copy
batch_size_input_ids = self.model_tester.num_codebooks * batch_size
dummy_input = inputs_dict[model.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]:
dummy_input = dummy_input.to(torch_dtype)
# Ignore copy
dummy_input = dummy_input[:batch_size_input_ids]
# Ignore copy
if dummy_input.shape[0] != batch_size_input_ids:
if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]:
# Ignore copy
extension = torch.rand(
batch_size_input_ids - dummy_input.shape[0],
*dummy_input.shape[1:],
dtype=torch_dtype,
device=torch_device,
)
dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device)
else:
# Ignore copy
extension = torch.randint(
high=5,
size=(batch_size_input_ids - dummy_input.shape[0], *dummy_input.shape[1:]),
dtype=dummy_input.dtype,
device=torch_device,
)
dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device)
if not use_mask:
dummy_attention_mask = None
else:
dummy_attention_mask = inputs_dict.get("attention_mask", None)
if dummy_attention_mask is None:
if is_encoder_decoder:
seqlen = inputs_dict.get("decoder_input_ids", dummy_input).shape[-1]
else:
seqlen = dummy_input.shape[-1]
dummy_attention_mask = (
torch.ones(batch_size, seqlen).to(torch.int64).to(torch_device)
)
dummy_attention_mask = dummy_attention_mask[:batch_size]
if dummy_attention_mask.shape[0] != batch_size:
extension = torch.ones(
batch_size - dummy_attention_mask.shape[0],
*dummy_attention_mask.shape[1:],
dtype=dummy_attention_mask.dtype,
device=torch_device,
)
dummy_attention_mask = torch.cat((dummy_attention_mask, extension), dim=0)
dummy_attention_mask = dummy_attention_mask.to(torch_device)
dummy_attention_mask[:] = 1
if padding_side == "left":
dummy_attention_mask[-1, :2] = 0
dummy_attention_mask[-1, 2:] = 1
elif padding_side == "right":
dummy_attention_mask[-1, -2:] = 0
dummy_attention_mask[-1, :-2] = 1
for enable_kernels in [False, True]:
failcase = f"padding_side={padding_side}, use_mask={use_mask}, batch_size={batch_size}, enable_kernels={enable_kernels}"
other_inputs = {
"output_hidden_states": True,
}
# Otherwise fails for e.g. WhisperEncoderModel
if "attention_mask" in inspect.signature(model_eager.forward).parameters:
other_inputs["attention_mask"] = dummy_attention_mask
# TODO: test gradients as well (& for FA2 as well!)
with torch.no_grad():
with sdpa_kernel(
enable_flash=enable_kernels,
enable_math=True,
enable_mem_efficient=enable_kernels,
):
outputs_eager = model_eager(dummy_input, **other_inputs)
outputs_sdpa = model_sdpa(dummy_input, **other_inputs)
logits_eager = (
outputs_eager.hidden_states[-1]
if not is_encoder_decoder
else outputs_eager.decoder_hidden_states[-1]
)
logits_sdpa = (
outputs_sdpa.hidden_states[-1]
if not is_encoder_decoder
else outputs_sdpa.decoder_hidden_states[-1]
)
if torch_device in ["cpu", "cuda"]:
atol = atols[torch_device, enable_kernels, torch_dtype]
rtol = rtols[torch_device, enable_kernels, torch_dtype]
elif torch_device == "xpu":
# As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH
# which is implemented on PyTorch level using aten operators and is
# device agnostic with respect to implementation of each aten operator.
atol = atols["cuda", False, torch_dtype]
rtol = rtols["cuda", False, torch_dtype]
else:
atol = 1e-7
rtol = 1e-4
# Masked tokens output slightly deviates - we don't mind that.
if use_mask:
_logits_sdpa = torch.zeros_like(input=logits_sdpa)
_logits_eager = torch.zeros_like(input=logits_eager)
_logits_sdpa[:-1] = logits_sdpa[:-1]
_logits_eager[:-1] = logits_eager[:-1]
if padding_side == "left":
_logits_sdpa[-1:, 2:] = logits_sdpa[-1:, 2:]
_logits_eager[-1:, 2:] = logits_eager[-1:, 2:]
elif padding_side == "right":
_logits_sdpa[-1:, 2:] = logits_sdpa[-1:, :-2]
_logits_eager[-1:, 2:] = logits_eager[-1:, :-2]
logits_sdpa = _logits_sdpa
logits_eager = _logits_eager
results = [
torch.allclose(_logits_sdpa, _logits_eager, atol=atol, rtol=rtol)
for (_logits_sdpa, _logits_eager) in zip(logits_sdpa, logits_eager)
]
# If 80% batch elements have matched results, it's fine
if np.mean(results) < 0.8:
fail_cases.append(
get_mean_reldiff(failcase, logits_sdpa, logits_eager, atol, rtol)
)
self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases))
@unittest.skip(
reason=(
"MusicGen has a custom set of generation tests that rely on `GenerationTesterMixin`, controlled by "
@ -1495,240 +1265,6 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
if not has_sdpa and model_sdpa.config.model_type != "falcon":
raise ValueError("The SDPA model should have SDPA attention layers")
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
# Copied from tests.test_modeling_common.ModelTesterMixin.test_eager_matches_sdpa_inference
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
if not self.all_model_classes[0]._supports_sdpa:
self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
if torch_dtype == "float16" and not is_torch_fp16_available_on_device(torch_device):
self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)")
if torch_dtype == "bfloat16" and not is_torch_bf16_available_on_device(torch_device):
self.skipTest(
f"bfloat16 not supported on {torch_device} (on the specific device currently used, e.g. Nvidia T4 GPU)"
)
# Not sure whether it's fine to put torch.XXX in a decorator if torch is not available so hacking it here instead.
if torch_dtype == "float16":
torch_dtype = torch.float16
elif torch_dtype == "bfloat16":
torch_dtype = torch.bfloat16
elif torch_dtype == "float32":
torch_dtype = torch.float32
atols = {
("cpu", False, torch.float32): 1e-6,
("cpu", False, torch.float16): 5e-3,
("cpu", False, torch.bfloat16): 1e-2,
("cpu", True, torch.float32): 1e-6,
("cpu", True, torch.float16): 5e-3,
("cpu", True, torch.bfloat16): 1e-2,
("cuda", False, torch.float32): 1e-6,
("cuda", False, torch.bfloat16): 1e-2,
("cuda", False, torch.float16): 5e-3,
("cuda", True, torch.float32): 1e-6,
("cuda", True, torch.bfloat16): 1e-2,
("cuda", True, torch.float16): 5e-3,
}
rtols = {
("cpu", False, torch.float32): 1e-4,
("cpu", False, torch.float16): 5e-3,
("cpu", False, torch.bfloat16): 1e-2,
("cpu", True, torch.float32): 1e-4,
("cpu", True, torch.float16): 5e-3,
("cpu", True, torch.bfloat16): 1e-2,
("cuda", False, torch.float32): 1e-4,
("cuda", False, torch.bfloat16): 1e-2,
("cuda", False, torch.float16): 5e-3,
("cuda", True, torch.float32): 1e-4,
("cuda", True, torch.bfloat16): 3e-2,
("cuda", True, torch.float16): 5e-3,
}
def get_mean_reldiff(failcase, x, ref, atol, rtol):
return f"{failcase}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}"
set_model_tester_for_less_flaky_test(self)
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
set_config_for_less_flaky_test(config)
model = model_class(config)
is_encoder_decoder = model.config.is_encoder_decoder
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype)
model_sdpa = model_sdpa.eval().to(torch_device)
model_eager = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch_dtype,
attn_implementation="eager",
)
model_eager = model_eager.eval().to(torch_device)
set_model_for_less_flaky_test(model_eager)
set_model_for_less_flaky_test(model_sdpa)
# We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 8 times the model,
# but it would be nicer to have an efficient way to use parameterized.expand
fail_cases = []
for padding_side in ["left", "right"]:
for use_mask in [False, True]:
for batch_size in [7]:
dummy_input = inputs_dict[model.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]:
dummy_input = dummy_input.to(torch_dtype)
dummy_input = dummy_input[:batch_size]
if dummy_input.shape[0] != batch_size:
if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]:
extension = torch.rand(
batch_size - dummy_input.shape[0],
*dummy_input.shape[1:],
dtype=torch_dtype,
device=torch_device,
)
dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device)
else:
extension = torch.randint(
high=5,
size=(batch_size - dummy_input.shape[0], *dummy_input.shape[1:]),
dtype=dummy_input.dtype,
device=torch_device,
)
dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device)
if not use_mask:
dummy_attention_mask = None
else:
dummy_attention_mask = inputs_dict.get("attention_mask", None)
if dummy_attention_mask is None:
# Ignore copy
seqlen = inputs_dict.get("decoder_input_ids", dummy_input).shape[-1]
# Ignore copy
dummy_attention_mask = (
torch.ones(batch_size, seqlen).to(torch.int64).to(torch_device)
)
dummy_attention_mask = dummy_attention_mask[:batch_size]
if dummy_attention_mask.shape[0] != batch_size:
extension = torch.ones(
batch_size - dummy_attention_mask.shape[0],
*dummy_attention_mask.shape[1:],
dtype=dummy_attention_mask.dtype,
device=torch_device,
)
dummy_attention_mask = torch.cat((dummy_attention_mask, extension), dim=0)
dummy_attention_mask = dummy_attention_mask.to(torch_device)
dummy_attention_mask[:] = 1
if padding_side == "left":
dummy_attention_mask[-1, :2] = 0
dummy_attention_mask[-1, 2:] = 1
elif padding_side == "right":
dummy_attention_mask[-1, -2:] = 0
dummy_attention_mask[-1, :-2] = 1
for enable_kernels in [False, True]:
failcase = f"padding_side={padding_side}, use_mask={use_mask}, batch_size={batch_size}, enable_kernels={enable_kernels}"
# Ignore copy
batch_size_input_ids = self.model_tester.num_codebooks * batch_size
# Ignore copy
decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[
:batch_size_input_ids
]
# Ignore copy
if decoder_input_ids.shape[0] != batch_size_input_ids:
# Ignore copy
extension = torch.ones(
batch_size_input_ids - decoder_input_ids.shape[0],
*decoder_input_ids.shape[1:],
dtype=decoder_input_ids.dtype,
device=torch_device,
)
decoder_input_ids = torch.cat((decoder_input_ids, extension), dim=0)
decoder_input_ids = decoder_input_ids.to(torch_device)
# TODO: never an `attention_mask` arg here?
# Ignore copy
other_inputs = {
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": dummy_attention_mask,
"output_hidden_states": True,
}
# TODO: test gradients as well (& for FA2 as well!)
# Ignore copy
with torch.no_grad():
with sdpa_kernel(
enable_flash=enable_kernels,
enable_math=True,
enable_mem_efficient=enable_kernels,
):
outputs_eager = model_eager(dummy_input, **other_inputs)
outputs_sdpa = model_sdpa(dummy_input, **other_inputs)
logits_eager = (
outputs_eager.hidden_states[-1]
if not is_encoder_decoder
else outputs_eager.decoder_hidden_states[-1]
)
logits_sdpa = (
outputs_sdpa.hidden_states[-1]
if not is_encoder_decoder
else outputs_sdpa.decoder_hidden_states[-1]
)
if torch_device in ["cpu", "cuda"]:
atol = atols[torch_device, enable_kernels, torch_dtype]
rtol = rtols[torch_device, enable_kernels, torch_dtype]
elif torch_device == "xpu":
# As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH
# which is implemented on PyTorch level using aten operators and is
# device agnostic with respect to implementation of each aten operator.
atol = atols["cuda", False, torch_dtype]
rtol = rtols["cuda", False, torch_dtype]
else:
atol = 1e-7
rtol = 1e-4
# Masked tokens output slightly deviates - we don't mind that.
if use_mask:
_logits_sdpa = torch.zeros_like(input=logits_sdpa)
_logits_eager = torch.zeros_like(input=logits_eager)
_logits_sdpa[:-1] = logits_sdpa[:-1]
_logits_eager[:-1] = logits_eager[:-1]
if padding_side == "left":
_logits_sdpa[-1:, 2:] = logits_sdpa[-1:, 2:]
_logits_eager[-1:, 2:] = logits_eager[-1:, 2:]
elif padding_side == "right":
_logits_sdpa[-1:, 2:] = logits_sdpa[-1:, :-2]
_logits_eager[-1:, 2:] = logits_eager[-1:, :-2]
logits_sdpa = _logits_sdpa
logits_eager = _logits_eager
results = [
torch.allclose(_logits_sdpa, _logits_eager, atol=atol, rtol=rtol)
for (_logits_sdpa, _logits_eager) in zip(logits_sdpa, logits_eager)
]
# If 80% batch elements have matched results, it's fine
if np.mean(results) < 0.8:
fail_cases.append(
get_mean_reldiff(failcase, logits_sdpa, logits_eager, atol, rtol)
)
self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases))
def test_requires_grad_with_frozen_encoders(self):
config = self.model_tester.get_config()
for model_class in self.all_model_classes:

View File

@ -347,10 +347,6 @@ class RecurrentGemmaModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Te
def test_eager_matches_sdpa_generate(self):
pass
@unittest.skip(reason="RecurrentGemma only supports sdpa")
def test_eager_matches_sdpa_inference(self):
pass
@unittest.skip(reason="RecurrentGemma does not return the cache")
def test_contrastive_generate_low_memory(self):
pass

View File

@ -22,7 +22,7 @@ from huggingface_hub import hf_hub_download
from transformers import VideoMAEConfig
from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, require_torch_sdpa, require_vision, slow, torch_device
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from transformers.utils import cached_property, is_torch_available, is_vision_available
from ...test_configuration_common import ConfigTester
@ -214,11 +214,6 @@ class VideoMAEModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
return inputs_dict
@unittest.skip("`mse_cpu` not implemented for 'BFloat16'")
@require_torch_sdpa
def test_eager_matches_sdpa_inference_1_bfloat16(self):
pass
def test_config(self):
self.config_tester.run_common_tests()

View File

@ -59,6 +59,7 @@ class ViTMSNModelTester:
initializer_range=0.02,
scope=None,
attn_implementation="eager",
mask_ratio=0.5,
):
self.parent = parent
self.batch_size = batch_size
@ -82,6 +83,8 @@ class ViTMSNModelTester:
# in ViT MSN, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
num_patches = (image_size // patch_size) ** 2
self.seq_length = num_patches + 1
self.num_masks = int(mask_ratio * self.seq_length)
self.mask_length = self.seq_length - 1
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])

View File

@ -133,6 +133,23 @@ if is_deepspeed_available():
import deepspeed
# used in other test files e.g. when overwriting the test
TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION = [
(
# test name for the test runner
f"{dtype}_pad_{padding_side}{'' if use_attention_mask else '_no_attn_mask'}"
f"{'_output_attn' if output_attentions else ''}{'_sdpa_kernels' if enable_kernels else ''}",
# parameterization
*(dtype, padding_side, use_attention_mask, output_attentions, enable_kernels),
)
for dtype in ("fp16", "fp32", "bf16")
for padding_side in ("left", "right")
for use_attention_mask in (True, False)
for output_attentions in (True, False)
for enable_kernels in (True, False)
]
def _config_zero_init(config):
configs_no_init = copy.deepcopy(config)
for key in configs_no_init.__dict__.keys():
@ -3543,31 +3560,39 @@ class ModelTesterMixin:
):
raise ValueError("The eager model should not have SDPA attention layers")
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION)
@require_torch_sdpa
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
def test_eager_matches_sdpa_inference(
self, name, torch_dtype, padding_side, use_attention_mask, output_attentions, enable_kernels
):
# TODO: we shouldn't need to do this skip, i.e. the test would be composable from the model tester. CLIP-like
# models have a custom mixin, which we detect to skip this test.
if not any(".ModelTesterMixin" in str(base) for base in self.__class__.__bases__):
self.skipTest(reason="CLIP-like models have a different `test_eager_matches_sdpa_inference`")
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
if not self.all_model_classes[0]._supports_sdpa:
self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
if torch_dtype == "float16" and not is_torch_fp16_available_on_device(torch_device):
# convert shorthand name to torch.dtype
if torch_dtype == "fp16":
torch_dtype = torch.float16
elif torch_dtype == "bf16":
torch_dtype = torch.bfloat16
elif torch_dtype == "fp32":
torch_dtype = torch.float32
if not is_torch_fp16_available_on_device(torch_device) and torch_dtype == torch.float16:
self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)")
if torch_dtype == "bfloat16" and not is_torch_bf16_available_on_device(torch_device):
if not is_torch_bf16_available_on_device(torch_device) and torch_dtype == torch.bfloat16:
self.skipTest(
f"bfloat16 not supported on {torch_device} (on the specific device currently used, e.g. Nvidia T4 GPU)"
)
# Not sure whether it's fine to put torch.XXX in a decorator if torch is not available so hacking it here instead.
if torch_dtype == "float16":
torch_dtype = torch.float16
elif torch_dtype == "bfloat16":
torch_dtype = torch.bfloat16
elif torch_dtype == "float32":
torch_dtype = torch.float32
# Dictionary of tolerances for eager <> sdpa tests. Key = (device, sdpa_kernels_enabled, dtype)
atols = {
("cpu", False, torch.float32): 1e-6,
("cpu", False, torch.float16): 5e-3,
@ -3597,238 +3622,243 @@ class ModelTesterMixin:
("cuda", True, torch.float16): 5e-3,
}
def get_mean_reldiff(failcase, x, ref, atol, rtol):
return f"{failcase}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}"
set_model_tester_for_less_flaky_test(self)
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
set_config_for_less_flaky_test(config)
model = model_class(config)
# FIXME: we deactivate boolean mask for models using "use_mask_token" in their constructors.
# These models support masking only in the case `use_mask_token=True`. Otherwise they cannot consume an input mask.
# This means that the class needs to be instantiated much later, after `use_mask` is set, which means a significant refactor of the code.
# However masking there is not done at any layers that matters (i.e self-attention), therefore we can safely deactivate it.
deactivate_mask = "use_mask_token" in inspect.signature(model_class).parameters
is_encoder_decoder = model.config.is_encoder_decoder
# TODO: standardize the interfaces for musicgen models, see other todo in this test
if model.__class__.__name__ == "MusicgenMelodyForConditionalGeneration":
is_encoder_decoder = True
else:
is_encoder_decoder = model.config.is_encoder_decoder
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_from_pretrained_kwargs = {
"pretrained_model_name_or_path": tmpdirname,
"torch_dtype": torch_dtype,
}
if (
hasattr(config, "use_mask_token")
or "use_mask_token" in inspect.signature(model.__init__).parameters
):
model_from_pretrained_kwargs["use_mask_token"] = True
# TODO: remove this try/except, models should have a shared API
try:
model_sdpa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch_dtype, attn_implementation="sdpa"
**model_from_pretrained_kwargs, attn_implementation="sdpa"
)
except ValueError:
model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype)
model_sdpa = model_class.from_pretrained(**model_from_pretrained_kwargs)
model_sdpa = model_sdpa.eval().to(torch_device, dtype=torch_dtype)
model_eager = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch_dtype,
attn_implementation="eager",
)
model_eager = model_class.from_pretrained(**model_from_pretrained_kwargs, attn_implementation="eager")
model_eager = model_eager.eval().to(torch_device, dtype=torch_dtype)
set_model_for_less_flaky_test(model_eager)
set_model_for_less_flaky_test(model_sdpa)
# We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 16 times the model,
# but it would be nicer to have an efficient way to use parameterized.expand
fail_cases = []
for padding_side in ["left", "right"]:
for use_mask in [False, True]:
for output_attentions in [True, False]:
can_output_attn = "output_attentions" in inspect.signature(model_sdpa.forward).parameters
if not (self.has_attentions and can_output_attn) and output_attentions:
continue
# TODO: if we can also check with `batch_size=1` without being flaky?
for batch_size in [7]:
dummy_input = inputs_dict[model.main_input_name]
can_output_attn = "output_attentions" in inspect.signature(model_sdpa.forward).parameters
if not (self.has_attentions and can_output_attn) and output_attentions:
self.skipTest(reason="Model does not support output_attentions")
if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]:
dummy_input = dummy_input.to(torch_dtype)
# TODO: if we can also check with `batch_size=1` without being flaky?
for batch_size in [7]:
# musicgen decoder models; TODO: find better abstraction
if hasattr(self.model_tester, "num_codebooks") and not hasattr(model_eager, "text_encoder"):
input_data_batch_size = batch_size * self.model_tester.num_codebooks
else:
input_data_batch_size = batch_size
dummy_input = dummy_input[:batch_size]
if dummy_input.shape[0] != batch_size:
if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]:
extension = torch.rand(
batch_size - dummy_input.shape[0],
*dummy_input.shape[1:],
dtype=torch_dtype,
device=torch_device,
)
dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device)
else:
extension = torch.randint(
high=5,
size=(batch_size - dummy_input.shape[0], *dummy_input.shape[1:]),
dtype=dummy_input.dtype,
device=torch_device,
)
dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device)
dummy_input = inputs_dict[model.main_input_name]
if not use_mask:
dummy_attention_mask = None
else:
dummy_attention_mask = inputs_dict.get("attention_mask", None)
if dummy_attention_mask is None:
if is_encoder_decoder:
seqlen = inputs_dict.get("decoder_input_ids", dummy_input).shape[-1]
else:
seqlen = dummy_input.shape[-1]
dummy_attention_mask = (
torch.ones(batch_size, seqlen).to(torch.int64).to(torch_device)
)
if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]:
dummy_input = dummy_input.to(torch_dtype)
dummy_attention_mask = dummy_attention_mask[:batch_size]
if dummy_attention_mask.shape[0] != batch_size:
extension = torch.ones(
batch_size - dummy_attention_mask.shape[0],
*dummy_attention_mask.shape[1:],
dtype=dummy_attention_mask.dtype,
device=torch_device,
)
dummy_attention_mask = torch.cat((dummy_attention_mask, extension), dim=0)
dummy_attention_mask = dummy_attention_mask.to(torch_device)
dummy_input = dummy_input[:input_data_batch_size]
if dummy_input.shape[0] != input_data_batch_size:
if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]:
extension = torch.rand(
input_data_batch_size - dummy_input.shape[0],
*dummy_input.shape[1:],
dtype=torch_dtype,
device=torch_device,
)
dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device)
else:
extension = torch.randint(
high=5,
size=(input_data_batch_size - dummy_input.shape[0], *dummy_input.shape[1:]),
dtype=dummy_input.dtype,
device=torch_device,
)
dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device)
dummy_attention_mask[:] = 1
if padding_side == "left":
dummy_attention_mask[-1, :2] = 0
dummy_attention_mask[-1, 2:] = 1
elif padding_side == "right":
dummy_attention_mask[-1, -2:] = 0
dummy_attention_mask[-1, :-2] = 1
if not use_attention_mask:
dummy_attention_mask = None
else:
dummy_attention_mask = inputs_dict.get("attention_mask", None)
if dummy_attention_mask is None:
if is_encoder_decoder:
seqlen = inputs_dict.get("decoder_input_ids", dummy_input).shape[-1]
else:
seqlen = dummy_input.shape[-1]
dummy_attention_mask = torch.ones(batch_size, seqlen).to(torch.int64).to(torch_device)
for enable_kernels in [False, True]:
failcase = f"padding_side={padding_side}, use_mask={use_mask}, enable_kernels={enable_kernels}"
if is_encoder_decoder:
decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[
:batch_size
]
if decoder_input_ids.shape[0] != batch_size:
extension = torch.ones(
batch_size - decoder_input_ids.shape[0],
*decoder_input_ids.shape[1:],
dtype=decoder_input_ids.dtype,
device=torch_device,
)
decoder_input_ids = torch.cat((decoder_input_ids, extension), dim=0)
decoder_input_ids = decoder_input_ids.to(torch_device)
dummy_attention_mask = dummy_attention_mask[:batch_size]
if dummy_attention_mask.shape[0] != batch_size:
extension = torch.ones(
batch_size - dummy_attention_mask.shape[0],
*dummy_attention_mask.shape[1:],
dtype=dummy_attention_mask.dtype,
device=torch_device,
)
dummy_attention_mask = torch.cat((dummy_attention_mask, extension), dim=0)
dummy_attention_mask = dummy_attention_mask.to(torch_device)
# TODO: never an `attention_mask` arg here?
processed_inputs = {
model.main_input_name: dummy_input,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": dummy_attention_mask,
"output_hidden_states": True,
}
else:
processed_inputs = {
model.main_input_name: dummy_input,
"output_hidden_states": True,
}
dummy_attention_mask[:] = 1
if padding_side == "left":
dummy_attention_mask[-1, :2] = 0
dummy_attention_mask[-1, 2:] = 1
elif padding_side == "right":
dummy_attention_mask[-1, -2:] = 0
dummy_attention_mask[-1, :-2] = 1
# Otherwise fails for e.g. WhisperEncoderModel
if "attention_mask" in inspect.signature(model_eager.forward).parameters:
processed_inputs["attention_mask"] = dummy_attention_mask
if is_encoder_decoder:
# musicgen encoder-decoder models; TODO: find better abstraction
if hasattr(self.model_tester, "num_codebooks"):
input_data_batch_size = batch_size * self.model_tester.num_codebooks
else:
input_data_batch_size = batch_size
if (
self.has_attentions
and "output_attentions" in inspect.signature(model_sdpa.forward).parameters
):
processed_inputs["output_attentions"] = output_attentions
if not deactivate_mask and (
"bool_masked_pos" in inspect.signature(model_eager.forward).parameters
):
dummy_mask = torch.ones((self.model_tester.num_masks,))
decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:input_data_batch_size]
if decoder_input_ids.shape[0] != input_data_batch_size:
extension = torch.ones(
input_data_batch_size - decoder_input_ids.shape[0],
*decoder_input_ids.shape[1:],
dtype=decoder_input_ids.dtype,
device=torch_device,
)
decoder_input_ids = torch.cat((decoder_input_ids, extension), dim=0)
decoder_input_ids = decoder_input_ids.to(torch_device)
# In case of additional token (like class) we define a custom `mask_length`
if hasattr(self.model_tester, "mask_length"):
mask_length = self.model_tester.mask_length - dummy_mask.size(0)
else:
mask_length = self.model_tester.seq_length - dummy_mask.size(0)
dummy_mask = torch.cat([dummy_mask, torch.zeros(mask_length)])
dummy_bool_masked_pos = dummy_mask.expand(batch_size, -1).bool()
processed_inputs["bool_masked_pos"] = dummy_bool_masked_pos.to(torch_device)
# TODO: never an `attention_mask` arg here?
processed_inputs = {
model.main_input_name: dummy_input,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": dummy_attention_mask,
"output_hidden_states": True,
}
else:
processed_inputs = {
model.main_input_name: dummy_input,
"output_hidden_states": True,
}
if "noise" in inspect.signature(model_eager.forward).parameters:
np.random.seed(2)
num_patches = int(
(self.model_tester.image_size // self.model_tester.patch_size) ** 2
)
noise = np.random.uniform(size=(batch_size, num_patches))
processed_inputs["noise"] = torch.from_numpy(noise)
# Otherwise fails for e.g. WhisperEncoderModel
if "attention_mask" in inspect.signature(model_eager.forward).parameters:
processed_inputs["attention_mask"] = dummy_attention_mask
# TODO: test gradients as well (& for FA2 as well!)
with torch.no_grad():
with sdpa_kernel(
enable_flash=enable_kernels,
enable_math=True,
enable_mem_efficient=enable_kernels,
):
prepared_inputs = self._prepare_for_class(processed_inputs, model_class)
outputs_eager = model_eager(**prepared_inputs)
outputs_sdpa = model_sdpa(**prepared_inputs)
if (
self.has_attentions
and "output_attentions" in inspect.signature(model_sdpa.forward).parameters
):
processed_inputs["output_attentions"] = output_attentions
if "bool_masked_pos" in inspect.signature(model_eager.forward).parameters:
dummy_mask = torch.ones((self.model_tester.num_masks,))
if hasattr(outputs_eager, "vision_hidden_states"):
logits_eager = outputs_eager.vision_hidden_states[-1]
logits_sdpa = outputs_sdpa.vision_hidden_states[-1]
else:
logits_eager = (
outputs_eager.hidden_states[-1]
if not is_encoder_decoder
else outputs_eager.decoder_hidden_states[-1]
)
logits_sdpa = (
outputs_sdpa.hidden_states[-1]
if not is_encoder_decoder
else outputs_sdpa.decoder_hidden_states[-1]
)
# In case of additional token (like class) we define a custom `mask_length`
if hasattr(self.model_tester, "mask_length"):
mask_length = self.model_tester.mask_length - dummy_mask.size(0)
else:
mask_length = self.model_tester.seq_length - dummy_mask.size(0)
dummy_mask = torch.cat([dummy_mask, torch.zeros(mask_length)])
dummy_bool_masked_pos = dummy_mask.expand(batch_size, -1).bool()
processed_inputs["bool_masked_pos"] = dummy_bool_masked_pos.to(torch_device)
if torch_device in ["cpu", "cuda"]:
atol = atols[torch_device, enable_kernels, torch_dtype]
rtol = rtols[torch_device, enable_kernels, torch_dtype]
elif torch_device == "xpu":
# As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH
# which is implemented on PyTorch level using aten operators and is
# device agnostic with respect to implementation of each aten operator.
atol = atols["cuda", False, torch_dtype]
rtol = rtols["cuda", False, torch_dtype]
else:
atol = 1e-7
rtol = 1e-4
if "noise" in inspect.signature(model_eager.forward).parameters:
np.random.seed(2)
num_patches = int((self.model_tester.image_size // self.model_tester.patch_size) ** 2)
noise = np.random.uniform(size=(batch_size, num_patches))
processed_inputs["noise"] = torch.from_numpy(noise)
# Masked tokens output slightly deviates - we don't mind that.
if use_mask:
_logits_sdpa = torch.zeros_like(input=logits_sdpa)
_logits_eager = torch.zeros_like(input=logits_eager)
# TODO: test gradients as well (& for FA2 as well!)
with torch.no_grad():
with sdpa_kernel(
enable_flash=enable_kernels,
enable_math=True,
enable_mem_efficient=enable_kernels,
):
prepared_inputs = self._prepare_for_class(processed_inputs, model_class)
outputs_eager = model_eager(**prepared_inputs)
outputs_sdpa = model_sdpa(**prepared_inputs)
_logits_sdpa[:-1] = logits_sdpa[:-1]
_logits_eager[:-1] = logits_eager[:-1]
# TODO: rename logits -> hidden_states
if hasattr(outputs_eager, "vision_hidden_states"):
logits_eager = outputs_eager.vision_hidden_states[-1]
logits_sdpa = outputs_sdpa.vision_hidden_states[-1]
elif hasattr(outputs_eager, "audio_values"):
logits_eager = outputs_eager.audio_values
logits_sdpa = outputs_sdpa.audio_values
else:
logits_eager = (
outputs_eager.decoder_hidden_states[-1]
if hasattr(outputs_eager, "decoder_hidden_states")
else outputs_eager.hidden_states[-1]
)
logits_sdpa = (
outputs_sdpa.decoder_hidden_states[-1]
if hasattr(outputs_sdpa, "decoder_hidden_states")
else outputs_sdpa.hidden_states[-1]
)
if padding_side == "left":
_logits_sdpa[-1:, 2:] = logits_sdpa[-1:, 2:]
_logits_eager[-1:, 2:] = logits_eager[-1:, 2:]
if torch_device in ["cpu", "cuda"]:
atol = atols[torch_device, enable_kernels, torch_dtype]
rtol = rtols[torch_device, enable_kernels, torch_dtype]
elif torch_device == "xpu":
# As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH
# which is implemented on PyTorch level using aten operators and is
# device agnostic with respect to implementation of each aten operator.
atol = atols["cuda", False, torch_dtype]
rtol = rtols["cuda", False, torch_dtype]
else:
atol = 1e-7
rtol = 1e-4
elif padding_side == "right":
_logits_sdpa[-1:, 2:] = logits_sdpa[-1:, :-2]
_logits_eager[-1:, 2:] = logits_eager[-1:, :-2]
# Masked tokens output slightly deviates - we don't mind that.
if use_attention_mask:
_logits_sdpa = torch.zeros_like(input=logits_sdpa)
_logits_eager = torch.zeros_like(input=logits_eager)
logits_sdpa = _logits_sdpa
logits_eager = _logits_eager
_logits_sdpa[:-1] = logits_sdpa[:-1]
_logits_eager[:-1] = logits_eager[:-1]
results = [
torch.allclose(_logits_sdpa, _logits_eager, atol=atol, rtol=rtol)
for (_logits_sdpa, _logits_eager) in zip(logits_sdpa, logits_eager)
]
# If 80% batch elements have matched results, it's fine
if np.mean(results) < 0.8:
fail_cases.append(
get_mean_reldiff(failcase, logits_sdpa, logits_eager, atol, rtol)
)
if padding_side == "left":
_logits_sdpa[-1:, 2:] = logits_sdpa[-1:, 2:]
_logits_eager[-1:, 2:] = logits_eager[-1:, 2:]
self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases))
elif padding_side == "right":
_logits_sdpa[-1:, 2:] = logits_sdpa[-1:, :-2]
_logits_eager[-1:, 2:] = logits_eager[-1:, :-2]
logits_sdpa = _logits_sdpa
logits_eager = _logits_eager
results = [
torch.allclose(_logits_sdpa, _logits_eager, atol=atol, rtol=rtol)
for (_logits_sdpa, _logits_eager) in zip(logits_sdpa, logits_eager)
]
# If 80% batch elements have matched results, it's fine
if np.mean(results) < 0.8:
mean_relative_diff = ((logits_sdpa - logits_eager).abs() / (logits_eager.abs() + 1e-12)).mean()
raise ValueError(
f"mean relative difference: {mean_relative_diff:.3e}, torch atol = {atol}, torch rtol = "
f"{rtol}"
)
@require_torch_sdpa
@require_torch_gpu