[generate] fix eos/pad id check on mps devices (#31695)

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
Sanchit Gandhi 2024-07-22 21:18:48 +08:00 committed by GitHub
parent f2a1e3ca68
commit 5a649ff3ec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1542,10 +1542,7 @@ class GenerationMixin:
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.")
# we can't infer attn mask if pad token is set to be eos token in model's generation config
if (
eos_token_tensor is not None
and torch.isin(elements=eos_token_tensor, test_elements=pad_token_tensor).any()
):
if eos_token_tensor is not None and pad_token_tensor in eos_token_tensor:
if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask:
logger.warning_once(
"The attention mask is not set and cannot be inferred from input because pad token is same as eos token."