Adding new parameter to generate: max_time. (#9846)

* [WIP] Adding new parameter to `generate`:  `max_time`.

Generation by tokens number is sometimes a bit clunky because we don't
know how many tokens are good enough or even how many tokens are in
the payload (for pipelines users for instance). This leads to hard
to understand behavior.

This PR proposes a new argument `max_time` which is a float of seconds
for the allowed time for `generate` to run on.
Ideally combinations of `max_tokens=None`, `max_time=2` could be used to
generate as many tokens as possible within time budget.

NB: Another possible approach consists of passing a callback to `generate`
  putting the caller in charge of the actual decision of when to stop
  generating tokens. It opens the door to 'which args should we pass'
  to this callback. It's hard to imagine other use-cases for this
  early stopping behavior than time (that are not already covered by
  parameters of generate)

* Revamp with StoppingCriteria

* Removing deprecated mentions.

* Forgot arguments to stopping criteria.

* Readding max_length it's not just used as a stopping criteria.

* Default value for `stopping_criteria`.

* Address @patrickvonplaten comments.

- More docstrings
- Actual doc
- Include in global namespace
- Remove TF work.

* Put back `max_length` (deprecation different PR).

* Doc quality.

* Fixing old behavior without `stopping_criteria` but with `max_length`.

Making sure we don't break that in the future.

* Adding more tests for possible inconsistencies between

`max_length` and `stopping_criteria`.

* Fixing the torch imports.
This commit is contained in:
Nicolas Patry 2021-03-12 10:11:50 +01:00 committed by GitHub
parent ea46e3fa9c
commit 543d0549f8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 506 additions and 5 deletions

View File

@ -151,6 +151,23 @@ generation.
.. autoclass:: transformers.HammingDiversityLogitsProcessor
:members: __call__
StoppingCriteria
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
A :class:`~transformers.StoppingCriteria` can be used to change when to stop generation (other than EOS token).
.. autoclass:: transformers.StoppingCriteria
:members: __call__
.. autoclass:: transformers.StoppingCriteriaList
:members: __call__
.. autoclass:: transformers.MaxLengthCriteria
:members: __call__
.. autoclass:: transformers.MaxTimeCriteria
:members: __call__
BeamSearch
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -380,6 +380,12 @@ if is_torch_available():
"TopKLogitsWarper",
"TopPLogitsWarper",
]
_import_structure["generation_stopping_criteria"] = [
"StoppingCriteria",
"StoppingCriteriaList",
"MaxLengthCriteria",
"MaxTimeCriteria",
]
_import_structure["generation_utils"] = ["top_k_top_p_filtering"]
_import_structure["modeling_utils"] = ["Conv1D", "PreTrainedModel", "apply_chunking_to_forward", "prune_layer"]
# PyTorch models structure

View File

@ -0,0 +1,97 @@
import time
import warnings
from abc import ABC
from typing import Optional
import torch
from .file_utils import add_start_docstrings
LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using :class:`~transformers.BertTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
details.
`What are input IDs? <../glossary.html#input-ids>`__
scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.vocab_size)`):
Prediction scores of a language modeling head. These can be scores for each vocabulary token before SoftMax
or scores for each vocabulary token after SoftMax.
kwargs:
Additional stopping critera specific kwargs.
Return:
:obj:`bool`. :obj:`False` indicates we should continue, :obj:`True` indicates we should stop.
"""
class StoppingCriteria(ABC):
"""Abstract base class for all stopping criteria that can be applied during generation."""
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, score: torch.FloatTensor, **kwargs) -> bool:
raise NotImplementedError("StoppingCriteria needs to be subclassed")
class MaxLengthCriteria(StoppingCriteria):
"""
This class can be used to stop generation whenever the full generated number of tokens exceeds :obj:`max_length`.
Keep in mind for decoder-only type of transformers, this will include the initial prompted tokens.
Args:
max_length (:obj:`int`):
The maximum length that the output sequence can have in number of tokens.
"""
def __init__(self, max_length: int):
self.max_length = max_length
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
return input_ids.shape[-1] > self.max_length
class MaxTimeCriteria(StoppingCriteria):
"""
This class can be used to stop generation whenever the full generation exceeds some amount of time. By default, the
time will start being counted when you initialize this function. You can override this by passing an
:obj:`initial_time`.
Args:
max_time (:obj:`float`):
The maximum allowed time in seconds for the generation.
initial_time (:obj:`float`, `optional`, defaults to :obj:`time.time()`):
The start of the generation allowed time.
"""
def __init__(self, max_time: float, initial_timestamp: Optional[float] = None):
self.max_time = max_time
self.initial_timestamp = time.time() if initial_timestamp is None else initial_timestamp
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
return time.time() - self.initial_timestamp > self.max_time
class StoppingCriteriaList(list):
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
return any(criteria(input_ids, scores) for criteria in self)
def validate_stopping_criteria(stopping_criteria: StoppingCriteriaList, max_length: int):
found = False
for stopping_criterium in stopping_criteria:
if isinstance(stopping_criterium, MaxLengthCriteria):
found = True
if stopping_criterium.max_length != max_length:
warnings.warn(
"You set different `max_length` for stopping criteria and `max_length` parameter", UserWarning
)
if not found:
stopping_criteria.append(MaxLengthCriteria(max_length=max_length))

View File

@ -37,6 +37,12 @@ from .generation_logits_process import (
TopKLogitsWarper,
TopPLogitsWarper,
)
from .generation_stopping_criteria import (
MaxLengthCriteria,
MaxTimeCriteria,
StoppingCriteriaList,
validate_stopping_criteria,
)
from .utils import logging
@ -627,6 +633,19 @@ class GenerationMixin:
processors.append(ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id))
return processors
def _get_stopping_criteria(
self,
max_length: Optional[int],
max_time: Optional[float],
) -> StoppingCriteriaList:
stopping_criteria = StoppingCriteriaList()
if max_length is not None:
stopping_criteria.append(MaxLengthCriteria(max_length=max_length))
if max_time is not None:
stopping_criteria.append(MaxTimeCriteria(max_time=max_time))
return stopping_criteria
@torch.no_grad()
def generate(
self,
@ -648,6 +667,7 @@ class GenerationMixin:
no_repeat_ngram_size: Optional[int] = None,
encoder_no_repeat_ngram_size: Optional[int] = None,
num_return_sequences: Optional[int] = None,
max_time: Optional[float] = None,
decoder_start_token_id: Optional[int] = None,
use_cache: Optional[bool] = None,
num_beam_groups: Optional[int] = None,
@ -718,6 +738,9 @@ class GenerationMixin:
add_prefix_space=True).input_ids`.
num_return_sequences(:obj:`int`, `optional`, defaults to 1):
The number of independently computed returned sequences for each element in the batch.
max_time(:obj:`float`, `optional`, defaults to None):
The maximum amount of time you allow the computation to run for in seconds. generation will still
finish the current pass after allocated time has been passed.
attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on padding token indices. Mask values are in ``[0, 1]``, 1 for
tokens that are not masked, and 0 for masked tokens. If not provided, will default to a tensor the same
@ -936,6 +959,11 @@ class GenerationMixin:
diversity_penalty=diversity_penalty,
)
stopping_criteria = self._get_stopping_criteria(
max_length=max_length,
max_time=max_time,
)
if is_greedy_gen_mode:
if num_return_sequences > 1:
raise ValueError(
@ -946,6 +974,7 @@ class GenerationMixin:
return self.greedy_search(
input_ids,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
max_length=max_length,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
@ -973,6 +1002,7 @@ class GenerationMixin:
input_ids,
logits_processor=logits_processor,
logits_warper=logits_warper,
stopping_criteria=stopping_criteria,
max_length=max_length,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
@ -1007,6 +1037,7 @@ class GenerationMixin:
input_ids,
beam_scorer,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
max_length=max_length,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
@ -1045,6 +1076,7 @@ class GenerationMixin:
beam_scorer,
logits_processor=logits_processor,
logits_warper=logits_warper,
stopping_criteria=stopping_criteria,
max_length=max_length,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
@ -1083,6 +1115,7 @@ class GenerationMixin:
input_ids,
diverse_beam_scorer,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
max_length=max_length,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
@ -1095,6 +1128,7 @@ class GenerationMixin:
self,
input_ids: torch.LongTensor,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
@ -1118,6 +1152,9 @@ class GenerationMixin:
An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from
:class:`~transformers.LogitsProcessor` used to modify the prediction scores of the language modeling
head applied at each generation step.
stopping_criteria (:obj:`StoppingCriteriaList`, `optional`):
An instance of :class:`~transformers.StoppingCriteriaList`. List of instances of class derived from
:class:`~transformers.StoppingCriteria` used to tell if the generation loop should stop.
max_length (:obj:`int`, `optional`, defaults to 20):
The maximum length of the sequence to be generated.
pad_token_id (:obj:`int`, `optional`):
@ -1134,7 +1171,6 @@ class GenerationMixin:
Whether or not to return the prediction scores. See ``scores`` under returned tensors for more details.
return_dict_in_generate (:obj:`bool`, `optional`, defaults to `False`):
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
model_kwargs:
Additional model specific keyword arguments will be forwarded to the :obj:`forward` function of the
model. If model is an encoder-decoder model the kwargs should include :obj:`encoder_outputs`.
@ -1177,7 +1213,9 @@ class GenerationMixin:
"""
# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
max_length = max_length if max_length is not None else self.config.max_length
validate_stopping_criteria(stopping_criteria, max_length)
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
output_scores = output_scores if output_scores is not None else self.config.output_scores
@ -1267,6 +1305,9 @@ class GenerationMixin:
if unfinished_sequences.max() == 0:
break
if stopping_criteria(input_ids, scores):
break
# increase cur_len
cur_len = cur_len + 1
@ -1295,6 +1336,7 @@ class GenerationMixin:
self,
input_ids: torch.LongTensor,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
logits_warper: Optional[LogitsProcessorList] = None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
@ -1317,6 +1359,9 @@ class GenerationMixin:
An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from
:class:`~transformers.LogitsProcessor` used to modify the prediction scores of the language modeling
head applied at each generation step.
stopping_criteria (:obj:`StoppingCriteriaList`, `optional`):
An instance of :class:`~transformers.StoppingCriteriaList`. List of instances of class derived from
:class:`~transformers.StoppingCriteria` used to tell if the generation loop should stop.
logits_warper (:obj:`LogitsProcessorList`, `optional`):
An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from
:class:`~transformers.LogitsWarper` used to warp the prediction score distribution of the language
@ -1387,8 +1432,10 @@ class GenerationMixin:
# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
max_length = max_length if max_length is not None else self.config.max_length
validate_stopping_criteria(stopping_criteria, max_length)
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
output_scores = output_scores if output_scores is not None else self.config.output_scores
@ -1477,6 +1524,9 @@ class GenerationMixin:
if unfinished_sequences.max() == 0:
break
if stopping_criteria(input_ids, scores):
break
# update model kwargs
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
@ -1508,6 +1558,7 @@ class GenerationMixin:
input_ids: torch.LongTensor,
beam_scorer: BeamScorer,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
@ -1533,6 +1584,9 @@ class GenerationMixin:
An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from
:class:`~transformers.LogitsProcessor` used to modify the prediction scores of the language modeling
head applied at each generation step.
stopping_criteria (:obj:`StoppingCriteriaList`, `optional`):
An instance of :class:`~transformers.StoppingCriteriaList`. List of instances of class derived from
:class:`~transformers.StoppingCriteria` used to tell if the generation loop should stop.
max_length (:obj:`int`, `optional`, defaults to 20):
The maximum length of the sequence to be generated.
pad_token_id (:obj:`int`, `optional`):
@ -1609,10 +1663,11 @@ class GenerationMixin:
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
"""
# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
max_length = max_length if max_length is not None else self.config.max_length
validate_stopping_criteria(stopping_criteria, max_length)
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
output_scores = output_scores if output_scores is not None else self.config.output_scores
@ -1727,6 +1782,9 @@ class GenerationMixin:
if beam_scorer.is_done:
break
if stopping_criteria(input_ids, scores):
break
sequence_outputs = beam_scorer.finalize(
input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id
)
@ -1761,6 +1819,7 @@ class GenerationMixin:
input_ids: torch.LongTensor,
beam_scorer: BeamScorer,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
logits_warper: Optional[LogitsProcessorList] = None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
@ -1787,6 +1846,9 @@ class GenerationMixin:
An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from
:class:`~transformers.LogitsProcessor` used to modify the prediction scores of the language modeling
head applied at each generation step.
stopping_criteria (:obj:`StoppingCriteriaList`, `optional`):
An instance of :class:`~transformers.StoppingCriteriaList`. List of instances of class derived from
:class:`~transformers.StoppingCriteria` used to tell if the generation loop should stop.
logits_warper (:obj:`LogitsProcessorList`, `optional`):
An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from
:class:`~transformers.LogitsWarper` used to warp the prediction score distribution of the language
@ -1874,9 +1936,9 @@ class GenerationMixin:
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
"""
# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
max_length = max_length if max_length is not None else self.config.max_length
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
@ -1990,6 +2052,9 @@ class GenerationMixin:
if beam_scorer.is_done:
break
if stopping_criteria(input_ids, scores):
break
sequence_outputs = beam_scorer.finalize(
input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id
)
@ -2024,6 +2089,7 @@ class GenerationMixin:
input_ids: torch.LongTensor,
beam_scorer: BeamScorer,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
@ -2049,6 +2115,9 @@ class GenerationMixin:
An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from
:class:`~transformers.LogitsProcessor` used to modify the prediction scores of the language modeling
head applied at each generation step.
stopping_criteria (:obj:`StoppingCriteriaList`, `optional`):
An instance of :class:`~transformers.StoppingCriteriaList`. List of instances of class derived from
:class:`~transformers.StoppingCriteria` used to tell if the generation loop should stop.
max_length (:obj:`int`, `optional`, defaults to 20):
The maximum length of the sequence to be generated.
pad_token_id (:obj:`int`, `optional`):
@ -2128,10 +2197,11 @@ class GenerationMixin:
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
"""
# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
max_length = max_length if max_length is not None else self.config.max_length
validate_stopping_criteria(stopping_criteria, max_length)
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
output_scores = output_scores if output_scores is not None else self.config.output_scores
@ -2291,6 +2361,9 @@ class GenerationMixin:
if beam_scorer.is_done:
break
if stopping_criteria(input_ids, scores):
break
sequence_outputs = beam_scorer.finalize(
input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id
)

View File

@ -0,0 +1,79 @@
import time
import unittest
from transformers import is_torch_available
from transformers.testing_utils import require_torch, torch_device
from .test_modeling_common import ids_tensor
if is_torch_available():
import torch
from transformers.generation_stopping_criteria import (
MaxLengthCriteria,
MaxTimeCriteria,
StoppingCriteriaList,
validate_stopping_criteria,
)
@require_torch
class StoppingCriteriaTestCase(unittest.TestCase):
def _get_tensors(self, length):
batch_size = 3
vocab_size = 250
input_ids = ids_tensor((batch_size, length), vocab_size)
scores = torch.ones((batch_size, length), device=torch_device, dtype=torch.float) / length
return input_ids, scores
def test_list_criteria(self):
input_ids, scores = self._get_tensors(5)
criteria = StoppingCriteriaList(
[
MaxLengthCriteria(max_length=10),
MaxTimeCriteria(max_time=0.1),
]
)
self.assertFalse(criteria(input_ids, scores))
input_ids, scores = self._get_tensors(10)
self.assertFalse(criteria(input_ids, scores))
input_ids, scores = self._get_tensors(11)
self.assertTrue(criteria(input_ids, scores))
def test_max_length_criteria(self):
criteria = MaxLengthCriteria(max_length=10)
input_ids, scores = self._get_tensors(5)
self.assertFalse(criteria(input_ids, scores))
input_ids, scores = self._get_tensors(10)
self.assertFalse(criteria(input_ids, scores))
input_ids, scores = self._get_tensors(11)
self.assertTrue(criteria(input_ids, scores))
def test_max_time_criteria(self):
input_ids, scores = self._get_tensors(5)
criteria = MaxTimeCriteria(max_time=0.1)
self.assertFalse(criteria(input_ids, scores))
criteria = MaxTimeCriteria(max_time=0.1, initial_timestamp=time.time() - 0.2)
self.assertTrue(criteria(input_ids, scores))
def test_validate_stopping_criteria(self):
validate_stopping_criteria(StoppingCriteriaList([MaxLengthCriteria(10)]), 10)
with self.assertWarns(UserWarning):
validate_stopping_criteria(StoppingCriteriaList([MaxLengthCriteria(10)]), 11)
stopping_criteria = StoppingCriteriaList()
validate_stopping_criteria(stopping_criteria, 11)
self.assertEqual(len(stopping_criteria), 1)

View File

@ -38,6 +38,7 @@ if is_torch_available():
TopKLogitsWarper,
TopPLogitsWarper,
)
from transformers.generation_stopping_criteria import MaxLengthCriteria, StoppingCriteriaList
from transformers.generation_utils import (
BeamSampleDecoderOnlyOutput,
BeamSampleEncoderDecoderOutput,
@ -1320,3 +1321,189 @@ class GenerationIntegrationTests(unittest.TestCase):
"Justin Timberlake and Jessica Biel have a son. The baby is named Silas Randall Timberlake. It is the first child for both. The couple announced the pregnancy in January. The name Silas is the middle name of Timberlake's maternal grandfather. It's also his own middle name.",
],
)
def test_max_length_backward_compat_greedy(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
max_length = 20
input_ids = input_ids.expand(2, -1)
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
input_ids = bart_model._prepare_decoder_input_ids_for_generation(
input_ids,
decoder_start_token_id=bart_model.config.decoder_start_token_id,
bos_token_id=bart_model.config.bos_token_id,
)
bart_model.greedy_search(
input_ids,
max_length=max_length,
pad_token_id=bart_model.config.pad_token_id,
eos_token_id=bart_model.config.eos_token_id,
**model_kwargs,
)
def test_max_length_backward_compat_sample(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
max_length = 20
input_ids = input_ids.expand(2, -1)
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
input_ids = bart_model._prepare_decoder_input_ids_for_generation(
input_ids,
decoder_start_token_id=bart_model.config.decoder_start_token_id,
bos_token_id=bart_model.config.bos_token_id,
)
bart_model.sample(
input_ids,
max_length=max_length,
pad_token_id=bart_model.config.pad_token_id,
eos_token_id=bart_model.config.eos_token_id,
**model_kwargs,
)
def test_max_length_backward_compat_beam_search(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
batch_size = 1
max_length = 20
num_beams = 2
input_ids = input_ids.expand(2, -1)
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
input_ids = bart_model._prepare_decoder_input_ids_for_generation(
input_ids,
decoder_start_token_id=bart_model.config.decoder_start_token_id,
bos_token_id=bart_model.config.bos_token_id,
)
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
max_length=max_length,
num_beams=num_beams,
device=torch_device,
)
_ = bart_model.beam_search(
input_ids, num_beams=num_beams, max_length=max_length, beam_scorer=beam_scorer, **model_kwargs
)
def test_max_length_backward_compat_group_beam_search(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
batch_size = 1
max_length = 20
num_beams = 6
num_beam_groups = 3
num_return_sequences = num_beams * batch_size
input_ids = input_ids.expand(6, -1)
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
input_ids = bart_model._prepare_decoder_input_ids_for_generation(
input_ids,
decoder_start_token_id=bart_model.config.decoder_start_token_id,
bos_token_id=bart_model.config.bos_token_id,
)
diverse_beam_scorer = BeamSearchScorer(
batch_size=batch_size,
max_length=max_length,
num_beams=num_beams,
device=torch_device,
num_beam_hyps_to_keep=num_return_sequences,
num_beam_groups=num_beam_groups,
)
bart_model.group_beam_search(
input_ids, diverse_beam_scorer, num_beams=num_beams, max_length=max_length, **model_kwargs
)
def test_max_length_warning_if_different(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
batch_size = 1
max_length = 20
num_beams = 6
num_beam_groups = 3
num_return_sequences = num_beams * batch_size
stopping_criteria_max_length = 18
stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=stopping_criteria_max_length)])
# Greedy
input_ids = input_ids.expand(6, -1)
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
input_ids = bart_model._prepare_decoder_input_ids_for_generation(
input_ids,
decoder_start_token_id=bart_model.config.decoder_start_token_id,
bos_token_id=bart_model.config.bos_token_id,
)
with self.assertWarns(UserWarning):
bart_model.greedy_search(
input_ids,
max_length=max_length,
pad_token_id=bart_model.config.pad_token_id,
stopping_criteria=stopping_criteria,
eos_token_id=bart_model.config.eos_token_id,
**model_kwargs,
)
# Sample
with self.assertWarns(UserWarning):
bart_model.sample(
input_ids,
max_length=max_length,
stopping_criteria=stopping_criteria,
pad_token_id=bart_model.config.pad_token_id,
eos_token_id=bart_model.config.eos_token_id,
**model_kwargs,
)
# Beam
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
max_length=max_length,
num_beams=num_beams,
device=torch_device,
)
with self.assertWarns(UserWarning):
bart_model.beam_search(
input_ids,
num_beams=num_beams,
stopping_criteria=stopping_criteria,
max_length=max_length,
beam_scorer=beam_scorer,
**model_kwargs,
)
# Grouped beam search
diverse_beam_scorer = BeamSearchScorer(
batch_size=batch_size,
max_length=max_length,
num_beams=num_beams,
device=torch_device,
num_beam_hyps_to_keep=num_return_sequences,
num_beam_groups=num_beam_groups,
)
with self.assertWarns(UserWarning):
bart_model.group_beam_search(
input_ids,
diverse_beam_scorer,
stopping_criteria=stopping_criteria,
num_beams=num_beams,
max_length=max_length,
**model_kwargs,
)

View File

@ -14,6 +14,7 @@
# limitations under the License.
import datetime
import unittest
from transformers import is_torch_available
@ -649,3 +650,44 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase):
self.assertTrue(
all([output_seq_strs[idx] != output_seq_tt_strs[idx] for idx in range(len(output_seq_tt_strs))])
) # token_type_ids should change output
@slow
def test_gpt2_sample_max_time(self):
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")
model.to(torch_device)
torch.manual_seed(0)
tokenized = tokenizer("Today is a nice day and", return_tensors="pt", return_token_type_ids=True)
input_ids = tokenized.input_ids.to(torch_device)
MAX_TIME = 0.5
start = datetime.datetime.now()
model.generate(input_ids, do_sample=True, max_time=MAX_TIME, max_length=256)
duration = datetime.datetime.now() - start
self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))
start = datetime.datetime.now()
model.generate(input_ids, do_sample=False, max_time=MAX_TIME, max_length=256)
duration = datetime.datetime.now() - start
self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))
start = datetime.datetime.now()
model.generate(input_ids, do_sample=False, num_beams=2, max_time=MAX_TIME, max_length=256)
duration = datetime.datetime.now() - start
self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))
start = datetime.datetime.now()
model.generate(input_ids, do_sample=True, num_beams=2, max_time=MAX_TIME, max_length=256)
duration = datetime.datetime.now() - start
self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))
start = datetime.datetime.now()
model.generate(input_ids, do_sample=False, max_time=None, max_length=256)
duration = datetime.datetime.now() - start
self.assertGreater(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))