Fix SDPA dispatch & make SDPA CI compatible with torch<2.1.1 (#27940)

fix sdpa dispatch
This commit is contained in:
fxmarty 2023-12-11 10:56:38 +01:00 committed by GitHub
parent 7ea21f1f03
commit 9f18cc6df0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 10 additions and 8 deletions

View File

@ -1244,6 +1244,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Here we use config._attn_implementation_internal to check whether the attention implementation was explicitely set by the user. # Here we use config._attn_implementation_internal to check whether the attention implementation was explicitely set by the user.
# The property `PretrainedConfig._attn_implementation` is never `None`, for backward compatibility (always fall back on "eager"). # The property `PretrainedConfig._attn_implementation` is never `None`, for backward compatibility (always fall back on "eager").
# The `hasattr` here is used as some Transformers tests for some reason do not call PretrainedConfig __init__ (e.g. test_no_super_init_config_and_model) # The `hasattr` here is used as some Transformers tests for some reason do not call PretrainedConfig __init__ (e.g. test_no_super_init_config_and_model)
requested_attn_implementation = None
if hasattr(config, "_attn_implementation_internal") and config._attn_implementation_internal is not None: if hasattr(config, "_attn_implementation_internal") and config._attn_implementation_internal is not None:
if config._attn_implementation != "flash_attention_2" and use_flash_attention_2: if config._attn_implementation != "flash_attention_2" and use_flash_attention_2:
raise ValueError( raise ValueError(
@ -1260,9 +1261,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
raise ValueError(message + ".") raise ValueError(message + ".")
# If a config is passed with a preset attn_implementation, we skip the automatic dispatch and use the user-provided config, with hard checks that the requested attention implementation is available. # If a config is passed with a preset attn_implementation, we skip the automatic dispatch and use the user-provided config, with hard checks that the requested attention implementation is available.
hard_check_only = True requested_attn_implementation = config._attn_implementation_internal
else:
hard_check_only = False
if use_flash_attention_2: if use_flash_attention_2:
logger.warning_once( logger.warning_once(
@ -1275,13 +1274,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
config, config,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
device_map=device_map, device_map=device_map,
hard_check_only=hard_check_only, hard_check_only=False,
check_device_map=check_device_map, check_device_map=check_device_map,
) )
elif cls._supports_sdpa or config._attn_implementation == "sdpa": elif requested_attn_implementation in [None, "sdpa"]:
# use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif. # use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif.
config = cls._check_and_enable_sdpa(config, hard_check_only=hard_check_only) config = cls._check_and_enable_sdpa(
elif not hard_check_only: config, hard_check_only=False if requested_attn_implementation is None else True
)
else:
config._attn_implementation = "eager" config._attn_implementation = "eager"
return config return config

View File

@ -83,6 +83,7 @@ from transformers.utils import (
is_flax_available, is_flax_available,
is_tf_available, is_tf_available,
is_torch_fx_available, is_torch_fx_available,
is_torch_sdpa_available,
) )
from transformers.utils.generic import ModelOutput from transformers.utils.generic import ModelOutput
@ -778,7 +779,7 @@ class ModelTesterMixin:
configs_no_init.torchscript = True configs_no_init.torchscript = True
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
for attn_implementation in ["eager", "sdpa"]: for attn_implementation in ["eager", "sdpa"]:
if attn_implementation == "sdpa" and not model_class._supports_sdpa: if attn_implementation == "sdpa" and (not model_class._supports_sdpa or not is_torch_sdpa_available()):
continue continue
configs_no_init._attn_implementation = attn_implementation configs_no_init._attn_implementation = attn_implementation