[Generate] Make generate multi-modal (#14784)

* finish refactor

* refactor

* add tests

* add more tests

* up

* finish tests

* finish

* up

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* improve docstring

* fix docs

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Patrick von Platen 2021-12-16 18:03:55 +01:00 committed by GitHub
parent 48463ebb33
commit b18d8534ea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 243 additions and 97 deletions

View File

@ -359,12 +359,72 @@ BeamSearchOutput = Union[BeamSearchEncoderDecoderOutput, BeamSearchDecoderOnlyOu
BeamSampleOutput = Union[BeamSampleEncoderDecoderOutput, BeamSampleDecoderOnlyOutput] BeamSampleOutput = Union[BeamSampleEncoderDecoderOutput, BeamSampleDecoderOnlyOutput]
ENCODER_MODEL_INPUT_NAMES = ["input_ids", "inputs_embeds", "input_values", "input_features", "pixel_values"]
class GenerationMixin: class GenerationMixin:
""" """
A class containing all of the functions supporting generation, to be used as a mixin in A class containing all of the functions supporting generation, to be used as a mixin in
:class:`~transformers.PreTrainedModel`. :class:`~transformers.PreTrainedModel`.
""" """
def _prepare_model_inputs(
self,
inputs: Optional[torch.Tensor] = None,
bos_token_id: Optional[int] = None,
model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Optional[str]]:
"""
This function extracts the model-specific `inputs` for generation.
"""
# filter model input names that are `None`
model_kwargs = {k: v for k, v in model_kwargs.items() if k not in ENCODER_MODEL_INPUT_NAMES or v is not None}
# extract keyword arguments that are model input specific
model_input_kwarg_names = set(ENCODER_MODEL_INPUT_NAMES) & set(model_kwargs.keys())
# There are 5 possible scenarios
if inputs is not None and len(model_input_kwarg_names) == 0:
# 1. `inputs` are passed and no model-specific keyword inputs
# -> return input
model_input_name = None
return inputs, model_input_name, model_kwargs
elif inputs is not None and len(model_input_kwarg_names) > 0:
# 2. `inputs` are passed as well as model-specific keyword inputs
# -> not allowed, raise Error
raise ValueError(
f"`inputs`: {inputs}` were passed alongside "
f"{model_input_kwarg_names} which is not allowed."
f"Make sure to not pass any of {model_input_kwarg_names} "
"when `inputs` is defined."
)
elif inputs is None and len(model_input_kwarg_names) == 0:
# 3. no `inputs` and no model-specific keyword inputs are passed
# -> try to create `input_ids` from BOS
input_tensor = self._prepare_input_ids_for_generation(bos_token_id, model_kwargs.get("encoder_outputs"))
return input_tensor, "input_ids", model_kwargs
elif inputs is None and len(model_input_kwarg_names) == 1:
# 4. no `inputs` are passed and exactly one model-specific keyword input
# -> return that model-specific keyword input tensor
model_input_name = model_input_kwarg_names.pop()
input_tensor = model_kwargs.pop(model_input_name)
# make sure model is encoder decoder if not `input_ids`
if not self.config.is_encoder_decoder and model_input_name != "input_ids":
raise ValueError(
f"If {model_input_name} is passed as model-specific keyword "
"input then model has to be an encoder-decoder and not a "
f"{self.__class__.__name__}."
)
return input_tensor, model_input_name, model_kwargs
else:
# 5. no `inputs` are passed and multiple model-specific keyword inputs
# -> not allowed, raise Error
raise ValueError(
f"Can only pass one of {ENCODER_MODEL_INPUT_NAMES}, "
f"but passed {model_input_kwarg_names}."
f"Make sure to only pass one of {model_input_kwarg_names}."
)
def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> Dict[str, Any]: def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> Dict[str, Any]:
""" """
Implement in subclasses of :class:`~transformers.PreTrainedModel` for custom behavior to prepare inputs in the Implement in subclasses of :class:`~transformers.PreTrainedModel` for custom behavior to prepare inputs in the
@ -393,47 +453,63 @@ class GenerationMixin:
def _prepare_attention_mask_for_generation( def _prepare_attention_mask_for_generation(
self, self,
input_ids: torch.Tensor, inputs: torch.Tensor,
pad_token_id: int, pad_token_id: int,
eos_token_id: int, eos_token_id: int,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.LongTensor: ) -> torch.LongTensor:
is_input_ids = isinstance(inputs, torch.LongTensor) and len(inputs.shape) == 2
# First if `inputs_embeds` are given, but no `attention_mask` assume that full attention_mask is used is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs)
if inputs_embeds is not None:
return torch.ones((inputs_embeds.shape[0], inputs_embeds.shape[1]), dtype=torch.long, device=self.device)
# Otherwise, use `input_ids`
is_pad_token_in_inputs_ids = (pad_token_id is not None) and (pad_token_id in input_ids)
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ( is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (
(eos_token_id is not None) and (pad_token_id != eos_token_id) (eos_token_id is not None) and (pad_token_id != eos_token_id)
) )
if is_pad_token_in_inputs_ids and is_pad_token_not_equal_to_eos_token_id: # Check if input is input_ids and padded -> only then is attention_mask defined
return input_ids.ne(pad_token_id).long() if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id:
return inputs.ne(pad_token_id).long()
else: else:
return input_ids.new_ones(input_ids.shape, dtype=torch.long) return torch.ones(inputs.shape[:2], dtype=torch.long, device=self.device)
def _prepare_encoder_decoder_kwargs_for_generation( def _prepare_encoder_decoder_kwargs_for_generation(
self, input_ids: torch.LongTensor, model_kwargs self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
if "encoder_outputs" not in model_kwargs: if "encoder_outputs" not in model_kwargs:
# retrieve encoder hidden states # 1. get encoder
encoder = self.get_encoder() encoder = self.get_encoder()
# 2. prepare encoder args and encoder kwargs from model kwargs
encoder_args = (inputs_tensor,)
irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
encoder_kwargs = { encoder_kwargs = {
argument: value argument: value
for argument, value in model_kwargs.items() for argument, value in model_kwargs.items()
if not (argument.startswith("decoder_") or argument.startswith("cross_attn")) if not any(argument.startswith(p) for p in irrelevant_prefix)
} }
model_kwargs["encoder_outputs"]: ModelOutput = encoder(input_ids, return_dict=True, **encoder_kwargs) # 3. make sure that encoder returns `ModelOutput`
encoder_kwargs["return_dict"] = True
# 4. if model_input_name is not defined then pass input_tensor as
# first input argument and remove from args
if model_input_name is not None:
# make sure inputs_tensor is None in case model
# accepts multiple model input arguments
encoder_kwargs[model_input_name] = inputs_tensor
encoder_args = ()
model_kwargs["encoder_outputs"]: ModelOutput = encoder(*encoder_args, **encoder_kwargs)
return model_kwargs return model_kwargs
def _prepare_decoder_input_ids_for_generation( def _prepare_decoder_input_ids_for_generation(
self, batch_size: int, decoder_start_token_id: int = None, bos_token_id: int = None self,
batch_size: int,
decoder_start_token_id: int = None,
bos_token_id: int = None,
model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.LongTensor: ) -> torch.LongTensor:
decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
decoder_input_ids = torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * decoder_start_token_id if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
return decoder_input_ids return model_kwargs.pop("decoder_input_ids")
else:
decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * decoder_start_token_id
def _get_pad_token_id(self, pad_token_id: int = None, eos_token_id: int = None) -> int: def _get_pad_token_id(self, pad_token_id: int = None, eos_token_id: int = None) -> int:
if pad_token_id is None and eos_token_id is not None: if pad_token_id is None and eos_token_id is not None:
@ -649,7 +725,7 @@ class GenerationMixin:
@torch.no_grad() @torch.no_grad()
def generate( def generate(
self, self,
input_ids: Optional[torch.LongTensor] = None, inputs: Optional[torch.Tensor] = None,
max_length: Optional[int] = None, max_length: Optional[int] = None,
min_length: Optional[int] = None, min_length: Optional[int] = None,
do_sample: Optional[bool] = None, do_sample: Optional[bool] = None,
@ -688,18 +764,20 @@ class GenerationMixin:
Generates sequences for models with a language modeling head. The method currently supports greedy decoding, Generates sequences for models with a language modeling head. The method currently supports greedy decoding,
multinomial sampling, beam-search decoding, and beam-search multinomial sampling. multinomial sampling, beam-search decoding, and beam-search multinomial sampling.
Apart from :obj:`input_ids` and :obj:`attention_mask`, all the arguments below will default to the value of the Apart from :obj:`inputs`, all the arguments below will default to the value of the attribute of the same name
attribute of the same name inside the :class:`~transformers.PretrainedConfig` of the model. The default values inside the :class:`~transformers.PretrainedConfig` of the model. The default values indicated are the default
indicated are the default values of those config. values of those config.
Most of these parameters are explained in more detail in `this blog post Most of these parameters are explained in more detail in `this blog post
<https://huggingface.co/blog/how-to-generate>`__. <https://huggingface.co/blog/how-to-generate>`__.
Parameters: Parameters:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): inputs (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, :obj:`(batch_size, sequence_length, feature_dim)` or :obj:`(batch_size, num_channels, height, width)`, `optional`):
The sequence used as a prompt for the generation. If :obj:`None` the method initializes it with The sequence used as a prompt for the generation or as model inputs to the encoder. If :obj:`None` the
:obj:`bos_token_id` and a batch size of 1. method initializes it with :obj:`bos_token_id` and a batch size of 1. For decoder-only models
:obj:`inputs` should of in the format of :obj:`input_ids`. For encoder-decoder models `inputs` can
represent any of :obj:`input_ids`, :obj:`input_values`, :obj:`input_features`, or :obj:`pixel_values`.
max_length (:obj:`int`, `optional`, defaults to :obj:`model.config.max_length`): max_length (:obj:`int`, `optional`, defaults to :obj:`model.config.max_length`):
The maximum length of the sequence to be generated. The maximum length of the sequence to be generated.
max_new_tokens (:obj:`int`, `optional`, defaults to None): max_new_tokens (:obj:`int`, `optional`, defaults to None):
@ -870,8 +948,11 @@ class GenerationMixin:
>>> outputs = model.generate(input_ids=input_ids, max_length=20, do_sample=True, bad_words_ids=bad_words_ids) >>> outputs = model.generate(input_ids=input_ids, max_length=20, do_sample=True, bad_words_ids=bad_words_ids)
>>> print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True)) >>> print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True))
""" """
# 1. Set generation parameters if not already defined
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
num_beams = num_beams if num_beams is not None else self.config.num_beams num_beams = num_beams if num_beams is not None else self.config.num_beams
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
num_beam_groups = num_beam_groups if num_beam_groups is not None else self.config.num_beam_groups num_beam_groups = num_beam_groups if num_beam_groups is not None else self.config.num_beam_groups
do_sample = do_sample if do_sample is not None else self.config.do_sample do_sample = do_sample if do_sample is not None else self.config.do_sample
num_return_sequences = ( num_return_sequences = (
@ -879,7 +960,6 @@ class GenerationMixin:
) )
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
output_scores = output_scores if output_scores is not None else self.config.output_scores output_scores = output_scores if output_scores is not None else self.config.output_scores
@ -891,55 +971,52 @@ class GenerationMixin:
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
) )
model_kwargs["output_attentions"] = output_attentions
model_kwargs["output_hidden_states"] = output_hidden_states
if input_ids is None and "inputs_embeds" not in model_kwargs:
# init `input_ids` with bos_token_id
input_ids = self._prepare_input_ids_for_generation(bos_token_id, model_kwargs.get("encoder_outputs"))
if model_kwargs.get("attention_mask", None) is None:
# init `attention_mask` depending on `pad_token_id`
inputs_embeds = model_kwargs.get("inputs_embeds", None)
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
input_ids, pad_token_id, eos_token_id, inputs_embeds
)
# special case if pad_token_id is not defined
if pad_token_id is None and eos_token_id is not None: if pad_token_id is None and eos_token_id is not None:
# special case if pad_token_id is not defined
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
pad_token_id = eos_token_id pad_token_id = eos_token_id
# Storing encoder_input_ids for logits_processor that could use them # 2. Define model inputs
encoder_input_ids = input_ids if self.config.is_encoder_decoder else None # inputs_tensor has to be defined
# model_input_name is defined if model-specific keyword input is passed
# otherwise model_input_name is None
# all model-specific keyword inputs are removed from `model_kwargs`
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(inputs, bos_token_id, model_kwargs)
batch_size = inputs_tensor.shape[0]
# 3. Define other model kwargs
model_kwargs["output_attentions"] = output_attentions
model_kwargs["output_hidden_states"] = output_hidden_states
model_kwargs["use_cache"] = use_cache
if model_kwargs.get("attention_mask", None) is None:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
inputs_tensor, pad_token_id, eos_token_id
)
if self.config.is_encoder_decoder: if self.config.is_encoder_decoder:
# add encoder_outputs to model_kwargs # if model is encoder decoder encoder_outputs are created
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, model_kwargs) # and added to `model_kwargs`
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
inputs_tensor, model_kwargs, model_input_name
)
# set input_ids as decoder_input_ids # 4. Prepare `input_ids` which will be used for auto-regressive generation
if "decoder_input_ids" in model_kwargs: if self.config.is_encoder_decoder:
input_ids = model_kwargs.pop("decoder_input_ids") input_ids = self._prepare_decoder_input_ids_for_generation(
else: batch_size,
# if word embeddings are provided directly, infere the batch size from it decoder_start_token_id=decoder_start_token_id,
batch_size = input_ids.shape[0] if input_ids is not None else model_kwargs["inputs_embeds"].shape[0] bos_token_id=bos_token_id,
input_ids = self._prepare_decoder_input_ids_for_generation( model_kwargs=model_kwargs,
batch_size, decoder_start_token_id=decoder_start_token_id, bos_token_id=bos_token_id )
)
if "encoder_outputs" not in model_kwargs or not isinstance(model_kwargs["encoder_outputs"], ModelOutput):
raise ValueError("Make sure that `model_kwargs` include `encoder_outputs` of type `ModelOutput`.")
else: else:
if "inputs_embeds" in model_kwargs and input_ids is None: # if decoder-only then inputs_tensor has to be `input_ids`
raise ValueError("For decoder-only generation, one must pass `input_ids`.") input_ids = inputs_tensor
# 5. Prepare `max_length` depending on other stopping criteria
# if `max_new_tokens` is passed, but not `max_length` -> set `max_length = max_new_tokens` # if `max_new_tokens` is passed, but not `max_length` -> set `max_length = max_new_tokens`
if max_length is None and max_new_tokens is not None: if max_length is None and max_new_tokens is not None:
max_length = ( max_length = max_new_tokens + input_ids.shape[-1]
max_new_tokens + input_ids.shape[-1]
if input_ids is not None
else max_length + model_kwargs["inputs_embeds"].shape[1]
)
elif max_length is not None and max_new_tokens is not None: elif max_length is not None and max_new_tokens is not None:
# Both are set, this is odd, raise a warning # Both are set, this is odd, raise a warning
warnings.warn( warnings.warn(
@ -948,7 +1025,6 @@ class GenerationMixin:
f"will take priority over `max_new_tokens` {max_new_tokens}.", f"will take priority over `max_new_tokens` {max_new_tokens}.",
UserWarning, UserWarning,
) )
# default to config if still None # default to config if still None
max_length = max_length if max_length is not None else self.config.max_length max_length = max_length if max_length is not None else self.config.max_length
@ -959,12 +1035,13 @@ class GenerationMixin:
"This can lead to unexpected behavior. You should consider increasing ``config.max_length`` or ``max_length``." "This can lead to unexpected behavior. You should consider increasing ``config.max_length`` or ``max_length``."
) )
# determine generation mode # 6. determine generation mode
is_greedy_gen_mode = (num_beams == 1) and (num_beam_groups == 1) and do_sample is False is_greedy_gen_mode = (num_beams == 1) and (num_beam_groups == 1) and do_sample is False
is_sample_gen_mode = (num_beams == 1) and (num_beam_groups == 1) and do_sample is True is_sample_gen_mode = (num_beams == 1) and (num_beam_groups == 1) and do_sample is True
is_beam_gen_mode = (num_beams > 1) and (num_beam_groups == 1) and do_sample is False is_beam_gen_mode = (num_beams > 1) and (num_beam_groups == 1) and do_sample is False
is_beam_sample_gen_mode = (num_beams > 1) and (num_beam_groups == 1) and do_sample is True is_beam_sample_gen_mode = (num_beams > 1) and (num_beam_groups == 1) and do_sample is True
is_group_beam_gen_mode = (num_beams > 1) and (num_beam_groups > 1) is_group_beam_gen_mode = (num_beams > 1) and (num_beam_groups > 1)
if num_beam_groups > num_beams: if num_beam_groups > num_beams:
raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`") raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`")
if is_group_beam_gen_mode and do_sample is True: if is_group_beam_gen_mode and do_sample is True:
@ -972,15 +1049,12 @@ class GenerationMixin:
"Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`." "Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`."
) )
# set model_kwargs # 7. prepare distribution pre_processing samplers
model_kwargs["use_cache"] = use_cache
# get distribution pre_processing samplers
logits_processor = self._get_logits_processor( logits_processor = self._get_logits_processor(
repetition_penalty=repetition_penalty, repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size, no_repeat_ngram_size=no_repeat_ngram_size,
encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size, encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size,
encoder_input_ids=encoder_input_ids, encoder_input_ids=inputs_tensor,
bad_words_ids=bad_words_ids, bad_words_ids=bad_words_ids,
min_length=min_length, min_length=min_length,
max_length=max_length, max_length=max_length,
@ -994,15 +1068,17 @@ class GenerationMixin:
remove_invalid_values=remove_invalid_values, remove_invalid_values=remove_invalid_values,
) )
# 8. prepare stopping criteria
stopping_criteria = self._get_stopping_criteria(max_length=max_length, max_time=max_time) stopping_criteria = self._get_stopping_criteria(max_length=max_length, max_time=max_time)
# 9. go into different generation modes
if is_greedy_gen_mode: if is_greedy_gen_mode:
if num_return_sequences > 1: if num_return_sequences > 1:
raise ValueError( raise ValueError(
f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search." f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search."
) )
# greedy search # 10. run greedy search
return self.greedy_search( return self.greedy_search(
input_ids, input_ids,
logits_processor=logits_processor, logits_processor=logits_processor,
@ -1016,12 +1092,12 @@ class GenerationMixin:
) )
elif is_sample_gen_mode: elif is_sample_gen_mode:
# get probability distribution warper # 10. prepare logits warper
logits_warper = self._get_logits_warper( logits_warper = self._get_logits_warper(
top_k=top_k, top_p=top_p, temperature=temperature, num_beams=num_beams top_k=top_k, top_p=top_p, temperature=temperature, num_beams=num_beams
) )
# expand input_ids with `num_return_sequences` additional sequences per batch # 11. expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids, input_ids,
expand_size=num_return_sequences, expand_size=num_return_sequences,
@ -1029,7 +1105,7 @@ class GenerationMixin:
**model_kwargs, **model_kwargs,
) )
# sample # 12. run sample
return self.sample( return self.sample(
input_ids, input_ids,
logits_processor=logits_processor, logits_processor=logits_processor,
@ -1044,17 +1120,13 @@ class GenerationMixin:
) )
elif is_beam_gen_mode: elif is_beam_gen_mode:
batch_size = input_ids.shape[0]
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
if num_return_sequences > num_beams: if num_return_sequences > num_beams:
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
if stopping_criteria.max_length is None: if stopping_criteria.max_length is None:
raise ValueError("`max_length` needs to be a stopping_criteria for now.") raise ValueError("`max_length` needs to be a stopping_criteria for now.")
# 10. prepare beam search scorer
beam_scorer = BeamSearchScorer( beam_scorer = BeamSearchScorer(
batch_size=batch_size, batch_size=batch_size,
num_beams=num_beams, num_beams=num_beams,
@ -1063,10 +1135,11 @@ class GenerationMixin:
do_early_stopping=early_stopping, do_early_stopping=early_stopping,
num_beam_hyps_to_keep=num_return_sequences, num_beam_hyps_to_keep=num_return_sequences,
) )
# interleave with `num_beams` # 11. interleave input_ids with `num_beams` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs
) )
# 12. run beam search
return self.beam_search( return self.beam_search(
input_ids, input_ids,
beam_scorer, beam_scorer,
@ -1081,24 +1154,23 @@ class GenerationMixin:
) )
elif is_beam_sample_gen_mode: elif is_beam_sample_gen_mode:
# 10. prepare logits warper
logits_warper = self._get_logits_warper( logits_warper = self._get_logits_warper(
top_k=top_k, top_p=top_p, temperature=temperature, num_beams=num_beams top_k=top_k, top_p=top_p, temperature=temperature, num_beams=num_beams
) )
batch_size = input_ids.shape[0] * num_return_sequences
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
if stopping_criteria.max_length is None: if stopping_criteria.max_length is None:
raise ValueError("`max_length` needs to be a stopping_criteria for now.") raise ValueError("`max_length` needs to be a stopping_criteria for now.")
# 11. prepare beam search scorer
beam_scorer = BeamSearchScorer( beam_scorer = BeamSearchScorer(
batch_size=batch_size, batch_size=batch_size * num_return_sequences,
num_beams=num_beams, num_beams=num_beams,
device=self.device, device=self.device,
length_penalty=length_penalty, length_penalty=length_penalty,
do_early_stopping=early_stopping, do_early_stopping=early_stopping,
) )
# interleave with `num_beams * num_return_sequences` # 12. interleave input_ids with `num_beams` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids, input_ids,
expand_size=num_beams * num_return_sequences, expand_size=num_beams * num_return_sequences,
@ -1106,6 +1178,7 @@ class GenerationMixin:
**model_kwargs, **model_kwargs,
) )
# 13. run beam sample
return self.beam_sample( return self.beam_sample(
input_ids, input_ids,
beam_scorer, beam_scorer,
@ -1121,11 +1194,6 @@ class GenerationMixin:
) )
elif is_group_beam_gen_mode: elif is_group_beam_gen_mode:
batch_size = input_ids.shape[0]
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
if num_return_sequences > num_beams: if num_return_sequences > num_beams:
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
@ -1135,7 +1203,8 @@ class GenerationMixin:
if stopping_criteria.max_length is None: if stopping_criteria.max_length is None:
raise ValueError("`max_length` needs to be a stopping_criteria for now.") raise ValueError("`max_length` needs to be a stopping_criteria for now.")
diverse_beam_scorer = BeamSearchScorer( # 10. prepare beam search scorer
beam_scorer = BeamSearchScorer(
batch_size=batch_size, batch_size=batch_size,
num_beams=num_beams, num_beams=num_beams,
max_length=stopping_criteria.max_length, max_length=stopping_criteria.max_length,
@ -1145,13 +1214,14 @@ class GenerationMixin:
num_beam_hyps_to_keep=num_return_sequences, num_beam_hyps_to_keep=num_return_sequences,
num_beam_groups=num_beam_groups, num_beam_groups=num_beam_groups,
) )
# interleave with `num_beams` # 11. interleave input_ids with `num_beams` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs
) )
# 12. run beam search
return self.group_beam_search( return self.group_beam_search(
input_ids, input_ids,
diverse_beam_scorer, beam_scorer,
logits_processor=logits_processor, logits_processor=logits_processor,
stopping_criteria=stopping_criteria, stopping_criteria=stopping_criteria,
pad_token_id=pad_token_id, pad_token_id=pad_token_id,

View File

@ -20,6 +20,8 @@ import unittest
from transformers import is_torch_available from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device from transformers.testing_utils import require_torch, slow, torch_device
from .test_modeling_common import floats_tensor
if is_torch_available(): if is_torch_available():
import torch import torch
@ -29,6 +31,9 @@ if is_torch_available():
BartTokenizer, BartTokenizer,
GPT2LMHeadModel, GPT2LMHeadModel,
GPT2Tokenizer, GPT2Tokenizer,
Speech2TextForConditionalGeneration,
SpeechEncoderDecoderModel,
VisionEncoderDecoderModel,
top_k_top_p_filtering, top_k_top_p_filtering,
) )
from transformers.generation_beam_search import BeamSearchScorer from transformers.generation_beam_search import BeamSearchScorer
@ -1724,3 +1729,74 @@ class GenerationIntegrationTests(unittest.TestCase):
# cannot generate from `inputs_embeds` for decoder only # cannot generate from `inputs_embeds` for decoder only
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
model.generate(inputs_embeds=inputs_embeds) model.generate(inputs_embeds=inputs_embeds)
def test_generate_input_ids_as_kwarg(self):
article = """I need input_ids to generate"""
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2", max_length=15).to(torch_device)
input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
output_sequences_kwargs = model.generate(input_ids=input_ids).cpu()
output_sequences = model.generate(input_ids).cpu()
self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist())
self.assertEqual(output_sequences.shape, (1, 15))
def test_generate_input_ids_as_encoder_kwarg(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart", max_length=5).to(
torch_device
)
model.config.eos_token_id = None
input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
output_sequences_kwargs = model.generate(input_ids=input_ids).cpu()
output_sequences = model.generate(input_ids).cpu()
self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist())
self.assertEqual(output_sequences.shape, (1, 5))
def test_generate_inputs_and_encoder_kwargs(self):
article = """I need input_ids to generate"""
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2", max_length=10).to(torch_device)
input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
with self.assertRaises(ValueError):
model.generate(input_ids, input_ids=input_ids)
def test_generate_too_many_encoder_kwargs(self):
article = """I need input_ids to generate"""
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2", max_length=10).to(torch_device)
input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
with self.assertRaises(ValueError):
model.generate(input_ids=input_ids, input_values=input_ids)
def test_generate_input_values_as_encoder_kwarg(self):
input_values = floats_tensor((2, 250))
model = SpeechEncoderDecoderModel.from_pretrained("hf-internal-testing/tiny-random-speech-encoder-decoder")
model = model.to(torch_device)
output_sequences_kwargs = model.generate(input_values=input_values, max_length=5).cpu()
output_sequences = model.generate(input_values, max_length=5).cpu()
self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist())
self.assertEqual(output_sequences.shape, (2, 5))
def test_generate_input_features_as_encoder_kwarg(self):
input_features = floats_tensor((3, 20, 24))
model = Speech2TextForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-speech_to_text")
model = model.to(torch_device)
output_sequences_kwargs = model.generate(input_features=input_features, max_length=5).cpu()
output_sequences = model.generate(input_features, max_length=5).cpu()
self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist())
self.assertEqual(output_sequences.shape, (3, 5))
def test_generate_pixel_values_as_encoder_kwarg(self):
pixel_values = floats_tensor((2, 3, 30, 30))
model = VisionEncoderDecoderModel.from_pretrained("hf-internal-testing/tiny-random-vision-encoder-decoder")
model = model.to(torch_device)
output_sequences_kwargs = model.generate(pixel_values=pixel_values, max_length=5).cpu()
output_sequences = model.generate(pixel_values, max_length=5).cpu()
self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist())
self.assertEqual(output_sequences.shape, (2, 5))