diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 2af0232902b..1d5d3b661e4 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -216,6 +216,9 @@ class GenerationConfig(PushToHubMixin): more details. output_scores (`bool`, *optional*, defaults to `False`): Whether or not to return the prediction scores. See `scores` under returned tensors for more details. + output_logits (`bool`, *optional*): + Whether or not to return the unprocessed prediction logit scores. See `logits` under returned tensors for + more details. return_dict_in_generate (`bool`, *optional*, defaults to `False`): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. @@ -315,6 +318,7 @@ class GenerationConfig(PushToHubMixin): self.output_attentions = kwargs.pop("output_attentions", False) self.output_hidden_states = kwargs.pop("output_hidden_states", False) self.output_scores = kwargs.pop("output_scores", False) + self.output_logits = kwargs.pop("output_logits", None) self.return_dict_in_generate = kwargs.pop("return_dict_in_generate", False) # Special tokens that can be used at generation time diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 0c6740b3238..6fd2c752a0a 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -110,6 +110,10 @@ class GenerateDecoderOnlyOutput(ModelOutput): Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. + logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True` is passed or when `config.output_logits=True`): + Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) + at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for + each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. @@ -127,6 +131,7 @@ class GenerateDecoderOnlyOutput(ModelOutput): sequences: torch.LongTensor = None scores: Optional[Tuple[torch.FloatTensor]] = None + logits: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None @@ -145,6 +150,10 @@ class GenerateEncoderDecoderOutput(ModelOutput): Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. + logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True` is passed or when `config.output_logits=True`): + Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) + at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for + each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. @@ -171,6 +180,7 @@ class GenerateEncoderDecoderOutput(ModelOutput): sequences: torch.LongTensor = None scores: Optional[Tuple[torch.FloatTensor]] = None + logits: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None @@ -195,6 +205,10 @@ class GenerateBeamDecoderOnlyOutput(ModelOutput): of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token), with each tensor of shape `(batch_size*num_beams*num_return_sequences, config.vocab_size)`. + logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True` is passed or when `config.output_logits=True`): + Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) + at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for + each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. beam_indices (`torch.LongTensor`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): Beam indices of generated token id at each generation step. `torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`. @@ -216,6 +230,7 @@ class GenerateBeamDecoderOnlyOutput(ModelOutput): sequences: torch.LongTensor = None sequences_scores: Optional[torch.FloatTensor] = None scores: Optional[Tuple[torch.FloatTensor]] = None + logits: Optional[Tuple[torch.FloatTensor]] = None beam_indices: Optional[torch.LongTensor] = None attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None @@ -238,6 +253,10 @@ class GenerateBeamEncoderDecoderOutput(ModelOutput): of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token), with each tensor of shape `(batch_size*num_beams, config.vocab_size)`. + logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True` is passed or when `config.output_logits=True`): + Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) + at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for + each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. beam_indices (`torch.LongTensor`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): Beam indices of generated token id at each generation step. `torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`. @@ -269,6 +288,7 @@ class GenerateBeamEncoderDecoderOutput(ModelOutput): sequences: torch.LongTensor = None sequences_scores: Optional[torch.FloatTensor] = None scores: Optional[Tuple[torch.FloatTensor]] = None + logits: Optional[Tuple[torch.FloatTensor]] = None beam_indices: Optional[torch.LongTensor] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None @@ -1514,6 +1534,7 @@ class GenerationMixin: pad_token_id=generation_config.pad_token_id, eos_token_id=generation_config.eos_token_id, output_scores=generation_config.output_scores, + output_logits=generation_config.output_logits, return_dict_in_generate=generation_config.return_dict_in_generate, synced_gpus=synced_gpus, streamer=streamer, @@ -1528,6 +1549,7 @@ class GenerationMixin: pad_token_id=generation_config.pad_token_id, eos_token_id=generation_config.eos_token_id, output_scores=generation_config.output_scores, + output_logits=generation_config.output_logits, return_dict_in_generate=generation_config.return_dict_in_generate, synced_gpus=synced_gpus, streamer=streamer, @@ -1547,6 +1569,7 @@ class GenerationMixin: pad_token_id=generation_config.pad_token_id, eos_token_id=generation_config.eos_token_id, output_scores=generation_config.output_scores, + output_logits=generation_config.output_logits, return_dict_in_generate=generation_config.return_dict_in_generate, synced_gpus=synced_gpus, streamer=streamer, @@ -1575,6 +1598,7 @@ class GenerationMixin: pad_token_id=generation_config.pad_token_id, eos_token_id=generation_config.eos_token_id, output_scores=generation_config.output_scores, + output_logits=generation_config.output_logits, return_dict_in_generate=generation_config.return_dict_in_generate, synced_gpus=synced_gpus, streamer=streamer, @@ -1608,6 +1632,7 @@ class GenerationMixin: pad_token_id=generation_config.pad_token_id, eos_token_id=generation_config.eos_token_id, output_scores=generation_config.output_scores, + output_logits=generation_config.output_logits, return_dict_in_generate=generation_config.return_dict_in_generate, synced_gpus=synced_gpus, sequential=generation_config.low_memory, @@ -1647,6 +1672,7 @@ class GenerationMixin: pad_token_id=generation_config.pad_token_id, eos_token_id=generation_config.eos_token_id, output_scores=generation_config.output_scores, + output_logits=generation_config.output_logits, return_dict_in_generate=generation_config.return_dict_in_generate, synced_gpus=synced_gpus, **model_kwargs, @@ -1680,6 +1706,7 @@ class GenerationMixin: pad_token_id=generation_config.pad_token_id, eos_token_id=generation_config.eos_token_id, output_scores=generation_config.output_scores, + output_logits=generation_config.output_logits, return_dict_in_generate=generation_config.return_dict_in_generate, synced_gpus=synced_gpus, **model_kwargs, @@ -1753,6 +1780,7 @@ class GenerationMixin: pad_token_id=generation_config.pad_token_id, eos_token_id=generation_config.eos_token_id, output_scores=generation_config.output_scores, + output_logits=generation_config.output_logits, return_dict_in_generate=generation_config.return_dict_in_generate, synced_gpus=synced_gpus, **model_kwargs, @@ -1772,6 +1800,7 @@ class GenerationMixin: output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, + output_logits: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, synced_gpus: bool = False, streamer: Optional["BaseStreamer"] = None, @@ -1819,6 +1848,9 @@ class GenerationMixin: for more details. output_scores (`bool`, *optional*, defaults to `False`): Whether or not to return the prediction scores. See `scores` under returned tensors for more details. + output_logits (`bool`, *optional*, defaults to `False`): + Whether or not to return the raw prediction logit scores. See `logits` under returned tensors + for more details. return_dict_in_generate (`bool`, *optional*, defaults to `False`): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. synced_gpus (`bool`, *optional*, defaults to `False`): @@ -1872,6 +1904,7 @@ class GenerationMixin: eos_token_id = [eos_token_id] eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None output_scores = output_scores if output_scores is not None else self.generation_config.output_scores + output_logits = output_logits if output_logits is not None else self.generation_config.output_logits output_attentions = ( output_attentions if output_attentions is not None else self.generation_config.output_attentions ) @@ -1885,6 +1918,7 @@ class GenerationMixin: ) # init attention / hidden states / scores tuples + raw_logits = () if (return_dict_in_generate and output_logits) else None scores = () if (return_dict_in_generate and output_scores) else None decoder_attentions = () if (return_dict_in_generate and output_attentions) else None cross_attentions = () if (return_dict_in_generate and output_attentions) else None @@ -1967,15 +2001,18 @@ class GenerationMixin: # contrastive_search main logic start: # contrastive search decoding consists of two steps: (1) candidate tokens recall; (2) candidate re-rank by # degeneration penalty - logit_for_next_step = logits_processor(input_ids, logit_for_next_step) - logit_for_next_step = logits_warper(input_ids, logit_for_next_step) - next_probs = nn.functional.softmax(logit_for_next_step, dim=-1) + processed_logit_for_next_step = logits_processor(input_ids, logit_for_next_step) + processed_logit_for_next_step = logits_warper(input_ids, processed_logit_for_next_step) + next_probs = nn.functional.softmax(processed_logit_for_next_step, dim=-1) + top_k_probs, top_k_ids = torch.topk(next_probs, dim=-1, k=top_k) # Store scores, attentions and hidden_states when required if return_dict_in_generate: + if output_logits: + raw_logits += (logit_for_next_step,) if output_scores: - scores += (logit_for_next_step,) + scores += (processed_logit_for_next_step,) if output_attentions: decoder_attentions += ( (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) @@ -2172,6 +2209,7 @@ class GenerationMixin: return GenerateEncoderDecoderOutput( sequences=input_ids, scores=scores, + logits=raw_logits, encoder_attentions=encoder_attentions, encoder_hidden_states=encoder_hidden_states, decoder_attentions=decoder_attentions, @@ -2183,6 +2221,7 @@ class GenerationMixin: return GenerateDecoderOnlyOutput( sequences=input_ids, scores=scores, + logits=raw_logits, attentions=decoder_attentions, hidden_states=decoder_hidden_states, past_key_values=model_kwargs.get("past_key_values"), @@ -2201,6 +2240,7 @@ class GenerationMixin: output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, + output_logits: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, synced_gpus: bool = False, streamer: Optional["BaseStreamer"] = None, @@ -2244,6 +2284,9 @@ class GenerationMixin: for more details. output_scores (`bool`, *optional*, defaults to `False`): Whether or not to return the prediction scores. See `scores` under returned tensors for more details. + output_logits (`bool`, *optional*, defaults to `False`): + Whether or not to return the raw prediction logit scores. See `logits` under returned tensors + for more details. return_dict_in_generate (`bool`, *optional*, defaults to `False`): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. synced_gpus (`bool`, *optional*, defaults to `False`): @@ -2327,6 +2370,7 @@ class GenerationMixin: ) # init attention / hidden states / scores tuples + raw_logits = () if (return_dict_in_generate and output_logits) else None scores = () if (return_dict_in_generate and output_scores) else None decoder_attentions = () if (return_dict_in_generate and output_attentions) else None cross_attentions = () if (return_dict_in_generate and output_attentions) else None @@ -2377,6 +2421,8 @@ class GenerationMixin: if return_dict_in_generate: if output_scores: scores += (next_tokens_scores,) + if output_logits: + raw_logits += (next_token_logits,) if output_attentions: decoder_attentions += ( (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) @@ -2433,6 +2479,7 @@ class GenerationMixin: return GenerateEncoderDecoderOutput( sequences=input_ids, scores=scores, + logits=raw_logits, encoder_attentions=encoder_attentions, encoder_hidden_states=encoder_hidden_states, decoder_attentions=decoder_attentions, @@ -2444,6 +2491,7 @@ class GenerationMixin: return GenerateDecoderOnlyOutput( sequences=input_ids, scores=scores, + logits=raw_logits, attentions=decoder_attentions, hidden_states=decoder_hidden_states, past_key_values=model_kwargs.get("past_key_values"), @@ -2463,6 +2511,7 @@ class GenerationMixin: output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, + output_logits: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, synced_gpus: bool = False, streamer: Optional["BaseStreamer"] = None, @@ -2508,6 +2557,9 @@ class GenerationMixin: for more details. output_scores (`bool`, *optional*, defaults to `False`): Whether or not to return the prediction scores. See `scores` under returned tensors for more details. + output_logits (`bool`, *optional*, defaults to `False`): + Whether or not to return the raw prediction logit scores. See `logits` under returned tensors for + more details. return_dict_in_generate (`bool`, *optional*, defaults to `False`): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. synced_gpus (`bool`, *optional*, defaults to `False`): @@ -2595,6 +2647,7 @@ class GenerationMixin: eos_token_id = [eos_token_id] eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None output_scores = output_scores if output_scores is not None else self.generation_config.output_scores + output_logits = output_logits if output_logits is not None else self.generation_config.output_logits output_attentions = ( output_attentions if output_attentions is not None else self.generation_config.output_attentions ) @@ -2609,6 +2662,7 @@ class GenerationMixin: # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None + raw_logits = () if (return_dict_in_generate and output_logits) else None decoder_attentions = () if (return_dict_in_generate and output_attentions) else None cross_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None @@ -2660,6 +2714,8 @@ class GenerationMixin: if return_dict_in_generate: if output_scores: scores += (next_token_scores,) + if output_logits: + raw_logits += (next_token_logits,) if output_attentions: decoder_attentions += ( (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) @@ -2717,6 +2773,7 @@ class GenerationMixin: return GenerateEncoderDecoderOutput( sequences=input_ids, scores=scores, + logits=raw_logits, encoder_attentions=encoder_attentions, encoder_hidden_states=encoder_hidden_states, decoder_attentions=decoder_attentions, @@ -2728,6 +2785,7 @@ class GenerationMixin: return GenerateDecoderOnlyOutput( sequences=input_ids, scores=scores, + logits=raw_logits, attentions=decoder_attentions, hidden_states=decoder_hidden_states, past_key_values=model_kwargs.get("past_key_values"), @@ -2773,6 +2831,7 @@ class GenerationMixin: output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, + output_logits: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, synced_gpus: bool = False, sequential: Optional[bool] = None, @@ -2815,6 +2874,9 @@ class GenerationMixin: output_hidden_states (`bool`, *optional*, defaults to `False`): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more details. + output_logits (`bool`, *optional*, defaults to `False`): + Whether or not to return the raw prediction logit scores. See `logits` under returned tensors for + more details. output_scores (`bool`, *optional*, defaults to `False`): Whether or not to return the prediction scores. See `scores` under returned tensors for more details. return_dict_in_generate (`bool`, *optional*, defaults to `False`): @@ -2906,6 +2968,7 @@ class GenerationMixin: if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] output_scores = output_scores if output_scores is not None else self.generation_config.output_scores + output_logits = output_logits if output_logits is not None else self.generation_config.output_logits output_attentions = ( output_attentions if output_attentions is not None else self.generation_config.output_attentions ) @@ -2930,6 +2993,7 @@ class GenerationMixin: # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None + raw_logits = () if (return_dict_in_generate and output_logits) else None beam_indices = ( tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None ) @@ -3027,13 +3091,14 @@ class GenerationMixin: if return_dict_in_generate: if output_scores: scores += (next_token_scores_processed,) + if output_logits: + raw_logits += (next_token_logits,) if output_attentions: decoder_attentions += ( (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) ) if self.config.is_encoder_decoder: cross_attentions += (outputs.cross_attentions,) - if output_hidden_states: decoder_hidden_states += ( (outputs.decoder_hidden_states,) @@ -3113,6 +3178,7 @@ class GenerationMixin: sequences=sequence_outputs["sequences"], sequences_scores=sequence_outputs["sequence_scores"], scores=scores, + logits=raw_logits, beam_indices=sequence_outputs["beam_indices"], encoder_attentions=encoder_attentions, encoder_hidden_states=encoder_hidden_states, @@ -3126,6 +3192,7 @@ class GenerationMixin: sequences=sequence_outputs["sequences"], sequences_scores=sequence_outputs["sequence_scores"], scores=scores, + logits=raw_logits, beam_indices=sequence_outputs["beam_indices"], attentions=decoder_attentions, hidden_states=decoder_hidden_states, @@ -3147,6 +3214,7 @@ class GenerationMixin: output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, + output_logits: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, synced_gpus: bool = False, **model_kwargs, @@ -3194,6 +3262,9 @@ class GenerationMixin: for more details. output_scores (`bool`, *optional*, defaults to `False`): Whether or not to return the prediction scores. See `scores` under returned tensors for more details. + output_logits (`bool`, *optional*, defaults to `False`): + Whether or not to return the raw prediction logit scores. See `logits` under returned tensors for + more details. return_dict_in_generate (`bool`, *optional*, defaults to `False`): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. synced_gpus (`bool`, *optional*, defaults to `False`): @@ -3284,6 +3355,7 @@ class GenerationMixin: if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] output_scores = output_scores if output_scores is not None else self.generation_config.output_scores + output_logits = output_logits if output_logits is not None else self.generation_config.output_logits output_attentions = ( output_attentions if output_attentions is not None else self.generation_config.output_attentions ) @@ -3303,6 +3375,7 @@ class GenerationMixin: # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None + raw_logits = () if (return_dict_in_generate and output_logits) else None beam_indices = ( tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None ) @@ -3363,6 +3436,8 @@ class GenerationMixin: if return_dict_in_generate: if output_scores: scores += (next_token_scores_processed,) + if output_logits: + raw_logits += (next_token_logits,) if output_attentions: decoder_attentions += ( (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) @@ -3450,6 +3525,7 @@ class GenerationMixin: sequences=sequence_outputs["sequences"], sequences_scores=sequence_outputs["sequence_scores"], scores=scores, + logits=raw_logits, beam_indices=sequence_outputs["beam_indices"], encoder_attentions=encoder_attentions, encoder_hidden_states=encoder_hidden_states, @@ -3463,6 +3539,7 @@ class GenerationMixin: sequences=sequence_outputs["sequences"], sequences_scores=sequence_outputs["sequence_scores"], scores=scores, + logits=raw_logits, beam_indices=sequence_outputs["beam_indices"], attentions=decoder_attentions, hidden_states=decoder_hidden_states, @@ -3483,6 +3560,7 @@ class GenerationMixin: output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, + output_logits: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, synced_gpus: bool = False, **model_kwargs, @@ -3526,6 +3604,9 @@ class GenerationMixin: for more details. output_scores (`bool`, *optional*, defaults to `False`): Whether or not to return the prediction scores. See `scores` under returned tensors for more details. + output_logits (`bool`, *optional*, defaults to `False`): + Whether or not to return the raw prediction logit scores. See `logits` under returned tensors for + more details. return_dict_in_generate (`bool`, *optional*, defaults to `False`): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. synced_gpus (`bool`, *optional*, defaults to `False`): @@ -3614,6 +3695,7 @@ class GenerationMixin: if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] output_scores = output_scores if output_scores is not None else self.generation_config.output_scores + output_logits = output_logits if output_logits is not None else self.generation_config.output_logits output_attentions = ( output_attentions if output_attentions is not None else self.generation_config.output_attentions ) @@ -3646,6 +3728,7 @@ class GenerationMixin: # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None + raw_logits = () if (return_dict_in_generate and output_logits) else None decoder_attentions = () if (return_dict_in_generate and output_attentions) else None cross_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None @@ -3698,6 +3781,8 @@ class GenerationMixin: if output_scores: processed_score = torch.zeros_like(outputs.logits[:, -1, :]) + if output_logits: + raw_logit_score = outputs.logits[:, -1, :] for beam_group_idx in range(num_beam_groups): group_start_idx = beam_group_idx * num_sub_beams @@ -3780,6 +3865,8 @@ class GenerationMixin: if return_dict_in_generate: if output_scores: scores += (processed_score,) + if output_logits: + raw_logits += (raw_logit_score,) if output_attentions: decoder_attentions += ( (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) @@ -3835,6 +3922,7 @@ class GenerationMixin: sequences=sequence_outputs["sequences"], sequences_scores=sequence_outputs["sequence_scores"], scores=scores, + logits=raw_logits, beam_indices=sequence_outputs["beam_indices"], encoder_attentions=encoder_attentions, encoder_hidden_states=encoder_hidden_states, @@ -3848,6 +3936,7 @@ class GenerationMixin: sequences=sequence_outputs["sequences"], sequences_scores=sequence_outputs["sequence_scores"], scores=scores, + logits=raw_logits, beam_indices=sequence_outputs["beam_indices"], attentions=decoder_attentions, hidden_states=decoder_hidden_states, @@ -3868,6 +3957,7 @@ class GenerationMixin: output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, + output_logits: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, synced_gpus: Optional[bool] = None, **model_kwargs, @@ -3916,6 +4006,9 @@ class GenerationMixin: for more details. output_scores (`bool`, *optional*, defaults to `False`): Whether or not to return the prediction scores. See `scores` under returned tensors for more details. + output_logits (`bool`, *optional*, defaults to `False`): + Whether or not to return the raw prediction logit scores. See `logits` under returned tensors for + more details. return_dict_in_generate (`bool`, *optional*, defaults to `False`): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. synced_gpus (`bool`, *optional*, defaults to `False`): @@ -4006,6 +4099,7 @@ class GenerationMixin: if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] output_scores = output_scores if output_scores is not None else self.generation_config.output_scores + output_logits = output_logits if output_logits is not None else self.generation_config.output_logits output_attentions = ( output_attentions if output_attentions is not None else self.generation_config.output_attentions ) @@ -4030,6 +4124,7 @@ class GenerationMixin: # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None + raw_logits = () if (return_dict_in_generate and output_logits) else None beam_indices = ( tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None ) @@ -4094,6 +4189,8 @@ class GenerationMixin: if return_dict_in_generate: if output_scores: scores += (next_token_scores,) + if output_logits: + raw_logits += (next_token_logits,) if output_attentions: decoder_attentions += ( (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) @@ -4178,6 +4275,7 @@ class GenerationMixin: sequences=sequence_outputs["sequences"], sequences_scores=sequence_outputs["sequence_scores"], scores=scores, + logits=raw_logits, beam_indices=sequence_outputs["beam_indices"], encoder_attentions=encoder_attentions, encoder_hidden_states=encoder_hidden_states, @@ -4191,6 +4289,7 @@ class GenerationMixin: sequences=sequence_outputs["sequences"], sequences_scores=sequence_outputs["sequence_scores"], scores=scores, + logits=raw_logits, beam_indices=sequence_outputs["beam_indices"], attentions=decoder_attentions, hidden_states=decoder_hidden_states, @@ -4213,6 +4312,7 @@ class GenerationMixin: output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, + output_logits: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, synced_gpus: bool = False, streamer: Optional["BaseStreamer"] = None, @@ -4267,6 +4367,9 @@ class GenerationMixin: for more details. output_scores (`bool`, *optional*, defaults to `False`): Whether or not to return the prediction scores. See `scores` under returned tensors for more details. + output_logits (`bool`, *optional*, defaults to `False`): + Whether or not to return the raw prediction logit scores. See `logits` under returned tensors for + more details. return_dict_in_generate (`bool`, *optional*, defaults to `False`): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. synced_gpus (`bool`, *optional*, defaults to `False`): @@ -4350,6 +4453,7 @@ class GenerationMixin: eos_token_id = [eos_token_id] eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None output_scores = output_scores if output_scores is not None else self.generation_config.output_scores + output_logits = output_logits if output_logits is not None else self.generation_config.output_logits output_attentions = ( output_attentions if output_attentions is not None else self.generation_config.output_attentions ) @@ -4364,6 +4468,7 @@ class GenerationMixin: # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None + raw_logits = () if (return_dict_in_generate and output_logits) else None decoder_attentions = () if (return_dict_in_generate and output_attentions) else None cross_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None @@ -4432,6 +4537,7 @@ class GenerationMixin: # 2.3. Process the new logits new_logits = outputs.logits[:, -candidate_length - 1 :] # excludes the input prompt if present + next_token_logits = new_logits.clone() if len(logits_processor) > 0: for i in range(candidate_length + 1): new_logits[:, i, :] = logits_processor(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :]) @@ -4498,6 +4604,8 @@ class GenerationMixin: if return_dict_in_generate: if output_scores: scores += tuple(new_logits[:, i, :] for i in range(n_matches + 1)) + if output_logits: + raw_logits += (next_token_logits,) if "past_key_values" not in model_kwargs: added_len = new_cur_len @@ -4573,6 +4681,7 @@ class GenerationMixin: return GenerateEncoderDecoderOutput( sequences=input_ids, scores=scores, + logits=raw_logits, encoder_attentions=encoder_attentions, encoder_hidden_states=encoder_hidden_states, decoder_attentions=decoder_attentions, @@ -4584,6 +4693,7 @@ class GenerationMixin: return GenerateDecoderOnlyOutput( sequences=input_ids, scores=scores, + logits=raw_logits, attentions=decoder_attentions, hidden_states=decoder_hidden_states, past_key_values=model_kwargs.get("past_key_values"), diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index b4e1a218a92..cb224c3c6a9 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -269,6 +269,7 @@ class GenerationTesterMixin: attention_mask, max_length, output_scores=False, + output_logits=False, output_attentions=False, output_hidden_states=False, return_dict_in_generate=False, @@ -293,6 +294,7 @@ class GenerationTesterMixin: output_attentions=output_attentions, output_hidden_states=output_hidden_states, output_scores=output_scores, + output_logits=output_logits, return_dict_in_generate=return_dict_in_generate, **logits_process_kwargs, **model_kwargs, @@ -317,6 +319,7 @@ class GenerationTesterMixin: output_attentions=output_attentions, output_hidden_states=output_hidden_states, output_scores=output_scores, + output_logits=output_logits, return_dict_in_generate=return_dict_in_generate, **kwargs, **model_kwargs, @@ -335,6 +338,7 @@ class GenerationTesterMixin: logits_warper_kwargs, process_kwargs, output_scores=False, + output_logits=False, output_attentions=False, output_hidden_states=False, return_dict_in_generate=False, @@ -348,6 +352,7 @@ class GenerationTesterMixin: max_length=max_length, num_return_sequences=num_return_sequences, output_scores=output_scores, + output_logits=output_logits, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, @@ -379,6 +384,7 @@ class GenerationTesterMixin: logits_processor=logits_processor, logits_warper=logits_warper, output_scores=output_scores, + output_logits=output_logits, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, @@ -399,6 +405,7 @@ class GenerationTesterMixin: logits_processor, logits_process_kwargs, output_scores=False, + output_logits=False, output_attentions=False, output_hidden_states=False, return_dict_in_generate=False, @@ -409,6 +416,7 @@ class GenerationTesterMixin: do_sample=False, max_length=max_length, output_scores=output_scores, + output_logits=output_logits, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, @@ -440,6 +448,7 @@ class GenerationTesterMixin: max_length=max_length, logits_processor=logits_processor, output_scores=output_scores, + output_logits=output_logits, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, @@ -459,6 +468,7 @@ class GenerationTesterMixin: logits_warper, logits_warper_kwargs, output_scores=False, + output_logits=False, output_attentions=False, output_hidden_states=False, return_dict_in_generate=False, @@ -470,6 +480,7 @@ class GenerationTesterMixin: do_sample=True, max_length=max_length, output_scores=output_scores, + output_logits=output_logits, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, @@ -506,6 +517,7 @@ class GenerationTesterMixin: logits_warper=logits_warper, logits_processor=logits_processor, output_scores=output_scores, + output_logits=output_logits, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, @@ -526,6 +538,7 @@ class GenerationTesterMixin: logits_processor, logits_process_kwargs, output_scores=False, + output_logits=False, output_attentions=False, output_hidden_states=False, return_dict_in_generate=False, @@ -536,6 +549,7 @@ class GenerationTesterMixin: do_sample=False, max_length=max_length, output_scores=output_scores, + output_logits=output_logits, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, @@ -567,6 +581,7 @@ class GenerationTesterMixin: max_length=max_length, logits_processor=logits_processor, output_scores=output_scores, + output_logits=output_logits, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, @@ -587,6 +602,7 @@ class GenerationTesterMixin: logits_processor, logits_process_kwargs, output_scores=False, + output_logits=False, output_attentions=False, output_hidden_states=False, return_dict_in_generate=False, @@ -597,6 +613,7 @@ class GenerationTesterMixin: do_sample=False, max_length=max_length, output_scores=output_scores, + output_logits=output_logits, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, @@ -629,6 +646,7 @@ class GenerationTesterMixin: max_length=max_length, logits_processor=logits_processor, output_scores=output_scores, + output_logits=output_logits, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, @@ -644,6 +662,7 @@ class GenerationTesterMixin: attention_mask, max_length, output_scores=False, + output_logits=False, output_attentions=False, output_hidden_states=False, return_dict_in_generate=False, @@ -673,6 +692,7 @@ class GenerationTesterMixin: output_attentions=output_attentions, output_hidden_states=output_hidden_states, output_scores=output_scores, + output_logits=output_logits, return_dict_in_generate=return_dict_in_generate, **logits_process_kwargs, **model_kwargs, @@ -699,6 +719,7 @@ class GenerationTesterMixin: output_attentions=output_attentions, output_hidden_states=output_hidden_states, output_scores=output_scores, + output_logits=output_logits, return_dict_in_generate=return_dict_in_generate, **kwargs, **model_kwargs, @@ -729,6 +750,7 @@ class GenerationTesterMixin: attention_mask=attention_mask, max_length=max_length, output_scores=True, + output_logits=True, output_hidden_states=True, output_attentions=True, return_dict_in_generate=True, @@ -769,6 +791,7 @@ class GenerationTesterMixin: attention_mask=attention_mask, max_length=max_length, output_scores=True, + output_logits=True, output_hidden_states=True, output_attentions=True, return_dict_in_generate=True, @@ -853,6 +876,7 @@ class GenerationTesterMixin: logits_warper_kwargs=logits_warper_kwargs, process_kwargs=process_kwargs, output_scores=True, + output_logits=True, output_hidden_states=True, output_attentions=True, return_dict_in_generate=True, @@ -964,6 +988,7 @@ class GenerationTesterMixin: logits_process_kwargs=logits_process_kwargs, logits_processor=logits_processor, output_scores=True, + output_logits=True, output_hidden_states=True, output_attentions=True, return_dict_in_generate=True, @@ -1032,6 +1057,7 @@ class GenerationTesterMixin: logits_process_kwargs=logits_process_kwargs, logits_processor=logits_processor, output_scores=True, + output_logits=True, output_hidden_states=True, output_attentions=True, return_dict_in_generate=True, @@ -1126,6 +1152,7 @@ class GenerationTesterMixin: logits_warper=logits_warper, logits_warper_kwargs=logits_warper_kwargs, output_scores=True, + output_logits=True, output_hidden_states=True, output_attentions=True, return_dict_in_generate=True, @@ -1262,6 +1289,7 @@ class GenerationTesterMixin: logits_processor=logits_processor, logits_process_kwargs=logits_process_kwargs, output_scores=True, + output_logits=True, output_hidden_states=True, output_attentions=True, return_dict_in_generate=True, @@ -1421,6 +1449,7 @@ class GenerationTesterMixin: logits_processor=logits_processor, logits_process_kwargs=logits_process_kwargs, output_scores=True, + output_logits=True, output_hidden_states=True, output_attentions=True, return_dict_in_generate=True, @@ -1493,6 +1522,7 @@ class GenerationTesterMixin: attention_mask=attention_mask, max_length=max_length, output_scores=True, + output_logits=True, output_hidden_states=True, output_attentions=True, return_dict_in_generate=True, @@ -1628,6 +1658,7 @@ class GenerationTesterMixin: "num_beams": 1, "do_sample": False, "output_scores": True, + "output_logits": True, "output_hidden_states": True, "output_attentions": True, "return_dict_in_generate": True, @@ -1690,6 +1721,7 @@ class GenerationTesterMixin: "num_beams": 1, "do_sample": False, "output_scores": True, + "output_logits": True, "output_hidden_states": True, "output_attentions": True, "return_dict_in_generate": True, @@ -1753,6 +1785,7 @@ class GenerationTesterMixin: "do_sample": True, "assistant_model": assistant_model, "output_scores": True, + "output_logits": True, "output_hidden_states": True, "output_attentions": True, "return_dict_in_generate": True, @@ -2105,6 +2138,7 @@ class GenerationTesterMixin: def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1): batch_size, seq_length = input_ids.shape num_sequences_in_output = batch_size * num_return_sequences + gen_len = ( output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - seq_length ) @@ -2112,6 +2146,9 @@ class GenerationTesterMixin: # scores self._check_scores(num_sequences_in_output, output.scores, length=gen_len, config=config) + # unprocessed logits + self._check_logits(num_sequences_in_output, output.logits, config=config) + # Attentions if config.is_encoder_decoder: # encoder @@ -2191,6 +2228,14 @@ class GenerationTesterMixin: self.assertEqual(len(scores), length) self.assertListEqual([iter_scores.shape for iter_scores in scores], [expected_shape] * len(scores)) + def _check_logits(self, batch_size, scores, config): + self.assertIsInstance(scores, tuple) + self.assertListEqual([iter_scores.shape[0] for iter_scores in scores], [batch_size] * len(scores)) + # vocabulary difference equal to one (imagegptmodel?) or zero (all other models) + vocab_diff = config.vocab_size - scores[0].shape[-1] + self.assertTrue(vocab_diff in [0, 1]) + self.assertListEqual([config.vocab_size - score.shape[-1] for score in scores], [vocab_diff] * len(scores)) + def _check_attentions_for_generate( self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1 ): @@ -3536,3 +3581,60 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi model.generate(**inputs, **generation_kwargs) # update_candidate_strategy is called once but assistant_model.generation_config.num_assistant_tokens should stay 5 self.assertEqual(assistant_model.generation_config.num_assistant_tokens, 5) + + def test_compare_unprocessed_logit_scores(self): + # Get unprocessed logit scores back from model generate function. + # Assert that unprocessed logits from generate() are same as those from modal eval() + + # tell model to generate text and return unprocessed/unwarped logit scores + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") + text = "generate yes or no: " + input_ids = tokenizer([text], return_tensors="pt").input_ids.to(torch_device) + + model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) + + with torch.no_grad(): + # Get logits for the next token from fwd pass + logits_fwd = model(input_ids).logits[:, -1, :][0] + + # Get logits for the next token from generate function + outputs = model.generate( + input_ids=input_ids, + return_dict_in_generate=True, + output_logits=True, + max_new_tokens=1, + do_sample=True, + ) + logits_gen = outputs.logits[0][0] + + # assert that unprocessed logits from generate() are same as those from modal eval() + self.assertListEqual(logits_fwd.tolist(), logits_gen.tolist()) + + def test_return_unprocessed_logit_scores(self): + # tell model to generate text and return unprocessed/unwarped logit scores + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") + text = "generate yes or no: " + input_ids = tokenizer([text], return_tensors="pt").input_ids.to(torch_device) + model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) + + outputs = model.generate( + input_ids=input_ids, return_dict_in_generate=True, output_logits=True, max_new_tokens=3 + ) + + # perform dummy check if unpreprocessed logits make sense. + # do preselection on high probabilities; find scores of y and n tokens + probs_all = torch.nn.functional.softmax(outputs.logits[2][0], dim=-1) + indices = torch.argwhere(probs_all > 0.001) + indices = indices[:, -1] + tokens_max = tokenizer.batch_decode(indices, skip_special_tokens=True) + probs_max = probs_all[probs_all > 0.001] + + self.assertTrue(len(indices) >= 2) + next_token_dict = {str(t): p for t, p in zip(tokens_max, probs_max)} + self.assertTrue("n" in next_token_dict) + self.assertTrue("y" in next_token_dict) + y_prob = next_token_dict["y"] + n_prob = next_token_dict["n"] + + self.assertTrue(y_prob > 0.001 and n_prob > 0.001) + self.assertTrue(y_prob <= 1.0 and n_prob <= 1.0)