Fix masking utils (#38783)

* fix

* Update masking_utils.py

* Update masking_utils.py
This commit is contained in:
Cyril Vallez 2025-06-12 11:00:46 +02:00 committed by GitHub
parent 7c58336949
commit 887054c714
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -684,8 +684,8 @@ def create_causal_mask(
useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
"""
# If we have an HybridCache structure, here we want to create the mask for the full layers
if past_key_values is not None and hasattr(past_key_values, "is_sliding"):
layer_idx = past_key_values.is_sliding.index(False) if False in past_key_values.is_sliding else 0
if hasattr(past_key_values, "is_sliding") and False in past_key_values.is_sliding:
layer_idx = past_key_values.is_sliding.index(False)
else:
layer_idx = 0
@ -765,9 +765,9 @@ def create_sliding_window_causal_mask(
An optional mask function to combine with the sliding causal mask function (by doing the intersection of both). This is
useful to easily overlay another mask on top of the sliding causal one, for example for image tokens handling.
"""
# If we have an HybridCache structure, here we want to create the mask for the full layers
if past_key_values is not None and hasattr(past_key_values, "is_sliding"):
layer_idx = past_key_values.is_sliding.index(False) if False in past_key_values.is_sliding else 0
# If we have an HybridCache structure, here we want to create the mask for the sliding layers
if hasattr(past_key_values, "is_sliding") and True in past_key_values.is_sliding:
layer_idx = past_key_values.is_sliding.index(True)
else:
layer_idx = 0
@ -852,9 +852,9 @@ def create_chunked_causal_mask(
An optional mask function to combine with the chunked causal mask function (by doing the intersection of both). This is
useful to easily overlay another mask on top of the chunked causal one, for example for image tokens handling.
"""
# If we have an HybridCache structure, here we want to create the mask for the full layers
if past_key_values is not None and hasattr(past_key_values, "is_sliding"):
layer_idx = past_key_values.is_sliding.index(False) if False in past_key_values.is_sliding else 0
# If we have an HybridCache structure, here we want to create the mask for the sliding layers
if hasattr(past_key_values, "is_sliding") and True in past_key_values.is_sliding:
layer_idx = past_key_values.is_sliding.index(True)
else:
layer_idx = 0