[generate] fix eos/pad id check on mps devices (#31695)

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
Sanchit Gandhi 2024-07-22 21:18:48 +08:00 committed by GitHub
parent f2a1e3ca68
commit 5a649ff3ec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1542,10 +1542,7 @@ class GenerationMixin:
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.") logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.")
# we can't infer attn mask if pad token is set to be eos token in model's generation config # we can't infer attn mask if pad token is set to be eos token in model's generation config
if ( if eos_token_tensor is not None and pad_token_tensor in eos_token_tensor:
eos_token_tensor is not None
and torch.isin(elements=eos_token_tensor, test_elements=pad_token_tensor).any()
):
if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask: if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask:
logger.warning_once( logger.warning_once(
"The attention mask is not set and cannot be inferred from input because pad token is same as eos token." "The attention mask is not set and cannot be inferred from input because pad token is same as eos token."