Generate: unify LogitsWarper and LogitsProcessor (#32626)

This commit is contained in:
Joao Gante 2024-08-16 11:20:41 +01:00 committed by GitHub
parent 5fd7ca7bc9
commit 70d5df6107
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 186 additions and 623 deletions

View File

@ -158,9 +158,6 @@ generation.
[[autodoc]] LogitsProcessorList
- __call__
[[autodoc]] LogitsWarper
- __call__
[[autodoc]] MinLengthLogitsProcessor
- __call__
@ -421,4 +418,3 @@ A [`Constraint`] can be used to force the generation to include specific tokens
[[autodoc]] WatermarkDetector
- __call__

View File

@ -157,9 +157,6 @@ generation_output[:2]
[[autodoc]] LogitsProcessorList
- __call__
[[autodoc]] LogitsWarper
- __call__
[[autodoc]] MinLengthLogitsProcessor
- __call__

View File

@ -151,9 +151,6 @@ generation_output[:2]
[[autodoc]] LogitsProcessorList
- __call__
[[autodoc]] LogitsWarper
- __call__
[[autodoc]] MinLengthLogitsProcessor
- __call__

View File

@ -190,9 +190,9 @@ class GenerationConfig(PushToHubMixin):
triggers a [disjunctive constraint](https://github.com/huggingface/transformers/issues/14081), where one
can allow different forms of each word.
renormalize_logits (`bool`, *optional*, defaults to `False`):
Whether to renormalize the logits after applying all the logits processors or warpers (including the custom
Whether to renormalize the logits after applying all the logits processors (including the custom
ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the score logits
are normalized but some logit processors or warpers break the normalization.
are normalized but some logit processors break the normalization.
constraints (`List[Constraint]`, *optional*):
Custom constraints that can be added to the generation to ensure that the output will contain the use of
certain tokens as defined by `Constraint` objects, in the most sensible way possible.

View File

@ -55,6 +55,12 @@ class LogitsProcessor:
class LogitsWarper:
"""Abstract base class for all logit warpers that can be applied during generation with multinomial sampling."""
def __init__(self):
logger.warning_once(
"`LogitsWarper` is deprecated and will be removed in v4.48. Your class should inherit `LogitsProcessor` "
"instead, which has the same properties and interface."
)
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
raise NotImplementedError(
@ -64,9 +70,9 @@ class LogitsWarper:
class LogitsProcessorList(list):
"""
This class can be used to create a list of [`LogitsProcessor`] or [`LogitsWarper`] to subsequently process a
`scores` input tensor. This class inherits from list and adds a specific *__call__* method to apply each
[`LogitsProcessor`] or [`LogitsWarper`] to the inputs.
This class can be used to create a list of [`LogitsProcessor`] to subsequently process a `scores` input tensor.
This class inherits from list and adds a specific *__call__* method to apply each [`LogitsProcessor`] to the
inputs.
"""
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
@ -233,9 +239,9 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
return scores_processed
class TemperatureLogitsWarper(LogitsWarper):
class TemperatureLogitsWarper(LogitsProcessor):
r"""
[`LogitsWarper`] for temperature (exponential scaling output probability distribution), which effectively means
[`LogitsProcessor`] for temperature (exponential scaling output probability distribution), which effectively means
that it can control the randomness of the predicted tokens. Often used together with [`TopPLogitsWarper`] and
[`TopKLogitsWarper`].
@ -408,10 +414,10 @@ class EncoderRepetitionPenaltyLogitsProcessor(LogitsProcessor):
return scores_processed
class TopPLogitsWarper(LogitsWarper):
class TopPLogitsWarper(LogitsProcessor):
"""
[`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off. Often
used together with [`TemperatureLogitsWarper`] and [`TopKLogitsWarper`].
[`LogitsProcessor`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
Often used together with [`TemperatureLogitsWarper`] and [`TopKLogitsWarper`].
Args:
top_p (`float`):
@ -475,10 +481,10 @@ class TopPLogitsWarper(LogitsWarper):
return scores_processed
class TopKLogitsWarper(LogitsWarper):
class TopKLogitsWarper(LogitsProcessor):
r"""
[`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements. Often used together
with [`TemperatureLogitsWarper`] and [`TopPLogitsWarper`].
[`LogitsProcessor`] that performs top-k, i.e. restricting to the k highest probability elements. Often used
together with [`TemperatureLogitsWarper`] and [`TopPLogitsWarper`].
Args:
top_k (`int`):
@ -528,9 +534,9 @@ class TopKLogitsWarper(LogitsWarper):
return scores_processed
class MinPLogitsWarper(LogitsWarper):
class MinPLogitsWarper(LogitsProcessor):
"""
[`LogitsWarper`] that performs min-p, i.e. keeps all tokens that are above a minimum probability, scaled by the
[`LogitsProcessor`] that performs min-p, i.e. keeps all tokens that are above a minimum probability, scaled by the
probability of the most likely token. As a result, the filter becomes more agressive in the presence of
high-probability tokens, which is a sign of a confident output that we shouldn't deviate from.
@ -605,11 +611,11 @@ class MinPLogitsWarper(LogitsWarper):
return scores_processed
class TypicalLogitsWarper(LogitsWarper):
class TypicalLogitsWarper(LogitsProcessor):
r"""
[`LogitsWarper`] that performs typical decoding. Inspired on how humans use language, it prioritizes tokens whose
log probability is close to the entropy of the token probability distribution. This means that the most likely
tokens may be discarded in the process.
[`LogitsProcessor`] that performs typical decoding. Inspired on how humans use language, it prioritizes tokens
whose log probability is close to the entropy of the token probability distribution. This means that the most
likely tokens may be discarded in the process.
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information.
@ -693,9 +699,9 @@ class TypicalLogitsWarper(LogitsWarper):
return scores_processed
class EpsilonLogitsWarper(LogitsWarper):
class EpsilonLogitsWarper(LogitsProcessor):
r"""
[`LogitsWarper`] that performs epsilon-sampling, i.e. restricting to tokens with `prob >= epsilon`. Takes the
[`LogitsProcessor`] that performs epsilon-sampling, i.e. restricting to tokens with `prob >= epsilon`. Takes the
largest min_tokens_to_keep tokens if no tokens satisfy this constraint. See [Truncation Sampling as Language Model
Desmoothing](https://arxiv.org/abs/2210.15191) for more information.
@ -762,15 +768,15 @@ class EpsilonLogitsWarper(LogitsWarper):
return scores_processed
class EtaLogitsWarper(LogitsWarper):
class EtaLogitsWarper(LogitsProcessor):
r"""
[`LogitsWarper`] that performs eta-sampling, a technique to filter out tokens with probabilities below a dynamic
[`LogitsProcessor`] that performs eta-sampling, a technique to filter out tokens with probabilities below a dynamic
cutoff value, `eta`, which is calculated based on a combination of the hyperparameter `epsilon` and the entropy of
the token probabilities, i.e. `eta := min(epsilon, sqrt(epsilon * e^-entropy(probabilities)))`. Takes the largest
min_tokens_to_keep tokens if no tokens satisfy this constraint. It addresses the issue of poor quality in long
samples of text generated by neural language models leading to more coherent and fluent text. See [Truncation
Sampling as Language Model Desmoothing](https://arxiv.org/abs/2210.15191) for more information. Note: `do_sample`
must be set to `True` for this `LogitsWarper` to work.
must be set to `True` for this `LogitsProcessor` to work.
Args:
@ -1708,9 +1714,9 @@ class ExponentialDecayLengthPenalty(LogitsProcessor):
return scores_processed
class LogitNormalization(LogitsProcessor, LogitsWarper):
class LogitNormalization(LogitsProcessor):
r"""
[`LogitsWarper`] and [`LogitsProcessor`] for normalizing the scores using log-softmax. It's important to normalize
[`LogitsProcessor`] for normalizing the scores using log-softmax. It's important to normalize
the scores during beam search, after applying the logits processors or warpers, since the search algorithm used in
this library doesn't do it (it only does it before, but they may need re-normalization) but it still supposes that
the scores are normalized when comparing the hypotheses.

View File

@ -735,61 +735,6 @@ class GenerationMixin:
)
return candidate_generator
def _get_logits_warper(
self,
generation_config: GenerationConfig,
device: str,
) -> LogitsProcessorList:
"""
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances
used for multinomial sampling.
"""
# instantiate warpers list
warpers = LogitsProcessorList()
# In beam methods, we need to keep at least one non-eos token to explore continuations that might have a
# better score (i.e. keep len(list(generation_config._eos_token_tensor)) + 1)
if generation_config.num_beams > 1:
if isinstance(generation_config._eos_token_tensor, list):
min_tokens_to_keep = len(generation_config._eos_token_tensor) + 1
elif isinstance(generation_config._eos_token_tensor, torch.Tensor):
min_tokens_to_keep = generation_config._eos_token_tensor.shape[0] + 1
else:
min_tokens_to_keep = 2
else:
min_tokens_to_keep = 1
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# all samplers can be found in `generation_utils_samplers.py`
if generation_config.temperature is not None and generation_config.temperature != 1.0:
warpers.append(TemperatureLogitsWarper(generation_config.temperature))
if generation_config.top_k is not None and generation_config.top_k != 0:
warpers.append(TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep))
if generation_config.top_p is not None and generation_config.top_p < 1.0:
warpers.append(TopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep))
if generation_config.min_p is not None:
# Applied after temperature scaling (see https://github.com/ggerganov/llama.cpp/pull/3841#issuecomment-2073826084)
warpers.append(MinPLogitsWarper(min_p=generation_config.min_p, min_tokens_to_keep=min_tokens_to_keep))
if generation_config.typical_p is not None and generation_config.typical_p < 1.0:
warpers.append(
TypicalLogitsWarper(mass=generation_config.typical_p, min_tokens_to_keep=min_tokens_to_keep)
)
if generation_config.epsilon_cutoff is not None and 0.0 < generation_config.epsilon_cutoff < 1.0:
warpers.append(
EpsilonLogitsWarper(epsilon=generation_config.epsilon_cutoff, min_tokens_to_keep=min_tokens_to_keep)
)
if generation_config.eta_cutoff is not None and 0.0 < generation_config.eta_cutoff < 1.0:
warpers.append(
EtaLogitsWarper(
epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep, device=device
)
)
# `LogitNormalization` should always be the last logit processor, when present
if generation_config.renormalize_logits is True:
warpers.append(LogitNormalization())
return warpers
def _get_logits_processor(
self,
generation_config: GenerationConfig,
@ -960,7 +905,58 @@ class GenerationMixin:
context_width=generation_config.watermarking_config.context_width,
)
)
# TODO (joao): find a strategy to specify the order of the processors
processors = self._merge_criteria_processor_list(processors, logits_processor)
# Processors previously known as `LogitsWarpers`, only applied with sampling strategies
if generation_config.do_sample:
# In beam methods, we need to keep at least one non-eos token to explore continuations that might have a
# better score (i.e. keep len(list(generation_config._eos_token_tensor)) + 1)
if generation_config.num_beams > 1:
if isinstance(generation_config._eos_token_tensor, list):
min_tokens_to_keep = len(generation_config._eos_token_tensor) + 1
elif isinstance(generation_config._eos_token_tensor, torch.Tensor):
min_tokens_to_keep = generation_config._eos_token_tensor.shape[0] + 1
else:
min_tokens_to_keep = 2
else:
min_tokens_to_keep = 1
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# all samplers can be found in `generation_utils_samplers.py`
if generation_config.temperature is not None and generation_config.temperature != 1.0:
processors.append(TemperatureLogitsWarper(generation_config.temperature))
if generation_config.top_k is not None and generation_config.top_k != 0:
processors.append(
TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep)
)
if generation_config.top_p is not None and generation_config.top_p < 1.0:
processors.append(
TopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep)
)
if generation_config.min_p is not None:
# Applied after temperature scaling (see https://github.com/ggerganov/llama.cpp/pull/3841#issuecomment-2073826084)
processors.append(
MinPLogitsWarper(min_p=generation_config.min_p, min_tokens_to_keep=min_tokens_to_keep)
)
if generation_config.typical_p is not None and generation_config.typical_p < 1.0:
processors.append(
TypicalLogitsWarper(mass=generation_config.typical_p, min_tokens_to_keep=min_tokens_to_keep)
)
if generation_config.epsilon_cutoff is not None and 0.0 < generation_config.epsilon_cutoff < 1.0:
processors.append(
EpsilonLogitsWarper(
epsilon=generation_config.epsilon_cutoff, min_tokens_to_keep=min_tokens_to_keep
)
)
if generation_config.eta_cutoff is not None and 0.0 < generation_config.eta_cutoff < 1.0:
processors.append(
EtaLogitsWarper(
epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep, device=device
)
)
# `LogitNormalization` should always be the last logit processor, when present
if generation_config.renormalize_logits is True:
processors.append(LogitNormalization())
@ -1940,22 +1936,11 @@ class GenerationMixin:
model_kwargs=model_kwargs,
)
# 12. prepare logits warper (if `do_sample` is `True`)
prepared_logits_warper = (
self._get_logits_warper(
generation_config,
device=input_ids.device,
)
if generation_config.do_sample
else None
)
# 13. run assisted generate
# 12. run assisted generate
result = self._assisted_decoding(
input_ids,
candidate_generator=candidate_generator,
logits_processor=prepared_logits_processor,
logits_warper=prepared_logits_warper,
stopping_criteria=prepared_stopping_criteria,
generation_config=generation_config,
synced_gpus=synced_gpus,
@ -1968,16 +1953,10 @@ class GenerationMixin:
raise ValueError(
f"dola decoding is not supported with stateful models, such as {self.__class__.__name__}"
)
prepared_logits_warper = (
self._get_logits_warper(generation_config, device=input_ids.device)
if generation_config.do_sample
else None
)
result = self._dola_decoding(
input_ids,
dola_layers=generation_config.dola_layers,
logits_processor=prepared_logits_processor,
logits_warper=prepared_logits_warper,
stopping_criteria=prepared_stopping_criteria,
generation_config=generation_config,
synced_gpus=synced_gpus,
@ -2005,14 +1984,7 @@ class GenerationMixin:
)
elif generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
# 11. prepare logits warper
prepared_logits_warper = (
self._get_logits_warper(generation_config, device=input_ids.device)
if generation_config.do_sample
else None
)
# 12. expand input_ids with `num_return_sequences` additional sequences per batch
# 11. expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
expand_size=generation_config.num_return_sequences,
@ -2020,11 +1992,10 @@ class GenerationMixin:
**model_kwargs,
)
# 13. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
# 12. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
result = self._sample(
input_ids,
logits_processor=prepared_logits_processor,
logits_warper=prepared_logits_warper,
stopping_criteria=prepared_stopping_criteria,
generation_config=generation_config,
synced_gpus=synced_gpus,
@ -2033,14 +2004,7 @@ class GenerationMixin:
)
elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
# 11. prepare logits warper
prepared_logits_warper = (
self._get_logits_warper(generation_config, device=input_ids.device)
if generation_config.do_sample
else None
)
# 12. prepare beam search scorer
# 11. prepare beam search scorer
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
num_beams=generation_config.num_beams,
@ -2051,7 +2015,7 @@ class GenerationMixin:
max_length=generation_config.max_length,
)
# 13. interleave input_ids with `num_beams` additional sequences per batch
# 12. interleave input_ids with `num_beams` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
expand_size=generation_config.num_beams,
@ -2059,12 +2023,11 @@ class GenerationMixin:
**model_kwargs,
)
# 14. run beam sample
# 13. run beam sample
result = self._beam_search(
input_ids,
beam_scorer,
logits_processor=prepared_logits_processor,
logits_warper=prepared_logits_warper,
stopping_criteria=prepared_stopping_criteria,
generation_config=generation_config,
synced_gpus=synced_gpus,
@ -2287,7 +2250,6 @@ class GenerationMixin:
generation_config: GenerationConfig,
synced_gpus: bool,
streamer: "BaseStreamer",
logits_warper: Optional[LogitsProcessorList],
**model_kwargs,
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
r"""
@ -2316,10 +2278,6 @@ class GenerationMixin:
streamer (`BaseStreamer`, *optional*):
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
logits_warper (`LogitsProcessorList`, *optional*):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
to warp the prediction score distribution of the language modeling head applied before multinomial
sampling at each generation step.
model_kwargs:
Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
@ -2344,11 +2302,6 @@ class GenerationMixin:
return_dict_in_generate = generation_config.return_dict_in_generate
has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
do_sample = generation_config.do_sample
if do_sample is True and not isinstance(logits_warper, LogitsProcessorList):
raise ValueError(
"`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is "
f"{logits_warper})."
)
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
@ -2436,8 +2389,7 @@ class GenerationMixin:
)
# pre-process distribution
next_token_scores = logits_processor(input_ids, next_token_logits)
if do_sample: # sample
next_token_scores = logits_warper(input_ids, next_token_scores)
# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
if output_scores:
@ -2893,7 +2845,6 @@ class GenerationMixin:
generation_config: GenerationConfig,
synced_gpus: bool,
streamer: Optional["BaseStreamer"],
logits_warper: Optional[LogitsProcessorList],
**model_kwargs,
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
r"""
@ -2916,11 +2867,6 @@ class GenerationMixin:
streamer (`BaseStreamer`, *optional*):
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
logits_warper (`LogitsProcessorList`, *optional*):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
to warp the prediction score distribution of the language modeling head applied before multinomial
sampling at each generation step. Only required with sampling strategies (i.e. `do_sample` is set in
`generation_config`)
model_kwargs:
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
an encoder-decoder model the kwargs should include `encoder_outputs`.
@ -2942,11 +2888,6 @@ class GenerationMixin:
max_length = generation_config.max_length
has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
do_sample = generation_config.do_sample
if do_sample is True and not isinstance(logits_warper, LogitsProcessorList):
raise ValueError(
"`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is "
f"{logits_warper})."
)
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
@ -2990,8 +2931,6 @@ class GenerationMixin:
# pre-process distribution
next_token_scores = logits_processor(input_ids, next_token_logits)
if do_sample:
next_token_scores = logits_warper(input_ids, next_token_scores)
# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
@ -3105,7 +3044,6 @@ class GenerationMixin:
stopping_criteria: StoppingCriteriaList,
generation_config: GenerationConfig,
synced_gpus: bool,
logits_warper: Optional[LogitsProcessorList],
**model_kwargs,
) -> Union[GenerateBeamOutput, torch.LongTensor]:
r"""
@ -3128,11 +3066,6 @@ class GenerationMixin:
The generation configuration to be used as parametrization of the decoding method.
synced_gpus (`bool`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
logits_warper (`LogitsProcessorList`, *optional*):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
to warp the prediction score distribution of the language modeling head applied before multinomial
sampling at each generation step. Only required with sampling strategies (i.e. `do_sample` is set in
`generation_config`)
model_kwargs:
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
an encoder-decoder model the kwargs should include `encoder_outputs`.
@ -3154,11 +3087,6 @@ class GenerationMixin:
return_dict_in_generate = generation_config.return_dict_in_generate
sequential = generation_config.low_memory
do_sample = generation_config.do_sample
if do_sample is True and not isinstance(logits_warper, LogitsProcessorList):
raise ValueError(
"`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is "
f"{logits_warper})."
)
batch_size = len(beam_scorer._beam_hyps)
num_beams = beam_scorer.num_beams
@ -3249,8 +3177,6 @@ class GenerationMixin:
) # (batch_size * num_beams, vocab_size)
next_token_scores_processed = logits_processor(input_ids, next_token_scores)
if do_sample:
next_token_scores_processed = logits_warper(input_ids, next_token_scores_processed)
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(
next_token_scores_processed
)
@ -3698,10 +3624,6 @@ class GenerationMixin:
stopping_criteria (`StoppingCriteriaList`):
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
used to tell if the generation loop should stop.
logits_warper (`LogitsProcessorList`):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
to warp the prediction score distribution of the language modeling head applied before multinomial
sampling at each generation step.
generation_config ([`~generation.GenerationConfig`]):
The generation configuration to be used as parametrization of the decoding method.
synced_gpus (`bool`):
@ -3915,7 +3837,6 @@ class GenerationMixin:
input_ids: torch.LongTensor,
candidate_generator: CandidateGenerator,
logits_processor: LogitsProcessorList,
logits_warper: LogitsProcessorList,
stopping_criteria: StoppingCriteriaList,
generation_config: GenerationConfig,
synced_gpus: bool,
@ -3937,10 +3858,6 @@ class GenerationMixin:
logits_processor (`LogitsProcessorList`):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
used to modify the prediction scores of the language modeling head applied at each generation step.
logits_warper (`LogitsProcessorList`):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
to warp the prediction score distribution of the language modeling head applied before multinomial
sampling at each generation step. Only used if sampling is active.
stopping_criteria (`StoppingCriteriaList`):
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
used to tell if the generation loop should stop.
@ -3963,7 +3880,7 @@ class GenerationMixin:
`model.config.is_encoder_decoder=True`.
"""
# init values
do_sample = logits_warper is not None
do_sample = generation_config.do_sample
output_attentions = generation_config.output_attentions
output_hidden_states = generation_config.output_hidden_states
output_scores = generation_config.output_scores
@ -4047,9 +3964,6 @@ class GenerationMixin:
if len(logits_processor) > 0:
for i in range(candidate_length + 1):
new_logits[:, i, :] = logits_processor(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :])
if do_sample and len(logits_warper) > 0:
for i in range(candidate_length + 1):
new_logits[:, i, :] = logits_warper(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :])
# 3. Select the accepted tokens. There are two possible cases:
# Case 1: `do_sample=True` and we have logits for the candidates (originally from speculative decoding)

View File

@ -56,9 +56,9 @@ class BarkSemanticGenerationConfig(GenerationConfig):
eos_token_id (`int`, *optional*, defaults to 10_000):
The id of the *end-of-sequence* token.
renormalize_logits (`bool`, *optional*, defaults to `True`):
Whether to renormalize the logits after applying all the logits processors or warpers (including the
Whether to renormalize the logits after applying all the logits processors (including the
custom ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the
score logits are normalized but some logit processors or warpers break the normalization.
score logits are normalized but some logit processors break the normalization.
max_new_tokens (`int`, *optional*, defaults to 768):
The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.
output_scores (`bool`, *optional*, defaults to `False`):
@ -143,9 +143,9 @@ class BarkCoarseGenerationConfig(GenerationConfig):
Args:
renormalize_logits (`bool`, *optional*, defaults to `True`):
Whether to renormalize the logits after applying all the logits processors or warpers (including the
Whether to renormalize the logits after applying all the logits processors (including the
custom ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the
score logits are normalized but some logit processors or warpers break the normalization.
score logits are normalized but some logit processors break the normalization.
output_scores (`bool`, *optional*, defaults to `False`):
Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
return_dict_in_generate (`bool`, *optional*, defaults to `False`):

View File

@ -1609,13 +1609,6 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
)
if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
# 11. prepare logits warper
prepared_logits_warper = (
self._get_logits_warper(generation_config, device=input_ids.device)
if generation_config.do_sample
else None
)
# expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
@ -1623,11 +1616,10 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
**model_kwargs,
)
# 12. run sample
# 11. run sample
outputs = self._sample(
input_ids,
logits_processor=logits_processor,
logits_warper=prepared_logits_warper,
stopping_criteria=stopping_criteria,
generation_config=generation_config,
synced_gpus=synced_gpus,
@ -2649,13 +2641,6 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
)
if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
# 11. prepare logits warper
prepared_logits_warper = (
self._get_logits_warper(generation_config, device=input_ids.device)
if generation_config.do_sample
else None
)
# expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
@ -2664,11 +2649,10 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
**model_kwargs,
)
# 12. run sample
# 11. run sample
outputs = self._sample(
input_ids,
logits_processor=logits_processor,
logits_warper=prepared_logits_warper,
stopping_criteria=stopping_criteria,
generation_config=generation_config,
synced_gpus=synced_gpus,

View File

@ -1531,13 +1531,6 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
)
if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
# 11. prepare logits warper
prepared_logits_warper = (
self._get_logits_warper(generation_config, device=input_ids.device)
if generation_config.do_sample
else None
)
# expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
@ -1545,11 +1538,10 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
**model_kwargs,
)
# 12. run sample
# 11. run sample
outputs = self._sample(
input_ids,
logits_processor=logits_processor,
logits_warper=prepared_logits_warper,
stopping_criteria=stopping_criteria,
generation_config=generation_config,
synced_gpus=synced_gpus,
@ -2490,13 +2482,6 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
)
if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
# 11. prepare logits warper
prepared_logits_warper = (
self._get_logits_warper(generation_config, device=input_ids.device)
if generation_config.do_sample
else None
)
# expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
@ -2505,11 +2490,10 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
**model_kwargs,
)
# 12. run sample
# 11. run sample
outputs = self._sample(
input_ids,
logits_processor=logits_processor,
logits_warper=prepared_logits_warper,
stopping_criteria=stopping_criteria,
generation_config=generation_config,
synced_gpus=synced_gpus,

View File

@ -1558,7 +1558,6 @@ class RagTokenForGeneration(RagPreTrainedModel):
generation_config=generation_config,
synced_gpus=False,
streamer=None,
logits_warper=None,
**model_kwargs,
)
elif generation_config.num_beams > 1:
@ -1580,7 +1579,6 @@ class RagTokenForGeneration(RagPreTrainedModel):
stopping_criteria=prepared_stopping_criteria,
generation_config=generation_config,
synced_gpus=False,
logits_warper=None,
**model_kwargs,
)
else:

View File

@ -118,26 +118,24 @@ class GenerationTesterMixin:
return config, input_ids, attention_mask
@staticmethod
def _get_logits_processor_and_warper_kwargs(
input_length,
forced_bos_token_id=None,
forced_eos_token_id=None,
):
process_kwargs = {
def _get_logits_processor_kwargs(self, do_sample=False):
logits_processor_kwargs = {
"bad_words_ids": [[1, 0]],
"repetition_penalty": 1.2,
"remove_invalid_values": True,
}
# NoRepeatNGramLogitsProcessor + forced tokens may result in no valid continuations
if forced_bos_token_id is None and forced_eos_token_id is None:
process_kwargs["no_repeat_ngram_size"] = 2
if do_sample:
logits_processor_kwargs.update(
{
"top_k": 10,
"top_p": 0.7,
"temperature": 0.7,
}
)
warp_kwargs = {"top_k": 10, "top_p": 0.7, "temperature": 0.7}
return process_kwargs, warp_kwargs
return logits_processor_kwargs
@staticmethod
def _get_beam_kwargs(num_return_sequences=1):
def _get_beam_kwargs(self, num_return_sequences=1):
beam_kwargs = {
"early_stopping": False,
"length_penalty": 2.0,
@ -146,8 +144,7 @@ class GenerationTesterMixin:
}
return beam_kwargs
@staticmethod
def _get_diverse_beam_kwargs(num_return_sequences=1):
def _get_diverse_beam_kwargs(self, num_return_sequences=1):
beam_kwargs = {
"early_stopping": False,
"length_penalty": 2.0,
@ -158,8 +155,7 @@ class GenerationTesterMixin:
}
return beam_kwargs
@staticmethod
def _get_constrained_beam_kwargs(num_return_sequences=1):
def _get_constrained_beam_kwargs(self, num_return_sequences=1):
beam_kwargs = {
"early_stopping": False,
"length_penalty": 2.0,
@ -199,12 +195,7 @@ class GenerationTesterMixin:
output_hidden_states=False,
return_dict_in_generate=False,
):
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
input_ids.shape[-1],
forced_bos_token_id=model.config.forced_bos_token_id,
forced_eos_token_id=model.config.forced_eos_token_id,
)
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_generate = model.generate(
input_ids,
@ -216,7 +207,7 @@ class GenerationTesterMixin:
output_scores=output_scores,
output_logits=output_logits,
return_dict_in_generate=return_dict_in_generate,
**logits_process_kwargs,
**logits_processor_kwargs,
**model_kwargs,
)
@ -228,8 +219,6 @@ class GenerationTesterMixin:
input_ids,
attention_mask,
num_return_sequences,
logits_warper_kwargs,
process_kwargs,
output_scores=False,
output_logits=False,
output_attentions=False,
@ -237,6 +226,7 @@ class GenerationTesterMixin:
return_dict_in_generate=False,
):
torch.manual_seed(0)
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_generate = model.generate(
input_ids,
@ -249,8 +239,7 @@ class GenerationTesterMixin:
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
**logits_warper_kwargs,
**process_kwargs,
**logits_processor_kwargs,
**model_kwargs,
)
@ -262,13 +251,13 @@ class GenerationTesterMixin:
input_ids,
attention_mask,
beam_kwargs,
logits_process_kwargs,
output_scores=False,
output_logits=False,
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=False,
):
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_generate = model.generate(
input_ids,
@ -280,7 +269,7 @@ class GenerationTesterMixin:
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
**beam_kwargs,
**logits_process_kwargs,
**logits_processor_kwargs,
**model_kwargs,
)
@ -292,7 +281,6 @@ class GenerationTesterMixin:
input_ids,
attention_mask,
beam_kwargs,
logits_warper_kwargs,
output_scores=False,
output_logits=False,
output_attentions=False,
@ -300,6 +288,7 @@ class GenerationTesterMixin:
return_dict_in_generate=False,
):
torch.manual_seed(0)
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_generate = model.generate(
input_ids,
@ -311,7 +300,7 @@ class GenerationTesterMixin:
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
**beam_kwargs,
**logits_warper_kwargs,
**logits_processor_kwargs,
**model_kwargs,
)
@ -323,13 +312,13 @@ class GenerationTesterMixin:
input_ids,
attention_mask,
beam_kwargs,
logits_process_kwargs,
output_scores=False,
output_logits=False,
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=False,
):
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_generate = model.generate(
input_ids,
@ -341,7 +330,7 @@ class GenerationTesterMixin:
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
**beam_kwargs,
**logits_process_kwargs,
**logits_processor_kwargs,
**model_kwargs,
)
@ -354,13 +343,13 @@ class GenerationTesterMixin:
attention_mask,
constraints,
beam_kwargs,
logits_process_kwargs,
output_scores=False,
output_logits=False,
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=False,
):
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_generate = model.generate(
input_ids,
@ -373,7 +362,7 @@ class GenerationTesterMixin:
return_dict_in_generate=return_dict_in_generate,
constraints=constraints,
**beam_kwargs,
**logits_process_kwargs,
**logits_processor_kwargs,
**model_kwargs,
)
@ -395,12 +384,7 @@ class GenerationTesterMixin:
"top_k": 5,
}
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
input_ids.shape[-1],
forced_bos_token_id=model.config.forced_bos_token_id,
forced_eos_token_id=model.config.forced_eos_token_id,
)
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_generate = model.generate(
input_ids,
@ -412,7 +396,7 @@ class GenerationTesterMixin:
output_scores=output_scores,
output_logits=output_logits,
return_dict_in_generate=return_dict_in_generate,
**logits_process_kwargs,
**logits_processor_kwargs,
**model_kwargs,
**contrastive_search_kwargs,
)
@ -495,19 +479,11 @@ class GenerationTesterMixin:
config, input_ids, attention_mask = self._get_input_ids_and_config()
model = model_class(config).to(torch_device).eval()
process_kwargs, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(
input_ids.shape[-1],
forced_bos_token_id=model.config.forced_bos_token_id,
forced_eos_token_id=model.config.forced_eos_token_id,
)
output_generate = self._sample_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
num_return_sequences=1,
logits_warper_kwargs=logits_warper_kwargs,
process_kwargs=process_kwargs,
)
if model.config.is_encoder_decoder:
@ -521,20 +497,11 @@ class GenerationTesterMixin:
config.use_cache = False
model = model_class(config).to(torch_device).eval()
process_kwargs, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(
input_ids.shape[-1],
forced_bos_token_id=model.config.forced_bos_token_id,
forced_eos_token_id=model.config.forced_eos_token_id,
)
output_generate = self._sample_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
num_return_sequences=2,
logits_warper_kwargs=logits_warper_kwargs,
process_kwargs=process_kwargs,
output_scores=True,
output_logits=True,
output_hidden_states=True,
@ -561,19 +528,12 @@ class GenerationTesterMixin:
model = model_class(config).to(torch_device).eval()
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
input_ids.shape[-1],
config.forced_bos_token_id,
config.forced_eos_token_id,
)
beam_kwargs = self._get_beam_kwargs()
output_generate = self._beam_search_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
beam_kwargs=beam_kwargs,
logits_process_kwargs=logits_process_kwargs,
)
if model.config.is_encoder_decoder:
@ -589,18 +549,12 @@ class GenerationTesterMixin:
config.use_cache = False
model = model_class(config).to(torch_device).eval()
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
input_ids.shape[-1],
config.forced_bos_token_id,
config.forced_eos_token_id,
)
beam_kwargs = self._get_beam_kwargs()
output_generate = self._beam_search_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
beam_kwargs=beam_kwargs,
logits_process_kwargs=logits_process_kwargs,
output_scores=True,
output_logits=True,
output_hidden_states=True,
@ -633,12 +587,6 @@ class GenerationTesterMixin:
self.skipTest(reason="Won't fix: model with non-standard dictionary output shapes")
model = model_class(config).to(torch_device).eval()
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
input_ids.shape[-1],
config.forced_bos_token_id,
config.forced_eos_token_id,
)
beam_kwargs = self._get_beam_kwargs()
config.use_cache = True
@ -649,7 +597,6 @@ class GenerationTesterMixin:
input_ids=input_ids,
attention_mask=attention_mask,
beam_kwargs=beam_kwargs,
logits_process_kwargs=logits_process_kwargs,
output_scores=True,
output_logits=True,
output_hidden_states=True,
@ -693,17 +640,13 @@ class GenerationTesterMixin:
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask = self._get_input_ids_and_config()
_, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(input_ids.shape[-1])
model = model_class(config).to(torch_device).eval()
beam_kwargs = self._get_beam_kwargs()
output_generate = self._beam_sample_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
beam_kwargs=beam_kwargs,
logits_warper_kwargs=logits_warper_kwargs,
)
if model.config.is_encoder_decoder:
@ -711,7 +654,13 @@ class GenerationTesterMixin:
else:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
if "inputs_embeds" in set(inspect.signature(model.prepare_inputs_for_generation).parameters):
prepare_inputs_for_generation_args = set(inspect.signature(model.prepare_inputs_for_generation).parameters)
# `inputs_embeds` input is well supported when `cache_positions` is used, because it means the modeling
# code is up to date with our most recent standards
if (
"inputs_embeds" in prepare_inputs_for_generation_args
and "cache_positions" in prepare_inputs_for_generation_args
):
input_embeds = model.get_input_embeddings()(input_ids)
beam_kwargs.update({"inputs_embeds": input_embeds})
output_generate2 = self._beam_sample_generate(
@ -719,7 +668,6 @@ class GenerationTesterMixin:
input_ids=None,
attention_mask=attention_mask,
beam_kwargs=beam_kwargs,
logits_warper_kwargs=logits_warper_kwargs,
)
torch.testing.assert_close(output_generate[:, input_embeds.shape[1] :], output_generate2)
@ -732,7 +680,6 @@ class GenerationTesterMixin:
config.use_cache = False
model = model_class(config).to(torch_device).eval()
_, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(input_ids.shape[-1])
beam_kwargs = self._get_beam_kwargs()
output_generate = self._beam_sample_generate(
@ -740,7 +687,6 @@ class GenerationTesterMixin:
input_ids=input_ids,
attention_mask=attention_mask,
beam_kwargs=beam_kwargs,
logits_warper_kwargs=logits_warper_kwargs,
output_scores=True,
output_logits=True,
output_hidden_states=True,
@ -788,12 +734,6 @@ class GenerationTesterMixin:
config, input_ids, attention_mask = self._get_input_ids_and_config()
model = model_class(config).to(torch_device).eval()
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
input_ids.shape[-1],
config.forced_bos_token_id,
config.forced_eos_token_id,
)
# check `generate()` and `group_beam_search()` are equal
beam_kwargs = self._get_diverse_beam_kwargs()
output_generate = self._group_beam_search_generate(
@ -801,7 +741,6 @@ class GenerationTesterMixin:
input_ids=input_ids,
attention_mask=attention_mask,
beam_kwargs=beam_kwargs,
logits_process_kwargs=logits_process_kwargs,
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
@ -816,7 +755,6 @@ class GenerationTesterMixin:
input_ids=input_ids,
attention_mask=attention_mask,
beam_kwargs=beam_kwargs,
logits_process_kwargs=logits_process_kwargs,
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
@ -829,19 +767,12 @@ class GenerationTesterMixin:
config.use_cache = False
model = model_class(config).to(torch_device).eval()
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
input_ids.shape[-1],
config.forced_bos_token_id,
config.forced_eos_token_id,
)
beam_kwargs = self._get_diverse_beam_kwargs()
output_generate = self._group_beam_search_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
beam_kwargs=beam_kwargs,
logits_process_kwargs=logits_process_kwargs,
output_scores=True,
output_logits=True,
output_hidden_states=True,
@ -871,12 +802,6 @@ class GenerationTesterMixin:
model = model_class(config).to(torch_device).eval()
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
input_ids.shape[-1],
config.forced_bos_token_id,
config.forced_eos_token_id,
)
# Sample constraints
min_id = 3
max_id = config.vocab_size
@ -893,7 +818,6 @@ class GenerationTesterMixin:
attention_mask=attention_mask,
constraints=constraints,
beam_kwargs=beam_kwargs,
logits_process_kwargs=logits_process_kwargs,
)
if model.config.is_encoder_decoder:
@ -919,7 +843,6 @@ class GenerationTesterMixin:
attention_mask=attention_mask,
constraints=constraints,
beam_kwargs=beam_kwargs,
logits_process_kwargs=logits_process_kwargs,
)
if model.config.is_encoder_decoder:
@ -938,11 +861,6 @@ class GenerationTesterMixin:
config.use_cache = False
model = model_class(config).to(torch_device).eval()
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
input_ids.shape[-1],
config.forced_bos_token_id,
config.forced_eos_token_id,
)
# Sample constraints
min_id = 3
@ -959,7 +877,6 @@ class GenerationTesterMixin:
attention_mask=attention_mask,
constraints=constraints,
beam_kwargs=beam_kwargs,
logits_process_kwargs=logits_process_kwargs,
output_scores=True,
output_logits=True,
output_hidden_states=True,

View File

@ -414,10 +414,6 @@ class BioGptModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
@unittest.skip(reason="The `input_embeds` when fed don't produce the same results.")
def test_beam_sample_generate(self):
pass
@require_torch
class BioGptModelIntegrationTest(unittest.TestCase):

View File

@ -433,6 +433,10 @@ class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
@unittest.skip("The `input_embeds` when fed don't produce the same results.")
def test_beam_sample_generate(self):
pass
@require_torch
class MambaIntegrationTests(unittest.TestCase):

View File

@ -283,6 +283,12 @@ class Mamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
@unittest.skip(
reason="Mamba2 does not support generating with input embeddings (custom cache_position computation)"
)
def test_inputs_embeds_matches_input_ids_with_generate(self):
pass
@require_torch
@slow

View File

@ -293,15 +293,9 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
return config, input_ids, attention_mask
@staticmethod
def _get_logits_processor_and_warper_kwargs(
input_length,
forced_bos_token_id=None,
forced_eos_token_id=None,
):
process_kwargs = {}
warper_kwargs = {}
return process_kwargs, warper_kwargs
def _get_logits_processor_kwargs(self, do_sample=False):
logits_processor_kwargs = {}
return logits_processor_kwargs
def test_greedy_generate_stereo_outputs(self):
for model_class in self.greedy_sample_model_classes:
@ -1483,15 +1477,9 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
return output_generate
@staticmethod
def _get_logits_processor_and_warper_kwargs(
input_length,
forced_bos_token_id=None,
forced_eos_token_id=None,
):
process_kwargs = {}
warper_kwargs = {}
return process_kwargs, warper_kwargs
def _get_logits_processor_kwargs(self, do_sample=False):
logits_processor_kwargs = {}
return logits_processor_kwargs
def test_greedy_generate_dict_outputs(self):
for model_class in self.greedy_sample_model_classes:

View File

@ -296,15 +296,9 @@ class MusicgenMelodyDecoderTest(ModelTesterMixin, GenerationTesterMixin, unittes
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
return config, input_ids, attention_mask
@staticmethod
def _get_logits_processor_and_warper_kwargs(
input_length,
forced_bos_token_id=None,
forced_eos_token_id=None,
):
process_kwargs = {}
warper_kwargs = {}
return process_kwargs, warper_kwargs
def _get_logits_processor_kwargs(self, do_sample=False):
logits_processor_kwargs = {}
return logits_processor_kwargs
def test_greedy_generate_stereo_outputs(self):
for model_class in self.greedy_sample_model_classes:
@ -1467,15 +1461,9 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
return output_generate
@staticmethod
def _get_logits_processor_and_warper_kwargs(
input_length,
forced_bos_token_id=None,
forced_eos_token_id=None,
):
process_kwargs = {}
warper_kwargs = {}
return process_kwargs, warper_kwargs
def _get_logits_processor_kwargs(self, do_sample=False):
logits_processor_kwargs = {}
return logits_processor_kwargs
def test_greedy_generate_dict_outputs(self):
for model_class in self.greedy_sample_model_classes:

View File

@ -413,6 +413,10 @@ class RecurrentGemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
def test_initialization(self):
pass
@unittest.skip(reason="RecurrentGemma does not support generating with input embeddings (missing position_ids)")
def test_inputs_embeds_matches_input_ids_with_generate(self):
pass
@require_torch_gpu
@slow

View File

@ -68,14 +68,7 @@ if is_torch_available():
set_seed,
)
from transformers.generation import (
BeamSampleDecoderOnlyOutput,
BeamSampleEncoderDecoderOutput,
BeamSearchDecoderOnlyOutput,
BeamSearchEncoderDecoderOutput,
GenerateBeamDecoderOnlyOutput,
GenerateBeamEncoderDecoderOutput,
GenerateEncoderDecoderOutput,
PhrasalConstraint,
)
from transformers.generation.logits_process import LogitsProcessor
from transformers.models.whisper.modeling_whisper import WhisperDecoder, WhisperEncoder, sinusoids
@ -419,6 +412,30 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
return False
def _get_logits_processor_kwargs(self, do_sample=False):
# Overwritten from `GenerationTesterMixin`, Whisper needs `"temperature": 0.0` to be able to do beam search
logits_processor_kwargs = super()._get_logits_processor_kwargs(do_sample=do_sample)
logits_processor_kwargs["temperature"] = 0.0
return logits_processor_kwargs
def _get_beam_kwargs(self, num_return_sequences=1):
# Overwritten from `GenerationTesterMixin`, Whisper's `num_return_sequences` differs from the core `generate`
beam_kwargs = super()._get_beam_kwargs(num_return_sequences=num_return_sequences)
beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"]
return beam_kwargs
def _get_diverse_beam_kwargs(self, num_return_sequences=1):
# Overwritten from `GenerationTesterMixin`, Whisper's `num_return_sequences` differs from the core `generate`
beam_kwargs = super()._get_diverse_beam_kwargs(num_return_sequences=num_return_sequences)
beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"]
return beam_kwargs
def _get_constrained_beam_kwargs(self, num_return_sequences=1):
# Overwritten from `GenerationTesterMixin`, Whisper's `num_return_sequences` differs from the core `generate`
beam_kwargs = super()._get_constrained_beam_kwargs(num_return_sequences=num_return_sequences)
beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"]
return beam_kwargs
def setUp(self):
self.model_tester = WhisperModelTester(self)
self.config_tester = ConfigTester(self, config_class=WhisperConfig)
@ -1551,241 +1568,6 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
def test_longform_generate_multi_batch_cond_prev(self):
self._check_longform_generate_multi_batch(condition_on_prev_tokens=True)
def test_beam_sample_generate_dict_output(self):
# We overwrite test_beam_sample_generate_dict_output in test_utils as
# we can only perform beam search if the temperature is set to 0 in Whisper.
config, input_ids, attention_mask = self._get_input_ids_and_config()
# disable cache
config.use_cache = False
model = WhisperForConditionalGeneration(config).to(torch_device).eval()
_, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(input_ids.shape[-1])
beam_kwargs = self._get_beam_kwargs()
# With Whisper, we can only perform a beam search if the temperature is set to 0.
logits_warper_kwargs["temperature"] = 0
# We will return num_beams sequences per input only if num_return_sequences == num_beams:
beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"]
output_generate = self._beam_sample_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
beam_kwargs=beam_kwargs,
logits_warper_kwargs=logits_warper_kwargs,
output_scores=True,
output_logits=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
# Retrocompatibility check
self.assertIsInstance(output_generate, BeamSampleEncoderDecoderOutput)
else:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
# Retrocompatibility check
self.assertIsInstance(output_generate, BeamSampleDecoderOnlyOutput)
self._check_outputs(output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"])
def test_beam_search_generate_dict_output(self):
# We overwrite test_beam_search_generate_dict_output in test_utils as
# we can only perform beam search if the temperature is set to 0 in Whisper.
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask = self._get_input_ids_and_config()
# disable cache
config.use_cache = False
model = model_class(config).to(torch_device).eval()
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
input_ids.shape[-1],
config.forced_bos_token_id,
config.forced_eos_token_id,
)
beam_kwargs = self._get_beam_kwargs()
# With Whisper, we can only perform a beam search if the temperature is set to 0.
logits_process_kwargs["temperature"] = 0
# We will return num_beams sequences per input only if num_return_sequences == num_beams:
beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"]
output_generate = self._beam_search_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
beam_kwargs=beam_kwargs,
logits_process_kwargs=logits_process_kwargs,
output_scores=True,
output_logits=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
# Retrocompatibility check
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
else:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
# Retrocompatibility check
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
self._check_outputs(
output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"]
)
def test_beam_search_generate_dict_outputs_use_cache(self):
# We overwrite test_beam_search_generate_dict_outputs_use_cache in test_utils as
# we can only perform beam search if the temperature is set to 0 in Whisper.
for model_class in self.all_generative_model_classes:
# enable cache
config, input_ids, attention_mask = self._get_input_ids_and_config()
if not hasattr(config, "use_cache"):
self.skipTest("This model doesn't support caching")
model = model_class(config).to(torch_device).eval()
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
input_ids.shape[-1],
config.forced_bos_token_id,
config.forced_eos_token_id,
)
beam_kwargs = self._get_beam_kwargs()
# We will return num_beams sequences per input only if num_return_sequences == num_beams:
beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"]
config.use_cache = True
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
output_generate = self._beam_search_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
beam_kwargs=beam_kwargs,
logits_process_kwargs=logits_process_kwargs,
output_scores=True,
output_logits=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
else:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
self._check_outputs(
output_generate, input_ids, model.config, use_cache=True, num_return_sequences=beam_kwargs["num_beams"]
)
def test_group_beam_search_generate_dict_output(self):
# We overwrite test_group_beam_search_generate_dict_output in test_utils as
# we can only perform beam search if the temperature is set to 0 in Whisper.
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask = self._get_input_ids_and_config()
config.use_cache = False
model = model_class(config).to(torch_device).eval()
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
input_ids.shape[-1],
config.forced_bos_token_id,
config.forced_eos_token_id,
)
beam_kwargs = self._get_diverse_beam_kwargs()
# We will return num_beams sequences per input only if num_return_sequences == num_beams:
beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"]
output_generate = self._group_beam_search_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
beam_kwargs=beam_kwargs,
logits_process_kwargs=logits_process_kwargs,
output_scores=True,
output_logits=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
# Retrocompatibility check
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
else:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
# Retrocompatibility check
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
self._check_outputs(
output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"]
)
def test_constrained_beam_search_generate_dict_output(self):
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask = self._get_input_ids_and_config()
# disable cache
config.use_cache = False
model = model_class(config).to(torch_device).eval()
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
input_ids.shape[-1],
config.forced_bos_token_id,
config.forced_eos_token_id,
)
# Sample constraints
min_id = 3
max_id = model.config.vocab_size
force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0]
constraints = [
PhrasalConstraint(force_tokens),
]
beam_kwargs = self._get_constrained_beam_kwargs()
output_generate = self._constrained_beam_search_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
constraints=constraints,
beam_kwargs=beam_kwargs,
logits_process_kwargs=logits_process_kwargs,
output_scores=True,
output_logits=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
# Retrocompatibility check
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
else:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
# Retrocompatibility check
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
self._check_outputs(
output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_return_sequences"]
)
@is_flaky() # TODO (joao, sanchit): fails ~9% of the times. Does the original test have the same issue?
def test_custom_4d_attention_mask(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()

View File

@ -70,6 +70,7 @@ OBJECTS_TO_IGNORE = [
# Deprecated
"InputExample",
"InputFeatures",
"LogitsWarper",
# Signature is *args/**kwargs
"TFSequenceSummary",
"TFBertTokenizer",

View File

@ -932,6 +932,7 @@ DEPRECATED_OBJECTS = [
"LineByLineTextDataset",
"LineByLineWithRefDataset",
"LineByLineWithSOPTextDataset",
"LogitsWarper",
"NerPipeline",
"PretrainedBartModel",
"PretrainedFSMTModel",