transformers/tests/test_generation_stopping_criteria.py
Nicolas Patry 543d0549f8
Adding new parameter to generate: max_time. (#9846)
* [WIP] Adding new parameter to `generate`:  `max_time`.

Generation by tokens number is sometimes a bit clunky because we don't
know how many tokens are good enough or even how many tokens are in
the payload (for pipelines users for instance). This leads to hard
to understand behavior.

This PR proposes a new argument `max_time` which is a float of seconds
for the allowed time for `generate` to run on.
Ideally combinations of `max_tokens=None`, `max_time=2` could be used to
generate as many tokens as possible within time budget.

NB: Another possible approach consists of passing a callback to `generate`
  putting the caller in charge of the actual decision of when to stop
  generating tokens. It opens the door to 'which args should we pass'
  to this callback. It's hard to imagine other use-cases for this
  early stopping behavior than time (that are not already covered by
  parameters of generate)

* Revamp with StoppingCriteria

* Removing deprecated mentions.

* Forgot arguments to stopping criteria.

* Readding max_length it's not just used as a stopping criteria.

* Default value for `stopping_criteria`.

* Address @patrickvonplaten comments.

- More docstrings
- Actual doc
- Include in global namespace
- Remove TF work.

* Put back `max_length` (deprecation different PR).

* Doc quality.

* Fixing old behavior without `stopping_criteria` but with `max_length`.

Making sure we don't break that in the future.

* Adding more tests for possible inconsistencies between

`max_length` and `stopping_criteria`.

* Fixing the torch imports.
2021-03-12 10:11:50 +01:00

80 lines
2.4 KiB
Python

import time
import unittest
from transformers import is_torch_available
from transformers.testing_utils import require_torch, torch_device
from .test_modeling_common import ids_tensor
if is_torch_available():
import torch
from transformers.generation_stopping_criteria import (
MaxLengthCriteria,
MaxTimeCriteria,
StoppingCriteriaList,
validate_stopping_criteria,
)
@require_torch
class StoppingCriteriaTestCase(unittest.TestCase):
def _get_tensors(self, length):
batch_size = 3
vocab_size = 250
input_ids = ids_tensor((batch_size, length), vocab_size)
scores = torch.ones((batch_size, length), device=torch_device, dtype=torch.float) / length
return input_ids, scores
def test_list_criteria(self):
input_ids, scores = self._get_tensors(5)
criteria = StoppingCriteriaList(
[
MaxLengthCriteria(max_length=10),
MaxTimeCriteria(max_time=0.1),
]
)
self.assertFalse(criteria(input_ids, scores))
input_ids, scores = self._get_tensors(10)
self.assertFalse(criteria(input_ids, scores))
input_ids, scores = self._get_tensors(11)
self.assertTrue(criteria(input_ids, scores))
def test_max_length_criteria(self):
criteria = MaxLengthCriteria(max_length=10)
input_ids, scores = self._get_tensors(5)
self.assertFalse(criteria(input_ids, scores))
input_ids, scores = self._get_tensors(10)
self.assertFalse(criteria(input_ids, scores))
input_ids, scores = self._get_tensors(11)
self.assertTrue(criteria(input_ids, scores))
def test_max_time_criteria(self):
input_ids, scores = self._get_tensors(5)
criteria = MaxTimeCriteria(max_time=0.1)
self.assertFalse(criteria(input_ids, scores))
criteria = MaxTimeCriteria(max_time=0.1, initial_timestamp=time.time() - 0.2)
self.assertTrue(criteria(input_ids, scores))
def test_validate_stopping_criteria(self):
validate_stopping_criteria(StoppingCriteriaList([MaxLengthCriteria(10)]), 10)
with self.assertWarns(UserWarning):
validate_stopping_criteria(StoppingCriteriaList([MaxLengthCriteria(10)]), 11)
stopping_criteria = StoppingCriteriaList()
validate_stopping_criteria(stopping_criteria, 11)
self.assertEqual(len(stopping_criteria), 1)