mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[ImageGPT] Deprecate pixel_values input name to input_ids (#14801)
* [ImageGPT] Deprecate pixel_values input name to input_ids * up * Apply suggestions from code review Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * correct * finish Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>
This commit is contained in:
parent
c4a96cecbc
commit
84ea427f46
@ -16,6 +16,7 @@
|
||||
|
||||
import math
|
||||
import os
|
||||
import warnings
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
@ -550,22 +551,22 @@ IMAGEGPT_START_DOCSTRING = r"""
|
||||
|
||||
IMAGEGPT_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
pixel_values (:obj:`torch.LongTensor` of shape :obj:`(batch_size, pixel_values_length)`):
|
||||
:obj:`pixel_values_length` = ``sequence_length`` if :obj:`past_key_values` is ``None`` else
|
||||
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
|
||||
:obj:`input_ids_length` = ``sequence_length`` if :obj:`past_key_values` is ``None`` else
|
||||
``past_key_values[0][0].shape[-2]`` (``sequence_length`` of input past key value states). Indices of input
|
||||
sequence tokens in the vocabulary.
|
||||
|
||||
If :obj:`past_key_values` is used, only ``pixel_values`` that do not have their past calculated should be
|
||||
passed as ``pixel_values``.
|
||||
If :obj:`past_key_values` is used, only ``input_ids`` that do not have their past calculated should be
|
||||
passed as ``input_ids``.
|
||||
|
||||
Indices can be obtained using :class:`~transformers.ImageGPTFeatureExtractor`. See
|
||||
:meth:`transformers.ImageGPTFeatureExtractor.__call__` for details.
|
||||
|
||||
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers`):
|
||||
Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
|
||||
:obj:`past_key_values` output below). Can be used to speed up sequential decoding. The ``pixel_values``
|
||||
which have their past given to this model should not be passed as ``pixel_values`` as they have already
|
||||
been computed.
|
||||
:obj:`past_key_values` output below). Can be used to speed up sequential decoding. The ``input_ids`` which
|
||||
have their past given to this model should not be passed as ``input_ids`` as they have already been
|
||||
computed.
|
||||
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
|
||||
|
||||
@ -573,7 +574,7 @@ IMAGEGPT_INPUTS_DOCSTRING = r"""
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, pixel_values_length)`, `optional`):
|
||||
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
|
||||
1]``:
|
||||
|
||||
@ -593,9 +594,9 @@ IMAGEGPT_INPUTS_DOCSTRING = r"""
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||
Optionally, instead of passing :obj:`pixel_values` you can choose to directly pass an embedded
|
||||
representation. This is useful if you want more control over how to convert :obj:`pixel_values` indices
|
||||
into associated vectors than the model's internal embedding lookup matrix.
|
||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
|
||||
vectors than the model's internal embedding lookup matrix.
|
||||
|
||||
If :obj:`past_key_values` is used, optionally only the last :obj:`inputs_embeds` have to be input (see
|
||||
:obj:`past_key_values`).
|
||||
@ -656,7 +657,7 @@ class ImageGPTModel(ImageGPTPreTrainedModel):
|
||||
@replace_return_docstrings(output_type=BaseModelOutputWithPastAndCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
pixel_values=None,
|
||||
input_ids=None,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
@ -669,11 +670,12 @@ class ImageGPTModel(ImageGPTPreTrainedModel):
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
||||
``labels = pixel_values`` Indices are selected in ``[-100, 0, ..., config.vocab_size]`` All labels set to
|
||||
``labels = input_ids`` Indices are selected in ``[-100, 0, ..., config.vocab_size]`` All labels set to
|
||||
``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]``
|
||||
|
||||
Returns:
|
||||
@ -695,6 +697,20 @@ class ImageGPTModel(ImageGPTPreTrainedModel):
|
||||
>>> last_hidden_states = outputs.last_hidden_state
|
||||
"""
|
||||
|
||||
if "pixel_values" in kwargs:
|
||||
warnings.warn(
|
||||
"The `pixel_values` argument is deprecated and will be removed in a future version, use `input_ids` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
|
||||
if input_ids is not None:
|
||||
raise ValueError(
|
||||
"You cannot pass both `pixel_values` and `input_ids`. "
|
||||
"Please make sure to only pass `input_ids`."
|
||||
)
|
||||
|
||||
input_ids = kwargs.pop("pixel_values")
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
@ -702,19 +718,19 @@ class ImageGPTModel(ImageGPTPreTrainedModel):
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if pixel_values is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both pixel_values and inputs_embeds at the same time")
|
||||
elif pixel_values is not None:
|
||||
input_shape = pixel_values.size()
|
||||
pixel_values = pixel_values.view(-1, input_shape[-1])
|
||||
batch_size = pixel_values.shape[0]
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
batch_size = input_ids.shape[0]
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
batch_size = inputs_embeds.shape[0]
|
||||
else:
|
||||
raise ValueError("You have to specify either pixel_values or inputs_embeds")
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
device = pixel_values.device if pixel_values is not None else inputs_embeds.device
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
||||
@ -768,7 +784,7 @@ class ImageGPTModel(ImageGPTPreTrainedModel):
|
||||
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.wte(pixel_values)
|
||||
inputs_embeds = self.wte(input_ids)
|
||||
position_embeds = self.wpe(position_ids)
|
||||
hidden_states = inputs_embeds + position_embeds
|
||||
|
||||
@ -901,11 +917,11 @@ class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel):
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def prepare_inputs_for_generation(self, pixel_values, past=None, **kwargs):
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
||||
token_type_ids = kwargs.get("token_type_ids", None)
|
||||
# only last token for inputs_ids if past is defined in kwargs
|
||||
if past:
|
||||
pixel_values = pixel_values[:, -1].unsqueeze(-1)
|
||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
|
||||
|
||||
@ -921,7 +937,7 @@ class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel):
|
||||
else:
|
||||
position_ids = None
|
||||
return {
|
||||
"pixel_values": pixel_values,
|
||||
"input_ids": input_ids,
|
||||
"past_key_values": past,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
"position_ids": position_ids,
|
||||
@ -933,7 +949,7 @@ class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel):
|
||||
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
pixel_values=None,
|
||||
input_ids=None,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
@ -947,11 +963,12 @@ class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel):
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
||||
``labels = pixel_values`` Indices are selected in ``[-100, 0, ..., config.vocab_size]`` All labels set to
|
||||
``labels = input_ids`` Indices are selected in ``[-100, 0, ..., config.vocab_size]`` All labels set to
|
||||
``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]``
|
||||
|
||||
Returns:
|
||||
@ -972,7 +989,7 @@ class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel):
|
||||
>>> batch_size = 8
|
||||
>>> context = torch.full((batch_size, 1), model.config.vocab_size - 1) #initialize with SOS token
|
||||
>>> context = torch.tensor(context).to(device)
|
||||
>>> output = model.generate(pixel_values=context, max_length=model.config.n_positions + 1, temperature=1.0, do_sample=True, top_k=40)
|
||||
>>> output = model.generate(input_ids=context, max_length=model.config.n_positions + 1, temperature=1.0, do_sample=True, top_k=40)
|
||||
|
||||
>>> clusters = feature_extractor.clusters
|
||||
>>> n_px = feature_extractor.size
|
||||
@ -986,10 +1003,24 @@ class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel):
|
||||
... ax.imshow(img)
|
||||
"""
|
||||
|
||||
if "pixel_values" in kwargs:
|
||||
warnings.warn(
|
||||
"The `pixel_values` argument is deprecated and will be removed in a future version, use `input_ids` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
|
||||
if input_ids is not None:
|
||||
raise ValueError(
|
||||
"You cannot pass both `pixel_values` and `input_ids`. "
|
||||
"Please make sure to only pass `input_ids`."
|
||||
)
|
||||
|
||||
input_ids = kwargs.pop("pixel_values")
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
pixel_values,
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
@ -1066,7 +1097,7 @@ class ImageGPTForImageClassification(ImageGPTPreTrainedModel):
|
||||
@replace_return_docstrings(output_type=SequenceClassifierOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
pixel_values=None,
|
||||
input_ids=None,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
@ -1078,6 +1109,7 @@ class ImageGPTForImageClassification(ImageGPTPreTrainedModel):
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
@ -1103,10 +1135,25 @@ class ImageGPTForImageClassification(ImageGPTPreTrainedModel):
|
||||
>>> outputs = model(**inputs)
|
||||
>>> logits = outputs.logits
|
||||
"""
|
||||
|
||||
if "pixel_values" in kwargs:
|
||||
warnings.warn(
|
||||
"The `pixel_values` argument is deprecated and will be removed in a future version, use `input_ids` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
|
||||
if input_ids is not None:
|
||||
raise ValueError(
|
||||
"You cannot pass both `pixel_values` and `input_ids`. "
|
||||
"Please make sure to only pass `input_ids`."
|
||||
)
|
||||
|
||||
input_ids = kwargs.pop("pixel_values")
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
pixel_values,
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
|
@ -20,7 +20,7 @@ import unittest
|
||||
from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from .test_modeling_common import floats_tensor
|
||||
from .test_modeling_common import floats_tensor, ids_tensor
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@ -31,6 +31,7 @@ if is_torch_available():
|
||||
BartTokenizer,
|
||||
GPT2LMHeadModel,
|
||||
GPT2Tokenizer,
|
||||
ImageGPTForCausalImageModeling,
|
||||
Speech2TextForConditionalGeneration,
|
||||
SpeechEncoderDecoderModel,
|
||||
VisionEncoderDecoderModel,
|
||||
@ -1766,6 +1767,18 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist())
|
||||
self.assertEqual(output_sequences.shape, (1, 15))
|
||||
|
||||
def test_generate_non_nlp_input_ids_as_kwarg(self):
|
||||
model = ImageGPTForCausalImageModeling.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-imagegpt", max_length=10
|
||||
).to(torch_device)
|
||||
input_ids = ids_tensor((3, 5), vocab_size=10)
|
||||
|
||||
output_sequences_kwargs = model.generate(input_ids=input_ids).cpu()
|
||||
output_sequences = model.generate(input_ids).cpu()
|
||||
|
||||
self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist())
|
||||
self.assertEqual(output_sequences.shape, (3, 10))
|
||||
|
||||
def test_generate_input_ids_as_encoder_kwarg(self):
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
|
@ -314,7 +314,7 @@ class ImageGPTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCa
|
||||
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||
arg_names = [*signature.parameters.keys()]
|
||||
|
||||
expected_arg_names = ["pixel_values"]
|
||||
expected_arg_names = ["input_ids"]
|
||||
self.assertListEqual(arg_names[:1], expected_arg_names)
|
||||
|
||||
def test_resize_tokens_embeddings(self):
|
||||
|
Loading…
Reference in New Issue
Block a user