mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Generate: contrastive search cosmetic tweaks (#19871)
This commit is contained in:
parent
0a77249178
commit
d39f794eda
@ -100,8 +100,9 @@ class GreedySearchDecoderOnlyOutput(ModelOutput):
|
||||
@dataclass
|
||||
class ContrastiveSearchEncoderDecoderOutput(ModelOutput):
|
||||
"""
|
||||
Args:
|
||||
Base class for outputs of decoder-only generation models using contrastive search.
|
||||
|
||||
Args:
|
||||
sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
|
||||
if all batches finished early due to the `eos_token_id`.
|
||||
@ -110,7 +111,7 @@ class ContrastiveSearchEncoderDecoderOutput(ModelOutput):
|
||||
Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
|
||||
at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
|
||||
each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
|
||||
decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`
|
||||
decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`:
|
||||
is passed or when `config.output_hidden_states=True`):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
`torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
|
||||
@ -124,8 +125,9 @@ class ContrastiveSearchEncoderDecoderOutput(ModelOutput):
|
||||
@dataclass
|
||||
class ContrastiveSearchDecoderOnlyOutput(ModelOutput):
|
||||
"""
|
||||
Args:
|
||||
Base class for outputs of decoder-only generation models using contrastive search.
|
||||
|
||||
Args:
|
||||
sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
|
||||
if all batches finished early due to the `eos_token_id`.
|
||||
@ -433,6 +435,8 @@ GreedySearchOutput = Union[GreedySearchEncoderDecoderOutput, GreedySearchDecoder
|
||||
SampleOutput = Union[SampleEncoderDecoderOutput, SampleDecoderOnlyOutput]
|
||||
BeamSearchOutput = Union[BeamSearchEncoderDecoderOutput, BeamSearchDecoderOnlyOutput]
|
||||
BeamSampleOutput = Union[BeamSampleEncoderDecoderOutput, BeamSampleDecoderOnlyOutput]
|
||||
ContrastiveSearchOutput = Union[ContrastiveSearchEncoderDecoderOutput, ContrastiveSearchDecoderOnlyOutput]
|
||||
GenerateOutput = Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, ContrastiveSearchOutput]
|
||||
|
||||
|
||||
class GenerationMixin:
|
||||
@ -1010,7 +1014,7 @@ class GenerationMixin:
|
||||
begin_suppress_tokens: Optional[List[int]] = None,
|
||||
forced_decoder_ids: Optional[List[List[int]]] = None,
|
||||
**model_kwargs,
|
||||
) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]:
|
||||
) -> Union[GenerateOutput, torch.LongTensor]:
|
||||
r"""
|
||||
|
||||
Generates sequences of token ids for models with a language modeling head. The method supports the following
|
||||
@ -1766,7 +1770,7 @@ class GenerationMixin:
|
||||
return_dict_in_generate: Optional[bool] = None,
|
||||
synced_gpus: Optional[bool] = False,
|
||||
**model_kwargs,
|
||||
) -> Union[GreedySearchOutput, torch.LongTensor]:
|
||||
) -> Union[ContrastiveSearchOutput, torch.LongTensor]:
|
||||
r"""
|
||||
Generates sequences of token ids for models with a language modeling head using **contrastive search** and can
|
||||
be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
|
||||
@ -1781,6 +1785,10 @@ class GenerationMixin:
|
||||
logits_processor (`LogitsProcessorList`, *optional*):
|
||||
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
|
||||
used to modify the prediction scores of the language modeling head applied at each generation step.
|
||||
logits_warper (`LogitsProcessorList`, *optional*):
|
||||
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
|
||||
to warp the prediction score distribution of the language modeling head applied before multinomial
|
||||
sampling at each generation step.
|
||||
stopping_criteria (`StoppingCriteriaList`, *optional*):
|
||||
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
|
||||
used to tell if the generation loop should stop.
|
||||
@ -1817,7 +1825,6 @@ class GenerationMixin:
|
||||
>>> from transformers import (
|
||||
... AutoTokenizer,
|
||||
... AutoModelForCausalLM,
|
||||
... MinLengthLogitsProcessor,
|
||||
... StoppingCriteriaList,
|
||||
... MaxLengthCriteria,
|
||||
... )
|
||||
@ -1859,7 +1866,6 @@ class GenerationMixin:
|
||||
|
||||
this_peer_finished = False # used by synced_gpus only
|
||||
|
||||
step_counter = 0
|
||||
while True:
|
||||
if synced_gpus:
|
||||
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
||||
@ -1875,20 +1881,23 @@ class GenerationMixin:
|
||||
model_kwargs["use_cache"] = True
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
||||
|
||||
# if the first step in the loop, encode all the prefix and obtain three parameters: (1) past_key_values; (2) last_hidden_states; (3) logit_for_next_step
|
||||
if step_counter == 0:
|
||||
# encode the given prefix and prepare model inputs; encoder-decoder model process the prefix and save the `encoder_outputs`
|
||||
# if the first step in the loop, encode all the prefix and obtain three parameters: (1) past_key_values;
|
||||
# (2) last_hidden_states; (3) logit_for_next_step
|
||||
if model_kwargs.get("past") is None:
|
||||
# encode the given prefix and prepare model inputs; encoder-decoder model process the prefix and save
|
||||
# the `encoder_outputs`
|
||||
output = self(**model_inputs, output_hidden_states=True, output_attentions=True)
|
||||
|
||||
# past_key_values is activated for fast decoding
|
||||
# past_key_values is required for fast decoding
|
||||
if "past_key_values" not in output:
|
||||
raise ValueError(
|
||||
"self.__class__ cannot return `past_key_values` and can therefore **not** be used for"
|
||||
" contrastive search."
|
||||
f"{self.__class__.__name__} cannot return `past_key_values` and can therefore **not** be used "
|
||||
"for contrastive search."
|
||||
)
|
||||
past_key_values = output.past_key_values
|
||||
|
||||
# last decoder hidden states will be used to compute the degeneration penalty (cosine similarity with previous tokens)
|
||||
# last decoder hidden states will be used to compute the degeneration penalty (cosine similarity with
|
||||
# previous tokens)
|
||||
if self.config.is_encoder_decoder:
|
||||
last_hidden_states = output.decoder_hidden_states[-1]
|
||||
else:
|
||||
@ -1897,7 +1906,8 @@ class GenerationMixin:
|
||||
logit_for_next_step = output.logits[:, -1, :]
|
||||
|
||||
# contrastive_search main logic start:
|
||||
# contrastive search decoding consists of two steps: (1) candidate tokens recall; (2) candidate re-rank by degeneration penalty
|
||||
# contrastive search decoding consists of two steps: (1) candidate tokens recall; (2) candidate re-rank by
|
||||
# degeneration penalty
|
||||
bsz, seqlen, embed_dim = last_hidden_states.size()
|
||||
|
||||
# logits processor
|
||||
@ -1949,12 +1959,6 @@ class GenerationMixin:
|
||||
)
|
||||
# compute the candidate tokens by the language model and collects their hidden_states
|
||||
output = self(output_hidden_states=True, **next_model_inputs)
|
||||
|
||||
if "past_key_values" not in output:
|
||||
raise ValueError(
|
||||
"self.__class__ cannot return `past_key_values` and can therefore **not** be used for contrastive"
|
||||
" search."
|
||||
)
|
||||
past_key_values = output.past_key_values
|
||||
|
||||
logits = output.logits[:, -1, :]
|
||||
@ -1969,13 +1973,16 @@ class GenerationMixin:
|
||||
last_hidden_states.unsqueeze(1).expand(-1, top_k, -1, -1).reshape(bsz * top_k, seqlen, embed_dim)
|
||||
)
|
||||
|
||||
# compute the degeneratin penalty and re-rank the candidates based on the degeneration penalty and the model confidence
|
||||
# compute the degeneratin penalty and re-rank the candidates based on the degeneration penalty and the
|
||||
# model confidence
|
||||
# the scores and index of the selected tokens are returned
|
||||
selected_scores, selected_idx = ranking_fast(
|
||||
context_hidden, next_hidden, top_k_probs, penalty_alpha, top_k
|
||||
)
|
||||
|
||||
# prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores (model confidence minus degeneration penalty); (6) decoder hidden_states
|
||||
# prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing
|
||||
# the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores
|
||||
# (model confidence minus degeneration penalty); (6) decoder hidden_states
|
||||
next_tokens = top_k_ids[range(len(top_k_ids)), selected_idx]
|
||||
next_hidden = torch.stack(torch.split(next_hidden.squeeze(dim=1), top_k))
|
||||
next_hidden = next_hidden[range(bsz), selected_idx, :]
|
||||
@ -2003,7 +2010,8 @@ class GenerationMixin:
|
||||
|
||||
logit_for_next_step = torch.stack(torch.split(logits, top_k))[range(bsz), selected_idx, :]
|
||||
# contrastive_search main logic end::
|
||||
# after running the above codes, we update following parameters: next_tokens, past_key_values, logit_for_next_step, selected_score, decoder_hidden_states_one_step
|
||||
# after running the above codes, we update following parameters: next_tokens, past_key_values,
|
||||
# logit_for_next_step, selected_score, decoder_hidden_states_one_step
|
||||
|
||||
if synced_gpus and this_peer_finished:
|
||||
continue # don't waste resources running the code we don't need
|
||||
@ -2047,10 +2055,6 @@ class GenerationMixin:
|
||||
else:
|
||||
this_peer_finished = True
|
||||
|
||||
# prepare model inputs
|
||||
model_kwargs["past_key_values"] = past_key_values
|
||||
step_counter += 1
|
||||
|
||||
if return_dict_in_generate:
|
||||
if self.config.is_encoder_decoder:
|
||||
return ContrastiveSearchEncoderDecoderOutput(
|
||||
|
Loading…
Reference in New Issue
Block a user