[ImageGPT] Deprecate pixel_values input name to input_ids (#14801)

* [ImageGPT] Deprecate pixel_values input name to input_ids

* up

* Apply suggestions from code review

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* correct

* finish

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>
This commit is contained in:
Patrick von Platen 2021-12-17 20:05:22 +01:00 committed by GitHub
parent c4a96cecbc
commit 84ea427f46
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 93 additions and 33 deletions

View File

@ -16,6 +16,7 @@
import math
import os
import warnings
from typing import Tuple
import torch
@ -550,22 +551,22 @@ IMAGEGPT_START_DOCSTRING = r"""
IMAGEGPT_INPUTS_DOCSTRING = r"""
Args:
pixel_values (:obj:`torch.LongTensor` of shape :obj:`(batch_size, pixel_values_length)`):
:obj:`pixel_values_length` = ``sequence_length`` if :obj:`past_key_values` is ``None`` else
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
:obj:`input_ids_length` = ``sequence_length`` if :obj:`past_key_values` is ``None`` else
``past_key_values[0][0].shape[-2]`` (``sequence_length`` of input past key value states). Indices of input
sequence tokens in the vocabulary.
If :obj:`past_key_values` is used, only ``pixel_values`` that do not have their past calculated should be
passed as ``pixel_values``.
If :obj:`past_key_values` is used, only ``input_ids`` that do not have their past calculated should be
passed as ``input_ids``.
Indices can be obtained using :class:`~transformers.ImageGPTFeatureExtractor`. See
:meth:`transformers.ImageGPTFeatureExtractor.__call__` for details.
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers`):
Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
:obj:`past_key_values` output below). Can be used to speed up sequential decoding. The ``pixel_values``
which have their past given to this model should not be passed as ``pixel_values`` as they have already
been computed.
:obj:`past_key_values` output below). Can be used to speed up sequential decoding. The ``input_ids`` which
have their past given to this model should not be passed as ``input_ids`` as they have already been
computed.
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
@ -573,7 +574,7 @@ IMAGEGPT_INPUTS_DOCSTRING = r"""
- 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, pixel_values_length)`, `optional`):
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
1]``:
@ -593,9 +594,9 @@ IMAGEGPT_INPUTS_DOCSTRING = r"""
- 0 indicates the head is **masked**.
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`pixel_values` you can choose to directly pass an embedded
representation. This is useful if you want more control over how to convert :obj:`pixel_values` indices
into associated vectors than the model's internal embedding lookup matrix.
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
vectors than the model's internal embedding lookup matrix.
If :obj:`past_key_values` is used, optionally only the last :obj:`inputs_embeds` have to be input (see
:obj:`past_key_values`).
@ -656,7 +657,7 @@ class ImageGPTModel(ImageGPTPreTrainedModel):
@replace_return_docstrings(output_type=BaseModelOutputWithPastAndCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values=None,
input_ids=None,
past_key_values=None,
attention_mask=None,
token_type_ids=None,
@ -669,11 +670,12 @@ class ImageGPTModel(ImageGPTPreTrainedModel):
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
``labels = pixel_values`` Indices are selected in ``[-100, 0, ..., config.vocab_size]`` All labels set to
``labels = input_ids`` Indices are selected in ``[-100, 0, ..., config.vocab_size]`` All labels set to
``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]``
Returns:
@ -695,6 +697,20 @@ class ImageGPTModel(ImageGPTPreTrainedModel):
>>> last_hidden_states = outputs.last_hidden_state
"""
if "pixel_values" in kwargs:
warnings.warn(
"The `pixel_values` argument is deprecated and will be removed in a future version, use `input_ids` instead.",
FutureWarning,
)
if input_ids is not None:
raise ValueError(
"You cannot pass both `pixel_values` and `input_ids`. "
"Please make sure to only pass `input_ids`."
)
input_ids = kwargs.pop("pixel_values")
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@ -702,19 +718,19 @@ class ImageGPTModel(ImageGPTPreTrainedModel):
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both pixel_values and inputs_embeds at the same time")
elif pixel_values is not None:
input_shape = pixel_values.size()
pixel_values = pixel_values.view(-1, input_shape[-1])
batch_size = pixel_values.shape[0]
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
batch_size = input_ids.shape[0]
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
batch_size = inputs_embeds.shape[0]
else:
raise ValueError("You have to specify either pixel_values or inputs_embeds")
raise ValueError("You have to specify either input_ids or inputs_embeds")
device = pixel_values.device if pixel_values is not None else inputs_embeds.device
device = input_ids.device if input_ids is not None else inputs_embeds.device
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1])
@ -768,7 +784,7 @@ class ImageGPTModel(ImageGPTPreTrainedModel):
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
if inputs_embeds is None:
inputs_embeds = self.wte(pixel_values)
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
@ -901,11 +917,11 @@ class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel):
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def prepare_inputs_for_generation(self, pixel_values, past=None, **kwargs):
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None)
# only last token for inputs_ids if past is defined in kwargs
if past:
pixel_values = pixel_values[:, -1].unsqueeze(-1)
input_ids = input_ids[:, -1].unsqueeze(-1)
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
@ -921,7 +937,7 @@ class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel):
else:
position_ids = None
return {
"pixel_values": pixel_values,
"input_ids": input_ids,
"past_key_values": past,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
@ -933,7 +949,7 @@ class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel):
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values=None,
input_ids=None,
past_key_values=None,
attention_mask=None,
token_type_ids=None,
@ -947,11 +963,12 @@ class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel):
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
``labels = pixel_values`` Indices are selected in ``[-100, 0, ..., config.vocab_size]`` All labels set to
``labels = input_ids`` Indices are selected in ``[-100, 0, ..., config.vocab_size]`` All labels set to
``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]``
Returns:
@ -972,7 +989,7 @@ class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel):
>>> batch_size = 8
>>> context = torch.full((batch_size, 1), model.config.vocab_size - 1) #initialize with SOS token
>>> context = torch.tensor(context).to(device)
>>> output = model.generate(pixel_values=context, max_length=model.config.n_positions + 1, temperature=1.0, do_sample=True, top_k=40)
>>> output = model.generate(input_ids=context, max_length=model.config.n_positions + 1, temperature=1.0, do_sample=True, top_k=40)
>>> clusters = feature_extractor.clusters
>>> n_px = feature_extractor.size
@ -986,10 +1003,24 @@ class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel):
... ax.imshow(img)
"""
if "pixel_values" in kwargs:
warnings.warn(
"The `pixel_values` argument is deprecated and will be removed in a future version, use `input_ids` instead.",
FutureWarning,
)
if input_ids is not None:
raise ValueError(
"You cannot pass both `pixel_values` and `input_ids`. "
"Please make sure to only pass `input_ids`."
)
input_ids = kwargs.pop("pixel_values")
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.transformer(
pixel_values,
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
@ -1066,7 +1097,7 @@ class ImageGPTForImageClassification(ImageGPTPreTrainedModel):
@replace_return_docstrings(output_type=SequenceClassifierOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values=None,
input_ids=None,
past_key_values=None,
attention_mask=None,
token_type_ids=None,
@ -1078,6 +1109,7 @@ class ImageGPTForImageClassification(ImageGPTPreTrainedModel):
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
@ -1103,10 +1135,25 @@ class ImageGPTForImageClassification(ImageGPTPreTrainedModel):
>>> outputs = model(**inputs)
>>> logits = outputs.logits
"""
if "pixel_values" in kwargs:
warnings.warn(
"The `pixel_values` argument is deprecated and will be removed in a future version, use `input_ids` instead.",
FutureWarning,
)
if input_ids is not None:
raise ValueError(
"You cannot pass both `pixel_values` and `input_ids`. "
"Please make sure to only pass `input_ids`."
)
input_ids = kwargs.pop("pixel_values")
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.transformer(
pixel_values,
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,

View File

@ -20,7 +20,7 @@ import unittest
from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_modeling_common import floats_tensor
from .test_modeling_common import floats_tensor, ids_tensor
if is_torch_available():
@ -31,6 +31,7 @@ if is_torch_available():
BartTokenizer,
GPT2LMHeadModel,
GPT2Tokenizer,
ImageGPTForCausalImageModeling,
Speech2TextForConditionalGeneration,
SpeechEncoderDecoderModel,
VisionEncoderDecoderModel,
@ -1766,6 +1767,18 @@ class GenerationIntegrationTests(unittest.TestCase):
self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist())
self.assertEqual(output_sequences.shape, (1, 15))
def test_generate_non_nlp_input_ids_as_kwarg(self):
model = ImageGPTForCausalImageModeling.from_pretrained(
"hf-internal-testing/tiny-random-imagegpt", max_length=10
).to(torch_device)
input_ids = ids_tensor((3, 5), vocab_size=10)
output_sequences_kwargs = model.generate(input_ids=input_ids).cpu()
output_sequences = model.generate(input_ids).cpu()
self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist())
self.assertEqual(output_sequences.shape, (3, 10))
def test_generate_input_ids_as_encoder_kwarg(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")

View File

@ -314,7 +314,7 @@ class ImageGPTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCa
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["pixel_values"]
expected_arg_names = ["input_ids"]
self.assertListEqual(arg_names[:1], expected_arg_names)
def test_resize_tokens_embeddings(self):