diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index 76275778c49..a56b5c68085 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -76,7 +76,7 @@ class NemotronLayerNorm1P(nn.LayerNorm): def forward(self, input: Tensor) -> Tensor: args = _cast_if_autocast_enabled(input, self.normalized_shape, self.weight + 1, self.bias, self.eps) - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast(input.device.type, enabled=False): return F.layer_norm(*args) diff --git a/tests/models/mimi/test_modeling_mimi.py b/tests/models/mimi/test_modeling_mimi.py index 7ddc6b74744..4f6cfaff7e9 100644 --- a/tests/models/mimi/test_modeling_mimi.py +++ b/tests/models/mimi/test_modeling_mimi.py @@ -41,7 +41,7 @@ from transformers.utils import ( ) from ...test_configuration_common import ConfigTester -from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor +from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor, sdpa_kernel if is_torch_available(): @@ -636,7 +636,7 @@ class MimiModelTest(ModelTesterMixin, unittest.TestCase): # TODO: test gradients as well (& for FA2 as well!) with torch.no_grad(): - with torch.backends.cuda.sdp_kernel( + with sdpa_kernel( enable_flash=enable_kernels, enable_math=True, enable_mem_efficient=enable_kernels, @@ -653,6 +653,12 @@ class MimiModelTest(ModelTesterMixin, unittest.TestCase): if torch_device in ["cpu", "cuda"]: atol = atols[torch_device, enable_kernels, torch_dtype] rtol = rtols[torch_device, enable_kernels, torch_dtype] + elif torch_device == "xpu": + # As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH + # which is implemented on PyTorch level using aten operators and is + # device agnostic with respect to implementation of each aten operator. + atol = atols["cuda", False, torch_dtype] + rtol = rtols["cuda", False, torch_dtype] else: atol = 1e-7 rtol = 1e-4 diff --git a/tests/models/musicgen/test_modeling_musicgen.py b/tests/models/musicgen/test_modeling_musicgen.py index 37b5af3ae7e..3ea60d550e0 100644 --- a/tests/models/musicgen/test_modeling_musicgen.py +++ b/tests/models/musicgen/test_modeling_musicgen.py @@ -47,7 +47,7 @@ from transformers.utils import cached_property, is_torch_bf16_available_on_devic from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester -from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor +from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, sdpa_kernel from ...test_pipeline_mixin import PipelineTesterMixin @@ -607,7 +607,7 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste # TODO: test gradients as well (& for FA2 as well!) with torch.no_grad(): - with torch.backends.cuda.sdp_kernel( + with sdpa_kernel( enable_flash=enable_kernels, enable_math=True, enable_mem_efficient=enable_kernels, @@ -629,6 +629,12 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste if torch_device in ["cpu", "cuda"]: atol = atols[torch_device, enable_kernels, torch_dtype] rtol = rtols[torch_device, enable_kernels, torch_dtype] + elif torch_device == "xpu": + # As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH + # which is implemented on PyTorch level using aten operators and is + # device agnostic with respect to implementation of each aten operator. + atol = atols["cuda", False, torch_dtype] + rtol = rtols["cuda", False, torch_dtype] else: atol = 1e-7 rtol = 1e-4 @@ -1343,7 +1349,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, if isinstance(inp, torch.Tensor) and inp.dtype in [torch.float32, torch.float16]: inputs_dict[name] = inp.to(torch.float16) - with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): + with sdpa_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): _ = model(**inputs_dict) @require_flash_attn @@ -1669,7 +1675,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, # TODO: test gradients as well (& for FA2 as well!) # Ignore copy with torch.no_grad(): - with torch.backends.cuda.sdp_kernel( + with sdpa_kernel( enable_flash=enable_kernels, enable_math=True, enable_mem_efficient=enable_kernels, @@ -1691,6 +1697,12 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, if torch_device in ["cpu", "cuda"]: atol = atols[torch_device, enable_kernels, torch_dtype] rtol = rtols[torch_device, enable_kernels, torch_dtype] + elif torch_device == "xpu": + # As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH + # which is implemented on PyTorch level using aten operators and is + # device agnostic with respect to implementation of each aten operator. + atol = atols["cuda", False, torch_dtype] + rtol = rtols["cuda", False, torch_dtype] else: atol = 1e-7 rtol = 1e-4 diff --git a/tests/models/musicgen_melody/test_modeling_musicgen_melody.py b/tests/models/musicgen_melody/test_modeling_musicgen_melody.py index de7a2745ca0..bc8baa2746a 100644 --- a/tests/models/musicgen_melody/test_modeling_musicgen_melody.py +++ b/tests/models/musicgen_melody/test_modeling_musicgen_melody.py @@ -48,7 +48,7 @@ from transformers.utils import cached_property, is_torch_bf16_available_on_devic from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester -from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor +from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, sdpa_kernel from ...test_pipeline_mixin import PipelineTesterMixin @@ -615,7 +615,7 @@ class MusicgenMelodyDecoderTest(ModelTesterMixin, GenerationTesterMixin, unittes # TODO: test gradients as well (& for FA2 as well!) with torch.no_grad(): - with torch.backends.cuda.sdp_kernel( + with sdpa_kernel( enable_flash=enable_kernels, enable_math=True, enable_mem_efficient=enable_kernels, @@ -637,6 +637,12 @@ class MusicgenMelodyDecoderTest(ModelTesterMixin, GenerationTesterMixin, unittes if torch_device in ["cpu", "cuda"]: atol = atols[torch_device, enable_kernels, torch_dtype] rtol = rtols[torch_device, enable_kernels, torch_dtype] + elif torch_device == "xpu": + # As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH + # which is implemented on PyTorch level using aten operators and is + # device agnostic with respect to implementation of each aten operator. + atol = atols["cuda", False, torch_dtype] + rtol = rtols["cuda", False, torch_dtype] else: atol = 1e-7 rtol = 1e-4 @@ -1333,7 +1339,7 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester if isinstance(inp, torch.Tensor) and inp.dtype in [torch.float32, torch.float16]: inputs_dict[name] = inp.to(torch.float16) - with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): + with sdpa_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): _ = model(**inputs_dict) @require_flash_attn @@ -1632,7 +1638,7 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester # TODO: test gradients as well (& for FA2 as well!) # Ignore copy with torch.no_grad(): - with torch.backends.cuda.sdp_kernel( + with sdpa_kernel( enable_flash=enable_kernels, enable_math=True, enable_mem_efficient=enable_kernels, @@ -1654,6 +1660,12 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester if torch_device in ["cpu", "cuda"]: atol = atols[torch_device, enable_kernels, torch_dtype] rtol = rtols[torch_device, enable_kernels, torch_dtype] + elif torch_device == "xpu": + # As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH + # which is implemented on PyTorch level using aten operators and is + # device agnostic with respect to implementation of each aten operator. + atol = atols["cuda", False, torch_dtype] + rtol = rtols["cuda", False, torch_dtype] else: atol = 1e-7 rtol = 1e-4 diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index f3f326a4ce8..99d0a8058c6 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -187,6 +187,22 @@ def _deepspeed_zero3(ds_config): unset_hf_deepspeed_config() +def sdpa_kernel(enable_flash, enable_math, enable_mem_efficient): + if version.parse(torch.__version__).release < version.parse("2.3").release: + return torch.backends.cuda.sdp_kernel( + enable_flash=enable_flash, enable_math=enable_math, enable_mem_efficient=enable_mem_efficient + ) + + backends = [] + if enable_flash: + backends += [torch.nn.attention.SDPBackend.FLASH_ATTENTION] + if enable_math: + backends += [torch.nn.attention.SDPBackend.MATH] + if enable_mem_efficient: + backends += [torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION] + return torch.nn.attention.sdpa_kernel(backends) + + @require_torch class ModelTesterMixin: model_tester = None @@ -4175,7 +4191,7 @@ class ModelTesterMixin: # TODO: test gradients as well (& for FA2 as well!) with torch.no_grad(): - with torch.backends.cuda.sdp_kernel( + with sdpa_kernel( enable_flash=enable_kernels, enable_math=True, enable_mem_efficient=enable_kernels, @@ -4198,6 +4214,12 @@ class ModelTesterMixin: if torch_device in ["cpu", "cuda"]: atol = atols[torch_device, enable_kernels, torch_dtype] rtol = rtols[torch_device, enable_kernels, torch_dtype] + elif torch_device == "xpu": + # As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH + # which is implemented on PyTorch level using aten operators and is + # device agnostic with respect to implementation of each aten operator. + atol = atols["cuda", False, torch_dtype] + rtol = rtols["cuda", False, torch_dtype] else: atol = 1e-7 rtol = 1e-4