Fix test_eager_matches_sdpa_inference for XPU backend (#34889)

* Use torch.nn.attention.sdpa_kernel instead of deprecated torch.backends.cuda.sdp_kernel

Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>

* Fix test_eager_matches_sdpa_inference for XPU backend

As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH
which is implemented on PyTorch level using aten operators and is device
agnostic with respect to implementation of each aten operator. Thus, we can
reuse CUDA (or CPU) MATH weights for XPU.

Fixes: #34888
Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>

* Use torch.amp.autocast instead of deprecated torch.cuda.amp.autocast in nemotron

Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>

---------

Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
This commit is contained in:
Dmitry Rogozhkin 2024-12-02 07:21:04 -08:00 committed by GitHub
parent f41d5d8f74
commit 31830474bf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 64 additions and 12 deletions

View File

@ -76,7 +76,7 @@ class NemotronLayerNorm1P(nn.LayerNorm):
def forward(self, input: Tensor) -> Tensor: def forward(self, input: Tensor) -> Tensor:
args = _cast_if_autocast_enabled(input, self.normalized_shape, self.weight + 1, self.bias, self.eps) args = _cast_if_autocast_enabled(input, self.normalized_shape, self.weight + 1, self.bias, self.eps)
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast(input.device.type, enabled=False):
return F.layer_norm(*args) return F.layer_norm(*args)

View File

@ -41,7 +41,7 @@ from transformers.utils import (
) )
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor, sdpa_kernel
if is_torch_available(): if is_torch_available():
@ -636,7 +636,7 @@ class MimiModelTest(ModelTesterMixin, unittest.TestCase):
# TODO: test gradients as well (& for FA2 as well!) # TODO: test gradients as well (& for FA2 as well!)
with torch.no_grad(): with torch.no_grad():
with torch.backends.cuda.sdp_kernel( with sdpa_kernel(
enable_flash=enable_kernels, enable_flash=enable_kernels,
enable_math=True, enable_math=True,
enable_mem_efficient=enable_kernels, enable_mem_efficient=enable_kernels,
@ -653,6 +653,12 @@ class MimiModelTest(ModelTesterMixin, unittest.TestCase):
if torch_device in ["cpu", "cuda"]: if torch_device in ["cpu", "cuda"]:
atol = atols[torch_device, enable_kernels, torch_dtype] atol = atols[torch_device, enable_kernels, torch_dtype]
rtol = rtols[torch_device, enable_kernels, torch_dtype] rtol = rtols[torch_device, enable_kernels, torch_dtype]
elif torch_device == "xpu":
# As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH
# which is implemented on PyTorch level using aten operators and is
# device agnostic with respect to implementation of each aten operator.
atol = atols["cuda", False, torch_dtype]
rtol = rtols["cuda", False, torch_dtype]
else: else:
atol = 1e-7 atol = 1e-7
rtol = 1e-4 rtol = 1e-4

View File

@ -47,7 +47,7 @@ from transformers.utils import cached_property, is_torch_bf16_available_on_devic
from ...generation.test_utils import GenerationTesterMixin from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, sdpa_kernel
from ...test_pipeline_mixin import PipelineTesterMixin from ...test_pipeline_mixin import PipelineTesterMixin
@ -607,7 +607,7 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
# TODO: test gradients as well (& for FA2 as well!) # TODO: test gradients as well (& for FA2 as well!)
with torch.no_grad(): with torch.no_grad():
with torch.backends.cuda.sdp_kernel( with sdpa_kernel(
enable_flash=enable_kernels, enable_flash=enable_kernels,
enable_math=True, enable_math=True,
enable_mem_efficient=enable_kernels, enable_mem_efficient=enable_kernels,
@ -629,6 +629,12 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
if torch_device in ["cpu", "cuda"]: if torch_device in ["cpu", "cuda"]:
atol = atols[torch_device, enable_kernels, torch_dtype] atol = atols[torch_device, enable_kernels, torch_dtype]
rtol = rtols[torch_device, enable_kernels, torch_dtype] rtol = rtols[torch_device, enable_kernels, torch_dtype]
elif torch_device == "xpu":
# As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH
# which is implemented on PyTorch level using aten operators and is
# device agnostic with respect to implementation of each aten operator.
atol = atols["cuda", False, torch_dtype]
rtol = rtols["cuda", False, torch_dtype]
else: else:
atol = 1e-7 atol = 1e-7
rtol = 1e-4 rtol = 1e-4
@ -1343,7 +1349,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
if isinstance(inp, torch.Tensor) and inp.dtype in [torch.float32, torch.float16]: if isinstance(inp, torch.Tensor) and inp.dtype in [torch.float32, torch.float16]:
inputs_dict[name] = inp.to(torch.float16) inputs_dict[name] = inp.to(torch.float16)
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): with sdpa_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
_ = model(**inputs_dict) _ = model(**inputs_dict)
@require_flash_attn @require_flash_attn
@ -1669,7 +1675,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
# TODO: test gradients as well (& for FA2 as well!) # TODO: test gradients as well (& for FA2 as well!)
# Ignore copy # Ignore copy
with torch.no_grad(): with torch.no_grad():
with torch.backends.cuda.sdp_kernel( with sdpa_kernel(
enable_flash=enable_kernels, enable_flash=enable_kernels,
enable_math=True, enable_math=True,
enable_mem_efficient=enable_kernels, enable_mem_efficient=enable_kernels,
@ -1691,6 +1697,12 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
if torch_device in ["cpu", "cuda"]: if torch_device in ["cpu", "cuda"]:
atol = atols[torch_device, enable_kernels, torch_dtype] atol = atols[torch_device, enable_kernels, torch_dtype]
rtol = rtols[torch_device, enable_kernels, torch_dtype] rtol = rtols[torch_device, enable_kernels, torch_dtype]
elif torch_device == "xpu":
# As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH
# which is implemented on PyTorch level using aten operators and is
# device agnostic with respect to implementation of each aten operator.
atol = atols["cuda", False, torch_dtype]
rtol = rtols["cuda", False, torch_dtype]
else: else:
atol = 1e-7 atol = 1e-7
rtol = 1e-4 rtol = 1e-4

View File

@ -48,7 +48,7 @@ from transformers.utils import cached_property, is_torch_bf16_available_on_devic
from ...generation.test_utils import GenerationTesterMixin from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, sdpa_kernel
from ...test_pipeline_mixin import PipelineTesterMixin from ...test_pipeline_mixin import PipelineTesterMixin
@ -615,7 +615,7 @@ class MusicgenMelodyDecoderTest(ModelTesterMixin, GenerationTesterMixin, unittes
# TODO: test gradients as well (& for FA2 as well!) # TODO: test gradients as well (& for FA2 as well!)
with torch.no_grad(): with torch.no_grad():
with torch.backends.cuda.sdp_kernel( with sdpa_kernel(
enable_flash=enable_kernels, enable_flash=enable_kernels,
enable_math=True, enable_math=True,
enable_mem_efficient=enable_kernels, enable_mem_efficient=enable_kernels,
@ -637,6 +637,12 @@ class MusicgenMelodyDecoderTest(ModelTesterMixin, GenerationTesterMixin, unittes
if torch_device in ["cpu", "cuda"]: if torch_device in ["cpu", "cuda"]:
atol = atols[torch_device, enable_kernels, torch_dtype] atol = atols[torch_device, enable_kernels, torch_dtype]
rtol = rtols[torch_device, enable_kernels, torch_dtype] rtol = rtols[torch_device, enable_kernels, torch_dtype]
elif torch_device == "xpu":
# As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH
# which is implemented on PyTorch level using aten operators and is
# device agnostic with respect to implementation of each aten operator.
atol = atols["cuda", False, torch_dtype]
rtol = rtols["cuda", False, torch_dtype]
else: else:
atol = 1e-7 atol = 1e-7
rtol = 1e-4 rtol = 1e-4
@ -1333,7 +1339,7 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
if isinstance(inp, torch.Tensor) and inp.dtype in [torch.float32, torch.float16]: if isinstance(inp, torch.Tensor) and inp.dtype in [torch.float32, torch.float16]:
inputs_dict[name] = inp.to(torch.float16) inputs_dict[name] = inp.to(torch.float16)
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): with sdpa_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
_ = model(**inputs_dict) _ = model(**inputs_dict)
@require_flash_attn @require_flash_attn
@ -1632,7 +1638,7 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
# TODO: test gradients as well (& for FA2 as well!) # TODO: test gradients as well (& for FA2 as well!)
# Ignore copy # Ignore copy
with torch.no_grad(): with torch.no_grad():
with torch.backends.cuda.sdp_kernel( with sdpa_kernel(
enable_flash=enable_kernels, enable_flash=enable_kernels,
enable_math=True, enable_math=True,
enable_mem_efficient=enable_kernels, enable_mem_efficient=enable_kernels,
@ -1654,6 +1660,12 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
if torch_device in ["cpu", "cuda"]: if torch_device in ["cpu", "cuda"]:
atol = atols[torch_device, enable_kernels, torch_dtype] atol = atols[torch_device, enable_kernels, torch_dtype]
rtol = rtols[torch_device, enable_kernels, torch_dtype] rtol = rtols[torch_device, enable_kernels, torch_dtype]
elif torch_device == "xpu":
# As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH
# which is implemented on PyTorch level using aten operators and is
# device agnostic with respect to implementation of each aten operator.
atol = atols["cuda", False, torch_dtype]
rtol = rtols["cuda", False, torch_dtype]
else: else:
atol = 1e-7 atol = 1e-7
rtol = 1e-4 rtol = 1e-4

View File

@ -187,6 +187,22 @@ def _deepspeed_zero3(ds_config):
unset_hf_deepspeed_config() unset_hf_deepspeed_config()
def sdpa_kernel(enable_flash, enable_math, enable_mem_efficient):
if version.parse(torch.__version__).release < version.parse("2.3").release:
return torch.backends.cuda.sdp_kernel(
enable_flash=enable_flash, enable_math=enable_math, enable_mem_efficient=enable_mem_efficient
)
backends = []
if enable_flash:
backends += [torch.nn.attention.SDPBackend.FLASH_ATTENTION]
if enable_math:
backends += [torch.nn.attention.SDPBackend.MATH]
if enable_mem_efficient:
backends += [torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION]
return torch.nn.attention.sdpa_kernel(backends)
@require_torch @require_torch
class ModelTesterMixin: class ModelTesterMixin:
model_tester = None model_tester = None
@ -4175,7 +4191,7 @@ class ModelTesterMixin:
# TODO: test gradients as well (& for FA2 as well!) # TODO: test gradients as well (& for FA2 as well!)
with torch.no_grad(): with torch.no_grad():
with torch.backends.cuda.sdp_kernel( with sdpa_kernel(
enable_flash=enable_kernels, enable_flash=enable_kernels,
enable_math=True, enable_math=True,
enable_mem_efficient=enable_kernels, enable_mem_efficient=enable_kernels,
@ -4198,6 +4214,12 @@ class ModelTesterMixin:
if torch_device in ["cpu", "cuda"]: if torch_device in ["cpu", "cuda"]:
atol = atols[torch_device, enable_kernels, torch_dtype] atol = atols[torch_device, enable_kernels, torch_dtype]
rtol = rtols[torch_device, enable_kernels, torch_dtype] rtol = rtols[torch_device, enable_kernels, torch_dtype]
elif torch_device == "xpu":
# As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH
# which is implemented on PyTorch level using aten operators and is
# device agnostic with respect to implementation of each aten operator.
atol = atols["cuda", False, torch_dtype]
rtol = rtols["cuda", False, torch_dtype]
else: else:
atol = 1e-7 atol = 1e-7
rtol = 1e-4 rtol = 1e-4