From 4fed806b68d7446de8719c9684282648ecaa107b Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 3 Jul 2025 15:29:20 +0200 Subject: [PATCH] Update masking_utils.py --- src/transformers/masking_utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/transformers/masking_utils.py b/src/transformers/masking_utils.py index 0777cb24928..b194540bc37 100644 --- a/src/transformers/masking_utils.py +++ b/src/transformers/masking_utils.py @@ -627,7 +627,7 @@ def find_packed_sequence_indices(position_ids: torch.Tensor) -> Optional[torch.T # cannot be part of the end of the first batch dim and the start of the 2nd one for example first_dummy_value = position_ids[:, :1] - 1 # We just need the diff on this first value to be 1 position_diff = torch.diff(position_ids, prepend=first_dummy_value, dim=-1) - packed_sequence_mask = (position_diff != 1).cumsum(1) + packed_sequence_mask = (position_diff != 1).cumsum(-1) # Here it would be nice to return None if we did not detect packed sequence format, i.e. if `packed_sequence_mask[:, -1] == 0` # but it causes issues with export @@ -702,9 +702,10 @@ def _preprocess_mask_arguments( else: kv_length, kv_offset = input_embeds.shape[1], 0 - # We check the position_ids for potential packed sequence format (only if the 2D attention mask is explicitly None) + # We check the position_ids for potential packed sequence format (only if the 2D attention mask is explicitly None, + # and we don't have past_key_values, i.e. generally a training setup) packed_sequence_mask = None - if position_ids is not None and attention_mask is None: + if position_ids is not None and attention_mask is None and past_key_values is None: packed_sequence_mask = find_packed_sequence_indices(position_ids) return False, attention_mask, packed_sequence_mask, kv_length, kv_offset