mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Generate: remove near-duplicate sample/greedy copy (#30773)
This commit is contained in:
parent
ce87dca1d7
commit
de2f722172
@ -1683,17 +1683,6 @@ class GenerationMixin:
|
||||
streamer=streamer,
|
||||
**model_kwargs,
|
||||
)
|
||||
if generation_mode == GenerationMode.GREEDY_SEARCH:
|
||||
# 11. run greedy search
|
||||
result = self._greedy_search(
|
||||
input_ids,
|
||||
logits_processor=prepared_logits_processor,
|
||||
stopping_criteria=prepared_stopping_criteria,
|
||||
generation_config=generation_config,
|
||||
synced_gpus=synced_gpus,
|
||||
streamer=streamer,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
elif generation_mode == GenerationMode.CONTRASTIVE_SEARCH:
|
||||
if not model_kwargs["use_cache"]:
|
||||
@ -1709,9 +1698,11 @@ class GenerationMixin:
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
elif generation_mode == GenerationMode.SAMPLE:
|
||||
elif generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
|
||||
# 11. prepare logits warper
|
||||
logits_warper = self._get_logits_warper(generation_config)
|
||||
prepared_logits_warper = (
|
||||
self._get_logits_warper(generation_config) if generation_config.do_sample else None
|
||||
)
|
||||
|
||||
# 12. expand input_ids with `num_return_sequences` additional sequences per batch
|
||||
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
||||
@ -1721,11 +1712,11 @@ class GenerationMixin:
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
# 13. run sample
|
||||
# 13. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
|
||||
result = self._sample(
|
||||
input_ids,
|
||||
logits_processor=prepared_logits_processor,
|
||||
logits_warper=logits_warper,
|
||||
logits_warper=prepared_logits_warper,
|
||||
stopping_criteria=prepared_stopping_criteria,
|
||||
generation_config=generation_config,
|
||||
synced_gpus=synced_gpus,
|
||||
@ -1733,38 +1724,11 @@ class GenerationMixin:
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
elif generation_mode == GenerationMode.BEAM_SEARCH:
|
||||
# 11. prepare beam search scorer
|
||||
beam_scorer = BeamSearchScorer(
|
||||
batch_size=batch_size,
|
||||
num_beams=generation_config.num_beams,
|
||||
device=inputs_tensor.device,
|
||||
length_penalty=generation_config.length_penalty,
|
||||
do_early_stopping=generation_config.early_stopping,
|
||||
num_beam_hyps_to_keep=generation_config.num_return_sequences,
|
||||
max_length=generation_config.max_length,
|
||||
)
|
||||
# 12. interleave input_ids with `num_beams` additional sequences per batch
|
||||
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
||||
input_ids=input_ids,
|
||||
expand_size=generation_config.num_beams,
|
||||
is_encoder_decoder=self.config.is_encoder_decoder,
|
||||
**model_kwargs,
|
||||
)
|
||||
# 13. run beam search
|
||||
result = self._beam_search(
|
||||
input_ids,
|
||||
beam_scorer,
|
||||
logits_processor=prepared_logits_processor,
|
||||
stopping_criteria=prepared_stopping_criteria,
|
||||
generation_config=generation_config,
|
||||
synced_gpus=synced_gpus,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
elif generation_mode == GenerationMode.BEAM_SAMPLE:
|
||||
elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
|
||||
# 11. prepare logits warper
|
||||
logits_warper = self._get_logits_warper(generation_config)
|
||||
prepared_logits_warper = (
|
||||
self._get_logits_warper(generation_config) if generation_config.do_sample else None
|
||||
)
|
||||
|
||||
# 12. prepare beam search scorer
|
||||
beam_scorer = BeamSearchScorer(
|
||||
@ -1786,11 +1750,11 @@ class GenerationMixin:
|
||||
)
|
||||
|
||||
# 14. run beam sample
|
||||
result = self._beam_sample(
|
||||
result = self._beam_search(
|
||||
input_ids,
|
||||
beam_scorer,
|
||||
logits_processor=prepared_logits_processor,
|
||||
logits_warper=logits_warper,
|
||||
logits_warper=prepared_logits_warper,
|
||||
stopping_criteria=prepared_stopping_criteria,
|
||||
generation_config=generation_config,
|
||||
synced_gpus=synced_gpus,
|
||||
@ -2284,162 +2248,32 @@ class GenerationMixin:
|
||||
**model_kwargs,
|
||||
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
|
||||
r"""
|
||||
Generates sequences of token ids for models with a language modeling head using **greedy decoding** and can be
|
||||
used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
|
||||
|
||||
Parameters:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
The sequence used as a prompt for the generation.
|
||||
logits_processor (`LogitsProcessorList`):
|
||||
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
|
||||
used to modify the prediction scores of the language modeling head applied at each generation step.
|
||||
stopping_criteria (`StoppingCriteriaList`):
|
||||
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
|
||||
used to tell if the generation loop should stop.
|
||||
generation_config ([`~generation.GenerationConfig`]):
|
||||
The generation configuration to be used as parametrization of the decoding method.
|
||||
synced_gpus (`bool`):
|
||||
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
|
||||
streamer (`BaseStreamer`, *optional*):
|
||||
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
|
||||
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
|
||||
model_kwargs:
|
||||
Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
|
||||
If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
|
||||
|
||||
Return:
|
||||
[`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or
|
||||
`torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
|
||||
[`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
|
||||
`return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if
|
||||
`model.config.is_encoder_decoder=True`.
|
||||
Deprecated. Use `._sample()` instead, passing the same arguments.
|
||||
"""
|
||||
# init values
|
||||
pad_token_id = generation_config.pad_token_id
|
||||
output_attentions = generation_config.output_attentions
|
||||
output_hidden_states = generation_config.output_hidden_states
|
||||
output_scores = generation_config.output_scores
|
||||
output_logits = generation_config.output_logits
|
||||
return_dict_in_generate = generation_config.return_dict_in_generate
|
||||
has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
|
||||
|
||||
# init attention / hidden states / scores tuples
|
||||
raw_logits = () if (return_dict_in_generate and output_logits) else None
|
||||
scores = () if (return_dict_in_generate and output_scores) else None
|
||||
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
||||
|
||||
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
||||
if return_dict_in_generate and self.config.is_encoder_decoder:
|
||||
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
|
||||
encoder_hidden_states = (
|
||||
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
|
||||
)
|
||||
|
||||
# keep track of which sequences are already finished
|
||||
batch_size = input_ids.shape[0]
|
||||
this_peer_finished = False
|
||||
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
|
||||
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
|
||||
|
||||
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
|
||||
# prepare model inputs
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
||||
|
||||
# forward pass to get next token
|
||||
outputs = self(
|
||||
**model_inputs,
|
||||
return_dict=True,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
|
||||
if synced_gpus and this_peer_finished:
|
||||
continue # don't waste resources running the code we don't need
|
||||
|
||||
next_token_logits = outputs.logits[:, -1, :]
|
||||
|
||||
# pre-process distribution
|
||||
next_tokens_scores = logits_processor(input_ids, next_token_logits)
|
||||
|
||||
# Store scores, attentions and hidden_states when required
|
||||
if return_dict_in_generate:
|
||||
if output_scores:
|
||||
scores += (next_tokens_scores,)
|
||||
if output_logits:
|
||||
raw_logits += (next_token_logits,)
|
||||
if output_attentions:
|
||||
decoder_attentions += (
|
||||
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
||||
)
|
||||
if self.config.is_encoder_decoder:
|
||||
cross_attentions += (outputs.cross_attentions,)
|
||||
|
||||
if output_hidden_states:
|
||||
decoder_hidden_states += (
|
||||
(outputs.decoder_hidden_states,)
|
||||
if self.config.is_encoder_decoder
|
||||
else (outputs.hidden_states,)
|
||||
)
|
||||
|
||||
# argmax
|
||||
next_tokens = torch.argmax(next_tokens_scores, dim=-1)
|
||||
|
||||
# finished sentences should have their next token be a padding token
|
||||
if has_eos_stopping_criteria:
|
||||
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
||||
|
||||
# update generated ids, model inputs, and length for next step
|
||||
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
||||
if streamer is not None:
|
||||
streamer.put(next_tokens.cpu())
|
||||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
outputs,
|
||||
model_kwargs,
|
||||
is_encoder_decoder=self.config.is_encoder_decoder,
|
||||
)
|
||||
|
||||
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
|
||||
this_peer_finished = unfinished_sequences.max() == 0
|
||||
|
||||
if streamer is not None:
|
||||
streamer.end()
|
||||
|
||||
if return_dict_in_generate:
|
||||
if self.config.is_encoder_decoder:
|
||||
return GenerateEncoderDecoderOutput(
|
||||
sequences=input_ids,
|
||||
scores=scores,
|
||||
logits=raw_logits,
|
||||
encoder_attentions=encoder_attentions,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
decoder_attentions=decoder_attentions,
|
||||
cross_attentions=cross_attentions,
|
||||
decoder_hidden_states=decoder_hidden_states,
|
||||
past_key_values=model_kwargs.get("past_key_values"),
|
||||
)
|
||||
else:
|
||||
return GenerateDecoderOnlyOutput(
|
||||
sequences=input_ids,
|
||||
scores=scores,
|
||||
logits=raw_logits,
|
||||
attentions=decoder_attentions,
|
||||
hidden_states=decoder_hidden_states,
|
||||
past_key_values=model_kwargs.get("past_key_values"),
|
||||
)
|
||||
else:
|
||||
return input_ids
|
||||
logger.warning_once(
|
||||
"Calling `._greedy_search()` directly is deprecated and will be removed in v4.42. Use `._sample()` "
|
||||
"instead, passing the same arguments."
|
||||
)
|
||||
return self._sample(
|
||||
input_ids=input_ids,
|
||||
logits_processor=logits_processor,
|
||||
stopping_criteria=stopping_criteria,
|
||||
generation_config=generation_config,
|
||||
synced_gpus=synced_gpus,
|
||||
streamer=streamer,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
def _sample(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
logits_processor: LogitsProcessorList,
|
||||
stopping_criteria: StoppingCriteriaList,
|
||||
logits_warper: LogitsProcessorList,
|
||||
generation_config: GenerationConfig,
|
||||
synced_gpus: bool,
|
||||
streamer: Optional["BaseStreamer"],
|
||||
logits_warper: Optional[LogitsProcessorList] = None,
|
||||
**model_kwargs,
|
||||
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
|
||||
r"""
|
||||
@ -2455,10 +2289,6 @@ class GenerationMixin:
|
||||
stopping_criteria (`StoppingCriteriaList`):
|
||||
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
|
||||
used to tell if the generation loop should stop.
|
||||
logits_warper (`LogitsProcessorList`):
|
||||
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
|
||||
to warp the prediction score distribution of the language modeling head applied before multinomial
|
||||
sampling at each generation step.
|
||||
generation_config ([`~generation.GenerationConfig`]):
|
||||
The generation configuration to be used as parametrization of the decoding method.
|
||||
synced_gpus (`bool`):
|
||||
@ -2466,6 +2296,11 @@ class GenerationMixin:
|
||||
streamer (`BaseStreamer`, *optional*):
|
||||
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
|
||||
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
|
||||
logits_warper (`LogitsProcessorList`, *optional*):
|
||||
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
|
||||
to warp the prediction score distribution of the language modeling head applied before multinomial
|
||||
sampling at each generation step. Only required with sampling strategies (i.e. `do_sample` is set in
|
||||
`generation_config`)
|
||||
model_kwargs:
|
||||
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
|
||||
an encoder-decoder model the kwargs should include `encoder_outputs`.
|
||||
@ -2485,6 +2320,12 @@ class GenerationMixin:
|
||||
output_logits = generation_config.output_logits
|
||||
return_dict_in_generate = generation_config.return_dict_in_generate
|
||||
has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
|
||||
do_sample = generation_config.do_sample
|
||||
if do_sample is True and not isinstance(logits_warper, LogitsProcessorList):
|
||||
raise ValueError(
|
||||
"`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is "
|
||||
f"{logits_warper})."
|
||||
)
|
||||
|
||||
# init attention / hidden states / scores tuples
|
||||
scores = () if (return_dict_in_generate and output_scores) else None
|
||||
@ -2525,7 +2366,8 @@ class GenerationMixin:
|
||||
|
||||
# pre-process distribution
|
||||
next_token_scores = logits_processor(input_ids, next_token_logits)
|
||||
next_token_scores = logits_warper(input_ids, next_token_scores)
|
||||
if do_sample:
|
||||
next_token_scores = logits_warper(input_ids, next_token_scores)
|
||||
|
||||
# Store scores, attentions and hidden_states when required
|
||||
if return_dict_in_generate:
|
||||
@ -2547,9 +2389,12 @@ class GenerationMixin:
|
||||
else (outputs.hidden_states,)
|
||||
)
|
||||
|
||||
# sample
|
||||
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
||||
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
||||
# token selection
|
||||
if do_sample:
|
||||
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
||||
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
||||
else:
|
||||
next_tokens = torch.argmax(next_token_scores, dim=-1)
|
||||
|
||||
# finished sentences should have their next token be a padding token
|
||||
if has_eos_stopping_criteria:
|
||||
@ -2622,6 +2467,7 @@ class GenerationMixin:
|
||||
past_key_values.reorder_cache(beam_idx)
|
||||
return past_key_values
|
||||
|
||||
# TODO (joao, v4.42): remove default for `logits_warper`
|
||||
def _beam_search(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
@ -2630,6 +2476,7 @@ class GenerationMixin:
|
||||
stopping_criteria: StoppingCriteriaList,
|
||||
generation_config: GenerationConfig,
|
||||
synced_gpus: bool,
|
||||
logits_warper: Optional[LogitsProcessorList] = None,
|
||||
**model_kwargs,
|
||||
) -> Union[GenerateBeamOutput, torch.LongTensor]:
|
||||
r"""
|
||||
@ -2652,6 +2499,11 @@ class GenerationMixin:
|
||||
The generation configuration to be used as parametrization of the decoding method.
|
||||
synced_gpus (`bool`):
|
||||
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
|
||||
logits_warper (`LogitsProcessorList`, *optional*):
|
||||
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
|
||||
to warp the prediction score distribution of the language modeling head applied before multinomial
|
||||
sampling at each generation step. Only required with sampling strategies (i.e. `do_sample` is set in
|
||||
`generation_config`)
|
||||
model_kwargs:
|
||||
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
|
||||
an encoder-decoder model the kwargs should include `encoder_outputs`.
|
||||
@ -2672,6 +2524,12 @@ class GenerationMixin:
|
||||
output_logits = generation_config.output_logits
|
||||
return_dict_in_generate = generation_config.return_dict_in_generate
|
||||
sequential = generation_config.low_memory
|
||||
do_sample = generation_config.do_sample
|
||||
if do_sample is True and not isinstance(logits_warper, LogitsProcessorList):
|
||||
raise ValueError(
|
||||
"`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is "
|
||||
f"{logits_warper})."
|
||||
)
|
||||
|
||||
batch_size = len(beam_scorer._beam_hyps)
|
||||
num_beams = beam_scorer.num_beams
|
||||
@ -2768,6 +2626,8 @@ class GenerationMixin:
|
||||
) # (batch_size * num_beams, vocab_size)
|
||||
|
||||
next_token_scores_processed = logits_processor(input_ids, next_token_scores)
|
||||
if do_sample:
|
||||
next_token_scores_processed = logits_warper(input_ids, next_token_scores_processed)
|
||||
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(
|
||||
next_token_scores_processed
|
||||
)
|
||||
@ -2795,11 +2655,20 @@ class GenerationMixin:
|
||||
vocab_size = next_token_scores.shape[-1]
|
||||
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
|
||||
|
||||
# Sample 1 + len(eos_token_id) next tokens for each beam so we have at least 1 non eos token per beam.
|
||||
# Beam token selection: pick 1 + eos_token_id.shape[0] next tokens for each beam so we have at least 1
|
||||
# non eos token per beam.
|
||||
n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0
|
||||
next_token_scores, next_tokens = torch.topk(
|
||||
next_token_scores, max(2, 1 + n_eos_tokens) * num_beams, dim=1, largest=True, sorted=True
|
||||
)
|
||||
n_tokens_to_keep = max(2, 1 + n_eos_tokens) * num_beams
|
||||
if do_sample:
|
||||
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
||||
next_tokens = torch.multinomial(probs, num_samples=n_tokens_to_keep)
|
||||
next_token_scores = torch.gather(next_token_scores, -1, next_tokens)
|
||||
next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1)
|
||||
next_tokens = torch.gather(next_tokens, -1, _indices)
|
||||
else:
|
||||
next_token_scores, next_tokens = torch.topk(
|
||||
next_token_scores, n_tokens_to_keep, dim=1, largest=True, sorted=True
|
||||
)
|
||||
|
||||
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
|
||||
next_tokens = next_tokens % vocab_size
|
||||
@ -2897,219 +2766,24 @@ class GenerationMixin:
|
||||
**model_kwargs,
|
||||
) -> Union[GenerateBeamOutput, torch.LongTensor]:
|
||||
r"""
|
||||
Generates sequences of token ids for models with a language modeling head using **beam search multinomial
|
||||
sampling** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
|
||||
|
||||
Parameters:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
The sequence used as a prompt for the generation.
|
||||
beam_scorer (`BeamScorer`):
|
||||
A derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and
|
||||
sorted during generation. For more information, the documentation of [`BeamScorer`] should be read.
|
||||
logits_processor (`LogitsProcessorList`):
|
||||
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
|
||||
used to modify the prediction scores of the language modeling head applied at each generation step.
|
||||
stopping_criteria (`StoppingCriteriaList`):
|
||||
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
|
||||
used to tell if the generation loop should stop.
|
||||
logits_warper (`LogitsProcessorList`):
|
||||
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
|
||||
to warp the prediction score distribution of the language modeling head applied before multinomial
|
||||
sampling at each generation step.
|
||||
generation_config ([`~generation.GenerationConfig`]):
|
||||
The generation configuration to be used as parametrization of the decoding method.
|
||||
synced_gpus (`bool`):
|
||||
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
|
||||
model_kwargs:
|
||||
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
|
||||
an encoder-decoder model the kwargs should include `encoder_outputs`.
|
||||
|
||||
Return:
|
||||
[`~generation.GenerateBeamDecoderOnlyOutput`], [`~generation.GenerateBeamEncoderDecoderOutput`] or
|
||||
`torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
|
||||
[`~generation.GenerateBeamDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
|
||||
`return_dict_in_generate=True` or a [`~generation.GenerateBeamEncoderDecoderOutput`] if
|
||||
`model.config.is_encoder_decoder=True`.
|
||||
Deprecated. Use `._beam_search()` instead, passing the same arguments.
|
||||
"""
|
||||
# init values
|
||||
pad_token_id = generation_config.pad_token_id
|
||||
eos_token_id = generation_config.eos_token_id
|
||||
output_attentions = generation_config.output_attentions
|
||||
output_hidden_states = generation_config.output_hidden_states
|
||||
output_scores = generation_config.output_scores
|
||||
output_logits = generation_config.output_logits
|
||||
return_dict_in_generate = generation_config.return_dict_in_generate
|
||||
|
||||
batch_size = len(beam_scorer._beam_hyps)
|
||||
num_beams = beam_scorer.num_beams
|
||||
|
||||
batch_beam_size, cur_len = input_ids.shape
|
||||
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
|
||||
|
||||
# init attention / hidden states / scores tuples
|
||||
scores = () if (return_dict_in_generate and output_scores) else None
|
||||
raw_logits = () if (return_dict_in_generate and output_logits) else None
|
||||
beam_indices = (
|
||||
tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None
|
||||
logger.warning_once(
|
||||
"Calling `._beam_sample()` directly is deprecated and will be removed in v4.42. Use `._beam_search()` "
|
||||
"instead, passing the same arguments."
|
||||
)
|
||||
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
||||
|
||||
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
||||
if return_dict_in_generate and self.config.is_encoder_decoder:
|
||||
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
|
||||
encoder_hidden_states = (
|
||||
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
|
||||
)
|
||||
|
||||
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
|
||||
beam_scores = beam_scores.view((batch_size * num_beams,))
|
||||
|
||||
this_peer_finished = False
|
||||
|
||||
decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder
|
||||
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
||||
|
||||
outputs = self(
|
||||
**model_inputs,
|
||||
return_dict=True,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
|
||||
if synced_gpus and this_peer_finished:
|
||||
cur_len = cur_len + 1
|
||||
continue # don't waste resources running the code we don't need
|
||||
|
||||
next_token_logits = outputs.logits[:, -1, :]
|
||||
|
||||
next_token_scores = nn.functional.log_softmax(
|
||||
next_token_logits, dim=-1
|
||||
) # (batch_size * num_beams, vocab_size)
|
||||
|
||||
next_token_scores_processed = logits_processor(input_ids, next_token_scores)
|
||||
next_token_scores_processed = logits_warper(input_ids, next_token_scores_processed)
|
||||
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(
|
||||
next_token_scores_processed
|
||||
)
|
||||
|
||||
# Store scores, attentions and hidden_states when required
|
||||
if return_dict_in_generate:
|
||||
if output_scores:
|
||||
scores += (next_token_scores_processed,)
|
||||
if output_logits:
|
||||
raw_logits += (next_token_logits,)
|
||||
if output_attentions:
|
||||
decoder_attentions += (
|
||||
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
||||
)
|
||||
if self.config.is_encoder_decoder:
|
||||
cross_attentions += (outputs.cross_attentions,)
|
||||
|
||||
if output_hidden_states:
|
||||
decoder_hidden_states += (
|
||||
(outputs.decoder_hidden_states,)
|
||||
if self.config.is_encoder_decoder
|
||||
else (outputs.hidden_states,)
|
||||
)
|
||||
|
||||
# reshape for beam search
|
||||
vocab_size = next_token_scores.shape[-1]
|
||||
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
|
||||
|
||||
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
||||
|
||||
next_tokens = torch.multinomial(probs, num_samples=2 * num_beams)
|
||||
next_token_scores = torch.gather(next_token_scores, -1, next_tokens)
|
||||
|
||||
next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1)
|
||||
next_tokens = torch.gather(next_tokens, -1, _indices)
|
||||
|
||||
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
|
||||
next_tokens = next_tokens % vocab_size
|
||||
|
||||
# stateless
|
||||
beam_outputs = beam_scorer.process(
|
||||
input_ids,
|
||||
next_token_scores,
|
||||
next_tokens,
|
||||
next_indices,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
beam_indices=beam_indices,
|
||||
decoder_prompt_len=decoder_prompt_len,
|
||||
)
|
||||
beam_scores = beam_outputs["next_beam_scores"]
|
||||
beam_next_tokens = beam_outputs["next_beam_tokens"]
|
||||
beam_idx = beam_outputs["next_beam_indices"]
|
||||
|
||||
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
|
||||
|
||||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
outputs,
|
||||
model_kwargs,
|
||||
is_encoder_decoder=self.config.is_encoder_decoder,
|
||||
)
|
||||
if model_kwargs.get("past_key_values", None) is not None:
|
||||
model_kwargs["past_key_values"] = self._temporary_reorder_cache(
|
||||
model_kwargs["past_key_values"], beam_idx
|
||||
)
|
||||
|
||||
if return_dict_in_generate and output_scores:
|
||||
beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))
|
||||
|
||||
# increase cur_len
|
||||
cur_len = cur_len + 1
|
||||
|
||||
if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)):
|
||||
this_peer_finished = True
|
||||
|
||||
sequence_outputs = beam_scorer.finalize(
|
||||
input_ids,
|
||||
beam_scores,
|
||||
next_tokens,
|
||||
next_indices,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
max_length=stopping_criteria.max_length,
|
||||
beam_indices=beam_indices,
|
||||
decoder_prompt_len=decoder_prompt_len,
|
||||
return self._beam_search(
|
||||
input_ids=input_ids,
|
||||
beam_scorer=beam_scorer,
|
||||
logits_processor=logits_processor,
|
||||
stopping_criteria=stopping_criteria,
|
||||
logits_warper=logits_warper,
|
||||
generation_config=generation_config,
|
||||
synced_gpus=synced_gpus,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
if return_dict_in_generate:
|
||||
if not output_scores:
|
||||
sequence_outputs["sequence_scores"] = None
|
||||
|
||||
if self.config.is_encoder_decoder:
|
||||
return GenerateBeamEncoderDecoderOutput(
|
||||
sequences=sequence_outputs["sequences"],
|
||||
sequences_scores=sequence_outputs["sequence_scores"],
|
||||
scores=scores,
|
||||
logits=raw_logits,
|
||||
beam_indices=sequence_outputs["beam_indices"],
|
||||
encoder_attentions=encoder_attentions,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
decoder_attentions=decoder_attentions,
|
||||
cross_attentions=cross_attentions,
|
||||
decoder_hidden_states=decoder_hidden_states,
|
||||
past_key_values=model_kwargs.get("past_key_values"),
|
||||
)
|
||||
else:
|
||||
return GenerateBeamDecoderOnlyOutput(
|
||||
sequences=sequence_outputs["sequences"],
|
||||
sequences_scores=sequence_outputs["sequence_scores"],
|
||||
scores=scores,
|
||||
logits=raw_logits,
|
||||
beam_indices=sequence_outputs["beam_indices"],
|
||||
attentions=decoder_attentions,
|
||||
hidden_states=decoder_hidden_states,
|
||||
past_key_values=model_kwargs.get("past_key_values"),
|
||||
)
|
||||
else:
|
||||
return sequence_outputs["sequences"]
|
||||
|
||||
def _group_beam_search(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
|
@ -1739,7 +1739,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
|
||||
)
|
||||
|
||||
# 11. run greedy search
|
||||
outputs = self._greedy_search(
|
||||
outputs = self._sample(
|
||||
input_ids,
|
||||
logits_processor=logits_processor,
|
||||
stopping_criteria=stopping_criteria,
|
||||
@ -2832,7 +2832,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
|
||||
)
|
||||
|
||||
# 11. run greedy search
|
||||
outputs = self._greedy_search(
|
||||
outputs = self._sample(
|
||||
input_ids,
|
||||
logits_processor=logits_processor,
|
||||
stopping_criteria=stopping_criteria,
|
||||
|
@ -1676,7 +1676,7 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
|
||||
)
|
||||
|
||||
# 11. run greedy search
|
||||
outputs = self._greedy_search(
|
||||
outputs = self._sample(
|
||||
input_ids,
|
||||
logits_processor=logits_processor,
|
||||
stopping_criteria=stopping_criteria,
|
||||
@ -2691,7 +2691,7 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
|
||||
)
|
||||
|
||||
# 11. run greedy search
|
||||
outputs = self._greedy_search(
|
||||
outputs = self._sample(
|
||||
input_ids,
|
||||
logits_processor=logits_processor,
|
||||
stopping_criteria=stopping_criteria,
|
||||
|
@ -1550,7 +1550,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing"
|
||||
" greedy search."
|
||||
)
|
||||
return self._greedy_search(
|
||||
return self._sample(
|
||||
input_ids,
|
||||
logits_processor=pre_processor,
|
||||
stopping_criteria=prepared_stopping_criteria,
|
||||
|
Loading…
Reference in New Issue
Block a user