Generate: contrastive search cosmetic tweaks (#19871)

This commit is contained in:
Joao Gante 2022-10-25 14:43:06 +01:00 committed by GitHub
parent 0a77249178
commit d39f794eda
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -100,8 +100,9 @@ class GreedySearchDecoderOnlyOutput(ModelOutput):
@dataclass
class ContrastiveSearchEncoderDecoderOutput(ModelOutput):
"""
Args:
Base class for outputs of decoder-only generation models using contrastive search.
Args:
sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
if all batches finished early due to the `eos_token_id`.
@ -110,7 +111,7 @@ class ContrastiveSearchEncoderDecoderOutput(ModelOutput):
Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`
decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`:
is passed or when `config.output_hidden_states=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
@ -124,8 +125,9 @@ class ContrastiveSearchEncoderDecoderOutput(ModelOutput):
@dataclass
class ContrastiveSearchDecoderOnlyOutput(ModelOutput):
"""
Args:
Base class for outputs of decoder-only generation models using contrastive search.
Args:
sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
if all batches finished early due to the `eos_token_id`.
@ -433,6 +435,8 @@ GreedySearchOutput = Union[GreedySearchEncoderDecoderOutput, GreedySearchDecoder
SampleOutput = Union[SampleEncoderDecoderOutput, SampleDecoderOnlyOutput]
BeamSearchOutput = Union[BeamSearchEncoderDecoderOutput, BeamSearchDecoderOnlyOutput]
BeamSampleOutput = Union[BeamSampleEncoderDecoderOutput, BeamSampleDecoderOnlyOutput]
ContrastiveSearchOutput = Union[ContrastiveSearchEncoderDecoderOutput, ContrastiveSearchDecoderOnlyOutput]
GenerateOutput = Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, ContrastiveSearchOutput]
class GenerationMixin:
@ -1010,7 +1014,7 @@ class GenerationMixin:
begin_suppress_tokens: Optional[List[int]] = None,
forced_decoder_ids: Optional[List[List[int]]] = None,
**model_kwargs,
) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]:
) -> Union[GenerateOutput, torch.LongTensor]:
r"""
Generates sequences of token ids for models with a language modeling head. The method supports the following
@ -1766,7 +1770,7 @@ class GenerationMixin:
return_dict_in_generate: Optional[bool] = None,
synced_gpus: Optional[bool] = False,
**model_kwargs,
) -> Union[GreedySearchOutput, torch.LongTensor]:
) -> Union[ContrastiveSearchOutput, torch.LongTensor]:
r"""
Generates sequences of token ids for models with a language modeling head using **contrastive search** and can
be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
@ -1781,6 +1785,10 @@ class GenerationMixin:
logits_processor (`LogitsProcessorList`, *optional*):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
used to modify the prediction scores of the language modeling head applied at each generation step.
logits_warper (`LogitsProcessorList`, *optional*):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
to warp the prediction score distribution of the language modeling head applied before multinomial
sampling at each generation step.
stopping_criteria (`StoppingCriteriaList`, *optional*):
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
used to tell if the generation loop should stop.
@ -1817,7 +1825,6 @@ class GenerationMixin:
>>> from transformers import (
... AutoTokenizer,
... AutoModelForCausalLM,
... MinLengthLogitsProcessor,
... StoppingCriteriaList,
... MaxLengthCriteria,
... )
@ -1859,7 +1866,6 @@ class GenerationMixin:
this_peer_finished = False # used by synced_gpus only
step_counter = 0
while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
@ -1875,20 +1881,23 @@ class GenerationMixin:
model_kwargs["use_cache"] = True
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# if the first step in the loop, encode all the prefix and obtain three parameters: (1) past_key_values; (2) last_hidden_states; (3) logit_for_next_step
if step_counter == 0:
# encode the given prefix and prepare model inputs; encoder-decoder model process the prefix and save the `encoder_outputs`
# if the first step in the loop, encode all the prefix and obtain three parameters: (1) past_key_values;
# (2) last_hidden_states; (3) logit_for_next_step
if model_kwargs.get("past") is None:
# encode the given prefix and prepare model inputs; encoder-decoder model process the prefix and save
# the `encoder_outputs`
output = self(**model_inputs, output_hidden_states=True, output_attentions=True)
# past_key_values is activated for fast decoding
# past_key_values is required for fast decoding
if "past_key_values" not in output:
raise ValueError(
"self.__class__ cannot return `past_key_values` and can therefore **not** be used for"
" contrastive search."
f"{self.__class__.__name__} cannot return `past_key_values` and can therefore **not** be used "
"for contrastive search."
)
past_key_values = output.past_key_values
# last decoder hidden states will be used to compute the degeneration penalty (cosine similarity with previous tokens)
# last decoder hidden states will be used to compute the degeneration penalty (cosine similarity with
# previous tokens)
if self.config.is_encoder_decoder:
last_hidden_states = output.decoder_hidden_states[-1]
else:
@ -1897,7 +1906,8 @@ class GenerationMixin:
logit_for_next_step = output.logits[:, -1, :]
# contrastive_search main logic start:
# contrastive search decoding consists of two steps: (1) candidate tokens recall; (2) candidate re-rank by degeneration penalty
# contrastive search decoding consists of two steps: (1) candidate tokens recall; (2) candidate re-rank by
# degeneration penalty
bsz, seqlen, embed_dim = last_hidden_states.size()
# logits processor
@ -1949,12 +1959,6 @@ class GenerationMixin:
)
# compute the candidate tokens by the language model and collects their hidden_states
output = self(output_hidden_states=True, **next_model_inputs)
if "past_key_values" not in output:
raise ValueError(
"self.__class__ cannot return `past_key_values` and can therefore **not** be used for contrastive"
" search."
)
past_key_values = output.past_key_values
logits = output.logits[:, -1, :]
@ -1969,13 +1973,16 @@ class GenerationMixin:
last_hidden_states.unsqueeze(1).expand(-1, top_k, -1, -1).reshape(bsz * top_k, seqlen, embed_dim)
)
# compute the degeneratin penalty and re-rank the candidates based on the degeneration penalty and the model confidence
# compute the degeneratin penalty and re-rank the candidates based on the degeneration penalty and the
# model confidence
# the scores and index of the selected tokens are returned
selected_scores, selected_idx = ranking_fast(
context_hidden, next_hidden, top_k_probs, penalty_alpha, top_k
)
# prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores (model confidence minus degeneration penalty); (6) decoder hidden_states
# prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing
# the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores
# (model confidence minus degeneration penalty); (6) decoder hidden_states
next_tokens = top_k_ids[range(len(top_k_ids)), selected_idx]
next_hidden = torch.stack(torch.split(next_hidden.squeeze(dim=1), top_k))
next_hidden = next_hidden[range(bsz), selected_idx, :]
@ -2003,7 +2010,8 @@ class GenerationMixin:
logit_for_next_step = torch.stack(torch.split(logits, top_k))[range(bsz), selected_idx, :]
# contrastive_search main logic end::
# after running the above codes, we update following parameters: next_tokens, past_key_values, logit_for_next_step, selected_score, decoder_hidden_states_one_step
# after running the above codes, we update following parameters: next_tokens, past_key_values,
# logit_for_next_step, selected_score, decoder_hidden_states_one_step
if synced_gpus and this_peer_finished:
continue # don't waste resources running the code we don't need
@ -2047,10 +2055,6 @@ class GenerationMixin:
else:
this_peer_finished = True
# prepare model inputs
model_kwargs["past_key_values"] = past_key_values
step_counter += 1
if return_dict_in_generate:
if self.config.is_encoder_decoder:
return ContrastiveSearchEncoderDecoderOutput(