mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Fix SDPA dispatch & make SDPA CI compatible with torch<2.1.1 (#27940)
fix sdpa dispatch
This commit is contained in:
parent
7ea21f1f03
commit
9f18cc6df0
@ -1244,6 +1244,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
# Here we use config._attn_implementation_internal to check whether the attention implementation was explicitely set by the user.
|
||||
# The property `PretrainedConfig._attn_implementation` is never `None`, for backward compatibility (always fall back on "eager").
|
||||
# The `hasattr` here is used as some Transformers tests for some reason do not call PretrainedConfig __init__ (e.g. test_no_super_init_config_and_model)
|
||||
requested_attn_implementation = None
|
||||
if hasattr(config, "_attn_implementation_internal") and config._attn_implementation_internal is not None:
|
||||
if config._attn_implementation != "flash_attention_2" and use_flash_attention_2:
|
||||
raise ValueError(
|
||||
@ -1260,9 +1261,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
raise ValueError(message + ".")
|
||||
|
||||
# If a config is passed with a preset attn_implementation, we skip the automatic dispatch and use the user-provided config, with hard checks that the requested attention implementation is available.
|
||||
hard_check_only = True
|
||||
else:
|
||||
hard_check_only = False
|
||||
requested_attn_implementation = config._attn_implementation_internal
|
||||
|
||||
if use_flash_attention_2:
|
||||
logger.warning_once(
|
||||
@ -1275,13 +1274,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
config,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map=device_map,
|
||||
hard_check_only=hard_check_only,
|
||||
hard_check_only=False,
|
||||
check_device_map=check_device_map,
|
||||
)
|
||||
elif cls._supports_sdpa or config._attn_implementation == "sdpa":
|
||||
elif requested_attn_implementation in [None, "sdpa"]:
|
||||
# use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif.
|
||||
config = cls._check_and_enable_sdpa(config, hard_check_only=hard_check_only)
|
||||
elif not hard_check_only:
|
||||
config = cls._check_and_enable_sdpa(
|
||||
config, hard_check_only=False if requested_attn_implementation is None else True
|
||||
)
|
||||
else:
|
||||
config._attn_implementation = "eager"
|
||||
|
||||
return config
|
||||
|
@ -83,6 +83,7 @@ from transformers.utils import (
|
||||
is_flax_available,
|
||||
is_tf_available,
|
||||
is_torch_fx_available,
|
||||
is_torch_sdpa_available,
|
||||
)
|
||||
from transformers.utils.generic import ModelOutput
|
||||
|
||||
@ -778,7 +779,7 @@ class ModelTesterMixin:
|
||||
configs_no_init.torchscript = True
|
||||
for model_class in self.all_model_classes:
|
||||
for attn_implementation in ["eager", "sdpa"]:
|
||||
if attn_implementation == "sdpa" and not model_class._supports_sdpa:
|
||||
if attn_implementation == "sdpa" and (not model_class._supports_sdpa or not is_torch_sdpa_available()):
|
||||
continue
|
||||
|
||||
configs_no_init._attn_implementation = attn_implementation
|
||||
|
Loading…
Reference in New Issue
Block a user