From de2f722172089473a0d1ff0c037cd6b29460493f Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Mon, 13 May 2024 15:48:20 +0100 Subject: [PATCH] Generate: remove near-duplicate sample/greedy copy (#30773) --- src/transformers/generation/utils.py | 500 +++--------------- .../models/musicgen/modeling_musicgen.py | 4 +- .../modeling_musicgen_melody.py | 4 +- src/transformers/models/rag/modeling_rag.py | 2 +- 4 files changed, 92 insertions(+), 418 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 9135bb20484..1c90fdd3075 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1683,17 +1683,6 @@ class GenerationMixin: streamer=streamer, **model_kwargs, ) - if generation_mode == GenerationMode.GREEDY_SEARCH: - # 11. run greedy search - result = self._greedy_search( - input_ids, - logits_processor=prepared_logits_processor, - stopping_criteria=prepared_stopping_criteria, - generation_config=generation_config, - synced_gpus=synced_gpus, - streamer=streamer, - **model_kwargs, - ) elif generation_mode == GenerationMode.CONTRASTIVE_SEARCH: if not model_kwargs["use_cache"]: @@ -1709,9 +1698,11 @@ class GenerationMixin: **model_kwargs, ) - elif generation_mode == GenerationMode.SAMPLE: + elif generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH): # 11. prepare logits warper - logits_warper = self._get_logits_warper(generation_config) + prepared_logits_warper = ( + self._get_logits_warper(generation_config) if generation_config.do_sample else None + ) # 12. expand input_ids with `num_return_sequences` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( @@ -1721,11 +1712,11 @@ class GenerationMixin: **model_kwargs, ) - # 13. run sample + # 13. run sample (it degenerates to greedy search when `generation_config.do_sample=False`) result = self._sample( input_ids, logits_processor=prepared_logits_processor, - logits_warper=logits_warper, + logits_warper=prepared_logits_warper, stopping_criteria=prepared_stopping_criteria, generation_config=generation_config, synced_gpus=synced_gpus, @@ -1733,38 +1724,11 @@ class GenerationMixin: **model_kwargs, ) - elif generation_mode == GenerationMode.BEAM_SEARCH: - # 11. prepare beam search scorer - beam_scorer = BeamSearchScorer( - batch_size=batch_size, - num_beams=generation_config.num_beams, - device=inputs_tensor.device, - length_penalty=generation_config.length_penalty, - do_early_stopping=generation_config.early_stopping, - num_beam_hyps_to_keep=generation_config.num_return_sequences, - max_length=generation_config.max_length, - ) - # 12. interleave input_ids with `num_beams` additional sequences per batch - input_ids, model_kwargs = self._expand_inputs_for_generation( - input_ids=input_ids, - expand_size=generation_config.num_beams, - is_encoder_decoder=self.config.is_encoder_decoder, - **model_kwargs, - ) - # 13. run beam search - result = self._beam_search( - input_ids, - beam_scorer, - logits_processor=prepared_logits_processor, - stopping_criteria=prepared_stopping_criteria, - generation_config=generation_config, - synced_gpus=synced_gpus, - **model_kwargs, - ) - - elif generation_mode == GenerationMode.BEAM_SAMPLE: + elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH): # 11. prepare logits warper - logits_warper = self._get_logits_warper(generation_config) + prepared_logits_warper = ( + self._get_logits_warper(generation_config) if generation_config.do_sample else None + ) # 12. prepare beam search scorer beam_scorer = BeamSearchScorer( @@ -1786,11 +1750,11 @@ class GenerationMixin: ) # 14. run beam sample - result = self._beam_sample( + result = self._beam_search( input_ids, beam_scorer, logits_processor=prepared_logits_processor, - logits_warper=logits_warper, + logits_warper=prepared_logits_warper, stopping_criteria=prepared_stopping_criteria, generation_config=generation_config, synced_gpus=synced_gpus, @@ -2284,162 +2248,32 @@ class GenerationMixin: **model_kwargs, ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: r""" - Generates sequences of token ids for models with a language modeling head using **greedy decoding** and can be - used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. - - Parameters: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - The sequence used as a prompt for the generation. - logits_processor (`LogitsProcessorList`): - An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] - used to modify the prediction scores of the language modeling head applied at each generation step. - stopping_criteria (`StoppingCriteriaList`): - An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] - used to tell if the generation loop should stop. - generation_config ([`~generation.GenerationConfig`]): - The generation configuration to be used as parametrization of the decoding method. - synced_gpus (`bool`): - Whether to continue running the while loop until max_length (needed for ZeRO stage 3) - streamer (`BaseStreamer`, *optional*): - Streamer object that will be used to stream the generated sequences. Generated tokens are passed - through `streamer.put(token_ids)` and the streamer is responsible for any further processing. - model_kwargs: - Additional model specific keyword arguments will be forwarded to the `forward` function of the model. - If model is an encoder-decoder model the kwargs should include `encoder_outputs`. - - Return: - [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or - `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a - [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and - `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if - `model.config.is_encoder_decoder=True`. + Deprecated. Use `._sample()` instead, passing the same arguments. """ - # init values - pad_token_id = generation_config.pad_token_id - output_attentions = generation_config.output_attentions - output_hidden_states = generation_config.output_hidden_states - output_scores = generation_config.output_scores - output_logits = generation_config.output_logits - return_dict_in_generate = generation_config.return_dict_in_generate - has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) - # init attention / hidden states / scores tuples - raw_logits = () if (return_dict_in_generate and output_logits) else None - scores = () if (return_dict_in_generate and output_scores) else None - decoder_attentions = () if (return_dict_in_generate and output_attentions) else None - cross_attentions = () if (return_dict_in_generate and output_attentions) else None - decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None - - # if model is an encoder-decoder, retrieve encoder attention weights and hidden states - if return_dict_in_generate and self.config.is_encoder_decoder: - encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None - encoder_hidden_states = ( - model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None - ) - - # keep track of which sequences are already finished - batch_size = input_ids.shape[0] - this_peer_finished = False - unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) - model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) - - while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): - # prepare model inputs - model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - - # forward pass to get next token - outputs = self( - **model_inputs, - return_dict=True, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - - if synced_gpus and this_peer_finished: - continue # don't waste resources running the code we don't need - - next_token_logits = outputs.logits[:, -1, :] - - # pre-process distribution - next_tokens_scores = logits_processor(input_ids, next_token_logits) - - # Store scores, attentions and hidden_states when required - if return_dict_in_generate: - if output_scores: - scores += (next_tokens_scores,) - if output_logits: - raw_logits += (next_token_logits,) - if output_attentions: - decoder_attentions += ( - (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) - ) - if self.config.is_encoder_decoder: - cross_attentions += (outputs.cross_attentions,) - - if output_hidden_states: - decoder_hidden_states += ( - (outputs.decoder_hidden_states,) - if self.config.is_encoder_decoder - else (outputs.hidden_states,) - ) - - # argmax - next_tokens = torch.argmax(next_tokens_scores, dim=-1) - - # finished sentences should have their next token be a padding token - if has_eos_stopping_criteria: - next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) - - # update generated ids, model inputs, and length for next step - input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) - if streamer is not None: - streamer.put(next_tokens.cpu()) - model_kwargs = self._update_model_kwargs_for_generation( - outputs, - model_kwargs, - is_encoder_decoder=self.config.is_encoder_decoder, - ) - - unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) - this_peer_finished = unfinished_sequences.max() == 0 - - if streamer is not None: - streamer.end() - - if return_dict_in_generate: - if self.config.is_encoder_decoder: - return GenerateEncoderDecoderOutput( - sequences=input_ids, - scores=scores, - logits=raw_logits, - encoder_attentions=encoder_attentions, - encoder_hidden_states=encoder_hidden_states, - decoder_attentions=decoder_attentions, - cross_attentions=cross_attentions, - decoder_hidden_states=decoder_hidden_states, - past_key_values=model_kwargs.get("past_key_values"), - ) - else: - return GenerateDecoderOnlyOutput( - sequences=input_ids, - scores=scores, - logits=raw_logits, - attentions=decoder_attentions, - hidden_states=decoder_hidden_states, - past_key_values=model_kwargs.get("past_key_values"), - ) - else: - return input_ids + logger.warning_once( + "Calling `._greedy_search()` directly is deprecated and will be removed in v4.42. Use `._sample()` " + "instead, passing the same arguments." + ) + return self._sample( + input_ids=input_ids, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + generation_config=generation_config, + synced_gpus=synced_gpus, + streamer=streamer, + **model_kwargs, + ) def _sample( self, input_ids: torch.LongTensor, logits_processor: LogitsProcessorList, stopping_criteria: StoppingCriteriaList, - logits_warper: LogitsProcessorList, generation_config: GenerationConfig, synced_gpus: bool, streamer: Optional["BaseStreamer"], + logits_warper: Optional[LogitsProcessorList] = None, **model_kwargs, ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: r""" @@ -2455,10 +2289,6 @@ class GenerationMixin: stopping_criteria (`StoppingCriteriaList`): An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] used to tell if the generation loop should stop. - logits_warper (`LogitsProcessorList`): - An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used - to warp the prediction score distribution of the language modeling head applied before multinomial - sampling at each generation step. generation_config ([`~generation.GenerationConfig`]): The generation configuration to be used as parametrization of the decoding method. synced_gpus (`bool`): @@ -2466,6 +2296,11 @@ class GenerationMixin: streamer (`BaseStreamer`, *optional*): Streamer object that will be used to stream the generated sequences. Generated tokens are passed through `streamer.put(token_ids)` and the streamer is responsible for any further processing. + logits_warper (`LogitsProcessorList`, *optional*): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used + to warp the prediction score distribution of the language modeling head applied before multinomial + sampling at each generation step. Only required with sampling strategies (i.e. `do_sample` is set in + `generation_config`) model_kwargs: Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is an encoder-decoder model the kwargs should include `encoder_outputs`. @@ -2485,6 +2320,12 @@ class GenerationMixin: output_logits = generation_config.output_logits return_dict_in_generate = generation_config.return_dict_in_generate has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) + do_sample = generation_config.do_sample + if do_sample is True and not isinstance(logits_warper, LogitsProcessorList): + raise ValueError( + "`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is " + f"{logits_warper})." + ) # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None @@ -2525,7 +2366,8 @@ class GenerationMixin: # pre-process distribution next_token_scores = logits_processor(input_ids, next_token_logits) - next_token_scores = logits_warper(input_ids, next_token_scores) + if do_sample: + next_token_scores = logits_warper(input_ids, next_token_scores) # Store scores, attentions and hidden_states when required if return_dict_in_generate: @@ -2547,9 +2389,12 @@ class GenerationMixin: else (outputs.hidden_states,) ) - # sample - probs = nn.functional.softmax(next_token_scores, dim=-1) - next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + # token selection + if do_sample: + probs = nn.functional.softmax(next_token_scores, dim=-1) + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + else: + next_tokens = torch.argmax(next_token_scores, dim=-1) # finished sentences should have their next token be a padding token if has_eos_stopping_criteria: @@ -2622,6 +2467,7 @@ class GenerationMixin: past_key_values.reorder_cache(beam_idx) return past_key_values + # TODO (joao, v4.42): remove default for `logits_warper` def _beam_search( self, input_ids: torch.LongTensor, @@ -2630,6 +2476,7 @@ class GenerationMixin: stopping_criteria: StoppingCriteriaList, generation_config: GenerationConfig, synced_gpus: bool, + logits_warper: Optional[LogitsProcessorList] = None, **model_kwargs, ) -> Union[GenerateBeamOutput, torch.LongTensor]: r""" @@ -2652,6 +2499,11 @@ class GenerationMixin: The generation configuration to be used as parametrization of the decoding method. synced_gpus (`bool`): Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + logits_warper (`LogitsProcessorList`, *optional*): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used + to warp the prediction score distribution of the language modeling head applied before multinomial + sampling at each generation step. Only required with sampling strategies (i.e. `do_sample` is set in + `generation_config`) model_kwargs: Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is an encoder-decoder model the kwargs should include `encoder_outputs`. @@ -2672,6 +2524,12 @@ class GenerationMixin: output_logits = generation_config.output_logits return_dict_in_generate = generation_config.return_dict_in_generate sequential = generation_config.low_memory + do_sample = generation_config.do_sample + if do_sample is True and not isinstance(logits_warper, LogitsProcessorList): + raise ValueError( + "`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is " + f"{logits_warper})." + ) batch_size = len(beam_scorer._beam_hyps) num_beams = beam_scorer.num_beams @@ -2768,6 +2626,8 @@ class GenerationMixin: ) # (batch_size * num_beams, vocab_size) next_token_scores_processed = logits_processor(input_ids, next_token_scores) + if do_sample: + next_token_scores_processed = logits_warper(input_ids, next_token_scores_processed) next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as( next_token_scores_processed ) @@ -2795,11 +2655,20 @@ class GenerationMixin: vocab_size = next_token_scores.shape[-1] next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) - # Sample 1 + len(eos_token_id) next tokens for each beam so we have at least 1 non eos token per beam. + # Beam token selection: pick 1 + eos_token_id.shape[0] next tokens for each beam so we have at least 1 + # non eos token per beam. n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0 - next_token_scores, next_tokens = torch.topk( - next_token_scores, max(2, 1 + n_eos_tokens) * num_beams, dim=1, largest=True, sorted=True - ) + n_tokens_to_keep = max(2, 1 + n_eos_tokens) * num_beams + if do_sample: + probs = nn.functional.softmax(next_token_scores, dim=-1) + next_tokens = torch.multinomial(probs, num_samples=n_tokens_to_keep) + next_token_scores = torch.gather(next_token_scores, -1, next_tokens) + next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1) + next_tokens = torch.gather(next_tokens, -1, _indices) + else: + next_token_scores, next_tokens = torch.topk( + next_token_scores, n_tokens_to_keep, dim=1, largest=True, sorted=True + ) next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") next_tokens = next_tokens % vocab_size @@ -2897,219 +2766,24 @@ class GenerationMixin: **model_kwargs, ) -> Union[GenerateBeamOutput, torch.LongTensor]: r""" - Generates sequences of token ids for models with a language modeling head using **beam search multinomial - sampling** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. - - Parameters: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - The sequence used as a prompt for the generation. - beam_scorer (`BeamScorer`): - A derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and - sorted during generation. For more information, the documentation of [`BeamScorer`] should be read. - logits_processor (`LogitsProcessorList`): - An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] - used to modify the prediction scores of the language modeling head applied at each generation step. - stopping_criteria (`StoppingCriteriaList`): - An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] - used to tell if the generation loop should stop. - logits_warper (`LogitsProcessorList`): - An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used - to warp the prediction score distribution of the language modeling head applied before multinomial - sampling at each generation step. - generation_config ([`~generation.GenerationConfig`]): - The generation configuration to be used as parametrization of the decoding method. - synced_gpus (`bool`): - Whether to continue running the while loop until max_length (needed for ZeRO stage 3) - model_kwargs: - Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is - an encoder-decoder model the kwargs should include `encoder_outputs`. - - Return: - [`~generation.GenerateBeamDecoderOnlyOutput`], [`~generation.GenerateBeamEncoderDecoderOutput`] or - `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a - [`~generation.GenerateBeamDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and - `return_dict_in_generate=True` or a [`~generation.GenerateBeamEncoderDecoderOutput`] if - `model.config.is_encoder_decoder=True`. + Deprecated. Use `._beam_search()` instead, passing the same arguments. """ - # init values - pad_token_id = generation_config.pad_token_id - eos_token_id = generation_config.eos_token_id - output_attentions = generation_config.output_attentions - output_hidden_states = generation_config.output_hidden_states - output_scores = generation_config.output_scores - output_logits = generation_config.output_logits - return_dict_in_generate = generation_config.return_dict_in_generate - batch_size = len(beam_scorer._beam_hyps) - num_beams = beam_scorer.num_beams - - batch_beam_size, cur_len = input_ids.shape - model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) - - # init attention / hidden states / scores tuples - scores = () if (return_dict_in_generate and output_scores) else None - raw_logits = () if (return_dict_in_generate and output_logits) else None - beam_indices = ( - tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None + logger.warning_once( + "Calling `._beam_sample()` directly is deprecated and will be removed in v4.42. Use `._beam_search()` " + "instead, passing the same arguments." ) - decoder_attentions = () if (return_dict_in_generate and output_attentions) else None - cross_attentions = () if (return_dict_in_generate and output_attentions) else None - decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None - - # if model is an encoder-decoder, retrieve encoder attention weights and hidden states - if return_dict_in_generate and self.config.is_encoder_decoder: - encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None - encoder_hidden_states = ( - model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None - ) - - beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) - beam_scores = beam_scores.view((batch_size * num_beams,)) - - this_peer_finished = False - - decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder - while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): - model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - - outputs = self( - **model_inputs, - return_dict=True, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - - if synced_gpus and this_peer_finished: - cur_len = cur_len + 1 - continue # don't waste resources running the code we don't need - - next_token_logits = outputs.logits[:, -1, :] - - next_token_scores = nn.functional.log_softmax( - next_token_logits, dim=-1 - ) # (batch_size * num_beams, vocab_size) - - next_token_scores_processed = logits_processor(input_ids, next_token_scores) - next_token_scores_processed = logits_warper(input_ids, next_token_scores_processed) - next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as( - next_token_scores_processed - ) - - # Store scores, attentions and hidden_states when required - if return_dict_in_generate: - if output_scores: - scores += (next_token_scores_processed,) - if output_logits: - raw_logits += (next_token_logits,) - if output_attentions: - decoder_attentions += ( - (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) - ) - if self.config.is_encoder_decoder: - cross_attentions += (outputs.cross_attentions,) - - if output_hidden_states: - decoder_hidden_states += ( - (outputs.decoder_hidden_states,) - if self.config.is_encoder_decoder - else (outputs.hidden_states,) - ) - - # reshape for beam search - vocab_size = next_token_scores.shape[-1] - next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) - - probs = nn.functional.softmax(next_token_scores, dim=-1) - - next_tokens = torch.multinomial(probs, num_samples=2 * num_beams) - next_token_scores = torch.gather(next_token_scores, -1, next_tokens) - - next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1) - next_tokens = torch.gather(next_tokens, -1, _indices) - - next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") - next_tokens = next_tokens % vocab_size - - # stateless - beam_outputs = beam_scorer.process( - input_ids, - next_token_scores, - next_tokens, - next_indices, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - beam_indices=beam_indices, - decoder_prompt_len=decoder_prompt_len, - ) - beam_scores = beam_outputs["next_beam_scores"] - beam_next_tokens = beam_outputs["next_beam_tokens"] - beam_idx = beam_outputs["next_beam_indices"] - - input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) - - model_kwargs = self._update_model_kwargs_for_generation( - outputs, - model_kwargs, - is_encoder_decoder=self.config.is_encoder_decoder, - ) - if model_kwargs.get("past_key_values", None) is not None: - model_kwargs["past_key_values"] = self._temporary_reorder_cache( - model_kwargs["past_key_values"], beam_idx - ) - - if return_dict_in_generate and output_scores: - beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))) - - # increase cur_len - cur_len = cur_len + 1 - - if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)): - this_peer_finished = True - - sequence_outputs = beam_scorer.finalize( - input_ids, - beam_scores, - next_tokens, - next_indices, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - max_length=stopping_criteria.max_length, - beam_indices=beam_indices, - decoder_prompt_len=decoder_prompt_len, + return self._beam_search( + input_ids=input_ids, + beam_scorer=beam_scorer, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + logits_warper=logits_warper, + generation_config=generation_config, + synced_gpus=synced_gpus, + **model_kwargs, ) - if return_dict_in_generate: - if not output_scores: - sequence_outputs["sequence_scores"] = None - - if self.config.is_encoder_decoder: - return GenerateBeamEncoderDecoderOutput( - sequences=sequence_outputs["sequences"], - sequences_scores=sequence_outputs["sequence_scores"], - scores=scores, - logits=raw_logits, - beam_indices=sequence_outputs["beam_indices"], - encoder_attentions=encoder_attentions, - encoder_hidden_states=encoder_hidden_states, - decoder_attentions=decoder_attentions, - cross_attentions=cross_attentions, - decoder_hidden_states=decoder_hidden_states, - past_key_values=model_kwargs.get("past_key_values"), - ) - else: - return GenerateBeamDecoderOnlyOutput( - sequences=sequence_outputs["sequences"], - sequences_scores=sequence_outputs["sequence_scores"], - scores=scores, - logits=raw_logits, - beam_indices=sequence_outputs["beam_indices"], - attentions=decoder_attentions, - hidden_states=decoder_hidden_states, - past_key_values=model_kwargs.get("past_key_values"), - ) - else: - return sequence_outputs["sequences"] - def _group_beam_search( self, input_ids: torch.LongTensor, diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index 9d1cf6e568f..8e8b1fe2842 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -1739,7 +1739,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel): ) # 11. run greedy search - outputs = self._greedy_search( + outputs = self._sample( input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria, @@ -2832,7 +2832,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel): ) # 11. run greedy search - outputs = self._greedy_search( + outputs = self._sample( input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria, diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index 63fc638f164..9865a4b9179 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -1676,7 +1676,7 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel): ) # 11. run greedy search - outputs = self._greedy_search( + outputs = self._sample( input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria, @@ -2691,7 +2691,7 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel): ) # 11. run greedy search - outputs = self._greedy_search( + outputs = self._sample( input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria, diff --git a/src/transformers/models/rag/modeling_rag.py b/src/transformers/models/rag/modeling_rag.py index 7eac28ca77e..3590369d5b9 100644 --- a/src/transformers/models/rag/modeling_rag.py +++ b/src/transformers/models/rag/modeling_rag.py @@ -1550,7 +1550,7 @@ class RagTokenForGeneration(RagPreTrainedModel): f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing" " greedy search." ) - return self._greedy_search( + return self._sample( input_ids, logits_processor=pre_processor, stopping_criteria=prepared_stopping_criteria,