Fix all import errors based on older torch versions (#38370)
Some checks are pending
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Waiting to run
Build documentation / build (push) Waiting to run
New model PR merged notification / Notify new model (push) Waiting to run
Slow tests on important models (on Push - A10) / Get all modified files (push) Waiting to run
Slow tests on important models (on Push - A10) / Slow & FA2 tests (push) Blocked by required conditions
Self-hosted runner (push-caller) / Check if setup was changed (push) Waiting to run
Self-hosted runner (push-caller) / build-docker-containers (push) Blocked by required conditions
Self-hosted runner (push-caller) / Trigger Push CI (push) Blocked by required conditions
Secret Leaks / trufflehog (push) Waiting to run
Update Transformers metadata / build_and_package (push) Waiting to run

* Update masking_utils.py

* fix

* fix

* fix

* Update masking_utils.py

* Update executorch.py

* fix
This commit is contained in:
Cyril Vallez 2025-05-26 12:11:54 +02:00 committed by GitHub
parent d03a3ca692
commit 9f0402bc4d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 20 additions and 15 deletions

View File

@ -25,11 +25,16 @@ from .utils.import_utils import is_torch_flex_attn_available, is_torch_greater_o
if is_torch_flex_attn_available():
from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
from torch.nn.attention.flex_attention import BlockMask, create_block_mask
else:
# Register a fake type to avoid crashing for annotations and `isinstance` checks
BlockMask = torch.Tensor
_is_torch_greater_or_equal_than_2_5 = is_torch_greater_or_equal("2.5", accept_dev=True)
_is_torch_greater_or_equal_than_2_6 = is_torch_greater_or_equal("2.6", accept_dev=True)
if _is_torch_greater_or_equal_than_2_6:
from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
def and_masks(*mask_functions: list[Callable]) -> Callable:
@ -415,14 +420,14 @@ def sdpa_mask_older_torch(
# Due to a bug in versions of torch<2.5, we need to update the mask in case a query is not attending to any
# tokens (due to padding). See details in https://github.com/pytorch/pytorch/issues/110213
if allow_torch_fix:
if not _is_torch_greater_or_equal_than_2_5 and allow_torch_fix:
causal_mask |= torch.all(~causal_mask, dim=-1, keepdim=True)
return causal_mask
# We use the version with newer torch whenever possible, as it is more general and can handle arbitrary mask functions
# (especially mask_function indexing a tensor, such as the padding mask function)
sdpa_mask = sdpa_mask_recent_torch if is_torch_flex_attn_available() else sdpa_mask_older_torch
sdpa_mask = sdpa_mask_recent_torch if _is_torch_greater_or_equal_than_2_6 else sdpa_mask_older_torch
def eager_mask(
@ -522,7 +527,7 @@ def flex_attention_mask(
mask_function: Callable = causal_mask_function,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> "BlockMask":
) -> BlockMask:
"""
Create a 4D block mask which is a compressed representation of the full 4D block causal mask. BlockMask is essential
for performant computation of flex attention. See: https://pytorch.org/blog/flexattention/
@ -652,7 +657,7 @@ def create_causal_mask(
past_key_values: Optional[Cache],
or_mask_function: Optional[Callable] = None,
and_mask_function: Optional[Callable] = None,
) -> Optional[Union[torch.Tensor, "BlockMask"]]:
) -> Optional[Union[torch.Tensor, BlockMask]]:
"""
Create a standard causal mask based on the attention implementation used (stored in the config). If `past_key_values`
has an HybridCache structure, this function will return the mask corresponding to one of the "full_attention" layers (to align
@ -700,12 +705,12 @@ def create_causal_mask(
# Allow slight deviations from causal mask
if or_mask_function is not None:
if not _is_torch_greater_or_equal_than_2_5:
if not _is_torch_greater_or_equal_than_2_6:
raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.5")
mask_factory_function = or_masks(mask_factory_function, or_mask_function)
allow_is_causal_skip = False
if and_mask_function is not None:
if not _is_torch_greater_or_equal_than_2_5:
if not _is_torch_greater_or_equal_than_2_6:
raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.5")
mask_factory_function = and_masks(mask_factory_function, and_mask_function)
allow_is_causal_skip = False
@ -733,7 +738,7 @@ def create_sliding_window_causal_mask(
past_key_values: Optional[Cache],
or_mask_function: Optional[Callable] = None,
and_mask_function: Optional[Callable] = None,
) -> Optional[Union[torch.Tensor, "BlockMask"]]:
) -> Optional[Union[torch.Tensor, BlockMask]]:
"""
Create a sliding window causal mask based on the attention implementation used (stored in the config). This type
of attention pattern was mostly democratized by Mistral. If `past_key_values` has an HybridCache structure, this
@ -786,12 +791,12 @@ def create_sliding_window_causal_mask(
# Allow slight deviations from sliding causal mask
if or_mask_function is not None:
if not _is_torch_greater_or_equal_than_2_5:
if not _is_torch_greater_or_equal_than_2_6:
raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.5")
mask_factory_function = or_masks(mask_factory_function, or_mask_function)
allow_is_causal_skip = False
if and_mask_function is not None:
if not _is_torch_greater_or_equal_than_2_5:
if not _is_torch_greater_or_equal_than_2_6:
raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.5")
mask_factory_function = and_masks(mask_factory_function, and_mask_function)
allow_is_causal_skip = False
@ -820,7 +825,7 @@ def create_chunked_causal_mask(
past_key_values: Optional[Cache],
or_mask_function: Optional[Callable] = None,
and_mask_function: Optional[Callable] = None,
) -> Optional[Union[torch.Tensor, "BlockMask"]]:
) -> Optional[Union[torch.Tensor, BlockMask]]:
"""
Create a chunked attention causal mask based on the attention implementation used (stored in the config). This type
of attention pattern was mostly democratized by Llama4. If `past_key_values` has an HybridCache structure, this
@ -880,12 +885,12 @@ def create_chunked_causal_mask(
# Allow slight deviations from chunked causal mask
if or_mask_function is not None:
if not _is_torch_greater_or_equal_than_2_5:
if not _is_torch_greater_or_equal_than_2_6:
raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.5")
mask_factory_function = or_masks(mask_factory_function, or_mask_function)
allow_is_causal_skip = False
if and_mask_function is not None:
if not _is_torch_greater_or_equal_than_2_5:
if not _is_torch_greater_or_equal_than_2_6:
raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.5")
mask_factory_function = and_masks(mask_factory_function, and_mask_function)
allow_is_causal_skip = False

View File

@ -2078,7 +2078,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
if plan := getattr(module, "_tp_plan", None):
self._tp_plan.update({f"{name}.{k}": v for k, v in plan.copy().items()})
if self._tp_plan is not None and is_torch_greater_or_equal("2.3"):
if self._tp_plan is not None and is_torch_greater_or_equal("2.5"):
for _, v in self._tp_plan.items():
if v not in ALL_PARALLEL_STYLES:
raise ValueError(