[generate] only require an attention mask for mps with torch<2.4 (#32367)

* up

* style

* stopping
This commit is contained in:
Sanchit Gandhi 2024-08-02 17:32:50 +08:00 committed by GitHub
parent 083e13b7c4
commit c1aa0edb48
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 9 additions and 4 deletions

View File

@ -9,6 +9,8 @@ import numpy as np
import torch
from torch.nn import functional as F
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4
from ..tokenization_utils_base import PreTrainedTokenizerBase
from ..utils import add_start_docstrings, logging
@ -485,7 +487,8 @@ class EosTokenCriteria(StoppingCriteria):
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
self.eos_token_id = self.eos_token_id.to(input_ids.device)
if input_ids.device.type == "mps":
if input_ids.device.type == "mps" and not is_torch_greater_or_equal_than_2_4:
# TODO: remove this workaround when we stop supporting torch<=2.3
# https://github.com/pytorch/pytorch/issues/77764#issuecomment-2067838075
is_done = (
input_ids[:, -1]

View File

@ -47,6 +47,7 @@ from ..models.auto import (
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
MODEL_FOR_VISION_2_SEQ_MAPPING,
)
from ..pytorch_utils import is_torch_greater_or_equal_than_2_4
from ..tokenization_utils import ExtensionsTrie
from ..utils import (
ModelOutput,
@ -488,10 +489,10 @@ class GenerationMixin:
return default_attention_mask
# Otherwise we have may have information -> try to infer the attention mask
if inputs.device.type == "mps":
# mps does not support torch.isin (https://github.com/pytorch/pytorch/issues/77764)
if inputs.device.type == "mps" and not is_torch_greater_or_equal_than_2_4:
# mps does not support torch.isin for torch<2.4 (https://github.com/pytorch/pytorch/issues/77764)
raise ValueError(
"Can't infer missing attention mask on `mps` device. Please provide an `attention_mask` or use a different device."
"Can't infer missing attention mask on `mps` device for torch<2.4. Please provide an `attention_mask` or upgrade to torch>=2.4"
)
is_pad_token_in_inputs = (pad_token_id is not None) and (

View File

@ -28,6 +28,7 @@ logger = logging.get_logger(__name__)
parsed_torch_version_base = version.parse(version.parse(torch.__version__).base_version)
is_torch_greater_or_equal_than_2_4 = parsed_torch_version_base >= version.parse("2.4")
is_torch_greater_or_equal_than_2_3 = parsed_torch_version_base >= version.parse("2.3")
is_torch_greater_or_equal_than_2_2 = parsed_torch_version_base >= version.parse("2.2")
is_torch_greater_or_equal_than_2_1 = parsed_torch_version_base >= version.parse("2.1")