mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[generate] only require an attention mask for mps with torch<2.4 (#32367)
* up * style * stopping
This commit is contained in:
parent
083e13b7c4
commit
c1aa0edb48
@ -9,6 +9,8 @@ import numpy as np
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4
|
||||
|
||||
from ..tokenization_utils_base import PreTrainedTokenizerBase
|
||||
from ..utils import add_start_docstrings, logging
|
||||
|
||||
@ -485,7 +487,8 @@ class EosTokenCriteria(StoppingCriteria):
|
||||
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
|
||||
self.eos_token_id = self.eos_token_id.to(input_ids.device)
|
||||
if input_ids.device.type == "mps":
|
||||
if input_ids.device.type == "mps" and not is_torch_greater_or_equal_than_2_4:
|
||||
# TODO: remove this workaround when we stop supporting torch<=2.3
|
||||
# https://github.com/pytorch/pytorch/issues/77764#issuecomment-2067838075
|
||||
is_done = (
|
||||
input_ids[:, -1]
|
||||
|
@ -47,6 +47,7 @@ from ..models.auto import (
|
||||
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
|
||||
MODEL_FOR_VISION_2_SEQ_MAPPING,
|
||||
)
|
||||
from ..pytorch_utils import is_torch_greater_or_equal_than_2_4
|
||||
from ..tokenization_utils import ExtensionsTrie
|
||||
from ..utils import (
|
||||
ModelOutput,
|
||||
@ -488,10 +489,10 @@ class GenerationMixin:
|
||||
return default_attention_mask
|
||||
|
||||
# Otherwise we have may have information -> try to infer the attention mask
|
||||
if inputs.device.type == "mps":
|
||||
# mps does not support torch.isin (https://github.com/pytorch/pytorch/issues/77764)
|
||||
if inputs.device.type == "mps" and not is_torch_greater_or_equal_than_2_4:
|
||||
# mps does not support torch.isin for torch<2.4 (https://github.com/pytorch/pytorch/issues/77764)
|
||||
raise ValueError(
|
||||
"Can't infer missing attention mask on `mps` device. Please provide an `attention_mask` or use a different device."
|
||||
"Can't infer missing attention mask on `mps` device for torch<2.4. Please provide an `attention_mask` or upgrade to torch>=2.4"
|
||||
)
|
||||
|
||||
is_pad_token_in_inputs = (pad_token_id is not None) and (
|
||||
|
@ -28,6 +28,7 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
parsed_torch_version_base = version.parse(version.parse(torch.__version__).base_version)
|
||||
|
||||
is_torch_greater_or_equal_than_2_4 = parsed_torch_version_base >= version.parse("2.4")
|
||||
is_torch_greater_or_equal_than_2_3 = parsed_torch_version_base >= version.parse("2.3")
|
||||
is_torch_greater_or_equal_than_2_2 = parsed_torch_version_base >= version.parse("2.2")
|
||||
is_torch_greater_or_equal_than_2_1 = parsed_torch_version_base >= version.parse("2.1")
|
||||
|
Loading…
Reference in New Issue
Block a user