diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 9639ff7ce06..c3c4c75750c 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2157,8 +2157,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi if isinstance(torch_dtype, str): torch_dtype = getattr(torch, torch_dtype) - use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False) - # override default dtype if needed dtype_orig = None if torch_dtype is not None: @@ -2177,7 +2175,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi if not getattr(config, "_attn_implementation_autoset", False): config = cls._autoset_attn_implementation( config, - use_flash_attention_2=use_flash_attention_2, check_device_map=False, torch_dtype=torch_dtype, ) @@ -2205,7 +2202,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi def _autoset_attn_implementation( cls, config, - use_flash_attention_2: bool = False, torch_dtype: Optional[torch.dtype] = None, device_map: Optional[Union[str, Dict[str, int]]] = None, check_device_map: bool = True, @@ -2213,21 +2209,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi """ Automatically checks and dispatches to a default attention implementation. In order of priority: 1. An implementation specified in `config._attn_implementation` (due for example to the argument attn_implementation="sdpa" in from_pretrained). - 2. DEPRECATED: if use_flash_attention_2 is set to `True` and `flash_attn` is available, flash attention. (`LlamaFlashAttention` for example) - 3. SDPA implementation, if available and supported by the model type. (`LlamaSdpaAttention` for example) - 4. The default model's implementation otherwise (`LlamaAttention` for example) . + 2. SDPA implementation, if available and supported by the model type. (`LlamaSdpaAttention` for example) + 3. The default model's implementation otherwise (`LlamaAttention` for example) . """ # Here we use config._attn_implementation_internal to check whether the attention implementation was explicitly set by the user. # The property `PretrainedConfig._attn_implementation` is never `None`, for backward compatibility (always fall back on "eager"). # The `hasattr` here is used as some Transformers tests for some reason do not call PretrainedConfig __init__ (e.g. test_no_super_init_config_and_model) requested_attn_implementation = None if hasattr(config, "_attn_implementation_internal") and config._attn_implementation_internal is not None: - if config._attn_implementation != "flash_attention_2" and use_flash_attention_2: - raise ValueError( - f'Both attn_implementation="{config._attn_implementation}" and `use_flash_attention_2=True` were used when loading the model, which are not compatible.' - ' We recommend to just use `attn_implementation="flash_attention_2"` when loading the model.' - ) - if isinstance(config._attn_implementation, str) and re.match( r"^[^/:]+/[^/:]+:[^/:]+$", config._attn_implementation ): @@ -2292,12 +2281,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi if sub_config is not None: sub_config._attn_implementation_internal = curr_attn_implementation - if use_flash_attention_2: - logger.warning_once( - 'The model was loaded with use_flash_attention_2=True, which is deprecated and may be removed in a future release. Please use `attn_implementation="flash_attention_2"` instead.' - ) - config._attn_implementation = "flash_attention_2" - if config._attn_implementation == "flash_attention_2": cls._check_and_enable_flash_attn_2( config, @@ -2309,10 +2292,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi elif requested_attn_implementation == "flex_attention": config = cls._check_and_enable_flex_attn(config, hard_check_only=True) elif requested_attn_implementation in [None, "sdpa"] and not is_torch_xla_available(): - # use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif. + # flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif. config = cls._check_and_enable_sdpa( config, - hard_check_only=False if requested_attn_implementation is None else True, + hard_check_only=requested_attn_implementation is not None, ) if ( @@ -4256,7 +4239,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi variant = kwargs.pop("variant", None) adapter_kwargs = kwargs.pop("adapter_kwargs", {}) adapter_name = kwargs.pop("adapter_name", "default") - use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False) generation_config = kwargs.pop("generation_config", None) gguf_file = kwargs.pop("gguf_file", None) tp_plan = kwargs.pop("tp_plan", None) @@ -4618,7 +4600,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi if not getattr(config, "_attn_implementation_autoset", False): config = cls._autoset_attn_implementation( config, - use_flash_attention_2=use_flash_attention_2, torch_dtype=torch_dtype, device_map=device_map, ) diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index c0e99097152..0b254bd73a2 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -615,7 +615,6 @@ class ModernBertPreTrainedModel(PreTrainedModel): def _autoset_attn_implementation( cls, config, - use_flash_attention_2: bool = False, torch_dtype: Optional[torch.dtype] = None, device_map: Optional[Union[str, Dict[str, int]]] = None, check_device_map: bool = True, @@ -638,8 +637,7 @@ class ModernBertPreTrainedModel(PreTrainedModel): config._attn_implementation_internal = None return super()._autoset_attn_implementation( config, - use_flash_attention_2=use_flash_attention_2, - torch_dtype=torch.float16, + torch_dtype=torch_dtype, device_map=device_map, check_device_map=check_device_map, ) diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 137673cfa59..18f2bb8beb7 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -817,7 +817,6 @@ class ModernBertPreTrainedModel(PreTrainedModel): def _autoset_attn_implementation( cls, config, - use_flash_attention_2: bool = False, torch_dtype: Optional[torch.dtype] = None, device_map: Optional[Union[str, Dict[str, int]]] = None, check_device_map: bool = True, @@ -840,8 +839,7 @@ class ModernBertPreTrainedModel(PreTrainedModel): config._attn_implementation_internal = None return super()._autoset_attn_implementation( config, - use_flash_attention_2=use_flash_attention_2, - torch_dtype=torch.float16, + torch_dtype=torch_dtype, device_map=device_map, check_device_map=check_device_map, ) diff --git a/tests/models/diffllama/test_modeling_diffllama.py b/tests/models/diffllama/test_modeling_diffllama.py index 50525a3ec4e..0b3ef078905 100644 --- a/tests/models/diffllama/test_modeling_diffllama.py +++ b/tests/models/diffllama/test_modeling_diffllama.py @@ -488,7 +488,7 @@ class DiffLlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester model.save_pretrained(tmp_dir) new_model = DiffLlamaForCausalLM.from_pretrained( - tmp_dir, use_flash_attention_2=True, torch_dtype=torch.float16 + tmp_dir, attn_implementation="flash_attention_2", torch_dtype=torch.float16 ).to("cuda") self.assertTrue(new_model.config._attn_implementation == "flash_attention_2")