From 5a649ff3ecd70599dd0fea7ee430ba47b51a4556 Mon Sep 17 00:00:00 2001 From: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Date: Mon, 22 Jul 2024 21:18:48 +0800 Subject: [PATCH] [generate] fix eos/pad id check on mps devices (#31695) Co-authored-by: Joao Gante --- src/transformers/generation/utils.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index ebe968b7ac4..51019da9a6b 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1542,10 +1542,7 @@ class GenerationMixin: logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.") # we can't infer attn mask if pad token is set to be eos token in model's generation config - if ( - eos_token_tensor is not None - and torch.isin(elements=eos_token_tensor, test_elements=pad_token_tensor).any() - ): + if eos_token_tensor is not None and pad_token_tensor in eos_token_tensor: if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask: logger.warning_once( "The attention mask is not set and cannot be inferred from input because pad token is same as eos token."