mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
fix Aggressive boolean conversion breaking packing implementations
This commit is contained in:
parent
f8b88866f5
commit
675df92a58
@ -203,6 +203,14 @@ def _ignore_causal_mask_sdpa(
|
||||
mask_indices += kv_offset
|
||||
padding_mask = padding_mask[:, mask_indices]
|
||||
|
||||
# Check if all tokens are non-padding, handling both boolean and non-boolean masks
|
||||
all_tokens_valid = True
|
||||
if padding_mask is not None:
|
||||
if padding_mask.dtype == torch.bool:
|
||||
all_tokens_valid = padding_mask.all()
|
||||
else:
|
||||
all_tokens_valid = not (padding_mask == 0).any()
|
||||
|
||||
# When using `torch.export` or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is
|
||||
# hard-coded to the forward. If a user exports a model with query_length > 1, the exported model will hard-code `is_causal=True`
|
||||
# which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108). Thus, we only set
|
||||
@ -214,7 +222,7 @@ def _ignore_causal_mask_sdpa(
|
||||
# in this case we need to add special patterns to the mask so cannot be skipped otherwise
|
||||
and (local_attention_size is None or kv_length < local_attention_size)
|
||||
# In this case, we need to add padding to the mask, so cannot be skipped otherwise
|
||||
and (padding_mask is None or padding_mask.all())
|
||||
and (padding_mask is None or all_tokens_valid)
|
||||
):
|
||||
return True
|
||||
|
||||
@ -513,9 +521,14 @@ def flash_attention_mask(
|
||||
# Here we need to slice from the right if using sliding or chunked (for full attention, this is equivalent to doing nothing)
|
||||
attention_mask = attention_mask[:, -kv_length:]
|
||||
# We only return an actual mask if there is at least 1 padding token, otherwise we return `None` and use `is_causal` in FA2
|
||||
# (note that the attention_mask is a boolean dtype here)
|
||||
if attention_mask.all():
|
||||
attention_mask = None
|
||||
# For boolean masks, check if all values are True
|
||||
# For non-boolean masks, check if there are any zero values (indicating padding)
|
||||
if attention_mask.dtype == torch.bool:
|
||||
if attention_mask.all():
|
||||
attention_mask = None
|
||||
else:
|
||||
if not (attention_mask == 0).any():
|
||||
attention_mask = None
|
||||
|
||||
return attention_mask
|
||||
|
||||
@ -645,7 +658,14 @@ def _preprocess_mask_arguments(
|
||||
|
||||
# Move the mask to correct device, and potentially switch dtype for efficiency
|
||||
if attention_mask is not None and attention_mask.ndim == 2:
|
||||
attention_mask = attention_mask.to(device=cache_position.device, dtype=torch.bool)
|
||||
unique_values = torch.unique(attention_mask)
|
||||
is_binary_mask = len(unique_values) <= 2 and all(
|
||||
torch.isclose(val, torch.tensor(0.0)) or torch.isclose(val, torch.tensor(1.0)) for val in unique_values
|
||||
)
|
||||
if is_binary_mask:
|
||||
attention_mask = attention_mask.to(device=cache_position.device, dtype=torch.bool)
|
||||
else:
|
||||
attention_mask = attention_mask.to(device=cache_position.device)
|
||||
|
||||
# If using a cache, it can give all informations about mask sizes based on seen tokens
|
||||
if past_key_values is not None:
|
||||
|
@ -2318,6 +2318,114 @@ class AttentionMaskTester(unittest.TestCase):
|
||||
is_min = mask_4d_values == torch.finfo(mask_4d.dtype).min
|
||||
assert torch.logical_or(is_inf, is_min).all()
|
||||
|
||||
def test_non_binary_attention_masks_preserved(self):
|
||||
"""
|
||||
Test that non-binary attention masks (like those used in packing) are preserved
|
||||
and not converted to boolean dtype, which would break the packing information.
|
||||
|
||||
This is a regression test for PR #37866 where aggressive bool conversion
|
||||
broke packing implementations that use fractional values in attention masks.
|
||||
"""
|
||||
from transformers import LlamaConfig
|
||||
from transformers.masking_utils import _preprocess_mask_arguments, flash_attention_mask
|
||||
|
||||
# Create a non-binary attention mask with fractional values (simulating packing)
|
||||
batch_size, seq_len = 2, 8
|
||||
attention_mask = torch.tensor(
|
||||
[
|
||||
[1.5, 0.8, 1.2, 0.0, 1.0, 0.9, 1.3, 0.7], # Non-binary values for packed sequences
|
||||
[1.0, 1.0, 0.5, 0.0, 0.0, 1.8, 1.1, 0.2], # Mix of binary and non-binary values
|
||||
],
|
||||
dtype=torch.float32,
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
# Set up minimal inputs
|
||||
config = LlamaConfig(
|
||||
vocab_size=1000,
|
||||
hidden_size=128,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=2,
|
||||
)
|
||||
input_embeds = torch.randn(batch_size, seq_len, config.hidden_size, device=torch_device)
|
||||
cache_position = torch.arange(seq_len, device=torch_device)
|
||||
|
||||
# Test _preprocess_mask_arguments preserves non-binary values
|
||||
early_exit, processed_mask, kv_length, kv_offset = _preprocess_mask_arguments(
|
||||
config, input_embeds, attention_mask, cache_position, None, None
|
||||
)
|
||||
|
||||
# The original mask dtype and values should be preserved (not converted to bool)
|
||||
self.assertEqual(
|
||||
processed_mask.dtype, torch.float32, "Non-binary attention mask should not be converted to bool dtype"
|
||||
)
|
||||
|
||||
# Check that the original fractional values are preserved
|
||||
self.assertTrue(
|
||||
torch.allclose(processed_mask, attention_mask), "Non-binary attention mask values should be preserved"
|
||||
)
|
||||
|
||||
# Test flash_attention_mask properly handles non-binary masks
|
||||
flash_mask = flash_attention_mask(
|
||||
batch_size,
|
||||
cache_position,
|
||||
kv_length,
|
||||
kv_offset=kv_offset,
|
||||
attention_mask=processed_mask,
|
||||
)
|
||||
|
||||
# For non-binary masks with zeros, we should get a mask (not None)
|
||||
# since there are padding tokens (0.0 values)
|
||||
self.assertIsNotNone(flash_mask, "Flash attention mask should not be None when there are padding tokens")
|
||||
|
||||
# Test that binary masks still work correctly
|
||||
binary_mask = torch.tensor(
|
||||
[
|
||||
[1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0],
|
||||
[1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0],
|
||||
],
|
||||
dtype=torch.float32,
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
early_exit_binary, processed_binary_mask, kv_length_binary, kv_offset_binary = _preprocess_mask_arguments(
|
||||
config, input_embeds, binary_mask, cache_position, None, None
|
||||
)
|
||||
|
||||
# Binary masks should be converted to bool for efficiency
|
||||
self.assertEqual(
|
||||
processed_binary_mask.dtype,
|
||||
torch.bool,
|
||||
"Binary attention mask should be converted to bool dtype for efficiency",
|
||||
)
|
||||
|
||||
# Test edge case: all non-padding tokens (should return None for flash attention)
|
||||
all_valid_non_binary = torch.tensor(
|
||||
[
|
||||
[1.5, 0.8, 1.2, 1.0, 1.0, 0.9, 1.3, 0.7],
|
||||
[1.0, 1.0, 0.5, 1.2, 1.8, 1.8, 1.1, 0.2],
|
||||
],
|
||||
dtype=torch.float32,
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
early_exit_valid, processed_valid_mask, kv_length_valid, kv_offset_valid = _preprocess_mask_arguments(
|
||||
config, input_embeds, all_valid_non_binary, cache_position, None, None
|
||||
)
|
||||
|
||||
flash_mask_valid = flash_attention_mask(
|
||||
batch_size,
|
||||
cache_position,
|
||||
kv_length_valid,
|
||||
kv_offset=kv_offset_valid,
|
||||
attention_mask=processed_valid_mask,
|
||||
)
|
||||
|
||||
# When all tokens are valid (no padding), flash attention should return None
|
||||
self.assertIsNone(
|
||||
flash_mask_valid, "Flash attention mask should be None when all tokens are valid (no padding)"
|
||||
)
|
||||
|
||||
def check_to_4d(self, mask_converter, q_len, kv_len, additional_mask=None, bsz=3):
|
||||
mask_2d = torch.ones((bsz, kv_len), device=torch_device, dtype=torch.long)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user