mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
🚨🚨 Generate: standardize beam search behavior across frameworks (#21368)
This commit is contained in:
parent
ea55bd86b9
commit
f21af26279
@ -13,7 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import UserDict
|
||||
from typing import List, Optional, Tuple, Union
|
||||
@ -130,8 +129,6 @@ class BeamSearchScorer(BeamScorer):
|
||||
Args:
|
||||
batch_size (`int`):
|
||||
Batch Size of `input_ids` for which standard beam search decoding is run in parallel.
|
||||
max_length (`int`):
|
||||
The maximum length of the sequence to be generated.
|
||||
num_beams (`int`):
|
||||
Number of beams for beam search.
|
||||
device (`torch.device`):
|
||||
@ -142,14 +139,20 @@ class BeamSearchScorer(BeamScorer):
|
||||
the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log
|
||||
likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while
|
||||
`length_penalty` < 0.0 encourages shorter sequences.
|
||||
do_early_stopping (`bool`, *optional*, defaults to `False`):
|
||||
Whether to stop the beam search when at least `num_beams` sentences are finished per batch or not.
|
||||
do_early_stopping (`bool` or `str`, *optional*, defaults to `False`):
|
||||
Controls the stopping condition for beam-based methods, like beam-search. It accepts the following values:
|
||||
`True`, where the generation stops as soon as there are `num_beams` complete candidates; `False`, where an
|
||||
heuristic is applied and the generation stops when is it very unlikely to find better candidates;
|
||||
`"never"`, where the beam search procedure only stops when there cannot be better candidates (canonical
|
||||
beam search algorithm).
|
||||
num_beam_hyps_to_keep (`int`, *optional*, defaults to 1):
|
||||
The number of beam hypotheses that shall be returned upon calling
|
||||
[`~transformer.BeamSearchScorer.finalize`].
|
||||
num_beam_groups (`int`):
|
||||
Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams.
|
||||
See [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
|
||||
max_length (`int`, *optional*):
|
||||
The maximum length of the sequence to be generated.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -158,10 +161,10 @@ class BeamSearchScorer(BeamScorer):
|
||||
num_beams: int,
|
||||
device: torch.device,
|
||||
length_penalty: Optional[float] = 1.0,
|
||||
do_early_stopping: Optional[bool] = False,
|
||||
do_early_stopping: Optional[Union[bool, str]] = False,
|
||||
num_beam_hyps_to_keep: Optional[int] = 1,
|
||||
num_beam_groups: Optional[int] = 1,
|
||||
**kwargs,
|
||||
max_length: Optional[int] = None,
|
||||
):
|
||||
self.num_beams = num_beams
|
||||
self.device = device
|
||||
@ -177,6 +180,7 @@ class BeamSearchScorer(BeamScorer):
|
||||
num_beams=self.num_beams,
|
||||
length_penalty=self.length_penalty,
|
||||
early_stopping=self.do_early_stopping,
|
||||
max_length=max_length,
|
||||
)
|
||||
for _ in range(batch_size)
|
||||
]
|
||||
@ -194,13 +198,6 @@ class BeamSearchScorer(BeamScorer):
|
||||
f" divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}."
|
||||
)
|
||||
|
||||
if "max_length" in kwargs:
|
||||
warnings.warn(
|
||||
"Passing `max_length` to BeamSearchScorer is deprecated and has no effect. "
|
||||
"`max_length` should be passed directly to `beam_search(...)`, `beam_sample(...)`"
|
||||
", or `group_beam_search(...)`."
|
||||
)
|
||||
|
||||
@property
|
||||
def is_done(self) -> bool:
|
||||
return self._done.all()
|
||||
@ -402,8 +399,6 @@ class ConstrainedBeamSearchScorer(BeamScorer):
|
||||
Args:
|
||||
batch_size (`int`):
|
||||
Batch Size of `input_ids` for which standard beam search decoding is run in parallel.
|
||||
max_length (`int`):
|
||||
The maximum length of the sequence to be generated.
|
||||
num_beams (`int`):
|
||||
Number of beams for beam search.
|
||||
constraints (`List[Constraint]`):
|
||||
@ -417,14 +412,20 @@ class ConstrainedBeamSearchScorer(BeamScorer):
|
||||
the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log
|
||||
likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while
|
||||
`length_penalty` < 0.0 encourages shorter sequences.
|
||||
do_early_stopping (`bool`, *optional*, defaults to `False`):
|
||||
Whether to stop the beam search when at least `num_beams` sentences are finished per batch or not.
|
||||
do_early_stopping (`bool` or `str`, *optional*, defaults to `False`):
|
||||
Controls the stopping condition for beam-based methods, like beam-search. It accepts the following values:
|
||||
`True`, where the generation stops as soon as there are `num_beams` complete candidates; `False`, where an
|
||||
heuristic is applied and the generation stops when is it very unlikely to find better candidates;
|
||||
`"never"`, where the beam search procedure only stops when there cannot be better candidates (canonical
|
||||
beam search algorithm).
|
||||
num_beam_hyps_to_keep (`int`, *optional*, defaults to 1):
|
||||
The number of beam hypotheses that shall be returned upon calling
|
||||
[`~transformer.BeamSearchScorer.finalize`].
|
||||
num_beam_groups (`int`):
|
||||
Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams.
|
||||
See [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
|
||||
max_length (`int`, *optional*):
|
||||
The maximum length of the sequence to be generated.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -434,10 +435,10 @@ class ConstrainedBeamSearchScorer(BeamScorer):
|
||||
constraints: List[Constraint],
|
||||
device: torch.device,
|
||||
length_penalty: Optional[float] = 1.0,
|
||||
do_early_stopping: Optional[bool] = False,
|
||||
do_early_stopping: Optional[Union[bool, str]] = False,
|
||||
num_beam_hyps_to_keep: Optional[int] = 1,
|
||||
num_beam_groups: Optional[int] = 1,
|
||||
**kwargs,
|
||||
max_length: Optional[int] = None,
|
||||
):
|
||||
self.num_beams = num_beams
|
||||
self.device = device
|
||||
@ -454,6 +455,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
|
||||
num_beams=self.num_beams,
|
||||
length_penalty=self.length_penalty,
|
||||
early_stopping=self.do_early_stopping,
|
||||
max_length=max_length,
|
||||
)
|
||||
for _ in range(batch_size)
|
||||
]
|
||||
@ -471,13 +473,6 @@ class ConstrainedBeamSearchScorer(BeamScorer):
|
||||
f" divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}."
|
||||
)
|
||||
|
||||
if "max_length" in kwargs:
|
||||
warnings.warn(
|
||||
"Passing `max_length` to ConstrainedBeamSearchScorer is deprecated and has no effect. "
|
||||
"`max_length` should be passed directly to `beam_search(...)`, `beam_sample(...)`"
|
||||
", or `group_beam_search(...)`."
|
||||
)
|
||||
|
||||
@property
|
||||
def is_done(self) -> bool:
|
||||
return self._done.all()
|
||||
@ -865,16 +860,23 @@ class ConstrainedBeamSearchScorer(BeamScorer):
|
||||
|
||||
|
||||
class BeamHypotheses:
|
||||
def __init__(self, num_beams: int, length_penalty: float, early_stopping: bool):
|
||||
def __init__(self, num_beams: int, length_penalty: float, early_stopping: bool, max_length: Optional[int] = None):
|
||||
"""
|
||||
Initialize n-best list of hypotheses.
|
||||
"""
|
||||
self.length_penalty = length_penalty
|
||||
self.early_stopping = early_stopping
|
||||
self.max_length = max_length
|
||||
self.num_beams = num_beams
|
||||
self.beams = []
|
||||
self.worst_score = 1e9
|
||||
|
||||
if not isinstance(self.early_stopping, bool) and self.max_length is None:
|
||||
raise ValueError(
|
||||
"When `do_early_stopping` is set to a string, `max_length` must be defined. Ensure it is passed to the"
|
||||
" BeamScorer class instance at initialization time."
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
Number of hypotheses in the list.
|
||||
@ -903,9 +905,26 @@ class BeamHypotheses:
|
||||
|
||||
if len(self) < self.num_beams:
|
||||
return False
|
||||
elif self.early_stopping:
|
||||
|
||||
# `True`: stop as soon as at least `num_beams` hypotheses are finished
|
||||
if self.early_stopping is True:
|
||||
return True
|
||||
else:
|
||||
cur_score = best_sum_logprobs / cur_len**self.length_penalty
|
||||
ret = self.worst_score >= cur_score
|
||||
# `False`: heuristic -- compute best possible score from `cur_len`, even though it is not entirely accurate
|
||||
# when `length_penalty` is positive. See the discussion below for more details.
|
||||
# https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565
|
||||
elif self.early_stopping is False:
|
||||
highest_attainable_score = best_sum_logprobs / cur_len**self.length_penalty
|
||||
ret = self.worst_score >= highest_attainable_score
|
||||
return ret
|
||||
# `"never"`: compute the best possible score, depending on the signal of `length_penalty`
|
||||
else:
|
||||
# `length_penalty` > 0.0 -> max denominator is obtaned from `max_length`, not from `cur_len` -> min
|
||||
# abs(`highest_attainable_score`) is obtained -> `highest_attainable_score` is negative, hence we obtain
|
||||
# its max this way
|
||||
if self.length_penalty > 0.0:
|
||||
highest_attainable_score = best_sum_logprobs / self.max_length**self.length_penalty
|
||||
# the opposite logic applies here (max `highest_attainable_score` from `cur_len`)
|
||||
else:
|
||||
highest_attainable_score = best_sum_logprobs / cur_len**self.length_penalty
|
||||
ret = self.worst_score >= highest_attainable_score
|
||||
return ret
|
||||
|
@ -71,8 +71,12 @@ class GenerationConfig(PushToHubMixin):
|
||||
`min_new_tokens`. Its effect is overridden by `min_new_tokens`, if also set.
|
||||
min_new_tokens (`int`, *optional*):
|
||||
The minimum numbers of tokens to generate, ignoring the number of tokens in the prompt.
|
||||
early_stopping (`bool`, *optional*, defaults to `False`):
|
||||
Whether to stop the beam search when at least `num_beams` sentences are finished per batch or not.
|
||||
early_stopping (`bool` or `str`, *optional*, defaults to `False`):
|
||||
Controls the stopping condition for beam-based methods, like beam-search. It accepts the following values:
|
||||
`True`, where the generation stops as soon as there are `num_beams` complete candidates; `False`, where an
|
||||
heuristic is applied and the generation stops when is it very unlikely to find better candidates;
|
||||
`"never"`, where the beam search procedure only stops when there cannot be better candidates (canonical
|
||||
beam search algorithm).
|
||||
max_time(`float`, *optional*):
|
||||
The maximum amount of time you allow the computation to run for in seconds. generation will still finish
|
||||
the current pass after allocated time has been passed.
|
||||
@ -290,6 +294,9 @@ class GenerationConfig(PushToHubMixin):
|
||||
logger.error(f"Can't set {key} with value {value} for {self}")
|
||||
raise err
|
||||
|
||||
# Validate the values of the attributes
|
||||
self.validate()
|
||||
|
||||
def __eq__(self, other):
|
||||
self_dict = self.__dict__.copy()
|
||||
other_dict = other.__dict__.copy()
|
||||
@ -302,6 +309,14 @@ class GenerationConfig(PushToHubMixin):
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__} {self.to_json_string()}"
|
||||
|
||||
def validate(self):
|
||||
"""
|
||||
Validates the values of the attributes of the GenerationConfig instance, and raises a `ValueError` if any of
|
||||
the values are invalid.
|
||||
"""
|
||||
if self.early_stopping not in {True, False, "never"}:
|
||||
raise ValueError(f"`early_stopping` must be a boolean or 'never', but is {self.early_stopping}.")
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
|
@ -19,7 +19,7 @@ import copy
|
||||
import inspect
|
||||
import warnings
|
||||
from functools import partial
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -275,6 +275,7 @@ class FlaxGenerationMixin:
|
||||
|
||||
generation_config = copy.deepcopy(generation_config)
|
||||
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
|
||||
generation_config.validate()
|
||||
self._validate_model_kwargs(model_kwargs.copy())
|
||||
|
||||
# set init values
|
||||
@ -633,7 +634,7 @@ class FlaxGenerationMixin:
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
length_penalty: Optional[float] = None,
|
||||
early_stopping: Optional[bool] = None,
|
||||
early_stopping: Optional[Union[bool, str]] = None,
|
||||
logits_processor: Optional[FlaxLogitsProcessorList] = None,
|
||||
trace: bool = True,
|
||||
params: Optional[Dict[str, jnp.ndarray]] = None,
|
||||
@ -733,14 +734,22 @@ class FlaxGenerationMixin:
|
||||
not_max_length_yet = state.cur_len < max_length
|
||||
|
||||
# 2. can the new beams still improve?
|
||||
best_running_score = state.running_scores[:, -1:] / (max_length**length_penalty)
|
||||
# early_stopping == False -> apply heuristic = always get the best score from `cur_len`. See the discussion
|
||||
# below for more details.
|
||||
# https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565
|
||||
# early_stopping == "never" -> compute the best score from max_length or cur_len, depending on the sign of
|
||||
# length_penalty. Positive length_penalty favors longer sequences, thus we use max_length there.
|
||||
if early_stopping == "never" and length_penalty > 0.0:
|
||||
best_running_score = state.running_scores[:, :1] / (max_length**length_penalty)
|
||||
else:
|
||||
best_running_score = state.running_scores[:, :1] / (state.cur_len**length_penalty)
|
||||
worst_finished_score = jnp.where(
|
||||
state.is_sent_finished, jnp.min(state.scores, axis=1, keepdims=True), np.array(-1.0e7)
|
||||
)
|
||||
improvement_still_possible = jnp.all(worst_finished_score < best_running_score)
|
||||
improvement_still_possible = jnp.any(best_running_score > worst_finished_score)
|
||||
|
||||
# 3. is there still a beam that has not finished?
|
||||
still_open_beam = ~(jnp.all(state.is_sent_finished) & early_stopping)
|
||||
still_open_beam = ~(jnp.all(state.is_sent_finished) & (early_stopping is True))
|
||||
|
||||
return not_max_length_yet & still_open_beam & improvement_still_possible
|
||||
|
||||
@ -813,7 +822,7 @@ class FlaxGenerationMixin:
|
||||
# 5. Get running sequences scores for next
|
||||
# Determine the top k beam indices (from top 2*k beams) from log probs
|
||||
# and gather top k beams (from top 2*k beams).
|
||||
next_topk_indices = jnp.flip(lax.top_k(running_topk_log_probs, k=num_beams)[1], axis=1)
|
||||
next_topk_indices = lax.top_k(running_topk_log_probs, k=num_beams)[1]
|
||||
next_running_sequences, next_running_scores = gather_beams(
|
||||
[topk_sequences, running_topk_log_probs], next_topk_indices, batch_size, num_beams
|
||||
)
|
||||
@ -824,10 +833,9 @@ class FlaxGenerationMixin:
|
||||
# - make sure no scores can be added anymore if beam is full
|
||||
# - make sure still running sequences cannot be chosen as finalized beam
|
||||
topk_log_probs = topk_log_probs / (state.cur_len**length_penalty)
|
||||
beams_in_batch_are_full = (
|
||||
jnp.broadcast_to(state.is_sent_finished.all(axis=-1, keepdims=True), did_topk_just_finished.shape)
|
||||
& early_stopping
|
||||
)
|
||||
beams_in_batch_are_full = jnp.broadcast_to(
|
||||
state.is_sent_finished.all(axis=-1, keepdims=True), did_topk_just_finished.shape
|
||||
) & (early_stopping is True)
|
||||
add_penalty = ~did_topk_just_finished | beams_in_batch_are_full
|
||||
topk_log_probs += add_penalty * np.array(-1.0e7)
|
||||
|
||||
@ -838,7 +846,7 @@ class FlaxGenerationMixin:
|
||||
merged_sequences = jnp.concatenate([state.sequences, topk_sequences], axis=1)
|
||||
merged_scores = jnp.concatenate([state.scores, topk_log_probs], axis=1)
|
||||
merged_is_sent_finished = jnp.concatenate([state.is_sent_finished, did_topk_just_finished], axis=1)
|
||||
topk_merged_indices = jnp.flip(lax.top_k(merged_scores, k=num_beams)[1], axis=1)
|
||||
topk_merged_indices = lax.top_k(merged_scores, k=num_beams)[1]
|
||||
next_sequences, next_scores, next_is_sent_finished = gather_beams(
|
||||
[merged_sequences, merged_scores, merged_is_sent_finished], topk_merged_indices, batch_size, num_beams
|
||||
)
|
||||
@ -877,7 +885,7 @@ class FlaxGenerationMixin:
|
||||
scores = jnp.where(none_finished[:, None], state.scores, state.running_scores)
|
||||
|
||||
# take best beam for each batch
|
||||
sequences = sequences[:, -1]
|
||||
scores = scores[:, -1]
|
||||
sequences = sequences[:, 0]
|
||||
scores = scores[:, 0]
|
||||
|
||||
return FlaxBeamSearchOutput(sequences=sequences, scores=scores)
|
||||
|
@ -611,6 +611,7 @@ class TFGenerationMixin:
|
||||
|
||||
generation_config = copy.deepcopy(generation_config)
|
||||
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
|
||||
generation_config.validate()
|
||||
self._validate_model_kwargs(model_kwargs.copy())
|
||||
|
||||
# 2. Cast input dtypes to tf.int32 unless they're floats (which happens for some image models)
|
||||
@ -1808,7 +1809,7 @@ class TFGenerationMixin:
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
length_penalty: Optional[float] = None,
|
||||
early_stopping: Optional[bool] = None,
|
||||
early_stopping: Optional[Union[bool, str]] = None,
|
||||
logits_processor: Optional[TFLogitsProcessorList] = None,
|
||||
logits_warper: Optional[TFLogitsProcessorList] = None,
|
||||
num_return_sequences: Optional[int] = None,
|
||||
@ -1838,8 +1839,12 @@ class TFGenerationMixin:
|
||||
to the sequence length, which in turn is used to divide the score of the sequence. Since the score is
|
||||
the log likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences,
|
||||
while `length_penalty` < 0.0 encourages shorter sequences.
|
||||
early_stopping (`bool`, *optional*, defaults to `False`):
|
||||
Whether to stop the beam search when at least `num_beams` sentences are finished per batch or not.
|
||||
early_stopping (`bool` or `str`, *optional*, defaults to `False`):
|
||||
Controls the stopping condition for beam-based methods, like beam-search. It accepts the following
|
||||
values: `True`, where the generation stops as soon as there are `num_beams` complete candidates;
|
||||
`False`, where an heuristic is applied and the generation stops when is it very unlikely to find better
|
||||
candidates; `"never"`, where the beam search procedure only stops when there cannot be better
|
||||
candidates (canonical beam search algorithm).
|
||||
logits_processor (`[TFLogitsProcessorList]`, *optional*):
|
||||
An instance of [`TFLogitsProcessorList`]. List of instances of class derived from [`TFLogitsProcessor`]
|
||||
used to modify the prediction scores of the language modeling head applied at each generation step.
|
||||
@ -2009,16 +2014,24 @@ class TFGenerationMixin:
|
||||
not_max_length_yet = cur_len < max_length
|
||||
|
||||
# 2. can the new beams still improve?
|
||||
best_running_score = running_scores[:, :1] / (max_length**length_penalty)
|
||||
# early_stopping == False -> apply heuristic = always get the best score from `cur_len`. See the discussion
|
||||
# below for more details.
|
||||
# https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565
|
||||
# early_stopping == "never" -> compute the best score from max_length or cur_len, depending on the sign of
|
||||
# length_penalty. Positive length_penalty favors longer sequences, thus we use max_length there.
|
||||
if early_stopping == "never" and length_penalty > 0.0:
|
||||
best_running_score = running_scores[:, :1] / (max_length**length_penalty)
|
||||
else:
|
||||
best_running_score = running_scores[:, :1] / (tf.cast(cur_len, dtype=tf.float32) ** length_penalty)
|
||||
worst_finished_score = tf.where(
|
||||
is_sent_finished, tf.math.reduce_min(scores, axis=1, keepdims=True), -1.0e9
|
||||
)
|
||||
improvement_still_possible = tf.math.reduce_all(worst_finished_score < best_running_score)
|
||||
improvement_still_possible = tf.math.reduce_any(best_running_score > worst_finished_score)
|
||||
|
||||
# 3. is there still a beam that has not finished?
|
||||
still_open_beam = ~(tf.math.reduce_all(is_sent_finished) & early_stopping)
|
||||
still_open_beam = ~(tf.math.reduce_all(is_sent_finished) & (early_stopping is True))
|
||||
|
||||
return not_max_length_yet & (still_open_beam | improvement_still_possible)
|
||||
return not_max_length_yet & still_open_beam & improvement_still_possible
|
||||
|
||||
def beam_search_body_fn(
|
||||
cur_len,
|
||||
@ -2140,12 +2153,9 @@ class TFGenerationMixin:
|
||||
# - make sure no scores can be added anymore if beam is full
|
||||
# - make sure still running sequences cannot be chosen as finalized beam
|
||||
topk_log_probs = topk_log_probs / (tf.cast(cur_len, dtype=tf.float32) ** length_penalty)
|
||||
beams_in_batch_are_full = (
|
||||
tf.broadcast_to(
|
||||
tf.math.reduce_all(is_sent_finished, axis=-1, keepdims=True), shape_list(did_topk_just_finished)
|
||||
)
|
||||
& early_stopping
|
||||
)
|
||||
beams_in_batch_are_full = tf.broadcast_to(
|
||||
tf.math.reduce_all(is_sent_finished, axis=-1, keepdims=True), shape_list(did_topk_just_finished)
|
||||
) & (early_stopping is True)
|
||||
add_penalty = ~did_topk_just_finished | beams_in_batch_are_full
|
||||
topk_log_probs += tf.cast(add_penalty, tf.float32) * -1.0e9
|
||||
|
||||
@ -2239,7 +2249,7 @@ class TFGenerationMixin:
|
||||
sequences = tf.where(none_finished[:, None, None], sequences, running_sequences)
|
||||
scores = tf.where(none_finished[:, None], scores, running_scores)
|
||||
|
||||
# Take best beams for each batch (the score is sorted in ascending order)
|
||||
# Take best beams for each batch (the score is sorted in descending order)
|
||||
sequences = flatten_beam_dim(sequences[:, :num_return_sequences, :])
|
||||
scores = flatten_beam_dim(scores[:, :num_return_sequences])
|
||||
|
||||
|
@ -1190,6 +1190,7 @@ class GenerationMixin:
|
||||
|
||||
generation_config = copy.deepcopy(generation_config)
|
||||
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
|
||||
generation_config.validate()
|
||||
self._validate_model_kwargs(model_kwargs.copy())
|
||||
|
||||
# 2. Set generation parameters if not already defined
|
||||
@ -1458,6 +1459,7 @@ class GenerationMixin:
|
||||
length_penalty=generation_config.length_penalty,
|
||||
do_early_stopping=generation_config.early_stopping,
|
||||
num_beam_hyps_to_keep=generation_config.num_return_sequences,
|
||||
max_length=generation_config.max_length,
|
||||
)
|
||||
# 12. interleave input_ids with `num_beams` additional sequences per batch
|
||||
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
||||
@ -1493,6 +1495,7 @@ class GenerationMixin:
|
||||
device=inputs_tensor.device,
|
||||
length_penalty=generation_config.length_penalty,
|
||||
do_early_stopping=generation_config.early_stopping,
|
||||
max_length=generation_config.max_length,
|
||||
)
|
||||
|
||||
# 13. interleave input_ids with `num_beams` additional sequences per batch
|
||||
@ -1536,12 +1539,12 @@ class GenerationMixin:
|
||||
beam_scorer = BeamSearchScorer(
|
||||
batch_size=batch_size,
|
||||
num_beams=generation_config.num_beams,
|
||||
max_length=stopping_criteria.max_length,
|
||||
device=inputs_tensor.device,
|
||||
length_penalty=generation_config.length_penalty,
|
||||
do_early_stopping=generation_config.early_stopping,
|
||||
num_beam_hyps_to_keep=generation_config.num_return_sequences,
|
||||
num_beam_groups=generation_config.num_beam_groups,
|
||||
max_length=generation_config.max_length,
|
||||
)
|
||||
# 12. interleave input_ids with `num_beams` additional sequences per batch
|
||||
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
||||
@ -1629,6 +1632,7 @@ class GenerationMixin:
|
||||
length_penalty=generation_config.length_penalty,
|
||||
do_early_stopping=generation_config.early_stopping,
|
||||
num_beam_hyps_to_keep=generation_config.num_return_sequences,
|
||||
max_length=generation_config.max_length,
|
||||
)
|
||||
# 12. interleave input_ids with `num_beams` additional sequences per batch
|
||||
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
||||
|
@ -1566,6 +1566,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
length_penalty=generation_config.length_penalty,
|
||||
do_early_stopping=generation_config.early_stopping,
|
||||
num_beam_hyps_to_keep=generation_config.num_return_sequences,
|
||||
max_length=generation_config.max_length,
|
||||
)
|
||||
return self.beam_search(
|
||||
input_ids,
|
||||
|
@ -2034,59 +2034,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
def test_beam_search_warning_if_max_length_is_passed(self):
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
batch_size = 1
|
||||
num_beams = 3
|
||||
|
||||
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||
input_ids = input_ids.expand(num_beams, -1)
|
||||
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
|
||||
|
||||
# pretend decoder_input_ids correspond to first encoder input id
|
||||
decoder_input_ids = input_ids[:, :1]
|
||||
|
||||
stopping_criteria_max_length = 18
|
||||
stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=stopping_criteria_max_length)])
|
||||
|
||||
with self.assertWarns(UserWarning):
|
||||
beam_scorer = BeamSearchScorer(
|
||||
batch_size=batch_size,
|
||||
num_beams=num_beams,
|
||||
device=torch_device,
|
||||
max_length=10,
|
||||
)
|
||||
|
||||
generated_ids = bart_model.beam_search(
|
||||
decoder_input_ids,
|
||||
num_beams=num_beams,
|
||||
stopping_criteria=stopping_criteria,
|
||||
beam_scorer=beam_scorer,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
beam_scorer_no_max_len = BeamSearchScorer(
|
||||
batch_size=batch_size,
|
||||
num_beams=num_beams,
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
generated_ids_no_max_len = bart_model.beam_search(
|
||||
decoder_input_ids,
|
||||
num_beams=num_beams,
|
||||
stopping_criteria=stopping_criteria,
|
||||
beam_scorer=beam_scorer_no_max_len,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
# BeamSearchScorer max_length should not influence "real" max_length
|
||||
self.assertEqual(generated_ids.tolist(), generated_ids_no_max_len.tolist())
|
||||
|
||||
def test_custom_stopping_criteria_overload_error(self):
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
|
||||
|
@ -426,7 +426,7 @@ class FlaxBartModelTest(FlaxModelTesterMixin, unittest.TestCase, FlaxGenerationT
|
||||
)
|
||||
|
||||
input_ids = tokenizer(input_str, return_tensors="np").input_ids
|
||||
sequences = model.generate(input_ids, num_beams=2, max_length=20).sequences
|
||||
sequences = model.generate(input_ids, num_beams=2, min_length=None, max_length=20).sequences
|
||||
|
||||
output_str = tokenizer.batch_decode(sequences)[0]
|
||||
|
||||
|
@ -224,7 +224,7 @@ class FlaxGPT2ModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittes
|
||||
output_string = tokenizer.batch_decode(output_sequences, skip_special_tokens=True)
|
||||
|
||||
expected_string = [
|
||||
"Hello this is a long string of words. I'm going to try to explain what I mean.",
|
||||
"Hello this is a long string of words. I'm going to start with the first one.\n",
|
||||
"Hey, I'm not sure if I'm going to be able to do",
|
||||
]
|
||||
|
||||
|
@ -1076,7 +1076,7 @@ class FlaxT5ModelIntegrationTests(unittest.TestCase):
|
||||
expected_summaries = [
|
||||
'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a'
|
||||
" cell phone video of the final seconds . \"one can hear cries of 'My God' in several languages,\" one"
|
||||
" magazine says . all 150 on board were killed when germanwings flight 9525 crashed .",
|
||||
" magazine says . all 150 on board were killed in the crash .",
|
||||
"the formal accession was marked by a ceremony at The Hague, in the Netherlands . the ICC opened a"
|
||||
" preliminary examination into the situation in the occupied Palestinian territory . as members of the"
|
||||
" court, Palestinians may be subject to counter-charges as well .",
|
||||
|
Loading…
Reference in New Issue
Block a user