mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Generate: unify LogitsWarper
and LogitsProcessor
(#32626)
This commit is contained in:
parent
5fd7ca7bc9
commit
70d5df6107
@ -158,9 +158,6 @@ generation.
|
||||
[[autodoc]] LogitsProcessorList
|
||||
- __call__
|
||||
|
||||
[[autodoc]] LogitsWarper
|
||||
- __call__
|
||||
|
||||
[[autodoc]] MinLengthLogitsProcessor
|
||||
- __call__
|
||||
|
||||
@ -421,4 +418,3 @@ A [`Constraint`] can be used to force the generation to include specific tokens
|
||||
|
||||
[[autodoc]] WatermarkDetector
|
||||
- __call__
|
||||
|
||||
|
@ -157,9 +157,6 @@ generation_output[:2]
|
||||
[[autodoc]] LogitsProcessorList
|
||||
- __call__
|
||||
|
||||
[[autodoc]] LogitsWarper
|
||||
- __call__
|
||||
|
||||
[[autodoc]] MinLengthLogitsProcessor
|
||||
- __call__
|
||||
|
||||
|
@ -151,9 +151,6 @@ generation_output[:2]
|
||||
[[autodoc]] LogitsProcessorList
|
||||
- __call__
|
||||
|
||||
[[autodoc]] LogitsWarper
|
||||
- __call__
|
||||
|
||||
[[autodoc]] MinLengthLogitsProcessor
|
||||
- __call__
|
||||
|
||||
|
@ -190,9 +190,9 @@ class GenerationConfig(PushToHubMixin):
|
||||
triggers a [disjunctive constraint](https://github.com/huggingface/transformers/issues/14081), where one
|
||||
can allow different forms of each word.
|
||||
renormalize_logits (`bool`, *optional*, defaults to `False`):
|
||||
Whether to renormalize the logits after applying all the logits processors or warpers (including the custom
|
||||
Whether to renormalize the logits after applying all the logits processors (including the custom
|
||||
ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the score logits
|
||||
are normalized but some logit processors or warpers break the normalization.
|
||||
are normalized but some logit processors break the normalization.
|
||||
constraints (`List[Constraint]`, *optional*):
|
||||
Custom constraints that can be added to the generation to ensure that the output will contain the use of
|
||||
certain tokens as defined by `Constraint` objects, in the most sensible way possible.
|
||||
|
@ -55,6 +55,12 @@ class LogitsProcessor:
|
||||
class LogitsWarper:
|
||||
"""Abstract base class for all logit warpers that can be applied during generation with multinomial sampling."""
|
||||
|
||||
def __init__(self):
|
||||
logger.warning_once(
|
||||
"`LogitsWarper` is deprecated and will be removed in v4.48. Your class should inherit `LogitsProcessor` "
|
||||
"instead, which has the same properties and interface."
|
||||
)
|
||||
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
raise NotImplementedError(
|
||||
@ -64,9 +70,9 @@ class LogitsWarper:
|
||||
|
||||
class LogitsProcessorList(list):
|
||||
"""
|
||||
This class can be used to create a list of [`LogitsProcessor`] or [`LogitsWarper`] to subsequently process a
|
||||
`scores` input tensor. This class inherits from list and adds a specific *__call__* method to apply each
|
||||
[`LogitsProcessor`] or [`LogitsWarper`] to the inputs.
|
||||
This class can be used to create a list of [`LogitsProcessor`] to subsequently process a `scores` input tensor.
|
||||
This class inherits from list and adds a specific *__call__* method to apply each [`LogitsProcessor`] to the
|
||||
inputs.
|
||||
"""
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
||||
@ -233,9 +239,9 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
|
||||
return scores_processed
|
||||
|
||||
|
||||
class TemperatureLogitsWarper(LogitsWarper):
|
||||
class TemperatureLogitsWarper(LogitsProcessor):
|
||||
r"""
|
||||
[`LogitsWarper`] for temperature (exponential scaling output probability distribution), which effectively means
|
||||
[`LogitsProcessor`] for temperature (exponential scaling output probability distribution), which effectively means
|
||||
that it can control the randomness of the predicted tokens. Often used together with [`TopPLogitsWarper`] and
|
||||
[`TopKLogitsWarper`].
|
||||
|
||||
@ -408,10 +414,10 @@ class EncoderRepetitionPenaltyLogitsProcessor(LogitsProcessor):
|
||||
return scores_processed
|
||||
|
||||
|
||||
class TopPLogitsWarper(LogitsWarper):
|
||||
class TopPLogitsWarper(LogitsProcessor):
|
||||
"""
|
||||
[`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off. Often
|
||||
used together with [`TemperatureLogitsWarper`] and [`TopKLogitsWarper`].
|
||||
[`LogitsProcessor`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
|
||||
Often used together with [`TemperatureLogitsWarper`] and [`TopKLogitsWarper`].
|
||||
|
||||
Args:
|
||||
top_p (`float`):
|
||||
@ -475,10 +481,10 @@ class TopPLogitsWarper(LogitsWarper):
|
||||
return scores_processed
|
||||
|
||||
|
||||
class TopKLogitsWarper(LogitsWarper):
|
||||
class TopKLogitsWarper(LogitsProcessor):
|
||||
r"""
|
||||
[`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements. Often used together
|
||||
with [`TemperatureLogitsWarper`] and [`TopPLogitsWarper`].
|
||||
[`LogitsProcessor`] that performs top-k, i.e. restricting to the k highest probability elements. Often used
|
||||
together with [`TemperatureLogitsWarper`] and [`TopPLogitsWarper`].
|
||||
|
||||
Args:
|
||||
top_k (`int`):
|
||||
@ -528,9 +534,9 @@ class TopKLogitsWarper(LogitsWarper):
|
||||
return scores_processed
|
||||
|
||||
|
||||
class MinPLogitsWarper(LogitsWarper):
|
||||
class MinPLogitsWarper(LogitsProcessor):
|
||||
"""
|
||||
[`LogitsWarper`] that performs min-p, i.e. keeps all tokens that are above a minimum probability, scaled by the
|
||||
[`LogitsProcessor`] that performs min-p, i.e. keeps all tokens that are above a minimum probability, scaled by the
|
||||
probability of the most likely token. As a result, the filter becomes more agressive in the presence of
|
||||
high-probability tokens, which is a sign of a confident output that we shouldn't deviate from.
|
||||
|
||||
@ -605,11 +611,11 @@ class MinPLogitsWarper(LogitsWarper):
|
||||
return scores_processed
|
||||
|
||||
|
||||
class TypicalLogitsWarper(LogitsWarper):
|
||||
class TypicalLogitsWarper(LogitsProcessor):
|
||||
r"""
|
||||
[`LogitsWarper`] that performs typical decoding. Inspired on how humans use language, it prioritizes tokens whose
|
||||
log probability is close to the entropy of the token probability distribution. This means that the most likely
|
||||
tokens may be discarded in the process.
|
||||
[`LogitsProcessor`] that performs typical decoding. Inspired on how humans use language, it prioritizes tokens
|
||||
whose log probability is close to the entropy of the token probability distribution. This means that the most
|
||||
likely tokens may be discarded in the process.
|
||||
|
||||
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information.
|
||||
|
||||
@ -693,9 +699,9 @@ class TypicalLogitsWarper(LogitsWarper):
|
||||
return scores_processed
|
||||
|
||||
|
||||
class EpsilonLogitsWarper(LogitsWarper):
|
||||
class EpsilonLogitsWarper(LogitsProcessor):
|
||||
r"""
|
||||
[`LogitsWarper`] that performs epsilon-sampling, i.e. restricting to tokens with `prob >= epsilon`. Takes the
|
||||
[`LogitsProcessor`] that performs epsilon-sampling, i.e. restricting to tokens with `prob >= epsilon`. Takes the
|
||||
largest min_tokens_to_keep tokens if no tokens satisfy this constraint. See [Truncation Sampling as Language Model
|
||||
Desmoothing](https://arxiv.org/abs/2210.15191) for more information.
|
||||
|
||||
@ -762,15 +768,15 @@ class EpsilonLogitsWarper(LogitsWarper):
|
||||
return scores_processed
|
||||
|
||||
|
||||
class EtaLogitsWarper(LogitsWarper):
|
||||
class EtaLogitsWarper(LogitsProcessor):
|
||||
r"""
|
||||
[`LogitsWarper`] that performs eta-sampling, a technique to filter out tokens with probabilities below a dynamic
|
||||
[`LogitsProcessor`] that performs eta-sampling, a technique to filter out tokens with probabilities below a dynamic
|
||||
cutoff value, `eta`, which is calculated based on a combination of the hyperparameter `epsilon` and the entropy of
|
||||
the token probabilities, i.e. `eta := min(epsilon, sqrt(epsilon * e^-entropy(probabilities)))`. Takes the largest
|
||||
min_tokens_to_keep tokens if no tokens satisfy this constraint. It addresses the issue of poor quality in long
|
||||
samples of text generated by neural language models leading to more coherent and fluent text. See [Truncation
|
||||
Sampling as Language Model Desmoothing](https://arxiv.org/abs/2210.15191) for more information. Note: `do_sample`
|
||||
must be set to `True` for this `LogitsWarper` to work.
|
||||
must be set to `True` for this `LogitsProcessor` to work.
|
||||
|
||||
|
||||
Args:
|
||||
@ -1708,9 +1714,9 @@ class ExponentialDecayLengthPenalty(LogitsProcessor):
|
||||
return scores_processed
|
||||
|
||||
|
||||
class LogitNormalization(LogitsProcessor, LogitsWarper):
|
||||
class LogitNormalization(LogitsProcessor):
|
||||
r"""
|
||||
[`LogitsWarper`] and [`LogitsProcessor`] for normalizing the scores using log-softmax. It's important to normalize
|
||||
[`LogitsProcessor`] for normalizing the scores using log-softmax. It's important to normalize
|
||||
the scores during beam search, after applying the logits processors or warpers, since the search algorithm used in
|
||||
this library doesn't do it (it only does it before, but they may need re-normalization) but it still supposes that
|
||||
the scores are normalized when comparing the hypotheses.
|
||||
|
@ -735,61 +735,6 @@ class GenerationMixin:
|
||||
)
|
||||
return candidate_generator
|
||||
|
||||
def _get_logits_warper(
|
||||
self,
|
||||
generation_config: GenerationConfig,
|
||||
device: str,
|
||||
) -> LogitsProcessorList:
|
||||
"""
|
||||
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances
|
||||
used for multinomial sampling.
|
||||
"""
|
||||
|
||||
# instantiate warpers list
|
||||
warpers = LogitsProcessorList()
|
||||
|
||||
# In beam methods, we need to keep at least one non-eos token to explore continuations that might have a
|
||||
# better score (i.e. keep len(list(generation_config._eos_token_tensor)) + 1)
|
||||
if generation_config.num_beams > 1:
|
||||
if isinstance(generation_config._eos_token_tensor, list):
|
||||
min_tokens_to_keep = len(generation_config._eos_token_tensor) + 1
|
||||
elif isinstance(generation_config._eos_token_tensor, torch.Tensor):
|
||||
min_tokens_to_keep = generation_config._eos_token_tensor.shape[0] + 1
|
||||
else:
|
||||
min_tokens_to_keep = 2
|
||||
else:
|
||||
min_tokens_to_keep = 1
|
||||
|
||||
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
|
||||
# all samplers can be found in `generation_utils_samplers.py`
|
||||
if generation_config.temperature is not None and generation_config.temperature != 1.0:
|
||||
warpers.append(TemperatureLogitsWarper(generation_config.temperature))
|
||||
if generation_config.top_k is not None and generation_config.top_k != 0:
|
||||
warpers.append(TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep))
|
||||
if generation_config.top_p is not None and generation_config.top_p < 1.0:
|
||||
warpers.append(TopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep))
|
||||
if generation_config.min_p is not None:
|
||||
# Applied after temperature scaling (see https://github.com/ggerganov/llama.cpp/pull/3841#issuecomment-2073826084)
|
||||
warpers.append(MinPLogitsWarper(min_p=generation_config.min_p, min_tokens_to_keep=min_tokens_to_keep))
|
||||
if generation_config.typical_p is not None and generation_config.typical_p < 1.0:
|
||||
warpers.append(
|
||||
TypicalLogitsWarper(mass=generation_config.typical_p, min_tokens_to_keep=min_tokens_to_keep)
|
||||
)
|
||||
if generation_config.epsilon_cutoff is not None and 0.0 < generation_config.epsilon_cutoff < 1.0:
|
||||
warpers.append(
|
||||
EpsilonLogitsWarper(epsilon=generation_config.epsilon_cutoff, min_tokens_to_keep=min_tokens_to_keep)
|
||||
)
|
||||
if generation_config.eta_cutoff is not None and 0.0 < generation_config.eta_cutoff < 1.0:
|
||||
warpers.append(
|
||||
EtaLogitsWarper(
|
||||
epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep, device=device
|
||||
)
|
||||
)
|
||||
# `LogitNormalization` should always be the last logit processor, when present
|
||||
if generation_config.renormalize_logits is True:
|
||||
warpers.append(LogitNormalization())
|
||||
return warpers
|
||||
|
||||
def _get_logits_processor(
|
||||
self,
|
||||
generation_config: GenerationConfig,
|
||||
@ -960,7 +905,58 @@ class GenerationMixin:
|
||||
context_width=generation_config.watermarking_config.context_width,
|
||||
)
|
||||
)
|
||||
|
||||
# TODO (joao): find a strategy to specify the order of the processors
|
||||
processors = self._merge_criteria_processor_list(processors, logits_processor)
|
||||
|
||||
# Processors previously known as `LogitsWarpers`, only applied with sampling strategies
|
||||
if generation_config.do_sample:
|
||||
# In beam methods, we need to keep at least one non-eos token to explore continuations that might have a
|
||||
# better score (i.e. keep len(list(generation_config._eos_token_tensor)) + 1)
|
||||
if generation_config.num_beams > 1:
|
||||
if isinstance(generation_config._eos_token_tensor, list):
|
||||
min_tokens_to_keep = len(generation_config._eos_token_tensor) + 1
|
||||
elif isinstance(generation_config._eos_token_tensor, torch.Tensor):
|
||||
min_tokens_to_keep = generation_config._eos_token_tensor.shape[0] + 1
|
||||
else:
|
||||
min_tokens_to_keep = 2
|
||||
else:
|
||||
min_tokens_to_keep = 1
|
||||
|
||||
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
|
||||
# all samplers can be found in `generation_utils_samplers.py`
|
||||
if generation_config.temperature is not None and generation_config.temperature != 1.0:
|
||||
processors.append(TemperatureLogitsWarper(generation_config.temperature))
|
||||
if generation_config.top_k is not None and generation_config.top_k != 0:
|
||||
processors.append(
|
||||
TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep)
|
||||
)
|
||||
if generation_config.top_p is not None and generation_config.top_p < 1.0:
|
||||
processors.append(
|
||||
TopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep)
|
||||
)
|
||||
if generation_config.min_p is not None:
|
||||
# Applied after temperature scaling (see https://github.com/ggerganov/llama.cpp/pull/3841#issuecomment-2073826084)
|
||||
processors.append(
|
||||
MinPLogitsWarper(min_p=generation_config.min_p, min_tokens_to_keep=min_tokens_to_keep)
|
||||
)
|
||||
if generation_config.typical_p is not None and generation_config.typical_p < 1.0:
|
||||
processors.append(
|
||||
TypicalLogitsWarper(mass=generation_config.typical_p, min_tokens_to_keep=min_tokens_to_keep)
|
||||
)
|
||||
if generation_config.epsilon_cutoff is not None and 0.0 < generation_config.epsilon_cutoff < 1.0:
|
||||
processors.append(
|
||||
EpsilonLogitsWarper(
|
||||
epsilon=generation_config.epsilon_cutoff, min_tokens_to_keep=min_tokens_to_keep
|
||||
)
|
||||
)
|
||||
if generation_config.eta_cutoff is not None and 0.0 < generation_config.eta_cutoff < 1.0:
|
||||
processors.append(
|
||||
EtaLogitsWarper(
|
||||
epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep, device=device
|
||||
)
|
||||
)
|
||||
|
||||
# `LogitNormalization` should always be the last logit processor, when present
|
||||
if generation_config.renormalize_logits is True:
|
||||
processors.append(LogitNormalization())
|
||||
@ -1940,22 +1936,11 @@ class GenerationMixin:
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
|
||||
# 12. prepare logits warper (if `do_sample` is `True`)
|
||||
prepared_logits_warper = (
|
||||
self._get_logits_warper(
|
||||
generation_config,
|
||||
device=input_ids.device,
|
||||
)
|
||||
if generation_config.do_sample
|
||||
else None
|
||||
)
|
||||
|
||||
# 13. run assisted generate
|
||||
# 12. run assisted generate
|
||||
result = self._assisted_decoding(
|
||||
input_ids,
|
||||
candidate_generator=candidate_generator,
|
||||
logits_processor=prepared_logits_processor,
|
||||
logits_warper=prepared_logits_warper,
|
||||
stopping_criteria=prepared_stopping_criteria,
|
||||
generation_config=generation_config,
|
||||
synced_gpus=synced_gpus,
|
||||
@ -1968,16 +1953,10 @@ class GenerationMixin:
|
||||
raise ValueError(
|
||||
f"dola decoding is not supported with stateful models, such as {self.__class__.__name__}"
|
||||
)
|
||||
prepared_logits_warper = (
|
||||
self._get_logits_warper(generation_config, device=input_ids.device)
|
||||
if generation_config.do_sample
|
||||
else None
|
||||
)
|
||||
result = self._dola_decoding(
|
||||
input_ids,
|
||||
dola_layers=generation_config.dola_layers,
|
||||
logits_processor=prepared_logits_processor,
|
||||
logits_warper=prepared_logits_warper,
|
||||
stopping_criteria=prepared_stopping_criteria,
|
||||
generation_config=generation_config,
|
||||
synced_gpus=synced_gpus,
|
||||
@ -2005,14 +1984,7 @@ class GenerationMixin:
|
||||
)
|
||||
|
||||
elif generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
|
||||
# 11. prepare logits warper
|
||||
prepared_logits_warper = (
|
||||
self._get_logits_warper(generation_config, device=input_ids.device)
|
||||
if generation_config.do_sample
|
||||
else None
|
||||
)
|
||||
|
||||
# 12. expand input_ids with `num_return_sequences` additional sequences per batch
|
||||
# 11. expand input_ids with `num_return_sequences` additional sequences per batch
|
||||
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
||||
input_ids=input_ids,
|
||||
expand_size=generation_config.num_return_sequences,
|
||||
@ -2020,11 +1992,10 @@ class GenerationMixin:
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
# 13. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
|
||||
# 12. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
|
||||
result = self._sample(
|
||||
input_ids,
|
||||
logits_processor=prepared_logits_processor,
|
||||
logits_warper=prepared_logits_warper,
|
||||
stopping_criteria=prepared_stopping_criteria,
|
||||
generation_config=generation_config,
|
||||
synced_gpus=synced_gpus,
|
||||
@ -2033,14 +2004,7 @@ class GenerationMixin:
|
||||
)
|
||||
|
||||
elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
|
||||
# 11. prepare logits warper
|
||||
prepared_logits_warper = (
|
||||
self._get_logits_warper(generation_config, device=input_ids.device)
|
||||
if generation_config.do_sample
|
||||
else None
|
||||
)
|
||||
|
||||
# 12. prepare beam search scorer
|
||||
# 11. prepare beam search scorer
|
||||
beam_scorer = BeamSearchScorer(
|
||||
batch_size=batch_size,
|
||||
num_beams=generation_config.num_beams,
|
||||
@ -2051,7 +2015,7 @@ class GenerationMixin:
|
||||
max_length=generation_config.max_length,
|
||||
)
|
||||
|
||||
# 13. interleave input_ids with `num_beams` additional sequences per batch
|
||||
# 12. interleave input_ids with `num_beams` additional sequences per batch
|
||||
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
||||
input_ids=input_ids,
|
||||
expand_size=generation_config.num_beams,
|
||||
@ -2059,12 +2023,11 @@ class GenerationMixin:
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
# 14. run beam sample
|
||||
# 13. run beam sample
|
||||
result = self._beam_search(
|
||||
input_ids,
|
||||
beam_scorer,
|
||||
logits_processor=prepared_logits_processor,
|
||||
logits_warper=prepared_logits_warper,
|
||||
stopping_criteria=prepared_stopping_criteria,
|
||||
generation_config=generation_config,
|
||||
synced_gpus=synced_gpus,
|
||||
@ -2287,7 +2250,6 @@ class GenerationMixin:
|
||||
generation_config: GenerationConfig,
|
||||
synced_gpus: bool,
|
||||
streamer: "BaseStreamer",
|
||||
logits_warper: Optional[LogitsProcessorList],
|
||||
**model_kwargs,
|
||||
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
|
||||
r"""
|
||||
@ -2316,10 +2278,6 @@ class GenerationMixin:
|
||||
streamer (`BaseStreamer`, *optional*):
|
||||
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
|
||||
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
|
||||
logits_warper (`LogitsProcessorList`, *optional*):
|
||||
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
|
||||
to warp the prediction score distribution of the language modeling head applied before multinomial
|
||||
sampling at each generation step.
|
||||
model_kwargs:
|
||||
Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
|
||||
If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
|
||||
@ -2344,11 +2302,6 @@ class GenerationMixin:
|
||||
return_dict_in_generate = generation_config.return_dict_in_generate
|
||||
has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
|
||||
do_sample = generation_config.do_sample
|
||||
if do_sample is True and not isinstance(logits_warper, LogitsProcessorList):
|
||||
raise ValueError(
|
||||
"`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is "
|
||||
f"{logits_warper})."
|
||||
)
|
||||
|
||||
# init attention / hidden states / scores tuples
|
||||
scores = () if (return_dict_in_generate and output_scores) else None
|
||||
@ -2436,8 +2389,7 @@ class GenerationMixin:
|
||||
)
|
||||
# pre-process distribution
|
||||
next_token_scores = logits_processor(input_ids, next_token_logits)
|
||||
if do_sample: # sample
|
||||
next_token_scores = logits_warper(input_ids, next_token_scores)
|
||||
|
||||
# Store scores, attentions and hidden_states when required
|
||||
if return_dict_in_generate:
|
||||
if output_scores:
|
||||
@ -2893,7 +2845,6 @@ class GenerationMixin:
|
||||
generation_config: GenerationConfig,
|
||||
synced_gpus: bool,
|
||||
streamer: Optional["BaseStreamer"],
|
||||
logits_warper: Optional[LogitsProcessorList],
|
||||
**model_kwargs,
|
||||
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
|
||||
r"""
|
||||
@ -2916,11 +2867,6 @@ class GenerationMixin:
|
||||
streamer (`BaseStreamer`, *optional*):
|
||||
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
|
||||
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
|
||||
logits_warper (`LogitsProcessorList`, *optional*):
|
||||
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
|
||||
to warp the prediction score distribution of the language modeling head applied before multinomial
|
||||
sampling at each generation step. Only required with sampling strategies (i.e. `do_sample` is set in
|
||||
`generation_config`)
|
||||
model_kwargs:
|
||||
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
|
||||
an encoder-decoder model the kwargs should include `encoder_outputs`.
|
||||
@ -2942,11 +2888,6 @@ class GenerationMixin:
|
||||
max_length = generation_config.max_length
|
||||
has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
|
||||
do_sample = generation_config.do_sample
|
||||
if do_sample is True and not isinstance(logits_warper, LogitsProcessorList):
|
||||
raise ValueError(
|
||||
"`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is "
|
||||
f"{logits_warper})."
|
||||
)
|
||||
|
||||
# init attention / hidden states / scores tuples
|
||||
scores = () if (return_dict_in_generate and output_scores) else None
|
||||
@ -2990,8 +2931,6 @@ class GenerationMixin:
|
||||
|
||||
# pre-process distribution
|
||||
next_token_scores = logits_processor(input_ids, next_token_logits)
|
||||
if do_sample:
|
||||
next_token_scores = logits_warper(input_ids, next_token_scores)
|
||||
|
||||
# Store scores, attentions and hidden_states when required
|
||||
if return_dict_in_generate:
|
||||
@ -3105,7 +3044,6 @@ class GenerationMixin:
|
||||
stopping_criteria: StoppingCriteriaList,
|
||||
generation_config: GenerationConfig,
|
||||
synced_gpus: bool,
|
||||
logits_warper: Optional[LogitsProcessorList],
|
||||
**model_kwargs,
|
||||
) -> Union[GenerateBeamOutput, torch.LongTensor]:
|
||||
r"""
|
||||
@ -3128,11 +3066,6 @@ class GenerationMixin:
|
||||
The generation configuration to be used as parametrization of the decoding method.
|
||||
synced_gpus (`bool`):
|
||||
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
|
||||
logits_warper (`LogitsProcessorList`, *optional*):
|
||||
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
|
||||
to warp the prediction score distribution of the language modeling head applied before multinomial
|
||||
sampling at each generation step. Only required with sampling strategies (i.e. `do_sample` is set in
|
||||
`generation_config`)
|
||||
model_kwargs:
|
||||
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
|
||||
an encoder-decoder model the kwargs should include `encoder_outputs`.
|
||||
@ -3154,11 +3087,6 @@ class GenerationMixin:
|
||||
return_dict_in_generate = generation_config.return_dict_in_generate
|
||||
sequential = generation_config.low_memory
|
||||
do_sample = generation_config.do_sample
|
||||
if do_sample is True and not isinstance(logits_warper, LogitsProcessorList):
|
||||
raise ValueError(
|
||||
"`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is "
|
||||
f"{logits_warper})."
|
||||
)
|
||||
|
||||
batch_size = len(beam_scorer._beam_hyps)
|
||||
num_beams = beam_scorer.num_beams
|
||||
@ -3249,8 +3177,6 @@ class GenerationMixin:
|
||||
) # (batch_size * num_beams, vocab_size)
|
||||
|
||||
next_token_scores_processed = logits_processor(input_ids, next_token_scores)
|
||||
if do_sample:
|
||||
next_token_scores_processed = logits_warper(input_ids, next_token_scores_processed)
|
||||
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(
|
||||
next_token_scores_processed
|
||||
)
|
||||
@ -3698,10 +3624,6 @@ class GenerationMixin:
|
||||
stopping_criteria (`StoppingCriteriaList`):
|
||||
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
|
||||
used to tell if the generation loop should stop.
|
||||
logits_warper (`LogitsProcessorList`):
|
||||
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
|
||||
to warp the prediction score distribution of the language modeling head applied before multinomial
|
||||
sampling at each generation step.
|
||||
generation_config ([`~generation.GenerationConfig`]):
|
||||
The generation configuration to be used as parametrization of the decoding method.
|
||||
synced_gpus (`bool`):
|
||||
@ -3915,7 +3837,6 @@ class GenerationMixin:
|
||||
input_ids: torch.LongTensor,
|
||||
candidate_generator: CandidateGenerator,
|
||||
logits_processor: LogitsProcessorList,
|
||||
logits_warper: LogitsProcessorList,
|
||||
stopping_criteria: StoppingCriteriaList,
|
||||
generation_config: GenerationConfig,
|
||||
synced_gpus: bool,
|
||||
@ -3937,10 +3858,6 @@ class GenerationMixin:
|
||||
logits_processor (`LogitsProcessorList`):
|
||||
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
|
||||
used to modify the prediction scores of the language modeling head applied at each generation step.
|
||||
logits_warper (`LogitsProcessorList`):
|
||||
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
|
||||
to warp the prediction score distribution of the language modeling head applied before multinomial
|
||||
sampling at each generation step. Only used if sampling is active.
|
||||
stopping_criteria (`StoppingCriteriaList`):
|
||||
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
|
||||
used to tell if the generation loop should stop.
|
||||
@ -3963,7 +3880,7 @@ class GenerationMixin:
|
||||
`model.config.is_encoder_decoder=True`.
|
||||
"""
|
||||
# init values
|
||||
do_sample = logits_warper is not None
|
||||
do_sample = generation_config.do_sample
|
||||
output_attentions = generation_config.output_attentions
|
||||
output_hidden_states = generation_config.output_hidden_states
|
||||
output_scores = generation_config.output_scores
|
||||
@ -4047,9 +3964,6 @@ class GenerationMixin:
|
||||
if len(logits_processor) > 0:
|
||||
for i in range(candidate_length + 1):
|
||||
new_logits[:, i, :] = logits_processor(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :])
|
||||
if do_sample and len(logits_warper) > 0:
|
||||
for i in range(candidate_length + 1):
|
||||
new_logits[:, i, :] = logits_warper(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :])
|
||||
|
||||
# 3. Select the accepted tokens. There are two possible cases:
|
||||
# Case 1: `do_sample=True` and we have logits for the candidates (originally from speculative decoding)
|
||||
|
@ -56,9 +56,9 @@ class BarkSemanticGenerationConfig(GenerationConfig):
|
||||
eos_token_id (`int`, *optional*, defaults to 10_000):
|
||||
The id of the *end-of-sequence* token.
|
||||
renormalize_logits (`bool`, *optional*, defaults to `True`):
|
||||
Whether to renormalize the logits after applying all the logits processors or warpers (including the
|
||||
Whether to renormalize the logits after applying all the logits processors (including the
|
||||
custom ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the
|
||||
score logits are normalized but some logit processors or warpers break the normalization.
|
||||
score logits are normalized but some logit processors break the normalization.
|
||||
max_new_tokens (`int`, *optional*, defaults to 768):
|
||||
The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.
|
||||
output_scores (`bool`, *optional*, defaults to `False`):
|
||||
@ -143,9 +143,9 @@ class BarkCoarseGenerationConfig(GenerationConfig):
|
||||
|
||||
Args:
|
||||
renormalize_logits (`bool`, *optional*, defaults to `True`):
|
||||
Whether to renormalize the logits after applying all the logits processors or warpers (including the
|
||||
Whether to renormalize the logits after applying all the logits processors (including the
|
||||
custom ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the
|
||||
score logits are normalized but some logit processors or warpers break the normalization.
|
||||
score logits are normalized but some logit processors break the normalization.
|
||||
output_scores (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
|
||||
return_dict_in_generate (`bool`, *optional*, defaults to `False`):
|
||||
|
@ -1609,13 +1609,6 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
|
||||
)
|
||||
|
||||
if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
|
||||
# 11. prepare logits warper
|
||||
prepared_logits_warper = (
|
||||
self._get_logits_warper(generation_config, device=input_ids.device)
|
||||
if generation_config.do_sample
|
||||
else None
|
||||
)
|
||||
|
||||
# expand input_ids with `num_return_sequences` additional sequences per batch
|
||||
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
||||
input_ids=input_ids,
|
||||
@ -1623,11 +1616,10 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
# 12. run sample
|
||||
# 11. run sample
|
||||
outputs = self._sample(
|
||||
input_ids,
|
||||
logits_processor=logits_processor,
|
||||
logits_warper=prepared_logits_warper,
|
||||
stopping_criteria=stopping_criteria,
|
||||
generation_config=generation_config,
|
||||
synced_gpus=synced_gpus,
|
||||
@ -2649,13 +2641,6 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
|
||||
)
|
||||
|
||||
if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
|
||||
# 11. prepare logits warper
|
||||
prepared_logits_warper = (
|
||||
self._get_logits_warper(generation_config, device=input_ids.device)
|
||||
if generation_config.do_sample
|
||||
else None
|
||||
)
|
||||
|
||||
# expand input_ids with `num_return_sequences` additional sequences per batch
|
||||
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
||||
input_ids=input_ids,
|
||||
@ -2664,11 +2649,10 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
# 12. run sample
|
||||
# 11. run sample
|
||||
outputs = self._sample(
|
||||
input_ids,
|
||||
logits_processor=logits_processor,
|
||||
logits_warper=prepared_logits_warper,
|
||||
stopping_criteria=stopping_criteria,
|
||||
generation_config=generation_config,
|
||||
synced_gpus=synced_gpus,
|
||||
|
@ -1531,13 +1531,6 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
|
||||
)
|
||||
|
||||
if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
|
||||
# 11. prepare logits warper
|
||||
prepared_logits_warper = (
|
||||
self._get_logits_warper(generation_config, device=input_ids.device)
|
||||
if generation_config.do_sample
|
||||
else None
|
||||
)
|
||||
|
||||
# expand input_ids with `num_return_sequences` additional sequences per batch
|
||||
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
||||
input_ids=input_ids,
|
||||
@ -1545,11 +1538,10 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
# 12. run sample
|
||||
# 11. run sample
|
||||
outputs = self._sample(
|
||||
input_ids,
|
||||
logits_processor=logits_processor,
|
||||
logits_warper=prepared_logits_warper,
|
||||
stopping_criteria=stopping_criteria,
|
||||
generation_config=generation_config,
|
||||
synced_gpus=synced_gpus,
|
||||
@ -2490,13 +2482,6 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
|
||||
)
|
||||
|
||||
if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
|
||||
# 11. prepare logits warper
|
||||
prepared_logits_warper = (
|
||||
self._get_logits_warper(generation_config, device=input_ids.device)
|
||||
if generation_config.do_sample
|
||||
else None
|
||||
)
|
||||
|
||||
# expand input_ids with `num_return_sequences` additional sequences per batch
|
||||
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
||||
input_ids=input_ids,
|
||||
@ -2505,11 +2490,10 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
# 12. run sample
|
||||
# 11. run sample
|
||||
outputs = self._sample(
|
||||
input_ids,
|
||||
logits_processor=logits_processor,
|
||||
logits_warper=prepared_logits_warper,
|
||||
stopping_criteria=stopping_criteria,
|
||||
generation_config=generation_config,
|
||||
synced_gpus=synced_gpus,
|
||||
|
@ -1558,7 +1558,6 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
generation_config=generation_config,
|
||||
synced_gpus=False,
|
||||
streamer=None,
|
||||
logits_warper=None,
|
||||
**model_kwargs,
|
||||
)
|
||||
elif generation_config.num_beams > 1:
|
||||
@ -1580,7 +1579,6 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
stopping_criteria=prepared_stopping_criteria,
|
||||
generation_config=generation_config,
|
||||
synced_gpus=False,
|
||||
logits_warper=None,
|
||||
**model_kwargs,
|
||||
)
|
||||
else:
|
||||
|
@ -118,26 +118,24 @@ class GenerationTesterMixin:
|
||||
|
||||
return config, input_ids, attention_mask
|
||||
|
||||
@staticmethod
|
||||
def _get_logits_processor_and_warper_kwargs(
|
||||
input_length,
|
||||
forced_bos_token_id=None,
|
||||
forced_eos_token_id=None,
|
||||
):
|
||||
process_kwargs = {
|
||||
def _get_logits_processor_kwargs(self, do_sample=False):
|
||||
logits_processor_kwargs = {
|
||||
"bad_words_ids": [[1, 0]],
|
||||
"repetition_penalty": 1.2,
|
||||
"remove_invalid_values": True,
|
||||
}
|
||||
# NoRepeatNGramLogitsProcessor + forced tokens may result in no valid continuations
|
||||
if forced_bos_token_id is None and forced_eos_token_id is None:
|
||||
process_kwargs["no_repeat_ngram_size"] = 2
|
||||
if do_sample:
|
||||
logits_processor_kwargs.update(
|
||||
{
|
||||
"top_k": 10,
|
||||
"top_p": 0.7,
|
||||
"temperature": 0.7,
|
||||
}
|
||||
)
|
||||
|
||||
warp_kwargs = {"top_k": 10, "top_p": 0.7, "temperature": 0.7}
|
||||
return process_kwargs, warp_kwargs
|
||||
return logits_processor_kwargs
|
||||
|
||||
@staticmethod
|
||||
def _get_beam_kwargs(num_return_sequences=1):
|
||||
def _get_beam_kwargs(self, num_return_sequences=1):
|
||||
beam_kwargs = {
|
||||
"early_stopping": False,
|
||||
"length_penalty": 2.0,
|
||||
@ -146,8 +144,7 @@ class GenerationTesterMixin:
|
||||
}
|
||||
return beam_kwargs
|
||||
|
||||
@staticmethod
|
||||
def _get_diverse_beam_kwargs(num_return_sequences=1):
|
||||
def _get_diverse_beam_kwargs(self, num_return_sequences=1):
|
||||
beam_kwargs = {
|
||||
"early_stopping": False,
|
||||
"length_penalty": 2.0,
|
||||
@ -158,8 +155,7 @@ class GenerationTesterMixin:
|
||||
}
|
||||
return beam_kwargs
|
||||
|
||||
@staticmethod
|
||||
def _get_constrained_beam_kwargs(num_return_sequences=1):
|
||||
def _get_constrained_beam_kwargs(self, num_return_sequences=1):
|
||||
beam_kwargs = {
|
||||
"early_stopping": False,
|
||||
"length_penalty": 2.0,
|
||||
@ -199,12 +195,7 @@ class GenerationTesterMixin:
|
||||
output_hidden_states=False,
|
||||
return_dict_in_generate=False,
|
||||
):
|
||||
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
||||
input_ids.shape[-1],
|
||||
forced_bos_token_id=model.config.forced_bos_token_id,
|
||||
forced_eos_token_id=model.config.forced_eos_token_id,
|
||||
)
|
||||
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
output_generate = model.generate(
|
||||
input_ids,
|
||||
@ -216,7 +207,7 @@ class GenerationTesterMixin:
|
||||
output_scores=output_scores,
|
||||
output_logits=output_logits,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
**logits_process_kwargs,
|
||||
**logits_processor_kwargs,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
@ -228,8 +219,6 @@ class GenerationTesterMixin:
|
||||
input_ids,
|
||||
attention_mask,
|
||||
num_return_sequences,
|
||||
logits_warper_kwargs,
|
||||
process_kwargs,
|
||||
output_scores=False,
|
||||
output_logits=False,
|
||||
output_attentions=False,
|
||||
@ -237,6 +226,7 @@ class GenerationTesterMixin:
|
||||
return_dict_in_generate=False,
|
||||
):
|
||||
torch.manual_seed(0)
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True)
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
output_generate = model.generate(
|
||||
input_ids,
|
||||
@ -249,8 +239,7 @@ class GenerationTesterMixin:
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
**logits_warper_kwargs,
|
||||
**process_kwargs,
|
||||
**logits_processor_kwargs,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
@ -262,13 +251,13 @@ class GenerationTesterMixin:
|
||||
input_ids,
|
||||
attention_mask,
|
||||
beam_kwargs,
|
||||
logits_process_kwargs,
|
||||
output_scores=False,
|
||||
output_logits=False,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict_in_generate=False,
|
||||
):
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
output_generate = model.generate(
|
||||
input_ids,
|
||||
@ -280,7 +269,7 @@ class GenerationTesterMixin:
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
**beam_kwargs,
|
||||
**logits_process_kwargs,
|
||||
**logits_processor_kwargs,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
@ -292,7 +281,6 @@ class GenerationTesterMixin:
|
||||
input_ids,
|
||||
attention_mask,
|
||||
beam_kwargs,
|
||||
logits_warper_kwargs,
|
||||
output_scores=False,
|
||||
output_logits=False,
|
||||
output_attentions=False,
|
||||
@ -300,6 +288,7 @@ class GenerationTesterMixin:
|
||||
return_dict_in_generate=False,
|
||||
):
|
||||
torch.manual_seed(0)
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True)
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
output_generate = model.generate(
|
||||
input_ids,
|
||||
@ -311,7 +300,7 @@ class GenerationTesterMixin:
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
**beam_kwargs,
|
||||
**logits_warper_kwargs,
|
||||
**logits_processor_kwargs,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
@ -323,13 +312,13 @@ class GenerationTesterMixin:
|
||||
input_ids,
|
||||
attention_mask,
|
||||
beam_kwargs,
|
||||
logits_process_kwargs,
|
||||
output_scores=False,
|
||||
output_logits=False,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict_in_generate=False,
|
||||
):
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
output_generate = model.generate(
|
||||
input_ids,
|
||||
@ -341,7 +330,7 @@ class GenerationTesterMixin:
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
**beam_kwargs,
|
||||
**logits_process_kwargs,
|
||||
**logits_processor_kwargs,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
@ -354,13 +343,13 @@ class GenerationTesterMixin:
|
||||
attention_mask,
|
||||
constraints,
|
||||
beam_kwargs,
|
||||
logits_process_kwargs,
|
||||
output_scores=False,
|
||||
output_logits=False,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict_in_generate=False,
|
||||
):
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
output_generate = model.generate(
|
||||
input_ids,
|
||||
@ -373,7 +362,7 @@ class GenerationTesterMixin:
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
constraints=constraints,
|
||||
**beam_kwargs,
|
||||
**logits_process_kwargs,
|
||||
**logits_processor_kwargs,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
@ -395,12 +384,7 @@ class GenerationTesterMixin:
|
||||
"top_k": 5,
|
||||
}
|
||||
|
||||
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
||||
input_ids.shape[-1],
|
||||
forced_bos_token_id=model.config.forced_bos_token_id,
|
||||
forced_eos_token_id=model.config.forced_eos_token_id,
|
||||
)
|
||||
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
output_generate = model.generate(
|
||||
input_ids,
|
||||
@ -412,7 +396,7 @@ class GenerationTesterMixin:
|
||||
output_scores=output_scores,
|
||||
output_logits=output_logits,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
**logits_process_kwargs,
|
||||
**logits_processor_kwargs,
|
||||
**model_kwargs,
|
||||
**contrastive_search_kwargs,
|
||||
)
|
||||
@ -495,19 +479,11 @@ class GenerationTesterMixin:
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
process_kwargs, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(
|
||||
input_ids.shape[-1],
|
||||
forced_bos_token_id=model.config.forced_bos_token_id,
|
||||
forced_eos_token_id=model.config.forced_eos_token_id,
|
||||
)
|
||||
|
||||
output_generate = self._sample_generate(
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
num_return_sequences=1,
|
||||
logits_warper_kwargs=logits_warper_kwargs,
|
||||
process_kwargs=process_kwargs,
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
@ -521,20 +497,11 @@ class GenerationTesterMixin:
|
||||
|
||||
config.use_cache = False
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
|
||||
process_kwargs, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(
|
||||
input_ids.shape[-1],
|
||||
forced_bos_token_id=model.config.forced_bos_token_id,
|
||||
forced_eos_token_id=model.config.forced_eos_token_id,
|
||||
)
|
||||
|
||||
output_generate = self._sample_generate(
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
num_return_sequences=2,
|
||||
logits_warper_kwargs=logits_warper_kwargs,
|
||||
process_kwargs=process_kwargs,
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
output_hidden_states=True,
|
||||
@ -561,19 +528,12 @@ class GenerationTesterMixin:
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
|
||||
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
||||
input_ids.shape[-1],
|
||||
config.forced_bos_token_id,
|
||||
config.forced_eos_token_id,
|
||||
)
|
||||
beam_kwargs = self._get_beam_kwargs()
|
||||
|
||||
output_generate = self._beam_search_generate(
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
beam_kwargs=beam_kwargs,
|
||||
logits_process_kwargs=logits_process_kwargs,
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
@ -589,18 +549,12 @@ class GenerationTesterMixin:
|
||||
config.use_cache = False
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
||||
input_ids.shape[-1],
|
||||
config.forced_bos_token_id,
|
||||
config.forced_eos_token_id,
|
||||
)
|
||||
beam_kwargs = self._get_beam_kwargs()
|
||||
output_generate = self._beam_search_generate(
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
beam_kwargs=beam_kwargs,
|
||||
logits_process_kwargs=logits_process_kwargs,
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
output_hidden_states=True,
|
||||
@ -633,12 +587,6 @@ class GenerationTesterMixin:
|
||||
self.skipTest(reason="Won't fix: model with non-standard dictionary output shapes")
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
||||
input_ids.shape[-1],
|
||||
config.forced_bos_token_id,
|
||||
config.forced_eos_token_id,
|
||||
)
|
||||
|
||||
beam_kwargs = self._get_beam_kwargs()
|
||||
|
||||
config.use_cache = True
|
||||
@ -649,7 +597,6 @@ class GenerationTesterMixin:
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
beam_kwargs=beam_kwargs,
|
||||
logits_process_kwargs=logits_process_kwargs,
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
output_hidden_states=True,
|
||||
@ -693,17 +640,13 @@ class GenerationTesterMixin:
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
|
||||
_, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(input_ids.shape[-1])
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
beam_kwargs = self._get_beam_kwargs()
|
||||
|
||||
output_generate = self._beam_sample_generate(
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
beam_kwargs=beam_kwargs,
|
||||
logits_warper_kwargs=logits_warper_kwargs,
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
@ -711,7 +654,13 @@ class GenerationTesterMixin:
|
||||
else:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
|
||||
|
||||
if "inputs_embeds" in set(inspect.signature(model.prepare_inputs_for_generation).parameters):
|
||||
prepare_inputs_for_generation_args = set(inspect.signature(model.prepare_inputs_for_generation).parameters)
|
||||
# `inputs_embeds` input is well supported when `cache_positions` is used, because it means the modeling
|
||||
# code is up to date with our most recent standards
|
||||
if (
|
||||
"inputs_embeds" in prepare_inputs_for_generation_args
|
||||
and "cache_positions" in prepare_inputs_for_generation_args
|
||||
):
|
||||
input_embeds = model.get_input_embeddings()(input_ids)
|
||||
beam_kwargs.update({"inputs_embeds": input_embeds})
|
||||
output_generate2 = self._beam_sample_generate(
|
||||
@ -719,7 +668,6 @@ class GenerationTesterMixin:
|
||||
input_ids=None,
|
||||
attention_mask=attention_mask,
|
||||
beam_kwargs=beam_kwargs,
|
||||
logits_warper_kwargs=logits_warper_kwargs,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(output_generate[:, input_embeds.shape[1] :], output_generate2)
|
||||
@ -732,7 +680,6 @@ class GenerationTesterMixin:
|
||||
config.use_cache = False
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
_, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(input_ids.shape[-1])
|
||||
beam_kwargs = self._get_beam_kwargs()
|
||||
|
||||
output_generate = self._beam_sample_generate(
|
||||
@ -740,7 +687,6 @@ class GenerationTesterMixin:
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
beam_kwargs=beam_kwargs,
|
||||
logits_warper_kwargs=logits_warper_kwargs,
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
output_hidden_states=True,
|
||||
@ -788,12 +734,6 @@ class GenerationTesterMixin:
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
||||
input_ids.shape[-1],
|
||||
config.forced_bos_token_id,
|
||||
config.forced_eos_token_id,
|
||||
)
|
||||
|
||||
# check `generate()` and `group_beam_search()` are equal
|
||||
beam_kwargs = self._get_diverse_beam_kwargs()
|
||||
output_generate = self._group_beam_search_generate(
|
||||
@ -801,7 +741,6 @@ class GenerationTesterMixin:
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
beam_kwargs=beam_kwargs,
|
||||
logits_process_kwargs=logits_process_kwargs,
|
||||
)
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
||||
@ -816,7 +755,6 @@ class GenerationTesterMixin:
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
beam_kwargs=beam_kwargs,
|
||||
logits_process_kwargs=logits_process_kwargs,
|
||||
)
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
||||
@ -829,19 +767,12 @@ class GenerationTesterMixin:
|
||||
config.use_cache = False
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
||||
input_ids.shape[-1],
|
||||
config.forced_bos_token_id,
|
||||
config.forced_eos_token_id,
|
||||
)
|
||||
|
||||
beam_kwargs = self._get_diverse_beam_kwargs()
|
||||
output_generate = self._group_beam_search_generate(
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
beam_kwargs=beam_kwargs,
|
||||
logits_process_kwargs=logits_process_kwargs,
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
output_hidden_states=True,
|
||||
@ -871,12 +802,6 @@ class GenerationTesterMixin:
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
|
||||
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
||||
input_ids.shape[-1],
|
||||
config.forced_bos_token_id,
|
||||
config.forced_eos_token_id,
|
||||
)
|
||||
|
||||
# Sample constraints
|
||||
min_id = 3
|
||||
max_id = config.vocab_size
|
||||
@ -893,7 +818,6 @@ class GenerationTesterMixin:
|
||||
attention_mask=attention_mask,
|
||||
constraints=constraints,
|
||||
beam_kwargs=beam_kwargs,
|
||||
logits_process_kwargs=logits_process_kwargs,
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
@ -919,7 +843,6 @@ class GenerationTesterMixin:
|
||||
attention_mask=attention_mask,
|
||||
constraints=constraints,
|
||||
beam_kwargs=beam_kwargs,
|
||||
logits_process_kwargs=logits_process_kwargs,
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
@ -938,11 +861,6 @@ class GenerationTesterMixin:
|
||||
config.use_cache = False
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
||||
input_ids.shape[-1],
|
||||
config.forced_bos_token_id,
|
||||
config.forced_eos_token_id,
|
||||
)
|
||||
|
||||
# Sample constraints
|
||||
min_id = 3
|
||||
@ -959,7 +877,6 @@ class GenerationTesterMixin:
|
||||
attention_mask=attention_mask,
|
||||
constraints=constraints,
|
||||
beam_kwargs=beam_kwargs,
|
||||
logits_process_kwargs=logits_process_kwargs,
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
output_hidden_states=True,
|
||||
|
@ -414,10 +414,6 @@ class BioGptModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||
|
||||
@unittest.skip(reason="The `input_embeds` when fed don't produce the same results.")
|
||||
def test_beam_sample_generate(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class BioGptModelIntegrationTest(unittest.TestCase):
|
||||
|
@ -433,6 +433,10 @@ class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
|
||||
|
||||
@unittest.skip("The `input_embeds` when fed don't produce the same results.")
|
||||
def test_beam_sample_generate(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class MambaIntegrationTests(unittest.TestCase):
|
||||
|
@ -283,6 +283,12 @@ class Mamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
|
||||
|
||||
@unittest.skip(
|
||||
reason="Mamba2 does not support generating with input embeddings (custom cache_position computation)"
|
||||
)
|
||||
def test_inputs_embeds_matches_input_ids_with_generate(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
|
@ -293,15 +293,9 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
||||
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
|
||||
return config, input_ids, attention_mask
|
||||
|
||||
@staticmethod
|
||||
def _get_logits_processor_and_warper_kwargs(
|
||||
input_length,
|
||||
forced_bos_token_id=None,
|
||||
forced_eos_token_id=None,
|
||||
):
|
||||
process_kwargs = {}
|
||||
warper_kwargs = {}
|
||||
return process_kwargs, warper_kwargs
|
||||
def _get_logits_processor_kwargs(self, do_sample=False):
|
||||
logits_processor_kwargs = {}
|
||||
return logits_processor_kwargs
|
||||
|
||||
def test_greedy_generate_stereo_outputs(self):
|
||||
for model_class in self.greedy_sample_model_classes:
|
||||
@ -1483,15 +1477,9 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
|
||||
return output_generate
|
||||
|
||||
@staticmethod
|
||||
def _get_logits_processor_and_warper_kwargs(
|
||||
input_length,
|
||||
forced_bos_token_id=None,
|
||||
forced_eos_token_id=None,
|
||||
):
|
||||
process_kwargs = {}
|
||||
warper_kwargs = {}
|
||||
return process_kwargs, warper_kwargs
|
||||
def _get_logits_processor_kwargs(self, do_sample=False):
|
||||
logits_processor_kwargs = {}
|
||||
return logits_processor_kwargs
|
||||
|
||||
def test_greedy_generate_dict_outputs(self):
|
||||
for model_class in self.greedy_sample_model_classes:
|
||||
|
@ -296,15 +296,9 @@ class MusicgenMelodyDecoderTest(ModelTesterMixin, GenerationTesterMixin, unittes
|
||||
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
|
||||
return config, input_ids, attention_mask
|
||||
|
||||
@staticmethod
|
||||
def _get_logits_processor_and_warper_kwargs(
|
||||
input_length,
|
||||
forced_bos_token_id=None,
|
||||
forced_eos_token_id=None,
|
||||
):
|
||||
process_kwargs = {}
|
||||
warper_kwargs = {}
|
||||
return process_kwargs, warper_kwargs
|
||||
def _get_logits_processor_kwargs(self, do_sample=False):
|
||||
logits_processor_kwargs = {}
|
||||
return logits_processor_kwargs
|
||||
|
||||
def test_greedy_generate_stereo_outputs(self):
|
||||
for model_class in self.greedy_sample_model_classes:
|
||||
@ -1467,15 +1461,9 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
||||
|
||||
return output_generate
|
||||
|
||||
@staticmethod
|
||||
def _get_logits_processor_and_warper_kwargs(
|
||||
input_length,
|
||||
forced_bos_token_id=None,
|
||||
forced_eos_token_id=None,
|
||||
):
|
||||
process_kwargs = {}
|
||||
warper_kwargs = {}
|
||||
return process_kwargs, warper_kwargs
|
||||
def _get_logits_processor_kwargs(self, do_sample=False):
|
||||
logits_processor_kwargs = {}
|
||||
return logits_processor_kwargs
|
||||
|
||||
def test_greedy_generate_dict_outputs(self):
|
||||
for model_class in self.greedy_sample_model_classes:
|
||||
|
@ -413,6 +413,10 @@ class RecurrentGemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
|
||||
def test_initialization(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="RecurrentGemma does not support generating with input embeddings (missing position_ids)")
|
||||
def test_inputs_embeds_matches_input_ids_with_generate(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
@slow
|
||||
|
@ -68,14 +68,7 @@ if is_torch_available():
|
||||
set_seed,
|
||||
)
|
||||
from transformers.generation import (
|
||||
BeamSampleDecoderOnlyOutput,
|
||||
BeamSampleEncoderDecoderOutput,
|
||||
BeamSearchDecoderOnlyOutput,
|
||||
BeamSearchEncoderDecoderOutput,
|
||||
GenerateBeamDecoderOnlyOutput,
|
||||
GenerateBeamEncoderDecoderOutput,
|
||||
GenerateEncoderDecoderOutput,
|
||||
PhrasalConstraint,
|
||||
)
|
||||
from transformers.generation.logits_process import LogitsProcessor
|
||||
from transformers.models.whisper.modeling_whisper import WhisperDecoder, WhisperEncoder, sinusoids
|
||||
@ -419,6 +412,30 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
|
||||
return False
|
||||
|
||||
def _get_logits_processor_kwargs(self, do_sample=False):
|
||||
# Overwritten from `GenerationTesterMixin`, Whisper needs `"temperature": 0.0` to be able to do beam search
|
||||
logits_processor_kwargs = super()._get_logits_processor_kwargs(do_sample=do_sample)
|
||||
logits_processor_kwargs["temperature"] = 0.0
|
||||
return logits_processor_kwargs
|
||||
|
||||
def _get_beam_kwargs(self, num_return_sequences=1):
|
||||
# Overwritten from `GenerationTesterMixin`, Whisper's `num_return_sequences` differs from the core `generate`
|
||||
beam_kwargs = super()._get_beam_kwargs(num_return_sequences=num_return_sequences)
|
||||
beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"]
|
||||
return beam_kwargs
|
||||
|
||||
def _get_diverse_beam_kwargs(self, num_return_sequences=1):
|
||||
# Overwritten from `GenerationTesterMixin`, Whisper's `num_return_sequences` differs from the core `generate`
|
||||
beam_kwargs = super()._get_diverse_beam_kwargs(num_return_sequences=num_return_sequences)
|
||||
beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"]
|
||||
return beam_kwargs
|
||||
|
||||
def _get_constrained_beam_kwargs(self, num_return_sequences=1):
|
||||
# Overwritten from `GenerationTesterMixin`, Whisper's `num_return_sequences` differs from the core `generate`
|
||||
beam_kwargs = super()._get_constrained_beam_kwargs(num_return_sequences=num_return_sequences)
|
||||
beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"]
|
||||
return beam_kwargs
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = WhisperModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=WhisperConfig)
|
||||
@ -1551,241 +1568,6 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
def test_longform_generate_multi_batch_cond_prev(self):
|
||||
self._check_longform_generate_multi_batch(condition_on_prev_tokens=True)
|
||||
|
||||
def test_beam_sample_generate_dict_output(self):
|
||||
# We overwrite test_beam_sample_generate_dict_output in test_utils as
|
||||
# we can only perform beam search if the temperature is set to 0 in Whisper.
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
|
||||
# disable cache
|
||||
config.use_cache = False
|
||||
|
||||
model = WhisperForConditionalGeneration(config).to(torch_device).eval()
|
||||
_, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(input_ids.shape[-1])
|
||||
beam_kwargs = self._get_beam_kwargs()
|
||||
|
||||
# With Whisper, we can only perform a beam search if the temperature is set to 0.
|
||||
logits_warper_kwargs["temperature"] = 0
|
||||
# We will return num_beams sequences per input only if num_return_sequences == num_beams:
|
||||
beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"]
|
||||
|
||||
output_generate = self._beam_sample_generate(
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
beam_kwargs=beam_kwargs,
|
||||
logits_warper_kwargs=logits_warper_kwargs,
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_generate, BeamSampleEncoderDecoderOutput)
|
||||
else:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
|
||||
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_generate, BeamSampleDecoderOnlyOutput)
|
||||
|
||||
self._check_outputs(output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"])
|
||||
|
||||
def test_beam_search_generate_dict_output(self):
|
||||
# We overwrite test_beam_search_generate_dict_output in test_utils as
|
||||
# we can only perform beam search if the temperature is set to 0 in Whisper.
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
|
||||
# disable cache
|
||||
config.use_cache = False
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
||||
input_ids.shape[-1],
|
||||
config.forced_bos_token_id,
|
||||
config.forced_eos_token_id,
|
||||
)
|
||||
beam_kwargs = self._get_beam_kwargs()
|
||||
|
||||
# With Whisper, we can only perform a beam search if the temperature is set to 0.
|
||||
logits_process_kwargs["temperature"] = 0
|
||||
# We will return num_beams sequences per input only if num_return_sequences == num_beams:
|
||||
beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"]
|
||||
|
||||
output_generate = self._beam_search_generate(
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
beam_kwargs=beam_kwargs,
|
||||
logits_process_kwargs=logits_process_kwargs,
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
|
||||
else:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
|
||||
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
|
||||
|
||||
self._check_outputs(
|
||||
output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"]
|
||||
)
|
||||
|
||||
def test_beam_search_generate_dict_outputs_use_cache(self):
|
||||
# We overwrite test_beam_search_generate_dict_outputs_use_cache in test_utils as
|
||||
# we can only perform beam search if the temperature is set to 0 in Whisper.
|
||||
for model_class in self.all_generative_model_classes:
|
||||
# enable cache
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
|
||||
if not hasattr(config, "use_cache"):
|
||||
self.skipTest("This model doesn't support caching")
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
||||
input_ids.shape[-1],
|
||||
config.forced_bos_token_id,
|
||||
config.forced_eos_token_id,
|
||||
)
|
||||
|
||||
beam_kwargs = self._get_beam_kwargs()
|
||||
|
||||
# We will return num_beams sequences per input only if num_return_sequences == num_beams:
|
||||
beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"]
|
||||
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
output_generate = self._beam_search_generate(
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
beam_kwargs=beam_kwargs,
|
||||
logits_process_kwargs=logits_process_kwargs,
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||
else:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
|
||||
self._check_outputs(
|
||||
output_generate, input_ids, model.config, use_cache=True, num_return_sequences=beam_kwargs["num_beams"]
|
||||
)
|
||||
|
||||
def test_group_beam_search_generate_dict_output(self):
|
||||
# We overwrite test_group_beam_search_generate_dict_output in test_utils as
|
||||
# we can only perform beam search if the temperature is set to 0 in Whisper.
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
config.use_cache = False
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
||||
input_ids.shape[-1],
|
||||
config.forced_bos_token_id,
|
||||
config.forced_eos_token_id,
|
||||
)
|
||||
|
||||
beam_kwargs = self._get_diverse_beam_kwargs()
|
||||
|
||||
# We will return num_beams sequences per input only if num_return_sequences == num_beams:
|
||||
beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"]
|
||||
|
||||
output_generate = self._group_beam_search_generate(
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
beam_kwargs=beam_kwargs,
|
||||
logits_process_kwargs=logits_process_kwargs,
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
|
||||
else:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
|
||||
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
|
||||
|
||||
self._check_outputs(
|
||||
output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"]
|
||||
)
|
||||
|
||||
def test_constrained_beam_search_generate_dict_output(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
|
||||
# disable cache
|
||||
config.use_cache = False
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
||||
input_ids.shape[-1],
|
||||
config.forced_bos_token_id,
|
||||
config.forced_eos_token_id,
|
||||
)
|
||||
|
||||
# Sample constraints
|
||||
min_id = 3
|
||||
max_id = model.config.vocab_size
|
||||
force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0]
|
||||
constraints = [
|
||||
PhrasalConstraint(force_tokens),
|
||||
]
|
||||
|
||||
beam_kwargs = self._get_constrained_beam_kwargs()
|
||||
output_generate = self._constrained_beam_search_generate(
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
constraints=constraints,
|
||||
beam_kwargs=beam_kwargs,
|
||||
logits_process_kwargs=logits_process_kwargs,
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
|
||||
else:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
|
||||
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
|
||||
|
||||
self._check_outputs(
|
||||
output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_return_sequences"]
|
||||
)
|
||||
|
||||
@is_flaky() # TODO (joao, sanchit): fails ~9% of the times. Does the original test have the same issue?
|
||||
def test_custom_4d_attention_mask(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
@ -70,6 +70,7 @@ OBJECTS_TO_IGNORE = [
|
||||
# Deprecated
|
||||
"InputExample",
|
||||
"InputFeatures",
|
||||
"LogitsWarper",
|
||||
# Signature is *args/**kwargs
|
||||
"TFSequenceSummary",
|
||||
"TFBertTokenizer",
|
||||
|
@ -932,6 +932,7 @@ DEPRECATED_OBJECTS = [
|
||||
"LineByLineTextDataset",
|
||||
"LineByLineWithRefDataset",
|
||||
"LineByLineWithSOPTextDataset",
|
||||
"LogitsWarper",
|
||||
"NerPipeline",
|
||||
"PretrainedBartModel",
|
||||
"PretrainedFSMTModel",
|
||||
|
Loading…
Reference in New Issue
Block a user