Refactoring the generate() function (#6949)

* first draft

* show design proposition for new generate method

* up

* make better readable

* make first version

* gpt2 tests pass

* make beam search for gpt2 work

* add first encoder-decoder code

* delete typo

* make t5 work

* save indermediate

* make bart work with beam search

* finish beam search bart / t5

* add default kwargs

* make more tests pass

* fix no bad words sampler

* some fixes and tests for all distribution processors

* fix test

* fix rag slow tests

* merge to master

* add nograd to generate

* make all slow tests pass

* speed up generate

* fix edge case bug

* small fix

* correct typo

* add type hints and docstrings

* fix typos in tests

* add beam search tests

* add tests for beam scorer

* fix test rag

* finish beam search tests

* move generation tests in seperate file

* fix generation tests

* more tests

* add aggressive generation tests

* fix tests

* add gpt2 sample test

* add more docstring

* add more docs

* finish doc strings

* apply some more of sylvains and sams comments

* fix some typos

* make fix copies

* apply lysandres and sylvains comments

* final corrections on examples

* small fix for reformer
This commit is contained in:
Patrick von Platen 2020-11-03 16:04:22 +01:00 committed by GitHub
parent b63beb743c
commit a1bbcf3f6c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
38 changed files with 3022 additions and 1191 deletions

View File

@ -272,3 +272,4 @@ conversion utilities for the following models:
internal/pipelines_utils
internal/tokenization_utils
internal/trainer_utils
internal/generation_utils

View File

@ -0,0 +1,50 @@
Utilities for Generation
-----------------------------------------------------------------------------------------------------------------------
This page lists all the utility functions used by :meth:`~transformers.PretrainedModel.generate`,
:meth:`~transformers.PretrainedModel.greedy_search`, :meth:`~transformers.PretrainedModel.sample`,
:meth:`~transformers.PretrainedModel.beam_search`, and :meth:`~transformers.PretrainedModel.beam_sample`.
Most of those are only useful if you are studying the code of the generate methods in the library.
LogitsProcessor
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
A :class:`~transformers.LogitsProcessor` can be used to modify the prediction scores of a language model head for
generation.
.. autoclass:: transformers.LogitsProcessor
:members: __call__
.. autoclass:: transformers.LogitsProcessorList
:members: __call__
.. autoclass:: transformers.MinLengthLogitsProcessor
:members: __call__
.. autoclass:: transformers.TemperatureLogitsWarper
:members: __call__
.. autoclass:: transformers.RepetitionPenaltyLogitsProcessor
:members: __call__
.. autoclass:: transformers.TopPLogitsWarper
:members: __call__
.. autoclass:: transformers.TopKLogitsWarper
:members: __call__
.. autoclass:: transformers.NoRepeatNGramLogitsProcessor
:members: __call__
.. autoclass:: transformers.NoBadWordsLogitsProcessor
:members: __call__
BeamSearch
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.BeamScorer
:members: process, finalize
.. autoclass:: transformers.BeamSearchScorer
:members: process, finalize

View File

@ -45,7 +45,7 @@ TFModelUtilsMixin
:members:
Generative models
Generation
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.generation_utils.GenerationMixin

View File

@ -299,6 +299,19 @@ if is_torch_available():
TextDataset,
TextDatasetForNextSentencePrediction,
)
from .generation_beam_search import BeamScorer, BeamSearchScorer
from .generation_logits_process import (
LogitsProcessor,
LogitsProcessorList,
LogitsWarper,
MinLengthLogitsProcessor,
NoBadWordsLogitsProcessor,
NoRepeatNGramLogitsProcessor,
RepetitionPenaltyLogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
)
from .generation_utils import top_k_top_p_filtering
from .modeling_albert import (
ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST,

View File

@ -0,0 +1,357 @@
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from collections import UserDict
from typing import Optional, Tuple
import torch
from .file_utils import add_start_docstrings
PROCESS_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_beams, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using any class inheriting from :class:`~transformers.PretrainedTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
details.
`What are input IDs? <../glossary.html#input-ids>`__
next_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2 * num_beams)`):
Current scores of the top :obj:`2 * num_beams` non-finished beam hypotheses.
next_tokens (:obj:`torch.LongTensor` of shape :obj:`(batch_size, 2 * num_beams)`):
:obj:`input_ids` of the tokens corresponding to the top :obj:`2 * num_beams` non-finished beam hypotheses.
next_indices (:obj:`torch.LongTensor` of shape :obj:`(batch_size, 2 * num_beams)`):
Beam indices indicating to which beam hypothesis the :obj:`next_tokens` correspond.
pad_token_id (:obj:`int`, `optional`):
The id of the `padding` token.
eos_token_id (:obj:`int`, `optional`):
The id of the `end-of-sequence` token.
Return:
:obj:`UserDict`: A dictionary composed of the fields as defined above:
- **next_beam_scores** (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`) -- Updated
scores of all non-finished beams.
- **next_beam_tokens** (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`) -- Next tokens
to be added to the non-finished beam_hypotheses.
- **next_beam_indices** (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`) -- Beam indices
indicating to which beam the next tokens shall be added.
"""
FINALIZE_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_beams, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using any class inheriting from :class:`~transformers.PretrainedTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
details.
`What are input IDs? <../glossary.html#input-ids>`__
final_beam_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`):
The final scores of all non-finished beams.
final_beam_tokens (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`):
The last tokens to be added to the non-finished beam_hypotheses.
final_beam_indices (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`):
The beam indices indicating to which beam the :obj:`final_beam_tokens` shall be added.
pad_token_id (:obj:`int`, `optional`):
The id of the `padding` token.
eos_token_id (:obj:`int`, `optional`):
The id of the `end-of-sequence` token.
Return:
:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated
sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or shorter if all
batches finished early due to the :obj:`eos_token_id`.
"""
class BeamScorer(ABC):
"""
Abstract base class for all beam scorers that are used for :meth:`~transformers.PretrainedModel.beam_search` and
:meth:`~transformers.PretrainedModel.beam_sample`.
"""
@abstractmethod
@add_start_docstrings(PROCESS_INPUTS_DOCSTRING)
def process(
self,
input_ids: torch.LongTensor,
next_scores: torch.FloatTensor,
next_tokens: torch.LongTensor,
next_indices: torch.LongTensor,
**kwargs
) -> Tuple[torch.Tensor]:
raise NotImplementedError("This is an abstract method.")
@abstractmethod
@add_start_docstrings(FINALIZE_INPUTS_DOCSTRING)
def finalize(
self,
input_ids: torch.LongTensor,
next_scores: torch.FloatTensor,
next_tokens: torch.LongTensor,
next_indices: torch.LongTensor,
**kwargs
) -> torch.LongTensor:
raise NotImplementedError("This is an abstract method.")
class BeamSearchScorer(BeamScorer):
r"""
:class:`transformers.BeamScorer` implementing standard beam search decoding.
Adapted in part from `Facebook's XLM beam search code
<https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529>`__.
Args:
batch_size (:obj:`int`):
Batch Size of :obj:`input_ids` for which beam search decoding is run in parallel.
max_length (:obj:`int`):
The maximum length of the sequence to be generated.
num_beams (:obj:`int`):
Number of beams for beam search.
device (:obj:`torch.device`):
Defines the device type (*e.g.*, :obj:`"cpu"` or :obj:`"cuda"`) on which this instance of
:obj:`BeamSearchScorer` will be allocated.
length_penalty (:obj:`float`, `optional`, defaults to 1.0):
Exponential penalty to the length. 1.0 means no penalty. Set to values < 1.0 in order to encourage the
model to generate shorter sequences, to a value > 1.0 in order to encourage the model to produce longer
sequences.
do_early_stopping (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to stop the beam search when at least ``num_beams`` sentences are finished per batch or not.
num_beam_hyps_to_keep (:obj:`int`, `optional`, defaults to 1):
The number of beam hypotheses that shall be returned upon calling
:meth:`~transformer.BeamSearchScorer.finalize`.
"""
def __init__(
self,
batch_size: int,
max_length: int,
num_beams: int,
device: torch.device,
length_penalty: Optional[float] = 1.0,
do_early_stopping: Optional[bool] = False,
num_beam_hyps_to_keep: Optional[int] = 1,
):
self.max_length = max_length
self.num_beams = num_beams
self.device = device
self.length_penalty = length_penalty
self.do_early_stopping = do_early_stopping
self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
self._is_init = False
self._beam_hyps = [
BeamHypotheses(
num_beams=self.num_beams,
max_length=self.max_length,
length_penalty=self.length_penalty,
early_stopping=self.do_early_stopping,
)
for _ in range(batch_size)
]
self._done = torch.tensor([False for _ in range(batch_size)], dtype=torch.bool, device=self.device)
if not isinstance(num_beams, int) or num_beams <= 1:
raise ValueError(
f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1, one should make use of `greedy_search` instead."
)
@property
def is_done(self) -> bool:
return self._done.all()
def process(
self,
input_ids: torch.LongTensor,
next_scores: torch.FloatTensor,
next_tokens: torch.LongTensor,
next_indices: torch.LongTensor,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
) -> Tuple[torch.Tensor]:
cur_len = input_ids.shape[-1]
batch_size = len(self._beam_hyps)
assert batch_size == (input_ids.shape[0] // self.num_beams)
device = input_ids.device
next_beam_scores = torch.zeros((batch_size, self.num_beams), dtype=next_scores.dtype, device=device)
next_beam_tokens = torch.zeros((batch_size, self.num_beams), dtype=next_tokens.dtype, device=device)
next_beam_indices = torch.zeros((batch_size, self.num_beams), dtype=next_indices.dtype, device=device)
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
if self._done[batch_idx]:
assert (
len(beam_hyp) >= self.num_beams
), "Batch can only be done if at least {} beams have been generated".format(self.num_beams)
assert (
eos_token_id is not None and pad_token_id is not None
), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
# pad the batch
next_beam_scores[batch_idx, :] = 0
next_beam_tokens[batch_idx, :] = pad_token_id
next_beam_indices[batch_idx, :] = 0
continue
# next tokens for this sentence
beam_idx = 0
for beam_token_rank, (next_token, next_score, next_index) in enumerate(
zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx])
):
batch_beam_idx = batch_idx * self.num_beams + next_index
# add to generated hypotheses if end of sentence
if (eos_token_id is not None) and (next_token.item() == eos_token_id):
# if beam_token does not belong to top num_beams tokens, it should not be added
is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.num_beams
if is_beam_token_worse_than_top_num_beams:
continue
beam_hyp.add(
input_ids[batch_beam_idx].clone(),
next_score.item(),
)
else:
# add next predicted token since it is not eos_token
next_beam_scores[batch_idx, beam_idx] = next_score
next_beam_tokens[batch_idx, beam_idx] = next_token
next_beam_indices[batch_idx, beam_idx] = batch_beam_idx
beam_idx += 1
# once the beam for next step is full, don't add more tokens to it.
if beam_idx == self.num_beams:
break
if beam_idx < self.num_beams:
raise ValueError(
f"At most {self.num_beams} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id: {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected."
)
# Check if we are done so that we can save a pad step if all(done)
self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done(
next_scores[batch_idx].max().item(), cur_len
)
return UserDict(
{
"next_beam_scores": next_beam_scores.view(-1),
"next_beam_tokens": next_beam_tokens.view(-1),
"next_beam_indices": next_beam_indices.view(-1),
}
)
def finalize(
self,
input_ids: torch.LongTensor,
final_beam_scores: torch.FloatTensor,
final_beam_tokens: torch.LongTensor,
final_beam_indices: torch.LongTensor,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
) -> torch.LongTensor:
batch_size = len(self._beam_hyps)
# finalize all open beam hypotheses and add to generated hypotheses
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
if self._done[batch_idx]:
continue
# need to add best num_beams hypotheses to generated hyps
for beam_id in range(self.num_beams):
batch_beam_idx = batch_idx * self.num_beams + beam_id
final_score = final_beam_scores[batch_beam_idx].item()
final_tokens = input_ids[batch_beam_idx]
beam_hyp.add(final_tokens, final_score)
# select the best hypotheses
sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
best = []
# retrieve best hypotheses
for i, beam_hyp in enumerate(self._beam_hyps):
sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0])
for j in range(self.num_beam_hyps_to_keep):
best_hyp = sorted_hyps.pop()[1]
sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)
best.append(best_hyp)
# prepare for adding eos
sent_max_len = min(sent_lengths.max().item() + 1, self.max_length)
decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
# shorter batches are padded if needed
if sent_lengths.min().item() != sent_lengths.max().item():
assert pad_token_id is not None, "`pad_token_id` has to be defined"
decoded.fill_(pad_token_id)
# fill with hypotheses and eos_token_id if the latter fits in
for i, hypo in enumerate(best):
decoded[i, : sent_lengths[i]] = hypo
if sent_lengths[i] < self.max_length:
decoded[i, sent_lengths[i]] = eos_token_id
return decoded
class BeamHypotheses:
def __init__(self, num_beams: int, max_length: int, length_penalty: float, early_stopping: bool):
"""
Initialize n-best list of hypotheses.
"""
self.max_length = max_length - 1 # ignoring bos_token
self.length_penalty = length_penalty
self.early_stopping = early_stopping
self.num_beams = num_beams
self.beams = []
self.worst_score = 1e9
def __len__(self):
"""
Number of hypotheses in the list.
"""
return len(self.beams)
def add(self, hyp: torch.LongTensor, sum_logprobs: float):
"""
Add a new hypothesis to the list.
"""
score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty)
if len(self) < self.num_beams or score > self.worst_score:
self.beams.append((score, hyp))
if len(self) > self.num_beams:
sorted_next_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)])
del self.beams[sorted_next_scores[0][1]]
self.worst_score = sorted_next_scores[1][0]
else:
self.worst_score = min(score, self.worst_score)
def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool:
"""
If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst
one in the heap, then we are done with this sentence.
"""
if len(self) < self.num_beams:
return False
elif self.early_stopping:
return True
else:
cur_score = best_sum_logprobs / cur_len ** self.length_penalty
ret = self.worst_score >= cur_score
return ret

View File

@ -0,0 +1,374 @@
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC
from typing import Iterable, List
import numpy as np
import torch
from torch.nn import functional as F
from .file_utils import add_start_docstrings
LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using :class:`~transformers.BertTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
details.
`What are input IDs? <../glossary.html#input-ids>`__
scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.vocab_size)`):
Prediction scores of a language modeling head. These can be scores for each vocabulary token before SoftMax
or scores for each vocabulary token after SoftMax.
Return:
:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.vocab_size)`: The processed prediction scores.
"""
class LogitsProcessor(ABC):
"""Abstract base class for all logit processors that can be applied during generation."""
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
"""Torch method for processing logits."""
raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
)
class LogitsWarper(ABC):
"""Abstract base class for all logit warpers that can be applied during generation with multinomial sampling."""
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
"""Torch method for warping logits."""
raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
)
class LogitsProcessorList(list):
"""
This class can be used to create a list of :class:`~transformers.LogitsProcessor` or
:class:`~transformers.LogitsWarper` to subsequently process a :obj:`scores` input tensor. This class inherits from
list and adds a specific `__call__` method to apply each :class:`~transformers.LogitsProcessor` or
:class:`~transformers.LogitsProcessor` to the inputs.
"""
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
for processor in self:
scores = processor(input_ids, scores)
return scores
class MinLengthLogitsProcessor(LogitsProcessor):
r"""
:class:`transformers.LogitsProcessor` enforcing a min-length by setting EOS probability to 0.
Args:
min_length (:obj:`int`):
The minimum length below which the score of :obj:`eos_token_id` is set to :obj:`-float("Inf")`.
eos_token_id (:obj:`int`):
The id of the `end-of-sequence` token.
"""
def __init__(self, min_length: int, eos_token_id: int):
if not isinstance(min_length, int) or min_length < 0:
raise ValueError(f"`min_length` has to be a positive integer, but is {min_length}")
if not isinstance(eos_token_id, int) or eos_token_id < 0:
raise ValueError(f"`eos_token_id` has to be a positive integer, but is {eos_token_id}")
self.min_length = min_length
self.eos_token_id = eos_token_id
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
cur_len = input_ids.shape[-1]
if cur_len < self.min_length:
scores[:, self.eos_token_id] = -float("inf")
return scores
class TemperatureLogitsWarper(LogitsWarper):
r"""
:class:`transformers.LogitsWarper` for temperature (exponential scaling output probability distribution).
Args:
temperature (:obj:`float`):
The value used to module the logits distribution.
"""
def __init__(self, temperature: float):
if not isinstance(temperature, float) or not (temperature > 0):
raise ValueError(f"`temperature` has to be a strictly positive float, but is {temperature}")
self.temperature = temperature
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
scores = scores / self.temperature
return scores
class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
r"""
:class:`transformers.LogitsProcessor` enforcing an exponential penalty on repeated sequences.
Args:
repetition_penalty (:obj:`float`):
The parameter for repetition penalty. 1.0 means no penalty. See `this paper
<https://arxiv.org/pdf/1909.05858.pdf>`__ for more details.
"""
def __init__(self, penalty: float):
if not isinstance(penalty, float) or not (penalty > 0):
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
self.penalty = penalty
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
for i in range(scores.shape[0]):
for previous_token in set(input_ids[i].tolist()):
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
if scores[i, previous_token] < 0:
scores[i, previous_token] *= self.penalty
else:
scores[i, previous_token] /= self.penalty
return scores
class TopPLogitsWarper(LogitsWarper):
"""
:class:`transformers.LogitsWarper` that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <=
prob_cut_off.
Args:
top_p (:obj:`float`):
If set to < 1, only the most probable tokens with probabilities that add up to :obj:`top_p` or higher are
kept for generation.
filter_value (:obj:`float`, `optional`, defaults to :obj:`-float("Inf")`):
All filtered values will be set to this float value.
min_tokens_to_keep (:obj:`int`, `optional`, defaults to 1):
Minimum number of tokens that cannot be filtered.
"""
def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
if not isinstance(top_p, float) or (top_p < 0 or top_p > 1.0):
raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
self.top_p = top_p
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
sorted_logits, sorted_indices = torch.sort(scores, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs > self.top_p
if self.min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove[..., : self.min_tokens_to_keep - 1] = 0
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
scores[indices_to_remove] = self.filter_value
return scores
class TopKLogitsWarper(LogitsWarper):
r"""
:class:`transformers.LogitsWarper` that performs top-k, i.e. restricting to the k highest probability elements.
Args:
top_k (:obj:`int`):
The number of highest probability vocabulary tokens to keep for top-k-filtering.
filter_value (:obj:`float`, `optional`, defaults to :obj:`-float("Inf")`):
All filtered values will be set to this float value.
min_tokens_to_keep (:obj:`int`, `optional`, defaults to 1):
Minimum number of tokens that cannot be filtered.
"""
def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
if not isinstance(top_k, int) or top_k <= 0:
raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}")
self.top_k = top_k
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
top_k = min(max(self.top_k, self.min_tokens_to_keep), scores.size(-1)) # Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = scores < torch.topk(scores, top_k)[0][..., -1, None]
scores[indices_to_remove] = self.filter_value
return scores
class NoRepeatNGramLogitsProcessor(LogitsProcessor):
r"""
:class:`transformers.LogitsProcessor` that enforces no repetition of n-grams. See `Fairseq
<https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345>`__.
Args:
ngram_size (:obj:`int`):
All ngrams of size :obj:`ngram_size` can only occur once.
"""
def __init__(self, ngram_size: int):
if not isinstance(ngram_size, int) or ngram_size <= 0:
raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}")
self.ngram_size = ngram_size
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
num_batch_hypotheses = scores.shape[0]
cur_len = input_ids.shape[-1]
banned_batch_tokens = self._calc_banned_ngram_tokens(input_ids, num_batch_hypotheses, cur_len)
for i, banned_tokens in enumerate(banned_batch_tokens):
scores[i, banned_tokens] = -float("inf")
return scores
def _calc_banned_ngram_tokens(
self, prev_input_ids: torch.Tensor, num_hypos: int, cur_len: int
) -> List[Iterable[int]]:
"""Copied from fairseq for no_repeat_ngram in beam_search"""
if cur_len + 1 < self.ngram_size:
# return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
return [[] for _ in range(num_hypos)]
generated_ngrams = [{} for _ in range(num_hypos)]
for idx in range(num_hypos):
gen_tokens = prev_input_ids[idx].tolist()
generated_ngram = generated_ngrams[idx]
for ngram in zip(*[gen_tokens[i:] for i in range(self.ngram_size)]):
prev_ngram_tuple = tuple(ngram[:-1])
generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
def _get_generated_ngrams(hypo_idx):
# Before decoding the next token, prevent decoding of ngrams that have already appeared
start_idx = cur_len + 1 - self.ngram_size
ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist())
return generated_ngrams[hypo_idx].get(ngram_idx, [])
banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
return banned_tokens
class NoBadWordsLogitsProcessor(LogitsProcessor):
"""
:class:`transformers.LogitsProcessor` that enforces that specified sequences will never be sampled.
Args:
bad_words_ids (:obj:`List[List[int]]`):
List of list of token ids that are not allowed to be generated. In order to get the tokens of the words
that should not appear in the generated text, use :obj:`tokenizer(bad_word,
add_prefix_space=True).input_ids`.
eos_token_id (:obj:`int`):
The id of the `end-of-sequence` token.
"""
def __init__(self, bad_words_ids: Iterable[Iterable[int]], eos_token_id: int):
if not isinstance(bad_words_ids, List) or len(bad_words_ids) == 0:
raise ValueError(f"`bad_words_ids` has to be a non-emtpy list, but is {bad_words_ids}.")
if any(not isinstance(bad_word_ids, list) for bad_word_ids in bad_words_ids):
raise ValueError(f"`bad_words_ids` has to be a list of lists, but is {bad_words_ids}.")
if any(
any((not isinstance(token_id, (int, np.integer)) or token_id < 0) for token_id in bad_word_ids)
for bad_word_ids in bad_words_ids
):
raise ValueError(
f"Each list in `bad_words_ids` has to be a list of positive integers, but is {bad_words_ids}."
)
self.bad_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], bad_words_ids))
for banned_token_seq in self.bad_words_ids:
assert len(banned_token_seq) > 0, "Banned words token sequences {} cannot have an empty list".format(
bad_words_ids
)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
banned_tokens = self._calc_banned_bad_words_ids(input_ids)
scores = self._set_scores_to_inf_for_banned_tokens(scores, banned_tokens)
return scores
def _tokens_match(self, prev_tokens: torch.LongTensor, tokens: List[int]) -> bool:
if len(tokens) == 0:
# if bad word tokens is just one token always ban it
return True
elif len(tokens) > len(prev_tokens):
# if bad word tokens are longer then prev input_ids they can't be equal
return False
elif prev_tokens[-len(tokens) :].tolist() == tokens:
# if tokens match
return True
else:
return False
def _calc_banned_bad_words_ids(self, prev_input_ids: Iterable[int]) -> Iterable[int]:
banned_tokens = []
for prev_input_ids_slice in prev_input_ids:
banned_tokens_slice = []
for banned_token_seq in self.bad_words_ids:
if self._tokens_match(prev_input_ids_slice, banned_token_seq[:-1]) is False:
# if tokens do not match continue
continue
banned_tokens_slice.append(banned_token_seq[-1])
banned_tokens.append(banned_tokens_slice)
return banned_tokens
def _set_scores_to_inf_for_banned_tokens(self, scores: torch.Tensor, banned_tokens: List[List[int]]) -> None:
"""
Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be a
list of list of banned tokens to ban in the format [[batch index, vocabulary position],...
Args:
scores: logits distribution of shape (batch size, vocabulary size)
banned_tokens: list of list of tokens to ban of length (batch_size)
"""
banned_mask_list = []
for idx, batch_banned_tokens in enumerate(banned_tokens):
for token in batch_banned_tokens:
banned_mask_list.append([idx, token])
if not banned_mask_list:
return scores
banned_mask = torch.LongTensor(banned_mask_list)
indices = torch.ones(len(banned_mask))
# A sparse tensor is generated from a list of coordinates: [[0, 1], [0, 2], [2, 0]]. A conversion to dense tensor generates:
# [ 0 1 1 ]
# [ 0 0 0 ]
# [ 1 0 0 ]
banned_mask = (
torch.sparse.LongTensor(banned_mask.t(), indices, scores.size()).to(scores.device).to_dense().bool()
)
scores = scores.masked_fill(banned_mask, -float("inf"))
return scores

File diff suppressed because it is too large Load Diff

View File

@ -1084,7 +1084,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
)
def prepare_inputs_for_generation(
self, decoder_input_ids, past, attention_mask, use_cache, encoder_outputs, **kwargs
self, decoder_input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
):
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed

View File

@ -514,12 +514,12 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
def get_output_embeddings(self):
return self.lm_head
def prepare_inputs_for_generation(self, input_ids, past, **kwargs):
def prepare_inputs_for_generation(self, input_ids, past=None, use_cache=None, **kwargs):
# only last token for inputs_ids if past is defined in kwargs
if past:
input_ids = input_ids[:, -1].unsqueeze(-1)
return {"input_ids": input_ids, "past_key_values": past, "use_cache": kwargs["use_cache"]}
return {"input_ids": input_ids, "past_key_values": past, "use_cache": use_cache}
@add_start_docstrings_to_model_forward(CTRL_INPUTS_DOCSTRING)
@add_code_sample_docstrings(

View File

@ -431,7 +431,7 @@ class EncoderDecoderModel(PreTrainedModel):
encoder_attentions=encoder_outputs.attentions,
)
def prepare_inputs_for_generation(self, input_ids, past, attention_mask, encoder_outputs, **kwargs):
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, encoder_outputs=None, **kwargs):
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids)
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
input_dict = {

View File

@ -1107,7 +1107,7 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel):
)
def prepare_inputs_for_generation(
self, decoder_input_ids, past, attention_mask, use_cache, encoder_outputs, **kwargs
self, decoder_input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
):
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed

View File

@ -1800,7 +1800,7 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel):
return loss
def prepare_inputs_for_generation(
self, decoder_input_ids, past, attention_mask, use_cache, encoder_outputs, **kwargs
self, decoder_input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
):
assert encoder_outputs is not None, "`encoder_outputs` have to be passed for generation."

View File

@ -22,6 +22,7 @@ import torch
from .configuration_rag import RagConfig
from .configuration_utils import PretrainedConfig
from .file_utils import add_start_docstrings_to_model_forward, replace_return_docstrings
from .generation_beam_search import BeamSearchScorer
from .modeling_outputs import ModelOutput
from .modeling_utils import PreTrainedModel
from .retrieval_rag import RagRetriever
@ -825,7 +826,7 @@ class RagSequenceForGeneration(RagPreTrainedModel):
num_return_sequences=None, # defaults to 1
num_beams=None, # defaults to 1
n_docs=None,
**kwargs
**model_kwargs
):
"""
Implements RAG sequence "thorough" decoding. Read the :meth:`~transformers.PreTrainedModel.generate``
@ -872,7 +873,6 @@ class RagSequenceForGeneration(RagPreTrainedModel):
)
num_beams = num_beams if num_beams is not None else self.config.num_beams
# TODO(patrick) - clean up generate here
if self.retriever is not None and context_input_ids is None:
question_hidden_states = self.question_encoder(input_ids, attention_mask=attention_mask)[0]
context_input_ids = self.retriever(
@ -887,10 +887,9 @@ class RagSequenceForGeneration(RagPreTrainedModel):
context_input_ids = context_input_ids.to(input_ids)
hypos = []
kwargs["num_beams"] = num_beams
kwargs["num_return_sequences"] = num_beams
kwargs["attention_mask"] = None
kwargs["n_docs"] = n_docs
model_kwargs["num_beams"] = num_beams
model_kwargs["num_return_sequences"] = num_beams
model_kwargs["attention_mask"] = None
for index in range(len(input_ids)):
# first, generate beams from documents:
@ -898,7 +897,7 @@ class RagSequenceForGeneration(RagPreTrainedModel):
output_sequences = self.generator.generate(
generator_input_ids,
**kwargs,
**model_kwargs,
) # n_docs * n_beam, tgt_len
if do_deduplication:
# do_deduplication, max_output_len
@ -1018,7 +1017,15 @@ class RagTokenForGeneration(RagPreTrainedModel):
return self.rag.generator.adjust_logits_during_generation(logits, cur_len=cur_len, max_length=max_length)
def prepare_inputs_for_generation(
self, decoder_input_ids, past, attention_mask, use_cache, encoder_outputs, doc_scores, n_docs=None, **kwargs
self,
decoder_input_ids,
past=None,
attention_mask=None,
use_cache=None,
encoder_outputs=None,
doc_scores=None,
n_docs=None,
**kwargs
):
return {
"input_ids": None,
@ -1222,11 +1229,12 @@ class RagTokenForGeneration(RagPreTrainedModel):
eos_token_id=None,
length_penalty=None,
no_repeat_ngram_size=None,
repetition_penalty=None,
bad_words_ids=None,
num_return_sequences=None,
decoder_start_token_id=None,
n_docs=None,
**kwargs
**model_kwargs
):
"""
Implements RAG token decoding.
@ -1307,22 +1315,15 @@ class RagTokenForGeneration(RagPreTrainedModel):
"""
# set default parameters
n_docs = n_docs if n_docs is not None else self.config.n_docs
max_length = max_length if max_length is not None else self.config.max_length
min_length = min_length if min_length is not None else self.config.min_length
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
use_cache = use_cache if use_cache is not None else self.config.use_cache
num_beams = num_beams if num_beams is not None else self.config.num_beams
bos_token_id = bos_token_id if bos_token_id is not None else self.config.generator.bos_token_id
pad_token_id = pad_token_id if pad_token_id is not None else self.config.generator.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.generator.eos_token_id
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
no_repeat_ngram_size = (
no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
)
bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
max_length = max_length if max_length is not None else self.config.max_length
num_return_sequences = (
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
)
bos_token_id = bos_token_id if bos_token_id is not None else self.config.generator.bos_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.generator.eos_token_id
pad_token_id = pad_token_id if pad_token_id is not None else self.config.generator.pad_token_id
use_cache = use_cache if use_cache is not None else self.config.use_cache
decoder_start_token_id = (
decoder_start_token_id
if decoder_start_token_id is not None
@ -1365,7 +1366,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
encoder = self.rag.generator.get_encoder()
encoder_outputs = encoder(input_ids=context_input_ids, attention_mask=context_attention_mask, return_dict=True)
decoder_input_ids = torch.full(
input_ids = torch.full(
(batch_size * num_beams, 1),
decoder_start_token_id,
dtype=torch.long,
@ -1388,64 +1389,57 @@ class RagTokenForGeneration(RagPreTrainedModel):
doc_scores = doc_scores.repeat_interleave(num_beams, dim=0)
# define start_len & additional parameters
cur_len = 1
vocab_size = self.config.generator.vocab_size
kwargs["doc_scores"] = doc_scores
kwargs["encoder_outputs"] = encoder_outputs
kwargs["n_docs"] = n_docs
model_kwargs["doc_scores"] = doc_scores
model_kwargs["encoder_outputs"] = encoder_outputs
model_kwargs["attention_mask"] = context_attention_mask
model_kwargs["n_docs"] = n_docs
# not needed. TODO(PVP): change after generate refactor
do_sample = False
temperature = self.config.temperature
top_k = self.config.top_k
top_p = self.config.top_p
repetition_penalty = self.config.repetition_penalty
pre_processor = self._get_logits_processor(
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
bad_words_ids=bad_words_ids,
min_length=min_length,
eos_token_id=eos_token_id,
)
if num_beams > 1:
return self._generate_beam_search(
decoder_input_ids,
cur_len=cur_len,
if num_beams == 1:
if num_return_sequences > 1:
raise ValueError(
f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search."
)
return self.greedy_search(
input_ids,
pre_processor=pre_processor,
max_length=max_length,
min_length=min_length,
do_sample=do_sample,
early_stopping=early_stopping,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
bad_words_ids=bad_words_ids,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
**model_kwargs,
)
elif num_beams > 1:
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
if num_return_sequences > num_beams:
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
num_return_sequences=num_return_sequences,
length_penalty=length_penalty,
max_length=max_length,
num_beams=num_beams,
vocab_size=vocab_size,
attention_mask=context_attention_mask,
use_cache=use_cache,
model_kwargs=kwargs,
device=self.device,
length_penalty=length_penalty,
do_early_stopping=early_stopping,
num_beam_hyps_to_keep=num_return_sequences,
)
return self.beam_search(
input_ids,
beam_scorer,
pre_processor=pre_processor,
max_length=max_length,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
**model_kwargs,
)
else:
return self._generate_no_beam_search(
decoder_input_ids,
cur_len=cur_len,
max_length=max_length,
min_length=min_length,
do_sample=do_sample,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
bad_words_ids=bad_words_ids,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
batch_size=batch_size,
attention_mask=context_attention_mask,
use_cache=use_cache,
model_kwargs=kwargs,
)
raise ValueError(f"`num_beams` has to be an integer strictly superior to 0 (≥ 1), but is {num_beams}")
def get_input_embeddings(self):
return self.rag.generator.get_input_embeddings()

View File

@ -638,7 +638,6 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
rotations_shape = (self.num_attention_heads, vectors.shape[-1], num_hashes, rotation_size // 2)
# create a random self.attention_head_size x num_hashes x num_buckets/2
random_rotations = torch.randn(rotations_shape, device=vectors.device, dtype=vectors.dtype)
# Output dim: Batch_Size x Num_Attn_Heads x Num_Hashes x Seq_Len x Num_Buckets/2
rotated_vectors = torch.einsum("bmtd,mdhr->bmhtr", vectors, random_rotations)
@ -1471,7 +1470,9 @@ class ReformerLayer(nn.Module):
# every forward pass we sample a different seed
# for dropout and save for forward fn in backward pass
# to have correct dropout
self._init_attention_seed()
if self.training:
self._init_attention_seed()
attn_outputs = self.attention(
hidden_states=hidden_states,
head_mask=head_mask,
@ -1494,7 +1495,8 @@ class ReformerLayer(nn.Module):
# every forward pass we sample a different seed
# for dropout and save seed for forward fn in backward
# to have correct dropout
self._init_feed_forward_seed()
if self.training:
self._init_feed_forward_seed()
# Y_2 = X_2 + g(Y_1)
hidden_states = hidden_states + self.feed_forward(attn_output)
@ -2263,7 +2265,7 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel):
attentions=reformer_outputs.attentions,
)
def prepare_inputs_for_generation(self, input_ids, past, **kwargs):
def prepare_inputs_for_generation(self, input_ids, past=None, use_cache=None, num_hashes=None, **kwargs):
# only last token for inputs_ids if past is defined in kwargs
if past is not None:
input_ids = input_ids[:, -1:]
@ -2271,12 +2273,10 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel):
inputs_dict = {
"input_ids": input_ids,
"past_buckets_states": past,
"use_cache": kwargs["use_cache"],
"use_cache": use_cache,
"num_hashes": num_hashes,
}
if "num_hashes" in kwargs:
inputs_dict["num_hashes"] = kwargs["num_hashes"]
return inputs_dict
def _reorder_cache(self, past, beam_idx):

View File

@ -1232,7 +1232,9 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
encoder_attentions=encoder_outputs.attentions,
)
def prepare_inputs_for_generation(self, input_ids, past, attention_mask, use_cache, encoder_outputs, **kwargs):
def prepare_inputs_for_generation(
self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
):
# cut decoder_input_ids if past is used
if past is not None:

View File

@ -1091,7 +1091,7 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
else:
return self.crit.out_layers[-1]
def prepare_inputs_for_generation(self, input_ids, past, **model_kwargs):
def prepare_inputs_for_generation(self, input_ids, past=None, **model_kwargs):
inputs = {}
# if past is defined in model kwargs then use it for faster decoding

View File

@ -1300,7 +1300,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
def get_output_embeddings(self):
return self.lm_loss
def prepare_inputs_for_generation(self, input_ids, past, **kwargs):
def prepare_inputs_for_generation(self, input_ids, past=None, use_cache=None, **kwargs):
# Add dummy token at the end (no attention on this one)
effective_batch_size = input_ids.shape[0]
@ -1333,7 +1333,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
"input_ids": input_ids,
"perm_mask": perm_mask,
"target_mapping": target_mapping,
"use_cache": kwargs["use_cache"],
"use_cache": use_cache,
}
# if past is defined in model kwargs then use it for faster decoding

View File

@ -88,8 +88,8 @@ def is_pipeline_test(test_case):
"""
Decorator marking a test as a pipeline test.
Pipeline tests are skipped by default and we can run only them by setting RUN_PIPELINE_TEST environment variable to
a truthy value and selecting the is_pipeline_test pytest mark.
Pipeline tests are skipped by default and we can run only them by setting RUN_PIPELINE_TESTS environment variable
to a truthy value and selecting the is_pipeline_test pytest mark.
"""
if not _run_pipeline_tests:

View File

@ -104,6 +104,66 @@ class TextDatasetForNextSentencePrediction:
requires_pytorch(self)
class BeamScorer:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
class BeamSearchScorer:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
class LogitsProcessor:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
class LogitsProcessorList:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
class LogitsWarper:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
class MinLengthLogitsProcessor:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
class NoBadWordsLogitsProcessor:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
class NoRepeatNGramLogitsProcessor:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
class RepetitionPenaltyLogitsProcessor:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
class TemperatureLogitsWarper:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
class TopKLogitsWarper:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
class TopPLogitsWarper:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
def top_k_top_p_filtering(*args, **kwargs):
requires_pytorch(top_k_top_p_filtering)

View File

@ -0,0 +1,239 @@
# coding=utf-8
# Copyright 2020 The HuggingFace Team Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a clone of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import is_torch_available
from transformers.testing_utils import require_torch, torch_device
from .test_modeling_common import floats_tensor, ids_tensor
if is_torch_available():
import torch
from transformers.generation_beam_search import BeamHypotheses, BeamSearchScorer
class BeamSearchTester:
def __init__(
self,
parent,
batch_size=3,
sequence_length=10,
vocab_size=99,
pad_token_id=0,
max_length=20,
num_beams=4,
length_penalty=2.0,
do_early_stopping=True,
num_beam_hyps_to_keep=2,
):
self.parent = parent
self.batch_size = batch_size
self.sequence_length = sequence_length
self.vocab_size = vocab_size
self.pad_token_id = pad_token_id
self.max_length = max_length
self.num_beams = num_beams
self.length_penalty = length_penalty
self.do_early_stopping = do_early_stopping
self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
# cannot be randomely generated
self.eos_token_id = vocab_size + 1
def prepare_beam_scorer(self, **kwargs):
return BeamSearchScorer(
batch_size=kwargs.get("batch_size", self.batch_size),
max_length=kwargs.get("max_length", self.max_length),
num_beams=kwargs.get("num_beams", self.num_beams),
device=torch_device,
length_penalty=kwargs.get("length_penalty", self.length_penalty),
do_early_stopping=kwargs.get("do_early_stopping", self.do_early_stopping),
num_beam_hyps_to_keep=kwargs.get("num_beam_hyps_to_keep", self.num_beam_hyps_to_keep),
)
def prepare_inputs(self):
input_ids = ids_tensor((self.batch_size * self.num_beams, self.sequence_length), self.vocab_size)
next_tokens = ids_tensor((self.batch_size, 2 * self.num_beams), self.vocab_size).to(torch_device)
next_indices = ids_tensor((self.batch_size, 2 * self.num_beams), self.num_beams).to(torch_device)
next_scores, _ = (-floats_tensor((self.batch_size, 2 * self.num_beams)).to(torch_device)).sort(descending=True)
return (input_ids, next_tokens, next_indices, next_scores)
def check_beam_hypotheses(self, input_ids, *args):
# check that correct number of beam hypotheses is set in beam scorer
beam_scorer = self.prepare_beam_scorer(do_early_stopping=True)
beam_hyp = beam_scorer._beam_hyps[0]
self.parent.assertEqual(len(beam_scorer._beam_hyps), self.batch_size)
# check correct type
self.parent.assertTrue(isinstance(beam_hyp, BeamHypotheses))
# check that num_beams is correctly set
self.parent.assertEqual(beam_hyp.num_beams, self.num_beams)
# check for early stopping deactivated
for beam_idx in range(self.num_beams):
beam_hyp.add(input_ids[beam_idx], -10.0)
# if early stopping True -> score does not matter
self.parent.assertTrue(beam_hyp.is_done(-10.0, 5))
# re-init
beam_scorer = self.prepare_beam_scorer(do_early_stopping=False)
beam_hyp = beam_scorer._beam_hyps[0]
# add `num_beams + 1` beams to change `worst_score`
for beam_idx in range(self.num_beams + 1):
beam_hyp.add(input_ids[beam_idx], -10.0 + float(beam_idx))
# -10.0 is removed => -9.0 is worst score
self.parent.assertAlmostEqual(beam_hyp.worst_score, -9.0 / (self.sequence_length ** beam_hyp.length_penalty))
# -5.0 is better than worst score => should not be finished
self.parent.assertFalse(beam_hyp.is_done(-5.0, self.sequence_length))
# -20.0 is worse than worst score => should be finished
self.parent.assertTrue(beam_hyp.is_done(-20.0, self.sequence_length))
def check_beam_scorer_update(self, input_ids, next_tokens, next_indices, next_scores):
# check too many eos tokens
beam_scorer = self.prepare_beam_scorer()
tokens = next_tokens.clone()
tokens[0, :] = self.eos_token_id
with self.parent.assertRaises(ValueError):
beam_scorer.process(input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id)
# check all batches are done
beam_scorer = self.prepare_beam_scorer()
tokens = next_tokens.clone()
tokens[:, : self.num_beams] = self.eos_token_id
beam_scorer.process(input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id)
# beam scorer should be done
self.parent.assertTrue(beam_scorer.is_done)
# check
beam_scorer = self.prepare_beam_scorer()
tokens = next_tokens.clone()
tokens[:, 1] = self.eos_token_id
beam_outputs = beam_scorer.process(
input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id
)
output_scores = beam_outputs["next_beam_scores"]
output_tokens = beam_outputs["next_beam_tokens"]
output_indices = beam_outputs["next_beam_indices"]
def cut_expected_tensor(tensor):
return torch.cat([tensor[:, :1], tensor[:, 2 : self.num_beams + 1]], dim=1).flatten()
# check all outptus
# cut out id of eos token and take best `num_beams` outputs
expected_output_tokens = cut_expected_tensor(tokens)
expected_output_scores = cut_expected_tensor(next_scores)
# add num_beams * batch_idx
expected_output_indices = (
cut_expected_tensor(next_indices)
+ (torch.arange(self.num_beams * self.batch_size, device=torch_device) // self.num_beams) * self.num_beams
)
self.parent.assertListEqual(expected_output_tokens.tolist(), output_tokens.tolist())
self.parent.assertListEqual(expected_output_indices.tolist(), output_indices.tolist())
self.parent.assertTrue(torch.allclose(expected_output_scores, output_scores, atol=1e-3))
# make sure ids of eos token are correctly saved in beam_hyps of beam scorer
for batch_idx in range(self.batch_size):
correct_idx = batch_idx * self.num_beams + next_indices[batch_idx, 1]
self.parent.assertListEqual(
input_ids[correct_idx].tolist(), beam_scorer._beam_hyps[batch_idx].beams[0][-1].tolist()
)
def check_beam_scores_finalize(self, input_ids, next_tokens, next_indices, next_scores):
# max_length should be only one more than current input_ids to check that eos is correctly appended
max_length = self.sequence_length + 1
beam_scorer = self.prepare_beam_scorer(
num_beam_hyps_to_keep=1, max_length=max_length, length_penalty=1.0, do_early_stopping=False
)
# update beams and append to input_ids
tokens = next_tokens.clone()
# first batch, first output has to finish with eos token id since scores are correctly sorted
tokens[0, 0] = self.eos_token_id
# make sure corresponding score is as good as possible to surely be picked first
next_scores[0, 0] = 0.0
beam_outputs = beam_scorer.process(
input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id
)
output_scores = beam_outputs["next_beam_scores"]
output_tokens = beam_outputs["next_beam_tokens"]
output_indices = beam_outputs["next_beam_indices"]
input_ids = torch.cat([input_ids[output_indices, :], output_tokens.unsqueeze(-1)], dim=-1)
# finalize
decoded = beam_scorer.finalize(
input_ids,
output_scores,
output_tokens,
output_indices,
pad_token_id=self.pad_token_id,
eos_token_id=self.eos_token_id,
)
# since `num_beam_hyps_to_keep` = 1 => only return `batch_size` x `max_length`
self.parent.assertListEqual(list(decoded.shape), [self.batch_size, max_length])
# first batch has to finish with eos_token
self.parent.assertEqual(decoded[0, -1].item(), self.eos_token_id)
# other batches cannot finish with eos token
self.parent.assertNotEqual(decoded[1, -1].item(), self.eos_token_id)
self.parent.assertNotEqual(decoded[2, -1].item(), self.eos_token_id)
# now test that if `num_beam_hyps_to_keep` is 3 => all beams are returned
beam_scorer.num_beam_hyps_to_keep = self.num_beams
decoded = beam_scorer.finalize(
input_ids,
output_scores,
output_tokens,
output_indices,
pad_token_id=self.pad_token_id,
eos_token_id=self.eos_token_id,
)
self.parent.assertListEqual(list(decoded.shape), [self.num_beams * self.batch_size, max_length])
@require_torch
class BeamSearchTest(unittest.TestCase):
def setUp(self):
self.beam_search_tester = BeamSearchTester(self)
def test_beam_hypotheses(self):
inputs = self.beam_search_tester.prepare_inputs()
self.beam_search_tester.check_beam_hypotheses(*inputs)
def test_beam_scorer_update(self):
inputs = self.beam_search_tester.prepare_inputs()
self.beam_search_tester.check_beam_scorer_update(*inputs)
def test_beam_scorer_finalize(self):
inputs = self.beam_search_tester.prepare_inputs()
self.beam_search_tester.check_beam_scores_finalize(*inputs)

View File

@ -0,0 +1,283 @@
# coding=utf-8
# Copyright 2020 The HuggingFace Team Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a clone of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import is_torch_available
from transformers.testing_utils import require_torch, torch_device
from .test_modeling_common import ids_tensor
if is_torch_available():
import torch
import torch.nn.functional as F
from transformers.generation_logits_process import (
LogitsProcessorList,
MinLengthLogitsProcessor,
NoBadWordsLogitsProcessor,
NoRepeatNGramLogitsProcessor,
RepetitionPenaltyLogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
)
@require_torch
class LogitsProcessorTest(unittest.TestCase):
def _get_uniform_logits(self, batch_size: int, length: int):
scores = torch.ones((batch_size, length), device=torch_device, dtype=torch.float) / length
return scores
def test_min_lenght_dist_processor(self):
vocab_size = 20
batch_size = 4
eos_token_id = 0
min_dist_processor = MinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
# check that min length is applied at length 5
input_ids = ids_tensor((batch_size, 5), vocab_size=20)
scores = self._get_uniform_logits(batch_size, vocab_size)
scores_before_min_length = min_dist_processor(input_ids, scores)
self.assertListEqual(scores_before_min_length[:, eos_token_id].tolist(), 4 * [-float("inf")])
# check that min length is not applied anymore at length 15
input_ids = ids_tensor((batch_size, 15), vocab_size=20)
scores = self._get_uniform_logits(batch_size, vocab_size)
scores_before_min_length = min_dist_processor(input_ids, scores)
self.assertFalse(torch.isinf(scores_before_min_length).any())
def test_temperature_dist_warper(self):
input_ids = None
length = 20
scores = self._get_uniform_logits(batch_size=2, length=length)
# tweak scores to not be uniform anymore
scores[1, 5] = (1 / length) + 0.1 # peak, 1st batch
scores[1, 10] = (1 / length) - 0.4 # valley, 1st batch
# compute softmax
probs = F.softmax(scores, dim=-1)
temp_dist_warper_sharper = TemperatureLogitsWarper(temperature=0.5)
temp_dist_warper_smoother = TemperatureLogitsWarper(temperature=1.3)
warped_prob_sharp = F.softmax(temp_dist_warper_sharper(input_ids, scores.clone()), dim=-1)
warped_prob_smooth = F.softmax(temp_dist_warper_smoother(input_ids, scores.clone()), dim=-1)
# uniform distribution stays uniform
self.assertTrue(torch.allclose(probs[0, :], warped_prob_sharp[0, :], atol=1e-3))
self.assertTrue(torch.allclose(probs[0, :], warped_prob_smooth[0, :], atol=1e-3))
# sharp peaks get higher, valleys get lower
self.assertLess(probs[1, :].max(), warped_prob_sharp[1, :].max())
self.assertGreater(probs[1, :].min(), warped_prob_sharp[1, :].min())
# smooth peaks get lower, valleys get higher
self.assertGreater(probs[1, :].max(), warped_prob_smooth[1, :].max())
self.assertLess(probs[1, :].min(), warped_prob_smooth[1, :].min())
def test_repetition_penalty_dist_process(self):
input_ids = torch.tensor([[0, 1], [5, 0]], device=torch_device, dtype=torch.long)
vocab_size = 10
scores = self._get_uniform_logits(batch_size=2, length=vocab_size)
# give values special values
scores[0, 0] = -(1 / vocab_size)
scores[1, 5] = 4 / vocab_size
rep_penalty_proc = RepetitionPenaltyLogitsProcessor(penalty=2.0)
scores = rep_penalty_proc(input_ids, scores.clone())
# check that values were correctly changed
self.assertAlmostEqual(scores[0, 0].item(), -(1 / vocab_size) * 2)
self.assertAlmostEqual(scores[0, 1].item(), (1 / vocab_size) / 2)
self.assertAlmostEqual(scores[1, 0].item(), (1 / vocab_size) / 2)
self.assertAlmostEqual(scores[1, 5].item(), (4 / vocab_size) / 2)
def test_top_k_dist_warper(self):
input_ids = None
vocab_size = 10
batch_size = 2
# create ramp distribution
ramp_logits = (
torch.arange(vocab_size, device=torch_device, dtype=torch.float).unsqueeze(0).repeat(batch_size, 1)
)
ramp_logits[1:, : vocab_size // 2] = ramp_logits[1:, : vocab_size // 2] + vocab_size
top_k_warp = TopKLogitsWarper(3)
scores = top_k_warp(input_ids, ramp_logits)
# check that correct tokens are filtered
self.assertListEqual(torch.isinf(scores[0]).tolist(), 7 * [True] + 3 * [False])
self.assertListEqual(torch.isinf(scores[1]).tolist(), 2 * [True] + 3 * [False] + 5 * [True])
# check special cases
length = 5
logits = self._get_uniform_logits(batch_size=batch_size, length=length)
top_k_warp_safety_check = TopKLogitsWarper(top_k=1, filter_value=0.0, min_tokens_to_keep=3)
scores = top_k_warp_safety_check(input_ids, logits)
# uniform dist is not changed
self.assertListEqual((scores == 0.0).to(torch.long).sum(dim=-1).tolist(), [0, 0])
ramp_logits = torch.arange(length, device=torch_device, dtype=torch.float).unsqueeze(0).repeat(batch_size, 1)
scores = top_k_warp_safety_check(input_ids, ramp_logits)
# min_tokens overwrites k: 3 tokens are kept => 2 tokens are nullified
self.assertListEqual((scores == 0.0).to(torch.long).sum(dim=-1).tolist(), [2, 2])
def test_top_p_dist_warper(self):
input_ids = None
vocab_size = 10
batch_size = 2
# create distribution and take log (inverse to Softmax as taken in TopPLogitsWarper)
dist = torch.log(
torch.tensor([[0.3, 0.1, 0.1, 0.5], [0.15, 0.3, 0.3, 0.25]], device=torch_device, dtype=torch.float)
)
top_p_warp = TopPLogitsWarper(0.7)
filtered_dist = torch.exp(top_p_warp(input_ids, dist))
# dist should be filtered to keep min num values so that sum is >= 0.7
# exp (-inf) => 0
EXPECTED_FILTERED_DIST = torch.tensor(
[[0.3, 0.0, 0.0, 0.5], [0.0, 0.3, 0.3, 0.25]], device=torch_device, dtype=torch.float
)
self.assertTrue(torch.allclose(filtered_dist, EXPECTED_FILTERED_DIST, atol=1e-3))
# check edge cases with negative and extreme logits
ramp_logits = torch.arange(vocab_size, device=torch_device, dtype=torch.float).unsqueeze(0).repeat(
batch_size, 1
) - (vocab_size // 2)
# make ramp_logits more extreme
ramp_logits[1] = ramp_logits[1] * 100.0
# make sure at least 2 tokens are kept
top_p_warp = TopPLogitsWarper(0.9, min_tokens_to_keep=2, filter_value=0.0)
filtered_dist = top_p_warp(input_ids, ramp_logits)
# first batch should keep three tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2.
self.assertListEqual((filtered_dist != 0.0).to(torch.long).sum(dim=-1).tolist(), [3, 2])
def test_no_repeat_ngram_dist_processor(self):
vocab_size = 3
batch_size = 2
input_ids = torch.tensor([[1, 1, 2, 1], [0, 1, 0, 1]], device=torch_device, dtype=torch.long)
scores = self._get_uniform_logits(batch_size, vocab_size)
no_repeat_proc_2_gram = NoRepeatNGramLogitsProcessor(2)
no_repeat_proc_3_gram = NoRepeatNGramLogitsProcessor(3)
filtered_scores_2_gram = no_repeat_proc_2_gram(input_ids, scores.clone())
filtered_scores_3_gram = no_repeat_proc_3_gram(input_ids, scores.clone())
# 2-gram would forbid 2nd and 3rd token (1,2) at 1st batch and 1st token (0) at 2nd batch
self.assertListEqual(torch.isinf(filtered_scores_2_gram).tolist(), [[False, True, True], [True, False, False]])
# 3-gram would forbid no token at 1st batch and 1st token (0) at 2nd batch
self.assertListEqual(
torch.isinf(filtered_scores_3_gram).tolist(), [[False, False, False], [True, False, False]]
)
def test_no_bad_words_dist_processor(self):
vocab_size = 5
batch_size = 2
eos_token_id = 4
input_ids = torch.tensor([[0, 1, 3, 1], [0, 1, 0, 1]], device=torch_device, dtype=torch.long)
bad_word_tokens = [[1], [4], [1, 0], [0, 1, 2], [1, 3, 1, 3]]
scores = self._get_uniform_logits(batch_size, vocab_size)
no_bad_words_dist_proc = NoBadWordsLogitsProcessor(bad_words_ids=bad_word_tokens, eos_token_id=eos_token_id)
filtered_scores = no_bad_words_dist_proc(input_ids, scores.clone())
# batch 1: 1st, 2nd, and 4th (0, 1, 3) token are forbidden
# batch 2: 1st, 2nd, and 3rd (0, 1, 2) token are forbidden
# Note that 5th element cannot be forbidden as it is EOS token
self.assertListEqual(
torch.isinf(filtered_scores).tolist(), [[True, True, False, True, False], [True, True, True, False, False]]
)
# check edge case
no_bad_words_dist_proc = NoBadWordsLogitsProcessor(bad_words_ids=[[4]], eos_token_id=eos_token_id)
filtered_scores = no_bad_words_dist_proc(input_ids, scores.clone())
self.assertTrue(torch.allclose(scores, filtered_scores, atol=1e-3))
def test_processor_list(self):
batch_size = 4
sequence_length = 10
vocab_size = 15
eos_token_id = 0
# dummy input_ids and scores
input_ids = ids_tensor((batch_size, sequence_length), vocab_size)
input_ids_comp = input_ids.clone()
scores = self._get_uniform_logits(batch_size, vocab_size)
scores_comp = scores.clone()
# instantiate all dist processors
min_dist_proc = MinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
temp_dist_warp = TemperatureLogitsWarper(temperature=0.5)
rep_penalty_proc = RepetitionPenaltyLogitsProcessor(penalty=2.0)
top_k_warp = TopKLogitsWarper(3)
top_p_warp = TopPLogitsWarper(0.8)
no_repeat_proc = NoRepeatNGramLogitsProcessor(2)
no_bad_words_dist_proc = NoBadWordsLogitsProcessor(bad_words_ids=[[1]], eos_token_id=eos_token_id)
# no processor list
scores = min_dist_proc(input_ids, scores)
scores = temp_dist_warp(input_ids, scores)
scores = rep_penalty_proc(input_ids, scores)
scores = top_k_warp(input_ids, scores)
scores = top_p_warp(input_ids, scores)
scores = no_repeat_proc(input_ids, scores)
scores = no_bad_words_dist_proc(input_ids, scores)
# with processor list
processor = LogitsProcessorList(
[
min_dist_proc,
temp_dist_warp,
rep_penalty_proc,
top_k_warp,
top_p_warp,
no_repeat_proc,
no_bad_words_dist_proc,
]
)
scores_comp = processor(input_ids, scores_comp)
# scores should be equal
self.assertTrue(torch.allclose(scores, scores_comp, atol=1e-3))
# input_ids should never be changed
self.assertListEqual(input_ids.tolist(), input_ids_comp.tolist())

View File

@ -0,0 +1,510 @@
# coding=utf-8
# Copyright 2020 The HuggingFace Team Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a clone of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import is_torch_available
from transformers.testing_utils import require_torch, torch_device
if is_torch_available():
import torch
from transformers import top_k_top_p_filtering
from transformers.generation_beam_search import BeamSearchScorer
from transformers.generation_logits_process import (
LogitsProcessorList,
MinLengthLogitsProcessor,
NoBadWordsLogitsProcessor,
NoRepeatNGramLogitsProcessor,
RepetitionPenaltyLogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
)
class GenerationTesterMixin:
model_tester = None
all_generative_model_classes = ()
def _get_input_ids_and_config(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict["input_ids"]
attention_mask = torch.ones_like(input_ids)
# cut to half length & take max batch_size 3
max_batch_size = 2
sequence_length = input_ids.shape[-1] // 2
input_ids = input_ids[:max_batch_size, :sequence_length]
attention_mask = attention_mask[:max_batch_size, :sequence_length]
# generate max 3 tokens
max_length = input_ids.shape[-1] + 3
if config.eos_token_id is not None and config.pad_token_id is None:
# hack to allow generate for models such as GPT2 as is done in `generate()`
config.pad_token_id = config.eos_token_id
return config, input_ids, attention_mask, max_length
@staticmethod
def _get_logits_processor_and_kwargs(input_length, eos_token_id):
process_kwargs = {
"min_length": input_length + 1,
"bad_words_ids": [[1, 0]],
"no_repeat_ngram_size": 2,
"repetition_penalty": 1.2,
}
logits_processor = LogitsProcessorList(
(
[
MinLengthLogitsProcessor(process_kwargs["min_length"], eos_token_id),
]
if eos_token_id is not None
else []
)
+ [
NoBadWordsLogitsProcessor(process_kwargs["bad_words_ids"], eos_token_id),
NoRepeatNGramLogitsProcessor(process_kwargs["no_repeat_ngram_size"]),
RepetitionPenaltyLogitsProcessor(process_kwargs["repetition_penalty"]),
]
)
return process_kwargs, logits_processor
@staticmethod
def _get_warper_and_kwargs(num_beams):
warp_kwargs = {"top_k": 10, "top_p": 0.7, "temperature": 0.7}
logits_warper = LogitsProcessorList(
[
TopKLogitsWarper(top_k=warp_kwargs["top_k"], min_tokens_to_keep=(2 if num_beams > 1 else 1)),
TopPLogitsWarper(top_p=warp_kwargs["top_p"], min_tokens_to_keep=(2 if num_beams > 1 else 1)),
TemperatureLogitsWarper(warp_kwargs["temperature"]),
]
)
return warp_kwargs, logits_warper
@staticmethod
def _get_beam_scorer_and_kwargs(batch_size, max_length, num_return_sequences=1):
beam_kwargs = {
"early_stopping": False,
"length_penalty": 2.0,
"num_beams": 2,
"num_return_sequences": num_return_sequences,
}
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
max_length=max_length,
num_beams=beam_kwargs["num_beams"],
device=torch_device,
length_penalty=beam_kwargs["length_penalty"],
do_early_stopping=beam_kwargs["early_stopping"],
num_beam_hyps_to_keep=num_return_sequences,
)
return beam_kwargs, beam_scorer
@staticmethod
def _get_encoder_outputs(model, input_ids, attention_mask, num_interleave=1):
encoder = model.get_encoder()
encoder_outputs = encoder(input_ids, attention_mask=attention_mask, return_dict=True)
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.repeat_interleave(
num_interleave, dim=0
)
input_ids = torch.zeros_like(input_ids[:, :1]) + model._get_decoder_start_token_id()
attention_mask = None
return encoder_outputs, input_ids, attention_mask
def test_greedy_generate(self):
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
input_ids.shape[-1], config.eos_token_id
)
model = model_class(config).to(torch_device)
model.eval()
# check `generate()` and `greedy_search()` are equal
kwargs = {}
if model.config.is_encoder_decoder:
encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs(
model, input_ids, attention_mask
)
kwargs["encoder_outputs"] = encoder_outputs
max_length = 4
output_ids_generate = model.generate(
input_ids,
attention_mask=attention_mask,
do_sample=False,
num_beams=1,
max_length=max_length,
**logits_process_kwargs,
)
with torch.no_grad():
output_ids_greedy = model.greedy_search(
input_ids,
max_length=max_length,
attention_mask=attention_mask,
logits_processor=logits_processor,
**kwargs,
)
self.assertListEqual(output_ids_generate.tolist(), output_ids_greedy.tolist())
def test_sample_generate(self):
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
input_ids.shape[-1], config.eos_token_id
)
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1)
model = model_class(config).to(torch_device)
model.eval()
# check `generate()` and `sample()` are equal
if model.config.is_encoder_decoder:
max_length = 4
torch.manual_seed(0)
output_ids_generate = model.generate(
input_ids,
do_sample=True,
num_beams=1,
max_length=max_length,
attention_mask=attention_mask,
**logits_warper_kwargs,
**process_kwargs,
)
torch.manual_seed(0)
kwargs = {}
if model.config.is_encoder_decoder:
encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs(
model, input_ids, attention_mask
)
kwargs["encoder_outputs"] = encoder_outputs
else:
attention_mask_clone = attention_mask
input_ids_clone = input_ids
with torch.no_grad():
output_ids_sample = model.sample(
input_ids_clone,
attention_mask=attention_mask_clone,
max_length=max_length,
logits_processor=logits_processor,
logits_warper=logits_warper,
**kwargs,
)
self.assertListEqual(output_ids_generate.tolist(), output_ids_sample.tolist())
# check `generate()` and `sample()` yield equal results for `num_return_sequences`
num_return_sequences = 3
if model.config.is_encoder_decoder:
max_length = 4
torch.manual_seed(0)
output_ids_generate = model.generate(
input_ids,
do_sample=True,
num_beams=1,
max_length=max_length,
num_return_sequences=num_return_sequences,
attention_mask=attention_mask,
**logits_warper_kwargs,
**process_kwargs,
)
torch.manual_seed(0)
kwargs = {}
if model.config.is_encoder_decoder:
encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs(
model, input_ids, attention_mask, num_interleave=num_return_sequences
)
kwargs["encoder_outputs"] = encoder_outputs
input_ids_clone = input_ids_clone.repeat_interleave(num_return_sequences, dim=0)
else:
attention_mask_clone = attention_mask.repeat_interleave(num_return_sequences, dim=0)
input_ids_clone = input_ids.repeat_interleave(num_return_sequences, dim=0)
with torch.no_grad():
output_ids_sample = model.sample(
input_ids_clone,
attention_mask=attention_mask_clone,
max_length=max_length,
logits_processor=logits_processor,
logits_warper=logits_warper,
**kwargs,
)
self.assertListEqual(output_ids_generate.tolist(), output_ids_sample.tolist())
def test_beam_search_generate(self):
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
input_ids.shape[-1], config.eos_token_id
)
model = model_class(config).to(torch_device)
model.eval()
# check `generate()` and `beam_search()` are equal
if model.config.is_encoder_decoder:
max_length = 4
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length)
output_ids_generate = model.generate(
input_ids,
attention_mask=attention_mask,
do_sample=False,
max_length=max_length,
**beam_kwargs,
**logits_process_kwargs,
)
# beam_search does not automatically interleave `batch_size` dim for `num_beams`
kwargs = {}
if model.config.is_encoder_decoder:
encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs(
model, input_ids, attention_mask, num_interleave=beam_scorer.num_beams
)
kwargs["encoder_outputs"] = encoder_outputs
input_ids_clone = input_ids_clone.repeat_interleave(beam_scorer.num_beams, dim=0)
else:
attention_mask_clone = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0)
input_ids_clone = input_ids.repeat_interleave(beam_scorer.num_beams, dim=0)
with torch.no_grad():
output_ids_beam_search = model.beam_search(
input_ids_clone,
beam_scorer,
max_length=max_length,
attention_mask=attention_mask_clone,
logits_processor=logits_processor,
**kwargs,
)
self.assertListEqual(output_ids_generate.tolist(), output_ids_beam_search.tolist())
# check `generate()` and `beam_search()` are equal for `num_return_sequences`
num_return_sequences = 2
if model.config.is_encoder_decoder:
max_length = 4
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(
input_ids.shape[0], max_length, num_return_sequences=num_return_sequences
)
output_ids_generate = model.generate(
input_ids,
attention_mask=attention_mask,
do_sample=False,
max_length=max_length,
**beam_kwargs,
**logits_process_kwargs,
)
# beam_search does not automatically interleave `batch_size` dim for `num_beams`
kwargs = {}
if model.config.is_encoder_decoder:
encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs(
model, input_ids, attention_mask, num_interleave=beam_scorer.num_beams
)
kwargs["encoder_outputs"] = encoder_outputs
input_ids_clone = input_ids_clone.repeat_interleave(beam_scorer.num_beams, dim=0)
else:
attention_mask_clone = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0)
input_ids_clone = input_ids.repeat_interleave(beam_scorer.num_beams, dim=0)
with torch.no_grad():
output_ids_beam_search = model.beam_search(
input_ids_clone,
beam_scorer,
max_length=max_length,
attention_mask=attention_mask_clone,
logits_processor=logits_processor,
**kwargs,
)
self.assertListEqual(output_ids_generate.tolist(), output_ids_beam_search.tolist())
def test_beam_sample_generate(self):
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1)
model = model_class(config).to(torch_device)
model.eval()
# check `generate()` and `beam_search()` are equal
# change `num_return_sequences = 2` but not for `beam_scorer`
num_return_sequences = 2
if model.config.is_encoder_decoder:
max_length = 4
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(
input_ids.shape[0] * num_return_sequences, max_length
)
beam_kwargs["num_return_sequences"] = num_return_sequences
torch.manual_seed(0)
output_ids_generate = model.generate(
input_ids,
attention_mask=attention_mask,
do_sample=True,
max_length=max_length,
**beam_kwargs,
**logits_warper_kwargs,
)
# beam_search does not automatically interleave `batch_size` dim for `num_beams * num_return_sequences`
kwargs = {}
if model.config.is_encoder_decoder:
encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs(
model, input_ids, attention_mask, num_interleave=beam_scorer.num_beams * num_return_sequences
)
kwargs["encoder_outputs"] = encoder_outputs
else:
attention_mask = attention_mask.repeat_interleave(beam_scorer.num_beams * num_return_sequences, dim=0)
torch.manual_seed(0)
with torch.no_grad():
output_ids_beam_sample = model.beam_sample(
input_ids.repeat_interleave(beam_scorer.num_beams * num_return_sequences, dim=0),
beam_scorer,
max_length=max_length,
attention_mask=attention_mask,
logits_warper=logits_warper,
**kwargs,
)
self.assertListEqual(output_ids_generate.tolist(), output_ids_beam_sample.tolist())
def test_generate_without_input_ids(self):
config, _, _, max_length = self._get_input_ids_and_config()
# if no bos token id => cannot generate from None
if config.bos_token_id is None:
return
for model_class in self.all_generative_model_classes:
model = model_class(config).to(torch_device)
model.eval()
output_ids_generate = model.generate(
do_sample=False,
max_length=max_length,
)
self.assertIsNotNone(output_ids_generate)
@require_torch
class UtilsFunctionsTest(unittest.TestCase):
# tests whether the top_k_top_p function behaves as expected
def test_top_k_top_p_filtering(self):
logits = torch.tensor(
[
[
8.2220991, # 3rd highest value; idx. 0
-0.5620044,
5.23229752,
4.0386393,
-6.8798378,
-0.54785802,
-3.2012153,
2.92777176,
1.88171953,
7.35341276,
8.43207833, # 2nd highest value; idx. 10
-9.85711836,
-5.96209236,
-1.13039161,
-7.1115294,
-0.8369633,
-5.3186408,
7.06427407,
0.81369344,
-0.82023817,
-5.9179796,
0.58813443,
-6.99778438,
4.71551189,
-0.18771637,
7.44020759, # 4th highest value; idx. 25
9.38450987, # 1st highest value; idx. 26
2.12662941,
-9.32562038,
2.35652522,
], # cummulative prob of 4 highest values <= 0.6
[
0.58425518,
4.53139238,
-5.57510464,
-6.28030699,
-7.19529503,
-4.02122551,
1.39337037,
-6.06707057,
1.59480517,
-9.643119,
0.03907799,
0.67231762,
-8.88206726,
6.27115922, # 4th highest value; idx. 13
2.28520723,
4.82767506,
4.30421368,
8.8275313, # 2nd highest value; idx. 17
5.44029958,
-4.4735794,
7.38579536, # 3rd highest value; idx. 20
-2.91051663,
2.61946077,
-2.5674762,
-9.48959302,
-4.02922645,
-1.35416918,
9.67702323, # 1st highest value; idx. 27
-5.89478553,
1.85370467,
], # cummulative prob of 4 highest values <= 0.6
],
dtype=torch.float,
device=torch_device,
)
non_inf_expected_idx = torch.tensor(
[[0, 0], [0, 10], [0, 25], [0, 26], [1, 13], [1, 17], [1, 20], [1, 27]],
dtype=torch.long,
device=torch_device,
) # expected non filtered idx as noted above
non_inf_expected_output = torch.tensor(
[
8.2221,
8.4321,
7.4402,
9.3845,
6.2712,
8.8275,
7.3858,
9.6770,
], # expected non filtered values as noted above
dtype=torch.float,
device=torch_device,
)
output = top_k_top_p_filtering(logits, top_k=10, top_p=0.6, min_tokens_to_keep=4)
non_inf_output = output[output != -float("inf")].to(device=torch_device)
non_inf_idx = (output != -float("inf")).nonzero().to(device=torch_device)
self.assertTrue(torch.allclose(non_inf_expected_output, non_inf_output, atol=1e-12))
self.assertTrue(torch.all(torch.eq(non_inf_expected_idx, non_inf_idx)))

View File

@ -23,6 +23,7 @@ from transformers.file_utils import cached_property
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_generation_utils import GenerationTesterMixin
from .test_modeling_common import ModelTesterMixin, ids_tensor
@ -128,7 +129,7 @@ def prepare_bart_inputs_dict(
@require_torch
class BARTModelTest(ModelTesterMixin, unittest.TestCase):
class BARTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (
(BartModel, BartForConditionalGeneration, BartForSequenceClassification, BartForQuestionAnswering)
if is_torch_available()

View File

@ -20,6 +20,7 @@ from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_generation_utils import GenerationTesterMixin
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
@ -357,11 +358,12 @@ class BertModelTester:
@require_torch
class BertModelTest(ModelTesterMixin, unittest.TestCase):
class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (
(
BertModel,
BertLMHeadModel,
BertForMaskedLM,
BertForMultipleChoice,
BertForNextSentencePrediction,
@ -373,6 +375,7 @@ class BertModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available()
else ()
)
all_generative_model_classes = (BertLMHeadModel,) if is_torch_available() else ()
def setUp(self):
self.model_tester = BertModelTester(self)

View File

@ -20,6 +20,7 @@ from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_generation_utils import GenerationTesterMixin
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
@ -183,9 +184,10 @@ class BertGenerationEncoderTester:
@require_torch
class BertGenerationEncoderTest(ModelTesterMixin, unittest.TestCase):
class BertGenerationEncoderTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (BertGenerationEncoder, BertGenerationDecoder) if is_torch_available() else ()
all_generative_model_classes = (BertGenerationDecoder,) if is_torch_available() else ()
def setUp(self):
self.model_tester = BertGenerationEncoderTester(self)

View File

@ -147,6 +147,7 @@ class Blenderbot3BIntegrationTests(unittest.TestCase):
src_text = ["Sam"]
model_inputs = self.tokenizer(src_text, return_tensors="pt").to(torch_device)
generated_utterances = model.generate(**model_inputs, **FASTER_GEN_KWARGS)
tgt_text = 'Sam is a great name. It means "sun" in Gaelic.'
@ -156,6 +157,7 @@ class Blenderbot3BIntegrationTests(unittest.TestCase):
src_text = "Social anxiety\nWow, I am never shy. Do you have anxiety?\nYes. I end up sweating and blushing and feel like i'm going to throw up.\nand why is that?"
model_inputs = self.tokenizer([src_text], return_tensors="pt").to(torch_device)
generated_ids = model.generate(**model_inputs, **FASTER_GEN_KWARGS)[0]
reply = self.tokenizer.decode(generated_ids, **TOK_DECODE_KW)
@ -187,6 +189,9 @@ class Blenderbot90MIntegrationTests(unittest.TestCase):
]
model_inputs = self.tokenizer(src_text, return_tensors="pt").to(torch_device)
# model does not have "token_type_ids"
model_inputs.pop("token_type_ids")
assert isinstance(self.tokenizer, BlenderbotSmallTokenizer)
generated_ids = self.model.generate(**model_inputs)[0]
reply = self.tokenizer.decode(generated_ids, **TOK_DECODE_KW)
@ -198,10 +203,11 @@ class Blenderbot90MIntegrationTests(unittest.TestCase):
def test_90_generation_from_short_input(self):
model_inputs = self.tokenizer(["sam"], return_tensors="pt").to(torch_device)
generated_utterances = self.model.generate(**model_inputs)
# generated_txt = self.tokenizer.decode(generated_utterances[0])
# assert generated_txt == "__start__ have you ever heard of sam harris? he's an american singer, songwriter, and actor. __end__"
# model does not have "token_type_ids"
model_inputs.pop("token_type_ids")
generated_utterances = self.model.generate(**model_inputs)
clean_txt = self.tokenizer.decode(generated_utterances[0], **TOK_DECODE_KW)
assert clean_txt in (
"have you ever been to a sam club? it's a great club in the south.",

View File

@ -44,7 +44,6 @@ if is_torch_available():
BertModel,
PretrainedConfig,
PreTrainedModel,
top_k_top_p_filtering,
)
@ -882,126 +881,6 @@ class ModelTesterMixin:
with torch.no_grad():
model(**inputs)[0]
def test_lm_head_model_random_no_beam_search_generate(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"]
# make sure that input_ids is at most of size 15
input_ids = input_ids[..., :15]
# iterate over all generative models
for model_class in self.all_generative_model_classes:
model = model_class(config).to(torch_device)
model.eval()
if config.bos_token_id is None:
# if bos token id is not defined, model needs input_ids
with self.assertRaises(AssertionError):
model.generate(do_sample=True, max_length=5)
# num_return_sequences = 1
self._check_generated_ids(model.generate(input_ids, do_sample=True))
else:
# num_return_sequences = 1
self._check_generated_ids(model.generate(do_sample=True, max_length=5))
with self.assertRaises(AssertionError):
# generating multiple sequences when no beam search generation
# is not allowed as it would always generate the same sequences
model.generate(input_ids, do_sample=False, num_beams=1, num_return_sequences=2)
# num_return_sequences > 1, sample
self._check_generated_ids(model.generate(input_ids, do_sample=True, num_return_sequences=2))
# check bad words tokens language generation
# create list of 1-seq bad token and list of 2-seq of bad tokens
bad_words_ids = [
self._generate_random_bad_tokens(1, model.config),
self._generate_random_bad_tokens(2, model.config),
]
output_tokens = model.generate(
input_ids, do_sample=True, bad_words_ids=bad_words_ids, num_return_sequences=2
)
# only count generated tokens
generated_ids = output_tokens[:, input_ids.shape[-1] :]
self.assertFalse(self._check_match_tokens(generated_ids.tolist(), bad_words_ids))
def test_lm_head_model_random_beam_search_generate(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = (inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"]).to(
torch_device
)
# make sure that input_ids is at most of size 15
input_ids = input_ids[..., :15]
for model_class in self.all_generative_model_classes:
model = model_class(config).to(torch_device)
model.eval()
if config.bos_token_id is None:
# if bos token id is not defined mobel needs input_ids, num_return_sequences = 1
self._check_generated_ids(model.generate(input_ids, do_sample=True, num_beams=2))
else:
# num_return_sequences = 1
self._check_generated_ids(model.generate(do_sample=True, max_length=5, num_beams=2))
with self.assertRaises(AssertionError):
# generating more sequences than having beams leads is not possible
model.generate(input_ids, do_sample=False, num_return_sequences=3, num_beams=2)
# num_return_sequences > 1, sample
self._check_generated_ids(
model.generate(
input_ids,
do_sample=True,
num_beams=2,
num_return_sequences=2,
)
)
# num_return_sequences > 1, greedy
self._check_generated_ids(model.generate(input_ids, do_sample=False, num_beams=2, num_return_sequences=2))
# check bad words tokens language generation
# create list of 1-seq bad token and list of 2-seq of bad tokens
bad_words_ids = [
self._generate_random_bad_tokens(1, model.config),
self._generate_random_bad_tokens(2, model.config),
]
output_tokens = model.generate(
input_ids, do_sample=False, bad_words_ids=bad_words_ids, num_beams=2, num_return_sequences=2
)
# only count generated tokens
generated_ids = output_tokens[:, input_ids.shape[-1] :]
self.assertFalse(self._check_match_tokens(generated_ids.tolist(), bad_words_ids))
def _generate_random_bad_tokens(self, num_bad_tokens: int, config) -> List[int]:
# special tokens cannot be bad tokens
special_tokens = [x for x in [config.bos_token_id, config.eos_token_id, config.pad_token_id] if x is not None]
# create random bad tokens that are not special tokens
bad_tokens = []
while len(bad_tokens) < num_bad_tokens:
token = ids_tensor((1, 1), self.model_tester.vocab_size).squeeze(0).cpu().numpy()[0]
if token not in special_tokens:
bad_tokens.append(token)
return bad_tokens
def _check_generated_ids(self, output_ids):
for token_id in output_ids[0].tolist():
self.assertGreaterEqual(token_id, 0)
self.assertLess(token_id, self.model_tester.vocab_size)
def _check_match_tokens(self, generated_ids, bad_words_ids):
# for all bad word tokens
for bad_word_ids in bad_words_ids:
# for all slices in batch
for generated_ids_slice in generated_ids:
# for all word idx
for i in range(len(bad_word_ids), len(generated_ids_slice)):
# if tokens match
if generated_ids_slice[i - len(bad_word_ids) : i] == bad_word_ids:
return True
return False
@require_torch_multigpu
def test_multigpu_data_parallel_forward(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@ -1094,110 +973,3 @@ class ModelUtilsTest(unittest.TestCase):
model = BertModel.from_pretrained(model_name, output_attentions=True, output_hidden_states=True)
self.assertEqual(model.config.output_hidden_states, True)
self.assertEqual(model.config, config)
@require_torch
class UtilsFunctionsTest(unittest.TestCase):
# tests whether the top_k_top_p function behaves as expected
def test_top_k_top_p_filtering(self):
logits = torch.tensor(
[
[
8.2220991, # 3rd highest value; idx. 0
-0.5620044,
5.23229752,
4.0386393,
-6.8798378,
-0.54785802,
-3.2012153,
2.92777176,
1.88171953,
7.35341276, # 5th highest value; idx. 9
8.43207833, # 2nd highest value; idx. 10
-9.85711836,
-5.96209236,
-1.13039161,
-7.1115294,
-0.8369633,
-5.3186408,
7.06427407,
0.81369344,
-0.82023817,
-5.9179796,
0.58813443,
-6.99778438,
4.71551189,
-0.18771637,
7.44020759, # 4th highest value; idx. 25
9.38450987, # 1st highest value; idx. 26
2.12662941,
-9.32562038,
2.35652522,
], # cumulative prob of 5 highest values <= 0.6
[
0.58425518,
4.53139238,
-5.57510464,
-6.28030699,
-7.19529503,
-4.02122551,
1.39337037,
-6.06707057,
1.59480517,
-9.643119,
0.03907799,
0.67231762,
-8.88206726,
6.27115922, # 4th highest value; idx. 13
2.28520723,
4.82767506,
4.30421368,
8.8275313, # 2nd highest value; idx. 17
5.44029958, # 5th highest value; idx. 18
-4.4735794,
7.38579536, # 3rd highest value; idx. 20
-2.91051663,
2.61946077,
-2.5674762,
-9.48959302,
-4.02922645,
-1.35416918,
9.67702323, # 1st highest value; idx. 27
-5.89478553,
1.85370467,
], # cumulative prob of 5 highest values <= 0.6
],
dtype=torch.float,
device=torch_device,
)
non_inf_expected_idx = torch.tensor(
[[0, 0], [0, 9], [0, 10], [0, 25], [0, 26], [1, 13], [1, 17], [1, 18], [1, 20], [1, 27]],
dtype=torch.long,
device=torch_device,
) # expected non filtered idx as noted above
non_inf_expected_output = torch.tensor(
[
8.2221,
7.3534,
8.4321,
7.4402,
9.3845,
6.2712,
8.8275,
5.4403,
7.3858,
9.6770,
], # expected non filtered values as noted above
dtype=torch.float,
device=torch_device,
)
output = top_k_top_p_filtering(logits, top_k=10, top_p=0.6, min_tokens_to_keep=4)
non_inf_output = output[output != -float("inf")].to(device=torch_device)
non_inf_idx = (output != -float("inf")).nonzero().to(device=torch_device)
self.assertTrue(torch.allclose(non_inf_expected_output, non_inf_output, atol=1e-12))
self.assertTrue(torch.all(torch.eq(non_inf_expected_idx, non_inf_idx)))

View File

@ -19,6 +19,7 @@ from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_generation_utils import GenerationTesterMixin
from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
@ -151,7 +152,7 @@ class CTRLModelTester:
@require_torch
class CTRLModelTest(ModelTesterMixin, unittest.TestCase):
class CTRLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (CTRLModel, CTRLLMHeadModel) if is_torch_available() else ()
all_generative_model_classes = (CTRLLMHeadModel,) if is_torch_available() else ()

View File

@ -24,6 +24,7 @@ from transformers.file_utils import cached_property
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_generation_utils import GenerationTesterMixin
from .test_modeling_common import ModelTesterMixin, ids_tensor
@ -120,7 +121,7 @@ def prepare_fsmt_inputs_dict(
@require_torch
class FSMTModelTest(ModelTesterMixin, unittest.TestCase):
class FSMTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (FSMTModel, FSMTForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (FSMTForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True

View File

@ -20,6 +20,7 @@ from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_generation_utils import GenerationTesterMixin
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
@ -377,7 +378,7 @@ class GPT2ModelTester:
@require_torch
class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (
(GPT2Model, GPT2LMHeadModel, GPT2DoubleHeadsModel, GPT2ForSequenceClassification)
@ -510,32 +511,17 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase):
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
@slow
def test_lm_generate_distilgpt2(self):
model = GPT2LMHeadModel.from_pretrained("distilgpt2")
def test_gpt2_sample(self):
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")
model.to(torch_device)
input_ids = torch.tensor([[464, 1893]], dtype=torch.long, device=torch_device) # The president
expected_output_ids = [
464,
1893,
286,
262,
1578,
1829,
11,
290,
262,
1893,
286,
262,
1578,
7526,
11,
423,
587,
287,
262,
2635,
] # The president of the United States, and the president of the United Kingdom, have been in the White
output_ids = model.generate(input_ids, do_sample=False)
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
torch.manual_seed(0)
input_ids = tokenizer("Today is a nice day and", return_tensors="pt").input_ids.to(torch_device)
output_ids = model.generate(input_ids, do_sample=True)
output_str = tokenizer.decode(output_ids[0], skip_special_tokens=True)
EXPECTED_OUTPUT_STR = (
"Today is a nice day and if you don't know anything about the state of play during your holiday"
)
self.assertEqual(output_str, EXPECTED_OUTPUT_STR)

View File

@ -20,6 +20,7 @@ from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_generation_utils import GenerationTesterMixin
from .test_modeling_common import ModelTesterMixin, ids_tensor
@ -170,7 +171,7 @@ class OpenAIGPTModelTester:
@require_torch
class OpenAIGPTModelTest(ModelTesterMixin, unittest.TestCase):
class OpenAIGPTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (
(OpenAIGPTModel, OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel, OpenAIGPTForSequenceClassification)

View File

@ -22,6 +22,7 @@ from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_generation_utils import GenerationTesterMixin
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
@ -853,7 +854,7 @@ class ProphetNetStandaloneEncoderModelTester:
@require_torch
class ProphetNetModelTest(ModelTesterMixin, unittest.TestCase):
class ProphetNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (ProphetNetModel, ProphetNetForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (ProphetNetForConditionalGeneration,) if is_torch_available() else ()
test_pruning = False
@ -917,7 +918,7 @@ class ProphetNetModelTest(ModelTesterMixin, unittest.TestCase):
@require_torch
class ProphetNetStandaloneDecoderModelTest(ModelTesterMixin, unittest.TestCase):
class ProphetNetStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (ProphetNetDecoder, ProphetNetForCausalLM) if is_torch_available() else ()
all_generative_model_classes = (ProphetNetForCausalLM,) if is_torch_available() else ()
test_pruning = False

View File

@ -26,6 +26,7 @@ from transformers.testing_utils import (
)
from .test_configuration_common import ConfigTester
from .test_generation_utils import GenerationTesterMixin
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
@ -196,11 +197,14 @@ class ReformerModelTester:
)
def create_and_check_reformer_model_with_lm_backward(self, config, input_ids, input_mask, choice_labels):
if not self.is_training:
return
config.is_decoder = False
config.lsh_num_chunks_after = 1
model = ReformerForMaskedLM(config=config)
model.to(torch_device)
model.eval()
model.train()
loss = model(input_ids, attention_mask=input_mask, labels=input_ids)["loss"]
loss.backward()
@ -569,7 +573,7 @@ class ReformerTesterMixin:
@require_torch
class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.TestCase):
class ReformerLocalAttnModelTest(ReformerTesterMixin, GenerationTesterMixin, ModelTesterMixin, unittest.TestCase):
all_model_classes = (
(ReformerModel, ReformerModelWithLMHead, ReformerForSequenceClassification, ReformerForQuestionAnswering)
if is_torch_available()
@ -629,7 +633,7 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest
@require_torch
class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.TestCase):
class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (
(ReformerModel, ReformerModelWithLMHead, ReformerForSequenceClassification, ReformerForQuestionAnswering)
if is_torch_available()

View File

@ -20,6 +20,7 @@ from transformers import is_torch_available
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_generation_utils import GenerationTesterMixin
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
@ -267,7 +268,7 @@ class RobertaModelTester:
@require_torch
class RobertaModelTest(ModelTesterMixin, unittest.TestCase):
class RobertaModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (
(
@ -282,6 +283,7 @@ class RobertaModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available()
else ()
)
all_generative_model_classes = (RobertaForCausalLM,) if is_torch_available() else ()
def setUp(self):
self.model_tester = RobertaModelTester(self)

View File

@ -23,6 +23,7 @@ from transformers.file_utils import cached_property
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_generation_utils import GenerationTesterMixin
from .test_modeling_common import ModelTesterMixin, ids_tensor
@ -466,7 +467,7 @@ class T5ModelTester:
@require_torch
class T5ModelTest(ModelTesterMixin, unittest.TestCase):
class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else ()
@ -592,6 +593,7 @@ class T5ModelIntegrationTests(unittest.TestCase):
do_sample=False,
early_stopping=True,
)
decoded = tok.batch_decode(hypotheses_batch, skip_special_tokens=True, clean_up_tokenization_spaces=False)
self.assertListEqual(
expected_summaries,

View File

@ -20,6 +20,7 @@ from transformers import is_torch_available
from transformers.testing_utils import require_torch, require_torch_multigpu, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_generation_utils import GenerationTesterMixin
from .test_modeling_common import ModelTesterMixin, ids_tensor
@ -156,7 +157,7 @@ class TransfoXLModelTester:
@require_torch
class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase):
class TransfoXLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (TransfoXLModel, TransfoXLLMHeadModel) if is_torch_available() else ()
all_generative_model_classes = (TransfoXLLMHeadModel,) if is_torch_available() else ()
test_pruning = False

View File

@ -20,6 +20,7 @@ from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_generation_utils import GenerationTesterMixin
from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
@ -331,7 +332,7 @@ class XLMModelTester:
@require_torch
class XLMModelTest(ModelTesterMixin, unittest.TestCase):
class XLMModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (
(

View File

@ -21,6 +21,7 @@ from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_generation_utils import GenerationTesterMixin
from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
@ -479,7 +480,7 @@ class XLNetModelTester:
@require_torch
class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (
(
XLNetModel,