mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Update masking_utils.py
This commit is contained in:
parent
111a3eac38
commit
4fed806b68
@ -627,7 +627,7 @@ def find_packed_sequence_indices(position_ids: torch.Tensor) -> Optional[torch.T
|
|||||||
# cannot be part of the end of the first batch dim and the start of the 2nd one for example
|
# cannot be part of the end of the first batch dim and the start of the 2nd one for example
|
||||||
first_dummy_value = position_ids[:, :1] - 1 # We just need the diff on this first value to be 1
|
first_dummy_value = position_ids[:, :1] - 1 # We just need the diff on this first value to be 1
|
||||||
position_diff = torch.diff(position_ids, prepend=first_dummy_value, dim=-1)
|
position_diff = torch.diff(position_ids, prepend=first_dummy_value, dim=-1)
|
||||||
packed_sequence_mask = (position_diff != 1).cumsum(1)
|
packed_sequence_mask = (position_diff != 1).cumsum(-1)
|
||||||
|
|
||||||
# Here it would be nice to return None if we did not detect packed sequence format, i.e. if `packed_sequence_mask[:, -1] == 0`
|
# Here it would be nice to return None if we did not detect packed sequence format, i.e. if `packed_sequence_mask[:, -1] == 0`
|
||||||
# but it causes issues with export
|
# but it causes issues with export
|
||||||
@ -702,9 +702,10 @@ def _preprocess_mask_arguments(
|
|||||||
else:
|
else:
|
||||||
kv_length, kv_offset = input_embeds.shape[1], 0
|
kv_length, kv_offset = input_embeds.shape[1], 0
|
||||||
|
|
||||||
# We check the position_ids for potential packed sequence format (only if the 2D attention mask is explicitly None)
|
# We check the position_ids for potential packed sequence format (only if the 2D attention mask is explicitly None,
|
||||||
|
# and we don't have past_key_values, i.e. generally a training setup)
|
||||||
packed_sequence_mask = None
|
packed_sequence_mask = None
|
||||||
if position_ids is not None and attention_mask is None:
|
if position_ids is not None and attention_mask is None and past_key_values is None:
|
||||||
packed_sequence_mask = find_packed_sequence_indices(position_ids)
|
packed_sequence_mask = find_packed_sequence_indices(position_ids)
|
||||||
|
|
||||||
return False, attention_mask, packed_sequence_mask, kv_length, kv_offset
|
return False, attention_mask, packed_sequence_mask, kv_length, kv_offset
|
||||||
|
Loading…
Reference in New Issue
Block a user