Update masking_utils.py

This commit is contained in:
Cyril Vallez 2025-07-03 15:29:20 +02:00
parent 111a3eac38
commit 4fed806b68
No known key found for this signature in database

View File

@ -627,7 +627,7 @@ def find_packed_sequence_indices(position_ids: torch.Tensor) -> Optional[torch.T
# cannot be part of the end of the first batch dim and the start of the 2nd one for example
first_dummy_value = position_ids[:, :1] - 1 # We just need the diff on this first value to be 1
position_diff = torch.diff(position_ids, prepend=first_dummy_value, dim=-1)
packed_sequence_mask = (position_diff != 1).cumsum(1)
packed_sequence_mask = (position_diff != 1).cumsum(-1)
# Here it would be nice to return None if we did not detect packed sequence format, i.e. if `packed_sequence_mask[:, -1] == 0`
# but it causes issues with export
@ -702,9 +702,10 @@ def _preprocess_mask_arguments(
else:
kv_length, kv_offset = input_embeds.shape[1], 0
# We check the position_ids for potential packed sequence format (only if the 2D attention mask is explicitly None)
# We check the position_ids for potential packed sequence format (only if the 2D attention mask is explicitly None,
# and we don't have past_key_values, i.e. generally a training setup)
packed_sequence_mask = None
if position_ids is not None and attention_mask is None:
if position_ids is not None and attention_mask is None and past_key_values is None:
packed_sequence_mask = find_packed_sequence_indices(position_ids)
return False, attention_mask, packed_sequence_mask, kv_length, kv_offset