From 31830474bff00c6cb15d395f800594b9a5a74e3f Mon Sep 17 00:00:00 2001 From: Dmitry Rogozhkin Date: Mon, 2 Dec 2024 07:21:04 -0800 Subject: [PATCH] Fix `test_eager_matches_sdpa_inference` for `XPU` backend (#34889) * Use torch.nn.attention.sdpa_kernel instead of deprecated torch.backends.cuda.sdp_kernel Signed-off-by: Dmitry Rogozhkin * Fix test_eager_matches_sdpa_inference for XPU backend As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH which is implemented on PyTorch level using aten operators and is device agnostic with respect to implementation of each aten operator. Thus, we can reuse CUDA (or CPU) MATH weights for XPU. Fixes: #34888 Signed-off-by: Dmitry Rogozhkin * Use torch.amp.autocast instead of deprecated torch.cuda.amp.autocast in nemotron Signed-off-by: Dmitry Rogozhkin --------- Signed-off-by: Dmitry Rogozhkin --- .../models/nemotron/modeling_nemotron.py | 2 +- tests/models/mimi/test_modeling_mimi.py | 10 ++++++-- .../models/musicgen/test_modeling_musicgen.py | 20 ++++++++++++---- .../test_modeling_musicgen_melody.py | 20 ++++++++++++---- tests/test_modeling_common.py | 24 ++++++++++++++++++- 5 files changed, 64 insertions(+), 12 deletions(-) diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index 76275778c49..a56b5c68085 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -76,7 +76,7 @@ class NemotronLayerNorm1P(nn.LayerNorm): def forward(self, input: Tensor) -> Tensor: args = _cast_if_autocast_enabled(input, self.normalized_shape, self.weight + 1, self.bias, self.eps) - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast(input.device.type, enabled=False): return F.layer_norm(*args) diff --git a/tests/models/mimi/test_modeling_mimi.py b/tests/models/mimi/test_modeling_mimi.py index 7ddc6b74744..4f6cfaff7e9 100644 --- a/tests/models/mimi/test_modeling_mimi.py +++ b/tests/models/mimi/test_modeling_mimi.py @@ -41,7 +41,7 @@ from transformers.utils import ( ) from ...test_configuration_common import ConfigTester -from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor +from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor, sdpa_kernel if is_torch_available(): @@ -636,7 +636,7 @@ class MimiModelTest(ModelTesterMixin, unittest.TestCase): # TODO: test gradients as well (& for FA2 as well!) with torch.no_grad(): - with torch.backends.cuda.sdp_kernel( + with sdpa_kernel( enable_flash=enable_kernels, enable_math=True, enable_mem_efficient=enable_kernels, @@ -653,6 +653,12 @@ class MimiModelTest(ModelTesterMixin, unittest.TestCase): if torch_device in ["cpu", "cuda"]: atol = atols[torch_device, enable_kernels, torch_dtype] rtol = rtols[torch_device, enable_kernels, torch_dtype] + elif torch_device == "xpu": + # As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH + # which is implemented on PyTorch level using aten operators and is + # device agnostic with respect to implementation of each aten operator. + atol = atols["cuda", False, torch_dtype] + rtol = rtols["cuda", False, torch_dtype] else: atol = 1e-7 rtol = 1e-4 diff --git a/tests/models/musicgen/test_modeling_musicgen.py b/tests/models/musicgen/test_modeling_musicgen.py index 37b5af3ae7e..3ea60d550e0 100644 --- a/tests/models/musicgen/test_modeling_musicgen.py +++ b/tests/models/musicgen/test_modeling_musicgen.py @@ -47,7 +47,7 @@ from transformers.utils import cached_property, is_torch_bf16_available_on_devic from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester -from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor +from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, sdpa_kernel from ...test_pipeline_mixin import PipelineTesterMixin @@ -607,7 +607,7 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste # TODO: test gradients as well (& for FA2 as well!) with torch.no_grad(): - with torch.backends.cuda.sdp_kernel( + with sdpa_kernel( enable_flash=enable_kernels, enable_math=True, enable_mem_efficient=enable_kernels, @@ -629,6 +629,12 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste if torch_device in ["cpu", "cuda"]: atol = atols[torch_device, enable_kernels, torch_dtype] rtol = rtols[torch_device, enable_kernels, torch_dtype] + elif torch_device == "xpu": + # As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH + # which is implemented on PyTorch level using aten operators and is + # device agnostic with respect to implementation of each aten operator. + atol = atols["cuda", False, torch_dtype] + rtol = rtols["cuda", False, torch_dtype] else: atol = 1e-7 rtol = 1e-4 @@ -1343,7 +1349,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, if isinstance(inp, torch.Tensor) and inp.dtype in [torch.float32, torch.float16]: inputs_dict[name] = inp.to(torch.float16) - with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): + with sdpa_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): _ = model(**inputs_dict) @require_flash_attn @@ -1669,7 +1675,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, # TODO: test gradients as well (& for FA2 as well!) # Ignore copy with torch.no_grad(): - with torch.backends.cuda.sdp_kernel( + with sdpa_kernel( enable_flash=enable_kernels, enable_math=True, enable_mem_efficient=enable_kernels, @@ -1691,6 +1697,12 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, if torch_device in ["cpu", "cuda"]: atol = atols[torch_device, enable_kernels, torch_dtype] rtol = rtols[torch_device, enable_kernels, torch_dtype] + elif torch_device == "xpu": + # As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH + # which is implemented on PyTorch level using aten operators and is + # device agnostic with respect to implementation of each aten operator. + atol = atols["cuda", False, torch_dtype] + rtol = rtols["cuda", False, torch_dtype] else: atol = 1e-7 rtol = 1e-4 diff --git a/tests/models/musicgen_melody/test_modeling_musicgen_melody.py b/tests/models/musicgen_melody/test_modeling_musicgen_melody.py index de7a2745ca0..bc8baa2746a 100644 --- a/tests/models/musicgen_melody/test_modeling_musicgen_melody.py +++ b/tests/models/musicgen_melody/test_modeling_musicgen_melody.py @@ -48,7 +48,7 @@ from transformers.utils import cached_property, is_torch_bf16_available_on_devic from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester -from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor +from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, sdpa_kernel from ...test_pipeline_mixin import PipelineTesterMixin @@ -615,7 +615,7 @@ class MusicgenMelodyDecoderTest(ModelTesterMixin, GenerationTesterMixin, unittes # TODO: test gradients as well (& for FA2 as well!) with torch.no_grad(): - with torch.backends.cuda.sdp_kernel( + with sdpa_kernel( enable_flash=enable_kernels, enable_math=True, enable_mem_efficient=enable_kernels, @@ -637,6 +637,12 @@ class MusicgenMelodyDecoderTest(ModelTesterMixin, GenerationTesterMixin, unittes if torch_device in ["cpu", "cuda"]: atol = atols[torch_device, enable_kernels, torch_dtype] rtol = rtols[torch_device, enable_kernels, torch_dtype] + elif torch_device == "xpu": + # As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH + # which is implemented on PyTorch level using aten operators and is + # device agnostic with respect to implementation of each aten operator. + atol = atols["cuda", False, torch_dtype] + rtol = rtols["cuda", False, torch_dtype] else: atol = 1e-7 rtol = 1e-4 @@ -1333,7 +1339,7 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester if isinstance(inp, torch.Tensor) and inp.dtype in [torch.float32, torch.float16]: inputs_dict[name] = inp.to(torch.float16) - with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): + with sdpa_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): _ = model(**inputs_dict) @require_flash_attn @@ -1632,7 +1638,7 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester # TODO: test gradients as well (& for FA2 as well!) # Ignore copy with torch.no_grad(): - with torch.backends.cuda.sdp_kernel( + with sdpa_kernel( enable_flash=enable_kernels, enable_math=True, enable_mem_efficient=enable_kernels, @@ -1654,6 +1660,12 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester if torch_device in ["cpu", "cuda"]: atol = atols[torch_device, enable_kernels, torch_dtype] rtol = rtols[torch_device, enable_kernels, torch_dtype] + elif torch_device == "xpu": + # As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH + # which is implemented on PyTorch level using aten operators and is + # device agnostic with respect to implementation of each aten operator. + atol = atols["cuda", False, torch_dtype] + rtol = rtols["cuda", False, torch_dtype] else: atol = 1e-7 rtol = 1e-4 diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index f3f326a4ce8..99d0a8058c6 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -187,6 +187,22 @@ def _deepspeed_zero3(ds_config): unset_hf_deepspeed_config() +def sdpa_kernel(enable_flash, enable_math, enable_mem_efficient): + if version.parse(torch.__version__).release < version.parse("2.3").release: + return torch.backends.cuda.sdp_kernel( + enable_flash=enable_flash, enable_math=enable_math, enable_mem_efficient=enable_mem_efficient + ) + + backends = [] + if enable_flash: + backends += [torch.nn.attention.SDPBackend.FLASH_ATTENTION] + if enable_math: + backends += [torch.nn.attention.SDPBackend.MATH] + if enable_mem_efficient: + backends += [torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION] + return torch.nn.attention.sdpa_kernel(backends) + + @require_torch class ModelTesterMixin: model_tester = None @@ -4175,7 +4191,7 @@ class ModelTesterMixin: # TODO: test gradients as well (& for FA2 as well!) with torch.no_grad(): - with torch.backends.cuda.sdp_kernel( + with sdpa_kernel( enable_flash=enable_kernels, enable_math=True, enable_mem_efficient=enable_kernels, @@ -4198,6 +4214,12 @@ class ModelTesterMixin: if torch_device in ["cpu", "cuda"]: atol = atols[torch_device, enable_kernels, torch_dtype] rtol = rtols[torch_device, enable_kernels, torch_dtype] + elif torch_device == "xpu": + # As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH + # which is implemented on PyTorch level using aten operators and is + # device agnostic with respect to implementation of each aten operator. + atol = atols["cuda", False, torch_dtype] + rtol = rtols["cuda", False, torch_dtype] else: atol = 1e-7 rtol = 1e-4