From d39f794eda5f3007ef1a6312097ddfdba852f7f3 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 25 Oct 2022 14:43:06 +0100 Subject: [PATCH] Generate: contrastive search cosmetic tweaks (#19871) --- src/transformers/generation_utils.py | 60 +++++++++++++++------------- 1 file changed, 32 insertions(+), 28 deletions(-) diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index bb6d5f438c0..d20b8c75dd7 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -100,8 +100,9 @@ class GreedySearchDecoderOnlyOutput(ModelOutput): @dataclass class ContrastiveSearchEncoderDecoderOutput(ModelOutput): """ - Args: Base class for outputs of decoder-only generation models using contrastive search. + + Args: sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`): The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter if all batches finished early due to the `eos_token_id`. @@ -110,7 +111,7 @@ class ContrastiveSearchEncoderDecoderOutput(ModelOutput): Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. - decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` + decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`: is passed or when `config.output_hidden_states=True`): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`. @@ -124,8 +125,9 @@ class ContrastiveSearchEncoderDecoderOutput(ModelOutput): @dataclass class ContrastiveSearchDecoderOnlyOutput(ModelOutput): """ - Args: Base class for outputs of decoder-only generation models using contrastive search. + + Args: sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`): The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter if all batches finished early due to the `eos_token_id`. @@ -433,6 +435,8 @@ GreedySearchOutput = Union[GreedySearchEncoderDecoderOutput, GreedySearchDecoder SampleOutput = Union[SampleEncoderDecoderOutput, SampleDecoderOnlyOutput] BeamSearchOutput = Union[BeamSearchEncoderDecoderOutput, BeamSearchDecoderOnlyOutput] BeamSampleOutput = Union[BeamSampleEncoderDecoderOutput, BeamSampleDecoderOnlyOutput] +ContrastiveSearchOutput = Union[ContrastiveSearchEncoderDecoderOutput, ContrastiveSearchDecoderOnlyOutput] +GenerateOutput = Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, ContrastiveSearchOutput] class GenerationMixin: @@ -1010,7 +1014,7 @@ class GenerationMixin: begin_suppress_tokens: Optional[List[int]] = None, forced_decoder_ids: Optional[List[List[int]]] = None, **model_kwargs, - ) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]: + ) -> Union[GenerateOutput, torch.LongTensor]: r""" Generates sequences of token ids for models with a language modeling head. The method supports the following @@ -1766,7 +1770,7 @@ class GenerationMixin: return_dict_in_generate: Optional[bool] = None, synced_gpus: Optional[bool] = False, **model_kwargs, - ) -> Union[GreedySearchOutput, torch.LongTensor]: + ) -> Union[ContrastiveSearchOutput, torch.LongTensor]: r""" Generates sequences of token ids for models with a language modeling head using **contrastive search** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. @@ -1781,6 +1785,10 @@ class GenerationMixin: logits_processor (`LogitsProcessorList`, *optional*): An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] used to modify the prediction scores of the language modeling head applied at each generation step. + logits_warper (`LogitsProcessorList`, *optional*): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used + to warp the prediction score distribution of the language modeling head applied before multinomial + sampling at each generation step. stopping_criteria (`StoppingCriteriaList`, *optional*): An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] used to tell if the generation loop should stop. @@ -1817,7 +1825,6 @@ class GenerationMixin: >>> from transformers import ( ... AutoTokenizer, ... AutoModelForCausalLM, - ... MinLengthLogitsProcessor, ... StoppingCriteriaList, ... MaxLengthCriteria, ... ) @@ -1859,7 +1866,6 @@ class GenerationMixin: this_peer_finished = False # used by synced_gpus only - step_counter = 0 while True: if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. @@ -1875,20 +1881,23 @@ class GenerationMixin: model_kwargs["use_cache"] = True model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - # if the first step in the loop, encode all the prefix and obtain three parameters: (1) past_key_values; (2) last_hidden_states; (3) logit_for_next_step - if step_counter == 0: - # encode the given prefix and prepare model inputs; encoder-decoder model process the prefix and save the `encoder_outputs` + # if the first step in the loop, encode all the prefix and obtain three parameters: (1) past_key_values; + # (2) last_hidden_states; (3) logit_for_next_step + if model_kwargs.get("past") is None: + # encode the given prefix and prepare model inputs; encoder-decoder model process the prefix and save + # the `encoder_outputs` output = self(**model_inputs, output_hidden_states=True, output_attentions=True) - # past_key_values is activated for fast decoding + # past_key_values is required for fast decoding if "past_key_values" not in output: raise ValueError( - "self.__class__ cannot return `past_key_values` and can therefore **not** be used for" - " contrastive search." + f"{self.__class__.__name__} cannot return `past_key_values` and can therefore **not** be used " + "for contrastive search." ) past_key_values = output.past_key_values - # last decoder hidden states will be used to compute the degeneration penalty (cosine similarity with previous tokens) + # last decoder hidden states will be used to compute the degeneration penalty (cosine similarity with + # previous tokens) if self.config.is_encoder_decoder: last_hidden_states = output.decoder_hidden_states[-1] else: @@ -1897,7 +1906,8 @@ class GenerationMixin: logit_for_next_step = output.logits[:, -1, :] # contrastive_search main logic start: - # contrastive search decoding consists of two steps: (1) candidate tokens recall; (2) candidate re-rank by degeneration penalty + # contrastive search decoding consists of two steps: (1) candidate tokens recall; (2) candidate re-rank by + # degeneration penalty bsz, seqlen, embed_dim = last_hidden_states.size() # logits processor @@ -1949,12 +1959,6 @@ class GenerationMixin: ) # compute the candidate tokens by the language model and collects their hidden_states output = self(output_hidden_states=True, **next_model_inputs) - - if "past_key_values" not in output: - raise ValueError( - "self.__class__ cannot return `past_key_values` and can therefore **not** be used for contrastive" - " search." - ) past_key_values = output.past_key_values logits = output.logits[:, -1, :] @@ -1969,13 +1973,16 @@ class GenerationMixin: last_hidden_states.unsqueeze(1).expand(-1, top_k, -1, -1).reshape(bsz * top_k, seqlen, embed_dim) ) - # compute the degeneratin penalty and re-rank the candidates based on the degeneration penalty and the model confidence + # compute the degeneratin penalty and re-rank the candidates based on the degeneration penalty and the + # model confidence # the scores and index of the selected tokens are returned selected_scores, selected_idx = ranking_fast( context_hidden, next_hidden, top_k_probs, penalty_alpha, top_k ) - # prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores (model confidence minus degeneration penalty); (6) decoder hidden_states + # prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing + # the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores + # (model confidence minus degeneration penalty); (6) decoder hidden_states next_tokens = top_k_ids[range(len(top_k_ids)), selected_idx] next_hidden = torch.stack(torch.split(next_hidden.squeeze(dim=1), top_k)) next_hidden = next_hidden[range(bsz), selected_idx, :] @@ -2003,7 +2010,8 @@ class GenerationMixin: logit_for_next_step = torch.stack(torch.split(logits, top_k))[range(bsz), selected_idx, :] # contrastive_search main logic end:: - # after running the above codes, we update following parameters: next_tokens, past_key_values, logit_for_next_step, selected_score, decoder_hidden_states_one_step + # after running the above codes, we update following parameters: next_tokens, past_key_values, + # logit_for_next_step, selected_score, decoder_hidden_states_one_step if synced_gpus and this_peer_finished: continue # don't waste resources running the code we don't need @@ -2047,10 +2055,6 @@ class GenerationMixin: else: this_peer_finished = True - # prepare model inputs - model_kwargs["past_key_values"] = past_key_values - step_counter += 1 - if return_dict_in_generate: if self.config.is_encoder_decoder: return ContrastiveSearchEncoderDecoderOutput(