🚨🚨 Generate: standardize beam search behavior across frameworks (#21368)

This commit is contained in:
Joao Gante 2023-02-03 10:24:02 +00:00 committed by GitHub
parent ea55bd86b9
commit f21af26279
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 122 additions and 118 deletions

View File

@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from abc import ABC, abstractmethod
from collections import UserDict
from typing import List, Optional, Tuple, Union
@ -130,8 +129,6 @@ class BeamSearchScorer(BeamScorer):
Args:
batch_size (`int`):
Batch Size of `input_ids` for which standard beam search decoding is run in parallel.
max_length (`int`):
The maximum length of the sequence to be generated.
num_beams (`int`):
Number of beams for beam search.
device (`torch.device`):
@ -142,14 +139,20 @@ class BeamSearchScorer(BeamScorer):
the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log
likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while
`length_penalty` < 0.0 encourages shorter sequences.
do_early_stopping (`bool`, *optional*, defaults to `False`):
Whether to stop the beam search when at least `num_beams` sentences are finished per batch or not.
do_early_stopping (`bool` or `str`, *optional*, defaults to `False`):
Controls the stopping condition for beam-based methods, like beam-search. It accepts the following values:
`True`, where the generation stops as soon as there are `num_beams` complete candidates; `False`, where an
heuristic is applied and the generation stops when is it very unlikely to find better candidates;
`"never"`, where the beam search procedure only stops when there cannot be better candidates (canonical
beam search algorithm).
num_beam_hyps_to_keep (`int`, *optional*, defaults to 1):
The number of beam hypotheses that shall be returned upon calling
[`~transformer.BeamSearchScorer.finalize`].
num_beam_groups (`int`):
Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams.
See [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
max_length (`int`, *optional*):
The maximum length of the sequence to be generated.
"""
def __init__(
@ -158,10 +161,10 @@ class BeamSearchScorer(BeamScorer):
num_beams: int,
device: torch.device,
length_penalty: Optional[float] = 1.0,
do_early_stopping: Optional[bool] = False,
do_early_stopping: Optional[Union[bool, str]] = False,
num_beam_hyps_to_keep: Optional[int] = 1,
num_beam_groups: Optional[int] = 1,
**kwargs,
max_length: Optional[int] = None,
):
self.num_beams = num_beams
self.device = device
@ -177,6 +180,7 @@ class BeamSearchScorer(BeamScorer):
num_beams=self.num_beams,
length_penalty=self.length_penalty,
early_stopping=self.do_early_stopping,
max_length=max_length,
)
for _ in range(batch_size)
]
@ -194,13 +198,6 @@ class BeamSearchScorer(BeamScorer):
f" divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}."
)
if "max_length" in kwargs:
warnings.warn(
"Passing `max_length` to BeamSearchScorer is deprecated and has no effect. "
"`max_length` should be passed directly to `beam_search(...)`, `beam_sample(...)`"
", or `group_beam_search(...)`."
)
@property
def is_done(self) -> bool:
return self._done.all()
@ -402,8 +399,6 @@ class ConstrainedBeamSearchScorer(BeamScorer):
Args:
batch_size (`int`):
Batch Size of `input_ids` for which standard beam search decoding is run in parallel.
max_length (`int`):
The maximum length of the sequence to be generated.
num_beams (`int`):
Number of beams for beam search.
constraints (`List[Constraint]`):
@ -417,14 +412,20 @@ class ConstrainedBeamSearchScorer(BeamScorer):
the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log
likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while
`length_penalty` < 0.0 encourages shorter sequences.
do_early_stopping (`bool`, *optional*, defaults to `False`):
Whether to stop the beam search when at least `num_beams` sentences are finished per batch or not.
do_early_stopping (`bool` or `str`, *optional*, defaults to `False`):
Controls the stopping condition for beam-based methods, like beam-search. It accepts the following values:
`True`, where the generation stops as soon as there are `num_beams` complete candidates; `False`, where an
heuristic is applied and the generation stops when is it very unlikely to find better candidates;
`"never"`, where the beam search procedure only stops when there cannot be better candidates (canonical
beam search algorithm).
num_beam_hyps_to_keep (`int`, *optional*, defaults to 1):
The number of beam hypotheses that shall be returned upon calling
[`~transformer.BeamSearchScorer.finalize`].
num_beam_groups (`int`):
Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams.
See [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
max_length (`int`, *optional*):
The maximum length of the sequence to be generated.
"""
def __init__(
@ -434,10 +435,10 @@ class ConstrainedBeamSearchScorer(BeamScorer):
constraints: List[Constraint],
device: torch.device,
length_penalty: Optional[float] = 1.0,
do_early_stopping: Optional[bool] = False,
do_early_stopping: Optional[Union[bool, str]] = False,
num_beam_hyps_to_keep: Optional[int] = 1,
num_beam_groups: Optional[int] = 1,
**kwargs,
max_length: Optional[int] = None,
):
self.num_beams = num_beams
self.device = device
@ -454,6 +455,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
num_beams=self.num_beams,
length_penalty=self.length_penalty,
early_stopping=self.do_early_stopping,
max_length=max_length,
)
for _ in range(batch_size)
]
@ -471,13 +473,6 @@ class ConstrainedBeamSearchScorer(BeamScorer):
f" divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}."
)
if "max_length" in kwargs:
warnings.warn(
"Passing `max_length` to ConstrainedBeamSearchScorer is deprecated and has no effect. "
"`max_length` should be passed directly to `beam_search(...)`, `beam_sample(...)`"
", or `group_beam_search(...)`."
)
@property
def is_done(self) -> bool:
return self._done.all()
@ -865,16 +860,23 @@ class ConstrainedBeamSearchScorer(BeamScorer):
class BeamHypotheses:
def __init__(self, num_beams: int, length_penalty: float, early_stopping: bool):
def __init__(self, num_beams: int, length_penalty: float, early_stopping: bool, max_length: Optional[int] = None):
"""
Initialize n-best list of hypotheses.
"""
self.length_penalty = length_penalty
self.early_stopping = early_stopping
self.max_length = max_length
self.num_beams = num_beams
self.beams = []
self.worst_score = 1e9
if not isinstance(self.early_stopping, bool) and self.max_length is None:
raise ValueError(
"When `do_early_stopping` is set to a string, `max_length` must be defined. Ensure it is passed to the"
" BeamScorer class instance at initialization time."
)
def __len__(self):
"""
Number of hypotheses in the list.
@ -903,9 +905,26 @@ class BeamHypotheses:
if len(self) < self.num_beams:
return False
elif self.early_stopping:
# `True`: stop as soon as at least `num_beams` hypotheses are finished
if self.early_stopping is True:
return True
else:
cur_score = best_sum_logprobs / cur_len**self.length_penalty
ret = self.worst_score >= cur_score
# `False`: heuristic -- compute best possible score from `cur_len`, even though it is not entirely accurate
# when `length_penalty` is positive. See the discussion below for more details.
# https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565
elif self.early_stopping is False:
highest_attainable_score = best_sum_logprobs / cur_len**self.length_penalty
ret = self.worst_score >= highest_attainable_score
return ret
# `"never"`: compute the best possible score, depending on the signal of `length_penalty`
else:
# `length_penalty` > 0.0 -> max denominator is obtaned from `max_length`, not from `cur_len` -> min
# abs(`highest_attainable_score`) is obtained -> `highest_attainable_score` is negative, hence we obtain
# its max this way
if self.length_penalty > 0.0:
highest_attainable_score = best_sum_logprobs / self.max_length**self.length_penalty
# the opposite logic applies here (max `highest_attainable_score` from `cur_len`)
else:
highest_attainable_score = best_sum_logprobs / cur_len**self.length_penalty
ret = self.worst_score >= highest_attainable_score
return ret

View File

@ -71,8 +71,12 @@ class GenerationConfig(PushToHubMixin):
`min_new_tokens`. Its effect is overridden by `min_new_tokens`, if also set.
min_new_tokens (`int`, *optional*):
The minimum numbers of tokens to generate, ignoring the number of tokens in the prompt.
early_stopping (`bool`, *optional*, defaults to `False`):
Whether to stop the beam search when at least `num_beams` sentences are finished per batch or not.
early_stopping (`bool` or `str`, *optional*, defaults to `False`):
Controls the stopping condition for beam-based methods, like beam-search. It accepts the following values:
`True`, where the generation stops as soon as there are `num_beams` complete candidates; `False`, where an
heuristic is applied and the generation stops when is it very unlikely to find better candidates;
`"never"`, where the beam search procedure only stops when there cannot be better candidates (canonical
beam search algorithm).
max_time(`float`, *optional*):
The maximum amount of time you allow the computation to run for in seconds. generation will still finish
the current pass after allocated time has been passed.
@ -290,6 +294,9 @@ class GenerationConfig(PushToHubMixin):
logger.error(f"Can't set {key} with value {value} for {self}")
raise err
# Validate the values of the attributes
self.validate()
def __eq__(self, other):
self_dict = self.__dict__.copy()
other_dict = other.__dict__.copy()
@ -302,6 +309,14 @@ class GenerationConfig(PushToHubMixin):
def __repr__(self):
return f"{self.__class__.__name__} {self.to_json_string()}"
def validate(self):
"""
Validates the values of the attributes of the GenerationConfig instance, and raises a `ValueError` if any of
the values are invalid.
"""
if self.early_stopping not in {True, False, "never"}:
raise ValueError(f"`early_stopping` must be a boolean or 'never', but is {self.early_stopping}.")
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],

View File

@ -19,7 +19,7 @@ import copy
import inspect
import warnings
from functools import partial
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union
import numpy as np
@ -275,6 +275,7 @@ class FlaxGenerationMixin:
generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
generation_config.validate()
self._validate_model_kwargs(model_kwargs.copy())
# set init values
@ -633,7 +634,7 @@ class FlaxGenerationMixin:
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
length_penalty: Optional[float] = None,
early_stopping: Optional[bool] = None,
early_stopping: Optional[Union[bool, str]] = None,
logits_processor: Optional[FlaxLogitsProcessorList] = None,
trace: bool = True,
params: Optional[Dict[str, jnp.ndarray]] = None,
@ -733,14 +734,22 @@ class FlaxGenerationMixin:
not_max_length_yet = state.cur_len < max_length
# 2. can the new beams still improve?
best_running_score = state.running_scores[:, -1:] / (max_length**length_penalty)
# early_stopping == False -> apply heuristic = always get the best score from `cur_len`. See the discussion
# below for more details.
# https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565
# early_stopping == "never" -> compute the best score from max_length or cur_len, depending on the sign of
# length_penalty. Positive length_penalty favors longer sequences, thus we use max_length there.
if early_stopping == "never" and length_penalty > 0.0:
best_running_score = state.running_scores[:, :1] / (max_length**length_penalty)
else:
best_running_score = state.running_scores[:, :1] / (state.cur_len**length_penalty)
worst_finished_score = jnp.where(
state.is_sent_finished, jnp.min(state.scores, axis=1, keepdims=True), np.array(-1.0e7)
)
improvement_still_possible = jnp.all(worst_finished_score < best_running_score)
improvement_still_possible = jnp.any(best_running_score > worst_finished_score)
# 3. is there still a beam that has not finished?
still_open_beam = ~(jnp.all(state.is_sent_finished) & early_stopping)
still_open_beam = ~(jnp.all(state.is_sent_finished) & (early_stopping is True))
return not_max_length_yet & still_open_beam & improvement_still_possible
@ -813,7 +822,7 @@ class FlaxGenerationMixin:
# 5. Get running sequences scores for next
# Determine the top k beam indices (from top 2*k beams) from log probs
# and gather top k beams (from top 2*k beams).
next_topk_indices = jnp.flip(lax.top_k(running_topk_log_probs, k=num_beams)[1], axis=1)
next_topk_indices = lax.top_k(running_topk_log_probs, k=num_beams)[1]
next_running_sequences, next_running_scores = gather_beams(
[topk_sequences, running_topk_log_probs], next_topk_indices, batch_size, num_beams
)
@ -824,10 +833,9 @@ class FlaxGenerationMixin:
# - make sure no scores can be added anymore if beam is full
# - make sure still running sequences cannot be chosen as finalized beam
topk_log_probs = topk_log_probs / (state.cur_len**length_penalty)
beams_in_batch_are_full = (
jnp.broadcast_to(state.is_sent_finished.all(axis=-1, keepdims=True), did_topk_just_finished.shape)
& early_stopping
)
beams_in_batch_are_full = jnp.broadcast_to(
state.is_sent_finished.all(axis=-1, keepdims=True), did_topk_just_finished.shape
) & (early_stopping is True)
add_penalty = ~did_topk_just_finished | beams_in_batch_are_full
topk_log_probs += add_penalty * np.array(-1.0e7)
@ -838,7 +846,7 @@ class FlaxGenerationMixin:
merged_sequences = jnp.concatenate([state.sequences, topk_sequences], axis=1)
merged_scores = jnp.concatenate([state.scores, topk_log_probs], axis=1)
merged_is_sent_finished = jnp.concatenate([state.is_sent_finished, did_topk_just_finished], axis=1)
topk_merged_indices = jnp.flip(lax.top_k(merged_scores, k=num_beams)[1], axis=1)
topk_merged_indices = lax.top_k(merged_scores, k=num_beams)[1]
next_sequences, next_scores, next_is_sent_finished = gather_beams(
[merged_sequences, merged_scores, merged_is_sent_finished], topk_merged_indices, batch_size, num_beams
)
@ -877,7 +885,7 @@ class FlaxGenerationMixin:
scores = jnp.where(none_finished[:, None], state.scores, state.running_scores)
# take best beam for each batch
sequences = sequences[:, -1]
scores = scores[:, -1]
sequences = sequences[:, 0]
scores = scores[:, 0]
return FlaxBeamSearchOutput(sequences=sequences, scores=scores)

View File

@ -611,6 +611,7 @@ class TFGenerationMixin:
generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
generation_config.validate()
self._validate_model_kwargs(model_kwargs.copy())
# 2. Cast input dtypes to tf.int32 unless they're floats (which happens for some image models)
@ -1808,7 +1809,7 @@ class TFGenerationMixin:
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
length_penalty: Optional[float] = None,
early_stopping: Optional[bool] = None,
early_stopping: Optional[Union[bool, str]] = None,
logits_processor: Optional[TFLogitsProcessorList] = None,
logits_warper: Optional[TFLogitsProcessorList] = None,
num_return_sequences: Optional[int] = None,
@ -1838,8 +1839,12 @@ class TFGenerationMixin:
to the sequence length, which in turn is used to divide the score of the sequence. Since the score is
the log likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences,
while `length_penalty` < 0.0 encourages shorter sequences.
early_stopping (`bool`, *optional*, defaults to `False`):
Whether to stop the beam search when at least `num_beams` sentences are finished per batch or not.
early_stopping (`bool` or `str`, *optional*, defaults to `False`):
Controls the stopping condition for beam-based methods, like beam-search. It accepts the following
values: `True`, where the generation stops as soon as there are `num_beams` complete candidates;
`False`, where an heuristic is applied and the generation stops when is it very unlikely to find better
candidates; `"never"`, where the beam search procedure only stops when there cannot be better
candidates (canonical beam search algorithm).
logits_processor (`[TFLogitsProcessorList]`, *optional*):
An instance of [`TFLogitsProcessorList`]. List of instances of class derived from [`TFLogitsProcessor`]
used to modify the prediction scores of the language modeling head applied at each generation step.
@ -2009,16 +2014,24 @@ class TFGenerationMixin:
not_max_length_yet = cur_len < max_length
# 2. can the new beams still improve?
best_running_score = running_scores[:, :1] / (max_length**length_penalty)
# early_stopping == False -> apply heuristic = always get the best score from `cur_len`. See the discussion
# below for more details.
# https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565
# early_stopping == "never" -> compute the best score from max_length or cur_len, depending on the sign of
# length_penalty. Positive length_penalty favors longer sequences, thus we use max_length there.
if early_stopping == "never" and length_penalty > 0.0:
best_running_score = running_scores[:, :1] / (max_length**length_penalty)
else:
best_running_score = running_scores[:, :1] / (tf.cast(cur_len, dtype=tf.float32) ** length_penalty)
worst_finished_score = tf.where(
is_sent_finished, tf.math.reduce_min(scores, axis=1, keepdims=True), -1.0e9
)
improvement_still_possible = tf.math.reduce_all(worst_finished_score < best_running_score)
improvement_still_possible = tf.math.reduce_any(best_running_score > worst_finished_score)
# 3. is there still a beam that has not finished?
still_open_beam = ~(tf.math.reduce_all(is_sent_finished) & early_stopping)
still_open_beam = ~(tf.math.reduce_all(is_sent_finished) & (early_stopping is True))
return not_max_length_yet & (still_open_beam | improvement_still_possible)
return not_max_length_yet & still_open_beam & improvement_still_possible
def beam_search_body_fn(
cur_len,
@ -2140,12 +2153,9 @@ class TFGenerationMixin:
# - make sure no scores can be added anymore if beam is full
# - make sure still running sequences cannot be chosen as finalized beam
topk_log_probs = topk_log_probs / (tf.cast(cur_len, dtype=tf.float32) ** length_penalty)
beams_in_batch_are_full = (
tf.broadcast_to(
tf.math.reduce_all(is_sent_finished, axis=-1, keepdims=True), shape_list(did_topk_just_finished)
)
& early_stopping
)
beams_in_batch_are_full = tf.broadcast_to(
tf.math.reduce_all(is_sent_finished, axis=-1, keepdims=True), shape_list(did_topk_just_finished)
) & (early_stopping is True)
add_penalty = ~did_topk_just_finished | beams_in_batch_are_full
topk_log_probs += tf.cast(add_penalty, tf.float32) * -1.0e9
@ -2239,7 +2249,7 @@ class TFGenerationMixin:
sequences = tf.where(none_finished[:, None, None], sequences, running_sequences)
scores = tf.where(none_finished[:, None], scores, running_scores)
# Take best beams for each batch (the score is sorted in ascending order)
# Take best beams for each batch (the score is sorted in descending order)
sequences = flatten_beam_dim(sequences[:, :num_return_sequences, :])
scores = flatten_beam_dim(scores[:, :num_return_sequences])

View File

@ -1190,6 +1190,7 @@ class GenerationMixin:
generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
generation_config.validate()
self._validate_model_kwargs(model_kwargs.copy())
# 2. Set generation parameters if not already defined
@ -1458,6 +1459,7 @@ class GenerationMixin:
length_penalty=generation_config.length_penalty,
do_early_stopping=generation_config.early_stopping,
num_beam_hyps_to_keep=generation_config.num_return_sequences,
max_length=generation_config.max_length,
)
# 12. interleave input_ids with `num_beams` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
@ -1493,6 +1495,7 @@ class GenerationMixin:
device=inputs_tensor.device,
length_penalty=generation_config.length_penalty,
do_early_stopping=generation_config.early_stopping,
max_length=generation_config.max_length,
)
# 13. interleave input_ids with `num_beams` additional sequences per batch
@ -1536,12 +1539,12 @@ class GenerationMixin:
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
num_beams=generation_config.num_beams,
max_length=stopping_criteria.max_length,
device=inputs_tensor.device,
length_penalty=generation_config.length_penalty,
do_early_stopping=generation_config.early_stopping,
num_beam_hyps_to_keep=generation_config.num_return_sequences,
num_beam_groups=generation_config.num_beam_groups,
max_length=generation_config.max_length,
)
# 12. interleave input_ids with `num_beams` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
@ -1629,6 +1632,7 @@ class GenerationMixin:
length_penalty=generation_config.length_penalty,
do_early_stopping=generation_config.early_stopping,
num_beam_hyps_to_keep=generation_config.num_return_sequences,
max_length=generation_config.max_length,
)
# 12. interleave input_ids with `num_beams` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(

View File

@ -1566,6 +1566,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
length_penalty=generation_config.length_penalty,
do_early_stopping=generation_config.early_stopping,
num_beam_hyps_to_keep=generation_config.num_return_sequences,
max_length=generation_config.max_length,
)
return self.beam_search(
input_ids,

View File

@ -2034,59 +2034,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
**model_kwargs,
)
def test_beam_search_warning_if_max_length_is_passed(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
torch_device
)
batch_size = 1
num_beams = 3
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
input_ids = input_ids.expand(num_beams, -1)
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
# pretend decoder_input_ids correspond to first encoder input id
decoder_input_ids = input_ids[:, :1]
stopping_criteria_max_length = 18
stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=stopping_criteria_max_length)])
with self.assertWarns(UserWarning):
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
num_beams=num_beams,
device=torch_device,
max_length=10,
)
generated_ids = bart_model.beam_search(
decoder_input_ids,
num_beams=num_beams,
stopping_criteria=stopping_criteria,
beam_scorer=beam_scorer,
**model_kwargs,
)
beam_scorer_no_max_len = BeamSearchScorer(
batch_size=batch_size,
num_beams=num_beams,
device=torch_device,
)
generated_ids_no_max_len = bart_model.beam_search(
decoder_input_ids,
num_beams=num_beams,
stopping_criteria=stopping_criteria,
beam_scorer=beam_scorer_no_max_len,
**model_kwargs,
)
# BeamSearchScorer max_length should not influence "real" max_length
self.assertEqual(generated_ids.tolist(), generated_ids_no_max_len.tolist())
def test_custom_stopping_criteria_overload_error(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")

View File

@ -426,7 +426,7 @@ class FlaxBartModelTest(FlaxModelTesterMixin, unittest.TestCase, FlaxGenerationT
)
input_ids = tokenizer(input_str, return_tensors="np").input_ids
sequences = model.generate(input_ids, num_beams=2, max_length=20).sequences
sequences = model.generate(input_ids, num_beams=2, min_length=None, max_length=20).sequences
output_str = tokenizer.batch_decode(sequences)[0]

View File

@ -224,7 +224,7 @@ class FlaxGPT2ModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittes
output_string = tokenizer.batch_decode(output_sequences, skip_special_tokens=True)
expected_string = [
"Hello this is a long string of words. I'm going to try to explain what I mean.",
"Hello this is a long string of words. I'm going to start with the first one.\n",
"Hey, I'm not sure if I'm going to be able to do",
]

View File

@ -1076,7 +1076,7 @@ class FlaxT5ModelIntegrationTests(unittest.TestCase):
expected_summaries = [
'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a'
" cell phone video of the final seconds . \"one can hear cries of 'My God' in several languages,\" one"
" magazine says . all 150 on board were killed when germanwings flight 9525 crashed .",
" magazine says . all 150 on board were killed in the crash .",
"the formal accession was marked by a ceremony at The Hague, in the Netherlands . the ICC opened a"
" preliminary examination into the situation in the occupied Palestinian territory . as members of the"
" court, Palestinians may be subject to counter-charges as well .",