mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Fix masking utils (#38783)
* fix * Update masking_utils.py * Update masking_utils.py
This commit is contained in:
parent
7c58336949
commit
887054c714
@ -684,8 +684,8 @@ def create_causal_mask(
|
||||
useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
|
||||
"""
|
||||
# If we have an HybridCache structure, here we want to create the mask for the full layers
|
||||
if past_key_values is not None and hasattr(past_key_values, "is_sliding"):
|
||||
layer_idx = past_key_values.is_sliding.index(False) if False in past_key_values.is_sliding else 0
|
||||
if hasattr(past_key_values, "is_sliding") and False in past_key_values.is_sliding:
|
||||
layer_idx = past_key_values.is_sliding.index(False)
|
||||
else:
|
||||
layer_idx = 0
|
||||
|
||||
@ -765,9 +765,9 @@ def create_sliding_window_causal_mask(
|
||||
An optional mask function to combine with the sliding causal mask function (by doing the intersection of both). This is
|
||||
useful to easily overlay another mask on top of the sliding causal one, for example for image tokens handling.
|
||||
"""
|
||||
# If we have an HybridCache structure, here we want to create the mask for the full layers
|
||||
if past_key_values is not None and hasattr(past_key_values, "is_sliding"):
|
||||
layer_idx = past_key_values.is_sliding.index(False) if False in past_key_values.is_sliding else 0
|
||||
# If we have an HybridCache structure, here we want to create the mask for the sliding layers
|
||||
if hasattr(past_key_values, "is_sliding") and True in past_key_values.is_sliding:
|
||||
layer_idx = past_key_values.is_sliding.index(True)
|
||||
else:
|
||||
layer_idx = 0
|
||||
|
||||
@ -852,9 +852,9 @@ def create_chunked_causal_mask(
|
||||
An optional mask function to combine with the chunked causal mask function (by doing the intersection of both). This is
|
||||
useful to easily overlay another mask on top of the chunked causal one, for example for image tokens handling.
|
||||
"""
|
||||
# If we have an HybridCache structure, here we want to create the mask for the full layers
|
||||
if past_key_values is not None and hasattr(past_key_values, "is_sliding"):
|
||||
layer_idx = past_key_values.is_sliding.index(False) if False in past_key_values.is_sliding else 0
|
||||
# If we have an HybridCache structure, here we want to create the mask for the sliding layers
|
||||
if hasattr(past_key_values, "is_sliding") and True in past_key_values.is_sliding:
|
||||
layer_idx = past_key_values.is_sliding.index(True)
|
||||
else:
|
||||
layer_idx = 0
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user