mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00
Remove deprecated use_flash_attention_2 parameter (#37131)
Signed-off-by: cyy <cyyever@outlook.com>
This commit is contained in:
parent
51d732709e
commit
fde1120b6c
@ -2157,8 +2157,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
if isinstance(torch_dtype, str):
|
if isinstance(torch_dtype, str):
|
||||||
torch_dtype = getattr(torch, torch_dtype)
|
torch_dtype = getattr(torch, torch_dtype)
|
||||||
|
|
||||||
use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False)
|
|
||||||
|
|
||||||
# override default dtype if needed
|
# override default dtype if needed
|
||||||
dtype_orig = None
|
dtype_orig = None
|
||||||
if torch_dtype is not None:
|
if torch_dtype is not None:
|
||||||
@ -2177,7 +2175,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
if not getattr(config, "_attn_implementation_autoset", False):
|
if not getattr(config, "_attn_implementation_autoset", False):
|
||||||
config = cls._autoset_attn_implementation(
|
config = cls._autoset_attn_implementation(
|
||||||
config,
|
config,
|
||||||
use_flash_attention_2=use_flash_attention_2,
|
|
||||||
check_device_map=False,
|
check_device_map=False,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
)
|
)
|
||||||
@ -2205,7 +2202,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
def _autoset_attn_implementation(
|
def _autoset_attn_implementation(
|
||||||
cls,
|
cls,
|
||||||
config,
|
config,
|
||||||
use_flash_attention_2: bool = False,
|
|
||||||
torch_dtype: Optional[torch.dtype] = None,
|
torch_dtype: Optional[torch.dtype] = None,
|
||||||
device_map: Optional[Union[str, Dict[str, int]]] = None,
|
device_map: Optional[Union[str, Dict[str, int]]] = None,
|
||||||
check_device_map: bool = True,
|
check_device_map: bool = True,
|
||||||
@ -2213,21 +2209,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
"""
|
"""
|
||||||
Automatically checks and dispatches to a default attention implementation. In order of priority:
|
Automatically checks and dispatches to a default attention implementation. In order of priority:
|
||||||
1. An implementation specified in `config._attn_implementation` (due for example to the argument attn_implementation="sdpa" in from_pretrained).
|
1. An implementation specified in `config._attn_implementation` (due for example to the argument attn_implementation="sdpa" in from_pretrained).
|
||||||
2. DEPRECATED: if use_flash_attention_2 is set to `True` and `flash_attn` is available, flash attention. (`LlamaFlashAttention` for example)
|
2. SDPA implementation, if available and supported by the model type. (`LlamaSdpaAttention` for example)
|
||||||
3. SDPA implementation, if available and supported by the model type. (`LlamaSdpaAttention` for example)
|
3. The default model's implementation otherwise (`LlamaAttention` for example) .
|
||||||
4. The default model's implementation otherwise (`LlamaAttention` for example) .
|
|
||||||
"""
|
"""
|
||||||
# Here we use config._attn_implementation_internal to check whether the attention implementation was explicitly set by the user.
|
# Here we use config._attn_implementation_internal to check whether the attention implementation was explicitly set by the user.
|
||||||
# The property `PretrainedConfig._attn_implementation` is never `None`, for backward compatibility (always fall back on "eager").
|
# The property `PretrainedConfig._attn_implementation` is never `None`, for backward compatibility (always fall back on "eager").
|
||||||
# The `hasattr` here is used as some Transformers tests for some reason do not call PretrainedConfig __init__ (e.g. test_no_super_init_config_and_model)
|
# The `hasattr` here is used as some Transformers tests for some reason do not call PretrainedConfig __init__ (e.g. test_no_super_init_config_and_model)
|
||||||
requested_attn_implementation = None
|
requested_attn_implementation = None
|
||||||
if hasattr(config, "_attn_implementation_internal") and config._attn_implementation_internal is not None:
|
if hasattr(config, "_attn_implementation_internal") and config._attn_implementation_internal is not None:
|
||||||
if config._attn_implementation != "flash_attention_2" and use_flash_attention_2:
|
|
||||||
raise ValueError(
|
|
||||||
f'Both attn_implementation="{config._attn_implementation}" and `use_flash_attention_2=True` were used when loading the model, which are not compatible.'
|
|
||||||
' We recommend to just use `attn_implementation="flash_attention_2"` when loading the model.'
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(config._attn_implementation, str) and re.match(
|
if isinstance(config._attn_implementation, str) and re.match(
|
||||||
r"^[^/:]+/[^/:]+:[^/:]+$", config._attn_implementation
|
r"^[^/:]+/[^/:]+:[^/:]+$", config._attn_implementation
|
||||||
):
|
):
|
||||||
@ -2292,12 +2281,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
if sub_config is not None:
|
if sub_config is not None:
|
||||||
sub_config._attn_implementation_internal = curr_attn_implementation
|
sub_config._attn_implementation_internal = curr_attn_implementation
|
||||||
|
|
||||||
if use_flash_attention_2:
|
|
||||||
logger.warning_once(
|
|
||||||
'The model was loaded with use_flash_attention_2=True, which is deprecated and may be removed in a future release. Please use `attn_implementation="flash_attention_2"` instead.'
|
|
||||||
)
|
|
||||||
config._attn_implementation = "flash_attention_2"
|
|
||||||
|
|
||||||
if config._attn_implementation == "flash_attention_2":
|
if config._attn_implementation == "flash_attention_2":
|
||||||
cls._check_and_enable_flash_attn_2(
|
cls._check_and_enable_flash_attn_2(
|
||||||
config,
|
config,
|
||||||
@ -2309,10 +2292,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
elif requested_attn_implementation == "flex_attention":
|
elif requested_attn_implementation == "flex_attention":
|
||||||
config = cls._check_and_enable_flex_attn(config, hard_check_only=True)
|
config = cls._check_and_enable_flex_attn(config, hard_check_only=True)
|
||||||
elif requested_attn_implementation in [None, "sdpa"] and not is_torch_xla_available():
|
elif requested_attn_implementation in [None, "sdpa"] and not is_torch_xla_available():
|
||||||
# use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif.
|
# flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif.
|
||||||
config = cls._check_and_enable_sdpa(
|
config = cls._check_and_enable_sdpa(
|
||||||
config,
|
config,
|
||||||
hard_check_only=False if requested_attn_implementation is None else True,
|
hard_check_only=requested_attn_implementation is not None,
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
@ -4256,7 +4239,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
variant = kwargs.pop("variant", None)
|
variant = kwargs.pop("variant", None)
|
||||||
adapter_kwargs = kwargs.pop("adapter_kwargs", {})
|
adapter_kwargs = kwargs.pop("adapter_kwargs", {})
|
||||||
adapter_name = kwargs.pop("adapter_name", "default")
|
adapter_name = kwargs.pop("adapter_name", "default")
|
||||||
use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False)
|
|
||||||
generation_config = kwargs.pop("generation_config", None)
|
generation_config = kwargs.pop("generation_config", None)
|
||||||
gguf_file = kwargs.pop("gguf_file", None)
|
gguf_file = kwargs.pop("gguf_file", None)
|
||||||
tp_plan = kwargs.pop("tp_plan", None)
|
tp_plan = kwargs.pop("tp_plan", None)
|
||||||
@ -4618,7 +4600,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
if not getattr(config, "_attn_implementation_autoset", False):
|
if not getattr(config, "_attn_implementation_autoset", False):
|
||||||
config = cls._autoset_attn_implementation(
|
config = cls._autoset_attn_implementation(
|
||||||
config,
|
config,
|
||||||
use_flash_attention_2=use_flash_attention_2,
|
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
device_map=device_map,
|
device_map=device_map,
|
||||||
)
|
)
|
||||||
|
@ -615,7 +615,6 @@ class ModernBertPreTrainedModel(PreTrainedModel):
|
|||||||
def _autoset_attn_implementation(
|
def _autoset_attn_implementation(
|
||||||
cls,
|
cls,
|
||||||
config,
|
config,
|
||||||
use_flash_attention_2: bool = False,
|
|
||||||
torch_dtype: Optional[torch.dtype] = None,
|
torch_dtype: Optional[torch.dtype] = None,
|
||||||
device_map: Optional[Union[str, Dict[str, int]]] = None,
|
device_map: Optional[Union[str, Dict[str, int]]] = None,
|
||||||
check_device_map: bool = True,
|
check_device_map: bool = True,
|
||||||
@ -638,8 +637,7 @@ class ModernBertPreTrainedModel(PreTrainedModel):
|
|||||||
config._attn_implementation_internal = None
|
config._attn_implementation_internal = None
|
||||||
return super()._autoset_attn_implementation(
|
return super()._autoset_attn_implementation(
|
||||||
config,
|
config,
|
||||||
use_flash_attention_2=use_flash_attention_2,
|
torch_dtype=torch_dtype,
|
||||||
torch_dtype=torch.float16,
|
|
||||||
device_map=device_map,
|
device_map=device_map,
|
||||||
check_device_map=check_device_map,
|
check_device_map=check_device_map,
|
||||||
)
|
)
|
||||||
|
@ -817,7 +817,6 @@ class ModernBertPreTrainedModel(PreTrainedModel):
|
|||||||
def _autoset_attn_implementation(
|
def _autoset_attn_implementation(
|
||||||
cls,
|
cls,
|
||||||
config,
|
config,
|
||||||
use_flash_attention_2: bool = False,
|
|
||||||
torch_dtype: Optional[torch.dtype] = None,
|
torch_dtype: Optional[torch.dtype] = None,
|
||||||
device_map: Optional[Union[str, Dict[str, int]]] = None,
|
device_map: Optional[Union[str, Dict[str, int]]] = None,
|
||||||
check_device_map: bool = True,
|
check_device_map: bool = True,
|
||||||
@ -840,8 +839,7 @@ class ModernBertPreTrainedModel(PreTrainedModel):
|
|||||||
config._attn_implementation_internal = None
|
config._attn_implementation_internal = None
|
||||||
return super()._autoset_attn_implementation(
|
return super()._autoset_attn_implementation(
|
||||||
config,
|
config,
|
||||||
use_flash_attention_2=use_flash_attention_2,
|
torch_dtype=torch_dtype,
|
||||||
torch_dtype=torch.float16,
|
|
||||||
device_map=device_map,
|
device_map=device_map,
|
||||||
check_device_map=check_device_map,
|
check_device_map=check_device_map,
|
||||||
)
|
)
|
||||||
|
@ -488,7 +488,7 @@ class DiffLlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
|||||||
model.save_pretrained(tmp_dir)
|
model.save_pretrained(tmp_dir)
|
||||||
|
|
||||||
new_model = DiffLlamaForCausalLM.from_pretrained(
|
new_model = DiffLlamaForCausalLM.from_pretrained(
|
||||||
tmp_dir, use_flash_attention_2=True, torch_dtype=torch.float16
|
tmp_dir, attn_implementation="flash_attention_2", torch_dtype=torch.float16
|
||||||
).to("cuda")
|
).to("cuda")
|
||||||
|
|
||||||
self.assertTrue(new_model.config._attn_implementation == "flash_attention_2")
|
self.assertTrue(new_model.config._attn_implementation == "flash_attention_2")
|
||||||
|
Loading…
Reference in New Issue
Block a user