mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Generate: Simplify is_pad_token_not_equal_to_eos_token_id (#18933)
This commit is contained in:
parent
85125fcffd
commit
f1a6df3210
@ -1739,9 +1739,7 @@ class TFGenerationMixin:
|
||||
) -> tf.Tensor:
|
||||
is_input_ids = len(inputs.shape) == 2 and inputs.dtype in (tf.int32, tf.int64)
|
||||
is_pad_token_in_inputs = (pad_token_id is not None) and tf.math.reduce_any(inputs == pad_token_id)
|
||||
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (
|
||||
(eos_token_id is not None) and (pad_token_id != eos_token_id)
|
||||
)
|
||||
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id != eos_token_id)
|
||||
|
||||
# Check if input is input_ids and padded -> only then is attention_mask defined
|
||||
if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id:
|
||||
|
@ -495,9 +495,8 @@ class GenerationMixin:
|
||||
) -> torch.LongTensor:
|
||||
is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long]
|
||||
is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs)
|
||||
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (
|
||||
(eos_token_id is not None) and (pad_token_id != eos_token_id)
|
||||
)
|
||||
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id != eos_token_id)
|
||||
|
||||
# Check if input is input_ids and padded -> only then is attention_mask defined
|
||||
if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id:
|
||||
return inputs.ne(pad_token_id).long()
|
||||
|
Loading…
Reference in New Issue
Block a user