fix Aggressive boolean conversion breaking packing implementations

This commit is contained in:
Kashif Rasul 2025-07-01 13:58:06 +02:00
parent f8b88866f5
commit 675df92a58
2 changed files with 133 additions and 5 deletions

View File

@ -203,6 +203,14 @@ def _ignore_causal_mask_sdpa(
mask_indices += kv_offset
padding_mask = padding_mask[:, mask_indices]
# Check if all tokens are non-padding, handling both boolean and non-boolean masks
all_tokens_valid = True
if padding_mask is not None:
if padding_mask.dtype == torch.bool:
all_tokens_valid = padding_mask.all()
else:
all_tokens_valid = not (padding_mask == 0).any()
# When using `torch.export` or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is
# hard-coded to the forward. If a user exports a model with query_length > 1, the exported model will hard-code `is_causal=True`
# which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108). Thus, we only set
@ -214,7 +222,7 @@ def _ignore_causal_mask_sdpa(
# in this case we need to add special patterns to the mask so cannot be skipped otherwise
and (local_attention_size is None or kv_length < local_attention_size)
# In this case, we need to add padding to the mask, so cannot be skipped otherwise
and (padding_mask is None or padding_mask.all())
and (padding_mask is None or all_tokens_valid)
):
return True
@ -513,9 +521,14 @@ def flash_attention_mask(
# Here we need to slice from the right if using sliding or chunked (for full attention, this is equivalent to doing nothing)
attention_mask = attention_mask[:, -kv_length:]
# We only return an actual mask if there is at least 1 padding token, otherwise we return `None` and use `is_causal` in FA2
# (note that the attention_mask is a boolean dtype here)
if attention_mask.all():
attention_mask = None
# For boolean masks, check if all values are True
# For non-boolean masks, check if there are any zero values (indicating padding)
if attention_mask.dtype == torch.bool:
if attention_mask.all():
attention_mask = None
else:
if not (attention_mask == 0).any():
attention_mask = None
return attention_mask
@ -645,7 +658,14 @@ def _preprocess_mask_arguments(
# Move the mask to correct device, and potentially switch dtype for efficiency
if attention_mask is not None and attention_mask.ndim == 2:
attention_mask = attention_mask.to(device=cache_position.device, dtype=torch.bool)
unique_values = torch.unique(attention_mask)
is_binary_mask = len(unique_values) <= 2 and all(
torch.isclose(val, torch.tensor(0.0)) or torch.isclose(val, torch.tensor(1.0)) for val in unique_values
)
if is_binary_mask:
attention_mask = attention_mask.to(device=cache_position.device, dtype=torch.bool)
else:
attention_mask = attention_mask.to(device=cache_position.device)
# If using a cache, it can give all informations about mask sizes based on seen tokens
if past_key_values is not None:

View File

@ -2318,6 +2318,114 @@ class AttentionMaskTester(unittest.TestCase):
is_min = mask_4d_values == torch.finfo(mask_4d.dtype).min
assert torch.logical_or(is_inf, is_min).all()
def test_non_binary_attention_masks_preserved(self):
"""
Test that non-binary attention masks (like those used in packing) are preserved
and not converted to boolean dtype, which would break the packing information.
This is a regression test for PR #37866 where aggressive bool conversion
broke packing implementations that use fractional values in attention masks.
"""
from transformers import LlamaConfig
from transformers.masking_utils import _preprocess_mask_arguments, flash_attention_mask
# Create a non-binary attention mask with fractional values (simulating packing)
batch_size, seq_len = 2, 8
attention_mask = torch.tensor(
[
[1.5, 0.8, 1.2, 0.0, 1.0, 0.9, 1.3, 0.7], # Non-binary values for packed sequences
[1.0, 1.0, 0.5, 0.0, 0.0, 1.8, 1.1, 0.2], # Mix of binary and non-binary values
],
dtype=torch.float32,
device=torch_device,
)
# Set up minimal inputs
config = LlamaConfig(
vocab_size=1000,
hidden_size=128,
num_attention_heads=4,
num_hidden_layers=2,
)
input_embeds = torch.randn(batch_size, seq_len, config.hidden_size, device=torch_device)
cache_position = torch.arange(seq_len, device=torch_device)
# Test _preprocess_mask_arguments preserves non-binary values
early_exit, processed_mask, kv_length, kv_offset = _preprocess_mask_arguments(
config, input_embeds, attention_mask, cache_position, None, None
)
# The original mask dtype and values should be preserved (not converted to bool)
self.assertEqual(
processed_mask.dtype, torch.float32, "Non-binary attention mask should not be converted to bool dtype"
)
# Check that the original fractional values are preserved
self.assertTrue(
torch.allclose(processed_mask, attention_mask), "Non-binary attention mask values should be preserved"
)
# Test flash_attention_mask properly handles non-binary masks
flash_mask = flash_attention_mask(
batch_size,
cache_position,
kv_length,
kv_offset=kv_offset,
attention_mask=processed_mask,
)
# For non-binary masks with zeros, we should get a mask (not None)
# since there are padding tokens (0.0 values)
self.assertIsNotNone(flash_mask, "Flash attention mask should not be None when there are padding tokens")
# Test that binary masks still work correctly
binary_mask = torch.tensor(
[
[1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0],
[1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0],
],
dtype=torch.float32,
device=torch_device,
)
early_exit_binary, processed_binary_mask, kv_length_binary, kv_offset_binary = _preprocess_mask_arguments(
config, input_embeds, binary_mask, cache_position, None, None
)
# Binary masks should be converted to bool for efficiency
self.assertEqual(
processed_binary_mask.dtype,
torch.bool,
"Binary attention mask should be converted to bool dtype for efficiency",
)
# Test edge case: all non-padding tokens (should return None for flash attention)
all_valid_non_binary = torch.tensor(
[
[1.5, 0.8, 1.2, 1.0, 1.0, 0.9, 1.3, 0.7],
[1.0, 1.0, 0.5, 1.2, 1.8, 1.8, 1.1, 0.2],
],
dtype=torch.float32,
device=torch_device,
)
early_exit_valid, processed_valid_mask, kv_length_valid, kv_offset_valid = _preprocess_mask_arguments(
config, input_embeds, all_valid_non_binary, cache_position, None, None
)
flash_mask_valid = flash_attention_mask(
batch_size,
cache_position,
kv_length_valid,
kv_offset=kv_offset_valid,
attention_mask=processed_valid_mask,
)
# When all tokens are valid (no padding), flash attention should return None
self.assertIsNone(
flash_mask_valid, "Flash attention mask should be None when all tokens are valid (no padding)"
)
def check_to_4d(self, mask_converter, q_len, kv_len, additional_mask=None, bsz=3):
mask_2d = torch.ones((bsz, kv_len), device=torch_device, dtype=torch.long)