mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 17:52:35 +06:00
Refactoring the generate() function (#6949)
* first draft * show design proposition for new generate method * up * make better readable * make first version * gpt2 tests pass * make beam search for gpt2 work * add first encoder-decoder code * delete typo * make t5 work * save indermediate * make bart work with beam search * finish beam search bart / t5 * add default kwargs * make more tests pass * fix no bad words sampler * some fixes and tests for all distribution processors * fix test * fix rag slow tests * merge to master * add nograd to generate * make all slow tests pass * speed up generate * fix edge case bug * small fix * correct typo * add type hints and docstrings * fix typos in tests * add beam search tests * add tests for beam scorer * fix test rag * finish beam search tests * move generation tests in seperate file * fix generation tests * more tests * add aggressive generation tests * fix tests * add gpt2 sample test * add more docstring * add more docs * finish doc strings * apply some more of sylvains and sams comments * fix some typos * make fix copies * apply lysandres and sylvains comments * final corrections on examples * small fix for reformer
This commit is contained in:
parent
b63beb743c
commit
a1bbcf3f6c
@ -272,3 +272,4 @@ conversion utilities for the following models:
|
||||
internal/pipelines_utils
|
||||
internal/tokenization_utils
|
||||
internal/trainer_utils
|
||||
internal/generation_utils
|
||||
|
50
docs/source/internal/generation_utils.rst
Normal file
50
docs/source/internal/generation_utils.rst
Normal file
@ -0,0 +1,50 @@
|
||||
Utilities for Generation
|
||||
-----------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
This page lists all the utility functions used by :meth:`~transformers.PretrainedModel.generate`,
|
||||
:meth:`~transformers.PretrainedModel.greedy_search`, :meth:`~transformers.PretrainedModel.sample`,
|
||||
:meth:`~transformers.PretrainedModel.beam_search`, and :meth:`~transformers.PretrainedModel.beam_sample`.
|
||||
|
||||
Most of those are only useful if you are studying the code of the generate methods in the library.
|
||||
|
||||
LogitsProcessor
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
A :class:`~transformers.LogitsProcessor` can be used to modify the prediction scores of a language model head for
|
||||
generation.
|
||||
|
||||
.. autoclass:: transformers.LogitsProcessor
|
||||
:members: __call__
|
||||
|
||||
.. autoclass:: transformers.LogitsProcessorList
|
||||
:members: __call__
|
||||
|
||||
.. autoclass:: transformers.MinLengthLogitsProcessor
|
||||
:members: __call__
|
||||
|
||||
.. autoclass:: transformers.TemperatureLogitsWarper
|
||||
:members: __call__
|
||||
|
||||
.. autoclass:: transformers.RepetitionPenaltyLogitsProcessor
|
||||
:members: __call__
|
||||
|
||||
.. autoclass:: transformers.TopPLogitsWarper
|
||||
:members: __call__
|
||||
|
||||
.. autoclass:: transformers.TopKLogitsWarper
|
||||
:members: __call__
|
||||
|
||||
.. autoclass:: transformers.NoRepeatNGramLogitsProcessor
|
||||
:members: __call__
|
||||
|
||||
.. autoclass:: transformers.NoBadWordsLogitsProcessor
|
||||
:members: __call__
|
||||
|
||||
BeamSearch
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.BeamScorer
|
||||
:members: process, finalize
|
||||
|
||||
.. autoclass:: transformers.BeamSearchScorer
|
||||
:members: process, finalize
|
@ -45,7 +45,7 @@ TFModelUtilsMixin
|
||||
:members:
|
||||
|
||||
|
||||
Generative models
|
||||
Generation
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.generation_utils.GenerationMixin
|
||||
|
@ -299,6 +299,19 @@ if is_torch_available():
|
||||
TextDataset,
|
||||
TextDatasetForNextSentencePrediction,
|
||||
)
|
||||
from .generation_beam_search import BeamScorer, BeamSearchScorer
|
||||
from .generation_logits_process import (
|
||||
LogitsProcessor,
|
||||
LogitsProcessorList,
|
||||
LogitsWarper,
|
||||
MinLengthLogitsProcessor,
|
||||
NoBadWordsLogitsProcessor,
|
||||
NoRepeatNGramLogitsProcessor,
|
||||
RepetitionPenaltyLogitsProcessor,
|
||||
TemperatureLogitsWarper,
|
||||
TopKLogitsWarper,
|
||||
TopPLogitsWarper,
|
||||
)
|
||||
from .generation_utils import top_k_top_p_filtering
|
||||
from .modeling_albert import (
|
||||
ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
|
357
src/transformers/generation_beam_search.py
Normal file
357
src/transformers/generation_beam_search.py
Normal file
@ -0,0 +1,357 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Inc. team
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import UserDict
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from .file_utils import add_start_docstrings
|
||||
|
||||
|
||||
PROCESS_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_beams, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
|
||||
Indices can be obtained using any class inheriting from :class:`~transformers.PretrainedTokenizer`. See
|
||||
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
|
||||
details.
|
||||
|
||||
`What are input IDs? <../glossary.html#input-ids>`__
|
||||
next_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2 * num_beams)`):
|
||||
Current scores of the top :obj:`2 * num_beams` non-finished beam hypotheses.
|
||||
next_tokens (:obj:`torch.LongTensor` of shape :obj:`(batch_size, 2 * num_beams)`):
|
||||
:obj:`input_ids` of the tokens corresponding to the top :obj:`2 * num_beams` non-finished beam hypotheses.
|
||||
next_indices (:obj:`torch.LongTensor` of shape :obj:`(batch_size, 2 * num_beams)`):
|
||||
Beam indices indicating to which beam hypothesis the :obj:`next_tokens` correspond.
|
||||
pad_token_id (:obj:`int`, `optional`):
|
||||
The id of the `padding` token.
|
||||
eos_token_id (:obj:`int`, `optional`):
|
||||
The id of the `end-of-sequence` token.
|
||||
|
||||
Return:
|
||||
:obj:`UserDict`: A dictionary composed of the fields as defined above:
|
||||
|
||||
- **next_beam_scores** (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`) -- Updated
|
||||
scores of all non-finished beams.
|
||||
- **next_beam_tokens** (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`) -- Next tokens
|
||||
to be added to the non-finished beam_hypotheses.
|
||||
- **next_beam_indices** (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`) -- Beam indices
|
||||
indicating to which beam the next tokens shall be added.
|
||||
|
||||
"""
|
||||
|
||||
FINALIZE_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_beams, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
|
||||
Indices can be obtained using any class inheriting from :class:`~transformers.PretrainedTokenizer`. See
|
||||
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
|
||||
details.
|
||||
|
||||
`What are input IDs? <../glossary.html#input-ids>`__
|
||||
final_beam_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`):
|
||||
The final scores of all non-finished beams.
|
||||
final_beam_tokens (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`):
|
||||
The last tokens to be added to the non-finished beam_hypotheses.
|
||||
final_beam_indices (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`):
|
||||
The beam indices indicating to which beam the :obj:`final_beam_tokens` shall be added.
|
||||
pad_token_id (:obj:`int`, `optional`):
|
||||
The id of the `padding` token.
|
||||
eos_token_id (:obj:`int`, `optional`):
|
||||
The id of the `end-of-sequence` token.
|
||||
|
||||
Return:
|
||||
:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated
|
||||
sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or shorter if all
|
||||
batches finished early due to the :obj:`eos_token_id`.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class BeamScorer(ABC):
|
||||
"""
|
||||
Abstract base class for all beam scorers that are used for :meth:`~transformers.PretrainedModel.beam_search` and
|
||||
:meth:`~transformers.PretrainedModel.beam_sample`.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
@add_start_docstrings(PROCESS_INPUTS_DOCSTRING)
|
||||
def process(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
next_scores: torch.FloatTensor,
|
||||
next_tokens: torch.LongTensor,
|
||||
next_indices: torch.LongTensor,
|
||||
**kwargs
|
||||
) -> Tuple[torch.Tensor]:
|
||||
raise NotImplementedError("This is an abstract method.")
|
||||
|
||||
@abstractmethod
|
||||
@add_start_docstrings(FINALIZE_INPUTS_DOCSTRING)
|
||||
def finalize(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
next_scores: torch.FloatTensor,
|
||||
next_tokens: torch.LongTensor,
|
||||
next_indices: torch.LongTensor,
|
||||
**kwargs
|
||||
) -> torch.LongTensor:
|
||||
raise NotImplementedError("This is an abstract method.")
|
||||
|
||||
|
||||
class BeamSearchScorer(BeamScorer):
|
||||
r"""
|
||||
:class:`transformers.BeamScorer` implementing standard beam search decoding.
|
||||
|
||||
Adapted in part from `Facebook's XLM beam search code
|
||||
<https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529>`__.
|
||||
|
||||
Args:
|
||||
batch_size (:obj:`int`):
|
||||
Batch Size of :obj:`input_ids` for which beam search decoding is run in parallel.
|
||||
max_length (:obj:`int`):
|
||||
The maximum length of the sequence to be generated.
|
||||
num_beams (:obj:`int`):
|
||||
Number of beams for beam search.
|
||||
device (:obj:`torch.device`):
|
||||
Defines the device type (*e.g.*, :obj:`"cpu"` or :obj:`"cuda"`) on which this instance of
|
||||
:obj:`BeamSearchScorer` will be allocated.
|
||||
length_penalty (:obj:`float`, `optional`, defaults to 1.0):
|
||||
Exponential penalty to the length. 1.0 means no penalty. Set to values < 1.0 in order to encourage the
|
||||
model to generate shorter sequences, to a value > 1.0 in order to encourage the model to produce longer
|
||||
sequences.
|
||||
do_early_stopping (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether to stop the beam search when at least ``num_beams`` sentences are finished per batch or not.
|
||||
num_beam_hyps_to_keep (:obj:`int`, `optional`, defaults to 1):
|
||||
The number of beam hypotheses that shall be returned upon calling
|
||||
:meth:`~transformer.BeamSearchScorer.finalize`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int,
|
||||
max_length: int,
|
||||
num_beams: int,
|
||||
device: torch.device,
|
||||
length_penalty: Optional[float] = 1.0,
|
||||
do_early_stopping: Optional[bool] = False,
|
||||
num_beam_hyps_to_keep: Optional[int] = 1,
|
||||
):
|
||||
self.max_length = max_length
|
||||
self.num_beams = num_beams
|
||||
self.device = device
|
||||
self.length_penalty = length_penalty
|
||||
self.do_early_stopping = do_early_stopping
|
||||
self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
|
||||
|
||||
self._is_init = False
|
||||
self._beam_hyps = [
|
||||
BeamHypotheses(
|
||||
num_beams=self.num_beams,
|
||||
max_length=self.max_length,
|
||||
length_penalty=self.length_penalty,
|
||||
early_stopping=self.do_early_stopping,
|
||||
)
|
||||
for _ in range(batch_size)
|
||||
]
|
||||
self._done = torch.tensor([False for _ in range(batch_size)], dtype=torch.bool, device=self.device)
|
||||
|
||||
if not isinstance(num_beams, int) or num_beams <= 1:
|
||||
raise ValueError(
|
||||
f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1, one should make use of `greedy_search` instead."
|
||||
)
|
||||
|
||||
@property
|
||||
def is_done(self) -> bool:
|
||||
return self._done.all()
|
||||
|
||||
def process(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
next_scores: torch.FloatTensor,
|
||||
next_tokens: torch.LongTensor,
|
||||
next_indices: torch.LongTensor,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
) -> Tuple[torch.Tensor]:
|
||||
cur_len = input_ids.shape[-1]
|
||||
batch_size = len(self._beam_hyps)
|
||||
assert batch_size == (input_ids.shape[0] // self.num_beams)
|
||||
|
||||
device = input_ids.device
|
||||
next_beam_scores = torch.zeros((batch_size, self.num_beams), dtype=next_scores.dtype, device=device)
|
||||
next_beam_tokens = torch.zeros((batch_size, self.num_beams), dtype=next_tokens.dtype, device=device)
|
||||
next_beam_indices = torch.zeros((batch_size, self.num_beams), dtype=next_indices.dtype, device=device)
|
||||
|
||||
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
|
||||
if self._done[batch_idx]:
|
||||
assert (
|
||||
len(beam_hyp) >= self.num_beams
|
||||
), "Batch can only be done if at least {} beams have been generated".format(self.num_beams)
|
||||
assert (
|
||||
eos_token_id is not None and pad_token_id is not None
|
||||
), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
|
||||
# pad the batch
|
||||
next_beam_scores[batch_idx, :] = 0
|
||||
next_beam_tokens[batch_idx, :] = pad_token_id
|
||||
next_beam_indices[batch_idx, :] = 0
|
||||
continue
|
||||
|
||||
# next tokens for this sentence
|
||||
beam_idx = 0
|
||||
for beam_token_rank, (next_token, next_score, next_index) in enumerate(
|
||||
zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx])
|
||||
):
|
||||
batch_beam_idx = batch_idx * self.num_beams + next_index
|
||||
# add to generated hypotheses if end of sentence
|
||||
if (eos_token_id is not None) and (next_token.item() == eos_token_id):
|
||||
# if beam_token does not belong to top num_beams tokens, it should not be added
|
||||
is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.num_beams
|
||||
if is_beam_token_worse_than_top_num_beams:
|
||||
continue
|
||||
beam_hyp.add(
|
||||
input_ids[batch_beam_idx].clone(),
|
||||
next_score.item(),
|
||||
)
|
||||
else:
|
||||
# add next predicted token since it is not eos_token
|
||||
next_beam_scores[batch_idx, beam_idx] = next_score
|
||||
next_beam_tokens[batch_idx, beam_idx] = next_token
|
||||
next_beam_indices[batch_idx, beam_idx] = batch_beam_idx
|
||||
beam_idx += 1
|
||||
|
||||
# once the beam for next step is full, don't add more tokens to it.
|
||||
if beam_idx == self.num_beams:
|
||||
break
|
||||
|
||||
if beam_idx < self.num_beams:
|
||||
raise ValueError(
|
||||
f"At most {self.num_beams} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id: {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected."
|
||||
)
|
||||
|
||||
# Check if we are done so that we can save a pad step if all(done)
|
||||
self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done(
|
||||
next_scores[batch_idx].max().item(), cur_len
|
||||
)
|
||||
|
||||
return UserDict(
|
||||
{
|
||||
"next_beam_scores": next_beam_scores.view(-1),
|
||||
"next_beam_tokens": next_beam_tokens.view(-1),
|
||||
"next_beam_indices": next_beam_indices.view(-1),
|
||||
}
|
||||
)
|
||||
|
||||
def finalize(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
final_beam_scores: torch.FloatTensor,
|
||||
final_beam_tokens: torch.LongTensor,
|
||||
final_beam_indices: torch.LongTensor,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
) -> torch.LongTensor:
|
||||
batch_size = len(self._beam_hyps)
|
||||
|
||||
# finalize all open beam hypotheses and add to generated hypotheses
|
||||
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
|
||||
if self._done[batch_idx]:
|
||||
continue
|
||||
|
||||
# need to add best num_beams hypotheses to generated hyps
|
||||
for beam_id in range(self.num_beams):
|
||||
batch_beam_idx = batch_idx * self.num_beams + beam_id
|
||||
final_score = final_beam_scores[batch_beam_idx].item()
|
||||
final_tokens = input_ids[batch_beam_idx]
|
||||
beam_hyp.add(final_tokens, final_score)
|
||||
|
||||
# select the best hypotheses
|
||||
sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
|
||||
best = []
|
||||
|
||||
# retrieve best hypotheses
|
||||
for i, beam_hyp in enumerate(self._beam_hyps):
|
||||
sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0])
|
||||
for j in range(self.num_beam_hyps_to_keep):
|
||||
best_hyp = sorted_hyps.pop()[1]
|
||||
sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)
|
||||
best.append(best_hyp)
|
||||
|
||||
# prepare for adding eos
|
||||
sent_max_len = min(sent_lengths.max().item() + 1, self.max_length)
|
||||
decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
|
||||
# shorter batches are padded if needed
|
||||
if sent_lengths.min().item() != sent_lengths.max().item():
|
||||
assert pad_token_id is not None, "`pad_token_id` has to be defined"
|
||||
decoded.fill_(pad_token_id)
|
||||
|
||||
# fill with hypotheses and eos_token_id if the latter fits in
|
||||
for i, hypo in enumerate(best):
|
||||
decoded[i, : sent_lengths[i]] = hypo
|
||||
if sent_lengths[i] < self.max_length:
|
||||
decoded[i, sent_lengths[i]] = eos_token_id
|
||||
return decoded
|
||||
|
||||
|
||||
class BeamHypotheses:
|
||||
def __init__(self, num_beams: int, max_length: int, length_penalty: float, early_stopping: bool):
|
||||
"""
|
||||
Initialize n-best list of hypotheses.
|
||||
"""
|
||||
self.max_length = max_length - 1 # ignoring bos_token
|
||||
self.length_penalty = length_penalty
|
||||
self.early_stopping = early_stopping
|
||||
self.num_beams = num_beams
|
||||
self.beams = []
|
||||
self.worst_score = 1e9
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
Number of hypotheses in the list.
|
||||
"""
|
||||
return len(self.beams)
|
||||
|
||||
def add(self, hyp: torch.LongTensor, sum_logprobs: float):
|
||||
"""
|
||||
Add a new hypothesis to the list.
|
||||
"""
|
||||
score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty)
|
||||
if len(self) < self.num_beams or score > self.worst_score:
|
||||
self.beams.append((score, hyp))
|
||||
if len(self) > self.num_beams:
|
||||
sorted_next_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)])
|
||||
del self.beams[sorted_next_scores[0][1]]
|
||||
self.worst_score = sorted_next_scores[1][0]
|
||||
else:
|
||||
self.worst_score = min(score, self.worst_score)
|
||||
|
||||
def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool:
|
||||
"""
|
||||
If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst
|
||||
one in the heap, then we are done with this sentence.
|
||||
"""
|
||||
|
||||
if len(self) < self.num_beams:
|
||||
return False
|
||||
elif self.early_stopping:
|
||||
return True
|
||||
else:
|
||||
cur_score = best_sum_logprobs / cur_len ** self.length_penalty
|
||||
ret = self.worst_score >= cur_score
|
||||
return ret
|
374
src/transformers/generation_logits_process.py
Normal file
374
src/transformers/generation_logits_process.py
Normal file
@ -0,0 +1,374 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Inc. team
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from abc import ABC
|
||||
from typing import Iterable, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .file_utils import add_start_docstrings
|
||||
|
||||
|
||||
LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
|
||||
Indices can be obtained using :class:`~transformers.BertTokenizer`. See
|
||||
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
|
||||
details.
|
||||
|
||||
`What are input IDs? <../glossary.html#input-ids>`__
|
||||
scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.vocab_size)`):
|
||||
Prediction scores of a language modeling head. These can be scores for each vocabulary token before SoftMax
|
||||
or scores for each vocabulary token after SoftMax.
|
||||
|
||||
Return:
|
||||
:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.vocab_size)`: The processed prediction scores.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class LogitsProcessor(ABC):
|
||||
"""Abstract base class for all logit processors that can be applied during generation."""
|
||||
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
"""Torch method for processing logits."""
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
|
||||
)
|
||||
|
||||
|
||||
class LogitsWarper(ABC):
|
||||
"""Abstract base class for all logit warpers that can be applied during generation with multinomial sampling."""
|
||||
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
"""Torch method for warping logits."""
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
|
||||
)
|
||||
|
||||
|
||||
class LogitsProcessorList(list):
|
||||
"""
|
||||
This class can be used to create a list of :class:`~transformers.LogitsProcessor` or
|
||||
:class:`~transformers.LogitsWarper` to subsequently process a :obj:`scores` input tensor. This class inherits from
|
||||
list and adds a specific `__call__` method to apply each :class:`~transformers.LogitsProcessor` or
|
||||
:class:`~transformers.LogitsProcessor` to the inputs.
|
||||
"""
|
||||
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
for processor in self:
|
||||
scores = processor(input_ids, scores)
|
||||
return scores
|
||||
|
||||
|
||||
class MinLengthLogitsProcessor(LogitsProcessor):
|
||||
r"""
|
||||
:class:`transformers.LogitsProcessor` enforcing a min-length by setting EOS probability to 0.
|
||||
|
||||
Args:
|
||||
min_length (:obj:`int`):
|
||||
The minimum length below which the score of :obj:`eos_token_id` is set to :obj:`-float("Inf")`.
|
||||
eos_token_id (:obj:`int`):
|
||||
The id of the `end-of-sequence` token.
|
||||
"""
|
||||
|
||||
def __init__(self, min_length: int, eos_token_id: int):
|
||||
if not isinstance(min_length, int) or min_length < 0:
|
||||
raise ValueError(f"`min_length` has to be a positive integer, but is {min_length}")
|
||||
|
||||
if not isinstance(eos_token_id, int) or eos_token_id < 0:
|
||||
raise ValueError(f"`eos_token_id` has to be a positive integer, but is {eos_token_id}")
|
||||
|
||||
self.min_length = min_length
|
||||
self.eos_token_id = eos_token_id
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
cur_len = input_ids.shape[-1]
|
||||
if cur_len < self.min_length:
|
||||
scores[:, self.eos_token_id] = -float("inf")
|
||||
return scores
|
||||
|
||||
|
||||
class TemperatureLogitsWarper(LogitsWarper):
|
||||
r"""
|
||||
:class:`transformers.LogitsWarper` for temperature (exponential scaling output probability distribution).
|
||||
|
||||
Args:
|
||||
temperature (:obj:`float`):
|
||||
The value used to module the logits distribution.
|
||||
"""
|
||||
|
||||
def __init__(self, temperature: float):
|
||||
if not isinstance(temperature, float) or not (temperature > 0):
|
||||
raise ValueError(f"`temperature` has to be a strictly positive float, but is {temperature}")
|
||||
|
||||
self.temperature = temperature
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||
scores = scores / self.temperature
|
||||
return scores
|
||||
|
||||
|
||||
class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
|
||||
r"""
|
||||
:class:`transformers.LogitsProcessor` enforcing an exponential penalty on repeated sequences.
|
||||
|
||||
Args:
|
||||
repetition_penalty (:obj:`float`):
|
||||
The parameter for repetition penalty. 1.0 means no penalty. See `this paper
|
||||
<https://arxiv.org/pdf/1909.05858.pdf>`__ for more details.
|
||||
"""
|
||||
|
||||
def __init__(self, penalty: float):
|
||||
if not isinstance(penalty, float) or not (penalty > 0):
|
||||
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
|
||||
|
||||
self.penalty = penalty
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
for i in range(scores.shape[0]):
|
||||
for previous_token in set(input_ids[i].tolist()):
|
||||
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
|
||||
if scores[i, previous_token] < 0:
|
||||
scores[i, previous_token] *= self.penalty
|
||||
else:
|
||||
scores[i, previous_token] /= self.penalty
|
||||
return scores
|
||||
|
||||
|
||||
class TopPLogitsWarper(LogitsWarper):
|
||||
"""
|
||||
:class:`transformers.LogitsWarper` that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <=
|
||||
prob_cut_off.
|
||||
|
||||
Args:
|
||||
top_p (:obj:`float`):
|
||||
If set to < 1, only the most probable tokens with probabilities that add up to :obj:`top_p` or higher are
|
||||
kept for generation.
|
||||
filter_value (:obj:`float`, `optional`, defaults to :obj:`-float("Inf")`):
|
||||
All filtered values will be set to this float value.
|
||||
min_tokens_to_keep (:obj:`int`, `optional`, defaults to 1):
|
||||
Minimum number of tokens that cannot be filtered.
|
||||
"""
|
||||
|
||||
def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||
if not isinstance(top_p, float) or (top_p < 0 or top_p > 1.0):
|
||||
raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
|
||||
|
||||
self.top_p = top_p
|
||||
self.filter_value = filter_value
|
||||
self.min_tokens_to_keep = min_tokens_to_keep
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
sorted_logits, sorted_indices = torch.sort(scores, descending=True)
|
||||
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
|
||||
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
|
||||
sorted_indices_to_remove = cumulative_probs > self.top_p
|
||||
if self.min_tokens_to_keep > 1:
|
||||
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
|
||||
sorted_indices_to_remove[..., : self.min_tokens_to_keep - 1] = 0
|
||||
# Shift the indices to the right to keep also the first token above the threshold
|
||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
|
||||
# scatter sorted tensors to original indexing
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
scores[indices_to_remove] = self.filter_value
|
||||
return scores
|
||||
|
||||
|
||||
class TopKLogitsWarper(LogitsWarper):
|
||||
r"""
|
||||
:class:`transformers.LogitsWarper` that performs top-k, i.e. restricting to the k highest probability elements.
|
||||
|
||||
Args:
|
||||
top_k (:obj:`int`):
|
||||
The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
||||
filter_value (:obj:`float`, `optional`, defaults to :obj:`-float("Inf")`):
|
||||
All filtered values will be set to this float value.
|
||||
min_tokens_to_keep (:obj:`int`, `optional`, defaults to 1):
|
||||
Minimum number of tokens that cannot be filtered.
|
||||
"""
|
||||
|
||||
def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||
if not isinstance(top_k, int) or top_k <= 0:
|
||||
raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}")
|
||||
|
||||
self.top_k = top_k
|
||||
self.filter_value = filter_value
|
||||
self.min_tokens_to_keep = min_tokens_to_keep
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
top_k = min(max(self.top_k, self.min_tokens_to_keep), scores.size(-1)) # Safety check
|
||||
# Remove all tokens with a probability less than the last token of the top-k
|
||||
indices_to_remove = scores < torch.topk(scores, top_k)[0][..., -1, None]
|
||||
scores[indices_to_remove] = self.filter_value
|
||||
return scores
|
||||
|
||||
|
||||
class NoRepeatNGramLogitsProcessor(LogitsProcessor):
|
||||
r"""
|
||||
:class:`transformers.LogitsProcessor` that enforces no repetition of n-grams. See `Fairseq
|
||||
<https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345>`__.
|
||||
|
||||
Args:
|
||||
ngram_size (:obj:`int`):
|
||||
All ngrams of size :obj:`ngram_size` can only occur once.
|
||||
"""
|
||||
|
||||
def __init__(self, ngram_size: int):
|
||||
if not isinstance(ngram_size, int) or ngram_size <= 0:
|
||||
raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}")
|
||||
self.ngram_size = ngram_size
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
num_batch_hypotheses = scores.shape[0]
|
||||
cur_len = input_ids.shape[-1]
|
||||
banned_batch_tokens = self._calc_banned_ngram_tokens(input_ids, num_batch_hypotheses, cur_len)
|
||||
|
||||
for i, banned_tokens in enumerate(banned_batch_tokens):
|
||||
scores[i, banned_tokens] = -float("inf")
|
||||
|
||||
return scores
|
||||
|
||||
def _calc_banned_ngram_tokens(
|
||||
self, prev_input_ids: torch.Tensor, num_hypos: int, cur_len: int
|
||||
) -> List[Iterable[int]]:
|
||||
"""Copied from fairseq for no_repeat_ngram in beam_search"""
|
||||
if cur_len + 1 < self.ngram_size:
|
||||
# return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
|
||||
return [[] for _ in range(num_hypos)]
|
||||
generated_ngrams = [{} for _ in range(num_hypos)]
|
||||
for idx in range(num_hypos):
|
||||
gen_tokens = prev_input_ids[idx].tolist()
|
||||
generated_ngram = generated_ngrams[idx]
|
||||
for ngram in zip(*[gen_tokens[i:] for i in range(self.ngram_size)]):
|
||||
prev_ngram_tuple = tuple(ngram[:-1])
|
||||
generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
|
||||
|
||||
def _get_generated_ngrams(hypo_idx):
|
||||
# Before decoding the next token, prevent decoding of ngrams that have already appeared
|
||||
start_idx = cur_len + 1 - self.ngram_size
|
||||
ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist())
|
||||
return generated_ngrams[hypo_idx].get(ngram_idx, [])
|
||||
|
||||
banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
|
||||
return banned_tokens
|
||||
|
||||
|
||||
class NoBadWordsLogitsProcessor(LogitsProcessor):
|
||||
"""
|
||||
:class:`transformers.LogitsProcessor` that enforces that specified sequences will never be sampled.
|
||||
|
||||
Args:
|
||||
bad_words_ids (:obj:`List[List[int]]`):
|
||||
List of list of token ids that are not allowed to be generated. In order to get the tokens of the words
|
||||
that should not appear in the generated text, use :obj:`tokenizer(bad_word,
|
||||
add_prefix_space=True).input_ids`.
|
||||
eos_token_id (:obj:`int`):
|
||||
The id of the `end-of-sequence` token.
|
||||
"""
|
||||
|
||||
def __init__(self, bad_words_ids: Iterable[Iterable[int]], eos_token_id: int):
|
||||
|
||||
if not isinstance(bad_words_ids, List) or len(bad_words_ids) == 0:
|
||||
raise ValueError(f"`bad_words_ids` has to be a non-emtpy list, but is {bad_words_ids}.")
|
||||
if any(not isinstance(bad_word_ids, list) for bad_word_ids in bad_words_ids):
|
||||
raise ValueError(f"`bad_words_ids` has to be a list of lists, but is {bad_words_ids}.")
|
||||
if any(
|
||||
any((not isinstance(token_id, (int, np.integer)) or token_id < 0) for token_id in bad_word_ids)
|
||||
for bad_word_ids in bad_words_ids
|
||||
):
|
||||
raise ValueError(
|
||||
f"Each list in `bad_words_ids` has to be a list of positive integers, but is {bad_words_ids}."
|
||||
)
|
||||
|
||||
self.bad_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], bad_words_ids))
|
||||
|
||||
for banned_token_seq in self.bad_words_ids:
|
||||
assert len(banned_token_seq) > 0, "Banned words token sequences {} cannot have an empty list".format(
|
||||
bad_words_ids
|
||||
)
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
banned_tokens = self._calc_banned_bad_words_ids(input_ids)
|
||||
scores = self._set_scores_to_inf_for_banned_tokens(scores, banned_tokens)
|
||||
|
||||
return scores
|
||||
|
||||
def _tokens_match(self, prev_tokens: torch.LongTensor, tokens: List[int]) -> bool:
|
||||
if len(tokens) == 0:
|
||||
# if bad word tokens is just one token always ban it
|
||||
return True
|
||||
elif len(tokens) > len(prev_tokens):
|
||||
# if bad word tokens are longer then prev input_ids they can't be equal
|
||||
return False
|
||||
elif prev_tokens[-len(tokens) :].tolist() == tokens:
|
||||
# if tokens match
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def _calc_banned_bad_words_ids(self, prev_input_ids: Iterable[int]) -> Iterable[int]:
|
||||
banned_tokens = []
|
||||
for prev_input_ids_slice in prev_input_ids:
|
||||
banned_tokens_slice = []
|
||||
for banned_token_seq in self.bad_words_ids:
|
||||
if self._tokens_match(prev_input_ids_slice, banned_token_seq[:-1]) is False:
|
||||
# if tokens do not match continue
|
||||
continue
|
||||
|
||||
banned_tokens_slice.append(banned_token_seq[-1])
|
||||
|
||||
banned_tokens.append(banned_tokens_slice)
|
||||
|
||||
return banned_tokens
|
||||
|
||||
def _set_scores_to_inf_for_banned_tokens(self, scores: torch.Tensor, banned_tokens: List[List[int]]) -> None:
|
||||
"""
|
||||
Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be a
|
||||
list of list of banned tokens to ban in the format [[batch index, vocabulary position],...
|
||||
|
||||
Args:
|
||||
scores: logits distribution of shape (batch size, vocabulary size)
|
||||
banned_tokens: list of list of tokens to ban of length (batch_size)
|
||||
"""
|
||||
banned_mask_list = []
|
||||
for idx, batch_banned_tokens in enumerate(banned_tokens):
|
||||
for token in batch_banned_tokens:
|
||||
banned_mask_list.append([idx, token])
|
||||
if not banned_mask_list:
|
||||
return scores
|
||||
|
||||
banned_mask = torch.LongTensor(banned_mask_list)
|
||||
indices = torch.ones(len(banned_mask))
|
||||
# A sparse tensor is generated from a list of coordinates: [[0, 1], [0, 2], [2, 0]]. A conversion to dense tensor generates:
|
||||
# [ 0 1 1 ]
|
||||
# [ 0 0 0 ]
|
||||
# [ 1 0 0 ]
|
||||
|
||||
banned_mask = (
|
||||
torch.sparse.LongTensor(banned_mask.t(), indices, scores.size()).to(scores.device).to_dense().bool()
|
||||
)
|
||||
scores = scores.masked_fill(banned_mask, -float("inf"))
|
||||
return scores
|
File diff suppressed because it is too large
Load Diff
@ -1084,7 +1084,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self, decoder_input_ids, past, attention_mask, use_cache, encoder_outputs, **kwargs
|
||||
self, decoder_input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
|
||||
):
|
||||
return {
|
||||
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||||
|
@ -514,12 +514,12 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past, **kwargs):
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, use_cache=None, **kwargs):
|
||||
# only last token for inputs_ids if past is defined in kwargs
|
||||
if past:
|
||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||
|
||||
return {"input_ids": input_ids, "past_key_values": past, "use_cache": kwargs["use_cache"]}
|
||||
return {"input_ids": input_ids, "past_key_values": past, "use_cache": use_cache}
|
||||
|
||||
@add_start_docstrings_to_model_forward(CTRL_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
|
@ -431,7 +431,7 @@ class EncoderDecoderModel(PreTrainedModel):
|
||||
encoder_attentions=encoder_outputs.attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past, attention_mask, encoder_outputs, **kwargs):
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, encoder_outputs=None, **kwargs):
|
||||
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids)
|
||||
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
|
||||
input_dict = {
|
||||
|
@ -1107,7 +1107,7 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel):
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self, decoder_input_ids, past, attention_mask, use_cache, encoder_outputs, **kwargs
|
||||
self, decoder_input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
|
||||
):
|
||||
return {
|
||||
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||||
|
@ -1800,7 +1800,7 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel):
|
||||
return loss
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self, decoder_input_ids, past, attention_mask, use_cache, encoder_outputs, **kwargs
|
||||
self, decoder_input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
|
||||
):
|
||||
assert encoder_outputs is not None, "`encoder_outputs` have to be passed for generation."
|
||||
|
||||
|
@ -22,6 +22,7 @@ import torch
|
||||
from .configuration_rag import RagConfig
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .file_utils import add_start_docstrings_to_model_forward, replace_return_docstrings
|
||||
from .generation_beam_search import BeamSearchScorer
|
||||
from .modeling_outputs import ModelOutput
|
||||
from .modeling_utils import PreTrainedModel
|
||||
from .retrieval_rag import RagRetriever
|
||||
@ -825,7 +826,7 @@ class RagSequenceForGeneration(RagPreTrainedModel):
|
||||
num_return_sequences=None, # defaults to 1
|
||||
num_beams=None, # defaults to 1
|
||||
n_docs=None,
|
||||
**kwargs
|
||||
**model_kwargs
|
||||
):
|
||||
"""
|
||||
Implements RAG sequence "thorough" decoding. Read the :meth:`~transformers.PreTrainedModel.generate``
|
||||
@ -872,7 +873,6 @@ class RagSequenceForGeneration(RagPreTrainedModel):
|
||||
)
|
||||
num_beams = num_beams if num_beams is not None else self.config.num_beams
|
||||
|
||||
# TODO(patrick) - clean up generate here
|
||||
if self.retriever is not None and context_input_ids is None:
|
||||
question_hidden_states = self.question_encoder(input_ids, attention_mask=attention_mask)[0]
|
||||
context_input_ids = self.retriever(
|
||||
@ -887,10 +887,9 @@ class RagSequenceForGeneration(RagPreTrainedModel):
|
||||
context_input_ids = context_input_ids.to(input_ids)
|
||||
|
||||
hypos = []
|
||||
kwargs["num_beams"] = num_beams
|
||||
kwargs["num_return_sequences"] = num_beams
|
||||
kwargs["attention_mask"] = None
|
||||
kwargs["n_docs"] = n_docs
|
||||
model_kwargs["num_beams"] = num_beams
|
||||
model_kwargs["num_return_sequences"] = num_beams
|
||||
model_kwargs["attention_mask"] = None
|
||||
|
||||
for index in range(len(input_ids)):
|
||||
# first, generate beams from documents:
|
||||
@ -898,7 +897,7 @@ class RagSequenceForGeneration(RagPreTrainedModel):
|
||||
|
||||
output_sequences = self.generator.generate(
|
||||
generator_input_ids,
|
||||
**kwargs,
|
||||
**model_kwargs,
|
||||
) # n_docs * n_beam, tgt_len
|
||||
if do_deduplication:
|
||||
# do_deduplication, max_output_len
|
||||
@ -1018,7 +1017,15 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
return self.rag.generator.adjust_logits_during_generation(logits, cur_len=cur_len, max_length=max_length)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self, decoder_input_ids, past, attention_mask, use_cache, encoder_outputs, doc_scores, n_docs=None, **kwargs
|
||||
self,
|
||||
decoder_input_ids,
|
||||
past=None,
|
||||
attention_mask=None,
|
||||
use_cache=None,
|
||||
encoder_outputs=None,
|
||||
doc_scores=None,
|
||||
n_docs=None,
|
||||
**kwargs
|
||||
):
|
||||
return {
|
||||
"input_ids": None,
|
||||
@ -1222,11 +1229,12 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
eos_token_id=None,
|
||||
length_penalty=None,
|
||||
no_repeat_ngram_size=None,
|
||||
repetition_penalty=None,
|
||||
bad_words_ids=None,
|
||||
num_return_sequences=None,
|
||||
decoder_start_token_id=None,
|
||||
n_docs=None,
|
||||
**kwargs
|
||||
**model_kwargs
|
||||
):
|
||||
"""
|
||||
Implements RAG token decoding.
|
||||
@ -1307,22 +1315,15 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
"""
|
||||
# set default parameters
|
||||
n_docs = n_docs if n_docs is not None else self.config.n_docs
|
||||
max_length = max_length if max_length is not None else self.config.max_length
|
||||
min_length = min_length if min_length is not None else self.config.min_length
|
||||
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
num_beams = num_beams if num_beams is not None else self.config.num_beams
|
||||
bos_token_id = bos_token_id if bos_token_id is not None else self.config.generator.bos_token_id
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.generator.pad_token_id
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.generator.eos_token_id
|
||||
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
|
||||
no_repeat_ngram_size = (
|
||||
no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
|
||||
)
|
||||
bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
|
||||
max_length = max_length if max_length is not None else self.config.max_length
|
||||
num_return_sequences = (
|
||||
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
|
||||
)
|
||||
bos_token_id = bos_token_id if bos_token_id is not None else self.config.generator.bos_token_id
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.generator.eos_token_id
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.generator.pad_token_id
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
decoder_start_token_id = (
|
||||
decoder_start_token_id
|
||||
if decoder_start_token_id is not None
|
||||
@ -1365,7 +1366,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
encoder = self.rag.generator.get_encoder()
|
||||
encoder_outputs = encoder(input_ids=context_input_ids, attention_mask=context_attention_mask, return_dict=True)
|
||||
|
||||
decoder_input_ids = torch.full(
|
||||
input_ids = torch.full(
|
||||
(batch_size * num_beams, 1),
|
||||
decoder_start_token_id,
|
||||
dtype=torch.long,
|
||||
@ -1388,64 +1389,57 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
doc_scores = doc_scores.repeat_interleave(num_beams, dim=0)
|
||||
|
||||
# define start_len & additional parameters
|
||||
cur_len = 1
|
||||
vocab_size = self.config.generator.vocab_size
|
||||
kwargs["doc_scores"] = doc_scores
|
||||
kwargs["encoder_outputs"] = encoder_outputs
|
||||
kwargs["n_docs"] = n_docs
|
||||
model_kwargs["doc_scores"] = doc_scores
|
||||
model_kwargs["encoder_outputs"] = encoder_outputs
|
||||
model_kwargs["attention_mask"] = context_attention_mask
|
||||
model_kwargs["n_docs"] = n_docs
|
||||
|
||||
# not needed. TODO(PVP): change after generate refactor
|
||||
do_sample = False
|
||||
temperature = self.config.temperature
|
||||
top_k = self.config.top_k
|
||||
top_p = self.config.top_p
|
||||
repetition_penalty = self.config.repetition_penalty
|
||||
pre_processor = self._get_logits_processor(
|
||||
repetition_penalty=repetition_penalty,
|
||||
no_repeat_ngram_size=no_repeat_ngram_size,
|
||||
bad_words_ids=bad_words_ids,
|
||||
min_length=min_length,
|
||||
eos_token_id=eos_token_id,
|
||||
)
|
||||
|
||||
if num_beams > 1:
|
||||
return self._generate_beam_search(
|
||||
decoder_input_ids,
|
||||
cur_len=cur_len,
|
||||
if num_beams == 1:
|
||||
if num_return_sequences > 1:
|
||||
raise ValueError(
|
||||
f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search."
|
||||
)
|
||||
return self.greedy_search(
|
||||
input_ids,
|
||||
pre_processor=pre_processor,
|
||||
max_length=max_length,
|
||||
min_length=min_length,
|
||||
do_sample=do_sample,
|
||||
early_stopping=early_stopping,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
repetition_penalty=repetition_penalty,
|
||||
no_repeat_ngram_size=no_repeat_ngram_size,
|
||||
bad_words_ids=bad_words_ids,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
**model_kwargs,
|
||||
)
|
||||
elif num_beams > 1:
|
||||
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
|
||||
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
|
||||
if num_return_sequences > num_beams:
|
||||
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
|
||||
beam_scorer = BeamSearchScorer(
|
||||
batch_size=batch_size,
|
||||
num_return_sequences=num_return_sequences,
|
||||
length_penalty=length_penalty,
|
||||
max_length=max_length,
|
||||
num_beams=num_beams,
|
||||
vocab_size=vocab_size,
|
||||
attention_mask=context_attention_mask,
|
||||
use_cache=use_cache,
|
||||
model_kwargs=kwargs,
|
||||
device=self.device,
|
||||
length_penalty=length_penalty,
|
||||
do_early_stopping=early_stopping,
|
||||
num_beam_hyps_to_keep=num_return_sequences,
|
||||
)
|
||||
return self.beam_search(
|
||||
input_ids,
|
||||
beam_scorer,
|
||||
pre_processor=pre_processor,
|
||||
max_length=max_length,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
**model_kwargs,
|
||||
)
|
||||
else:
|
||||
return self._generate_no_beam_search(
|
||||
decoder_input_ids,
|
||||
cur_len=cur_len,
|
||||
max_length=max_length,
|
||||
min_length=min_length,
|
||||
do_sample=do_sample,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
repetition_penalty=repetition_penalty,
|
||||
no_repeat_ngram_size=no_repeat_ngram_size,
|
||||
bad_words_ids=bad_words_ids,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
batch_size=batch_size,
|
||||
attention_mask=context_attention_mask,
|
||||
use_cache=use_cache,
|
||||
model_kwargs=kwargs,
|
||||
)
|
||||
raise ValueError(f"`num_beams` has to be an integer strictly superior to 0 (≥ 1), but is {num_beams}")
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.rag.generator.get_input_embeddings()
|
||||
|
@ -638,7 +638,6 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
|
||||
rotations_shape = (self.num_attention_heads, vectors.shape[-1], num_hashes, rotation_size // 2)
|
||||
# create a random self.attention_head_size x num_hashes x num_buckets/2
|
||||
random_rotations = torch.randn(rotations_shape, device=vectors.device, dtype=vectors.dtype)
|
||||
|
||||
# Output dim: Batch_Size x Num_Attn_Heads x Num_Hashes x Seq_Len x Num_Buckets/2
|
||||
rotated_vectors = torch.einsum("bmtd,mdhr->bmhtr", vectors, random_rotations)
|
||||
|
||||
@ -1471,7 +1470,9 @@ class ReformerLayer(nn.Module):
|
||||
# every forward pass we sample a different seed
|
||||
# for dropout and save for forward fn in backward pass
|
||||
# to have correct dropout
|
||||
self._init_attention_seed()
|
||||
if self.training:
|
||||
self._init_attention_seed()
|
||||
|
||||
attn_outputs = self.attention(
|
||||
hidden_states=hidden_states,
|
||||
head_mask=head_mask,
|
||||
@ -1494,7 +1495,8 @@ class ReformerLayer(nn.Module):
|
||||
# every forward pass we sample a different seed
|
||||
# for dropout and save seed for forward fn in backward
|
||||
# to have correct dropout
|
||||
self._init_feed_forward_seed()
|
||||
if self.training:
|
||||
self._init_feed_forward_seed()
|
||||
# Y_2 = X_2 + g(Y_1)
|
||||
hidden_states = hidden_states + self.feed_forward(attn_output)
|
||||
|
||||
@ -2263,7 +2265,7 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel):
|
||||
attentions=reformer_outputs.attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past, **kwargs):
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, use_cache=None, num_hashes=None, **kwargs):
|
||||
# only last token for inputs_ids if past is defined in kwargs
|
||||
if past is not None:
|
||||
input_ids = input_ids[:, -1:]
|
||||
@ -2271,12 +2273,10 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel):
|
||||
inputs_dict = {
|
||||
"input_ids": input_ids,
|
||||
"past_buckets_states": past,
|
||||
"use_cache": kwargs["use_cache"],
|
||||
"use_cache": use_cache,
|
||||
"num_hashes": num_hashes,
|
||||
}
|
||||
|
||||
if "num_hashes" in kwargs:
|
||||
inputs_dict["num_hashes"] = kwargs["num_hashes"]
|
||||
|
||||
return inputs_dict
|
||||
|
||||
def _reorder_cache(self, past, beam_idx):
|
||||
|
@ -1232,7 +1232,9 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
encoder_attentions=encoder_outputs.attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past, attention_mask, use_cache, encoder_outputs, **kwargs):
|
||||
def prepare_inputs_for_generation(
|
||||
self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
|
||||
):
|
||||
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
|
@ -1091,7 +1091,7 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
|
||||
else:
|
||||
return self.crit.out_layers[-1]
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past, **model_kwargs):
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, **model_kwargs):
|
||||
inputs = {}
|
||||
|
||||
# if past is defined in model kwargs then use it for faster decoding
|
||||
|
@ -1300,7 +1300,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_loss
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past, **kwargs):
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, use_cache=None, **kwargs):
|
||||
# Add dummy token at the end (no attention on this one)
|
||||
|
||||
effective_batch_size = input_ids.shape[0]
|
||||
@ -1333,7 +1333,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
|
||||
"input_ids": input_ids,
|
||||
"perm_mask": perm_mask,
|
||||
"target_mapping": target_mapping,
|
||||
"use_cache": kwargs["use_cache"],
|
||||
"use_cache": use_cache,
|
||||
}
|
||||
|
||||
# if past is defined in model kwargs then use it for faster decoding
|
||||
|
@ -88,8 +88,8 @@ def is_pipeline_test(test_case):
|
||||
"""
|
||||
Decorator marking a test as a pipeline test.
|
||||
|
||||
Pipeline tests are skipped by default and we can run only them by setting RUN_PIPELINE_TEST environment variable to
|
||||
a truthy value and selecting the is_pipeline_test pytest mark.
|
||||
Pipeline tests are skipped by default and we can run only them by setting RUN_PIPELINE_TESTS environment variable
|
||||
to a truthy value and selecting the is_pipeline_test pytest mark.
|
||||
|
||||
"""
|
||||
if not _run_pipeline_tests:
|
||||
|
@ -104,6 +104,66 @@ class TextDatasetForNextSentencePrediction:
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class BeamScorer:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class BeamSearchScorer:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class LogitsProcessor:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class LogitsProcessorList:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class LogitsWarper:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class MinLengthLogitsProcessor:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class NoBadWordsLogitsProcessor:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class NoRepeatNGramLogitsProcessor:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class RepetitionPenaltyLogitsProcessor:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class TemperatureLogitsWarper:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class TopKLogitsWarper:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class TopPLogitsWarper:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
def top_k_top_p_filtering(*args, **kwargs):
|
||||
requires_pytorch(top_k_top_p_filtering)
|
||||
|
||||
|
239
tests/test_generation_beam_search.py
Normal file
239
tests/test_generation_beam_search.py
Normal file
@ -0,0 +1,239 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Team Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a clone of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_torch, torch_device
|
||||
|
||||
from .test_modeling_common import floats_tensor, ids_tensor
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers.generation_beam_search import BeamHypotheses, BeamSearchScorer
|
||||
|
||||
|
||||
class BeamSearchTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=3,
|
||||
sequence_length=10,
|
||||
vocab_size=99,
|
||||
pad_token_id=0,
|
||||
max_length=20,
|
||||
num_beams=4,
|
||||
length_penalty=2.0,
|
||||
do_early_stopping=True,
|
||||
num_beam_hyps_to_keep=2,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.sequence_length = sequence_length
|
||||
self.vocab_size = vocab_size
|
||||
self.pad_token_id = pad_token_id
|
||||
self.max_length = max_length
|
||||
self.num_beams = num_beams
|
||||
self.length_penalty = length_penalty
|
||||
self.do_early_stopping = do_early_stopping
|
||||
self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
|
||||
|
||||
# cannot be randomely generated
|
||||
self.eos_token_id = vocab_size + 1
|
||||
|
||||
def prepare_beam_scorer(self, **kwargs):
|
||||
return BeamSearchScorer(
|
||||
batch_size=kwargs.get("batch_size", self.batch_size),
|
||||
max_length=kwargs.get("max_length", self.max_length),
|
||||
num_beams=kwargs.get("num_beams", self.num_beams),
|
||||
device=torch_device,
|
||||
length_penalty=kwargs.get("length_penalty", self.length_penalty),
|
||||
do_early_stopping=kwargs.get("do_early_stopping", self.do_early_stopping),
|
||||
num_beam_hyps_to_keep=kwargs.get("num_beam_hyps_to_keep", self.num_beam_hyps_to_keep),
|
||||
)
|
||||
|
||||
def prepare_inputs(self):
|
||||
input_ids = ids_tensor((self.batch_size * self.num_beams, self.sequence_length), self.vocab_size)
|
||||
next_tokens = ids_tensor((self.batch_size, 2 * self.num_beams), self.vocab_size).to(torch_device)
|
||||
next_indices = ids_tensor((self.batch_size, 2 * self.num_beams), self.num_beams).to(torch_device)
|
||||
next_scores, _ = (-floats_tensor((self.batch_size, 2 * self.num_beams)).to(torch_device)).sort(descending=True)
|
||||
return (input_ids, next_tokens, next_indices, next_scores)
|
||||
|
||||
def check_beam_hypotheses(self, input_ids, *args):
|
||||
# check that correct number of beam hypotheses is set in beam scorer
|
||||
beam_scorer = self.prepare_beam_scorer(do_early_stopping=True)
|
||||
beam_hyp = beam_scorer._beam_hyps[0]
|
||||
|
||||
self.parent.assertEqual(len(beam_scorer._beam_hyps), self.batch_size)
|
||||
|
||||
# check correct type
|
||||
self.parent.assertTrue(isinstance(beam_hyp, BeamHypotheses))
|
||||
|
||||
# check that num_beams is correctly set
|
||||
self.parent.assertEqual(beam_hyp.num_beams, self.num_beams)
|
||||
|
||||
# check for early stopping deactivated
|
||||
for beam_idx in range(self.num_beams):
|
||||
beam_hyp.add(input_ids[beam_idx], -10.0)
|
||||
|
||||
# if early stopping True -> score does not matter
|
||||
self.parent.assertTrue(beam_hyp.is_done(-10.0, 5))
|
||||
|
||||
# re-init
|
||||
beam_scorer = self.prepare_beam_scorer(do_early_stopping=False)
|
||||
beam_hyp = beam_scorer._beam_hyps[0]
|
||||
|
||||
# add `num_beams + 1` beams to change `worst_score`
|
||||
for beam_idx in range(self.num_beams + 1):
|
||||
beam_hyp.add(input_ids[beam_idx], -10.0 + float(beam_idx))
|
||||
|
||||
# -10.0 is removed => -9.0 is worst score
|
||||
self.parent.assertAlmostEqual(beam_hyp.worst_score, -9.0 / (self.sequence_length ** beam_hyp.length_penalty))
|
||||
|
||||
# -5.0 is better than worst score => should not be finished
|
||||
self.parent.assertFalse(beam_hyp.is_done(-5.0, self.sequence_length))
|
||||
|
||||
# -20.0 is worse than worst score => should be finished
|
||||
self.parent.assertTrue(beam_hyp.is_done(-20.0, self.sequence_length))
|
||||
|
||||
def check_beam_scorer_update(self, input_ids, next_tokens, next_indices, next_scores):
|
||||
# check too many eos tokens
|
||||
beam_scorer = self.prepare_beam_scorer()
|
||||
|
||||
tokens = next_tokens.clone()
|
||||
tokens[0, :] = self.eos_token_id
|
||||
|
||||
with self.parent.assertRaises(ValueError):
|
||||
beam_scorer.process(input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id)
|
||||
|
||||
# check all batches are done
|
||||
beam_scorer = self.prepare_beam_scorer()
|
||||
|
||||
tokens = next_tokens.clone()
|
||||
tokens[:, : self.num_beams] = self.eos_token_id
|
||||
beam_scorer.process(input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id)
|
||||
# beam scorer should be done
|
||||
self.parent.assertTrue(beam_scorer.is_done)
|
||||
|
||||
# check
|
||||
beam_scorer = self.prepare_beam_scorer()
|
||||
|
||||
tokens = next_tokens.clone()
|
||||
tokens[:, 1] = self.eos_token_id
|
||||
beam_outputs = beam_scorer.process(
|
||||
input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id
|
||||
)
|
||||
output_scores = beam_outputs["next_beam_scores"]
|
||||
output_tokens = beam_outputs["next_beam_tokens"]
|
||||
output_indices = beam_outputs["next_beam_indices"]
|
||||
|
||||
def cut_expected_tensor(tensor):
|
||||
return torch.cat([tensor[:, :1], tensor[:, 2 : self.num_beams + 1]], dim=1).flatten()
|
||||
|
||||
# check all outptus
|
||||
# cut out id of eos token and take best `num_beams` outputs
|
||||
expected_output_tokens = cut_expected_tensor(tokens)
|
||||
expected_output_scores = cut_expected_tensor(next_scores)
|
||||
|
||||
# add num_beams * batch_idx
|
||||
expected_output_indices = (
|
||||
cut_expected_tensor(next_indices)
|
||||
+ (torch.arange(self.num_beams * self.batch_size, device=torch_device) // self.num_beams) * self.num_beams
|
||||
)
|
||||
|
||||
self.parent.assertListEqual(expected_output_tokens.tolist(), output_tokens.tolist())
|
||||
self.parent.assertListEqual(expected_output_indices.tolist(), output_indices.tolist())
|
||||
self.parent.assertTrue(torch.allclose(expected_output_scores, output_scores, atol=1e-3))
|
||||
|
||||
# make sure ids of eos token are correctly saved in beam_hyps of beam scorer
|
||||
for batch_idx in range(self.batch_size):
|
||||
correct_idx = batch_idx * self.num_beams + next_indices[batch_idx, 1]
|
||||
self.parent.assertListEqual(
|
||||
input_ids[correct_idx].tolist(), beam_scorer._beam_hyps[batch_idx].beams[0][-1].tolist()
|
||||
)
|
||||
|
||||
def check_beam_scores_finalize(self, input_ids, next_tokens, next_indices, next_scores):
|
||||
# max_length should be only one more than current input_ids to check that eos is correctly appended
|
||||
max_length = self.sequence_length + 1
|
||||
beam_scorer = self.prepare_beam_scorer(
|
||||
num_beam_hyps_to_keep=1, max_length=max_length, length_penalty=1.0, do_early_stopping=False
|
||||
)
|
||||
|
||||
# update beams and append to input_ids
|
||||
tokens = next_tokens.clone()
|
||||
# first batch, first output has to finish with eos token id since scores are correctly sorted
|
||||
tokens[0, 0] = self.eos_token_id
|
||||
# make sure corresponding score is as good as possible to surely be picked first
|
||||
next_scores[0, 0] = 0.0
|
||||
beam_outputs = beam_scorer.process(
|
||||
input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id
|
||||
)
|
||||
output_scores = beam_outputs["next_beam_scores"]
|
||||
output_tokens = beam_outputs["next_beam_tokens"]
|
||||
output_indices = beam_outputs["next_beam_indices"]
|
||||
|
||||
input_ids = torch.cat([input_ids[output_indices, :], output_tokens.unsqueeze(-1)], dim=-1)
|
||||
|
||||
# finalize
|
||||
decoded = beam_scorer.finalize(
|
||||
input_ids,
|
||||
output_scores,
|
||||
output_tokens,
|
||||
output_indices,
|
||||
pad_token_id=self.pad_token_id,
|
||||
eos_token_id=self.eos_token_id,
|
||||
)
|
||||
# since `num_beam_hyps_to_keep` = 1 => only return `batch_size` x `max_length`
|
||||
self.parent.assertListEqual(list(decoded.shape), [self.batch_size, max_length])
|
||||
|
||||
# first batch has to finish with eos_token
|
||||
self.parent.assertEqual(decoded[0, -1].item(), self.eos_token_id)
|
||||
|
||||
# other batches cannot finish with eos token
|
||||
self.parent.assertNotEqual(decoded[1, -1].item(), self.eos_token_id)
|
||||
self.parent.assertNotEqual(decoded[2, -1].item(), self.eos_token_id)
|
||||
|
||||
# now test that if `num_beam_hyps_to_keep` is 3 => all beams are returned
|
||||
beam_scorer.num_beam_hyps_to_keep = self.num_beams
|
||||
decoded = beam_scorer.finalize(
|
||||
input_ids,
|
||||
output_scores,
|
||||
output_tokens,
|
||||
output_indices,
|
||||
pad_token_id=self.pad_token_id,
|
||||
eos_token_id=self.eos_token_id,
|
||||
)
|
||||
self.parent.assertListEqual(list(decoded.shape), [self.num_beams * self.batch_size, max_length])
|
||||
|
||||
|
||||
@require_torch
|
||||
class BeamSearchTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.beam_search_tester = BeamSearchTester(self)
|
||||
|
||||
def test_beam_hypotheses(self):
|
||||
inputs = self.beam_search_tester.prepare_inputs()
|
||||
self.beam_search_tester.check_beam_hypotheses(*inputs)
|
||||
|
||||
def test_beam_scorer_update(self):
|
||||
inputs = self.beam_search_tester.prepare_inputs()
|
||||
self.beam_search_tester.check_beam_scorer_update(*inputs)
|
||||
|
||||
def test_beam_scorer_finalize(self):
|
||||
inputs = self.beam_search_tester.prepare_inputs()
|
||||
self.beam_search_tester.check_beam_scores_finalize(*inputs)
|
283
tests/test_generation_logits_process.py
Normal file
283
tests/test_generation_logits_process.py
Normal file
@ -0,0 +1,283 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Team Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a clone of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_torch, torch_device
|
||||
|
||||
from .test_modeling_common import ids_tensor
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from transformers.generation_logits_process import (
|
||||
LogitsProcessorList,
|
||||
MinLengthLogitsProcessor,
|
||||
NoBadWordsLogitsProcessor,
|
||||
NoRepeatNGramLogitsProcessor,
|
||||
RepetitionPenaltyLogitsProcessor,
|
||||
TemperatureLogitsWarper,
|
||||
TopKLogitsWarper,
|
||||
TopPLogitsWarper,
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
class LogitsProcessorTest(unittest.TestCase):
|
||||
def _get_uniform_logits(self, batch_size: int, length: int):
|
||||
scores = torch.ones((batch_size, length), device=torch_device, dtype=torch.float) / length
|
||||
return scores
|
||||
|
||||
def test_min_lenght_dist_processor(self):
|
||||
vocab_size = 20
|
||||
batch_size = 4
|
||||
eos_token_id = 0
|
||||
|
||||
min_dist_processor = MinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
|
||||
|
||||
# check that min length is applied at length 5
|
||||
input_ids = ids_tensor((batch_size, 5), vocab_size=20)
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
scores_before_min_length = min_dist_processor(input_ids, scores)
|
||||
self.assertListEqual(scores_before_min_length[:, eos_token_id].tolist(), 4 * [-float("inf")])
|
||||
|
||||
# check that min length is not applied anymore at length 15
|
||||
input_ids = ids_tensor((batch_size, 15), vocab_size=20)
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
scores_before_min_length = min_dist_processor(input_ids, scores)
|
||||
self.assertFalse(torch.isinf(scores_before_min_length).any())
|
||||
|
||||
def test_temperature_dist_warper(self):
|
||||
input_ids = None
|
||||
length = 20
|
||||
|
||||
scores = self._get_uniform_logits(batch_size=2, length=length)
|
||||
|
||||
# tweak scores to not be uniform anymore
|
||||
scores[1, 5] = (1 / length) + 0.1 # peak, 1st batch
|
||||
scores[1, 10] = (1 / length) - 0.4 # valley, 1st batch
|
||||
|
||||
# compute softmax
|
||||
probs = F.softmax(scores, dim=-1)
|
||||
|
||||
temp_dist_warper_sharper = TemperatureLogitsWarper(temperature=0.5)
|
||||
temp_dist_warper_smoother = TemperatureLogitsWarper(temperature=1.3)
|
||||
|
||||
warped_prob_sharp = F.softmax(temp_dist_warper_sharper(input_ids, scores.clone()), dim=-1)
|
||||
warped_prob_smooth = F.softmax(temp_dist_warper_smoother(input_ids, scores.clone()), dim=-1)
|
||||
|
||||
# uniform distribution stays uniform
|
||||
self.assertTrue(torch.allclose(probs[0, :], warped_prob_sharp[0, :], atol=1e-3))
|
||||
self.assertTrue(torch.allclose(probs[0, :], warped_prob_smooth[0, :], atol=1e-3))
|
||||
|
||||
# sharp peaks get higher, valleys get lower
|
||||
self.assertLess(probs[1, :].max(), warped_prob_sharp[1, :].max())
|
||||
self.assertGreater(probs[1, :].min(), warped_prob_sharp[1, :].min())
|
||||
|
||||
# smooth peaks get lower, valleys get higher
|
||||
self.assertGreater(probs[1, :].max(), warped_prob_smooth[1, :].max())
|
||||
self.assertLess(probs[1, :].min(), warped_prob_smooth[1, :].min())
|
||||
|
||||
def test_repetition_penalty_dist_process(self):
|
||||
input_ids = torch.tensor([[0, 1], [5, 0]], device=torch_device, dtype=torch.long)
|
||||
vocab_size = 10
|
||||
|
||||
scores = self._get_uniform_logits(batch_size=2, length=vocab_size)
|
||||
|
||||
# give values special values
|
||||
scores[0, 0] = -(1 / vocab_size)
|
||||
scores[1, 5] = 4 / vocab_size
|
||||
|
||||
rep_penalty_proc = RepetitionPenaltyLogitsProcessor(penalty=2.0)
|
||||
|
||||
scores = rep_penalty_proc(input_ids, scores.clone())
|
||||
|
||||
# check that values were correctly changed
|
||||
self.assertAlmostEqual(scores[0, 0].item(), -(1 / vocab_size) * 2)
|
||||
self.assertAlmostEqual(scores[0, 1].item(), (1 / vocab_size) / 2)
|
||||
|
||||
self.assertAlmostEqual(scores[1, 0].item(), (1 / vocab_size) / 2)
|
||||
self.assertAlmostEqual(scores[1, 5].item(), (4 / vocab_size) / 2)
|
||||
|
||||
def test_top_k_dist_warper(self):
|
||||
input_ids = None
|
||||
vocab_size = 10
|
||||
batch_size = 2
|
||||
|
||||
# create ramp distribution
|
||||
ramp_logits = (
|
||||
torch.arange(vocab_size, device=torch_device, dtype=torch.float).unsqueeze(0).repeat(batch_size, 1)
|
||||
)
|
||||
ramp_logits[1:, : vocab_size // 2] = ramp_logits[1:, : vocab_size // 2] + vocab_size
|
||||
|
||||
top_k_warp = TopKLogitsWarper(3)
|
||||
|
||||
scores = top_k_warp(input_ids, ramp_logits)
|
||||
|
||||
# check that correct tokens are filtered
|
||||
self.assertListEqual(torch.isinf(scores[0]).tolist(), 7 * [True] + 3 * [False])
|
||||
self.assertListEqual(torch.isinf(scores[1]).tolist(), 2 * [True] + 3 * [False] + 5 * [True])
|
||||
|
||||
# check special cases
|
||||
length = 5
|
||||
|
||||
logits = self._get_uniform_logits(batch_size=batch_size, length=length)
|
||||
top_k_warp_safety_check = TopKLogitsWarper(top_k=1, filter_value=0.0, min_tokens_to_keep=3)
|
||||
|
||||
scores = top_k_warp_safety_check(input_ids, logits)
|
||||
# uniform dist is not changed
|
||||
self.assertListEqual((scores == 0.0).to(torch.long).sum(dim=-1).tolist(), [0, 0])
|
||||
|
||||
ramp_logits = torch.arange(length, device=torch_device, dtype=torch.float).unsqueeze(0).repeat(batch_size, 1)
|
||||
scores = top_k_warp_safety_check(input_ids, ramp_logits)
|
||||
|
||||
# min_tokens overwrites k: 3 tokens are kept => 2 tokens are nullified
|
||||
self.assertListEqual((scores == 0.0).to(torch.long).sum(dim=-1).tolist(), [2, 2])
|
||||
|
||||
def test_top_p_dist_warper(self):
|
||||
input_ids = None
|
||||
vocab_size = 10
|
||||
batch_size = 2
|
||||
|
||||
# create distribution and take log (inverse to Softmax as taken in TopPLogitsWarper)
|
||||
dist = torch.log(
|
||||
torch.tensor([[0.3, 0.1, 0.1, 0.5], [0.15, 0.3, 0.3, 0.25]], device=torch_device, dtype=torch.float)
|
||||
)
|
||||
|
||||
top_p_warp = TopPLogitsWarper(0.7)
|
||||
filtered_dist = torch.exp(top_p_warp(input_ids, dist))
|
||||
|
||||
# dist should be filtered to keep min num values so that sum is >= 0.7
|
||||
# exp (-inf) => 0
|
||||
EXPECTED_FILTERED_DIST = torch.tensor(
|
||||
[[0.3, 0.0, 0.0, 0.5], [0.0, 0.3, 0.3, 0.25]], device=torch_device, dtype=torch.float
|
||||
)
|
||||
self.assertTrue(torch.allclose(filtered_dist, EXPECTED_FILTERED_DIST, atol=1e-3))
|
||||
|
||||
# check edge cases with negative and extreme logits
|
||||
ramp_logits = torch.arange(vocab_size, device=torch_device, dtype=torch.float).unsqueeze(0).repeat(
|
||||
batch_size, 1
|
||||
) - (vocab_size // 2)
|
||||
|
||||
# make ramp_logits more extreme
|
||||
ramp_logits[1] = ramp_logits[1] * 100.0
|
||||
|
||||
# make sure at least 2 tokens are kept
|
||||
top_p_warp = TopPLogitsWarper(0.9, min_tokens_to_keep=2, filter_value=0.0)
|
||||
filtered_dist = top_p_warp(input_ids, ramp_logits)
|
||||
|
||||
# first batch should keep three tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2.
|
||||
self.assertListEqual((filtered_dist != 0.0).to(torch.long).sum(dim=-1).tolist(), [3, 2])
|
||||
|
||||
def test_no_repeat_ngram_dist_processor(self):
|
||||
vocab_size = 3
|
||||
batch_size = 2
|
||||
|
||||
input_ids = torch.tensor([[1, 1, 2, 1], [0, 1, 0, 1]], device=torch_device, dtype=torch.long)
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
|
||||
no_repeat_proc_2_gram = NoRepeatNGramLogitsProcessor(2)
|
||||
no_repeat_proc_3_gram = NoRepeatNGramLogitsProcessor(3)
|
||||
|
||||
filtered_scores_2_gram = no_repeat_proc_2_gram(input_ids, scores.clone())
|
||||
filtered_scores_3_gram = no_repeat_proc_3_gram(input_ids, scores.clone())
|
||||
|
||||
# 2-gram would forbid 2nd and 3rd token (1,2) at 1st batch and 1st token (0) at 2nd batch
|
||||
self.assertListEqual(torch.isinf(filtered_scores_2_gram).tolist(), [[False, True, True], [True, False, False]])
|
||||
|
||||
# 3-gram would forbid no token at 1st batch and 1st token (0) at 2nd batch
|
||||
self.assertListEqual(
|
||||
torch.isinf(filtered_scores_3_gram).tolist(), [[False, False, False], [True, False, False]]
|
||||
)
|
||||
|
||||
def test_no_bad_words_dist_processor(self):
|
||||
vocab_size = 5
|
||||
batch_size = 2
|
||||
eos_token_id = 4
|
||||
|
||||
input_ids = torch.tensor([[0, 1, 3, 1], [0, 1, 0, 1]], device=torch_device, dtype=torch.long)
|
||||
bad_word_tokens = [[1], [4], [1, 0], [0, 1, 2], [1, 3, 1, 3]]
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
|
||||
no_bad_words_dist_proc = NoBadWordsLogitsProcessor(bad_words_ids=bad_word_tokens, eos_token_id=eos_token_id)
|
||||
|
||||
filtered_scores = no_bad_words_dist_proc(input_ids, scores.clone())
|
||||
|
||||
# batch 1: 1st, 2nd, and 4th (0, 1, 3) token are forbidden
|
||||
# batch 2: 1st, 2nd, and 3rd (0, 1, 2) token are forbidden
|
||||
# Note that 5th element cannot be forbidden as it is EOS token
|
||||
self.assertListEqual(
|
||||
torch.isinf(filtered_scores).tolist(), [[True, True, False, True, False], [True, True, True, False, False]]
|
||||
)
|
||||
|
||||
# check edge case
|
||||
no_bad_words_dist_proc = NoBadWordsLogitsProcessor(bad_words_ids=[[4]], eos_token_id=eos_token_id)
|
||||
filtered_scores = no_bad_words_dist_proc(input_ids, scores.clone())
|
||||
self.assertTrue(torch.allclose(scores, filtered_scores, atol=1e-3))
|
||||
|
||||
def test_processor_list(self):
|
||||
batch_size = 4
|
||||
sequence_length = 10
|
||||
vocab_size = 15
|
||||
eos_token_id = 0
|
||||
|
||||
# dummy input_ids and scores
|
||||
input_ids = ids_tensor((batch_size, sequence_length), vocab_size)
|
||||
input_ids_comp = input_ids.clone()
|
||||
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
scores_comp = scores.clone()
|
||||
|
||||
# instantiate all dist processors
|
||||
min_dist_proc = MinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
|
||||
temp_dist_warp = TemperatureLogitsWarper(temperature=0.5)
|
||||
rep_penalty_proc = RepetitionPenaltyLogitsProcessor(penalty=2.0)
|
||||
top_k_warp = TopKLogitsWarper(3)
|
||||
top_p_warp = TopPLogitsWarper(0.8)
|
||||
no_repeat_proc = NoRepeatNGramLogitsProcessor(2)
|
||||
no_bad_words_dist_proc = NoBadWordsLogitsProcessor(bad_words_ids=[[1]], eos_token_id=eos_token_id)
|
||||
|
||||
# no processor list
|
||||
scores = min_dist_proc(input_ids, scores)
|
||||
scores = temp_dist_warp(input_ids, scores)
|
||||
scores = rep_penalty_proc(input_ids, scores)
|
||||
scores = top_k_warp(input_ids, scores)
|
||||
scores = top_p_warp(input_ids, scores)
|
||||
scores = no_repeat_proc(input_ids, scores)
|
||||
scores = no_bad_words_dist_proc(input_ids, scores)
|
||||
|
||||
# with processor list
|
||||
processor = LogitsProcessorList(
|
||||
[
|
||||
min_dist_proc,
|
||||
temp_dist_warp,
|
||||
rep_penalty_proc,
|
||||
top_k_warp,
|
||||
top_p_warp,
|
||||
no_repeat_proc,
|
||||
no_bad_words_dist_proc,
|
||||
]
|
||||
)
|
||||
scores_comp = processor(input_ids, scores_comp)
|
||||
|
||||
# scores should be equal
|
||||
self.assertTrue(torch.allclose(scores, scores_comp, atol=1e-3))
|
||||
|
||||
# input_ids should never be changed
|
||||
self.assertListEqual(input_ids.tolist(), input_ids_comp.tolist())
|
510
tests/test_generation_utils.py
Normal file
510
tests/test_generation_utils.py
Normal file
@ -0,0 +1,510 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Team Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a clone of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_torch, torch_device
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import top_k_top_p_filtering
|
||||
from transformers.generation_beam_search import BeamSearchScorer
|
||||
from transformers.generation_logits_process import (
|
||||
LogitsProcessorList,
|
||||
MinLengthLogitsProcessor,
|
||||
NoBadWordsLogitsProcessor,
|
||||
NoRepeatNGramLogitsProcessor,
|
||||
RepetitionPenaltyLogitsProcessor,
|
||||
TemperatureLogitsWarper,
|
||||
TopKLogitsWarper,
|
||||
TopPLogitsWarper,
|
||||
)
|
||||
|
||||
|
||||
class GenerationTesterMixin:
|
||||
model_tester = None
|
||||
all_generative_model_classes = ()
|
||||
|
||||
def _get_input_ids_and_config(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
input_ids = inputs_dict["input_ids"]
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
|
||||
# cut to half length & take max batch_size 3
|
||||
max_batch_size = 2
|
||||
sequence_length = input_ids.shape[-1] // 2
|
||||
input_ids = input_ids[:max_batch_size, :sequence_length]
|
||||
attention_mask = attention_mask[:max_batch_size, :sequence_length]
|
||||
|
||||
# generate max 3 tokens
|
||||
max_length = input_ids.shape[-1] + 3
|
||||
if config.eos_token_id is not None and config.pad_token_id is None:
|
||||
# hack to allow generate for models such as GPT2 as is done in `generate()`
|
||||
config.pad_token_id = config.eos_token_id
|
||||
return config, input_ids, attention_mask, max_length
|
||||
|
||||
@staticmethod
|
||||
def _get_logits_processor_and_kwargs(input_length, eos_token_id):
|
||||
process_kwargs = {
|
||||
"min_length": input_length + 1,
|
||||
"bad_words_ids": [[1, 0]],
|
||||
"no_repeat_ngram_size": 2,
|
||||
"repetition_penalty": 1.2,
|
||||
}
|
||||
logits_processor = LogitsProcessorList(
|
||||
(
|
||||
[
|
||||
MinLengthLogitsProcessor(process_kwargs["min_length"], eos_token_id),
|
||||
]
|
||||
if eos_token_id is not None
|
||||
else []
|
||||
)
|
||||
+ [
|
||||
NoBadWordsLogitsProcessor(process_kwargs["bad_words_ids"], eos_token_id),
|
||||
NoRepeatNGramLogitsProcessor(process_kwargs["no_repeat_ngram_size"]),
|
||||
RepetitionPenaltyLogitsProcessor(process_kwargs["repetition_penalty"]),
|
||||
]
|
||||
)
|
||||
return process_kwargs, logits_processor
|
||||
|
||||
@staticmethod
|
||||
def _get_warper_and_kwargs(num_beams):
|
||||
warp_kwargs = {"top_k": 10, "top_p": 0.7, "temperature": 0.7}
|
||||
logits_warper = LogitsProcessorList(
|
||||
[
|
||||
TopKLogitsWarper(top_k=warp_kwargs["top_k"], min_tokens_to_keep=(2 if num_beams > 1 else 1)),
|
||||
TopPLogitsWarper(top_p=warp_kwargs["top_p"], min_tokens_to_keep=(2 if num_beams > 1 else 1)),
|
||||
TemperatureLogitsWarper(warp_kwargs["temperature"]),
|
||||
]
|
||||
)
|
||||
return warp_kwargs, logits_warper
|
||||
|
||||
@staticmethod
|
||||
def _get_beam_scorer_and_kwargs(batch_size, max_length, num_return_sequences=1):
|
||||
beam_kwargs = {
|
||||
"early_stopping": False,
|
||||
"length_penalty": 2.0,
|
||||
"num_beams": 2,
|
||||
"num_return_sequences": num_return_sequences,
|
||||
}
|
||||
beam_scorer = BeamSearchScorer(
|
||||
batch_size=batch_size,
|
||||
max_length=max_length,
|
||||
num_beams=beam_kwargs["num_beams"],
|
||||
device=torch_device,
|
||||
length_penalty=beam_kwargs["length_penalty"],
|
||||
do_early_stopping=beam_kwargs["early_stopping"],
|
||||
num_beam_hyps_to_keep=num_return_sequences,
|
||||
)
|
||||
return beam_kwargs, beam_scorer
|
||||
|
||||
@staticmethod
|
||||
def _get_encoder_outputs(model, input_ids, attention_mask, num_interleave=1):
|
||||
encoder = model.get_encoder()
|
||||
encoder_outputs = encoder(input_ids, attention_mask=attention_mask, return_dict=True)
|
||||
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.repeat_interleave(
|
||||
num_interleave, dim=0
|
||||
)
|
||||
input_ids = torch.zeros_like(input_ids[:, :1]) + model._get_decoder_start_token_id()
|
||||
attention_mask = None
|
||||
return encoder_outputs, input_ids, attention_mask
|
||||
|
||||
def test_greedy_generate(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
|
||||
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
|
||||
input_ids.shape[-1], config.eos_token_id
|
||||
)
|
||||
|
||||
model = model_class(config).to(torch_device)
|
||||
model.eval()
|
||||
|
||||
# check `generate()` and `greedy_search()` are equal
|
||||
kwargs = {}
|
||||
if model.config.is_encoder_decoder:
|
||||
encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs(
|
||||
model, input_ids, attention_mask
|
||||
)
|
||||
kwargs["encoder_outputs"] = encoder_outputs
|
||||
max_length = 4
|
||||
|
||||
output_ids_generate = model.generate(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
do_sample=False,
|
||||
num_beams=1,
|
||||
max_length=max_length,
|
||||
**logits_process_kwargs,
|
||||
)
|
||||
with torch.no_grad():
|
||||
output_ids_greedy = model.greedy_search(
|
||||
input_ids,
|
||||
max_length=max_length,
|
||||
attention_mask=attention_mask,
|
||||
logits_processor=logits_processor,
|
||||
**kwargs,
|
||||
)
|
||||
self.assertListEqual(output_ids_generate.tolist(), output_ids_greedy.tolist())
|
||||
|
||||
def test_sample_generate(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
|
||||
input_ids.shape[-1], config.eos_token_id
|
||||
)
|
||||
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1)
|
||||
|
||||
model = model_class(config).to(torch_device)
|
||||
model.eval()
|
||||
|
||||
# check `generate()` and `sample()` are equal
|
||||
if model.config.is_encoder_decoder:
|
||||
max_length = 4
|
||||
|
||||
torch.manual_seed(0)
|
||||
output_ids_generate = model.generate(
|
||||
input_ids,
|
||||
do_sample=True,
|
||||
num_beams=1,
|
||||
max_length=max_length,
|
||||
attention_mask=attention_mask,
|
||||
**logits_warper_kwargs,
|
||||
**process_kwargs,
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
kwargs = {}
|
||||
if model.config.is_encoder_decoder:
|
||||
encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs(
|
||||
model, input_ids, attention_mask
|
||||
)
|
||||
kwargs["encoder_outputs"] = encoder_outputs
|
||||
else:
|
||||
attention_mask_clone = attention_mask
|
||||
input_ids_clone = input_ids
|
||||
|
||||
with torch.no_grad():
|
||||
output_ids_sample = model.sample(
|
||||
input_ids_clone,
|
||||
attention_mask=attention_mask_clone,
|
||||
max_length=max_length,
|
||||
logits_processor=logits_processor,
|
||||
logits_warper=logits_warper,
|
||||
**kwargs,
|
||||
)
|
||||
self.assertListEqual(output_ids_generate.tolist(), output_ids_sample.tolist())
|
||||
|
||||
# check `generate()` and `sample()` yield equal results for `num_return_sequences`
|
||||
num_return_sequences = 3
|
||||
if model.config.is_encoder_decoder:
|
||||
max_length = 4
|
||||
|
||||
torch.manual_seed(0)
|
||||
output_ids_generate = model.generate(
|
||||
input_ids,
|
||||
do_sample=True,
|
||||
num_beams=1,
|
||||
max_length=max_length,
|
||||
num_return_sequences=num_return_sequences,
|
||||
attention_mask=attention_mask,
|
||||
**logits_warper_kwargs,
|
||||
**process_kwargs,
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
kwargs = {}
|
||||
if model.config.is_encoder_decoder:
|
||||
encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs(
|
||||
model, input_ids, attention_mask, num_interleave=num_return_sequences
|
||||
)
|
||||
kwargs["encoder_outputs"] = encoder_outputs
|
||||
input_ids_clone = input_ids_clone.repeat_interleave(num_return_sequences, dim=0)
|
||||
else:
|
||||
attention_mask_clone = attention_mask.repeat_interleave(num_return_sequences, dim=0)
|
||||
input_ids_clone = input_ids.repeat_interleave(num_return_sequences, dim=0)
|
||||
|
||||
with torch.no_grad():
|
||||
output_ids_sample = model.sample(
|
||||
input_ids_clone,
|
||||
attention_mask=attention_mask_clone,
|
||||
max_length=max_length,
|
||||
logits_processor=logits_processor,
|
||||
logits_warper=logits_warper,
|
||||
**kwargs,
|
||||
)
|
||||
self.assertListEqual(output_ids_generate.tolist(), output_ids_sample.tolist())
|
||||
|
||||
def test_beam_search_generate(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
|
||||
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
|
||||
input_ids.shape[-1], config.eos_token_id
|
||||
)
|
||||
|
||||
model = model_class(config).to(torch_device)
|
||||
model.eval()
|
||||
|
||||
# check `generate()` and `beam_search()` are equal
|
||||
if model.config.is_encoder_decoder:
|
||||
max_length = 4
|
||||
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length)
|
||||
output_ids_generate = model.generate(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
do_sample=False,
|
||||
max_length=max_length,
|
||||
**beam_kwargs,
|
||||
**logits_process_kwargs,
|
||||
)
|
||||
|
||||
# beam_search does not automatically interleave `batch_size` dim for `num_beams`
|
||||
kwargs = {}
|
||||
if model.config.is_encoder_decoder:
|
||||
encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs(
|
||||
model, input_ids, attention_mask, num_interleave=beam_scorer.num_beams
|
||||
)
|
||||
kwargs["encoder_outputs"] = encoder_outputs
|
||||
input_ids_clone = input_ids_clone.repeat_interleave(beam_scorer.num_beams, dim=0)
|
||||
else:
|
||||
attention_mask_clone = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0)
|
||||
input_ids_clone = input_ids.repeat_interleave(beam_scorer.num_beams, dim=0)
|
||||
|
||||
with torch.no_grad():
|
||||
output_ids_beam_search = model.beam_search(
|
||||
input_ids_clone,
|
||||
beam_scorer,
|
||||
max_length=max_length,
|
||||
attention_mask=attention_mask_clone,
|
||||
logits_processor=logits_processor,
|
||||
**kwargs,
|
||||
)
|
||||
self.assertListEqual(output_ids_generate.tolist(), output_ids_beam_search.tolist())
|
||||
|
||||
# check `generate()` and `beam_search()` are equal for `num_return_sequences`
|
||||
num_return_sequences = 2
|
||||
if model.config.is_encoder_decoder:
|
||||
max_length = 4
|
||||
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(
|
||||
input_ids.shape[0], max_length, num_return_sequences=num_return_sequences
|
||||
)
|
||||
|
||||
output_ids_generate = model.generate(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
do_sample=False,
|
||||
max_length=max_length,
|
||||
**beam_kwargs,
|
||||
**logits_process_kwargs,
|
||||
)
|
||||
# beam_search does not automatically interleave `batch_size` dim for `num_beams`
|
||||
kwargs = {}
|
||||
if model.config.is_encoder_decoder:
|
||||
encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs(
|
||||
model, input_ids, attention_mask, num_interleave=beam_scorer.num_beams
|
||||
)
|
||||
kwargs["encoder_outputs"] = encoder_outputs
|
||||
input_ids_clone = input_ids_clone.repeat_interleave(beam_scorer.num_beams, dim=0)
|
||||
else:
|
||||
attention_mask_clone = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0)
|
||||
input_ids_clone = input_ids.repeat_interleave(beam_scorer.num_beams, dim=0)
|
||||
|
||||
with torch.no_grad():
|
||||
output_ids_beam_search = model.beam_search(
|
||||
input_ids_clone,
|
||||
beam_scorer,
|
||||
max_length=max_length,
|
||||
attention_mask=attention_mask_clone,
|
||||
logits_processor=logits_processor,
|
||||
**kwargs,
|
||||
)
|
||||
self.assertListEqual(output_ids_generate.tolist(), output_ids_beam_search.tolist())
|
||||
|
||||
def test_beam_sample_generate(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1)
|
||||
|
||||
model = model_class(config).to(torch_device)
|
||||
model.eval()
|
||||
|
||||
# check `generate()` and `beam_search()` are equal
|
||||
# change `num_return_sequences = 2` but not for `beam_scorer`
|
||||
num_return_sequences = 2
|
||||
if model.config.is_encoder_decoder:
|
||||
max_length = 4
|
||||
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(
|
||||
input_ids.shape[0] * num_return_sequences, max_length
|
||||
)
|
||||
beam_kwargs["num_return_sequences"] = num_return_sequences
|
||||
torch.manual_seed(0)
|
||||
output_ids_generate = model.generate(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
do_sample=True,
|
||||
max_length=max_length,
|
||||
**beam_kwargs,
|
||||
**logits_warper_kwargs,
|
||||
)
|
||||
# beam_search does not automatically interleave `batch_size` dim for `num_beams * num_return_sequences`
|
||||
kwargs = {}
|
||||
if model.config.is_encoder_decoder:
|
||||
encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs(
|
||||
model, input_ids, attention_mask, num_interleave=beam_scorer.num_beams * num_return_sequences
|
||||
)
|
||||
kwargs["encoder_outputs"] = encoder_outputs
|
||||
else:
|
||||
attention_mask = attention_mask.repeat_interleave(beam_scorer.num_beams * num_return_sequences, dim=0)
|
||||
|
||||
torch.manual_seed(0)
|
||||
with torch.no_grad():
|
||||
output_ids_beam_sample = model.beam_sample(
|
||||
input_ids.repeat_interleave(beam_scorer.num_beams * num_return_sequences, dim=0),
|
||||
beam_scorer,
|
||||
max_length=max_length,
|
||||
attention_mask=attention_mask,
|
||||
logits_warper=logits_warper,
|
||||
**kwargs,
|
||||
)
|
||||
self.assertListEqual(output_ids_generate.tolist(), output_ids_beam_sample.tolist())
|
||||
|
||||
def test_generate_without_input_ids(self):
|
||||
config, _, _, max_length = self._get_input_ids_and_config()
|
||||
|
||||
# if no bos token id => cannot generate from None
|
||||
if config.bos_token_id is None:
|
||||
return
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config).to(torch_device)
|
||||
model.eval()
|
||||
|
||||
output_ids_generate = model.generate(
|
||||
do_sample=False,
|
||||
max_length=max_length,
|
||||
)
|
||||
|
||||
self.assertIsNotNone(output_ids_generate)
|
||||
|
||||
|
||||
@require_torch
|
||||
class UtilsFunctionsTest(unittest.TestCase):
|
||||
|
||||
# tests whether the top_k_top_p function behaves as expected
|
||||
def test_top_k_top_p_filtering(self):
|
||||
logits = torch.tensor(
|
||||
[
|
||||
[
|
||||
8.2220991, # 3rd highest value; idx. 0
|
||||
-0.5620044,
|
||||
5.23229752,
|
||||
4.0386393,
|
||||
-6.8798378,
|
||||
-0.54785802,
|
||||
-3.2012153,
|
||||
2.92777176,
|
||||
1.88171953,
|
||||
7.35341276,
|
||||
8.43207833, # 2nd highest value; idx. 10
|
||||
-9.85711836,
|
||||
-5.96209236,
|
||||
-1.13039161,
|
||||
-7.1115294,
|
||||
-0.8369633,
|
||||
-5.3186408,
|
||||
7.06427407,
|
||||
0.81369344,
|
||||
-0.82023817,
|
||||
-5.9179796,
|
||||
0.58813443,
|
||||
-6.99778438,
|
||||
4.71551189,
|
||||
-0.18771637,
|
||||
7.44020759, # 4th highest value; idx. 25
|
||||
9.38450987, # 1st highest value; idx. 26
|
||||
2.12662941,
|
||||
-9.32562038,
|
||||
2.35652522,
|
||||
], # cummulative prob of 4 highest values <= 0.6
|
||||
[
|
||||
0.58425518,
|
||||
4.53139238,
|
||||
-5.57510464,
|
||||
-6.28030699,
|
||||
-7.19529503,
|
||||
-4.02122551,
|
||||
1.39337037,
|
||||
-6.06707057,
|
||||
1.59480517,
|
||||
-9.643119,
|
||||
0.03907799,
|
||||
0.67231762,
|
||||
-8.88206726,
|
||||
6.27115922, # 4th highest value; idx. 13
|
||||
2.28520723,
|
||||
4.82767506,
|
||||
4.30421368,
|
||||
8.8275313, # 2nd highest value; idx. 17
|
||||
5.44029958,
|
||||
-4.4735794,
|
||||
7.38579536, # 3rd highest value; idx. 20
|
||||
-2.91051663,
|
||||
2.61946077,
|
||||
-2.5674762,
|
||||
-9.48959302,
|
||||
-4.02922645,
|
||||
-1.35416918,
|
||||
9.67702323, # 1st highest value; idx. 27
|
||||
-5.89478553,
|
||||
1.85370467,
|
||||
], # cummulative prob of 4 highest values <= 0.6
|
||||
],
|
||||
dtype=torch.float,
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
non_inf_expected_idx = torch.tensor(
|
||||
[[0, 0], [0, 10], [0, 25], [0, 26], [1, 13], [1, 17], [1, 20], [1, 27]],
|
||||
dtype=torch.long,
|
||||
device=torch_device,
|
||||
) # expected non filtered idx as noted above
|
||||
|
||||
non_inf_expected_output = torch.tensor(
|
||||
[
|
||||
8.2221,
|
||||
8.4321,
|
||||
7.4402,
|
||||
9.3845,
|
||||
6.2712,
|
||||
8.8275,
|
||||
7.3858,
|
||||
9.6770,
|
||||
], # expected non filtered values as noted above
|
||||
dtype=torch.float,
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
output = top_k_top_p_filtering(logits, top_k=10, top_p=0.6, min_tokens_to_keep=4)
|
||||
non_inf_output = output[output != -float("inf")].to(device=torch_device)
|
||||
non_inf_idx = (output != -float("inf")).nonzero().to(device=torch_device)
|
||||
|
||||
self.assertTrue(torch.allclose(non_inf_expected_output, non_inf_output, atol=1e-12))
|
||||
self.assertTrue(torch.all(torch.eq(non_inf_expected_idx, non_inf_idx)))
|
@ -23,6 +23,7 @@ from transformers.file_utils import cached_property
|
||||
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_generation_utils import GenerationTesterMixin
|
||||
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
|
||||
|
||||
@ -128,7 +129,7 @@ def prepare_bart_inputs_dict(
|
||||
|
||||
|
||||
@require_torch
|
||||
class BARTModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
class BARTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(BartModel, BartForConditionalGeneration, BartForSequenceClassification, BartForQuestionAnswering)
|
||||
if is_torch_available()
|
||||
|
@ -20,6 +20,7 @@ from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_generation_utils import GenerationTesterMixin
|
||||
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
@ -357,11 +358,12 @@ class BertModelTester:
|
||||
|
||||
|
||||
@require_torch
|
||||
class BertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (
|
||||
(
|
||||
BertModel,
|
||||
BertLMHeadModel,
|
||||
BertForMaskedLM,
|
||||
BertForMultipleChoice,
|
||||
BertForNextSentencePrediction,
|
||||
@ -373,6 +375,7 @@ class BertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
all_generative_model_classes = (BertLMHeadModel,) if is_torch_available() else ()
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = BertModelTester(self)
|
||||
|
@ -20,6 +20,7 @@ from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_generation_utils import GenerationTesterMixin
|
||||
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
@ -183,9 +184,10 @@ class BertGenerationEncoderTester:
|
||||
|
||||
|
||||
@require_torch
|
||||
class BertGenerationEncoderTest(ModelTesterMixin, unittest.TestCase):
|
||||
class BertGenerationEncoderTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (BertGenerationEncoder, BertGenerationDecoder) if is_torch_available() else ()
|
||||
all_generative_model_classes = (BertGenerationDecoder,) if is_torch_available() else ()
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = BertGenerationEncoderTester(self)
|
||||
|
@ -147,6 +147,7 @@ class Blenderbot3BIntegrationTests(unittest.TestCase):
|
||||
|
||||
src_text = ["Sam"]
|
||||
model_inputs = self.tokenizer(src_text, return_tensors="pt").to(torch_device)
|
||||
|
||||
generated_utterances = model.generate(**model_inputs, **FASTER_GEN_KWARGS)
|
||||
tgt_text = 'Sam is a great name. It means "sun" in Gaelic.'
|
||||
|
||||
@ -156,6 +157,7 @@ class Blenderbot3BIntegrationTests(unittest.TestCase):
|
||||
src_text = "Social anxiety\nWow, I am never shy. Do you have anxiety?\nYes. I end up sweating and blushing and feel like i'm going to throw up.\nand why is that?"
|
||||
|
||||
model_inputs = self.tokenizer([src_text], return_tensors="pt").to(torch_device)
|
||||
|
||||
generated_ids = model.generate(**model_inputs, **FASTER_GEN_KWARGS)[0]
|
||||
reply = self.tokenizer.decode(generated_ids, **TOK_DECODE_KW)
|
||||
|
||||
@ -187,6 +189,9 @@ class Blenderbot90MIntegrationTests(unittest.TestCase):
|
||||
]
|
||||
|
||||
model_inputs = self.tokenizer(src_text, return_tensors="pt").to(torch_device)
|
||||
|
||||
# model does not have "token_type_ids"
|
||||
model_inputs.pop("token_type_ids")
|
||||
assert isinstance(self.tokenizer, BlenderbotSmallTokenizer)
|
||||
generated_ids = self.model.generate(**model_inputs)[0]
|
||||
reply = self.tokenizer.decode(generated_ids, **TOK_DECODE_KW)
|
||||
@ -198,10 +203,11 @@ class Blenderbot90MIntegrationTests(unittest.TestCase):
|
||||
|
||||
def test_90_generation_from_short_input(self):
|
||||
model_inputs = self.tokenizer(["sam"], return_tensors="pt").to(torch_device)
|
||||
generated_utterances = self.model.generate(**model_inputs)
|
||||
# generated_txt = self.tokenizer.decode(generated_utterances[0])
|
||||
|
||||
# assert generated_txt == "__start__ have you ever heard of sam harris? he's an american singer, songwriter, and actor. __end__"
|
||||
# model does not have "token_type_ids"
|
||||
model_inputs.pop("token_type_ids")
|
||||
generated_utterances = self.model.generate(**model_inputs)
|
||||
|
||||
clean_txt = self.tokenizer.decode(generated_utterances[0], **TOK_DECODE_KW)
|
||||
assert clean_txt in (
|
||||
"have you ever been to a sam club? it's a great club in the south.",
|
||||
|
@ -44,7 +44,6 @@ if is_torch_available():
|
||||
BertModel,
|
||||
PretrainedConfig,
|
||||
PreTrainedModel,
|
||||
top_k_top_p_filtering,
|
||||
)
|
||||
|
||||
|
||||
@ -882,126 +881,6 @@ class ModelTesterMixin:
|
||||
with torch.no_grad():
|
||||
model(**inputs)[0]
|
||||
|
||||
def test_lm_head_model_random_no_beam_search_generate(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
input_ids = inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"]
|
||||
|
||||
# make sure that input_ids is at most of size 15
|
||||
input_ids = input_ids[..., :15]
|
||||
|
||||
# iterate over all generative models
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config).to(torch_device)
|
||||
model.eval()
|
||||
|
||||
if config.bos_token_id is None:
|
||||
# if bos token id is not defined, model needs input_ids
|
||||
with self.assertRaises(AssertionError):
|
||||
model.generate(do_sample=True, max_length=5)
|
||||
# num_return_sequences = 1
|
||||
self._check_generated_ids(model.generate(input_ids, do_sample=True))
|
||||
else:
|
||||
# num_return_sequences = 1
|
||||
self._check_generated_ids(model.generate(do_sample=True, max_length=5))
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
# generating multiple sequences when no beam search generation
|
||||
# is not allowed as it would always generate the same sequences
|
||||
model.generate(input_ids, do_sample=False, num_beams=1, num_return_sequences=2)
|
||||
|
||||
# num_return_sequences > 1, sample
|
||||
self._check_generated_ids(model.generate(input_ids, do_sample=True, num_return_sequences=2))
|
||||
|
||||
# check bad words tokens language generation
|
||||
# create list of 1-seq bad token and list of 2-seq of bad tokens
|
||||
bad_words_ids = [
|
||||
self._generate_random_bad_tokens(1, model.config),
|
||||
self._generate_random_bad_tokens(2, model.config),
|
||||
]
|
||||
output_tokens = model.generate(
|
||||
input_ids, do_sample=True, bad_words_ids=bad_words_ids, num_return_sequences=2
|
||||
)
|
||||
# only count generated tokens
|
||||
generated_ids = output_tokens[:, input_ids.shape[-1] :]
|
||||
self.assertFalse(self._check_match_tokens(generated_ids.tolist(), bad_words_ids))
|
||||
|
||||
def test_lm_head_model_random_beam_search_generate(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
input_ids = (inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"]).to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
# make sure that input_ids is at most of size 15
|
||||
input_ids = input_ids[..., :15]
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config).to(torch_device)
|
||||
model.eval()
|
||||
|
||||
if config.bos_token_id is None:
|
||||
# if bos token id is not defined mobel needs input_ids, num_return_sequences = 1
|
||||
self._check_generated_ids(model.generate(input_ids, do_sample=True, num_beams=2))
|
||||
else:
|
||||
# num_return_sequences = 1
|
||||
self._check_generated_ids(model.generate(do_sample=True, max_length=5, num_beams=2))
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
# generating more sequences than having beams leads is not possible
|
||||
model.generate(input_ids, do_sample=False, num_return_sequences=3, num_beams=2)
|
||||
|
||||
# num_return_sequences > 1, sample
|
||||
self._check_generated_ids(
|
||||
model.generate(
|
||||
input_ids,
|
||||
do_sample=True,
|
||||
num_beams=2,
|
||||
num_return_sequences=2,
|
||||
)
|
||||
)
|
||||
# num_return_sequences > 1, greedy
|
||||
self._check_generated_ids(model.generate(input_ids, do_sample=False, num_beams=2, num_return_sequences=2))
|
||||
|
||||
# check bad words tokens language generation
|
||||
# create list of 1-seq bad token and list of 2-seq of bad tokens
|
||||
bad_words_ids = [
|
||||
self._generate_random_bad_tokens(1, model.config),
|
||||
self._generate_random_bad_tokens(2, model.config),
|
||||
]
|
||||
output_tokens = model.generate(
|
||||
input_ids, do_sample=False, bad_words_ids=bad_words_ids, num_beams=2, num_return_sequences=2
|
||||
)
|
||||
# only count generated tokens
|
||||
generated_ids = output_tokens[:, input_ids.shape[-1] :]
|
||||
self.assertFalse(self._check_match_tokens(generated_ids.tolist(), bad_words_ids))
|
||||
|
||||
def _generate_random_bad_tokens(self, num_bad_tokens: int, config) -> List[int]:
|
||||
# special tokens cannot be bad tokens
|
||||
special_tokens = [x for x in [config.bos_token_id, config.eos_token_id, config.pad_token_id] if x is not None]
|
||||
# create random bad tokens that are not special tokens
|
||||
bad_tokens = []
|
||||
while len(bad_tokens) < num_bad_tokens:
|
||||
token = ids_tensor((1, 1), self.model_tester.vocab_size).squeeze(0).cpu().numpy()[0]
|
||||
if token not in special_tokens:
|
||||
bad_tokens.append(token)
|
||||
return bad_tokens
|
||||
|
||||
def _check_generated_ids(self, output_ids):
|
||||
for token_id in output_ids[0].tolist():
|
||||
self.assertGreaterEqual(token_id, 0)
|
||||
self.assertLess(token_id, self.model_tester.vocab_size)
|
||||
|
||||
def _check_match_tokens(self, generated_ids, bad_words_ids):
|
||||
# for all bad word tokens
|
||||
for bad_word_ids in bad_words_ids:
|
||||
# for all slices in batch
|
||||
for generated_ids_slice in generated_ids:
|
||||
# for all word idx
|
||||
for i in range(len(bad_word_ids), len(generated_ids_slice)):
|
||||
# if tokens match
|
||||
if generated_ids_slice[i - len(bad_word_ids) : i] == bad_word_ids:
|
||||
return True
|
||||
return False
|
||||
|
||||
@require_torch_multigpu
|
||||
def test_multigpu_data_parallel_forward(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
@ -1094,110 +973,3 @@ class ModelUtilsTest(unittest.TestCase):
|
||||
model = BertModel.from_pretrained(model_name, output_attentions=True, output_hidden_states=True)
|
||||
self.assertEqual(model.config.output_hidden_states, True)
|
||||
self.assertEqual(model.config, config)
|
||||
|
||||
|
||||
@require_torch
|
||||
class UtilsFunctionsTest(unittest.TestCase):
|
||||
|
||||
# tests whether the top_k_top_p function behaves as expected
|
||||
def test_top_k_top_p_filtering(self):
|
||||
logits = torch.tensor(
|
||||
[
|
||||
[
|
||||
8.2220991, # 3rd highest value; idx. 0
|
||||
-0.5620044,
|
||||
5.23229752,
|
||||
4.0386393,
|
||||
-6.8798378,
|
||||
-0.54785802,
|
||||
-3.2012153,
|
||||
2.92777176,
|
||||
1.88171953,
|
||||
7.35341276, # 5th highest value; idx. 9
|
||||
8.43207833, # 2nd highest value; idx. 10
|
||||
-9.85711836,
|
||||
-5.96209236,
|
||||
-1.13039161,
|
||||
-7.1115294,
|
||||
-0.8369633,
|
||||
-5.3186408,
|
||||
7.06427407,
|
||||
0.81369344,
|
||||
-0.82023817,
|
||||
-5.9179796,
|
||||
0.58813443,
|
||||
-6.99778438,
|
||||
4.71551189,
|
||||
-0.18771637,
|
||||
7.44020759, # 4th highest value; idx. 25
|
||||
9.38450987, # 1st highest value; idx. 26
|
||||
2.12662941,
|
||||
-9.32562038,
|
||||
2.35652522,
|
||||
], # cumulative prob of 5 highest values <= 0.6
|
||||
[
|
||||
0.58425518,
|
||||
4.53139238,
|
||||
-5.57510464,
|
||||
-6.28030699,
|
||||
-7.19529503,
|
||||
-4.02122551,
|
||||
1.39337037,
|
||||
-6.06707057,
|
||||
1.59480517,
|
||||
-9.643119,
|
||||
0.03907799,
|
||||
0.67231762,
|
||||
-8.88206726,
|
||||
6.27115922, # 4th highest value; idx. 13
|
||||
2.28520723,
|
||||
4.82767506,
|
||||
4.30421368,
|
||||
8.8275313, # 2nd highest value; idx. 17
|
||||
5.44029958, # 5th highest value; idx. 18
|
||||
-4.4735794,
|
||||
7.38579536, # 3rd highest value; idx. 20
|
||||
-2.91051663,
|
||||
2.61946077,
|
||||
-2.5674762,
|
||||
-9.48959302,
|
||||
-4.02922645,
|
||||
-1.35416918,
|
||||
9.67702323, # 1st highest value; idx. 27
|
||||
-5.89478553,
|
||||
1.85370467,
|
||||
], # cumulative prob of 5 highest values <= 0.6
|
||||
],
|
||||
dtype=torch.float,
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
non_inf_expected_idx = torch.tensor(
|
||||
[[0, 0], [0, 9], [0, 10], [0, 25], [0, 26], [1, 13], [1, 17], [1, 18], [1, 20], [1, 27]],
|
||||
dtype=torch.long,
|
||||
device=torch_device,
|
||||
) # expected non filtered idx as noted above
|
||||
|
||||
non_inf_expected_output = torch.tensor(
|
||||
[
|
||||
8.2221,
|
||||
7.3534,
|
||||
8.4321,
|
||||
7.4402,
|
||||
9.3845,
|
||||
6.2712,
|
||||
8.8275,
|
||||
5.4403,
|
||||
7.3858,
|
||||
9.6770,
|
||||
], # expected non filtered values as noted above
|
||||
dtype=torch.float,
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
output = top_k_top_p_filtering(logits, top_k=10, top_p=0.6, min_tokens_to_keep=4)
|
||||
non_inf_output = output[output != -float("inf")].to(device=torch_device)
|
||||
non_inf_idx = (output != -float("inf")).nonzero().to(device=torch_device)
|
||||
|
||||
self.assertTrue(torch.allclose(non_inf_expected_output, non_inf_output, atol=1e-12))
|
||||
self.assertTrue(torch.all(torch.eq(non_inf_expected_idx, non_inf_idx)))
|
||||
|
@ -19,6 +19,7 @@ from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_generation_utils import GenerationTesterMixin
|
||||
from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
@ -151,7 +152,7 @@ class CTRLModelTester:
|
||||
|
||||
|
||||
@require_torch
|
||||
class CTRLModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
class CTRLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (CTRLModel, CTRLLMHeadModel) if is_torch_available() else ()
|
||||
all_generative_model_classes = (CTRLLMHeadModel,) if is_torch_available() else ()
|
||||
|
@ -24,6 +24,7 @@ from transformers.file_utils import cached_property
|
||||
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_generation_utils import GenerationTesterMixin
|
||||
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
|
||||
|
||||
@ -120,7 +121,7 @@ def prepare_fsmt_inputs_dict(
|
||||
|
||||
|
||||
@require_torch
|
||||
class FSMTModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
class FSMTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (FSMTModel, FSMTForConditionalGeneration) if is_torch_available() else ()
|
||||
all_generative_model_classes = (FSMTForConditionalGeneration,) if is_torch_available() else ()
|
||||
is_encoder_decoder = True
|
||||
|
@ -20,6 +20,7 @@ from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_generation_utils import GenerationTesterMixin
|
||||
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
@ -377,7 +378,7 @@ class GPT2ModelTester:
|
||||
|
||||
|
||||
@require_torch
|
||||
class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (
|
||||
(GPT2Model, GPT2LMHeadModel, GPT2DoubleHeadsModel, GPT2ForSequenceClassification)
|
||||
@ -510,32 +511,17 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase):
|
||||
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
|
||||
|
||||
@slow
|
||||
def test_lm_generate_distilgpt2(self):
|
||||
model = GPT2LMHeadModel.from_pretrained("distilgpt2")
|
||||
def test_gpt2_sample(self):
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||
model = GPT2LMHeadModel.from_pretrained("gpt2")
|
||||
model.to(torch_device)
|
||||
input_ids = torch.tensor([[464, 1893]], dtype=torch.long, device=torch_device) # The president
|
||||
expected_output_ids = [
|
||||
464,
|
||||
1893,
|
||||
286,
|
||||
262,
|
||||
1578,
|
||||
1829,
|
||||
11,
|
||||
290,
|
||||
262,
|
||||
1893,
|
||||
286,
|
||||
262,
|
||||
1578,
|
||||
7526,
|
||||
11,
|
||||
423,
|
||||
587,
|
||||
287,
|
||||
262,
|
||||
2635,
|
||||
] # The president of the United States, and the president of the United Kingdom, have been in the White
|
||||
|
||||
output_ids = model.generate(input_ids, do_sample=False)
|
||||
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
|
||||
torch.manual_seed(0)
|
||||
input_ids = tokenizer("Today is a nice day and", return_tensors="pt").input_ids.to(torch_device)
|
||||
output_ids = model.generate(input_ids, do_sample=True)
|
||||
output_str = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
||||
|
||||
EXPECTED_OUTPUT_STR = (
|
||||
"Today is a nice day and if you don't know anything about the state of play during your holiday"
|
||||
)
|
||||
self.assertEqual(output_str, EXPECTED_OUTPUT_STR)
|
||||
|
@ -20,6 +20,7 @@ from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_generation_utils import GenerationTesterMixin
|
||||
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
|
||||
|
||||
@ -170,7 +171,7 @@ class OpenAIGPTModelTester:
|
||||
|
||||
|
||||
@require_torch
|
||||
class OpenAIGPTModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
class OpenAIGPTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (
|
||||
(OpenAIGPTModel, OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel, OpenAIGPTForSequenceClassification)
|
||||
|
@ -22,6 +22,7 @@ from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_generation_utils import GenerationTesterMixin
|
||||
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||
|
||||
|
||||
@ -853,7 +854,7 @@ class ProphetNetStandaloneEncoderModelTester:
|
||||
|
||||
|
||||
@require_torch
|
||||
class ProphetNetModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
class ProphetNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (ProphetNetModel, ProphetNetForConditionalGeneration) if is_torch_available() else ()
|
||||
all_generative_model_classes = (ProphetNetForConditionalGeneration,) if is_torch_available() else ()
|
||||
test_pruning = False
|
||||
@ -917,7 +918,7 @@ class ProphetNetModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
|
||||
@require_torch
|
||||
class ProphetNetStandaloneDecoderModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
class ProphetNetStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (ProphetNetDecoder, ProphetNetForCausalLM) if is_torch_available() else ()
|
||||
all_generative_model_classes = (ProphetNetForCausalLM,) if is_torch_available() else ()
|
||||
test_pruning = False
|
||||
|
@ -26,6 +26,7 @@ from transformers.testing_utils import (
|
||||
)
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_generation_utils import GenerationTesterMixin
|
||||
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
@ -196,11 +197,14 @@ class ReformerModelTester:
|
||||
)
|
||||
|
||||
def create_and_check_reformer_model_with_lm_backward(self, config, input_ids, input_mask, choice_labels):
|
||||
if not self.is_training:
|
||||
return
|
||||
|
||||
config.is_decoder = False
|
||||
config.lsh_num_chunks_after = 1
|
||||
model = ReformerForMaskedLM(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
model.train()
|
||||
loss = model(input_ids, attention_mask=input_mask, labels=input_ids)["loss"]
|
||||
loss.backward()
|
||||
|
||||
@ -569,7 +573,7 @@ class ReformerTesterMixin:
|
||||
|
||||
|
||||
@require_torch
|
||||
class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.TestCase):
|
||||
class ReformerLocalAttnModelTest(ReformerTesterMixin, GenerationTesterMixin, ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(ReformerModel, ReformerModelWithLMHead, ReformerForSequenceClassification, ReformerForQuestionAnswering)
|
||||
if is_torch_available()
|
||||
@ -629,7 +633,7 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest
|
||||
|
||||
|
||||
@require_torch
|
||||
class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.TestCase):
|
||||
class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(ReformerModel, ReformerModelWithLMHead, ReformerForSequenceClassification, ReformerForQuestionAnswering)
|
||||
if is_torch_available()
|
||||
|
@ -20,6 +20,7 @@ from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_generation_utils import GenerationTesterMixin
|
||||
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
@ -267,7 +268,7 @@ class RobertaModelTester:
|
||||
|
||||
|
||||
@require_torch
|
||||
class RobertaModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
class RobertaModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (
|
||||
(
|
||||
@ -282,6 +283,7 @@ class RobertaModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
all_generative_model_classes = (RobertaForCausalLM,) if is_torch_available() else ()
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = RobertaModelTester(self)
|
||||
|
@ -23,6 +23,7 @@ from transformers.file_utils import cached_property
|
||||
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_generation_utils import GenerationTesterMixin
|
||||
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
|
||||
|
||||
@ -466,7 +467,7 @@ class T5ModelTester:
|
||||
|
||||
|
||||
@require_torch
|
||||
class T5ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()
|
||||
all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else ()
|
||||
@ -592,6 +593,7 @@ class T5ModelIntegrationTests(unittest.TestCase):
|
||||
do_sample=False,
|
||||
early_stopping=True,
|
||||
)
|
||||
|
||||
decoded = tok.batch_decode(hypotheses_batch, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
self.assertListEqual(
|
||||
expected_summaries,
|
||||
|
@ -20,6 +20,7 @@ from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_torch, require_torch_multigpu, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_generation_utils import GenerationTesterMixin
|
||||
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
|
||||
|
||||
@ -156,7 +157,7 @@ class TransfoXLModelTester:
|
||||
|
||||
|
||||
@require_torch
|
||||
class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
class TransfoXLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (TransfoXLModel, TransfoXLLMHeadModel) if is_torch_available() else ()
|
||||
all_generative_model_classes = (TransfoXLLMHeadModel,) if is_torch_available() else ()
|
||||
test_pruning = False
|
||||
|
@ -20,6 +20,7 @@ from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_generation_utils import GenerationTesterMixin
|
||||
from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
@ -331,7 +332,7 @@ class XLMModelTester:
|
||||
|
||||
|
||||
@require_torch
|
||||
class XLMModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
class XLMModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (
|
||||
(
|
||||
|
@ -21,6 +21,7 @@ from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_generation_utils import GenerationTesterMixin
|
||||
from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
@ -479,7 +480,7 @@ class XLNetModelTester:
|
||||
|
||||
|
||||
@require_torch
|
||||
class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(
|
||||
XLNetModel,
|
||||
|
Loading…
Reference in New Issue
Block a user