mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
Fix SDPA dispatch & make SDPA CI compatible with torch<2.1.1 (#27940)
fix sdpa dispatch
This commit is contained in:
parent
7ea21f1f03
commit
9f18cc6df0
@ -1244,6 +1244,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
# Here we use config._attn_implementation_internal to check whether the attention implementation was explicitely set by the user.
|
# Here we use config._attn_implementation_internal to check whether the attention implementation was explicitely set by the user.
|
||||||
# The property `PretrainedConfig._attn_implementation` is never `None`, for backward compatibility (always fall back on "eager").
|
# The property `PretrainedConfig._attn_implementation` is never `None`, for backward compatibility (always fall back on "eager").
|
||||||
# The `hasattr` here is used as some Transformers tests for some reason do not call PretrainedConfig __init__ (e.g. test_no_super_init_config_and_model)
|
# The `hasattr` here is used as some Transformers tests for some reason do not call PretrainedConfig __init__ (e.g. test_no_super_init_config_and_model)
|
||||||
|
requested_attn_implementation = None
|
||||||
if hasattr(config, "_attn_implementation_internal") and config._attn_implementation_internal is not None:
|
if hasattr(config, "_attn_implementation_internal") and config._attn_implementation_internal is not None:
|
||||||
if config._attn_implementation != "flash_attention_2" and use_flash_attention_2:
|
if config._attn_implementation != "flash_attention_2" and use_flash_attention_2:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -1260,9 +1261,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
raise ValueError(message + ".")
|
raise ValueError(message + ".")
|
||||||
|
|
||||||
# If a config is passed with a preset attn_implementation, we skip the automatic dispatch and use the user-provided config, with hard checks that the requested attention implementation is available.
|
# If a config is passed with a preset attn_implementation, we skip the automatic dispatch and use the user-provided config, with hard checks that the requested attention implementation is available.
|
||||||
hard_check_only = True
|
requested_attn_implementation = config._attn_implementation_internal
|
||||||
else:
|
|
||||||
hard_check_only = False
|
|
||||||
|
|
||||||
if use_flash_attention_2:
|
if use_flash_attention_2:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
@ -1275,13 +1274,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
config,
|
config,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
device_map=device_map,
|
device_map=device_map,
|
||||||
hard_check_only=hard_check_only,
|
hard_check_only=False,
|
||||||
check_device_map=check_device_map,
|
check_device_map=check_device_map,
|
||||||
)
|
)
|
||||||
elif cls._supports_sdpa or config._attn_implementation == "sdpa":
|
elif requested_attn_implementation in [None, "sdpa"]:
|
||||||
# use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif.
|
# use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif.
|
||||||
config = cls._check_and_enable_sdpa(config, hard_check_only=hard_check_only)
|
config = cls._check_and_enable_sdpa(
|
||||||
elif not hard_check_only:
|
config, hard_check_only=False if requested_attn_implementation is None else True
|
||||||
|
)
|
||||||
|
else:
|
||||||
config._attn_implementation = "eager"
|
config._attn_implementation = "eager"
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
@ -83,6 +83,7 @@ from transformers.utils import (
|
|||||||
is_flax_available,
|
is_flax_available,
|
||||||
is_tf_available,
|
is_tf_available,
|
||||||
is_torch_fx_available,
|
is_torch_fx_available,
|
||||||
|
is_torch_sdpa_available,
|
||||||
)
|
)
|
||||||
from transformers.utils.generic import ModelOutput
|
from transformers.utils.generic import ModelOutput
|
||||||
|
|
||||||
@ -778,7 +779,7 @@ class ModelTesterMixin:
|
|||||||
configs_no_init.torchscript = True
|
configs_no_init.torchscript = True
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
for attn_implementation in ["eager", "sdpa"]:
|
for attn_implementation in ["eager", "sdpa"]:
|
||||||
if attn_implementation == "sdpa" and not model_class._supports_sdpa:
|
if attn_implementation == "sdpa" and (not model_class._supports_sdpa or not is_torch_sdpa_available()):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
configs_no_init._attn_implementation = attn_implementation
|
configs_no_init._attn_implementation = attn_implementation
|
||||||
|
Loading…
Reference in New Issue
Block a user