mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-23 14:29:01 +06:00
Fix test_eager_matches_sdpa_inference
for XPU
backend (#34889)
* Use torch.nn.attention.sdpa_kernel instead of deprecated torch.backends.cuda.sdp_kernel Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com> * Fix test_eager_matches_sdpa_inference for XPU backend As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH which is implemented on PyTorch level using aten operators and is device agnostic with respect to implementation of each aten operator. Thus, we can reuse CUDA (or CPU) MATH weights for XPU. Fixes: #34888 Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com> * Use torch.amp.autocast instead of deprecated torch.cuda.amp.autocast in nemotron Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com> --------- Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
This commit is contained in:
parent
f41d5d8f74
commit
31830474bf
@ -76,7 +76,7 @@ class NemotronLayerNorm1P(nn.LayerNorm):
|
|||||||
|
|
||||||
def forward(self, input: Tensor) -> Tensor:
|
def forward(self, input: Tensor) -> Tensor:
|
||||||
args = _cast_if_autocast_enabled(input, self.normalized_shape, self.weight + 1, self.bias, self.eps)
|
args = _cast_if_autocast_enabled(input, self.normalized_shape, self.weight + 1, self.bias, self.eps)
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch.amp.autocast(input.device.type, enabled=False):
|
||||||
return F.layer_norm(*args)
|
return F.layer_norm(*args)
|
||||||
|
|
||||||
|
|
||||||
|
@ -41,7 +41,7 @@ from transformers.utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
|
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor, sdpa_kernel
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@ -636,7 +636,7 @@ class MimiModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
# TODO: test gradients as well (& for FA2 as well!)
|
# TODO: test gradients as well (& for FA2 as well!)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with torch.backends.cuda.sdp_kernel(
|
with sdpa_kernel(
|
||||||
enable_flash=enable_kernels,
|
enable_flash=enable_kernels,
|
||||||
enable_math=True,
|
enable_math=True,
|
||||||
enable_mem_efficient=enable_kernels,
|
enable_mem_efficient=enable_kernels,
|
||||||
@ -653,6 +653,12 @@ class MimiModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
if torch_device in ["cpu", "cuda"]:
|
if torch_device in ["cpu", "cuda"]:
|
||||||
atol = atols[torch_device, enable_kernels, torch_dtype]
|
atol = atols[torch_device, enable_kernels, torch_dtype]
|
||||||
rtol = rtols[torch_device, enable_kernels, torch_dtype]
|
rtol = rtols[torch_device, enable_kernels, torch_dtype]
|
||||||
|
elif torch_device == "xpu":
|
||||||
|
# As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH
|
||||||
|
# which is implemented on PyTorch level using aten operators and is
|
||||||
|
# device agnostic with respect to implementation of each aten operator.
|
||||||
|
atol = atols["cuda", False, torch_dtype]
|
||||||
|
rtol = rtols["cuda", False, torch_dtype]
|
||||||
else:
|
else:
|
||||||
atol = 1e-7
|
atol = 1e-7
|
||||||
rtol = 1e-4
|
rtol = 1e-4
|
||||||
|
@ -47,7 +47,7 @@ from transformers.utils import cached_property, is_torch_bf16_available_on_devic
|
|||||||
|
|
||||||
from ...generation.test_utils import GenerationTesterMixin
|
from ...generation.test_utils import GenerationTesterMixin
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, sdpa_kernel
|
||||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||||
|
|
||||||
|
|
||||||
@ -607,7 +607,7 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
|||||||
|
|
||||||
# TODO: test gradients as well (& for FA2 as well!)
|
# TODO: test gradients as well (& for FA2 as well!)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with torch.backends.cuda.sdp_kernel(
|
with sdpa_kernel(
|
||||||
enable_flash=enable_kernels,
|
enable_flash=enable_kernels,
|
||||||
enable_math=True,
|
enable_math=True,
|
||||||
enable_mem_efficient=enable_kernels,
|
enable_mem_efficient=enable_kernels,
|
||||||
@ -629,6 +629,12 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
|||||||
if torch_device in ["cpu", "cuda"]:
|
if torch_device in ["cpu", "cuda"]:
|
||||||
atol = atols[torch_device, enable_kernels, torch_dtype]
|
atol = atols[torch_device, enable_kernels, torch_dtype]
|
||||||
rtol = rtols[torch_device, enable_kernels, torch_dtype]
|
rtol = rtols[torch_device, enable_kernels, torch_dtype]
|
||||||
|
elif torch_device == "xpu":
|
||||||
|
# As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH
|
||||||
|
# which is implemented on PyTorch level using aten operators and is
|
||||||
|
# device agnostic with respect to implementation of each aten operator.
|
||||||
|
atol = atols["cuda", False, torch_dtype]
|
||||||
|
rtol = rtols["cuda", False, torch_dtype]
|
||||||
else:
|
else:
|
||||||
atol = 1e-7
|
atol = 1e-7
|
||||||
rtol = 1e-4
|
rtol = 1e-4
|
||||||
@ -1343,7 +1349,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||||||
if isinstance(inp, torch.Tensor) and inp.dtype in [torch.float32, torch.float16]:
|
if isinstance(inp, torch.Tensor) and inp.dtype in [torch.float32, torch.float16]:
|
||||||
inputs_dict[name] = inp.to(torch.float16)
|
inputs_dict[name] = inp.to(torch.float16)
|
||||||
|
|
||||||
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
|
with sdpa_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
|
||||||
_ = model(**inputs_dict)
|
_ = model(**inputs_dict)
|
||||||
|
|
||||||
@require_flash_attn
|
@require_flash_attn
|
||||||
@ -1669,7 +1675,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||||||
# TODO: test gradients as well (& for FA2 as well!)
|
# TODO: test gradients as well (& for FA2 as well!)
|
||||||
# Ignore copy
|
# Ignore copy
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with torch.backends.cuda.sdp_kernel(
|
with sdpa_kernel(
|
||||||
enable_flash=enable_kernels,
|
enable_flash=enable_kernels,
|
||||||
enable_math=True,
|
enable_math=True,
|
||||||
enable_mem_efficient=enable_kernels,
|
enable_mem_efficient=enable_kernels,
|
||||||
@ -1691,6 +1697,12 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||||||
if torch_device in ["cpu", "cuda"]:
|
if torch_device in ["cpu", "cuda"]:
|
||||||
atol = atols[torch_device, enable_kernels, torch_dtype]
|
atol = atols[torch_device, enable_kernels, torch_dtype]
|
||||||
rtol = rtols[torch_device, enable_kernels, torch_dtype]
|
rtol = rtols[torch_device, enable_kernels, torch_dtype]
|
||||||
|
elif torch_device == "xpu":
|
||||||
|
# As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH
|
||||||
|
# which is implemented on PyTorch level using aten operators and is
|
||||||
|
# device agnostic with respect to implementation of each aten operator.
|
||||||
|
atol = atols["cuda", False, torch_dtype]
|
||||||
|
rtol = rtols["cuda", False, torch_dtype]
|
||||||
else:
|
else:
|
||||||
atol = 1e-7
|
atol = 1e-7
|
||||||
rtol = 1e-4
|
rtol = 1e-4
|
||||||
|
@ -48,7 +48,7 @@ from transformers.utils import cached_property, is_torch_bf16_available_on_devic
|
|||||||
|
|
||||||
from ...generation.test_utils import GenerationTesterMixin
|
from ...generation.test_utils import GenerationTesterMixin
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, sdpa_kernel
|
||||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||||
|
|
||||||
|
|
||||||
@ -615,7 +615,7 @@ class MusicgenMelodyDecoderTest(ModelTesterMixin, GenerationTesterMixin, unittes
|
|||||||
|
|
||||||
# TODO: test gradients as well (& for FA2 as well!)
|
# TODO: test gradients as well (& for FA2 as well!)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with torch.backends.cuda.sdp_kernel(
|
with sdpa_kernel(
|
||||||
enable_flash=enable_kernels,
|
enable_flash=enable_kernels,
|
||||||
enable_math=True,
|
enable_math=True,
|
||||||
enable_mem_efficient=enable_kernels,
|
enable_mem_efficient=enable_kernels,
|
||||||
@ -637,6 +637,12 @@ class MusicgenMelodyDecoderTest(ModelTesterMixin, GenerationTesterMixin, unittes
|
|||||||
if torch_device in ["cpu", "cuda"]:
|
if torch_device in ["cpu", "cuda"]:
|
||||||
atol = atols[torch_device, enable_kernels, torch_dtype]
|
atol = atols[torch_device, enable_kernels, torch_dtype]
|
||||||
rtol = rtols[torch_device, enable_kernels, torch_dtype]
|
rtol = rtols[torch_device, enable_kernels, torch_dtype]
|
||||||
|
elif torch_device == "xpu":
|
||||||
|
# As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH
|
||||||
|
# which is implemented on PyTorch level using aten operators and is
|
||||||
|
# device agnostic with respect to implementation of each aten operator.
|
||||||
|
atol = atols["cuda", False, torch_dtype]
|
||||||
|
rtol = rtols["cuda", False, torch_dtype]
|
||||||
else:
|
else:
|
||||||
atol = 1e-7
|
atol = 1e-7
|
||||||
rtol = 1e-4
|
rtol = 1e-4
|
||||||
@ -1333,7 +1339,7 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
|||||||
if isinstance(inp, torch.Tensor) and inp.dtype in [torch.float32, torch.float16]:
|
if isinstance(inp, torch.Tensor) and inp.dtype in [torch.float32, torch.float16]:
|
||||||
inputs_dict[name] = inp.to(torch.float16)
|
inputs_dict[name] = inp.to(torch.float16)
|
||||||
|
|
||||||
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
|
with sdpa_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
|
||||||
_ = model(**inputs_dict)
|
_ = model(**inputs_dict)
|
||||||
|
|
||||||
@require_flash_attn
|
@require_flash_attn
|
||||||
@ -1632,7 +1638,7 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
|||||||
# TODO: test gradients as well (& for FA2 as well!)
|
# TODO: test gradients as well (& for FA2 as well!)
|
||||||
# Ignore copy
|
# Ignore copy
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with torch.backends.cuda.sdp_kernel(
|
with sdpa_kernel(
|
||||||
enable_flash=enable_kernels,
|
enable_flash=enable_kernels,
|
||||||
enable_math=True,
|
enable_math=True,
|
||||||
enable_mem_efficient=enable_kernels,
|
enable_mem_efficient=enable_kernels,
|
||||||
@ -1654,6 +1660,12 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
|||||||
if torch_device in ["cpu", "cuda"]:
|
if torch_device in ["cpu", "cuda"]:
|
||||||
atol = atols[torch_device, enable_kernels, torch_dtype]
|
atol = atols[torch_device, enable_kernels, torch_dtype]
|
||||||
rtol = rtols[torch_device, enable_kernels, torch_dtype]
|
rtol = rtols[torch_device, enable_kernels, torch_dtype]
|
||||||
|
elif torch_device == "xpu":
|
||||||
|
# As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH
|
||||||
|
# which is implemented on PyTorch level using aten operators and is
|
||||||
|
# device agnostic with respect to implementation of each aten operator.
|
||||||
|
atol = atols["cuda", False, torch_dtype]
|
||||||
|
rtol = rtols["cuda", False, torch_dtype]
|
||||||
else:
|
else:
|
||||||
atol = 1e-7
|
atol = 1e-7
|
||||||
rtol = 1e-4
|
rtol = 1e-4
|
||||||
|
@ -187,6 +187,22 @@ def _deepspeed_zero3(ds_config):
|
|||||||
unset_hf_deepspeed_config()
|
unset_hf_deepspeed_config()
|
||||||
|
|
||||||
|
|
||||||
|
def sdpa_kernel(enable_flash, enable_math, enable_mem_efficient):
|
||||||
|
if version.parse(torch.__version__).release < version.parse("2.3").release:
|
||||||
|
return torch.backends.cuda.sdp_kernel(
|
||||||
|
enable_flash=enable_flash, enable_math=enable_math, enable_mem_efficient=enable_mem_efficient
|
||||||
|
)
|
||||||
|
|
||||||
|
backends = []
|
||||||
|
if enable_flash:
|
||||||
|
backends += [torch.nn.attention.SDPBackend.FLASH_ATTENTION]
|
||||||
|
if enable_math:
|
||||||
|
backends += [torch.nn.attention.SDPBackend.MATH]
|
||||||
|
if enable_mem_efficient:
|
||||||
|
backends += [torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION]
|
||||||
|
return torch.nn.attention.sdpa_kernel(backends)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class ModelTesterMixin:
|
class ModelTesterMixin:
|
||||||
model_tester = None
|
model_tester = None
|
||||||
@ -4175,7 +4191,7 @@ class ModelTesterMixin:
|
|||||||
|
|
||||||
# TODO: test gradients as well (& for FA2 as well!)
|
# TODO: test gradients as well (& for FA2 as well!)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with torch.backends.cuda.sdp_kernel(
|
with sdpa_kernel(
|
||||||
enable_flash=enable_kernels,
|
enable_flash=enable_kernels,
|
||||||
enable_math=True,
|
enable_math=True,
|
||||||
enable_mem_efficient=enable_kernels,
|
enable_mem_efficient=enable_kernels,
|
||||||
@ -4198,6 +4214,12 @@ class ModelTesterMixin:
|
|||||||
if torch_device in ["cpu", "cuda"]:
|
if torch_device in ["cpu", "cuda"]:
|
||||||
atol = atols[torch_device, enable_kernels, torch_dtype]
|
atol = atols[torch_device, enable_kernels, torch_dtype]
|
||||||
rtol = rtols[torch_device, enable_kernels, torch_dtype]
|
rtol = rtols[torch_device, enable_kernels, torch_dtype]
|
||||||
|
elif torch_device == "xpu":
|
||||||
|
# As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH
|
||||||
|
# which is implemented on PyTorch level using aten operators and is
|
||||||
|
# device agnostic with respect to implementation of each aten operator.
|
||||||
|
atol = atols["cuda", False, torch_dtype]
|
||||||
|
rtol = rtols["cuda", False, torch_dtype]
|
||||||
else:
|
else:
|
||||||
atol = 1e-7
|
atol = 1e-7
|
||||||
rtol = 1e-4
|
rtol = 1e-4
|
||||||
|
Loading…
Reference in New Issue
Block a user