mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Adding new parameter to generate
: max_time
. (#9846)
* [WIP] Adding new parameter to `generate`: `max_time`. Generation by tokens number is sometimes a bit clunky because we don't know how many tokens are good enough or even how many tokens are in the payload (for pipelines users for instance). This leads to hard to understand behavior. This PR proposes a new argument `max_time` which is a float of seconds for the allowed time for `generate` to run on. Ideally combinations of `max_tokens=None`, `max_time=2` could be used to generate as many tokens as possible within time budget. NB: Another possible approach consists of passing a callback to `generate` putting the caller in charge of the actual decision of when to stop generating tokens. It opens the door to 'which args should we pass' to this callback. It's hard to imagine other use-cases for this early stopping behavior than time (that are not already covered by parameters of generate) * Revamp with StoppingCriteria * Removing deprecated mentions. * Forgot arguments to stopping criteria. * Readding max_length it's not just used as a stopping criteria. * Default value for `stopping_criteria`. * Address @patrickvonplaten comments. - More docstrings - Actual doc - Include in global namespace - Remove TF work. * Put back `max_length` (deprecation different PR). * Doc quality. * Fixing old behavior without `stopping_criteria` but with `max_length`. Making sure we don't break that in the future. * Adding more tests for possible inconsistencies between `max_length` and `stopping_criteria`. * Fixing the torch imports.
This commit is contained in:
parent
ea46e3fa9c
commit
543d0549f8
@ -151,6 +151,23 @@ generation.
|
||||
.. autoclass:: transformers.HammingDiversityLogitsProcessor
|
||||
:members: __call__
|
||||
|
||||
StoppingCriteria
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
A :class:`~transformers.StoppingCriteria` can be used to change when to stop generation (other than EOS token).
|
||||
|
||||
.. autoclass:: transformers.StoppingCriteria
|
||||
:members: __call__
|
||||
|
||||
.. autoclass:: transformers.StoppingCriteriaList
|
||||
:members: __call__
|
||||
|
||||
.. autoclass:: transformers.MaxLengthCriteria
|
||||
:members: __call__
|
||||
|
||||
.. autoclass:: transformers.MaxTimeCriteria
|
||||
:members: __call__
|
||||
|
||||
BeamSearch
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
@ -380,6 +380,12 @@ if is_torch_available():
|
||||
"TopKLogitsWarper",
|
||||
"TopPLogitsWarper",
|
||||
]
|
||||
_import_structure["generation_stopping_criteria"] = [
|
||||
"StoppingCriteria",
|
||||
"StoppingCriteriaList",
|
||||
"MaxLengthCriteria",
|
||||
"MaxTimeCriteria",
|
||||
]
|
||||
_import_structure["generation_utils"] = ["top_k_top_p_filtering"]
|
||||
_import_structure["modeling_utils"] = ["Conv1D", "PreTrainedModel", "apply_chunking_to_forward", "prune_layer"]
|
||||
# PyTorch models structure
|
||||
|
97
src/transformers/generation_stopping_criteria.py
Normal file
97
src/transformers/generation_stopping_criteria.py
Normal file
@ -0,0 +1,97 @@
|
||||
import time
|
||||
import warnings
|
||||
from abc import ABC
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from .file_utils import add_start_docstrings
|
||||
|
||||
|
||||
LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
|
||||
Indices can be obtained using :class:`~transformers.BertTokenizer`. See
|
||||
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
|
||||
details.
|
||||
|
||||
`What are input IDs? <../glossary.html#input-ids>`__
|
||||
scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.vocab_size)`):
|
||||
Prediction scores of a language modeling head. These can be scores for each vocabulary token before SoftMax
|
||||
or scores for each vocabulary token after SoftMax.
|
||||
kwargs:
|
||||
Additional stopping critera specific kwargs.
|
||||
|
||||
Return:
|
||||
:obj:`bool`. :obj:`False` indicates we should continue, :obj:`True` indicates we should stop.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class StoppingCriteria(ABC):
|
||||
"""Abstract base class for all stopping criteria that can be applied during generation."""
|
||||
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, score: torch.FloatTensor, **kwargs) -> bool:
|
||||
raise NotImplementedError("StoppingCriteria needs to be subclassed")
|
||||
|
||||
|
||||
class MaxLengthCriteria(StoppingCriteria):
|
||||
"""
|
||||
This class can be used to stop generation whenever the full generated number of tokens exceeds :obj:`max_length`.
|
||||
Keep in mind for decoder-only type of transformers, this will include the initial prompted tokens.
|
||||
|
||||
Args:
|
||||
max_length (:obj:`int`):
|
||||
The maximum length that the output sequence can have in number of tokens.
|
||||
"""
|
||||
|
||||
def __init__(self, max_length: int):
|
||||
self.max_length = max_length
|
||||
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
||||
return input_ids.shape[-1] > self.max_length
|
||||
|
||||
|
||||
class MaxTimeCriteria(StoppingCriteria):
|
||||
"""
|
||||
This class can be used to stop generation whenever the full generation exceeds some amount of time. By default, the
|
||||
time will start being counted when you initialize this function. You can override this by passing an
|
||||
:obj:`initial_time`.
|
||||
|
||||
Args:
|
||||
max_time (:obj:`float`):
|
||||
The maximum allowed time in seconds for the generation.
|
||||
initial_time (:obj:`float`, `optional`, defaults to :obj:`time.time()`):
|
||||
The start of the generation allowed time.
|
||||
"""
|
||||
|
||||
def __init__(self, max_time: float, initial_timestamp: Optional[float] = None):
|
||||
self.max_time = max_time
|
||||
self.initial_timestamp = time.time() if initial_timestamp is None else initial_timestamp
|
||||
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
||||
return time.time() - self.initial_timestamp > self.max_time
|
||||
|
||||
|
||||
class StoppingCriteriaList(list):
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
||||
return any(criteria(input_ids, scores) for criteria in self)
|
||||
|
||||
|
||||
def validate_stopping_criteria(stopping_criteria: StoppingCriteriaList, max_length: int):
|
||||
found = False
|
||||
for stopping_criterium in stopping_criteria:
|
||||
if isinstance(stopping_criterium, MaxLengthCriteria):
|
||||
found = True
|
||||
if stopping_criterium.max_length != max_length:
|
||||
warnings.warn(
|
||||
"You set different `max_length` for stopping criteria and `max_length` parameter", UserWarning
|
||||
)
|
||||
if not found:
|
||||
stopping_criteria.append(MaxLengthCriteria(max_length=max_length))
|
@ -37,6 +37,12 @@ from .generation_logits_process import (
|
||||
TopKLogitsWarper,
|
||||
TopPLogitsWarper,
|
||||
)
|
||||
from .generation_stopping_criteria import (
|
||||
MaxLengthCriteria,
|
||||
MaxTimeCriteria,
|
||||
StoppingCriteriaList,
|
||||
validate_stopping_criteria,
|
||||
)
|
||||
from .utils import logging
|
||||
|
||||
|
||||
@ -627,6 +633,19 @@ class GenerationMixin:
|
||||
processors.append(ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id))
|
||||
return processors
|
||||
|
||||
def _get_stopping_criteria(
|
||||
self,
|
||||
max_length: Optional[int],
|
||||
max_time: Optional[float],
|
||||
) -> StoppingCriteriaList:
|
||||
|
||||
stopping_criteria = StoppingCriteriaList()
|
||||
if max_length is not None:
|
||||
stopping_criteria.append(MaxLengthCriteria(max_length=max_length))
|
||||
if max_time is not None:
|
||||
stopping_criteria.append(MaxTimeCriteria(max_time=max_time))
|
||||
return stopping_criteria
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
@ -648,6 +667,7 @@ class GenerationMixin:
|
||||
no_repeat_ngram_size: Optional[int] = None,
|
||||
encoder_no_repeat_ngram_size: Optional[int] = None,
|
||||
num_return_sequences: Optional[int] = None,
|
||||
max_time: Optional[float] = None,
|
||||
decoder_start_token_id: Optional[int] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
num_beam_groups: Optional[int] = None,
|
||||
@ -718,6 +738,9 @@ class GenerationMixin:
|
||||
add_prefix_space=True).input_ids`.
|
||||
num_return_sequences(:obj:`int`, `optional`, defaults to 1):
|
||||
The number of independently computed returned sequences for each element in the batch.
|
||||
max_time(:obj:`float`, `optional`, defaults to None):
|
||||
The maximum amount of time you allow the computation to run for in seconds. generation will still
|
||||
finish the current pass after allocated time has been passed.
|
||||
attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Mask to avoid performing attention on padding token indices. Mask values are in ``[0, 1]``, 1 for
|
||||
tokens that are not masked, and 0 for masked tokens. If not provided, will default to a tensor the same
|
||||
@ -936,6 +959,11 @@ class GenerationMixin:
|
||||
diversity_penalty=diversity_penalty,
|
||||
)
|
||||
|
||||
stopping_criteria = self._get_stopping_criteria(
|
||||
max_length=max_length,
|
||||
max_time=max_time,
|
||||
)
|
||||
|
||||
if is_greedy_gen_mode:
|
||||
if num_return_sequences > 1:
|
||||
raise ValueError(
|
||||
@ -946,6 +974,7 @@ class GenerationMixin:
|
||||
return self.greedy_search(
|
||||
input_ids,
|
||||
logits_processor=logits_processor,
|
||||
stopping_criteria=stopping_criteria,
|
||||
max_length=max_length,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
@ -973,6 +1002,7 @@ class GenerationMixin:
|
||||
input_ids,
|
||||
logits_processor=logits_processor,
|
||||
logits_warper=logits_warper,
|
||||
stopping_criteria=stopping_criteria,
|
||||
max_length=max_length,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
@ -1007,6 +1037,7 @@ class GenerationMixin:
|
||||
input_ids,
|
||||
beam_scorer,
|
||||
logits_processor=logits_processor,
|
||||
stopping_criteria=stopping_criteria,
|
||||
max_length=max_length,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
@ -1045,6 +1076,7 @@ class GenerationMixin:
|
||||
beam_scorer,
|
||||
logits_processor=logits_processor,
|
||||
logits_warper=logits_warper,
|
||||
stopping_criteria=stopping_criteria,
|
||||
max_length=max_length,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
@ -1083,6 +1115,7 @@ class GenerationMixin:
|
||||
input_ids,
|
||||
diverse_beam_scorer,
|
||||
logits_processor=logits_processor,
|
||||
stopping_criteria=stopping_criteria,
|
||||
max_length=max_length,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
@ -1095,6 +1128,7 @@ class GenerationMixin:
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
logits_processor: Optional[LogitsProcessorList] = None,
|
||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||
max_length: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
@ -1118,6 +1152,9 @@ class GenerationMixin:
|
||||
An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from
|
||||
:class:`~transformers.LogitsProcessor` used to modify the prediction scores of the language modeling
|
||||
head applied at each generation step.
|
||||
stopping_criteria (:obj:`StoppingCriteriaList`, `optional`):
|
||||
An instance of :class:`~transformers.StoppingCriteriaList`. List of instances of class derived from
|
||||
:class:`~transformers.StoppingCriteria` used to tell if the generation loop should stop.
|
||||
max_length (:obj:`int`, `optional`, defaults to 20):
|
||||
The maximum length of the sequence to be generated.
|
||||
pad_token_id (:obj:`int`, `optional`):
|
||||
@ -1134,7 +1171,6 @@ class GenerationMixin:
|
||||
Whether or not to return the prediction scores. See ``scores`` under returned tensors for more details.
|
||||
return_dict_in_generate (:obj:`bool`, `optional`, defaults to `False`):
|
||||
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
||||
|
||||
model_kwargs:
|
||||
Additional model specific keyword arguments will be forwarded to the :obj:`forward` function of the
|
||||
model. If model is an encoder-decoder model the kwargs should include :obj:`encoder_outputs`.
|
||||
@ -1177,7 +1213,9 @@ class GenerationMixin:
|
||||
"""
|
||||
# init values
|
||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||
max_length = max_length if max_length is not None else self.config.max_length
|
||||
validate_stopping_criteria(stopping_criteria, max_length)
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||
output_scores = output_scores if output_scores is not None else self.config.output_scores
|
||||
@ -1267,6 +1305,9 @@ class GenerationMixin:
|
||||
if unfinished_sequences.max() == 0:
|
||||
break
|
||||
|
||||
if stopping_criteria(input_ids, scores):
|
||||
break
|
||||
|
||||
# increase cur_len
|
||||
cur_len = cur_len + 1
|
||||
|
||||
@ -1295,6 +1336,7 @@ class GenerationMixin:
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
logits_processor: Optional[LogitsProcessorList] = None,
|
||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||
logits_warper: Optional[LogitsProcessorList] = None,
|
||||
max_length: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
@ -1317,6 +1359,9 @@ class GenerationMixin:
|
||||
An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from
|
||||
:class:`~transformers.LogitsProcessor` used to modify the prediction scores of the language modeling
|
||||
head applied at each generation step.
|
||||
stopping_criteria (:obj:`StoppingCriteriaList`, `optional`):
|
||||
An instance of :class:`~transformers.StoppingCriteriaList`. List of instances of class derived from
|
||||
:class:`~transformers.StoppingCriteria` used to tell if the generation loop should stop.
|
||||
logits_warper (:obj:`LogitsProcessorList`, `optional`):
|
||||
An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from
|
||||
:class:`~transformers.LogitsWarper` used to warp the prediction score distribution of the language
|
||||
@ -1387,8 +1432,10 @@ class GenerationMixin:
|
||||
|
||||
# init values
|
||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
|
||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||
max_length = max_length if max_length is not None else self.config.max_length
|
||||
validate_stopping_criteria(stopping_criteria, max_length)
|
||||
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||
output_scores = output_scores if output_scores is not None else self.config.output_scores
|
||||
@ -1477,6 +1524,9 @@ class GenerationMixin:
|
||||
if unfinished_sequences.max() == 0:
|
||||
break
|
||||
|
||||
if stopping_criteria(input_ids, scores):
|
||||
break
|
||||
|
||||
# update model kwargs
|
||||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
||||
@ -1508,6 +1558,7 @@ class GenerationMixin:
|
||||
input_ids: torch.LongTensor,
|
||||
beam_scorer: BeamScorer,
|
||||
logits_processor: Optional[LogitsProcessorList] = None,
|
||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||
max_length: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
@ -1533,6 +1584,9 @@ class GenerationMixin:
|
||||
An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from
|
||||
:class:`~transformers.LogitsProcessor` used to modify the prediction scores of the language modeling
|
||||
head applied at each generation step.
|
||||
stopping_criteria (:obj:`StoppingCriteriaList`, `optional`):
|
||||
An instance of :class:`~transformers.StoppingCriteriaList`. List of instances of class derived from
|
||||
:class:`~transformers.StoppingCriteria` used to tell if the generation loop should stop.
|
||||
max_length (:obj:`int`, `optional`, defaults to 20):
|
||||
The maximum length of the sequence to be generated.
|
||||
pad_token_id (:obj:`int`, `optional`):
|
||||
@ -1609,10 +1663,11 @@ class GenerationMixin:
|
||||
|
||||
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
|
||||
"""
|
||||
|
||||
# init values
|
||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||
max_length = max_length if max_length is not None else self.config.max_length
|
||||
validate_stopping_criteria(stopping_criteria, max_length)
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||
output_scores = output_scores if output_scores is not None else self.config.output_scores
|
||||
@ -1727,6 +1782,9 @@ class GenerationMixin:
|
||||
if beam_scorer.is_done:
|
||||
break
|
||||
|
||||
if stopping_criteria(input_ids, scores):
|
||||
break
|
||||
|
||||
sequence_outputs = beam_scorer.finalize(
|
||||
input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id
|
||||
)
|
||||
@ -1761,6 +1819,7 @@ class GenerationMixin:
|
||||
input_ids: torch.LongTensor,
|
||||
beam_scorer: BeamScorer,
|
||||
logits_processor: Optional[LogitsProcessorList] = None,
|
||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||
logits_warper: Optional[LogitsProcessorList] = None,
|
||||
max_length: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
@ -1787,6 +1846,9 @@ class GenerationMixin:
|
||||
An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from
|
||||
:class:`~transformers.LogitsProcessor` used to modify the prediction scores of the language modeling
|
||||
head applied at each generation step.
|
||||
stopping_criteria (:obj:`StoppingCriteriaList`, `optional`):
|
||||
An instance of :class:`~transformers.StoppingCriteriaList`. List of instances of class derived from
|
||||
:class:`~transformers.StoppingCriteria` used to tell if the generation loop should stop.
|
||||
logits_warper (:obj:`LogitsProcessorList`, `optional`):
|
||||
An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from
|
||||
:class:`~transformers.LogitsWarper` used to warp the prediction score distribution of the language
|
||||
@ -1874,9 +1936,9 @@ class GenerationMixin:
|
||||
|
||||
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
|
||||
"""
|
||||
|
||||
# init values
|
||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||
max_length = max_length if max_length is not None else self.config.max_length
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||
@ -1990,6 +2052,9 @@ class GenerationMixin:
|
||||
if beam_scorer.is_done:
|
||||
break
|
||||
|
||||
if stopping_criteria(input_ids, scores):
|
||||
break
|
||||
|
||||
sequence_outputs = beam_scorer.finalize(
|
||||
input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id
|
||||
)
|
||||
@ -2024,6 +2089,7 @@ class GenerationMixin:
|
||||
input_ids: torch.LongTensor,
|
||||
beam_scorer: BeamScorer,
|
||||
logits_processor: Optional[LogitsProcessorList] = None,
|
||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||
max_length: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
@ -2049,6 +2115,9 @@ class GenerationMixin:
|
||||
An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from
|
||||
:class:`~transformers.LogitsProcessor` used to modify the prediction scores of the language modeling
|
||||
head applied at each generation step.
|
||||
stopping_criteria (:obj:`StoppingCriteriaList`, `optional`):
|
||||
An instance of :class:`~transformers.StoppingCriteriaList`. List of instances of class derived from
|
||||
:class:`~transformers.StoppingCriteria` used to tell if the generation loop should stop.
|
||||
max_length (:obj:`int`, `optional`, defaults to 20):
|
||||
The maximum length of the sequence to be generated.
|
||||
pad_token_id (:obj:`int`, `optional`):
|
||||
@ -2128,10 +2197,11 @@ class GenerationMixin:
|
||||
|
||||
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
|
||||
"""
|
||||
|
||||
# init values
|
||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||
max_length = max_length if max_length is not None else self.config.max_length
|
||||
validate_stopping_criteria(stopping_criteria, max_length)
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||
output_scores = output_scores if output_scores is not None else self.config.output_scores
|
||||
@ -2291,6 +2361,9 @@ class GenerationMixin:
|
||||
if beam_scorer.is_done:
|
||||
break
|
||||
|
||||
if stopping_criteria(input_ids, scores):
|
||||
break
|
||||
|
||||
sequence_outputs = beam_scorer.finalize(
|
||||
input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id
|
||||
)
|
||||
|
79
tests/test_generation_stopping_criteria.py
Normal file
79
tests/test_generation_stopping_criteria.py
Normal file
@ -0,0 +1,79 @@
|
||||
import time
|
||||
import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_torch, torch_device
|
||||
|
||||
from .test_modeling_common import ids_tensor
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers.generation_stopping_criteria import (
|
||||
MaxLengthCriteria,
|
||||
MaxTimeCriteria,
|
||||
StoppingCriteriaList,
|
||||
validate_stopping_criteria,
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
class StoppingCriteriaTestCase(unittest.TestCase):
|
||||
def _get_tensors(self, length):
|
||||
batch_size = 3
|
||||
vocab_size = 250
|
||||
|
||||
input_ids = ids_tensor((batch_size, length), vocab_size)
|
||||
scores = torch.ones((batch_size, length), device=torch_device, dtype=torch.float) / length
|
||||
return input_ids, scores
|
||||
|
||||
def test_list_criteria(self):
|
||||
input_ids, scores = self._get_tensors(5)
|
||||
|
||||
criteria = StoppingCriteriaList(
|
||||
[
|
||||
MaxLengthCriteria(max_length=10),
|
||||
MaxTimeCriteria(max_time=0.1),
|
||||
]
|
||||
)
|
||||
|
||||
self.assertFalse(criteria(input_ids, scores))
|
||||
|
||||
input_ids, scores = self._get_tensors(10)
|
||||
self.assertFalse(criteria(input_ids, scores))
|
||||
|
||||
input_ids, scores = self._get_tensors(11)
|
||||
self.assertTrue(criteria(input_ids, scores))
|
||||
|
||||
def test_max_length_criteria(self):
|
||||
criteria = MaxLengthCriteria(max_length=10)
|
||||
|
||||
input_ids, scores = self._get_tensors(5)
|
||||
self.assertFalse(criteria(input_ids, scores))
|
||||
|
||||
input_ids, scores = self._get_tensors(10)
|
||||
self.assertFalse(criteria(input_ids, scores))
|
||||
|
||||
input_ids, scores = self._get_tensors(11)
|
||||
self.assertTrue(criteria(input_ids, scores))
|
||||
|
||||
def test_max_time_criteria(self):
|
||||
input_ids, scores = self._get_tensors(5)
|
||||
|
||||
criteria = MaxTimeCriteria(max_time=0.1)
|
||||
self.assertFalse(criteria(input_ids, scores))
|
||||
|
||||
criteria = MaxTimeCriteria(max_time=0.1, initial_timestamp=time.time() - 0.2)
|
||||
self.assertTrue(criteria(input_ids, scores))
|
||||
|
||||
def test_validate_stopping_criteria(self):
|
||||
validate_stopping_criteria(StoppingCriteriaList([MaxLengthCriteria(10)]), 10)
|
||||
|
||||
with self.assertWarns(UserWarning):
|
||||
validate_stopping_criteria(StoppingCriteriaList([MaxLengthCriteria(10)]), 11)
|
||||
|
||||
stopping_criteria = StoppingCriteriaList()
|
||||
validate_stopping_criteria(stopping_criteria, 11)
|
||||
|
||||
self.assertEqual(len(stopping_criteria), 1)
|
@ -38,6 +38,7 @@ if is_torch_available():
|
||||
TopKLogitsWarper,
|
||||
TopPLogitsWarper,
|
||||
)
|
||||
from transformers.generation_stopping_criteria import MaxLengthCriteria, StoppingCriteriaList
|
||||
from transformers.generation_utils import (
|
||||
BeamSampleDecoderOnlyOutput,
|
||||
BeamSampleEncoderDecoderOutput,
|
||||
@ -1320,3 +1321,189 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
"Justin Timberlake and Jessica Biel have a son. The baby is named Silas Randall Timberlake. It is the first child for both. The couple announced the pregnancy in January. The name Silas is the middle name of Timberlake's maternal grandfather. It's also his own middle name.",
|
||||
],
|
||||
)
|
||||
|
||||
def test_max_length_backward_compat_greedy(self):
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
|
||||
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
|
||||
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
max_length = 20
|
||||
input_ids = input_ids.expand(2, -1)
|
||||
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
|
||||
input_ids = bart_model._prepare_decoder_input_ids_for_generation(
|
||||
input_ids,
|
||||
decoder_start_token_id=bart_model.config.decoder_start_token_id,
|
||||
bos_token_id=bart_model.config.bos_token_id,
|
||||
)
|
||||
|
||||
bart_model.greedy_search(
|
||||
input_ids,
|
||||
max_length=max_length,
|
||||
pad_token_id=bart_model.config.pad_token_id,
|
||||
eos_token_id=bart_model.config.eos_token_id,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
def test_max_length_backward_compat_sample(self):
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
|
||||
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
|
||||
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
max_length = 20
|
||||
input_ids = input_ids.expand(2, -1)
|
||||
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
|
||||
input_ids = bart_model._prepare_decoder_input_ids_for_generation(
|
||||
input_ids,
|
||||
decoder_start_token_id=bart_model.config.decoder_start_token_id,
|
||||
bos_token_id=bart_model.config.bos_token_id,
|
||||
)
|
||||
bart_model.sample(
|
||||
input_ids,
|
||||
max_length=max_length,
|
||||
pad_token_id=bart_model.config.pad_token_id,
|
||||
eos_token_id=bart_model.config.eos_token_id,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
def test_max_length_backward_compat_beam_search(self):
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
|
||||
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
|
||||
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
batch_size = 1
|
||||
max_length = 20
|
||||
num_beams = 2
|
||||
|
||||
input_ids = input_ids.expand(2, -1)
|
||||
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
|
||||
input_ids = bart_model._prepare_decoder_input_ids_for_generation(
|
||||
input_ids,
|
||||
decoder_start_token_id=bart_model.config.decoder_start_token_id,
|
||||
bos_token_id=bart_model.config.bos_token_id,
|
||||
)
|
||||
|
||||
beam_scorer = BeamSearchScorer(
|
||||
batch_size=batch_size,
|
||||
max_length=max_length,
|
||||
num_beams=num_beams,
|
||||
device=torch_device,
|
||||
)
|
||||
_ = bart_model.beam_search(
|
||||
input_ids, num_beams=num_beams, max_length=max_length, beam_scorer=beam_scorer, **model_kwargs
|
||||
)
|
||||
|
||||
def test_max_length_backward_compat_group_beam_search(self):
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
|
||||
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
|
||||
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
batch_size = 1
|
||||
max_length = 20
|
||||
num_beams = 6
|
||||
num_beam_groups = 3
|
||||
num_return_sequences = num_beams * batch_size
|
||||
|
||||
input_ids = input_ids.expand(6, -1)
|
||||
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
|
||||
input_ids = bart_model._prepare_decoder_input_ids_for_generation(
|
||||
input_ids,
|
||||
decoder_start_token_id=bart_model.config.decoder_start_token_id,
|
||||
bos_token_id=bart_model.config.bos_token_id,
|
||||
)
|
||||
|
||||
diverse_beam_scorer = BeamSearchScorer(
|
||||
batch_size=batch_size,
|
||||
max_length=max_length,
|
||||
num_beams=num_beams,
|
||||
device=torch_device,
|
||||
num_beam_hyps_to_keep=num_return_sequences,
|
||||
num_beam_groups=num_beam_groups,
|
||||
)
|
||||
bart_model.group_beam_search(
|
||||
input_ids, diverse_beam_scorer, num_beams=num_beams, max_length=max_length, **model_kwargs
|
||||
)
|
||||
|
||||
def test_max_length_warning_if_different(self):
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
|
||||
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
|
||||
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
batch_size = 1
|
||||
|
||||
max_length = 20
|
||||
num_beams = 6
|
||||
num_beam_groups = 3
|
||||
num_return_sequences = num_beams * batch_size
|
||||
stopping_criteria_max_length = 18
|
||||
stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=stopping_criteria_max_length)])
|
||||
|
||||
# Greedy
|
||||
input_ids = input_ids.expand(6, -1)
|
||||
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
|
||||
input_ids = bart_model._prepare_decoder_input_ids_for_generation(
|
||||
input_ids,
|
||||
decoder_start_token_id=bart_model.config.decoder_start_token_id,
|
||||
bos_token_id=bart_model.config.bos_token_id,
|
||||
)
|
||||
|
||||
with self.assertWarns(UserWarning):
|
||||
bart_model.greedy_search(
|
||||
input_ids,
|
||||
max_length=max_length,
|
||||
pad_token_id=bart_model.config.pad_token_id,
|
||||
stopping_criteria=stopping_criteria,
|
||||
eos_token_id=bart_model.config.eos_token_id,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
# Sample
|
||||
with self.assertWarns(UserWarning):
|
||||
bart_model.sample(
|
||||
input_ids,
|
||||
max_length=max_length,
|
||||
stopping_criteria=stopping_criteria,
|
||||
pad_token_id=bart_model.config.pad_token_id,
|
||||
eos_token_id=bart_model.config.eos_token_id,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
# Beam
|
||||
beam_scorer = BeamSearchScorer(
|
||||
batch_size=batch_size,
|
||||
max_length=max_length,
|
||||
num_beams=num_beams,
|
||||
device=torch_device,
|
||||
)
|
||||
with self.assertWarns(UserWarning):
|
||||
bart_model.beam_search(
|
||||
input_ids,
|
||||
num_beams=num_beams,
|
||||
stopping_criteria=stopping_criteria,
|
||||
max_length=max_length,
|
||||
beam_scorer=beam_scorer,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
# Grouped beam search
|
||||
diverse_beam_scorer = BeamSearchScorer(
|
||||
batch_size=batch_size,
|
||||
max_length=max_length,
|
||||
num_beams=num_beams,
|
||||
device=torch_device,
|
||||
num_beam_hyps_to_keep=num_return_sequences,
|
||||
num_beam_groups=num_beam_groups,
|
||||
)
|
||||
with self.assertWarns(UserWarning):
|
||||
bart_model.group_beam_search(
|
||||
input_ids,
|
||||
diverse_beam_scorer,
|
||||
stopping_criteria=stopping_criteria,
|
||||
num_beams=num_beams,
|
||||
max_length=max_length,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import datetime
|
||||
import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
@ -649,3 +650,44 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase):
|
||||
self.assertTrue(
|
||||
all([output_seq_strs[idx] != output_seq_tt_strs[idx] for idx in range(len(output_seq_tt_strs))])
|
||||
) # token_type_ids should change output
|
||||
|
||||
@slow
|
||||
def test_gpt2_sample_max_time(self):
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||
model = GPT2LMHeadModel.from_pretrained("gpt2")
|
||||
model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
tokenized = tokenizer("Today is a nice day and", return_tensors="pt", return_token_type_ids=True)
|
||||
input_ids = tokenized.input_ids.to(torch_device)
|
||||
|
||||
MAX_TIME = 0.5
|
||||
|
||||
start = datetime.datetime.now()
|
||||
model.generate(input_ids, do_sample=True, max_time=MAX_TIME, max_length=256)
|
||||
duration = datetime.datetime.now() - start
|
||||
self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
|
||||
self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))
|
||||
|
||||
start = datetime.datetime.now()
|
||||
model.generate(input_ids, do_sample=False, max_time=MAX_TIME, max_length=256)
|
||||
duration = datetime.datetime.now() - start
|
||||
self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
|
||||
self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))
|
||||
|
||||
start = datetime.datetime.now()
|
||||
model.generate(input_ids, do_sample=False, num_beams=2, max_time=MAX_TIME, max_length=256)
|
||||
duration = datetime.datetime.now() - start
|
||||
self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
|
||||
self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))
|
||||
|
||||
start = datetime.datetime.now()
|
||||
model.generate(input_ids, do_sample=True, num_beams=2, max_time=MAX_TIME, max_length=256)
|
||||
duration = datetime.datetime.now() - start
|
||||
self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
|
||||
self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))
|
||||
|
||||
start = datetime.datetime.now()
|
||||
model.generate(input_ids, do_sample=False, max_time=None, max_length=256)
|
||||
duration = datetime.datetime.now() - start
|
||||
self.assertGreater(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))
|
||||
|
Loading…
Reference in New Issue
Block a user