[Generate] Make generate multi-modal (#14784)

* finish refactor

* refactor

* add tests

* add more tests

* up

* finish tests

* finish

* up

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* improve docstring

* fix docs

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Patrick von Platen 2021-12-16 18:03:55 +01:00 committed by GitHub
parent 48463ebb33
commit b18d8534ea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 243 additions and 97 deletions

View File

@ -359,12 +359,72 @@ BeamSearchOutput = Union[BeamSearchEncoderDecoderOutput, BeamSearchDecoderOnlyOu
BeamSampleOutput = Union[BeamSampleEncoderDecoderOutput, BeamSampleDecoderOnlyOutput]
ENCODER_MODEL_INPUT_NAMES = ["input_ids", "inputs_embeds", "input_values", "input_features", "pixel_values"]
class GenerationMixin:
"""
A class containing all of the functions supporting generation, to be used as a mixin in
:class:`~transformers.PreTrainedModel`.
"""
def _prepare_model_inputs(
self,
inputs: Optional[torch.Tensor] = None,
bos_token_id: Optional[int] = None,
model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Optional[str]]:
"""
This function extracts the model-specific `inputs` for generation.
"""
# filter model input names that are `None`
model_kwargs = {k: v for k, v in model_kwargs.items() if k not in ENCODER_MODEL_INPUT_NAMES or v is not None}
# extract keyword arguments that are model input specific
model_input_kwarg_names = set(ENCODER_MODEL_INPUT_NAMES) & set(model_kwargs.keys())
# There are 5 possible scenarios
if inputs is not None and len(model_input_kwarg_names) == 0:
# 1. `inputs` are passed and no model-specific keyword inputs
# -> return input
model_input_name = None
return inputs, model_input_name, model_kwargs
elif inputs is not None and len(model_input_kwarg_names) > 0:
# 2. `inputs` are passed as well as model-specific keyword inputs
# -> not allowed, raise Error
raise ValueError(
f"`inputs`: {inputs}` were passed alongside "
f"{model_input_kwarg_names} which is not allowed."
f"Make sure to not pass any of {model_input_kwarg_names} "
"when `inputs` is defined."
)
elif inputs is None and len(model_input_kwarg_names) == 0:
# 3. no `inputs` and no model-specific keyword inputs are passed
# -> try to create `input_ids` from BOS
input_tensor = self._prepare_input_ids_for_generation(bos_token_id, model_kwargs.get("encoder_outputs"))
return input_tensor, "input_ids", model_kwargs
elif inputs is None and len(model_input_kwarg_names) == 1:
# 4. no `inputs` are passed and exactly one model-specific keyword input
# -> return that model-specific keyword input tensor
model_input_name = model_input_kwarg_names.pop()
input_tensor = model_kwargs.pop(model_input_name)
# make sure model is encoder decoder if not `input_ids`
if not self.config.is_encoder_decoder and model_input_name != "input_ids":
raise ValueError(
f"If {model_input_name} is passed as model-specific keyword "
"input then model has to be an encoder-decoder and not a "
f"{self.__class__.__name__}."
)
return input_tensor, model_input_name, model_kwargs
else:
# 5. no `inputs` are passed and multiple model-specific keyword inputs
# -> not allowed, raise Error
raise ValueError(
f"Can only pass one of {ENCODER_MODEL_INPUT_NAMES}, "
f"but passed {model_input_kwarg_names}."
f"Make sure to only pass one of {model_input_kwarg_names}."
)
def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> Dict[str, Any]:
"""
Implement in subclasses of :class:`~transformers.PreTrainedModel` for custom behavior to prepare inputs in the
@ -393,47 +453,63 @@ class GenerationMixin:
def _prepare_attention_mask_for_generation(
self,
input_ids: torch.Tensor,
inputs: torch.Tensor,
pad_token_id: int,
eos_token_id: int,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.LongTensor:
# First if `inputs_embeds` are given, but no `attention_mask` assume that full attention_mask is used
if inputs_embeds is not None:
return torch.ones((inputs_embeds.shape[0], inputs_embeds.shape[1]), dtype=torch.long, device=self.device)
# Otherwise, use `input_ids`
is_pad_token_in_inputs_ids = (pad_token_id is not None) and (pad_token_id in input_ids)
is_input_ids = isinstance(inputs, torch.LongTensor) and len(inputs.shape) == 2
is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs)
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (
(eos_token_id is not None) and (pad_token_id != eos_token_id)
)
if is_pad_token_in_inputs_ids and is_pad_token_not_equal_to_eos_token_id:
return input_ids.ne(pad_token_id).long()
# Check if input is input_ids and padded -> only then is attention_mask defined
if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id:
return inputs.ne(pad_token_id).long()
else:
return input_ids.new_ones(input_ids.shape, dtype=torch.long)
return torch.ones(inputs.shape[:2], dtype=torch.long, device=self.device)
def _prepare_encoder_decoder_kwargs_for_generation(
self, input_ids: torch.LongTensor, model_kwargs
self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None
) -> Dict[str, Any]:
if "encoder_outputs" not in model_kwargs:
# retrieve encoder hidden states
# 1. get encoder
encoder = self.get_encoder()
# 2. prepare encoder args and encoder kwargs from model kwargs
encoder_args = (inputs_tensor,)
irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
encoder_kwargs = {
argument: value
for argument, value in model_kwargs.items()
if not (argument.startswith("decoder_") or argument.startswith("cross_attn"))
if not any(argument.startswith(p) for p in irrelevant_prefix)
}
model_kwargs["encoder_outputs"]: ModelOutput = encoder(input_ids, return_dict=True, **encoder_kwargs)
# 3. make sure that encoder returns `ModelOutput`
encoder_kwargs["return_dict"] = True
# 4. if model_input_name is not defined then pass input_tensor as
# first input argument and remove from args
if model_input_name is not None:
# make sure inputs_tensor is None in case model
# accepts multiple model input arguments
encoder_kwargs[model_input_name] = inputs_tensor
encoder_args = ()
model_kwargs["encoder_outputs"]: ModelOutput = encoder(*encoder_args, **encoder_kwargs)
return model_kwargs
def _prepare_decoder_input_ids_for_generation(
self, batch_size: int, decoder_start_token_id: int = None, bos_token_id: int = None
self,
batch_size: int,
decoder_start_token_id: int = None,
bos_token_id: int = None,
model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.LongTensor:
decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
decoder_input_ids = torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * decoder_start_token_id
return decoder_input_ids
if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
return model_kwargs.pop("decoder_input_ids")
else:
decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * decoder_start_token_id
def _get_pad_token_id(self, pad_token_id: int = None, eos_token_id: int = None) -> int:
if pad_token_id is None and eos_token_id is not None:
@ -649,7 +725,7 @@ class GenerationMixin:
@torch.no_grad()
def generate(
self,
input_ids: Optional[torch.LongTensor] = None,
inputs: Optional[torch.Tensor] = None,
max_length: Optional[int] = None,
min_length: Optional[int] = None,
do_sample: Optional[bool] = None,
@ -688,18 +764,20 @@ class GenerationMixin:
Generates sequences for models with a language modeling head. The method currently supports greedy decoding,
multinomial sampling, beam-search decoding, and beam-search multinomial sampling.
Apart from :obj:`input_ids` and :obj:`attention_mask`, all the arguments below will default to the value of the
attribute of the same name inside the :class:`~transformers.PretrainedConfig` of the model. The default values
indicated are the default values of those config.
Apart from :obj:`inputs`, all the arguments below will default to the value of the attribute of the same name
inside the :class:`~transformers.PretrainedConfig` of the model. The default values indicated are the default
values of those config.
Most of these parameters are explained in more detail in `this blog post
<https://huggingface.co/blog/how-to-generate>`__.
Parameters:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
The sequence used as a prompt for the generation. If :obj:`None` the method initializes it with
:obj:`bos_token_id` and a batch size of 1.
inputs (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, :obj:`(batch_size, sequence_length, feature_dim)` or :obj:`(batch_size, num_channels, height, width)`, `optional`):
The sequence used as a prompt for the generation or as model inputs to the encoder. If :obj:`None` the
method initializes it with :obj:`bos_token_id` and a batch size of 1. For decoder-only models
:obj:`inputs` should of in the format of :obj:`input_ids`. For encoder-decoder models `inputs` can
represent any of :obj:`input_ids`, :obj:`input_values`, :obj:`input_features`, or :obj:`pixel_values`.
max_length (:obj:`int`, `optional`, defaults to :obj:`model.config.max_length`):
The maximum length of the sequence to be generated.
max_new_tokens (:obj:`int`, `optional`, defaults to None):
@ -870,8 +948,11 @@ class GenerationMixin:
>>> outputs = model.generate(input_ids=input_ids, max_length=20, do_sample=True, bad_words_ids=bad_words_ids)
>>> print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True))
"""
# 1. Set generation parameters if not already defined
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
num_beams = num_beams if num_beams is not None else self.config.num_beams
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
num_beam_groups = num_beam_groups if num_beam_groups is not None else self.config.num_beam_groups
do_sample = do_sample if do_sample is not None else self.config.do_sample
num_return_sequences = (
@ -879,7 +960,6 @@ class GenerationMixin:
)
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
output_scores = output_scores if output_scores is not None else self.config.output_scores
@ -891,55 +971,52 @@ class GenerationMixin:
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
)
model_kwargs["output_attentions"] = output_attentions
model_kwargs["output_hidden_states"] = output_hidden_states
if input_ids is None and "inputs_embeds" not in model_kwargs:
# init `input_ids` with bos_token_id
input_ids = self._prepare_input_ids_for_generation(bos_token_id, model_kwargs.get("encoder_outputs"))
if model_kwargs.get("attention_mask", None) is None:
# init `attention_mask` depending on `pad_token_id`
inputs_embeds = model_kwargs.get("inputs_embeds", None)
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
input_ids, pad_token_id, eos_token_id, inputs_embeds
)
# special case if pad_token_id is not defined
if pad_token_id is None and eos_token_id is not None:
# special case if pad_token_id is not defined
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
pad_token_id = eos_token_id
# Storing encoder_input_ids for logits_processor that could use them
encoder_input_ids = input_ids if self.config.is_encoder_decoder else None
# 2. Define model inputs
# inputs_tensor has to be defined
# model_input_name is defined if model-specific keyword input is passed
# otherwise model_input_name is None
# all model-specific keyword inputs are removed from `model_kwargs`
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(inputs, bos_token_id, model_kwargs)
batch_size = inputs_tensor.shape[0]
# 3. Define other model kwargs
model_kwargs["output_attentions"] = output_attentions
model_kwargs["output_hidden_states"] = output_hidden_states
model_kwargs["use_cache"] = use_cache
if model_kwargs.get("attention_mask", None) is None:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
inputs_tensor, pad_token_id, eos_token_id
)
if self.config.is_encoder_decoder:
# add encoder_outputs to model_kwargs
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, model_kwargs)
# if model is encoder decoder encoder_outputs are created
# and added to `model_kwargs`
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
inputs_tensor, model_kwargs, model_input_name
)
# set input_ids as decoder_input_ids
if "decoder_input_ids" in model_kwargs:
input_ids = model_kwargs.pop("decoder_input_ids")
else:
# if word embeddings are provided directly, infere the batch size from it
batch_size = input_ids.shape[0] if input_ids is not None else model_kwargs["inputs_embeds"].shape[0]
input_ids = self._prepare_decoder_input_ids_for_generation(
batch_size, decoder_start_token_id=decoder_start_token_id, bos_token_id=bos_token_id
)
if "encoder_outputs" not in model_kwargs or not isinstance(model_kwargs["encoder_outputs"], ModelOutput):
raise ValueError("Make sure that `model_kwargs` include `encoder_outputs` of type `ModelOutput`.")
# 4. Prepare `input_ids` which will be used for auto-regressive generation
if self.config.is_encoder_decoder:
input_ids = self._prepare_decoder_input_ids_for_generation(
batch_size,
decoder_start_token_id=decoder_start_token_id,
bos_token_id=bos_token_id,
model_kwargs=model_kwargs,
)
else:
if "inputs_embeds" in model_kwargs and input_ids is None:
raise ValueError("For decoder-only generation, one must pass `input_ids`.")
# if decoder-only then inputs_tensor has to be `input_ids`
input_ids = inputs_tensor
# 5. Prepare `max_length` depending on other stopping criteria
# if `max_new_tokens` is passed, but not `max_length` -> set `max_length = max_new_tokens`
if max_length is None and max_new_tokens is not None:
max_length = (
max_new_tokens + input_ids.shape[-1]
if input_ids is not None
else max_length + model_kwargs["inputs_embeds"].shape[1]
)
max_length = max_new_tokens + input_ids.shape[-1]
elif max_length is not None and max_new_tokens is not None:
# Both are set, this is odd, raise a warning
warnings.warn(
@ -948,7 +1025,6 @@ class GenerationMixin:
f"will take priority over `max_new_tokens` {max_new_tokens}.",
UserWarning,
)
# default to config if still None
max_length = max_length if max_length is not None else self.config.max_length
@ -959,12 +1035,13 @@ class GenerationMixin:
"This can lead to unexpected behavior. You should consider increasing ``config.max_length`` or ``max_length``."
)
# determine generation mode
# 6. determine generation mode
is_greedy_gen_mode = (num_beams == 1) and (num_beam_groups == 1) and do_sample is False
is_sample_gen_mode = (num_beams == 1) and (num_beam_groups == 1) and do_sample is True
is_beam_gen_mode = (num_beams > 1) and (num_beam_groups == 1) and do_sample is False
is_beam_sample_gen_mode = (num_beams > 1) and (num_beam_groups == 1) and do_sample is True
is_group_beam_gen_mode = (num_beams > 1) and (num_beam_groups > 1)
if num_beam_groups > num_beams:
raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`")
if is_group_beam_gen_mode and do_sample is True:
@ -972,15 +1049,12 @@ class GenerationMixin:
"Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`."
)
# set model_kwargs
model_kwargs["use_cache"] = use_cache
# get distribution pre_processing samplers
# 7. prepare distribution pre_processing samplers
logits_processor = self._get_logits_processor(
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size,
encoder_input_ids=encoder_input_ids,
encoder_input_ids=inputs_tensor,
bad_words_ids=bad_words_ids,
min_length=min_length,
max_length=max_length,
@ -994,15 +1068,17 @@ class GenerationMixin:
remove_invalid_values=remove_invalid_values,
)
# 8. prepare stopping criteria
stopping_criteria = self._get_stopping_criteria(max_length=max_length, max_time=max_time)
# 9. go into different generation modes
if is_greedy_gen_mode:
if num_return_sequences > 1:
raise ValueError(
f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search."
)
# greedy search
# 10. run greedy search
return self.greedy_search(
input_ids,
logits_processor=logits_processor,
@ -1016,12 +1092,12 @@ class GenerationMixin:
)
elif is_sample_gen_mode:
# get probability distribution warper
# 10. prepare logits warper
logits_warper = self._get_logits_warper(
top_k=top_k, top_p=top_p, temperature=temperature, num_beams=num_beams
)
# expand input_ids with `num_return_sequences` additional sequences per batch
# 11. expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids,
expand_size=num_return_sequences,
@ -1029,7 +1105,7 @@ class GenerationMixin:
**model_kwargs,
)
# sample
# 12. run sample
return self.sample(
input_ids,
logits_processor=logits_processor,
@ -1044,17 +1120,13 @@ class GenerationMixin:
)
elif is_beam_gen_mode:
batch_size = input_ids.shape[0]
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
if num_return_sequences > num_beams:
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
if stopping_criteria.max_length is None:
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
# 10. prepare beam search scorer
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
num_beams=num_beams,
@ -1063,10 +1135,11 @@ class GenerationMixin:
do_early_stopping=early_stopping,
num_beam_hyps_to_keep=num_return_sequences,
)
# interleave with `num_beams`
# 11. interleave input_ids with `num_beams` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs
)
# 12. run beam search
return self.beam_search(
input_ids,
beam_scorer,
@ -1081,24 +1154,23 @@ class GenerationMixin:
)
elif is_beam_sample_gen_mode:
# 10. prepare logits warper
logits_warper = self._get_logits_warper(
top_k=top_k, top_p=top_p, temperature=temperature, num_beams=num_beams
)
batch_size = input_ids.shape[0] * num_return_sequences
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
if stopping_criteria.max_length is None:
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
# 11. prepare beam search scorer
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
batch_size=batch_size * num_return_sequences,
num_beams=num_beams,
device=self.device,
length_penalty=length_penalty,
do_early_stopping=early_stopping,
)
# interleave with `num_beams * num_return_sequences`
# 12. interleave input_ids with `num_beams` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids,
expand_size=num_beams * num_return_sequences,
@ -1106,6 +1178,7 @@ class GenerationMixin:
**model_kwargs,
)
# 13. run beam sample
return self.beam_sample(
input_ids,
beam_scorer,
@ -1121,11 +1194,6 @@ class GenerationMixin:
)
elif is_group_beam_gen_mode:
batch_size = input_ids.shape[0]
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
if num_return_sequences > num_beams:
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
@ -1135,7 +1203,8 @@ class GenerationMixin:
if stopping_criteria.max_length is None:
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
diverse_beam_scorer = BeamSearchScorer(
# 10. prepare beam search scorer
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
num_beams=num_beams,
max_length=stopping_criteria.max_length,
@ -1145,13 +1214,14 @@ class GenerationMixin:
num_beam_hyps_to_keep=num_return_sequences,
num_beam_groups=num_beam_groups,
)
# interleave with `num_beams`
# 11. interleave input_ids with `num_beams` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs
)
# 12. run beam search
return self.group_beam_search(
input_ids,
diverse_beam_scorer,
beam_scorer,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
pad_token_id=pad_token_id,

View File

@ -20,6 +20,8 @@ import unittest
from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_modeling_common import floats_tensor
if is_torch_available():
import torch
@ -29,6 +31,9 @@ if is_torch_available():
BartTokenizer,
GPT2LMHeadModel,
GPT2Tokenizer,
Speech2TextForConditionalGeneration,
SpeechEncoderDecoderModel,
VisionEncoderDecoderModel,
top_k_top_p_filtering,
)
from transformers.generation_beam_search import BeamSearchScorer
@ -1724,3 +1729,74 @@ class GenerationIntegrationTests(unittest.TestCase):
# cannot generate from `inputs_embeds` for decoder only
with self.assertRaises(ValueError):
model.generate(inputs_embeds=inputs_embeds)
def test_generate_input_ids_as_kwarg(self):
article = """I need input_ids to generate"""
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2", max_length=15).to(torch_device)
input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
output_sequences_kwargs = model.generate(input_ids=input_ids).cpu()
output_sequences = model.generate(input_ids).cpu()
self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist())
self.assertEqual(output_sequences.shape, (1, 15))
def test_generate_input_ids_as_encoder_kwarg(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart", max_length=5).to(
torch_device
)
model.config.eos_token_id = None
input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
output_sequences_kwargs = model.generate(input_ids=input_ids).cpu()
output_sequences = model.generate(input_ids).cpu()
self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist())
self.assertEqual(output_sequences.shape, (1, 5))
def test_generate_inputs_and_encoder_kwargs(self):
article = """I need input_ids to generate"""
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2", max_length=10).to(torch_device)
input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
with self.assertRaises(ValueError):
model.generate(input_ids, input_ids=input_ids)
def test_generate_too_many_encoder_kwargs(self):
article = """I need input_ids to generate"""
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2", max_length=10).to(torch_device)
input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
with self.assertRaises(ValueError):
model.generate(input_ids=input_ids, input_values=input_ids)
def test_generate_input_values_as_encoder_kwarg(self):
input_values = floats_tensor((2, 250))
model = SpeechEncoderDecoderModel.from_pretrained("hf-internal-testing/tiny-random-speech-encoder-decoder")
model = model.to(torch_device)
output_sequences_kwargs = model.generate(input_values=input_values, max_length=5).cpu()
output_sequences = model.generate(input_values, max_length=5).cpu()
self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist())
self.assertEqual(output_sequences.shape, (2, 5))
def test_generate_input_features_as_encoder_kwarg(self):
input_features = floats_tensor((3, 20, 24))
model = Speech2TextForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-speech_to_text")
model = model.to(torch_device)
output_sequences_kwargs = model.generate(input_features=input_features, max_length=5).cpu()
output_sequences = model.generate(input_features, max_length=5).cpu()
self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist())
self.assertEqual(output_sequences.shape, (3, 5))
def test_generate_pixel_values_as_encoder_kwarg(self):
pixel_values = floats_tensor((2, 3, 30, 30))
model = VisionEncoderDecoderModel.from_pretrained("hf-internal-testing/tiny-random-vision-encoder-decoder")
model = model.to(torch_device)
output_sequences_kwargs = model.generate(pixel_values=pixel_values, max_length=5).cpu()
output_sequences = model.generate(pixel_values, max_length=5).cpu()
self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist())
self.assertEqual(output_sequences.shape, (2, 5))