Generate: Simplify is_pad_token_not_equal_to_eos_token_id (#18933)

This commit is contained in:
Ekagra Ranjan 2022-09-09 21:14:56 +05:30 committed by GitHub
parent 85125fcffd
commit f1a6df3210
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 3 additions and 6 deletions

View File

@ -1739,9 +1739,7 @@ class TFGenerationMixin:
) -> tf.Tensor:
is_input_ids = len(inputs.shape) == 2 and inputs.dtype in (tf.int32, tf.int64)
is_pad_token_in_inputs = (pad_token_id is not None) and tf.math.reduce_any(inputs == pad_token_id)
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (
(eos_token_id is not None) and (pad_token_id != eos_token_id)
)
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id != eos_token_id)
# Check if input is input_ids and padded -> only then is attention_mask defined
if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id:

View File

@ -495,9 +495,8 @@ class GenerationMixin:
) -> torch.LongTensor:
is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long]
is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs)
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (
(eos_token_id is not None) and (pad_token_id != eos_token_id)
)
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id != eos_token_id)
# Check if input is input_ids and padded -> only then is attention_mask defined
if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id:
return inputs.ne(pad_token_id).long()