From f21af2627996be9350fb62febb15ca411a0e02f5 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 3 Feb 2023 10:24:02 +0000 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=A8=F0=9F=9A=A8=20Generate:=20standard?= =?UTF-8?q?ize=20beam=20search=20behavior=20across=20frameworks=20(#21368)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/transformers/generation/beam_search.py | 83 ++++++++++++------- .../generation/configuration_utils.py | 19 ++++- src/transformers/generation/flax_utils.py | 34 +++++--- src/transformers/generation/tf_utils.py | 38 +++++---- src/transformers/generation/utils.py | 6 +- src/transformers/models/rag/modeling_rag.py | 1 + tests/generation/test_utils.py | 53 ------------ tests/models/bart/test_modeling_flax_bart.py | 2 +- tests/models/gpt2/test_modeling_flax_gpt2.py | 2 +- tests/models/t5/test_modeling_flax_t5.py | 2 +- 10 files changed, 122 insertions(+), 118 deletions(-) diff --git a/src/transformers/generation/beam_search.py b/src/transformers/generation/beam_search.py index 6e4f9cb936e..8b720e6a773 100644 --- a/src/transformers/generation/beam_search.py +++ b/src/transformers/generation/beam_search.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import warnings from abc import ABC, abstractmethod from collections import UserDict from typing import List, Optional, Tuple, Union @@ -130,8 +129,6 @@ class BeamSearchScorer(BeamScorer): Args: batch_size (`int`): Batch Size of `input_ids` for which standard beam search decoding is run in parallel. - max_length (`int`): - The maximum length of the sequence to be generated. num_beams (`int`): Number of beams for beam search. device (`torch.device`): @@ -142,14 +139,20 @@ class BeamSearchScorer(BeamScorer): the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while `length_penalty` < 0.0 encourages shorter sequences. - do_early_stopping (`bool`, *optional*, defaults to `False`): - Whether to stop the beam search when at least `num_beams` sentences are finished per batch or not. + do_early_stopping (`bool` or `str`, *optional*, defaults to `False`): + Controls the stopping condition for beam-based methods, like beam-search. It accepts the following values: + `True`, where the generation stops as soon as there are `num_beams` complete candidates; `False`, where an + heuristic is applied and the generation stops when is it very unlikely to find better candidates; + `"never"`, where the beam search procedure only stops when there cannot be better candidates (canonical + beam search algorithm). num_beam_hyps_to_keep (`int`, *optional*, defaults to 1): The number of beam hypotheses that shall be returned upon calling [`~transformer.BeamSearchScorer.finalize`]. num_beam_groups (`int`): Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams. See [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details. + max_length (`int`, *optional*): + The maximum length of the sequence to be generated. """ def __init__( @@ -158,10 +161,10 @@ class BeamSearchScorer(BeamScorer): num_beams: int, device: torch.device, length_penalty: Optional[float] = 1.0, - do_early_stopping: Optional[bool] = False, + do_early_stopping: Optional[Union[bool, str]] = False, num_beam_hyps_to_keep: Optional[int] = 1, num_beam_groups: Optional[int] = 1, - **kwargs, + max_length: Optional[int] = None, ): self.num_beams = num_beams self.device = device @@ -177,6 +180,7 @@ class BeamSearchScorer(BeamScorer): num_beams=self.num_beams, length_penalty=self.length_penalty, early_stopping=self.do_early_stopping, + max_length=max_length, ) for _ in range(batch_size) ] @@ -194,13 +198,6 @@ class BeamSearchScorer(BeamScorer): f" divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}." ) - if "max_length" in kwargs: - warnings.warn( - "Passing `max_length` to BeamSearchScorer is deprecated and has no effect. " - "`max_length` should be passed directly to `beam_search(...)`, `beam_sample(...)`" - ", or `group_beam_search(...)`." - ) - @property def is_done(self) -> bool: return self._done.all() @@ -402,8 +399,6 @@ class ConstrainedBeamSearchScorer(BeamScorer): Args: batch_size (`int`): Batch Size of `input_ids` for which standard beam search decoding is run in parallel. - max_length (`int`): - The maximum length of the sequence to be generated. num_beams (`int`): Number of beams for beam search. constraints (`List[Constraint]`): @@ -417,14 +412,20 @@ class ConstrainedBeamSearchScorer(BeamScorer): the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while `length_penalty` < 0.0 encourages shorter sequences. - do_early_stopping (`bool`, *optional*, defaults to `False`): - Whether to stop the beam search when at least `num_beams` sentences are finished per batch or not. + do_early_stopping (`bool` or `str`, *optional*, defaults to `False`): + Controls the stopping condition for beam-based methods, like beam-search. It accepts the following values: + `True`, where the generation stops as soon as there are `num_beams` complete candidates; `False`, where an + heuristic is applied and the generation stops when is it very unlikely to find better candidates; + `"never"`, where the beam search procedure only stops when there cannot be better candidates (canonical + beam search algorithm). num_beam_hyps_to_keep (`int`, *optional*, defaults to 1): The number of beam hypotheses that shall be returned upon calling [`~transformer.BeamSearchScorer.finalize`]. num_beam_groups (`int`): Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams. See [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details. + max_length (`int`, *optional*): + The maximum length of the sequence to be generated. """ def __init__( @@ -434,10 +435,10 @@ class ConstrainedBeamSearchScorer(BeamScorer): constraints: List[Constraint], device: torch.device, length_penalty: Optional[float] = 1.0, - do_early_stopping: Optional[bool] = False, + do_early_stopping: Optional[Union[bool, str]] = False, num_beam_hyps_to_keep: Optional[int] = 1, num_beam_groups: Optional[int] = 1, - **kwargs, + max_length: Optional[int] = None, ): self.num_beams = num_beams self.device = device @@ -454,6 +455,7 @@ class ConstrainedBeamSearchScorer(BeamScorer): num_beams=self.num_beams, length_penalty=self.length_penalty, early_stopping=self.do_early_stopping, + max_length=max_length, ) for _ in range(batch_size) ] @@ -471,13 +473,6 @@ class ConstrainedBeamSearchScorer(BeamScorer): f" divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}." ) - if "max_length" in kwargs: - warnings.warn( - "Passing `max_length` to ConstrainedBeamSearchScorer is deprecated and has no effect. " - "`max_length` should be passed directly to `beam_search(...)`, `beam_sample(...)`" - ", or `group_beam_search(...)`." - ) - @property def is_done(self) -> bool: return self._done.all() @@ -865,16 +860,23 @@ class ConstrainedBeamSearchScorer(BeamScorer): class BeamHypotheses: - def __init__(self, num_beams: int, length_penalty: float, early_stopping: bool): + def __init__(self, num_beams: int, length_penalty: float, early_stopping: bool, max_length: Optional[int] = None): """ Initialize n-best list of hypotheses. """ self.length_penalty = length_penalty self.early_stopping = early_stopping + self.max_length = max_length self.num_beams = num_beams self.beams = [] self.worst_score = 1e9 + if not isinstance(self.early_stopping, bool) and self.max_length is None: + raise ValueError( + "When `do_early_stopping` is set to a string, `max_length` must be defined. Ensure it is passed to the" + " BeamScorer class instance at initialization time." + ) + def __len__(self): """ Number of hypotheses in the list. @@ -903,9 +905,26 @@ class BeamHypotheses: if len(self) < self.num_beams: return False - elif self.early_stopping: + + # `True`: stop as soon as at least `num_beams` hypotheses are finished + if self.early_stopping is True: return True - else: - cur_score = best_sum_logprobs / cur_len**self.length_penalty - ret = self.worst_score >= cur_score + # `False`: heuristic -- compute best possible score from `cur_len`, even though it is not entirely accurate + # when `length_penalty` is positive. See the discussion below for more details. + # https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565 + elif self.early_stopping is False: + highest_attainable_score = best_sum_logprobs / cur_len**self.length_penalty + ret = self.worst_score >= highest_attainable_score + return ret + # `"never"`: compute the best possible score, depending on the signal of `length_penalty` + else: + # `length_penalty` > 0.0 -> max denominator is obtaned from `max_length`, not from `cur_len` -> min + # abs(`highest_attainable_score`) is obtained -> `highest_attainable_score` is negative, hence we obtain + # its max this way + if self.length_penalty > 0.0: + highest_attainable_score = best_sum_logprobs / self.max_length**self.length_penalty + # the opposite logic applies here (max `highest_attainable_score` from `cur_len`) + else: + highest_attainable_score = best_sum_logprobs / cur_len**self.length_penalty + ret = self.worst_score >= highest_attainable_score return ret diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index a869d49ccf0..e01b0dc40aa 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -71,8 +71,12 @@ class GenerationConfig(PushToHubMixin): `min_new_tokens`. Its effect is overridden by `min_new_tokens`, if also set. min_new_tokens (`int`, *optional*): The minimum numbers of tokens to generate, ignoring the number of tokens in the prompt. - early_stopping (`bool`, *optional*, defaults to `False`): - Whether to stop the beam search when at least `num_beams` sentences are finished per batch or not. + early_stopping (`bool` or `str`, *optional*, defaults to `False`): + Controls the stopping condition for beam-based methods, like beam-search. It accepts the following values: + `True`, where the generation stops as soon as there are `num_beams` complete candidates; `False`, where an + heuristic is applied and the generation stops when is it very unlikely to find better candidates; + `"never"`, where the beam search procedure only stops when there cannot be better candidates (canonical + beam search algorithm). max_time(`float`, *optional*): The maximum amount of time you allow the computation to run for in seconds. generation will still finish the current pass after allocated time has been passed. @@ -290,6 +294,9 @@ class GenerationConfig(PushToHubMixin): logger.error(f"Can't set {key} with value {value} for {self}") raise err + # Validate the values of the attributes + self.validate() + def __eq__(self, other): self_dict = self.__dict__.copy() other_dict = other.__dict__.copy() @@ -302,6 +309,14 @@ class GenerationConfig(PushToHubMixin): def __repr__(self): return f"{self.__class__.__name__} {self.to_json_string()}" + def validate(self): + """ + Validates the values of the attributes of the GenerationConfig instance, and raises a `ValueError` if any of + the values are invalid. + """ + if self.early_stopping not in {True, False, "never"}: + raise ValueError(f"`early_stopping` must be a boolean or 'never', but is {self.early_stopping}.") + def save_pretrained( self, save_directory: Union[str, os.PathLike], diff --git a/src/transformers/generation/flax_utils.py b/src/transformers/generation/flax_utils.py index a327621c3c0..1cfc07b9786 100644 --- a/src/transformers/generation/flax_utils.py +++ b/src/transformers/generation/flax_utils.py @@ -19,7 +19,7 @@ import copy import inspect import warnings from functools import partial -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Union import numpy as np @@ -275,6 +275,7 @@ class FlaxGenerationMixin: generation_config = copy.deepcopy(generation_config) model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs + generation_config.validate() self._validate_model_kwargs(model_kwargs.copy()) # set init values @@ -633,7 +634,7 @@ class FlaxGenerationMixin: pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, length_penalty: Optional[float] = None, - early_stopping: Optional[bool] = None, + early_stopping: Optional[Union[bool, str]] = None, logits_processor: Optional[FlaxLogitsProcessorList] = None, trace: bool = True, params: Optional[Dict[str, jnp.ndarray]] = None, @@ -733,14 +734,22 @@ class FlaxGenerationMixin: not_max_length_yet = state.cur_len < max_length # 2. can the new beams still improve? - best_running_score = state.running_scores[:, -1:] / (max_length**length_penalty) + # early_stopping == False -> apply heuristic = always get the best score from `cur_len`. See the discussion + # below for more details. + # https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565 + # early_stopping == "never" -> compute the best score from max_length or cur_len, depending on the sign of + # length_penalty. Positive length_penalty favors longer sequences, thus we use max_length there. + if early_stopping == "never" and length_penalty > 0.0: + best_running_score = state.running_scores[:, :1] / (max_length**length_penalty) + else: + best_running_score = state.running_scores[:, :1] / (state.cur_len**length_penalty) worst_finished_score = jnp.where( state.is_sent_finished, jnp.min(state.scores, axis=1, keepdims=True), np.array(-1.0e7) ) - improvement_still_possible = jnp.all(worst_finished_score < best_running_score) + improvement_still_possible = jnp.any(best_running_score > worst_finished_score) # 3. is there still a beam that has not finished? - still_open_beam = ~(jnp.all(state.is_sent_finished) & early_stopping) + still_open_beam = ~(jnp.all(state.is_sent_finished) & (early_stopping is True)) return not_max_length_yet & still_open_beam & improvement_still_possible @@ -813,7 +822,7 @@ class FlaxGenerationMixin: # 5. Get running sequences scores for next # Determine the top k beam indices (from top 2*k beams) from log probs # and gather top k beams (from top 2*k beams). - next_topk_indices = jnp.flip(lax.top_k(running_topk_log_probs, k=num_beams)[1], axis=1) + next_topk_indices = lax.top_k(running_topk_log_probs, k=num_beams)[1] next_running_sequences, next_running_scores = gather_beams( [topk_sequences, running_topk_log_probs], next_topk_indices, batch_size, num_beams ) @@ -824,10 +833,9 @@ class FlaxGenerationMixin: # - make sure no scores can be added anymore if beam is full # - make sure still running sequences cannot be chosen as finalized beam topk_log_probs = topk_log_probs / (state.cur_len**length_penalty) - beams_in_batch_are_full = ( - jnp.broadcast_to(state.is_sent_finished.all(axis=-1, keepdims=True), did_topk_just_finished.shape) - & early_stopping - ) + beams_in_batch_are_full = jnp.broadcast_to( + state.is_sent_finished.all(axis=-1, keepdims=True), did_topk_just_finished.shape + ) & (early_stopping is True) add_penalty = ~did_topk_just_finished | beams_in_batch_are_full topk_log_probs += add_penalty * np.array(-1.0e7) @@ -838,7 +846,7 @@ class FlaxGenerationMixin: merged_sequences = jnp.concatenate([state.sequences, topk_sequences], axis=1) merged_scores = jnp.concatenate([state.scores, topk_log_probs], axis=1) merged_is_sent_finished = jnp.concatenate([state.is_sent_finished, did_topk_just_finished], axis=1) - topk_merged_indices = jnp.flip(lax.top_k(merged_scores, k=num_beams)[1], axis=1) + topk_merged_indices = lax.top_k(merged_scores, k=num_beams)[1] next_sequences, next_scores, next_is_sent_finished = gather_beams( [merged_sequences, merged_scores, merged_is_sent_finished], topk_merged_indices, batch_size, num_beams ) @@ -877,7 +885,7 @@ class FlaxGenerationMixin: scores = jnp.where(none_finished[:, None], state.scores, state.running_scores) # take best beam for each batch - sequences = sequences[:, -1] - scores = scores[:, -1] + sequences = sequences[:, 0] + scores = scores[:, 0] return FlaxBeamSearchOutput(sequences=sequences, scores=scores) diff --git a/src/transformers/generation/tf_utils.py b/src/transformers/generation/tf_utils.py index deeefdcec43..c06e6132ec7 100644 --- a/src/transformers/generation/tf_utils.py +++ b/src/transformers/generation/tf_utils.py @@ -611,6 +611,7 @@ class TFGenerationMixin: generation_config = copy.deepcopy(generation_config) model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs + generation_config.validate() self._validate_model_kwargs(model_kwargs.copy()) # 2. Cast input dtypes to tf.int32 unless they're floats (which happens for some image models) @@ -1808,7 +1809,7 @@ class TFGenerationMixin: pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, length_penalty: Optional[float] = None, - early_stopping: Optional[bool] = None, + early_stopping: Optional[Union[bool, str]] = None, logits_processor: Optional[TFLogitsProcessorList] = None, logits_warper: Optional[TFLogitsProcessorList] = None, num_return_sequences: Optional[int] = None, @@ -1838,8 +1839,12 @@ class TFGenerationMixin: to the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while `length_penalty` < 0.0 encourages shorter sequences. - early_stopping (`bool`, *optional*, defaults to `False`): - Whether to stop the beam search when at least `num_beams` sentences are finished per batch or not. + early_stopping (`bool` or `str`, *optional*, defaults to `False`): + Controls the stopping condition for beam-based methods, like beam-search. It accepts the following + values: `True`, where the generation stops as soon as there are `num_beams` complete candidates; + `False`, where an heuristic is applied and the generation stops when is it very unlikely to find better + candidates; `"never"`, where the beam search procedure only stops when there cannot be better + candidates (canonical beam search algorithm). logits_processor (`[TFLogitsProcessorList]`, *optional*): An instance of [`TFLogitsProcessorList`]. List of instances of class derived from [`TFLogitsProcessor`] used to modify the prediction scores of the language modeling head applied at each generation step. @@ -2009,16 +2014,24 @@ class TFGenerationMixin: not_max_length_yet = cur_len < max_length # 2. can the new beams still improve? - best_running_score = running_scores[:, :1] / (max_length**length_penalty) + # early_stopping == False -> apply heuristic = always get the best score from `cur_len`. See the discussion + # below for more details. + # https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565 + # early_stopping == "never" -> compute the best score from max_length or cur_len, depending on the sign of + # length_penalty. Positive length_penalty favors longer sequences, thus we use max_length there. + if early_stopping == "never" and length_penalty > 0.0: + best_running_score = running_scores[:, :1] / (max_length**length_penalty) + else: + best_running_score = running_scores[:, :1] / (tf.cast(cur_len, dtype=tf.float32) ** length_penalty) worst_finished_score = tf.where( is_sent_finished, tf.math.reduce_min(scores, axis=1, keepdims=True), -1.0e9 ) - improvement_still_possible = tf.math.reduce_all(worst_finished_score < best_running_score) + improvement_still_possible = tf.math.reduce_any(best_running_score > worst_finished_score) # 3. is there still a beam that has not finished? - still_open_beam = ~(tf.math.reduce_all(is_sent_finished) & early_stopping) + still_open_beam = ~(tf.math.reduce_all(is_sent_finished) & (early_stopping is True)) - return not_max_length_yet & (still_open_beam | improvement_still_possible) + return not_max_length_yet & still_open_beam & improvement_still_possible def beam_search_body_fn( cur_len, @@ -2140,12 +2153,9 @@ class TFGenerationMixin: # - make sure no scores can be added anymore if beam is full # - make sure still running sequences cannot be chosen as finalized beam topk_log_probs = topk_log_probs / (tf.cast(cur_len, dtype=tf.float32) ** length_penalty) - beams_in_batch_are_full = ( - tf.broadcast_to( - tf.math.reduce_all(is_sent_finished, axis=-1, keepdims=True), shape_list(did_topk_just_finished) - ) - & early_stopping - ) + beams_in_batch_are_full = tf.broadcast_to( + tf.math.reduce_all(is_sent_finished, axis=-1, keepdims=True), shape_list(did_topk_just_finished) + ) & (early_stopping is True) add_penalty = ~did_topk_just_finished | beams_in_batch_are_full topk_log_probs += tf.cast(add_penalty, tf.float32) * -1.0e9 @@ -2239,7 +2249,7 @@ class TFGenerationMixin: sequences = tf.where(none_finished[:, None, None], sequences, running_sequences) scores = tf.where(none_finished[:, None], scores, running_scores) - # Take best beams for each batch (the score is sorted in ascending order) + # Take best beams for each batch (the score is sorted in descending order) sequences = flatten_beam_dim(sequences[:, :num_return_sequences, :]) scores = flatten_beam_dim(scores[:, :num_return_sequences]) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index ef9053eccd8..1d9e53168e0 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1190,6 +1190,7 @@ class GenerationMixin: generation_config = copy.deepcopy(generation_config) model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs + generation_config.validate() self._validate_model_kwargs(model_kwargs.copy()) # 2. Set generation parameters if not already defined @@ -1458,6 +1459,7 @@ class GenerationMixin: length_penalty=generation_config.length_penalty, do_early_stopping=generation_config.early_stopping, num_beam_hyps_to_keep=generation_config.num_return_sequences, + max_length=generation_config.max_length, ) # 12. interleave input_ids with `num_beams` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( @@ -1493,6 +1495,7 @@ class GenerationMixin: device=inputs_tensor.device, length_penalty=generation_config.length_penalty, do_early_stopping=generation_config.early_stopping, + max_length=generation_config.max_length, ) # 13. interleave input_ids with `num_beams` additional sequences per batch @@ -1536,12 +1539,12 @@ class GenerationMixin: beam_scorer = BeamSearchScorer( batch_size=batch_size, num_beams=generation_config.num_beams, - max_length=stopping_criteria.max_length, device=inputs_tensor.device, length_penalty=generation_config.length_penalty, do_early_stopping=generation_config.early_stopping, num_beam_hyps_to_keep=generation_config.num_return_sequences, num_beam_groups=generation_config.num_beam_groups, + max_length=generation_config.max_length, ) # 12. interleave input_ids with `num_beams` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( @@ -1629,6 +1632,7 @@ class GenerationMixin: length_penalty=generation_config.length_penalty, do_early_stopping=generation_config.early_stopping, num_beam_hyps_to_keep=generation_config.num_return_sequences, + max_length=generation_config.max_length, ) # 12. interleave input_ids with `num_beams` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( diff --git a/src/transformers/models/rag/modeling_rag.py b/src/transformers/models/rag/modeling_rag.py index c4186638337..52aca10d5c0 100644 --- a/src/transformers/models/rag/modeling_rag.py +++ b/src/transformers/models/rag/modeling_rag.py @@ -1566,6 +1566,7 @@ class RagTokenForGeneration(RagPreTrainedModel): length_penalty=generation_config.length_penalty, do_early_stopping=generation_config.early_stopping, num_beam_hyps_to_keep=generation_config.num_return_sequences, + max_length=generation_config.max_length, ) return self.beam_search( input_ids, diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 7a57a06afe5..a97f036e129 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2034,59 +2034,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi **model_kwargs, ) - def test_beam_search_warning_if_max_length_is_passed(self): - article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" - bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") - bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to( - torch_device - ) - - batch_size = 1 - num_beams = 3 - - input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) - input_ids = input_ids.expand(num_beams, -1) - model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {}) - - # pretend decoder_input_ids correspond to first encoder input id - decoder_input_ids = input_ids[:, :1] - - stopping_criteria_max_length = 18 - stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=stopping_criteria_max_length)]) - - with self.assertWarns(UserWarning): - beam_scorer = BeamSearchScorer( - batch_size=batch_size, - num_beams=num_beams, - device=torch_device, - max_length=10, - ) - - generated_ids = bart_model.beam_search( - decoder_input_ids, - num_beams=num_beams, - stopping_criteria=stopping_criteria, - beam_scorer=beam_scorer, - **model_kwargs, - ) - - beam_scorer_no_max_len = BeamSearchScorer( - batch_size=batch_size, - num_beams=num_beams, - device=torch_device, - ) - - generated_ids_no_max_len = bart_model.beam_search( - decoder_input_ids, - num_beams=num_beams, - stopping_criteria=stopping_criteria, - beam_scorer=beam_scorer_no_max_len, - **model_kwargs, - ) - - # BeamSearchScorer max_length should not influence "real" max_length - self.assertEqual(generated_ids.tolist(), generated_ids_no_max_len.tolist()) - def test_custom_stopping_criteria_overload_error(self): article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random") diff --git a/tests/models/bart/test_modeling_flax_bart.py b/tests/models/bart/test_modeling_flax_bart.py index 1289ae9ed48..7a1d1c5e8b4 100644 --- a/tests/models/bart/test_modeling_flax_bart.py +++ b/tests/models/bart/test_modeling_flax_bart.py @@ -426,7 +426,7 @@ class FlaxBartModelTest(FlaxModelTesterMixin, unittest.TestCase, FlaxGenerationT ) input_ids = tokenizer(input_str, return_tensors="np").input_ids - sequences = model.generate(input_ids, num_beams=2, max_length=20).sequences + sequences = model.generate(input_ids, num_beams=2, min_length=None, max_length=20).sequences output_str = tokenizer.batch_decode(sequences)[0] diff --git a/tests/models/gpt2/test_modeling_flax_gpt2.py b/tests/models/gpt2/test_modeling_flax_gpt2.py index cb3f3321291..23ff3f11b0a 100644 --- a/tests/models/gpt2/test_modeling_flax_gpt2.py +++ b/tests/models/gpt2/test_modeling_flax_gpt2.py @@ -224,7 +224,7 @@ class FlaxGPT2ModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittes output_string = tokenizer.batch_decode(output_sequences, skip_special_tokens=True) expected_string = [ - "Hello this is a long string of words. I'm going to try to explain what I mean.", + "Hello this is a long string of words. I'm going to start with the first one.\n", "Hey, I'm not sure if I'm going to be able to do", ] diff --git a/tests/models/t5/test_modeling_flax_t5.py b/tests/models/t5/test_modeling_flax_t5.py index f4bd54e97af..10e6622bb7d 100644 --- a/tests/models/t5/test_modeling_flax_t5.py +++ b/tests/models/t5/test_modeling_flax_t5.py @@ -1076,7 +1076,7 @@ class FlaxT5ModelIntegrationTests(unittest.TestCase): expected_summaries = [ 'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a' " cell phone video of the final seconds . \"one can hear cries of 'My God' in several languages,\" one" - " magazine says . all 150 on board were killed when germanwings flight 9525 crashed .", + " magazine says . all 150 on board were killed in the crash .", "the formal accession was marked by a ceremony at The Hague, in the Netherlands . the ICC opened a" " preliminary examination into the situation in the occupied Palestinian territory . as members of the" " court, Palestinians may be subject to counter-charges as well .",