mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
[Generate] Make generate multi-modal (#14784)
* finish refactor * refactor * add tests * add more tests * up * finish tests * finish * up * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * improve docstring * fix docs Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
48463ebb33
commit
b18d8534ea
@ -359,12 +359,72 @@ BeamSearchOutput = Union[BeamSearchEncoderDecoderOutput, BeamSearchDecoderOnlyOu
|
||||
BeamSampleOutput = Union[BeamSampleEncoderDecoderOutput, BeamSampleDecoderOnlyOutput]
|
||||
|
||||
|
||||
ENCODER_MODEL_INPUT_NAMES = ["input_ids", "inputs_embeds", "input_values", "input_features", "pixel_values"]
|
||||
|
||||
|
||||
class GenerationMixin:
|
||||
"""
|
||||
A class containing all of the functions supporting generation, to be used as a mixin in
|
||||
:class:`~transformers.PreTrainedModel`.
|
||||
"""
|
||||
|
||||
def _prepare_model_inputs(
|
||||
self,
|
||||
inputs: Optional[torch.Tensor] = None,
|
||||
bos_token_id: Optional[int] = None,
|
||||
model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[str]]:
|
||||
"""
|
||||
This function extracts the model-specific `inputs` for generation.
|
||||
"""
|
||||
# filter model input names that are `None`
|
||||
model_kwargs = {k: v for k, v in model_kwargs.items() if k not in ENCODER_MODEL_INPUT_NAMES or v is not None}
|
||||
# extract keyword arguments that are model input specific
|
||||
model_input_kwarg_names = set(ENCODER_MODEL_INPUT_NAMES) & set(model_kwargs.keys())
|
||||
|
||||
# There are 5 possible scenarios
|
||||
if inputs is not None and len(model_input_kwarg_names) == 0:
|
||||
# 1. `inputs` are passed and no model-specific keyword inputs
|
||||
# -> return input
|
||||
model_input_name = None
|
||||
return inputs, model_input_name, model_kwargs
|
||||
elif inputs is not None and len(model_input_kwarg_names) > 0:
|
||||
# 2. `inputs` are passed as well as model-specific keyword inputs
|
||||
# -> not allowed, raise Error
|
||||
raise ValueError(
|
||||
f"`inputs`: {inputs}` were passed alongside "
|
||||
f"{model_input_kwarg_names} which is not allowed."
|
||||
f"Make sure to not pass any of {model_input_kwarg_names} "
|
||||
"when `inputs` is defined."
|
||||
)
|
||||
elif inputs is None and len(model_input_kwarg_names) == 0:
|
||||
# 3. no `inputs` and no model-specific keyword inputs are passed
|
||||
# -> try to create `input_ids` from BOS
|
||||
input_tensor = self._prepare_input_ids_for_generation(bos_token_id, model_kwargs.get("encoder_outputs"))
|
||||
return input_tensor, "input_ids", model_kwargs
|
||||
elif inputs is None and len(model_input_kwarg_names) == 1:
|
||||
# 4. no `inputs` are passed and exactly one model-specific keyword input
|
||||
# -> return that model-specific keyword input tensor
|
||||
model_input_name = model_input_kwarg_names.pop()
|
||||
input_tensor = model_kwargs.pop(model_input_name)
|
||||
|
||||
# make sure model is encoder decoder if not `input_ids`
|
||||
if not self.config.is_encoder_decoder and model_input_name != "input_ids":
|
||||
raise ValueError(
|
||||
f"If {model_input_name} is passed as model-specific keyword "
|
||||
"input then model has to be an encoder-decoder and not a "
|
||||
f"{self.__class__.__name__}."
|
||||
)
|
||||
return input_tensor, model_input_name, model_kwargs
|
||||
else:
|
||||
# 5. no `inputs` are passed and multiple model-specific keyword inputs
|
||||
# -> not allowed, raise Error
|
||||
raise ValueError(
|
||||
f"Can only pass one of {ENCODER_MODEL_INPUT_NAMES}, "
|
||||
f"but passed {model_input_kwarg_names}."
|
||||
f"Make sure to only pass one of {model_input_kwarg_names}."
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
Implement in subclasses of :class:`~transformers.PreTrainedModel` for custom behavior to prepare inputs in the
|
||||
@ -393,47 +453,63 @@ class GenerationMixin:
|
||||
|
||||
def _prepare_attention_mask_for_generation(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
inputs: torch.Tensor,
|
||||
pad_token_id: int,
|
||||
eos_token_id: int,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.LongTensor:
|
||||
|
||||
# First if `inputs_embeds` are given, but no `attention_mask` assume that full attention_mask is used
|
||||
if inputs_embeds is not None:
|
||||
return torch.ones((inputs_embeds.shape[0], inputs_embeds.shape[1]), dtype=torch.long, device=self.device)
|
||||
|
||||
# Otherwise, use `input_ids`
|
||||
is_pad_token_in_inputs_ids = (pad_token_id is not None) and (pad_token_id in input_ids)
|
||||
is_input_ids = isinstance(inputs, torch.LongTensor) and len(inputs.shape) == 2
|
||||
is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs)
|
||||
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (
|
||||
(eos_token_id is not None) and (pad_token_id != eos_token_id)
|
||||
)
|
||||
if is_pad_token_in_inputs_ids and is_pad_token_not_equal_to_eos_token_id:
|
||||
return input_ids.ne(pad_token_id).long()
|
||||
# Check if input is input_ids and padded -> only then is attention_mask defined
|
||||
if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id:
|
||||
return inputs.ne(pad_token_id).long()
|
||||
else:
|
||||
return input_ids.new_ones(input_ids.shape, dtype=torch.long)
|
||||
return torch.ones(inputs.shape[:2], dtype=torch.long, device=self.device)
|
||||
|
||||
def _prepare_encoder_decoder_kwargs_for_generation(
|
||||
self, input_ids: torch.LongTensor, model_kwargs
|
||||
self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
if "encoder_outputs" not in model_kwargs:
|
||||
# retrieve encoder hidden states
|
||||
# 1. get encoder
|
||||
encoder = self.get_encoder()
|
||||
# 2. prepare encoder args and encoder kwargs from model kwargs
|
||||
encoder_args = (inputs_tensor,)
|
||||
irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
|
||||
encoder_kwargs = {
|
||||
argument: value
|
||||
for argument, value in model_kwargs.items()
|
||||
if not (argument.startswith("decoder_") or argument.startswith("cross_attn"))
|
||||
if not any(argument.startswith(p) for p in irrelevant_prefix)
|
||||
}
|
||||
model_kwargs["encoder_outputs"]: ModelOutput = encoder(input_ids, return_dict=True, **encoder_kwargs)
|
||||
# 3. make sure that encoder returns `ModelOutput`
|
||||
encoder_kwargs["return_dict"] = True
|
||||
|
||||
# 4. if model_input_name is not defined then pass input_tensor as
|
||||
# first input argument and remove from args
|
||||
if model_input_name is not None:
|
||||
# make sure inputs_tensor is None in case model
|
||||
# accepts multiple model input arguments
|
||||
encoder_kwargs[model_input_name] = inputs_tensor
|
||||
encoder_args = ()
|
||||
|
||||
model_kwargs["encoder_outputs"]: ModelOutput = encoder(*encoder_args, **encoder_kwargs)
|
||||
|
||||
return model_kwargs
|
||||
|
||||
def _prepare_decoder_input_ids_for_generation(
|
||||
self, batch_size: int, decoder_start_token_id: int = None, bos_token_id: int = None
|
||||
self,
|
||||
batch_size: int,
|
||||
decoder_start_token_id: int = None,
|
||||
bos_token_id: int = None,
|
||||
model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> torch.LongTensor:
|
||||
decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
|
||||
|
||||
decoder_input_ids = torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * decoder_start_token_id
|
||||
return decoder_input_ids
|
||||
if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
|
||||
return model_kwargs.pop("decoder_input_ids")
|
||||
else:
|
||||
decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
|
||||
return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * decoder_start_token_id
|
||||
|
||||
def _get_pad_token_id(self, pad_token_id: int = None, eos_token_id: int = None) -> int:
|
||||
if pad_token_id is None and eos_token_id is not None:
|
||||
@ -649,7 +725,7 @@ class GenerationMixin:
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
inputs: Optional[torch.Tensor] = None,
|
||||
max_length: Optional[int] = None,
|
||||
min_length: Optional[int] = None,
|
||||
do_sample: Optional[bool] = None,
|
||||
@ -688,18 +764,20 @@ class GenerationMixin:
|
||||
Generates sequences for models with a language modeling head. The method currently supports greedy decoding,
|
||||
multinomial sampling, beam-search decoding, and beam-search multinomial sampling.
|
||||
|
||||
Apart from :obj:`input_ids` and :obj:`attention_mask`, all the arguments below will default to the value of the
|
||||
attribute of the same name inside the :class:`~transformers.PretrainedConfig` of the model. The default values
|
||||
indicated are the default values of those config.
|
||||
Apart from :obj:`inputs`, all the arguments below will default to the value of the attribute of the same name
|
||||
inside the :class:`~transformers.PretrainedConfig` of the model. The default values indicated are the default
|
||||
values of those config.
|
||||
|
||||
Most of these parameters are explained in more detail in `this blog post
|
||||
<https://huggingface.co/blog/how-to-generate>`__.
|
||||
|
||||
Parameters:
|
||||
|
||||
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
The sequence used as a prompt for the generation. If :obj:`None` the method initializes it with
|
||||
:obj:`bos_token_id` and a batch size of 1.
|
||||
inputs (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, :obj:`(batch_size, sequence_length, feature_dim)` or :obj:`(batch_size, num_channels, height, width)`, `optional`):
|
||||
The sequence used as a prompt for the generation or as model inputs to the encoder. If :obj:`None` the
|
||||
method initializes it with :obj:`bos_token_id` and a batch size of 1. For decoder-only models
|
||||
:obj:`inputs` should of in the format of :obj:`input_ids`. For encoder-decoder models `inputs` can
|
||||
represent any of :obj:`input_ids`, :obj:`input_values`, :obj:`input_features`, or :obj:`pixel_values`.
|
||||
max_length (:obj:`int`, `optional`, defaults to :obj:`model.config.max_length`):
|
||||
The maximum length of the sequence to be generated.
|
||||
max_new_tokens (:obj:`int`, `optional`, defaults to None):
|
||||
@ -870,8 +948,11 @@ class GenerationMixin:
|
||||
>>> outputs = model.generate(input_ids=input_ids, max_length=20, do_sample=True, bad_words_ids=bad_words_ids)
|
||||
>>> print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True))
|
||||
"""
|
||||
|
||||
# 1. Set generation parameters if not already defined
|
||||
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
|
||||
num_beams = num_beams if num_beams is not None else self.config.num_beams
|
||||
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
|
||||
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
|
||||
num_beam_groups = num_beam_groups if num_beam_groups is not None else self.config.num_beam_groups
|
||||
do_sample = do_sample if do_sample is not None else self.config.do_sample
|
||||
num_return_sequences = (
|
||||
@ -879,7 +960,6 @@ class GenerationMixin:
|
||||
)
|
||||
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||
|
||||
output_scores = output_scores if output_scores is not None else self.config.output_scores
|
||||
@ -891,55 +971,52 @@ class GenerationMixin:
|
||||
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
|
||||
)
|
||||
|
||||
model_kwargs["output_attentions"] = output_attentions
|
||||
model_kwargs["output_hidden_states"] = output_hidden_states
|
||||
|
||||
if input_ids is None and "inputs_embeds" not in model_kwargs:
|
||||
# init `input_ids` with bos_token_id
|
||||
input_ids = self._prepare_input_ids_for_generation(bos_token_id, model_kwargs.get("encoder_outputs"))
|
||||
|
||||
if model_kwargs.get("attention_mask", None) is None:
|
||||
# init `attention_mask` depending on `pad_token_id`
|
||||
inputs_embeds = model_kwargs.get("inputs_embeds", None)
|
||||
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
|
||||
input_ids, pad_token_id, eos_token_id, inputs_embeds
|
||||
)
|
||||
|
||||
# special case if pad_token_id is not defined
|
||||
if pad_token_id is None and eos_token_id is not None:
|
||||
# special case if pad_token_id is not defined
|
||||
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
|
||||
pad_token_id = eos_token_id
|
||||
|
||||
# Storing encoder_input_ids for logits_processor that could use them
|
||||
encoder_input_ids = input_ids if self.config.is_encoder_decoder else None
|
||||
# 2. Define model inputs
|
||||
# inputs_tensor has to be defined
|
||||
# model_input_name is defined if model-specific keyword input is passed
|
||||
# otherwise model_input_name is None
|
||||
# all model-specific keyword inputs are removed from `model_kwargs`
|
||||
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(inputs, bos_token_id, model_kwargs)
|
||||
batch_size = inputs_tensor.shape[0]
|
||||
|
||||
# 3. Define other model kwargs
|
||||
model_kwargs["output_attentions"] = output_attentions
|
||||
model_kwargs["output_hidden_states"] = output_hidden_states
|
||||
model_kwargs["use_cache"] = use_cache
|
||||
|
||||
if model_kwargs.get("attention_mask", None) is None:
|
||||
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
|
||||
inputs_tensor, pad_token_id, eos_token_id
|
||||
)
|
||||
|
||||
if self.config.is_encoder_decoder:
|
||||
# add encoder_outputs to model_kwargs
|
||||
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, model_kwargs)
|
||||
|
||||
# set input_ids as decoder_input_ids
|
||||
if "decoder_input_ids" in model_kwargs:
|
||||
input_ids = model_kwargs.pop("decoder_input_ids")
|
||||
else:
|
||||
# if word embeddings are provided directly, infere the batch size from it
|
||||
batch_size = input_ids.shape[0] if input_ids is not None else model_kwargs["inputs_embeds"].shape[0]
|
||||
input_ids = self._prepare_decoder_input_ids_for_generation(
|
||||
batch_size, decoder_start_token_id=decoder_start_token_id, bos_token_id=bos_token_id
|
||||
# if model is encoder decoder encoder_outputs are created
|
||||
# and added to `model_kwargs`
|
||||
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
|
||||
inputs_tensor, model_kwargs, model_input_name
|
||||
)
|
||||
|
||||
if "encoder_outputs" not in model_kwargs or not isinstance(model_kwargs["encoder_outputs"], ModelOutput):
|
||||
raise ValueError("Make sure that `model_kwargs` include `encoder_outputs` of type `ModelOutput`.")
|
||||
# 4. Prepare `input_ids` which will be used for auto-regressive generation
|
||||
if self.config.is_encoder_decoder:
|
||||
input_ids = self._prepare_decoder_input_ids_for_generation(
|
||||
batch_size,
|
||||
decoder_start_token_id=decoder_start_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
else:
|
||||
if "inputs_embeds" in model_kwargs and input_ids is None:
|
||||
raise ValueError("For decoder-only generation, one must pass `input_ids`.")
|
||||
# if decoder-only then inputs_tensor has to be `input_ids`
|
||||
input_ids = inputs_tensor
|
||||
|
||||
# 5. Prepare `max_length` depending on other stopping criteria
|
||||
# if `max_new_tokens` is passed, but not `max_length` -> set `max_length = max_new_tokens`
|
||||
if max_length is None and max_new_tokens is not None:
|
||||
max_length = (
|
||||
max_new_tokens + input_ids.shape[-1]
|
||||
if input_ids is not None
|
||||
else max_length + model_kwargs["inputs_embeds"].shape[1]
|
||||
)
|
||||
max_length = max_new_tokens + input_ids.shape[-1]
|
||||
elif max_length is not None and max_new_tokens is not None:
|
||||
# Both are set, this is odd, raise a warning
|
||||
warnings.warn(
|
||||
@ -948,7 +1025,6 @@ class GenerationMixin:
|
||||
f"will take priority over `max_new_tokens` {max_new_tokens}.",
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
# default to config if still None
|
||||
max_length = max_length if max_length is not None else self.config.max_length
|
||||
|
||||
@ -959,12 +1035,13 @@ class GenerationMixin:
|
||||
"This can lead to unexpected behavior. You should consider increasing ``config.max_length`` or ``max_length``."
|
||||
)
|
||||
|
||||
# determine generation mode
|
||||
# 6. determine generation mode
|
||||
is_greedy_gen_mode = (num_beams == 1) and (num_beam_groups == 1) and do_sample is False
|
||||
is_sample_gen_mode = (num_beams == 1) and (num_beam_groups == 1) and do_sample is True
|
||||
is_beam_gen_mode = (num_beams > 1) and (num_beam_groups == 1) and do_sample is False
|
||||
is_beam_sample_gen_mode = (num_beams > 1) and (num_beam_groups == 1) and do_sample is True
|
||||
is_group_beam_gen_mode = (num_beams > 1) and (num_beam_groups > 1)
|
||||
|
||||
if num_beam_groups > num_beams:
|
||||
raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`")
|
||||
if is_group_beam_gen_mode and do_sample is True:
|
||||
@ -972,15 +1049,12 @@ class GenerationMixin:
|
||||
"Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`."
|
||||
)
|
||||
|
||||
# set model_kwargs
|
||||
model_kwargs["use_cache"] = use_cache
|
||||
|
||||
# get distribution pre_processing samplers
|
||||
# 7. prepare distribution pre_processing samplers
|
||||
logits_processor = self._get_logits_processor(
|
||||
repetition_penalty=repetition_penalty,
|
||||
no_repeat_ngram_size=no_repeat_ngram_size,
|
||||
encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size,
|
||||
encoder_input_ids=encoder_input_ids,
|
||||
encoder_input_ids=inputs_tensor,
|
||||
bad_words_ids=bad_words_ids,
|
||||
min_length=min_length,
|
||||
max_length=max_length,
|
||||
@ -994,15 +1068,17 @@ class GenerationMixin:
|
||||
remove_invalid_values=remove_invalid_values,
|
||||
)
|
||||
|
||||
# 8. prepare stopping criteria
|
||||
stopping_criteria = self._get_stopping_criteria(max_length=max_length, max_time=max_time)
|
||||
|
||||
# 9. go into different generation modes
|
||||
if is_greedy_gen_mode:
|
||||
if num_return_sequences > 1:
|
||||
raise ValueError(
|
||||
f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search."
|
||||
)
|
||||
|
||||
# greedy search
|
||||
# 10. run greedy search
|
||||
return self.greedy_search(
|
||||
input_ids,
|
||||
logits_processor=logits_processor,
|
||||
@ -1016,12 +1092,12 @@ class GenerationMixin:
|
||||
)
|
||||
|
||||
elif is_sample_gen_mode:
|
||||
# get probability distribution warper
|
||||
# 10. prepare logits warper
|
||||
logits_warper = self._get_logits_warper(
|
||||
top_k=top_k, top_p=top_p, temperature=temperature, num_beams=num_beams
|
||||
)
|
||||
|
||||
# expand input_ids with `num_return_sequences` additional sequences per batch
|
||||
# 11. expand input_ids with `num_return_sequences` additional sequences per batch
|
||||
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
||||
input_ids,
|
||||
expand_size=num_return_sequences,
|
||||
@ -1029,7 +1105,7 @@ class GenerationMixin:
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
# sample
|
||||
# 12. run sample
|
||||
return self.sample(
|
||||
input_ids,
|
||||
logits_processor=logits_processor,
|
||||
@ -1044,17 +1120,13 @@ class GenerationMixin:
|
||||
)
|
||||
|
||||
elif is_beam_gen_mode:
|
||||
batch_size = input_ids.shape[0]
|
||||
|
||||
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
|
||||
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
|
||||
|
||||
if num_return_sequences > num_beams:
|
||||
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
|
||||
|
||||
if stopping_criteria.max_length is None:
|
||||
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
|
||||
|
||||
# 10. prepare beam search scorer
|
||||
beam_scorer = BeamSearchScorer(
|
||||
batch_size=batch_size,
|
||||
num_beams=num_beams,
|
||||
@ -1063,10 +1135,11 @@ class GenerationMixin:
|
||||
do_early_stopping=early_stopping,
|
||||
num_beam_hyps_to_keep=num_return_sequences,
|
||||
)
|
||||
# interleave with `num_beams`
|
||||
# 11. interleave input_ids with `num_beams` additional sequences per batch
|
||||
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
||||
input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs
|
||||
)
|
||||
# 12. run beam search
|
||||
return self.beam_search(
|
||||
input_ids,
|
||||
beam_scorer,
|
||||
@ -1081,24 +1154,23 @@ class GenerationMixin:
|
||||
)
|
||||
|
||||
elif is_beam_sample_gen_mode:
|
||||
# 10. prepare logits warper
|
||||
logits_warper = self._get_logits_warper(
|
||||
top_k=top_k, top_p=top_p, temperature=temperature, num_beams=num_beams
|
||||
)
|
||||
|
||||
batch_size = input_ids.shape[0] * num_return_sequences
|
||||
|
||||
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
|
||||
if stopping_criteria.max_length is None:
|
||||
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
|
||||
# 11. prepare beam search scorer
|
||||
beam_scorer = BeamSearchScorer(
|
||||
batch_size=batch_size,
|
||||
batch_size=batch_size * num_return_sequences,
|
||||
num_beams=num_beams,
|
||||
device=self.device,
|
||||
length_penalty=length_penalty,
|
||||
do_early_stopping=early_stopping,
|
||||
)
|
||||
|
||||
# interleave with `num_beams * num_return_sequences`
|
||||
# 12. interleave input_ids with `num_beams` additional sequences per batch
|
||||
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
||||
input_ids,
|
||||
expand_size=num_beams * num_return_sequences,
|
||||
@ -1106,6 +1178,7 @@ class GenerationMixin:
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
# 13. run beam sample
|
||||
return self.beam_sample(
|
||||
input_ids,
|
||||
beam_scorer,
|
||||
@ -1121,11 +1194,6 @@ class GenerationMixin:
|
||||
)
|
||||
|
||||
elif is_group_beam_gen_mode:
|
||||
batch_size = input_ids.shape[0]
|
||||
|
||||
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
|
||||
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
|
||||
|
||||
if num_return_sequences > num_beams:
|
||||
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
|
||||
|
||||
@ -1135,7 +1203,8 @@ class GenerationMixin:
|
||||
if stopping_criteria.max_length is None:
|
||||
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
|
||||
|
||||
diverse_beam_scorer = BeamSearchScorer(
|
||||
# 10. prepare beam search scorer
|
||||
beam_scorer = BeamSearchScorer(
|
||||
batch_size=batch_size,
|
||||
num_beams=num_beams,
|
||||
max_length=stopping_criteria.max_length,
|
||||
@ -1145,13 +1214,14 @@ class GenerationMixin:
|
||||
num_beam_hyps_to_keep=num_return_sequences,
|
||||
num_beam_groups=num_beam_groups,
|
||||
)
|
||||
# interleave with `num_beams`
|
||||
# 11. interleave input_ids with `num_beams` additional sequences per batch
|
||||
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
||||
input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs
|
||||
)
|
||||
# 12. run beam search
|
||||
return self.group_beam_search(
|
||||
input_ids,
|
||||
diverse_beam_scorer,
|
||||
beam_scorer,
|
||||
logits_processor=logits_processor,
|
||||
stopping_criteria=stopping_criteria,
|
||||
pad_token_id=pad_token_id,
|
||||
|
@ -20,6 +20,8 @@ import unittest
|
||||
from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from .test_modeling_common import floats_tensor
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
@ -29,6 +31,9 @@ if is_torch_available():
|
||||
BartTokenizer,
|
||||
GPT2LMHeadModel,
|
||||
GPT2Tokenizer,
|
||||
Speech2TextForConditionalGeneration,
|
||||
SpeechEncoderDecoderModel,
|
||||
VisionEncoderDecoderModel,
|
||||
top_k_top_p_filtering,
|
||||
)
|
||||
from transformers.generation_beam_search import BeamSearchScorer
|
||||
@ -1724,3 +1729,74 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
# cannot generate from `inputs_embeds` for decoder only
|
||||
with self.assertRaises(ValueError):
|
||||
model.generate(inputs_embeds=inputs_embeds)
|
||||
|
||||
def test_generate_input_ids_as_kwarg(self):
|
||||
article = """I need input_ids to generate"""
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2", max_length=15).to(torch_device)
|
||||
input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||
output_sequences_kwargs = model.generate(input_ids=input_ids).cpu()
|
||||
output_sequences = model.generate(input_ids).cpu()
|
||||
|
||||
self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist())
|
||||
self.assertEqual(output_sequences.shape, (1, 15))
|
||||
|
||||
def test_generate_input_ids_as_encoder_kwarg(self):
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart", max_length=5).to(
|
||||
torch_device
|
||||
)
|
||||
model.config.eos_token_id = None
|
||||
input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||
output_sequences_kwargs = model.generate(input_ids=input_ids).cpu()
|
||||
output_sequences = model.generate(input_ids).cpu()
|
||||
|
||||
self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist())
|
||||
self.assertEqual(output_sequences.shape, (1, 5))
|
||||
|
||||
def test_generate_inputs_and_encoder_kwargs(self):
|
||||
article = """I need input_ids to generate"""
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2", max_length=10).to(torch_device)
|
||||
input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||
with self.assertRaises(ValueError):
|
||||
model.generate(input_ids, input_ids=input_ids)
|
||||
|
||||
def test_generate_too_many_encoder_kwargs(self):
|
||||
article = """I need input_ids to generate"""
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2", max_length=10).to(torch_device)
|
||||
input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||
with self.assertRaises(ValueError):
|
||||
model.generate(input_ids=input_ids, input_values=input_ids)
|
||||
|
||||
def test_generate_input_values_as_encoder_kwarg(self):
|
||||
input_values = floats_tensor((2, 250))
|
||||
model = SpeechEncoderDecoderModel.from_pretrained("hf-internal-testing/tiny-random-speech-encoder-decoder")
|
||||
model = model.to(torch_device)
|
||||
output_sequences_kwargs = model.generate(input_values=input_values, max_length=5).cpu()
|
||||
output_sequences = model.generate(input_values, max_length=5).cpu()
|
||||
|
||||
self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist())
|
||||
self.assertEqual(output_sequences.shape, (2, 5))
|
||||
|
||||
def test_generate_input_features_as_encoder_kwarg(self):
|
||||
input_features = floats_tensor((3, 20, 24))
|
||||
model = Speech2TextForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-speech_to_text")
|
||||
model = model.to(torch_device)
|
||||
output_sequences_kwargs = model.generate(input_features=input_features, max_length=5).cpu()
|
||||
output_sequences = model.generate(input_features, max_length=5).cpu()
|
||||
|
||||
self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist())
|
||||
self.assertEqual(output_sequences.shape, (3, 5))
|
||||
|
||||
def test_generate_pixel_values_as_encoder_kwarg(self):
|
||||
pixel_values = floats_tensor((2, 3, 30, 30))
|
||||
model = VisionEncoderDecoderModel.from_pretrained("hf-internal-testing/tiny-random-vision-encoder-decoder")
|
||||
model = model.to(torch_device)
|
||||
output_sequences_kwargs = model.generate(pixel_values=pixel_values, max_length=5).cpu()
|
||||
output_sequences = model.generate(pixel_values, max_length=5).cpu()
|
||||
|
||||
self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist())
|
||||
self.assertEqual(output_sequences.shape, (2, 5))
|
||||
|
Loading…
Reference in New Issue
Block a user