Remove deprecated use_flash_attention_2 parameter (#37131)

Signed-off-by: cyy <cyyever@outlook.com>
This commit is contained in:
Yuanyuan Chen 2025-06-02 17:06:25 +08:00 committed by GitHub
parent 51d732709e
commit fde1120b6c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 7 additions and 30 deletions

View File

@ -2157,8 +2157,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
if isinstance(torch_dtype, str): if isinstance(torch_dtype, str):
torch_dtype = getattr(torch, torch_dtype) torch_dtype = getattr(torch, torch_dtype)
use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False)
# override default dtype if needed # override default dtype if needed
dtype_orig = None dtype_orig = None
if torch_dtype is not None: if torch_dtype is not None:
@ -2177,7 +2175,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
if not getattr(config, "_attn_implementation_autoset", False): if not getattr(config, "_attn_implementation_autoset", False):
config = cls._autoset_attn_implementation( config = cls._autoset_attn_implementation(
config, config,
use_flash_attention_2=use_flash_attention_2,
check_device_map=False, check_device_map=False,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
) )
@ -2205,7 +2202,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
def _autoset_attn_implementation( def _autoset_attn_implementation(
cls, cls,
config, config,
use_flash_attention_2: bool = False,
torch_dtype: Optional[torch.dtype] = None, torch_dtype: Optional[torch.dtype] = None,
device_map: Optional[Union[str, Dict[str, int]]] = None, device_map: Optional[Union[str, Dict[str, int]]] = None,
check_device_map: bool = True, check_device_map: bool = True,
@ -2213,21 +2209,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
""" """
Automatically checks and dispatches to a default attention implementation. In order of priority: Automatically checks and dispatches to a default attention implementation. In order of priority:
1. An implementation specified in `config._attn_implementation` (due for example to the argument attn_implementation="sdpa" in from_pretrained). 1. An implementation specified in `config._attn_implementation` (due for example to the argument attn_implementation="sdpa" in from_pretrained).
2. DEPRECATED: if use_flash_attention_2 is set to `True` and `flash_attn` is available, flash attention. (`LlamaFlashAttention` for example) 2. SDPA implementation, if available and supported by the model type. (`LlamaSdpaAttention` for example)
3. SDPA implementation, if available and supported by the model type. (`LlamaSdpaAttention` for example) 3. The default model's implementation otherwise (`LlamaAttention` for example) .
4. The default model's implementation otherwise (`LlamaAttention` for example) .
""" """
# Here we use config._attn_implementation_internal to check whether the attention implementation was explicitly set by the user. # Here we use config._attn_implementation_internal to check whether the attention implementation was explicitly set by the user.
# The property `PretrainedConfig._attn_implementation` is never `None`, for backward compatibility (always fall back on "eager"). # The property `PretrainedConfig._attn_implementation` is never `None`, for backward compatibility (always fall back on "eager").
# The `hasattr` here is used as some Transformers tests for some reason do not call PretrainedConfig __init__ (e.g. test_no_super_init_config_and_model) # The `hasattr` here is used as some Transformers tests for some reason do not call PretrainedConfig __init__ (e.g. test_no_super_init_config_and_model)
requested_attn_implementation = None requested_attn_implementation = None
if hasattr(config, "_attn_implementation_internal") and config._attn_implementation_internal is not None: if hasattr(config, "_attn_implementation_internal") and config._attn_implementation_internal is not None:
if config._attn_implementation != "flash_attention_2" and use_flash_attention_2:
raise ValueError(
f'Both attn_implementation="{config._attn_implementation}" and `use_flash_attention_2=True` were used when loading the model, which are not compatible.'
' We recommend to just use `attn_implementation="flash_attention_2"` when loading the model.'
)
if isinstance(config._attn_implementation, str) and re.match( if isinstance(config._attn_implementation, str) and re.match(
r"^[^/:]+/[^/:]+:[^/:]+$", config._attn_implementation r"^[^/:]+/[^/:]+:[^/:]+$", config._attn_implementation
): ):
@ -2292,12 +2281,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
if sub_config is not None: if sub_config is not None:
sub_config._attn_implementation_internal = curr_attn_implementation sub_config._attn_implementation_internal = curr_attn_implementation
if use_flash_attention_2:
logger.warning_once(
'The model was loaded with use_flash_attention_2=True, which is deprecated and may be removed in a future release. Please use `attn_implementation="flash_attention_2"` instead.'
)
config._attn_implementation = "flash_attention_2"
if config._attn_implementation == "flash_attention_2": if config._attn_implementation == "flash_attention_2":
cls._check_and_enable_flash_attn_2( cls._check_and_enable_flash_attn_2(
config, config,
@ -2309,10 +2292,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
elif requested_attn_implementation == "flex_attention": elif requested_attn_implementation == "flex_attention":
config = cls._check_and_enable_flex_attn(config, hard_check_only=True) config = cls._check_and_enable_flex_attn(config, hard_check_only=True)
elif requested_attn_implementation in [None, "sdpa"] and not is_torch_xla_available(): elif requested_attn_implementation in [None, "sdpa"] and not is_torch_xla_available():
# use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif. # flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif.
config = cls._check_and_enable_sdpa( config = cls._check_and_enable_sdpa(
config, config,
hard_check_only=False if requested_attn_implementation is None else True, hard_check_only=requested_attn_implementation is not None,
) )
if ( if (
@ -4256,7 +4239,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
variant = kwargs.pop("variant", None) variant = kwargs.pop("variant", None)
adapter_kwargs = kwargs.pop("adapter_kwargs", {}) adapter_kwargs = kwargs.pop("adapter_kwargs", {})
adapter_name = kwargs.pop("adapter_name", "default") adapter_name = kwargs.pop("adapter_name", "default")
use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False)
generation_config = kwargs.pop("generation_config", None) generation_config = kwargs.pop("generation_config", None)
gguf_file = kwargs.pop("gguf_file", None) gguf_file = kwargs.pop("gguf_file", None)
tp_plan = kwargs.pop("tp_plan", None) tp_plan = kwargs.pop("tp_plan", None)
@ -4618,7 +4600,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
if not getattr(config, "_attn_implementation_autoset", False): if not getattr(config, "_attn_implementation_autoset", False):
config = cls._autoset_attn_implementation( config = cls._autoset_attn_implementation(
config, config,
use_flash_attention_2=use_flash_attention_2,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
device_map=device_map, device_map=device_map,
) )

View File

@ -615,7 +615,6 @@ class ModernBertPreTrainedModel(PreTrainedModel):
def _autoset_attn_implementation( def _autoset_attn_implementation(
cls, cls,
config, config,
use_flash_attention_2: bool = False,
torch_dtype: Optional[torch.dtype] = None, torch_dtype: Optional[torch.dtype] = None,
device_map: Optional[Union[str, Dict[str, int]]] = None, device_map: Optional[Union[str, Dict[str, int]]] = None,
check_device_map: bool = True, check_device_map: bool = True,
@ -638,8 +637,7 @@ class ModernBertPreTrainedModel(PreTrainedModel):
config._attn_implementation_internal = None config._attn_implementation_internal = None
return super()._autoset_attn_implementation( return super()._autoset_attn_implementation(
config, config,
use_flash_attention_2=use_flash_attention_2, torch_dtype=torch_dtype,
torch_dtype=torch.float16,
device_map=device_map, device_map=device_map,
check_device_map=check_device_map, check_device_map=check_device_map,
) )

View File

@ -817,7 +817,6 @@ class ModernBertPreTrainedModel(PreTrainedModel):
def _autoset_attn_implementation( def _autoset_attn_implementation(
cls, cls,
config, config,
use_flash_attention_2: bool = False,
torch_dtype: Optional[torch.dtype] = None, torch_dtype: Optional[torch.dtype] = None,
device_map: Optional[Union[str, Dict[str, int]]] = None, device_map: Optional[Union[str, Dict[str, int]]] = None,
check_device_map: bool = True, check_device_map: bool = True,
@ -840,8 +839,7 @@ class ModernBertPreTrainedModel(PreTrainedModel):
config._attn_implementation_internal = None config._attn_implementation_internal = None
return super()._autoset_attn_implementation( return super()._autoset_attn_implementation(
config, config,
use_flash_attention_2=use_flash_attention_2, torch_dtype=torch_dtype,
torch_dtype=torch.float16,
device_map=device_map, device_map=device_map,
check_device_map=check_device_map, check_device_map=check_device_map,
) )

View File

@ -488,7 +488,7 @@ class DiffLlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
model.save_pretrained(tmp_dir) model.save_pretrained(tmp_dir)
new_model = DiffLlamaForCausalLM.from_pretrained( new_model = DiffLlamaForCausalLM.from_pretrained(
tmp_dir, use_flash_attention_2=True, torch_dtype=torch.float16 tmp_dir, attn_implementation="flash_attention_2", torch_dtype=torch.float16
).to("cuda") ).to("cuda")
self.assertTrue(new_model.config._attn_implementation == "flash_attention_2") self.assertTrue(new_model.config._attn_implementation == "flash_attention_2")