From a1bbcf3f6c20e15fe799a8659d6b7bd36fdf11ed Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 3 Nov 2020 16:04:22 +0100 Subject: [PATCH] Refactoring the generate() function (#6949) * first draft * show design proposition for new generate method * up * make better readable * make first version * gpt2 tests pass * make beam search for gpt2 work * add first encoder-decoder code * delete typo * make t5 work * save indermediate * make bart work with beam search * finish beam search bart / t5 * add default kwargs * make more tests pass * fix no bad words sampler * some fixes and tests for all distribution processors * fix test * fix rag slow tests * merge to master * add nograd to generate * make all slow tests pass * speed up generate * fix edge case bug * small fix * correct typo * add type hints and docstrings * fix typos in tests * add beam search tests * add tests for beam scorer * fix test rag * finish beam search tests * move generation tests in seperate file * fix generation tests * more tests * add aggressive generation tests * fix tests * add gpt2 sample test * add more docstring * add more docs * finish doc strings * apply some more of sylvains and sams comments * fix some typos * make fix copies * apply lysandres and sylvains comments * final corrections on examples * small fix for reformer --- docs/source/index.rst | 1 + docs/source/internal/generation_utils.rst | 50 + docs/source/main_classes/model.rst | 2 +- src/transformers/__init__.py | 13 + src/transformers/generation_beam_search.py | 357 ++++ src/transformers/generation_logits_process.py | 374 ++++ src/transformers/generation_utils.py | 1811 +++++++++-------- src/transformers/modeling_bart.py | 2 +- src/transformers/modeling_ctrl.py | 4 +- src/transformers/modeling_encoder_decoder.py | 2 +- src/transformers/modeling_fsmt.py | 2 +- src/transformers/modeling_prophetnet.py | 2 +- src/transformers/modeling_rag.py | 136 +- src/transformers/modeling_reformer.py | 16 +- src/transformers/modeling_t5.py | 4 +- src/transformers/modeling_transfo_xl.py | 2 +- src/transformers/modeling_xlnet.py | 4 +- src/transformers/testing_utils.py | 4 +- src/transformers/utils/dummy_pt_objects.py | 60 + tests/test_generation_beam_search.py | 239 +++ tests/test_generation_logits_process.py | 283 +++ tests/test_generation_utils.py | 510 +++++ tests/test_modeling_bart.py | 3 +- tests/test_modeling_bert.py | 5 +- tests/test_modeling_bert_generation.py | 4 +- tests/test_modeling_blenderbot.py | 12 +- tests/test_modeling_common.py | 228 --- tests/test_modeling_ctrl.py | 3 +- tests/test_modeling_fsmt.py | 3 +- tests/test_modeling_gpt2.py | 42 +- tests/test_modeling_openai.py | 3 +- tests/test_modeling_prophetnet.py | 5 +- tests/test_modeling_reformer.py | 10 +- tests/test_modeling_roberta.py | 4 +- tests/test_modeling_t5.py | 4 +- tests/test_modeling_transfo_xl.py | 3 +- tests/test_modeling_xlm.py | 3 +- tests/test_modeling_xlnet.py | 3 +- 38 files changed, 3022 insertions(+), 1191 deletions(-) create mode 100644 docs/source/internal/generation_utils.rst create mode 100644 src/transformers/generation_beam_search.py create mode 100644 src/transformers/generation_logits_process.py create mode 100644 tests/test_generation_beam_search.py create mode 100644 tests/test_generation_logits_process.py create mode 100644 tests/test_generation_utils.py diff --git a/docs/source/index.rst b/docs/source/index.rst index f1f8ab66ad6..737f562f663 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -272,3 +272,4 @@ conversion utilities for the following models: internal/pipelines_utils internal/tokenization_utils internal/trainer_utils + internal/generation_utils diff --git a/docs/source/internal/generation_utils.rst b/docs/source/internal/generation_utils.rst new file mode 100644 index 00000000000..9496827a5e1 --- /dev/null +++ b/docs/source/internal/generation_utils.rst @@ -0,0 +1,50 @@ +Utilities for Generation +----------------------------------------------------------------------------------------------------------------------- + +This page lists all the utility functions used by :meth:`~transformers.PretrainedModel.generate`, +:meth:`~transformers.PretrainedModel.greedy_search`, :meth:`~transformers.PretrainedModel.sample`, +:meth:`~transformers.PretrainedModel.beam_search`, and :meth:`~transformers.PretrainedModel.beam_sample`. + +Most of those are only useful if you are studying the code of the generate methods in the library. + +LogitsProcessor +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +A :class:`~transformers.LogitsProcessor` can be used to modify the prediction scores of a language model head for +generation. + +.. autoclass:: transformers.LogitsProcessor + :members: __call__ + +.. autoclass:: transformers.LogitsProcessorList + :members: __call__ + +.. autoclass:: transformers.MinLengthLogitsProcessor + :members: __call__ + +.. autoclass:: transformers.TemperatureLogitsWarper + :members: __call__ + +.. autoclass:: transformers.RepetitionPenaltyLogitsProcessor + :members: __call__ + +.. autoclass:: transformers.TopPLogitsWarper + :members: __call__ + +.. autoclass:: transformers.TopKLogitsWarper + :members: __call__ + +.. autoclass:: transformers.NoRepeatNGramLogitsProcessor + :members: __call__ + +.. autoclass:: transformers.NoBadWordsLogitsProcessor + :members: __call__ + +BeamSearch +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.BeamScorer + :members: process, finalize + +.. autoclass:: transformers.BeamSearchScorer + :members: process, finalize diff --git a/docs/source/main_classes/model.rst b/docs/source/main_classes/model.rst index 9fa9a4899bc..668b10176f7 100644 --- a/docs/source/main_classes/model.rst +++ b/docs/source/main_classes/model.rst @@ -45,7 +45,7 @@ TFModelUtilsMixin :members: -Generative models +Generation ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: transformers.generation_utils.GenerationMixin diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 6662b2011e2..7285a41eb23 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -299,6 +299,19 @@ if is_torch_available(): TextDataset, TextDatasetForNextSentencePrediction, ) + from .generation_beam_search import BeamScorer, BeamSearchScorer + from .generation_logits_process import ( + LogitsProcessor, + LogitsProcessorList, + LogitsWarper, + MinLengthLogitsProcessor, + NoBadWordsLogitsProcessor, + NoRepeatNGramLogitsProcessor, + RepetitionPenaltyLogitsProcessor, + TemperatureLogitsWarper, + TopKLogitsWarper, + TopPLogitsWarper, + ) from .generation_utils import top_k_top_p_filtering from .modeling_albert import ( ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST, diff --git a/src/transformers/generation_beam_search.py b/src/transformers/generation_beam_search.py new file mode 100644 index 00000000000..135227895d8 --- /dev/null +++ b/src/transformers/generation_beam_search.py @@ -0,0 +1,357 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Inc. team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from collections import UserDict +from typing import Optional, Tuple + +import torch + +from .file_utils import add_start_docstrings + + +PROCESS_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_beams, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using any class inheriting from :class:`~transformers.PretrainedTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + `What are input IDs? <../glossary.html#input-ids>`__ + next_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2 * num_beams)`): + Current scores of the top :obj:`2 * num_beams` non-finished beam hypotheses. + next_tokens (:obj:`torch.LongTensor` of shape :obj:`(batch_size, 2 * num_beams)`): + :obj:`input_ids` of the tokens corresponding to the top :obj:`2 * num_beams` non-finished beam hypotheses. + next_indices (:obj:`torch.LongTensor` of shape :obj:`(batch_size, 2 * num_beams)`): + Beam indices indicating to which beam hypothesis the :obj:`next_tokens` correspond. + pad_token_id (:obj:`int`, `optional`): + The id of the `padding` token. + eos_token_id (:obj:`int`, `optional`): + The id of the `end-of-sequence` token. + + Return: + :obj:`UserDict`: A dictionary composed of the fields as defined above: + + - **next_beam_scores** (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`) -- Updated + scores of all non-finished beams. + - **next_beam_tokens** (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`) -- Next tokens + to be added to the non-finished beam_hypotheses. + - **next_beam_indices** (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`) -- Beam indices + indicating to which beam the next tokens shall be added. + +""" + +FINALIZE_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_beams, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using any class inheriting from :class:`~transformers.PretrainedTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + `What are input IDs? <../glossary.html#input-ids>`__ + final_beam_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`): + The final scores of all non-finished beams. + final_beam_tokens (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`): + The last tokens to be added to the non-finished beam_hypotheses. + final_beam_indices (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`): + The beam indices indicating to which beam the :obj:`final_beam_tokens` shall be added. + pad_token_id (:obj:`int`, `optional`): + The id of the `padding` token. + eos_token_id (:obj:`int`, `optional`): + The id of the `end-of-sequence` token. + + Return: + :obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated + sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or shorter if all + batches finished early due to the :obj:`eos_token_id`. + +""" + + +class BeamScorer(ABC): + """ + Abstract base class for all beam scorers that are used for :meth:`~transformers.PretrainedModel.beam_search` and + :meth:`~transformers.PretrainedModel.beam_sample`. + """ + + @abstractmethod + @add_start_docstrings(PROCESS_INPUTS_DOCSTRING) + def process( + self, + input_ids: torch.LongTensor, + next_scores: torch.FloatTensor, + next_tokens: torch.LongTensor, + next_indices: torch.LongTensor, + **kwargs + ) -> Tuple[torch.Tensor]: + raise NotImplementedError("This is an abstract method.") + + @abstractmethod + @add_start_docstrings(FINALIZE_INPUTS_DOCSTRING) + def finalize( + self, + input_ids: torch.LongTensor, + next_scores: torch.FloatTensor, + next_tokens: torch.LongTensor, + next_indices: torch.LongTensor, + **kwargs + ) -> torch.LongTensor: + raise NotImplementedError("This is an abstract method.") + + +class BeamSearchScorer(BeamScorer): + r""" + :class:`transformers.BeamScorer` implementing standard beam search decoding. + + Adapted in part from `Facebook's XLM beam search code + `__. + + Args: + batch_size (:obj:`int`): + Batch Size of :obj:`input_ids` for which beam search decoding is run in parallel. + max_length (:obj:`int`): + The maximum length of the sequence to be generated. + num_beams (:obj:`int`): + Number of beams for beam search. + device (:obj:`torch.device`): + Defines the device type (*e.g.*, :obj:`"cpu"` or :obj:`"cuda"`) on which this instance of + :obj:`BeamSearchScorer` will be allocated. + length_penalty (:obj:`float`, `optional`, defaults to 1.0): + Exponential penalty to the length. 1.0 means no penalty. Set to values < 1.0 in order to encourage the + model to generate shorter sequences, to a value > 1.0 in order to encourage the model to produce longer + sequences. + do_early_stopping (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether to stop the beam search when at least ``num_beams`` sentences are finished per batch or not. + num_beam_hyps_to_keep (:obj:`int`, `optional`, defaults to 1): + The number of beam hypotheses that shall be returned upon calling + :meth:`~transformer.BeamSearchScorer.finalize`. + """ + + def __init__( + self, + batch_size: int, + max_length: int, + num_beams: int, + device: torch.device, + length_penalty: Optional[float] = 1.0, + do_early_stopping: Optional[bool] = False, + num_beam_hyps_to_keep: Optional[int] = 1, + ): + self.max_length = max_length + self.num_beams = num_beams + self.device = device + self.length_penalty = length_penalty + self.do_early_stopping = do_early_stopping + self.num_beam_hyps_to_keep = num_beam_hyps_to_keep + + self._is_init = False + self._beam_hyps = [ + BeamHypotheses( + num_beams=self.num_beams, + max_length=self.max_length, + length_penalty=self.length_penalty, + early_stopping=self.do_early_stopping, + ) + for _ in range(batch_size) + ] + self._done = torch.tensor([False for _ in range(batch_size)], dtype=torch.bool, device=self.device) + + if not isinstance(num_beams, int) or num_beams <= 1: + raise ValueError( + f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1, one should make use of `greedy_search` instead." + ) + + @property + def is_done(self) -> bool: + return self._done.all() + + def process( + self, + input_ids: torch.LongTensor, + next_scores: torch.FloatTensor, + next_tokens: torch.LongTensor, + next_indices: torch.LongTensor, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + ) -> Tuple[torch.Tensor]: + cur_len = input_ids.shape[-1] + batch_size = len(self._beam_hyps) + assert batch_size == (input_ids.shape[0] // self.num_beams) + + device = input_ids.device + next_beam_scores = torch.zeros((batch_size, self.num_beams), dtype=next_scores.dtype, device=device) + next_beam_tokens = torch.zeros((batch_size, self.num_beams), dtype=next_tokens.dtype, device=device) + next_beam_indices = torch.zeros((batch_size, self.num_beams), dtype=next_indices.dtype, device=device) + + for batch_idx, beam_hyp in enumerate(self._beam_hyps): + if self._done[batch_idx]: + assert ( + len(beam_hyp) >= self.num_beams + ), "Batch can only be done if at least {} beams have been generated".format(self.num_beams) + assert ( + eos_token_id is not None and pad_token_id is not None + ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined" + # pad the batch + next_beam_scores[batch_idx, :] = 0 + next_beam_tokens[batch_idx, :] = pad_token_id + next_beam_indices[batch_idx, :] = 0 + continue + + # next tokens for this sentence + beam_idx = 0 + for beam_token_rank, (next_token, next_score, next_index) in enumerate( + zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx]) + ): + batch_beam_idx = batch_idx * self.num_beams + next_index + # add to generated hypotheses if end of sentence + if (eos_token_id is not None) and (next_token.item() == eos_token_id): + # if beam_token does not belong to top num_beams tokens, it should not be added + is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.num_beams + if is_beam_token_worse_than_top_num_beams: + continue + beam_hyp.add( + input_ids[batch_beam_idx].clone(), + next_score.item(), + ) + else: + # add next predicted token since it is not eos_token + next_beam_scores[batch_idx, beam_idx] = next_score + next_beam_tokens[batch_idx, beam_idx] = next_token + next_beam_indices[batch_idx, beam_idx] = batch_beam_idx + beam_idx += 1 + + # once the beam for next step is full, don't add more tokens to it. + if beam_idx == self.num_beams: + break + + if beam_idx < self.num_beams: + raise ValueError( + f"At most {self.num_beams} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id: {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected." + ) + + # Check if we are done so that we can save a pad step if all(done) + self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done( + next_scores[batch_idx].max().item(), cur_len + ) + + return UserDict( + { + "next_beam_scores": next_beam_scores.view(-1), + "next_beam_tokens": next_beam_tokens.view(-1), + "next_beam_indices": next_beam_indices.view(-1), + } + ) + + def finalize( + self, + input_ids: torch.LongTensor, + final_beam_scores: torch.FloatTensor, + final_beam_tokens: torch.LongTensor, + final_beam_indices: torch.LongTensor, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + ) -> torch.LongTensor: + batch_size = len(self._beam_hyps) + + # finalize all open beam hypotheses and add to generated hypotheses + for batch_idx, beam_hyp in enumerate(self._beam_hyps): + if self._done[batch_idx]: + continue + + # need to add best num_beams hypotheses to generated hyps + for beam_id in range(self.num_beams): + batch_beam_idx = batch_idx * self.num_beams + beam_id + final_score = final_beam_scores[batch_beam_idx].item() + final_tokens = input_ids[batch_beam_idx] + beam_hyp.add(final_tokens, final_score) + + # select the best hypotheses + sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep) + best = [] + + # retrieve best hypotheses + for i, beam_hyp in enumerate(self._beam_hyps): + sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0]) + for j in range(self.num_beam_hyps_to_keep): + best_hyp = sorted_hyps.pop()[1] + sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp) + best.append(best_hyp) + + # prepare for adding eos + sent_max_len = min(sent_lengths.max().item() + 1, self.max_length) + decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len) + # shorter batches are padded if needed + if sent_lengths.min().item() != sent_lengths.max().item(): + assert pad_token_id is not None, "`pad_token_id` has to be defined" + decoded.fill_(pad_token_id) + + # fill with hypotheses and eos_token_id if the latter fits in + for i, hypo in enumerate(best): + decoded[i, : sent_lengths[i]] = hypo + if sent_lengths[i] < self.max_length: + decoded[i, sent_lengths[i]] = eos_token_id + return decoded + + +class BeamHypotheses: + def __init__(self, num_beams: int, max_length: int, length_penalty: float, early_stopping: bool): + """ + Initialize n-best list of hypotheses. + """ + self.max_length = max_length - 1 # ignoring bos_token + self.length_penalty = length_penalty + self.early_stopping = early_stopping + self.num_beams = num_beams + self.beams = [] + self.worst_score = 1e9 + + def __len__(self): + """ + Number of hypotheses in the list. + """ + return len(self.beams) + + def add(self, hyp: torch.LongTensor, sum_logprobs: float): + """ + Add a new hypothesis to the list. + """ + score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty) + if len(self) < self.num_beams or score > self.worst_score: + self.beams.append((score, hyp)) + if len(self) > self.num_beams: + sorted_next_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)]) + del self.beams[sorted_next_scores[0][1]] + self.worst_score = sorted_next_scores[1][0] + else: + self.worst_score = min(score, self.worst_score) + + def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool: + """ + If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst + one in the heap, then we are done with this sentence. + """ + + if len(self) < self.num_beams: + return False + elif self.early_stopping: + return True + else: + cur_score = best_sum_logprobs / cur_len ** self.length_penalty + ret = self.worst_score >= cur_score + return ret diff --git a/src/transformers/generation_logits_process.py b/src/transformers/generation_logits_process.py new file mode 100644 index 00000000000..4b64a13fab5 --- /dev/null +++ b/src/transformers/generation_logits_process.py @@ -0,0 +1,374 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Inc. team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC +from typing import Iterable, List + +import numpy as np +import torch +from torch.nn import functional as F + +from .file_utils import add_start_docstrings + + +LOGITS_PROCESSOR_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using :class:`~transformers.BertTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + `What are input IDs? <../glossary.html#input-ids>`__ + scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.vocab_size)`): + Prediction scores of a language modeling head. These can be scores for each vocabulary token before SoftMax + or scores for each vocabulary token after SoftMax. + + Return: + :obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.vocab_size)`: The processed prediction scores. + +""" + + +class LogitsProcessor(ABC): + """Abstract base class for all logit processors that can be applied during generation.""" + + @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + """Torch method for processing logits.""" + raise NotImplementedError( + f"{self.__class__} is an abstract class. Only classes inheriting this class can be called." + ) + + +class LogitsWarper(ABC): + """Abstract base class for all logit warpers that can be applied during generation with multinomial sampling.""" + + @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + """Torch method for warping logits.""" + raise NotImplementedError( + f"{self.__class__} is an abstract class. Only classes inheriting this class can be called." + ) + + +class LogitsProcessorList(list): + """ + This class can be used to create a list of :class:`~transformers.LogitsProcessor` or + :class:`~transformers.LogitsWarper` to subsequently process a :obj:`scores` input tensor. This class inherits from + list and adds a specific `__call__` method to apply each :class:`~transformers.LogitsProcessor` or + :class:`~transformers.LogitsProcessor` to the inputs. + """ + + @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + for processor in self: + scores = processor(input_ids, scores) + return scores + + +class MinLengthLogitsProcessor(LogitsProcessor): + r""" + :class:`transformers.LogitsProcessor` enforcing a min-length by setting EOS probability to 0. + + Args: + min_length (:obj:`int`): + The minimum length below which the score of :obj:`eos_token_id` is set to :obj:`-float("Inf")`. + eos_token_id (:obj:`int`): + The id of the `end-of-sequence` token. + """ + + def __init__(self, min_length: int, eos_token_id: int): + if not isinstance(min_length, int) or min_length < 0: + raise ValueError(f"`min_length` has to be a positive integer, but is {min_length}") + + if not isinstance(eos_token_id, int) or eos_token_id < 0: + raise ValueError(f"`eos_token_id` has to be a positive integer, but is {eos_token_id}") + + self.min_length = min_length + self.eos_token_id = eos_token_id + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + cur_len = input_ids.shape[-1] + if cur_len < self.min_length: + scores[:, self.eos_token_id] = -float("inf") + return scores + + +class TemperatureLogitsWarper(LogitsWarper): + r""" + :class:`transformers.LogitsWarper` for temperature (exponential scaling output probability distribution). + + Args: + temperature (:obj:`float`): + The value used to module the logits distribution. + """ + + def __init__(self, temperature: float): + if not isinstance(temperature, float) or not (temperature > 0): + raise ValueError(f"`temperature` has to be a strictly positive float, but is {temperature}") + + self.temperature = temperature + + def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: + scores = scores / self.temperature + return scores + + +class RepetitionPenaltyLogitsProcessor(LogitsProcessor): + r""" + :class:`transformers.LogitsProcessor` enforcing an exponential penalty on repeated sequences. + + Args: + repetition_penalty (:obj:`float`): + The parameter for repetition penalty. 1.0 means no penalty. See `this paper + `__ for more details. + """ + + def __init__(self, penalty: float): + if not isinstance(penalty, float) or not (penalty > 0): + raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}") + + self.penalty = penalty + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + for i in range(scores.shape[0]): + for previous_token in set(input_ids[i].tolist()): + # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability + if scores[i, previous_token] < 0: + scores[i, previous_token] *= self.penalty + else: + scores[i, previous_token] /= self.penalty + return scores + + +class TopPLogitsWarper(LogitsWarper): + """ + :class:`transformers.LogitsWarper` that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= + prob_cut_off. + + Args: + top_p (:obj:`float`): + If set to < 1, only the most probable tokens with probabilities that add up to :obj:`top_p` or higher are + kept for generation. + filter_value (:obj:`float`, `optional`, defaults to :obj:`-float("Inf")`): + All filtered values will be set to this float value. + min_tokens_to_keep (:obj:`int`, `optional`, defaults to 1): + Minimum number of tokens that cannot be filtered. + """ + + def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): + if not isinstance(top_p, float) or (top_p < 0 or top_p > 1.0): + raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}") + + self.top_p = top_p + self.filter_value = filter_value + self.min_tokens_to_keep = min_tokens_to_keep + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + sorted_logits, sorted_indices = torch.sort(scores, descending=True) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + + # Remove tokens with cumulative top_p above the threshold (token with 0 are kept) + sorted_indices_to_remove = cumulative_probs > self.top_p + if self.min_tokens_to_keep > 1: + # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) + sorted_indices_to_remove[..., : self.min_tokens_to_keep - 1] = 0 + # Shift the indices to the right to keep also the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + # scatter sorted tensors to original indexing + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + scores[indices_to_remove] = self.filter_value + return scores + + +class TopKLogitsWarper(LogitsWarper): + r""" + :class:`transformers.LogitsWarper` that performs top-k, i.e. restricting to the k highest probability elements. + + Args: + top_k (:obj:`int`): + The number of highest probability vocabulary tokens to keep for top-k-filtering. + filter_value (:obj:`float`, `optional`, defaults to :obj:`-float("Inf")`): + All filtered values will be set to this float value. + min_tokens_to_keep (:obj:`int`, `optional`, defaults to 1): + Minimum number of tokens that cannot be filtered. + """ + + def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): + if not isinstance(top_k, int) or top_k <= 0: + raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}") + + self.top_k = top_k + self.filter_value = filter_value + self.min_tokens_to_keep = min_tokens_to_keep + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + top_k = min(max(self.top_k, self.min_tokens_to_keep), scores.size(-1)) # Safety check + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = scores < torch.topk(scores, top_k)[0][..., -1, None] + scores[indices_to_remove] = self.filter_value + return scores + + +class NoRepeatNGramLogitsProcessor(LogitsProcessor): + r""" + :class:`transformers.LogitsProcessor` that enforces no repetition of n-grams. See `Fairseq + `__. + + Args: + ngram_size (:obj:`int`): + All ngrams of size :obj:`ngram_size` can only occur once. + """ + + def __init__(self, ngram_size: int): + if not isinstance(ngram_size, int) or ngram_size <= 0: + raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}") + self.ngram_size = ngram_size + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + num_batch_hypotheses = scores.shape[0] + cur_len = input_ids.shape[-1] + banned_batch_tokens = self._calc_banned_ngram_tokens(input_ids, num_batch_hypotheses, cur_len) + + for i, banned_tokens in enumerate(banned_batch_tokens): + scores[i, banned_tokens] = -float("inf") + + return scores + + def _calc_banned_ngram_tokens( + self, prev_input_ids: torch.Tensor, num_hypos: int, cur_len: int + ) -> List[Iterable[int]]: + """Copied from fairseq for no_repeat_ngram in beam_search""" + if cur_len + 1 < self.ngram_size: + # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet + return [[] for _ in range(num_hypos)] + generated_ngrams = [{} for _ in range(num_hypos)] + for idx in range(num_hypos): + gen_tokens = prev_input_ids[idx].tolist() + generated_ngram = generated_ngrams[idx] + for ngram in zip(*[gen_tokens[i:] for i in range(self.ngram_size)]): + prev_ngram_tuple = tuple(ngram[:-1]) + generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]] + + def _get_generated_ngrams(hypo_idx): + # Before decoding the next token, prevent decoding of ngrams that have already appeared + start_idx = cur_len + 1 - self.ngram_size + ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist()) + return generated_ngrams[hypo_idx].get(ngram_idx, []) + + banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)] + return banned_tokens + + +class NoBadWordsLogitsProcessor(LogitsProcessor): + """ + :class:`transformers.LogitsProcessor` that enforces that specified sequences will never be sampled. + + Args: + bad_words_ids (:obj:`List[List[int]]`): + List of list of token ids that are not allowed to be generated. In order to get the tokens of the words + that should not appear in the generated text, use :obj:`tokenizer(bad_word, + add_prefix_space=True).input_ids`. + eos_token_id (:obj:`int`): + The id of the `end-of-sequence` token. + """ + + def __init__(self, bad_words_ids: Iterable[Iterable[int]], eos_token_id: int): + + if not isinstance(bad_words_ids, List) or len(bad_words_ids) == 0: + raise ValueError(f"`bad_words_ids` has to be a non-emtpy list, but is {bad_words_ids}.") + if any(not isinstance(bad_word_ids, list) for bad_word_ids in bad_words_ids): + raise ValueError(f"`bad_words_ids` has to be a list of lists, but is {bad_words_ids}.") + if any( + any((not isinstance(token_id, (int, np.integer)) or token_id < 0) for token_id in bad_word_ids) + for bad_word_ids in bad_words_ids + ): + raise ValueError( + f"Each list in `bad_words_ids` has to be a list of positive integers, but is {bad_words_ids}." + ) + + self.bad_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], bad_words_ids)) + + for banned_token_seq in self.bad_words_ids: + assert len(banned_token_seq) > 0, "Banned words token sequences {} cannot have an empty list".format( + bad_words_ids + ) + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + banned_tokens = self._calc_banned_bad_words_ids(input_ids) + scores = self._set_scores_to_inf_for_banned_tokens(scores, banned_tokens) + + return scores + + def _tokens_match(self, prev_tokens: torch.LongTensor, tokens: List[int]) -> bool: + if len(tokens) == 0: + # if bad word tokens is just one token always ban it + return True + elif len(tokens) > len(prev_tokens): + # if bad word tokens are longer then prev input_ids they can't be equal + return False + elif prev_tokens[-len(tokens) :].tolist() == tokens: + # if tokens match + return True + else: + return False + + def _calc_banned_bad_words_ids(self, prev_input_ids: Iterable[int]) -> Iterable[int]: + banned_tokens = [] + for prev_input_ids_slice in prev_input_ids: + banned_tokens_slice = [] + for banned_token_seq in self.bad_words_ids: + if self._tokens_match(prev_input_ids_slice, banned_token_seq[:-1]) is False: + # if tokens do not match continue + continue + + banned_tokens_slice.append(banned_token_seq[-1]) + + banned_tokens.append(banned_tokens_slice) + + return banned_tokens + + def _set_scores_to_inf_for_banned_tokens(self, scores: torch.Tensor, banned_tokens: List[List[int]]) -> None: + """ + Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be a + list of list of banned tokens to ban in the format [[batch index, vocabulary position],... + + Args: + scores: logits distribution of shape (batch size, vocabulary size) + banned_tokens: list of list of tokens to ban of length (batch_size) + """ + banned_mask_list = [] + for idx, batch_banned_tokens in enumerate(banned_tokens): + for token in batch_banned_tokens: + banned_mask_list.append([idx, token]) + if not banned_mask_list: + return scores + + banned_mask = torch.LongTensor(banned_mask_list) + indices = torch.ones(len(banned_mask)) + # A sparse tensor is generated from a list of coordinates: [[0, 1], [0, 2], [2, 0]]. A conversion to dense tensor generates: + # [ 0 1 1 ] + # [ 0 0 0 ] + # [ 1 0 0 ] + + banned_mask = ( + torch.sparse.LongTensor(banned_mask.t(), indices, scores.size()).to(scores.device).to_dense().bool() + ) + scores = scores.masked_fill(banned_mask, -float("inf")) + return scores diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index e85166a8158..206658da98a 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -1,6 +1,6 @@ # coding=utf-8 -# Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team. -# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# Copyright 2020 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team. +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,13 +14,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Iterable, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple import torch -from torch import Tensor from torch.nn import functional as F from .file_utils import ModelOutput +from .generation_beam_search import BeamScorer, BeamSearchScorer +from .generation_logits_process import ( + LogitsProcessorList, + MinLengthLogitsProcessor, + NoBadWordsLogitsProcessor, + NoRepeatNGramLogitsProcessor, + RepetitionPenaltyLogitsProcessor, + TemperatureLogitsWarper, + TopKLogitsWarper, + TopPLogitsWarper, +) from .utils import logging @@ -33,85 +43,245 @@ class GenerationMixin: :class:`~transformers.PreTrainedModel`. """ - def prepare_inputs_for_generation(self, input_ids, **kwargs): + def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> Dict[str, Any]: """ Implement in subclasses of :class:`~transformers.PreTrainedModel` for custom behavior to prepare inputs in the generate method. """ return {"input_ids": input_ids} - def adjust_logits_during_generation(self, logits, **kwargs): + def adjust_logits_during_generation(self, logits: torch.FloatTensor, **kwargs) -> torch.FloatTensor: """ Implement in subclasses of :class:`~transformers.PreTrainedModel` for custom behavior to adjust the logits in the generate method. """ return logits - def enforce_repetition_penalty_(self, lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty): - """ - Enforce the repetition penalty (from the `CTRL paper `__). - """ - for i in range(batch_size * num_beams): - for previous_token in set(prev_output_tokens[i].tolist()): - # if score < 0 then repetition penalty has to multiplied to reduce the previous token probability - if lprobs[i, previous_token] < 0: - lprobs[i, previous_token] *= repetition_penalty - else: - lprobs[i, previous_token] /= repetition_penalty + def _prepare_input_ids_for_generation(self, bos_token_id: int) -> torch.LongTensor: + if bos_token_id is None: + raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.") + return torch.ones((1, 1), dtype=torch.long, device=self.device) * bos_token_id - def postprocess_next_token_scores( + def _prepare_attention_mask_for_generation( + self, input_ids: torch.Tensor, pad_token_id: int, eos_token_id: int + ) -> torch.LongTensor: + is_pad_token_in_inputs_ids = (pad_token_id is not None) and (pad_token_id in input_ids) + is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ( + (eos_token_id is not None) and (pad_token_id != eos_token_id) + ) + if is_pad_token_in_inputs_ids and is_pad_token_not_equal_to_eos_token_id: + return input_ids.ne(pad_token_id).long() + return input_ids.new_ones(input_ids.shape) + + def _prepare_encoder_decoder_kwargs_for_generation( + self, input_ids: torch.LongTensor, model_kwargs + ) -> Dict[str, Any]: + # retrieve encoder hidden states + encoder = self.get_encoder() + encoder_kwargs = { + argument: value for argument, value in model_kwargs.items() if not argument.startswith("decoder_") + } + model_kwargs["encoder_outputs"]: ModelOutput = encoder(input_ids, return_dict=True, **encoder_kwargs) + return model_kwargs + + def _prepare_decoder_input_ids_for_generation( + self, input_ids: torch.LongTensor, decoder_start_token_id: int = None, bos_token_id: int = None, **model_kwargs + ) -> torch.LongTensor: + + if "decoder_input_ids" in model_kwargs: + return model_kwargs["decoder_input_ids"] + + decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id) + decoder_input_ids = ( + torch.ones((input_ids.shape[0], 1), dtype=input_ids.dtype, device=input_ids.device) + * decoder_start_token_id + ) + return decoder_input_ids + + def _get_pad_token_id(self, pad_token_id: int = None, eos_token_id: int = None) -> int: + if pad_token_id is None and eos_token_id is not None: + logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") + pad_token_id = eos_token_id + return pad_token_id + + def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int: + decoder_start_token_id = ( + decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id + ) + bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id + + if decoder_start_token_id is not None: + return decoder_start_token_id + elif ( + hasattr(self.config, "decoder") + and hasattr(self.config.decoder, "decoder_start_token_id") + and self.config.decoder.decoder_start_token_id is not None + ): + return self.config.decoder.decoder_start_token_id + elif bos_token_id is not None: + return bos_token_id + elif ( + hasattr(self.config, "decoder") + and hasattr(self.config.decoder, "bos_token_id") + and self.config.decoder.bos_token_id is not None + ): + return self.config.decoder.bos_token_id + raise ValueError( + "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation." + ) + + @staticmethod + def _expand_inputs_for_generation( + input_ids: torch.LongTensor, + expand_size: int = 1, + is_encoder_decoder: bool = False, + attention_mask: torch.LongTensor = None, + encoder_outputs: ModelOutput = None, + **model_kwargs + ) -> Tuple[torch.LongTensor, Dict[str, Any]]: + expanded_return_idx = ( + torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, expand_size).view(-1).to(input_ids.device) + ) + input_ids = input_ids.index_select(0, expanded_return_idx) + + if attention_mask is not None: + model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx) + + if is_encoder_decoder: + assert encoder_outputs is not None + encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select( + 0, expanded_return_idx + ) + model_kwargs["encoder_outputs"] = encoder_outputs + return input_ids, model_kwargs + + @staticmethod + def _init_sequence_length_for_generation( + input_ids: torch.LongTensor, max_length: int + ) -> Tuple[torch.Tensor, torch.Tensor, int]: + unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + sequence_lengths = input_ids.new(input_ids.shape[0]).fill_(max_length) + + cur_len = input_ids.shape[-1] + return sequence_lengths, unfinished_sequences, cur_len + + @staticmethod + def _update_seq_length_for_generation( + sequence_lengths: torch.LongTensor, + unfinished_sequences: torch.LongTensor, + cur_len: int, + is_eos_in_next_token: torch.BoolTensor, + ) -> Tuple[torch.LongTensor, torch.LongTensor]: + # check if sentence is not finished yet + is_sent_unfinished = unfinished_sequences.mul(is_eos_in_next_token.long()).bool() + + # update sentence length + sequence_lengths = sequence_lengths.masked_fill(is_sent_unfinished, cur_len) + unfinished_sequences = unfinished_sequences.mul((~is_eos_in_next_token).long()) + return sequence_lengths, unfinished_sequences + + @staticmethod + def _update_model_kwargs_for_generation( + outputs: ModelOutput, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False + ) -> Dict[str, Any]: + # update past + if "past_key_values" in outputs: + model_kwargs["past"] = outputs.past_key_values + elif "mems" in outputs: + model_kwargs["past"] = outputs.mems + elif "past_buckets_states" in outputs: + model_kwargs["past"] = outputs.past_buckets_states + else: + model_kwargs["past"] = None + + # update attention mask + if not is_encoder_decoder: + if "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + model_kwargs["attention_mask"] = torch.cat( + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 + ) + + return model_kwargs + + @staticmethod + def _reorder_cache(past: Tuple[torch.Tensor], beam_idx: torch.Tensor) -> Tuple[torch.Tensor]: + """ + This function is used to re-order the :obj:`past_key_values` or :obj:`mems` cache if + :meth:`~transformers.PretrainedModel.beam_search` or :meth:`~transformers.PretrainedModel.beam_sample` is + called. This is required to match :obj:`past_key_values` or :obj:`mems` with the correct beam_idx at every + generation step. + + For custom re-ordering of :obj:`past_key_values` or :obj:`mems`, the function should be implemented in + subclasses of :class:`~transformers.PreTrainedModel`. + """ + return tuple(layer_past.index_select(1, beam_idx) for layer_past in past) + + def _get_logits_warper( + self, top_k: int = None, top_p: float = None, temperature: float = None, num_beams: int = None + ) -> LogitsProcessorList: + """ + This class returns a :obj:`~transformers.LogitsProcessorList` list object that contains all relevant + :obj:`~transformers.LogitsWarper` instances used for multinomial sampling. + """ + + # init warp parameters + top_k = top_k if top_k is not None else self.config.top_k + top_p = top_p if top_p is not None else self.config.top_p + temperature = temperature if temperature is not None else self.config.temperature + # instantiate warpers list + warpers = LogitsProcessorList() + + # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files + # all samplers can be found in `generation_utils_samplers.py` + if top_k is not None and top_k != 0: + warpers.append(TopKLogitsWarper(top_k=top_k, min_tokens_to_keep=(2 if num_beams > 1 else 1))) + if top_p is not None and top_p < 1.0: + warpers.append(TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=(2 if num_beams > 1 else 1))) + if temperature is not None and temperature != 1.0: + warpers.append(TemperatureLogitsWarper(temperature)) + return warpers + + def _get_logits_processor( self, - scores, - input_ids, - no_repeat_ngram_size, - bad_words_ids, - cur_len, - min_length, - max_length, - eos_token_id, - repetition_penalty, - batch_size, - num_beams, - ): - # repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858) - if repetition_penalty != 1.0: - self.enforce_repetition_penalty_( - scores, - batch_size, - num_beams, - input_ids, - repetition_penalty, - ) + repetition_penalty: float, + no_repeat_ngram_size: int, + bad_words_ids: List[List[int]], + min_length: int, + eos_token_id: int, + ) -> LogitsProcessorList: + """ + This class returns a :obj:`~transformers.LogitsProcessorList` list object that contains all relevant + :obj:`~transformers.LogitsProcessor` instances used to modify the scores of the language model head. + """ - # set eos token prob to zero if min_length is not reached - if eos_token_id is not None and cur_len < min_length: - scores[:, eos_token_id] = -float("inf") - - if no_repeat_ngram_size > 0: - # calculate a list of banned tokens to prevent repetitively generating the same ngrams - num_batch_hypotheses = batch_size * num_beams - # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345 - banned_batch_tokens = calc_banned_ngram_tokens( - input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len - ) - for i, banned_tokens in enumerate(banned_batch_tokens): - scores[i, banned_tokens] = -float("inf") + # init warp parameters + repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty + no_repeat_ngram_size = ( + no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size + ) + bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids + min_length = min_length if min_length is not None else self.config.min_length + eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + # instantiate processors list + processors = LogitsProcessorList() + # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files + # all samplers can be found in `generation_utils_samplers.py` + if repetition_penalty is not None and repetition_penalty != 1.0: + processors.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)) + if no_repeat_ngram_size is not None and no_repeat_ngram_size > 0: + processors.append(NoRepeatNGramLogitsProcessor(no_repeat_ngram_size)) if bad_words_ids is not None: - # Exclude EOS token (already processed) - bad_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], bad_words_ids)) - # calculate a list of banned tokens according to bad words - banned_tokens = calc_banned_bad_words_ids(input_ids.tolist(), bad_words_ids) - # Modify the scores in place by setting the banned tokens logits to `-inf` - set_scores_to_inf_for_banned_tokens(scores, banned_tokens) - - return scores + processors.append(NoBadWordsLogitsProcessor(bad_words_ids, eos_token_id)) + if min_length is not None and eos_token_id is not None and min_length > -1: + processors.append(MinLengthLogitsProcessor(min_length, eos_token_id)) + return processors @torch.no_grad() def generate( self, input_ids: Optional[torch.LongTensor] = None, - decoder_input_ids: Optional[torch.LongTensor] = None, max_length: Optional[int] = None, min_length: Optional[int] = None, do_sample: Optional[bool] = None, @@ -128,17 +298,13 @@ class GenerationMixin: length_penalty: Optional[float] = None, no_repeat_ngram_size: Optional[int] = None, num_return_sequences: Optional[int] = None, - attention_mask: Optional[torch.LongTensor] = None, decoder_start_token_id: Optional[int] = None, use_cache: Optional[bool] = None, **model_kwargs ) -> torch.LongTensor: r""" Generates sequences for models with a language modeling head. The method currently supports greedy decoding, - beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling. - - Adapted in part from `Facebook's XLM beam search code - `__. + multinomial sampling, beam-search decoding, and beam-search multinomial sampling. Apart from :obj:`input_ids` and :obj:`attention_mask`, all the arguments below will default to the value of the attribute of the same name inside the :class:`~transformers.PretrainedConfig` of the model. The default values @@ -152,9 +318,6 @@ class GenerationMixin: input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): The sequence used as a prompt for the generation. If :obj:`None` the method initializes it as an empty :obj:`torch.LongTensor` of shape :obj:`(1,)`. - decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): - initial input_ids for the decoder of encoder-decoder type models. If :obj:`None` then only - decoder_start_token_id is passed as the first token to the decoder. max_length (:obj:`int`, `optional`, defaults to 20): The maximum length of the sequence to be generated. min_length (:obj:`int`, `optional`, defaults to 10): @@ -170,7 +333,7 @@ class GenerationMixin: top_k (:obj:`int`, `optional`, defaults to 50): The number of highest probability vocabulary tokens to keep for top-k-filtering. top_p (:obj:`float`, `optional`, defaults to 1.0): - If set to float < 1, only the most probable tokens with probabilities that add up to ``top_p`` or + If set to float < 1, only the most probable tokens with probabilities that add up to :obj:`top_p` or higher are kept for generation. repetition_penalty (:obj:`float`, `optional`, defaults to 1.0): The parameter for repetition penalty. 1.0 means no penalty. See `this paper @@ -182,792 +345,854 @@ class GenerationMixin: eos_token_id (:obj:`int`, `optional`): The id of the `end-of-sequence` token. length_penalty (:obj:`float`, `optional`, defaults to 1.0): - Exponential penalty to the length. 1.0 means no penalty. - - Set to values < 1.0 in order to encourage the model to generate shorter sequences, to a value > 1.0 in - order to encourage the model to produce longer sequences. + Exponential penalty to the length. 1.0 means no penalty. Set to values < 1.0 in order to encourage the + model to generate shorter sequences, to a value > 1.0 in order to encourage the model to produce longer + sequences. no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0): If set to int > 0, all ngrams of that size can only occur once. - bad_words_ids(:obj:`List[int]`, `optional`): + bad_words_ids(:obj:`List[List[int]]`, `optional`): List of token ids that are not allowed to be generated. In order to get the tokens of the words that - should not appear in the generated text, use :obj:`tokenizer.encode(bad_word, add_prefix_space=True)`. + should not appear in the generated text, use :obj:`tokenizer(bad_word, + add_prefix_space=True).input_ids`. num_return_sequences(:obj:`int`, `optional`, defaults to 1): The number of independently computed returned sequences for each element in the batch. attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): Mask to avoid performing attention on padding token indices. Mask values are in ``[0, 1]``, 1 for - tokens that are not masked, and 0 for masked tokens. - - If not provided, will default to a tensor the same shape as :obj:`input_ids` that masks the pad token. - - `What are attention masks? <../glossary.html#attention-mask>`__ + tokens that are not masked, and 0 for masked tokens. If not provided, will default to a tensor the same + shape as :obj:`input_ids` that masks the pad token. `What are attention masks? + <../glossary.html#attention-mask>`__ decoder_start_token_id (:obj:`int`, `optional`): If an encoder-decoder model starts decoding with a different token than `bos`, the id of that token. use_cache: (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding. model_kwargs: - Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. + Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If the + model is an Encoder-Decoder model, encoder specific kwargs should not be prefixed and decoder specific + kwargs should be prefixed with `decoder_`. Return: - :obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or shorter if all batches finished early due to the :obj:`eos_token_id`. Examples:: - tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer - model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache. - outputs = model.generate(max_length=40) # do greedy decoding - print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True))) + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM - tokenizer = AutoTokenizer.from_pretrained('openai-gpt') # Initialize tokenizer - model = AutoModelWithLMHead.from_pretrained('openai-gpt') # Download model and configuration from S3 and cache. - input_context = 'The dog' - input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context - outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3, temperature=1.5) # generate 3 independent sequences using beam search decoding (5 beams) with sampling from initial context 'The dog' - for i in range(3): # 3 output sequences were generated - print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True))) + >>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2") + >>> model = AutoModelForCausalLM.from_pretrained("distilgpt2") + >>> # do greedy decoding without providing a prompt + >>> outputs = model.generate(max_length=40) + >>> print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True)) - tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer - model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache. - input_context = 'The dog' - input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context - outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, num_return_sequences=3, do_sample=True) # generate 3 candidates using sampling - for i in range(3): # 3 output sequences were generated - print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True))) + >>> tokenizer = AutoTokenizer.from_pretrained("t5-base") + >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") + >>> document = ( + ... "at least two people were killed in a suspected bomb attack on a passenger bus " + ... "in the strife-torn southern philippines on monday , the military said." + ... ) + >>> # encode input contex + >>> input_ids = tokenizer(document, return_tensors="pt").input_ids + >>> # generate 3 independent sequences using beam search decoding (5 beams) + >>> # with T5 encoder-decoder model conditioned on short news article. + >>> outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3) + >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) - tokenizer = AutoTokenizer.from_pretrained('ctrl') # Initialize tokenizer - model = AutoModelWithLMHead.from_pretrained('ctrl') # Download model and configuration from S3 and cache. - input_context = 'Legal My neighbor is' # "Legal" is one of the control codes for ctrl - input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context - outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2) # generate sequences - print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True))) + >>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2") + >>> model = AutoModelForCausalLM.from_pretrained("distilgpt2") + >>> input_context = "The dog" + >>> # encode input context + >>> input_ids = tokenizer(input_context, return_tensors="pt").input_ids + >>> # generate 3 candidates using sampling + >>> outputs = model.generate(input_ids=input_ids, max_length=20, num_return_sequences=3, do_sample=True) + >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) - tokenizer = AutoTokenizer.from_pretrained('gpt2') # Initialize tokenizer - model = AutoModelWithLMHead.from_pretrained('gpt2') # Download model and configuration from S3 and cache. - input_context = 'My cute dog' # "Legal" is one of the control codes for ctrl - bad_words_ids = [tokenizer.encode(bad_word, add_prefix_space=True) for bad_word in ['idiot', 'stupid', 'shut up']] - input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context - outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids) # generate sequences without allowing bad_words to be generated + >>> tokenizer = AutoTokenizer.from_pretrained("ctrl") + >>> model = AutoModelForCausalLM.from_pretrained("ctrl") + >>> # "Legal" is one of the control codes for ctrl + >>> input_context = "Legal My neighbor is" + >>> # encode input context + >>> input_ids = tokenizer(input_context, return_tensors="pt").input_ids + >>> outputs = model.generate(input_ids=input_ids, max_length=20, repetition_penalty=1.2) + >>> print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True)) + + >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") + >>> model = AutoModelForCausalLM.from_pretrained("gpt2") + >>> input_context = "My cute dog" + >>> # get tokens of words that should not be generated + >>> bad_words_ids = [tokenizer(bad_word, add_prefix_space=True).input_ids for bad_word in ["idiot", "stupid", "shut up"]] + >>> # encode input context + >>> input_ids = tokenizer(input_context, return_tensors="pt").input_ids + >>> # generate sequences without allowing bad_words to be generated + >>> outputs = model.generate(input_ids=input_ids, max_length=20, do_sample=True, bad_words_ids=bad_words_ids) + >>> print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True)) """ - # We cannot generate if the model does not have a LM head - if self.get_output_embeddings() is None: - raise AttributeError( - "You tried to generate sequences with a model that does not have a LM Head." - "Please use another model class (e.g. `OpenAIGPTLMHeadModel`, `XLNetLMHeadModel`, `GPT2LMHeadModel`, `CTRLLMHeadModel`, `T5WithLMHeadModel`, `TransfoXLLMHeadModel`, `XLMWithLMHeadModel`, `BartForConditionalGeneration` )" - ) - - max_length = max_length if max_length is not None else self.config.max_length - min_length = min_length if min_length is not None else self.config.min_length - do_sample = do_sample if do_sample is not None else self.config.do_sample - early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping - use_cache = use_cache if use_cache is not None else self.config.use_cache + # set init values num_beams = num_beams if num_beams is not None else self.config.num_beams - temperature = temperature if temperature is not None else self.config.temperature - top_k = top_k if top_k is not None else self.config.top_k - top_p = top_p if top_p is not None else self.config.top_p - repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty - bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id - pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id - length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty - no_repeat_ngram_size = ( - no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size - ) - bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids + max_length = max_length if max_length is not None else self.config.max_length + do_sample = do_sample if do_sample is not None else self.config.do_sample num_return_sequences = ( num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences ) - decoder_start_token_id = ( - decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id - ) - if input_ids is not None: - batch_size = input_ids.shape[0] # overridden by the input batch_size - else: - batch_size = 1 - - assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer." - assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer." - assert isinstance(do_sample, bool), "`do_sample` should be a boolean." - assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean." - assert isinstance(use_cache, bool), "`use_cache` should be a boolean." - assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictly positive integer." - assert temperature > 0, "`temperature` should be strictly positive." - assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer." - assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1." - assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1." - assert input_ids is not None or ( - isinstance(bos_token_id, int) and bos_token_id >= 0 - ), "If input_ids is not defined, `bos_token_id` should be a positive integer." - assert pad_token_id is None or ( - isinstance(pad_token_id, int) and (pad_token_id >= 0) - ), "`pad_token_id` should be a positive integer." - assert (eos_token_id is None) or ( - isinstance(eos_token_id, int) and (eos_token_id >= 0) - ), "`eos_token_id` should be a positive integer." - assert length_penalty > 0, "`length_penalty` should be strictly positive." - assert ( - isinstance(no_repeat_ngram_size, int) and no_repeat_ngram_size >= 0 - ), "`no_repeat_ngram_size` should be a positive integer." - assert ( - isinstance(num_return_sequences, int) and num_return_sequences > 0 - ), "`num_return_sequences` should be a strictly positive integer." - assert ( - bad_words_ids is None or isinstance(bad_words_ids, list) and isinstance(bad_words_ids[0], list) - ), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated" + pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id + bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + use_cache = use_cache if use_cache is not None else self.config.use_cache if input_ids is None: - assert isinstance(bos_token_id, int) and bos_token_id >= 0, ( - "you should either supply a context to complete as `input_ids` input " - "or a `bos_token_id` (integer >= 0) as a first token to start the generation." + # init `input_ids` with bos_token_id + input_ids = self._prepare_input_ids_for_generation(bos_token_id) + + if model_kwargs.get("attention_mask", None) is None: + # init `attention_mask` depending on `pad_token_id` + model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( + input_ids, pad_token_id, eos_token_id ) - input_ids = torch.full( - (batch_size, 1), - bos_token_id, - dtype=torch.long, - device=next(self.parameters()).device, - ) - else: - assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)." - # not allow to duplicate outputs when greedy decoding - if do_sample is False: - if num_beams == 1: - # no_beam_search greedy generation conditions - assert ( - num_return_sequences == 1 - ), "Greedy decoding will always produce the same output for num_beams == 1 and num_return_sequences > 1. Please set num_return_sequences = 1" - - else: - # beam_search greedy generation conditions - assert ( - num_beams >= num_return_sequences - ), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences" - - # create attention mask if necessary - # TODO (PVP): this should later be handled by the forward fn() in each model in the future see PR 3140 - if (attention_mask is None) and (pad_token_id is not None) and (pad_token_id in input_ids): - attention_mask = input_ids.ne(pad_token_id).long() - elif attention_mask is None: - attention_mask = input_ids.new_ones(input_ids.shape) - - # set pad_token_id to eos_token_id if not set. Important that this is done after - # attention_mask is created + # special case if pad_token_id is not defined if pad_token_id is None and eos_token_id is not None: - logger.warning( - "Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_id) - ) + logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") pad_token_id = eos_token_id - # vocab size - if hasattr(self.config, "vocab_size"): - vocab_size = self.config.vocab_size - elif ( - self.config.is_encoder_decoder - and hasattr(self.config, "decoder") - and hasattr(self.config.decoder, "vocab_size") - ): - vocab_size = self.config.decoder.vocab_size - else: - raise ValueError("either self.config.vocab_size or self.config.decoder.vocab_size needs to be defined") - - # set effective batch size and effective batch multiplier according to do_sample - if do_sample: - effective_batch_size = batch_size * num_return_sequences - effective_batch_mult = num_return_sequences - else: - effective_batch_size = batch_size - effective_batch_mult = 1 - if self.config.is_encoder_decoder: - if decoder_start_token_id is None: - # see if BOS token can be used for decoder_start_token_id - if bos_token_id is not None: - decoder_start_token_id = bos_token_id - elif ( - hasattr(self.config, "decoder") - and hasattr(self.config.decoder, "bos_token_id") - and self.config.decoder.bos_token_id is not None - ): - decoder_start_token_id = self.config.decoder.bos_token_id - else: - raise ValueError( - "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation" - ) + # add encoder_outputs to model_kwargs + model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, model_kwargs) - assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self) - assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder) - - # get encoder and store encoder outputs - encoder = self.get_encoder() - encoder_outputs: ModelOutput = encoder(input_ids, attention_mask=attention_mask, return_dict=True) - - # Expand input ids if num_beams > 1 or num_return_sequences > 1 - if num_return_sequences > 1 or num_beams > 1: - input_ids_len = input_ids.shape[-1] - input_ids = input_ids.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len) - attention_mask = attention_mask.unsqueeze(1).expand( - batch_size, effective_batch_mult * num_beams, input_ids_len + # set input_ids as decoder_input_ids + input_ids = self._prepare_decoder_input_ids_for_generation( + input_ids, decoder_start_token_id=decoder_start_token_id, bos_token_id=bos_token_id, **model_kwargs ) - input_ids = input_ids.contiguous().view( - effective_batch_size * num_beams, input_ids_len - ) # shape: (batch_size * num_return_sequences * num_beams, cur_len) - attention_mask = attention_mask.contiguous().view( - effective_batch_size * num_beams, input_ids_len - ) # shape: (batch_size * num_return_sequences * num_beams, cur_len) + if "encoder_outputs" not in model_kwargs or not isinstance(model_kwargs["encoder_outputs"], ModelOutput): + raise ValueError("Make sure that `model_kwargs` include `encoder_outputs` of type `ModelOutput`.") - if self.config.is_encoder_decoder: - device = next(self.parameters()).device - if decoder_input_ids is not None: - # give initial decoder input ids - input_ids = decoder_input_ids.repeat(effective_batch_size * num_beams, 1).to(device) - else: - # create empty decoder input_ids - input_ids = torch.full( - (effective_batch_size * num_beams, 1), - decoder_start_token_id, - dtype=torch.long, - device=device, + # determine generation mode + is_greedy_gen_mode = (num_beams == 1) and do_sample is False + is_sample_gen_mode = (num_beams == 1) and do_sample is True + is_beam_gen_mode = (num_beams > 1) and do_sample is False + is_beam_sample_gen_mode = (num_beams > 1) and do_sample is True + + # set model_kwargs + model_kwargs["use_cache"] = use_cache + + # get distribution pre_processing samplers + logits_processor = self._get_logits_processor( + repetition_penalty=repetition_penalty, + no_repeat_ngram_size=no_repeat_ngram_size, + bad_words_ids=bad_words_ids, + min_length=min_length, + eos_token_id=eos_token_id, + ) + + if is_greedy_gen_mode: + if num_return_sequences > 1: + raise ValueError( + f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search." ) - cur_len = input_ids.shape[-1] - assert ( - batch_size == encoder_outputs.last_hidden_state.shape[0] - ), f"expected encoder_outputs.last_hidden_state to have 1st dimension bs={batch_size}, got {encoder_outputs.last_hidden_state.shape[0]} " - - # expand batch_idx to assign correct encoder output for expanded input_ids (due to num_beams > 1 and num_return_sequences > 1) - expanded_batch_idxs = ( - torch.arange(batch_size) - .view(-1, 1) - .repeat(1, num_beams * effective_batch_mult) - .view(-1) - .to(input_ids.device) + # greedy search + return self.greedy_search( + input_ids, + logits_processor=logits_processor, + max_length=max_length, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + **model_kwargs, ) - # expand encoder_outputs - encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select( - 0, expanded_batch_idxs + elif is_sample_gen_mode: + # get probability distribution warper + logits_warper = self._get_logits_warper( + top_k=top_k, top_p=top_p, temperature=temperature, num_beams=num_beams ) - # save encoder_outputs in `model_kwargs` - model_kwargs["encoder_outputs"] = encoder_outputs + # expand input_ids with `num_return_sequences` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids, + expand_size=num_return_sequences, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) - else: - cur_len = input_ids.shape[-1] + # sample + return self.sample( + input_ids, + logits_processor=logits_processor, + logits_warper=logits_warper, + max_length=max_length, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + **model_kwargs, + ) + + elif is_beam_gen_mode: + batch_size = input_ids.shape[0] + + length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty + early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping + + if num_return_sequences > num_beams: + raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") + + beam_scorer = BeamSearchScorer( + batch_size=batch_size, + max_length=max_length, + num_beams=num_beams, + device=self.device, + length_penalty=length_penalty, + do_early_stopping=early_stopping, + num_beam_hyps_to_keep=num_return_sequences, + ) + # interleave with `num_beams` + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs + ) + return self.beam_search( + input_ids, + beam_scorer, + logits_processor=logits_processor, + max_length=max_length, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + **model_kwargs, + ) + + elif is_beam_sample_gen_mode: + logits_warper = self._get_logits_warper( + top_k=top_k, top_p=top_p, temperature=temperature, num_beams=num_beams + ) + + batch_size = input_ids.shape[0] * num_return_sequences + + length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty + beam_scorer = BeamSearchScorer( + batch_size=batch_size, + max_length=max_length, + num_beams=num_beams, + device=self.device, + length_penalty=length_penalty, + do_early_stopping=early_stopping, + ) + + # interleave with `num_beams * num_return_sequences` + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids, + expand_size=num_beams * num_return_sequences, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + + return self.beam_sample( + input_ids, + beam_scorer, + logits_processor=logits_processor, + logits_warper=logits_warper, + max_length=max_length, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + **model_kwargs, + ) + + def greedy_search( + self, + input_ids: torch.LongTensor, + logits_processor: Optional[LogitsProcessorList] = None, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + **model_kwargs + ): + r""" + Generates sequences for models with a language modeling head using greedy decoding. + + Parameters: + + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + The sequence used as a prompt for the generation. If :obj:`None` the method initializes it as an empty + :obj:`torch.LongTensor` of shape :obj:`(1,)`. + logits_processor (:obj:`LogitsProcessorList`, `optional`): + An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from + :class:`~transformers.LogitsProcessor` used to modify the prediction scores of the language modeling + head applied at each generation step. + max_length (:obj:`int`, `optional`, defaults to 20): + The maximum length of the sequence to be generated. + pad_token_id (:obj:`int`, `optional`): + The id of the `padding` token. + eos_token_id (:obj:`int`, `optional`): + The id of the `end-of-sequence` token. + model_kwargs: + Additional model specific keyword arguments will be forwarded to the :obj:`forward` function of the + model. If model is an encoder-decoder model the kwargs should include :obj:`encoder_outputs`. + + Return: + :obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated + sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or shorter if all + batches finished early due to the :obj:`eos_token_id`. + + Examples:: + + >>> from transformers import ( + ... AutoTokenizer, + ... AutoModelForCausalLM, + ... LogitsProcessorList, + ... MinLengthLogitsProcessor, + ... ) + + >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") + >>> model = AutoModelForCausalLM.from_pretrained("gpt2") + + >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token + >>> model.config.pad_token_id = model.config.eos_token_id + + >>> input_prompt = "Today is a beautiful day, and" + >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids + + >>> # instantiate logits processors + >>> logits_processor = LogitsProcessorList([ + ... MinLengthLogitsProcessor(15, eos_token_id=model.config.eos_token_id), + ... ]) + + >>> outputs = model.greedy_search(input_ids, logits_processor=logits_processor) + + >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) + """ + + # init values + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + max_length = max_length if max_length is not None else self.config.max_length + pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + + # init sequence length tensors + sequence_lengths, unfinished_sequences, cur_len = self._init_sequence_length_for_generation( + input_ids, max_length + ) + + while cur_len < max_length: + # prepare model inputs + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + + # forward pass to get next token + outputs = self(**model_inputs, return_dict=True) + next_token_logits = outputs.logits[:, -1, :] + + # pre-process distribution + scores = logits_processor(input_ids, next_token_logits) + + # argmax + next_tokens = torch.argmax(scores, dim=-1) + + # add code that transfomers next_tokens to tokens_to_add + if eos_token_id is not None: + assert pad_token_id is not None, "If eos_token_id is defined, make sure that pad_token_id is defined." + next_tokens = next_tokens * unfinished_sequences + (pad_token_id) * (1 - unfinished_sequences) + + # add token and increase length by one + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + + # update sequence length + if eos_token_id is not None: + sequence_lengths, unfinished_sequences = self._update_seq_length_for_generation( + sequence_lengths, unfinished_sequences, cur_len, next_tokens == eos_token_id + ) + + # update model kwargs + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + + # stop when there is a in each sentence, or if we exceed the maximul length + if unfinished_sequences.max() == 0: + break + + # increase cur_len + cur_len = cur_len + 1 + + return input_ids + + def sample( + self, + input_ids: torch.LongTensor, + logits_processor: Optional[LogitsProcessorList] = None, + logits_warper: Optional[LogitsProcessorList] = None, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + **model_kwargs + ): + r""" + Generates sequences for models with a language modeling head using multinomial sampling. + + Parameters: + + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + The sequence used as a prompt for the generation. If :obj:`None` the method initializes it as an empty + :obj:`torch.LongTensor` of shape :obj:`(1,)`. + logits_processor (:obj:`LogitsProcessorList`, `optional`): + An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from + :class:`~transformers.LogitsProcessor` used to modify the prediction scores of the language modeling + head applied at each generation step. + logits_warper (:obj:`LogitsProcessorList`, `optional`): + An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from + :class:`~transformers.LogitsWarper` used to warp the prediction score distribution of the language + modeling head applied before multinomial sampling at each generation step. + max_length (:obj:`int`, `optional`, defaults to 20): + The maximum length of the sequence to be generated. + pad_token_id (:obj:`int`, `optional`): + The id of the `padding` token. + eos_token_id (:obj:`int`, `optional`): + The id of the `end-of-sequence` token. + model_kwargs: + Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If + model is an encoder-decoder model the kwargs should include :obj:`encoder_outputs`. + + Return: + :obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated + sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or shorter if all + batches finished early due to the :obj:`eos_token_id`. + + Examples:: + + >>> from transformers import ( + ... AutoTokenizer, + ... AutoModelForCausalLM, + ... LogitsProcessorList, + ... MinLengthLogitsProcessor, + ... TopKLogitsWarper, + ... TemperatureLogitsWarper, + ... ) + + >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") + >>> model = AutoModelForCausalLM.from_pretrained("gpt2") + + >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token + >>> model.config.pad_token_id = model.config.eos_token_id + + >>> input_prompt = "Today is a beautiful day, and" + >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids + + >>> # instantiate logits processors + >>> logits_processor = LogitsProcessorList([ + ... MinLengthLogitsProcessor(15, eos_token_id=model.config.eos_token_id), + ... ]) + >>> # instantiate logits processors + >>> logits_warper = LogitsProcessorList([ + ... TopKLogitsWarper(50), + ... TemperatureLogitsWarper(0.7), + ... ]) + + >>> outputs = model.sample(input_ids, logits_processor=logits_processor, logits_warper=logits_warper) + + >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) + """ + + # init values + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() + max_length = max_length if max_length is not None else self.config.max_length + pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + + # init sequence length tensors + sequence_lengths, unfinished_sequences, cur_len = self._init_sequence_length_for_generation( + input_ids, max_length + ) + + # auto-regressive generation + while cur_len < max_length: + # prepare model inputs + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + + # forward pass to get next token + outputs = self(**model_inputs, return_dict=True) + next_token_logits = outputs.logits[:, -1, :] + + # pre-process distribution + scores = logits_processor(input_ids, next_token_logits) + scores = logits_warper(input_ids, scores) + + # sample + probs = F.softmax(scores, dim=-1) + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + + # add code that transfomers next_tokens to tokens_to_add + if eos_token_id is not None: + assert pad_token_id is not None, "If eos_token_id is defined, make sure that pad_token_id is defined." + next_tokens = next_tokens * unfinished_sequences + (pad_token_id) * (1 - unfinished_sequences) + + # add token and increase length by one + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + cur_len = cur_len + 1 + + # update sequence length + if eos_token_id is not None: + sequence_lengths, unfinished_sequences = self._update_seq_length_for_generation( + sequence_lengths, unfinished_sequences, cur_len, next_tokens == eos_token_id + ) + + # stop when there is a in each sentence, or if we exceed the maximul length + if unfinished_sequences.max() == 0: + break + + # update model kwargs + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + + return input_ids + + def beam_search( + self, + input_ids: torch.LongTensor, + beam_scorer: BeamScorer, + logits_processor: Optional[LogitsProcessorList] = None, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + **model_kwargs + ): + r""" + Generates sequences for models with a language modeling head using beam search decoding. + + Parameters: + + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + The sequence used as a prompt for the generation. If :obj:`None` the method initializes it as an empty + :obj:`torch.LongTensor` of shape :obj:`(1,)`. + beam_scorer (:obj:`BeamScorer`): + An derived instance of :class:`~transformers.BeamScorer` that defines how beam hypotheses are + constructed, stored and sorted during generation. For more information, the documentation of + :class:`~transformers.BeamScorer` should be read. + logits_processor (:obj:`LogitsProcessorList`, `optional`): + An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from + :class:`~transformers.LogitsProcessor` used to modify the prediction scores of the language modeling + head applied at each generation step. + max_length (:obj:`int`, `optional`, defaults to 20): + The maximum length of the sequence to be generated. + pad_token_id (:obj:`int`, `optional`): + The id of the `padding` token. + eos_token_id (:obj:`int`, `optional`): + The id of the `end-of-sequence` token. + model_kwargs: + Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If + model is an encoder-decoder model the kwargs should include :obj:`encoder_outputs`. + + Return: + :obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated + sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or shorter if all + batches finished early due to the :obj:`eos_token_id`. + + Examples:: + + >>> from transformers import ( + ... AutoTokenizer, + ... AutoModelForSeq2SeqLM, + ... LogitsProcessorList, + ... MinLengthLogitsProcessor, + ... BeamSearchScorer, + ... ) + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("t5-base") + >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") + + >>> encoder_input_str = "translate English to German: How old are you?" + >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids + + + >>> # lets run beam search using 3 beams + >>> num_beams = 3 + >>> # define decoder start token ids + >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) + >>> input_ids = input_ids * model.config.decoder_start_token_id + + >>> # add encoder_outputs to model keyword arguments + >>> model_kwargs = { + ... "encoder_outputs": model.get_encoder()(encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True) + ... } + + >>> # instantiate beam scorer + >>> beam_scorer = BeamSearchScorer( + ... batch_size=1, + ... max_length=model.config.max_length, + ... num_beams=num_beams, + ... device=model.device, + ... ) + + >>> # instantiate logits processors + >>> logits_processor = LogitsProcessorList([ + ... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id), + ... ]) + + >>> outputs = model.beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs) + + >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) + """ + + # init values + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + max_length = max_length if max_length is not None else self.config.max_length + pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + + batch_size = len(beam_scorer._beam_hyps) + num_beams = beam_scorer.num_beams + + batch_beam_size, cur_len = input_ids.shape assert ( - cur_len < max_length - ), f"The context has {cur_len} number of tokens, but `max_length` is only {max_length}. Please make sure that `max_length` is bigger than the number of tokens, by setting either `generate(max_length=...,...)` or `config.max_length = ...`" + num_beams * batch_size == batch_beam_size + ), "Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." - if num_beams > 1: - output = self._generate_beam_search( - input_ids, - cur_len=cur_len, - max_length=max_length, - min_length=min_length, - do_sample=do_sample, - early_stopping=early_stopping, - temperature=temperature, - top_k=top_k, - top_p=top_p, - repetition_penalty=repetition_penalty, - no_repeat_ngram_size=no_repeat_ngram_size, - bad_words_ids=bad_words_ids, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - batch_size=effective_batch_size, - num_return_sequences=num_return_sequences, - length_penalty=length_penalty, - num_beams=num_beams, - vocab_size=vocab_size, - attention_mask=attention_mask, - use_cache=use_cache, - model_kwargs=model_kwargs, - ) - else: - output = self._generate_no_beam_search( - input_ids, - cur_len=cur_len, - max_length=max_length, - min_length=min_length, - do_sample=do_sample, - temperature=temperature, - top_k=top_k, - top_p=top_p, - repetition_penalty=repetition_penalty, - no_repeat_ngram_size=no_repeat_ngram_size, - bad_words_ids=bad_words_ids, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - batch_size=effective_batch_size, - attention_mask=attention_mask, - use_cache=use_cache, - model_kwargs=model_kwargs, - ) + beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) + beam_scores[:, 1:] = -1e9 + beam_scores = beam_scores.view((batch_size * num_beams,)) - return output - - def _generate_no_beam_search( - self, - input_ids, - cur_len, - max_length, - min_length, - do_sample, - temperature, - top_k, - top_p, - repetition_penalty, - no_repeat_ngram_size, - bad_words_ids, - pad_token_id, - eos_token_id, - batch_size, - attention_mask, - use_cache, - model_kwargs, - ): - """ - Generate sequences for each example without beam search (num_beams == 1). All returned sequence are generated - independently. - """ - # length of generated sentences / unfinished sentences - unfinished_sents = input_ids.new(batch_size).fill_(1) - sent_lengths = input_ids.new(batch_size).fill_(max_length) - - past = None while cur_len < max_length: - model_inputs = self.prepare_inputs_for_generation( - input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_kwargs - ) + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) outputs = self(**model_inputs, return_dict=True) next_token_logits = outputs.logits[:, -1, :] - scores = self.postprocess_next_token_scores( - scores=next_token_logits, - input_ids=input_ids, - no_repeat_ngram_size=no_repeat_ngram_size, - bad_words_ids=bad_words_ids, - cur_len=cur_len, - min_length=min_length, - max_length=max_length, - eos_token_id=eos_token_id, - repetition_penalty=repetition_penalty, - batch_size=batch_size, - num_beams=1, + # adjust tokens for Bart, *e.g.* + next_token_logits = self.adjust_logits_during_generation( + next_token_logits, cur_len=cur_len, max_length=max_length ) - # if model has past, then set the past variable to speed up decoding - if "past_key_values" in outputs: - past = outputs.past_key_values - elif "mems" in outputs: - past = outputs.mems + next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size) - if do_sample: - # Temperature (higher temperature => more likely to sample low probability tokens) - if temperature != 1.0: - scores = scores / temperature - # Top-p/top-k filtering - next_token_logscores = top_k_top_p_filtering(scores, top_k=top_k, top_p=top_p) - # Sample - probs = F.softmax(next_token_logscores, dim=-1) - next_token = torch.multinomial(probs, num_samples=1).squeeze(1) - else: - # Greedy decoding - next_token = torch.argmax(next_token_logits, dim=-1) + next_token_scores = logits_processor(input_ids, next_token_scores) + next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores) + # reshape for beam search + vocab_size = next_token_scores.shape[-1] + next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) - # update generations and finished sentences - if eos_token_id is not None: - # pad finished sentences if eos_token_id exist - tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents) - else: - tokens_to_add = next_token + next_token_scores, next_tokens = torch.topk( + next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True + ) - # add token and increase length by one - input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1) + next_indices = next_tokens // vocab_size + next_tokens = next_tokens % vocab_size + + # stateless + beam_outputs = beam_scorer.process( + input_ids, + next_token_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + ) + beam_scores = beam_outputs["next_beam_scores"] + beam_next_tokens = beam_outputs["next_beam_tokens"] + beam_idx = beam_outputs["next_beam_indices"] + + input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) cur_len = cur_len + 1 - if eos_token_id is not None: - eos_in_sents = tokens_to_add == eos_token_id - # if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length - is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(eos_in_sents.long()).bool() - sent_lengths.masked_fill_(is_sents_unfinished_and_token_to_add_is_eos, cur_len) - # unfinished_sents is set to zero if eos in sentence - unfinished_sents.mul_((~eos_in_sents).long()) + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + if model_kwargs["past"] is not None: + model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx) - # stop when there is a in each sentence, or if we exceed the maximum length - if unfinished_sents.max() == 0: + if beam_scorer.is_done: break - # extend attention_mask for new generated input if only decoder - if self.config.is_encoder_decoder is False: - attention_mask = torch.cat( - [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 - ) - - return input_ids - - def _generate_beam_search( - self, - input_ids, - cur_len, - max_length, - min_length, - do_sample, - early_stopping, - temperature, - top_k, - top_p, - repetition_penalty, - no_repeat_ngram_size, - bad_words_ids, - pad_token_id, - eos_token_id, - batch_size, - num_return_sequences, - length_penalty, - num_beams, - vocab_size, - attention_mask, - use_cache, - model_kwargs, - ): - """Generate sequences for each example with beam search.""" - - # generated hypotheses - generated_hyps = [ - BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=early_stopping) - for _ in range(batch_size) - ] - - # scores for each sentence in the beam - beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) - - # for greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times - if do_sample is False: - beam_scores[:, 1:] = -1e9 - beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,) - - # cache compute states - past = None - - # done sentences - done = [False for _ in range(batch_size)] - - while cur_len < max_length: - model_inputs = self.prepare_inputs_for_generation( - input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_kwargs - ) - outputs = self(**model_inputs, return_dict=True) # (batch_size * num_beams, cur_len, vocab_size) - next_token_logits = outputs.logits[:, -1, :] # (batch_size * num_beams, vocab_size) - - # if model has past, then set the past variable to speed up decoding - if "past_key_values" in outputs: - past = outputs.past_key_values - elif "mems" in outputs: - past = outputs.mems - - if self.config.is_encoder_decoder and do_sample is False: - # TODO (PVP) still a bit hacky here - there might be a better solution - next_token_logits = self.adjust_logits_during_generation( - next_token_logits, cur_len=cur_len, max_length=max_length - ) - - scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size) - - scores = self.postprocess_next_token_scores( - scores=scores, - input_ids=input_ids, - no_repeat_ngram_size=no_repeat_ngram_size, - bad_words_ids=bad_words_ids, - cur_len=cur_len, - min_length=min_length, - max_length=max_length, - eos_token_id=eos_token_id, - repetition_penalty=repetition_penalty, - batch_size=batch_size, - num_beams=num_beams, - ) - - assert scores.shape == (batch_size * num_beams, vocab_size), "Shapes of scores: {} != {}".format( - scores.shape, (batch_size * num_beams, vocab_size) - ) - - if do_sample: - _scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size) - # Temperature - if temperature != 1.0: - _scores = _scores / temperature - # Top-p/top-k filtering - _scores = top_k_top_p_filtering( - _scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2 - ) # (batch_size * num_beams, vocab_size) - # re-organize to group the beam together to sample from all beam_idxs - _scores = _scores.contiguous().view( - batch_size, num_beams * vocab_size - ) # (batch_size, num_beams * vocab_size) - - # Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search) - probs = F.softmax(_scores, dim=-1) - next_tokens = torch.multinomial(probs, num_samples=2 * num_beams) # (batch_size, num_beams * 2) - # Compute next scores - next_scores = torch.gather(_scores, -1, next_tokens) # (batch_size, num_beams * 2) - # sort the sampled vector to make sure that the first num_beams samples are the best - next_scores, next_scores_indices = torch.sort(next_scores, descending=True, dim=1) - next_tokens = torch.gather(next_tokens, -1, next_scores_indices) # (batch_size, num_beams * 2) - - else: - next_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size) - - # re-organize to group the beam together (we are keeping top hypothesis across beams) - next_scores = next_scores.view( - batch_size, num_beams * vocab_size - ) # (batch_size, num_beams * vocab_size) - - next_scores, next_tokens = torch.topk(next_scores, 2 * num_beams, dim=1, largest=True, sorted=True) - - assert next_scores.size() == next_tokens.size() == (batch_size, 2 * num_beams) - - # next batch beam content - next_batch_beam = [] - - # for each sentence - for batch_idx in range(batch_size): - - # if we are done with this sentence, add a pad token - if done[batch_idx]: - assert ( - len(generated_hyps[batch_idx]) >= num_beams - ), "Batch can only be done if at least {} beams have been generated".format(num_beams) - assert ( - eos_token_id is not None and pad_token_id is not None - ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined" - next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch - continue - - # next sentence beam content, this will get added to next_batch_beam - next_sent_beam = [] - - # next tokens for this sentence - for beam_token_rank, (beam_token_id, beam_token_score) in enumerate( - zip(next_tokens[batch_idx], next_scores[batch_idx]) - ): - # get beam and token IDs - beam_id = beam_token_id // vocab_size - token_id = beam_token_id % vocab_size - - effective_beam_id = batch_idx * num_beams + beam_id - # add to generated hypotheses if end of sentence - if (eos_token_id is not None) and (token_id.item() == eos_token_id): - # if beam_token does not belong to top num_beams tokens, it should not be added - is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams - if is_beam_token_worse_than_top_num_beams: - continue - generated_hyps[batch_idx].add( - input_ids[effective_beam_id].clone(), - beam_token_score.item(), - ) - else: - # add next predicted token since it is not eos_token - next_sent_beam.append((beam_token_score, token_id, effective_beam_id)) - - # once the beam for next step is full, don't add more tokens to it. - if len(next_sent_beam) == num_beams: - break - - # Check if we are done so that we can save a pad step if all(done) - done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done( - next_scores[batch_idx].max().item(), cur_len - ) - - # update next beam content - assert len(next_sent_beam) == num_beams, "Beam should always be full" - next_batch_beam.extend(next_sent_beam) - assert len(next_batch_beam) == num_beams * (batch_idx + 1), "We should have added num_beams each step" - - # stop when we are done with each sentence - if all(done): - break - - # sanity check / prepare next batch - assert len(next_batch_beam) == batch_size * num_beams - beam_scores = beam_scores.new([x[0] for x in next_batch_beam]) - beam_tokens = input_ids.new([x[1] for x in next_batch_beam]) - beam_idx = input_ids.new([x[2] for x in next_batch_beam]) - - # re-order batch and update current length - input_ids = input_ids[beam_idx, :] - input_ids = torch.cat([input_ids, beam_tokens.unsqueeze(1)], dim=-1) - cur_len = cur_len + 1 - - # re-order internal states - if past is not None: - past = self._reorder_cache(past, beam_idx) - - # extend attention_mask for new generated input if only decoder - if self.config.is_encoder_decoder is False: - attention_mask = torch.cat( - [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 - ) - - # finalize all open beam hypotheses and add to generated hypotheses - for batch_idx in range(batch_size): - if done[batch_idx]: - continue - - # test that beam scores match previously calculated scores if not eos and batch_idx not done - if eos_token_id is not None and all( - (token_id % vocab_size).item() != eos_token_id for token_id in next_tokens[batch_idx] - ): - assert torch.all( - next_scores[batch_idx, :num_beams] == beam_scores.view(batch_size, num_beams)[batch_idx] - ), "If batch_idx is not done, final next scores: {} have to equal to accumulated beam_scores: {}".format( - next_scores[:, :num_beams][batch_idx], - beam_scores.view(batch_size, num_beams)[batch_idx], - ) - - # need to add best num_beams hypotheses to generated hyps - for beam_id in range(num_beams): - effective_beam_id = batch_idx * num_beams + beam_id - final_score = beam_scores[effective_beam_id].item() - final_tokens = input_ids[effective_beam_id] - generated_hyps[batch_idx].add(final_tokens, final_score) - - # depending on whether greedy generation is wanted or not define different output_batch_size and output_num_return_sequences_per_batch - output_batch_size = batch_size if do_sample else batch_size * num_return_sequences - output_num_return_sequences_per_batch = 1 if do_sample else num_return_sequences - - # select the best hypotheses - sent_lengths = input_ids.new(output_batch_size) - best = [] - - # retrieve best hypotheses - for i, hypotheses in enumerate(generated_hyps): - sorted_hyps = sorted(hypotheses.beams, key=lambda x: x[0]) - for j in range(output_num_return_sequences_per_batch): - effective_batch_idx = output_num_return_sequences_per_batch * i + j - best_hyp = sorted_hyps.pop()[1] - sent_lengths[effective_batch_idx] = len(best_hyp) - best.append(best_hyp) - - # prepare for adding eos - sent_max_len = min(sent_lengths.max().item() + 1, max_length) - decoded = input_ids.new(output_batch_size, sent_max_len) - # shorter batches are padded if needed - if sent_lengths.min().item() != sent_lengths.max().item(): - assert pad_token_id is not None, "`pad_token_id` has to be defined" - decoded.fill_(pad_token_id) - - # fill with hypotheses and eos_token_id if the latter fits in - for i, hypo in enumerate(best): - decoded[i, : sent_lengths[i]] = hypo - if sent_lengths[i] < max_length: - decoded[i, sent_lengths[i]] = eos_token_id + decoded = beam_scorer.finalize( + input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id + ) return decoded - @staticmethod - def _reorder_cache(past: Tuple, beam_idx: Tensor) -> Tuple[Tensor]: - return tuple(layer_past.index_select(1, beam_idx) for layer_past in past) + def beam_sample( + self, + input_ids: torch.LongTensor, + beam_scorer: BeamScorer, + logits_processor: Optional[LogitsProcessorList] = None, + logits_warper: Optional[LogitsProcessorList] = None, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + **model_kwargs + ): + r""" + Generates sequences for models with a language modeling head using beam search with multinomial sampling. + Parameters: -def calc_banned_ngram_tokens(prev_input_ids: Tensor, num_hypos: int, no_repeat_ngram_size: int, cur_len: int) -> None: - """Copied from fairseq for no_repeat_ngram in beam_search""" - if cur_len + 1 < no_repeat_ngram_size: - # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet - return [[] for _ in range(num_hypos)] - generated_ngrams = [{} for _ in range(num_hypos)] - for idx in range(num_hypos): - gen_tokens = prev_input_ids[idx].tolist() - generated_ngram = generated_ngrams[idx] - for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]): - prev_ngram_tuple = tuple(ngram[:-1]) - generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]] + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + The sequence used as a prompt for the generation. If :obj:`None` the method initializes it as an empty + :obj:`torch.LongTensor` of shape :obj:`(1,)`. + beam_scorer (:obj:`BeamScorer`): + A derived instance of :class:`~transformers.BeamScorer` that defines how beam hypotheses are + constructed, stored and sorted during generation. For more information, the documentation of + :class:`~transformers.BeamScorer` should be read. + logits_processor (:obj:`LogitsProcessorList`, `optional`): + An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from + :class:`~transformers.LogitsProcessor` used to modify the prediction scores of the language modeling + head applied at each generation step. + logits_warper (:obj:`LogitsProcessorList`, `optional`): + An instance of :class:`~transformers.LogitsProcessorList`. List of instances of class derived from + :class:`~transformers.LogitsWarper` used to warp the prediction score distribution of the language + modeling head applied before multinomial sampling at each generation step. + max_length (:obj:`int`, `optional`, defaults to 20): + The maximum length of the sequence to be generated. + pad_token_id (:obj:`int`, `optional`): + The id of the `padding` token. + eos_token_id (:obj:`int`, `optional`): + The id of the `end-of-sequence` token. + model_kwargs: + Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If + model is an encoder-decoder model the kwargs should include :obj:`encoder_outputs`. - def _get_generated_ngrams(hypo_idx): - # Before decoding the next token, prevent decoding of ngrams that have already appeared - start_idx = cur_len + 1 - no_repeat_ngram_size - ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist()) - return generated_ngrams[hypo_idx].get(ngram_idx, []) + Return: + :obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated + sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or shorter if all + batches finished early due to the :obj:`eos_token_id`. - banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)] - return banned_tokens + Examples:: + >>> from transformers import ( + ... AutoTokenizer, + ... AutoModelForSeq2SeqLM, + ... LogitsProcessorList, + ... MinLengthLogitsProcessor, + ... TopKLogitsWarper, + ... TemperatureLogitsWarper, + ... BeamSearchScorer, + ... ) + >>> import torch -def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iterable[int]) -> Iterable[int]: - banned_tokens = [] + >>> tokenizer = AutoTokenizer.from_pretrained("t5-base") + >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") - def _tokens_match(prev_tokens, tokens): - if len(tokens) == 0: - # if bad word tokens is just one token always ban it - return True - if len(tokens) > len(prev_tokens): - # if bad word tokens are longer than prev tokens they can't be equal - return False + >>> encoder_input_str = "translate English to German: How old are you?" + >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids - if prev_tokens[-len(tokens) :] == tokens: - # if tokens match - return True - else: - return False + >>> # lets run beam search using 3 beams + >>> num_beams = 3 + >>> # define decoder start token ids + >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) + >>> input_ids = input_ids * model.config.decoder_start_token_id - for prev_input_ids_slice in prev_input_ids: - banned_tokens_slice = [] + >>> # add encoder_outputs to model keyword arguments + >>> model_kwargs = { + ... "encoder_outputs": model.get_encoder()(encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True) + ... } - for banned_token_seq in bad_words_ids: - assert len(banned_token_seq) > 0, "Banned words token sequences {} cannot have an empty list".format( - bad_words_ids + >>> # instantiate beam scorer + >>> beam_scorer = BeamSearchScorer( + ... batch_size=1, + ... max_length=model.config.max_length, + ... num_beams=num_beams, + ... device=model.device, + ... ) + + >>> # instantiate logits processors + >>> logits_processor = LogitsProcessorList([ + ... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id) + ... ]) + >>> # instantiate logits processors + >>> logits_warper = LogitsProcessorList([ + ... TopKLogitsWarper(50), + ... TemperatureLogitsWarper(0.7), + ... ]) + + >>> outputs = model.beam_sample( + ... input_ids, beam_scorer, logits_processor=logits_processor, logits_warper=logits_warper, **model_kwargs + ... ) + + >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) + """ + + # init values + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + max_length = max_length if max_length is not None else self.config.max_length + pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + + batch_size = len(beam_scorer._beam_hyps) + num_beams = beam_scorer.num_beams + + batch_beam_size, cur_len = input_ids.shape + + beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) + beam_scores = beam_scores.view((batch_size * num_beams,)) + + while cur_len < max_length: + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + + outputs = self(**model_inputs, return_dict=True) + next_token_logits = outputs.logits[:, -1, :] + + # adjust token scores (a no-op by default) + next_token_logits = self.adjust_logits_during_generation( + next_token_logits, cur_len=cur_len, max_length=max_length ) - if _tokens_match(prev_input_ids_slice, banned_token_seq[:-1]) is False: - # if tokens do not match continue - continue + next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size) - banned_tokens_slice.append(banned_token_seq[-1]) + next_token_scores = logits_processor(input_ids, next_token_scores) + next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores) + next_token_scores = logits_warper(input_ids, next_token_scores) - banned_tokens.append(banned_tokens_slice) + # reshape for beam search + vocab_size = next_token_scores.shape[-1] + next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) - return banned_tokens + probs = F.softmax(next_token_scores, dim=-1) + next_tokens = torch.multinomial(probs, num_samples=2 * num_beams) + next_token_scores = torch.gather(next_token_scores, -1, next_tokens) + next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1) + next_tokens = torch.gather(next_tokens, -1, _indices) -def set_scores_to_inf_for_banned_tokens(scores: torch.Tensor, banned_tokens: List[List[int]]) -> None: - """ - Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be a list - of list of banned tokens to ban in the format [[batch index, vocabulary position],... + next_indices = next_tokens // vocab_size + next_tokens = next_tokens % vocab_size - Args: - scores: logits distribution of shape (batch size, vocabulary size) - banned_tokens: list of list of tokens to ban of length (batch_size) - """ - banned_mask_list = [] - for idx, batch_banned_tokens in enumerate(banned_tokens): - for token in batch_banned_tokens: - banned_mask_list.append([idx, token]) - if not banned_mask_list: - return - banned_mask = torch.LongTensor(banned_mask_list) - indices = torch.ones(len(banned_mask)) - # A sparse tensor is generated from a list of coordinates: [[0, 1], [0, 2], [2, 0]]. A conversion to dense tensor generates: - # [ 0 1 1 ] - # [ 0 0 0 ] - # [ 1 0 0 ] + # stateless + beam_outputs = beam_scorer.process( + input_ids, + next_token_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + ) + beam_scores = beam_outputs["next_beam_scores"] + beam_next_tokens = beam_outputs["next_beam_tokens"] + beam_idx = beam_outputs["next_beam_indices"] - banned_mask = torch.sparse.LongTensor(banned_mask.t(), indices, scores.size()).to(scores.device).to_dense().bool() - scores.masked_fill_(banned_mask, -float("inf")) + input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) + cur_len = cur_len + 1 + + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + if model_kwargs["past"] is not None: + model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx) + + if beam_scorer.is_done: + break + + decoded = beam_scorer.finalize( + input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id + ) + + return decoded def top_k_top_p_filtering( - logits: Tensor, + logits: torch.FloatTensor, top_k: int = 0, top_p: float = 1.0, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1, -) -> Tensor: +) -> torch.FloatTensor: """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering @@ -980,73 +1205,11 @@ def top_k_top_p_filtering( From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 """ if top_k > 0: - top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check - # Remove all tokens with a probability less than the last token of the top-k - indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] - logits[indices_to_remove] = filter_value + logits = TopKLogitsWarper(top_k=top_k, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)( + None, logits + ) - if top_p < 1.0: - sorted_logits, sorted_indices = torch.sort(logits, descending=True) - cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + if 0 <= top_p <= 1.0: + logits = TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=min_tokens_to_keep)(None, logits) - # Remove tokens with cumulative probability above the threshold (token with 0 are kept) - sorted_indices_to_remove = cumulative_probs > top_p - if min_tokens_to_keep > 1: - # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) - sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 - # Shift the indices to the right to keep also the first token above the threshold - sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() - sorted_indices_to_remove[..., 0] = 0 - - # scatter sorted tensors to original indexing - indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) - logits[indices_to_remove] = filter_value return logits - - -class BeamHypotheses(object): - def __init__(self, num_beams, max_length, length_penalty, early_stopping): - """ - Initialize n-best list of hypotheses. - """ - self.max_length = max_length - 1 # ignoring bos_token - self.length_penalty = length_penalty - self.early_stopping = early_stopping - self.num_beams = num_beams - self.beams = [] - self.worst_score = 1e9 - - def __len__(self): - """ - Number of hypotheses in the list. - """ - return len(self.beams) - - def add(self, hyp, sum_logprobs): - """ - Add a new hypothesis to the list. - """ - score = sum_logprobs / len(hyp) ** self.length_penalty - if len(self) < self.num_beams or score > self.worst_score: - self.beams.append((score, hyp)) - if len(self) > self.num_beams: - sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)]) - del self.beams[sorted_scores[0][1]] - self.worst_score = sorted_scores[1][0] - else: - self.worst_score = min(score, self.worst_score) - - def is_done(self, best_sum_logprobs, cur_len): - """ - If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst - one in the heap, then we are done with this sentence. - """ - - if len(self) < self.num_beams: - return False - elif self.early_stopping: - return True - else: - cur_score = best_sum_logprobs / cur_len ** self.length_penalty - ret = self.worst_score >= cur_score - return ret diff --git a/src/transformers/modeling_bart.py b/src/transformers/modeling_bart.py index e90bbc37c7a..475ba0f386b 100644 --- a/src/transformers/modeling_bart.py +++ b/src/transformers/modeling_bart.py @@ -1084,7 +1084,7 @@ class BartForConditionalGeneration(PretrainedBartModel): ) def prepare_inputs_for_generation( - self, decoder_input_ids, past, attention_mask, use_cache, encoder_outputs, **kwargs + self, decoder_input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs ): return { "input_ids": None, # encoder_outputs is defined. input_ids not needed diff --git a/src/transformers/modeling_ctrl.py b/src/transformers/modeling_ctrl.py index c3ee2f41496..10d6949b2cd 100644 --- a/src/transformers/modeling_ctrl.py +++ b/src/transformers/modeling_ctrl.py @@ -514,12 +514,12 @@ class CTRLLMHeadModel(CTRLPreTrainedModel): def get_output_embeddings(self): return self.lm_head - def prepare_inputs_for_generation(self, input_ids, past, **kwargs): + def prepare_inputs_for_generation(self, input_ids, past=None, use_cache=None, **kwargs): # only last token for inputs_ids if past is defined in kwargs if past: input_ids = input_ids[:, -1].unsqueeze(-1) - return {"input_ids": input_ids, "past_key_values": past, "use_cache": kwargs["use_cache"]} + return {"input_ids": input_ids, "past_key_values": past, "use_cache": use_cache} @add_start_docstrings_to_model_forward(CTRL_INPUTS_DOCSTRING) @add_code_sample_docstrings( diff --git a/src/transformers/modeling_encoder_decoder.py b/src/transformers/modeling_encoder_decoder.py index 8efd43f5552..50381ed7c6a 100644 --- a/src/transformers/modeling_encoder_decoder.py +++ b/src/transformers/modeling_encoder_decoder.py @@ -431,7 +431,7 @@ class EncoderDecoderModel(PreTrainedModel): encoder_attentions=encoder_outputs.attentions, ) - def prepare_inputs_for_generation(self, input_ids, past, attention_mask, encoder_outputs, **kwargs): + def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, encoder_outputs=None, **kwargs): decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids) decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None input_dict = { diff --git a/src/transformers/modeling_fsmt.py b/src/transformers/modeling_fsmt.py index 800eec9b7e2..ba7f18cbf38 100644 --- a/src/transformers/modeling_fsmt.py +++ b/src/transformers/modeling_fsmt.py @@ -1107,7 +1107,7 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel): ) def prepare_inputs_for_generation( - self, decoder_input_ids, past, attention_mask, use_cache, encoder_outputs, **kwargs + self, decoder_input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs ): return { "input_ids": None, # encoder_outputs is defined. input_ids not needed diff --git a/src/transformers/modeling_prophetnet.py b/src/transformers/modeling_prophetnet.py index 417111c409b..0a0c6b1be2f 100644 --- a/src/transformers/modeling_prophetnet.py +++ b/src/transformers/modeling_prophetnet.py @@ -1800,7 +1800,7 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel): return loss def prepare_inputs_for_generation( - self, decoder_input_ids, past, attention_mask, use_cache, encoder_outputs, **kwargs + self, decoder_input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs ): assert encoder_outputs is not None, "`encoder_outputs` have to be passed for generation." diff --git a/src/transformers/modeling_rag.py b/src/transformers/modeling_rag.py index a203511dcfc..a2d8ddcf26a 100644 --- a/src/transformers/modeling_rag.py +++ b/src/transformers/modeling_rag.py @@ -22,6 +22,7 @@ import torch from .configuration_rag import RagConfig from .configuration_utils import PretrainedConfig from .file_utils import add_start_docstrings_to_model_forward, replace_return_docstrings +from .generation_beam_search import BeamSearchScorer from .modeling_outputs import ModelOutput from .modeling_utils import PreTrainedModel from .retrieval_rag import RagRetriever @@ -825,7 +826,7 @@ class RagSequenceForGeneration(RagPreTrainedModel): num_return_sequences=None, # defaults to 1 num_beams=None, # defaults to 1 n_docs=None, - **kwargs + **model_kwargs ): """ Implements RAG sequence "thorough" decoding. Read the :meth:`~transformers.PreTrainedModel.generate`` @@ -872,7 +873,6 @@ class RagSequenceForGeneration(RagPreTrainedModel): ) num_beams = num_beams if num_beams is not None else self.config.num_beams - # TODO(patrick) - clean up generate here if self.retriever is not None and context_input_ids is None: question_hidden_states = self.question_encoder(input_ids, attention_mask=attention_mask)[0] context_input_ids = self.retriever( @@ -887,10 +887,9 @@ class RagSequenceForGeneration(RagPreTrainedModel): context_input_ids = context_input_ids.to(input_ids) hypos = [] - kwargs["num_beams"] = num_beams - kwargs["num_return_sequences"] = num_beams - kwargs["attention_mask"] = None - kwargs["n_docs"] = n_docs + model_kwargs["num_beams"] = num_beams + model_kwargs["num_return_sequences"] = num_beams + model_kwargs["attention_mask"] = None for index in range(len(input_ids)): # first, generate beams from documents: @@ -898,7 +897,7 @@ class RagSequenceForGeneration(RagPreTrainedModel): output_sequences = self.generator.generate( generator_input_ids, - **kwargs, + **model_kwargs, ) # n_docs * n_beam, tgt_len if do_deduplication: # do_deduplication, max_output_len @@ -1018,7 +1017,15 @@ class RagTokenForGeneration(RagPreTrainedModel): return self.rag.generator.adjust_logits_during_generation(logits, cur_len=cur_len, max_length=max_length) def prepare_inputs_for_generation( - self, decoder_input_ids, past, attention_mask, use_cache, encoder_outputs, doc_scores, n_docs=None, **kwargs + self, + decoder_input_ids, + past=None, + attention_mask=None, + use_cache=None, + encoder_outputs=None, + doc_scores=None, + n_docs=None, + **kwargs ): return { "input_ids": None, @@ -1222,11 +1229,12 @@ class RagTokenForGeneration(RagPreTrainedModel): eos_token_id=None, length_penalty=None, no_repeat_ngram_size=None, + repetition_penalty=None, bad_words_ids=None, num_return_sequences=None, decoder_start_token_id=None, n_docs=None, - **kwargs + **model_kwargs ): """ Implements RAG token decoding. @@ -1307,22 +1315,15 @@ class RagTokenForGeneration(RagPreTrainedModel): """ # set default parameters n_docs = n_docs if n_docs is not None else self.config.n_docs - max_length = max_length if max_length is not None else self.config.max_length - min_length = min_length if min_length is not None else self.config.min_length - early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping - use_cache = use_cache if use_cache is not None else self.config.use_cache num_beams = num_beams if num_beams is not None else self.config.num_beams - bos_token_id = bos_token_id if bos_token_id is not None else self.config.generator.bos_token_id - pad_token_id = pad_token_id if pad_token_id is not None else self.config.generator.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.config.generator.eos_token_id - length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty - no_repeat_ngram_size = ( - no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size - ) - bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids + max_length = max_length if max_length is not None else self.config.max_length num_return_sequences = ( num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences ) + bos_token_id = bos_token_id if bos_token_id is not None else self.config.generator.bos_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.config.generator.eos_token_id + pad_token_id = pad_token_id if pad_token_id is not None else self.config.generator.pad_token_id + use_cache = use_cache if use_cache is not None else self.config.use_cache decoder_start_token_id = ( decoder_start_token_id if decoder_start_token_id is not None @@ -1365,7 +1366,7 @@ class RagTokenForGeneration(RagPreTrainedModel): encoder = self.rag.generator.get_encoder() encoder_outputs = encoder(input_ids=context_input_ids, attention_mask=context_attention_mask, return_dict=True) - decoder_input_ids = torch.full( + input_ids = torch.full( (batch_size * num_beams, 1), decoder_start_token_id, dtype=torch.long, @@ -1388,64 +1389,57 @@ class RagTokenForGeneration(RagPreTrainedModel): doc_scores = doc_scores.repeat_interleave(num_beams, dim=0) # define start_len & additional parameters - cur_len = 1 - vocab_size = self.config.generator.vocab_size - kwargs["doc_scores"] = doc_scores - kwargs["encoder_outputs"] = encoder_outputs - kwargs["n_docs"] = n_docs + model_kwargs["doc_scores"] = doc_scores + model_kwargs["encoder_outputs"] = encoder_outputs + model_kwargs["attention_mask"] = context_attention_mask + model_kwargs["n_docs"] = n_docs - # not needed. TODO(PVP): change after generate refactor - do_sample = False - temperature = self.config.temperature - top_k = self.config.top_k - top_p = self.config.top_p - repetition_penalty = self.config.repetition_penalty + pre_processor = self._get_logits_processor( + repetition_penalty=repetition_penalty, + no_repeat_ngram_size=no_repeat_ngram_size, + bad_words_ids=bad_words_ids, + min_length=min_length, + eos_token_id=eos_token_id, + ) - if num_beams > 1: - return self._generate_beam_search( - decoder_input_ids, - cur_len=cur_len, + if num_beams == 1: + if num_return_sequences > 1: + raise ValueError( + f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search." + ) + return self.greedy_search( + input_ids, + pre_processor=pre_processor, max_length=max_length, - min_length=min_length, - do_sample=do_sample, - early_stopping=early_stopping, - temperature=temperature, - top_k=top_k, - top_p=top_p, - repetition_penalty=repetition_penalty, - no_repeat_ngram_size=no_repeat_ngram_size, - bad_words_ids=bad_words_ids, pad_token_id=pad_token_id, eos_token_id=eos_token_id, + **model_kwargs, + ) + elif num_beams > 1: + length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty + early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping + if num_return_sequences > num_beams: + raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") + beam_scorer = BeamSearchScorer( batch_size=batch_size, - num_return_sequences=num_return_sequences, - length_penalty=length_penalty, + max_length=max_length, num_beams=num_beams, - vocab_size=vocab_size, - attention_mask=context_attention_mask, - use_cache=use_cache, - model_kwargs=kwargs, + device=self.device, + length_penalty=length_penalty, + do_early_stopping=early_stopping, + num_beam_hyps_to_keep=num_return_sequences, + ) + return self.beam_search( + input_ids, + beam_scorer, + pre_processor=pre_processor, + max_length=max_length, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + **model_kwargs, ) else: - return self._generate_no_beam_search( - decoder_input_ids, - cur_len=cur_len, - max_length=max_length, - min_length=min_length, - do_sample=do_sample, - temperature=temperature, - top_k=top_k, - top_p=top_p, - repetition_penalty=repetition_penalty, - no_repeat_ngram_size=no_repeat_ngram_size, - bad_words_ids=bad_words_ids, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - batch_size=batch_size, - attention_mask=context_attention_mask, - use_cache=use_cache, - model_kwargs=kwargs, - ) + raise ValueError(f"`num_beams` has to be an integer strictly superior to 0 (≥ 1), but is {num_beams}") def get_input_embeddings(self): return self.rag.generator.get_input_embeddings() diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 3110c591f55..434369934f5 100755 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -638,7 +638,6 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): rotations_shape = (self.num_attention_heads, vectors.shape[-1], num_hashes, rotation_size // 2) # create a random self.attention_head_size x num_hashes x num_buckets/2 random_rotations = torch.randn(rotations_shape, device=vectors.device, dtype=vectors.dtype) - # Output dim: Batch_Size x Num_Attn_Heads x Num_Hashes x Seq_Len x Num_Buckets/2 rotated_vectors = torch.einsum("bmtd,mdhr->bmhtr", vectors, random_rotations) @@ -1471,7 +1470,9 @@ class ReformerLayer(nn.Module): # every forward pass we sample a different seed # for dropout and save for forward fn in backward pass # to have correct dropout - self._init_attention_seed() + if self.training: + self._init_attention_seed() + attn_outputs = self.attention( hidden_states=hidden_states, head_mask=head_mask, @@ -1494,7 +1495,8 @@ class ReformerLayer(nn.Module): # every forward pass we sample a different seed # for dropout and save seed for forward fn in backward # to have correct dropout - self._init_feed_forward_seed() + if self.training: + self._init_feed_forward_seed() # Y_2 = X_2 + g(Y_1) hidden_states = hidden_states + self.feed_forward(attn_output) @@ -2263,7 +2265,7 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel): attentions=reformer_outputs.attentions, ) - def prepare_inputs_for_generation(self, input_ids, past, **kwargs): + def prepare_inputs_for_generation(self, input_ids, past=None, use_cache=None, num_hashes=None, **kwargs): # only last token for inputs_ids if past is defined in kwargs if past is not None: input_ids = input_ids[:, -1:] @@ -2271,12 +2273,10 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel): inputs_dict = { "input_ids": input_ids, "past_buckets_states": past, - "use_cache": kwargs["use_cache"], + "use_cache": use_cache, + "num_hashes": num_hashes, } - if "num_hashes" in kwargs: - inputs_dict["num_hashes"] = kwargs["num_hashes"] - return inputs_dict def _reorder_cache(self, past, beam_idx): diff --git a/src/transformers/modeling_t5.py b/src/transformers/modeling_t5.py index 8910be33214..d31524b31a6 100644 --- a/src/transformers/modeling_t5.py +++ b/src/transformers/modeling_t5.py @@ -1232,7 +1232,9 @@ class T5ForConditionalGeneration(T5PreTrainedModel): encoder_attentions=encoder_outputs.attentions, ) - def prepare_inputs_for_generation(self, input_ids, past, attention_mask, use_cache, encoder_outputs, **kwargs): + def prepare_inputs_for_generation( + self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs + ): # cut decoder_input_ids if past is used if past is not None: diff --git a/src/transformers/modeling_transfo_xl.py b/src/transformers/modeling_transfo_xl.py index a33a0c1f27c..0f188533e75 100644 --- a/src/transformers/modeling_transfo_xl.py +++ b/src/transformers/modeling_transfo_xl.py @@ -1091,7 +1091,7 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel): else: return self.crit.out_layers[-1] - def prepare_inputs_for_generation(self, input_ids, past, **model_kwargs): + def prepare_inputs_for_generation(self, input_ids, past=None, **model_kwargs): inputs = {} # if past is defined in model kwargs then use it for faster decoding diff --git a/src/transformers/modeling_xlnet.py b/src/transformers/modeling_xlnet.py index 2264cdf3024..fd3113fa263 100755 --- a/src/transformers/modeling_xlnet.py +++ b/src/transformers/modeling_xlnet.py @@ -1300,7 +1300,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): def get_output_embeddings(self): return self.lm_loss - def prepare_inputs_for_generation(self, input_ids, past, **kwargs): + def prepare_inputs_for_generation(self, input_ids, past=None, use_cache=None, **kwargs): # Add dummy token at the end (no attention on this one) effective_batch_size = input_ids.shape[0] @@ -1333,7 +1333,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): "input_ids": input_ids, "perm_mask": perm_mask, "target_mapping": target_mapping, - "use_cache": kwargs["use_cache"], + "use_cache": use_cache, } # if past is defined in model kwargs then use it for faster decoding diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index e40b80a4493..3bd9314e92d 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -88,8 +88,8 @@ def is_pipeline_test(test_case): """ Decorator marking a test as a pipeline test. - Pipeline tests are skipped by default and we can run only them by setting RUN_PIPELINE_TEST environment variable to - a truthy value and selecting the is_pipeline_test pytest mark. + Pipeline tests are skipped by default and we can run only them by setting RUN_PIPELINE_TESTS environment variable + to a truthy value and selecting the is_pipeline_test pytest mark. """ if not _run_pipeline_tests: diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 9e4a8ad6f7b..9109c1a25d4 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -104,6 +104,66 @@ class TextDatasetForNextSentencePrediction: requires_pytorch(self) +class BeamScorer: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + +class BeamSearchScorer: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + +class LogitsProcessor: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + +class LogitsProcessorList: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + +class LogitsWarper: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + +class MinLengthLogitsProcessor: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + +class NoBadWordsLogitsProcessor: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + +class NoRepeatNGramLogitsProcessor: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + +class RepetitionPenaltyLogitsProcessor: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + +class TemperatureLogitsWarper: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + +class TopKLogitsWarper: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + +class TopPLogitsWarper: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + def top_k_top_p_filtering(*args, **kwargs): requires_pytorch(top_k_top_p_filtering) diff --git a/tests/test_generation_beam_search.py b/tests/test_generation_beam_search.py new file mode 100644 index 00000000000..10a932395f9 --- /dev/null +++ b/tests/test_generation_beam_search.py @@ -0,0 +1,239 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Team Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a clone of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +from transformers import is_torch_available +from transformers.testing_utils import require_torch, torch_device + +from .test_modeling_common import floats_tensor, ids_tensor + + +if is_torch_available(): + import torch + + from transformers.generation_beam_search import BeamHypotheses, BeamSearchScorer + + +class BeamSearchTester: + def __init__( + self, + parent, + batch_size=3, + sequence_length=10, + vocab_size=99, + pad_token_id=0, + max_length=20, + num_beams=4, + length_penalty=2.0, + do_early_stopping=True, + num_beam_hyps_to_keep=2, + ): + self.parent = parent + self.batch_size = batch_size + self.sequence_length = sequence_length + self.vocab_size = vocab_size + self.pad_token_id = pad_token_id + self.max_length = max_length + self.num_beams = num_beams + self.length_penalty = length_penalty + self.do_early_stopping = do_early_stopping + self.num_beam_hyps_to_keep = num_beam_hyps_to_keep + + # cannot be randomely generated + self.eos_token_id = vocab_size + 1 + + def prepare_beam_scorer(self, **kwargs): + return BeamSearchScorer( + batch_size=kwargs.get("batch_size", self.batch_size), + max_length=kwargs.get("max_length", self.max_length), + num_beams=kwargs.get("num_beams", self.num_beams), + device=torch_device, + length_penalty=kwargs.get("length_penalty", self.length_penalty), + do_early_stopping=kwargs.get("do_early_stopping", self.do_early_stopping), + num_beam_hyps_to_keep=kwargs.get("num_beam_hyps_to_keep", self.num_beam_hyps_to_keep), + ) + + def prepare_inputs(self): + input_ids = ids_tensor((self.batch_size * self.num_beams, self.sequence_length), self.vocab_size) + next_tokens = ids_tensor((self.batch_size, 2 * self.num_beams), self.vocab_size).to(torch_device) + next_indices = ids_tensor((self.batch_size, 2 * self.num_beams), self.num_beams).to(torch_device) + next_scores, _ = (-floats_tensor((self.batch_size, 2 * self.num_beams)).to(torch_device)).sort(descending=True) + return (input_ids, next_tokens, next_indices, next_scores) + + def check_beam_hypotheses(self, input_ids, *args): + # check that correct number of beam hypotheses is set in beam scorer + beam_scorer = self.prepare_beam_scorer(do_early_stopping=True) + beam_hyp = beam_scorer._beam_hyps[0] + + self.parent.assertEqual(len(beam_scorer._beam_hyps), self.batch_size) + + # check correct type + self.parent.assertTrue(isinstance(beam_hyp, BeamHypotheses)) + + # check that num_beams is correctly set + self.parent.assertEqual(beam_hyp.num_beams, self.num_beams) + + # check for early stopping deactivated + for beam_idx in range(self.num_beams): + beam_hyp.add(input_ids[beam_idx], -10.0) + + # if early stopping True -> score does not matter + self.parent.assertTrue(beam_hyp.is_done(-10.0, 5)) + + # re-init + beam_scorer = self.prepare_beam_scorer(do_early_stopping=False) + beam_hyp = beam_scorer._beam_hyps[0] + + # add `num_beams + 1` beams to change `worst_score` + for beam_idx in range(self.num_beams + 1): + beam_hyp.add(input_ids[beam_idx], -10.0 + float(beam_idx)) + + # -10.0 is removed => -9.0 is worst score + self.parent.assertAlmostEqual(beam_hyp.worst_score, -9.0 / (self.sequence_length ** beam_hyp.length_penalty)) + + # -5.0 is better than worst score => should not be finished + self.parent.assertFalse(beam_hyp.is_done(-5.0, self.sequence_length)) + + # -20.0 is worse than worst score => should be finished + self.parent.assertTrue(beam_hyp.is_done(-20.0, self.sequence_length)) + + def check_beam_scorer_update(self, input_ids, next_tokens, next_indices, next_scores): + # check too many eos tokens + beam_scorer = self.prepare_beam_scorer() + + tokens = next_tokens.clone() + tokens[0, :] = self.eos_token_id + + with self.parent.assertRaises(ValueError): + beam_scorer.process(input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id) + + # check all batches are done + beam_scorer = self.prepare_beam_scorer() + + tokens = next_tokens.clone() + tokens[:, : self.num_beams] = self.eos_token_id + beam_scorer.process(input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id) + # beam scorer should be done + self.parent.assertTrue(beam_scorer.is_done) + + # check + beam_scorer = self.prepare_beam_scorer() + + tokens = next_tokens.clone() + tokens[:, 1] = self.eos_token_id + beam_outputs = beam_scorer.process( + input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id + ) + output_scores = beam_outputs["next_beam_scores"] + output_tokens = beam_outputs["next_beam_tokens"] + output_indices = beam_outputs["next_beam_indices"] + + def cut_expected_tensor(tensor): + return torch.cat([tensor[:, :1], tensor[:, 2 : self.num_beams + 1]], dim=1).flatten() + + # check all outptus + # cut out id of eos token and take best `num_beams` outputs + expected_output_tokens = cut_expected_tensor(tokens) + expected_output_scores = cut_expected_tensor(next_scores) + + # add num_beams * batch_idx + expected_output_indices = ( + cut_expected_tensor(next_indices) + + (torch.arange(self.num_beams * self.batch_size, device=torch_device) // self.num_beams) * self.num_beams + ) + + self.parent.assertListEqual(expected_output_tokens.tolist(), output_tokens.tolist()) + self.parent.assertListEqual(expected_output_indices.tolist(), output_indices.tolist()) + self.parent.assertTrue(torch.allclose(expected_output_scores, output_scores, atol=1e-3)) + + # make sure ids of eos token are correctly saved in beam_hyps of beam scorer + for batch_idx in range(self.batch_size): + correct_idx = batch_idx * self.num_beams + next_indices[batch_idx, 1] + self.parent.assertListEqual( + input_ids[correct_idx].tolist(), beam_scorer._beam_hyps[batch_idx].beams[0][-1].tolist() + ) + + def check_beam_scores_finalize(self, input_ids, next_tokens, next_indices, next_scores): + # max_length should be only one more than current input_ids to check that eos is correctly appended + max_length = self.sequence_length + 1 + beam_scorer = self.prepare_beam_scorer( + num_beam_hyps_to_keep=1, max_length=max_length, length_penalty=1.0, do_early_stopping=False + ) + + # update beams and append to input_ids + tokens = next_tokens.clone() + # first batch, first output has to finish with eos token id since scores are correctly sorted + tokens[0, 0] = self.eos_token_id + # make sure corresponding score is as good as possible to surely be picked first + next_scores[0, 0] = 0.0 + beam_outputs = beam_scorer.process( + input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id + ) + output_scores = beam_outputs["next_beam_scores"] + output_tokens = beam_outputs["next_beam_tokens"] + output_indices = beam_outputs["next_beam_indices"] + + input_ids = torch.cat([input_ids[output_indices, :], output_tokens.unsqueeze(-1)], dim=-1) + + # finalize + decoded = beam_scorer.finalize( + input_ids, + output_scores, + output_tokens, + output_indices, + pad_token_id=self.pad_token_id, + eos_token_id=self.eos_token_id, + ) + # since `num_beam_hyps_to_keep` = 1 => only return `batch_size` x `max_length` + self.parent.assertListEqual(list(decoded.shape), [self.batch_size, max_length]) + + # first batch has to finish with eos_token + self.parent.assertEqual(decoded[0, -1].item(), self.eos_token_id) + + # other batches cannot finish with eos token + self.parent.assertNotEqual(decoded[1, -1].item(), self.eos_token_id) + self.parent.assertNotEqual(decoded[2, -1].item(), self.eos_token_id) + + # now test that if `num_beam_hyps_to_keep` is 3 => all beams are returned + beam_scorer.num_beam_hyps_to_keep = self.num_beams + decoded = beam_scorer.finalize( + input_ids, + output_scores, + output_tokens, + output_indices, + pad_token_id=self.pad_token_id, + eos_token_id=self.eos_token_id, + ) + self.parent.assertListEqual(list(decoded.shape), [self.num_beams * self.batch_size, max_length]) + + +@require_torch +class BeamSearchTest(unittest.TestCase): + def setUp(self): + self.beam_search_tester = BeamSearchTester(self) + + def test_beam_hypotheses(self): + inputs = self.beam_search_tester.prepare_inputs() + self.beam_search_tester.check_beam_hypotheses(*inputs) + + def test_beam_scorer_update(self): + inputs = self.beam_search_tester.prepare_inputs() + self.beam_search_tester.check_beam_scorer_update(*inputs) + + def test_beam_scorer_finalize(self): + inputs = self.beam_search_tester.prepare_inputs() + self.beam_search_tester.check_beam_scores_finalize(*inputs) diff --git a/tests/test_generation_logits_process.py b/tests/test_generation_logits_process.py new file mode 100644 index 00000000000..bf3ee067b32 --- /dev/null +++ b/tests/test_generation_logits_process.py @@ -0,0 +1,283 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Team Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a clone of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +from transformers import is_torch_available +from transformers.testing_utils import require_torch, torch_device + +from .test_modeling_common import ids_tensor + + +if is_torch_available(): + import torch + import torch.nn.functional as F + + from transformers.generation_logits_process import ( + LogitsProcessorList, + MinLengthLogitsProcessor, + NoBadWordsLogitsProcessor, + NoRepeatNGramLogitsProcessor, + RepetitionPenaltyLogitsProcessor, + TemperatureLogitsWarper, + TopKLogitsWarper, + TopPLogitsWarper, + ) + + +@require_torch +class LogitsProcessorTest(unittest.TestCase): + def _get_uniform_logits(self, batch_size: int, length: int): + scores = torch.ones((batch_size, length), device=torch_device, dtype=torch.float) / length + return scores + + def test_min_lenght_dist_processor(self): + vocab_size = 20 + batch_size = 4 + eos_token_id = 0 + + min_dist_processor = MinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id) + + # check that min length is applied at length 5 + input_ids = ids_tensor((batch_size, 5), vocab_size=20) + scores = self._get_uniform_logits(batch_size, vocab_size) + scores_before_min_length = min_dist_processor(input_ids, scores) + self.assertListEqual(scores_before_min_length[:, eos_token_id].tolist(), 4 * [-float("inf")]) + + # check that min length is not applied anymore at length 15 + input_ids = ids_tensor((batch_size, 15), vocab_size=20) + scores = self._get_uniform_logits(batch_size, vocab_size) + scores_before_min_length = min_dist_processor(input_ids, scores) + self.assertFalse(torch.isinf(scores_before_min_length).any()) + + def test_temperature_dist_warper(self): + input_ids = None + length = 20 + + scores = self._get_uniform_logits(batch_size=2, length=length) + + # tweak scores to not be uniform anymore + scores[1, 5] = (1 / length) + 0.1 # peak, 1st batch + scores[1, 10] = (1 / length) - 0.4 # valley, 1st batch + + # compute softmax + probs = F.softmax(scores, dim=-1) + + temp_dist_warper_sharper = TemperatureLogitsWarper(temperature=0.5) + temp_dist_warper_smoother = TemperatureLogitsWarper(temperature=1.3) + + warped_prob_sharp = F.softmax(temp_dist_warper_sharper(input_ids, scores.clone()), dim=-1) + warped_prob_smooth = F.softmax(temp_dist_warper_smoother(input_ids, scores.clone()), dim=-1) + + # uniform distribution stays uniform + self.assertTrue(torch.allclose(probs[0, :], warped_prob_sharp[0, :], atol=1e-3)) + self.assertTrue(torch.allclose(probs[0, :], warped_prob_smooth[0, :], atol=1e-3)) + + # sharp peaks get higher, valleys get lower + self.assertLess(probs[1, :].max(), warped_prob_sharp[1, :].max()) + self.assertGreater(probs[1, :].min(), warped_prob_sharp[1, :].min()) + + # smooth peaks get lower, valleys get higher + self.assertGreater(probs[1, :].max(), warped_prob_smooth[1, :].max()) + self.assertLess(probs[1, :].min(), warped_prob_smooth[1, :].min()) + + def test_repetition_penalty_dist_process(self): + input_ids = torch.tensor([[0, 1], [5, 0]], device=torch_device, dtype=torch.long) + vocab_size = 10 + + scores = self._get_uniform_logits(batch_size=2, length=vocab_size) + + # give values special values + scores[0, 0] = -(1 / vocab_size) + scores[1, 5] = 4 / vocab_size + + rep_penalty_proc = RepetitionPenaltyLogitsProcessor(penalty=2.0) + + scores = rep_penalty_proc(input_ids, scores.clone()) + + # check that values were correctly changed + self.assertAlmostEqual(scores[0, 0].item(), -(1 / vocab_size) * 2) + self.assertAlmostEqual(scores[0, 1].item(), (1 / vocab_size) / 2) + + self.assertAlmostEqual(scores[1, 0].item(), (1 / vocab_size) / 2) + self.assertAlmostEqual(scores[1, 5].item(), (4 / vocab_size) / 2) + + def test_top_k_dist_warper(self): + input_ids = None + vocab_size = 10 + batch_size = 2 + + # create ramp distribution + ramp_logits = ( + torch.arange(vocab_size, device=torch_device, dtype=torch.float).unsqueeze(0).repeat(batch_size, 1) + ) + ramp_logits[1:, : vocab_size // 2] = ramp_logits[1:, : vocab_size // 2] + vocab_size + + top_k_warp = TopKLogitsWarper(3) + + scores = top_k_warp(input_ids, ramp_logits) + + # check that correct tokens are filtered + self.assertListEqual(torch.isinf(scores[0]).tolist(), 7 * [True] + 3 * [False]) + self.assertListEqual(torch.isinf(scores[1]).tolist(), 2 * [True] + 3 * [False] + 5 * [True]) + + # check special cases + length = 5 + + logits = self._get_uniform_logits(batch_size=batch_size, length=length) + top_k_warp_safety_check = TopKLogitsWarper(top_k=1, filter_value=0.0, min_tokens_to_keep=3) + + scores = top_k_warp_safety_check(input_ids, logits) + # uniform dist is not changed + self.assertListEqual((scores == 0.0).to(torch.long).sum(dim=-1).tolist(), [0, 0]) + + ramp_logits = torch.arange(length, device=torch_device, dtype=torch.float).unsqueeze(0).repeat(batch_size, 1) + scores = top_k_warp_safety_check(input_ids, ramp_logits) + + # min_tokens overwrites k: 3 tokens are kept => 2 tokens are nullified + self.assertListEqual((scores == 0.0).to(torch.long).sum(dim=-1).tolist(), [2, 2]) + + def test_top_p_dist_warper(self): + input_ids = None + vocab_size = 10 + batch_size = 2 + + # create distribution and take log (inverse to Softmax as taken in TopPLogitsWarper) + dist = torch.log( + torch.tensor([[0.3, 0.1, 0.1, 0.5], [0.15, 0.3, 0.3, 0.25]], device=torch_device, dtype=torch.float) + ) + + top_p_warp = TopPLogitsWarper(0.7) + filtered_dist = torch.exp(top_p_warp(input_ids, dist)) + + # dist should be filtered to keep min num values so that sum is >= 0.7 + # exp (-inf) => 0 + EXPECTED_FILTERED_DIST = torch.tensor( + [[0.3, 0.0, 0.0, 0.5], [0.0, 0.3, 0.3, 0.25]], device=torch_device, dtype=torch.float + ) + self.assertTrue(torch.allclose(filtered_dist, EXPECTED_FILTERED_DIST, atol=1e-3)) + + # check edge cases with negative and extreme logits + ramp_logits = torch.arange(vocab_size, device=torch_device, dtype=torch.float).unsqueeze(0).repeat( + batch_size, 1 + ) - (vocab_size // 2) + + # make ramp_logits more extreme + ramp_logits[1] = ramp_logits[1] * 100.0 + + # make sure at least 2 tokens are kept + top_p_warp = TopPLogitsWarper(0.9, min_tokens_to_keep=2, filter_value=0.0) + filtered_dist = top_p_warp(input_ids, ramp_logits) + + # first batch should keep three tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2. + self.assertListEqual((filtered_dist != 0.0).to(torch.long).sum(dim=-1).tolist(), [3, 2]) + + def test_no_repeat_ngram_dist_processor(self): + vocab_size = 3 + batch_size = 2 + + input_ids = torch.tensor([[1, 1, 2, 1], [0, 1, 0, 1]], device=torch_device, dtype=torch.long) + scores = self._get_uniform_logits(batch_size, vocab_size) + + no_repeat_proc_2_gram = NoRepeatNGramLogitsProcessor(2) + no_repeat_proc_3_gram = NoRepeatNGramLogitsProcessor(3) + + filtered_scores_2_gram = no_repeat_proc_2_gram(input_ids, scores.clone()) + filtered_scores_3_gram = no_repeat_proc_3_gram(input_ids, scores.clone()) + + # 2-gram would forbid 2nd and 3rd token (1,2) at 1st batch and 1st token (0) at 2nd batch + self.assertListEqual(torch.isinf(filtered_scores_2_gram).tolist(), [[False, True, True], [True, False, False]]) + + # 3-gram would forbid no token at 1st batch and 1st token (0) at 2nd batch + self.assertListEqual( + torch.isinf(filtered_scores_3_gram).tolist(), [[False, False, False], [True, False, False]] + ) + + def test_no_bad_words_dist_processor(self): + vocab_size = 5 + batch_size = 2 + eos_token_id = 4 + + input_ids = torch.tensor([[0, 1, 3, 1], [0, 1, 0, 1]], device=torch_device, dtype=torch.long) + bad_word_tokens = [[1], [4], [1, 0], [0, 1, 2], [1, 3, 1, 3]] + scores = self._get_uniform_logits(batch_size, vocab_size) + + no_bad_words_dist_proc = NoBadWordsLogitsProcessor(bad_words_ids=bad_word_tokens, eos_token_id=eos_token_id) + + filtered_scores = no_bad_words_dist_proc(input_ids, scores.clone()) + + # batch 1: 1st, 2nd, and 4th (0, 1, 3) token are forbidden + # batch 2: 1st, 2nd, and 3rd (0, 1, 2) token are forbidden + # Note that 5th element cannot be forbidden as it is EOS token + self.assertListEqual( + torch.isinf(filtered_scores).tolist(), [[True, True, False, True, False], [True, True, True, False, False]] + ) + + # check edge case + no_bad_words_dist_proc = NoBadWordsLogitsProcessor(bad_words_ids=[[4]], eos_token_id=eos_token_id) + filtered_scores = no_bad_words_dist_proc(input_ids, scores.clone()) + self.assertTrue(torch.allclose(scores, filtered_scores, atol=1e-3)) + + def test_processor_list(self): + batch_size = 4 + sequence_length = 10 + vocab_size = 15 + eos_token_id = 0 + + # dummy input_ids and scores + input_ids = ids_tensor((batch_size, sequence_length), vocab_size) + input_ids_comp = input_ids.clone() + + scores = self._get_uniform_logits(batch_size, vocab_size) + scores_comp = scores.clone() + + # instantiate all dist processors + min_dist_proc = MinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id) + temp_dist_warp = TemperatureLogitsWarper(temperature=0.5) + rep_penalty_proc = RepetitionPenaltyLogitsProcessor(penalty=2.0) + top_k_warp = TopKLogitsWarper(3) + top_p_warp = TopPLogitsWarper(0.8) + no_repeat_proc = NoRepeatNGramLogitsProcessor(2) + no_bad_words_dist_proc = NoBadWordsLogitsProcessor(bad_words_ids=[[1]], eos_token_id=eos_token_id) + + # no processor list + scores = min_dist_proc(input_ids, scores) + scores = temp_dist_warp(input_ids, scores) + scores = rep_penalty_proc(input_ids, scores) + scores = top_k_warp(input_ids, scores) + scores = top_p_warp(input_ids, scores) + scores = no_repeat_proc(input_ids, scores) + scores = no_bad_words_dist_proc(input_ids, scores) + + # with processor list + processor = LogitsProcessorList( + [ + min_dist_proc, + temp_dist_warp, + rep_penalty_proc, + top_k_warp, + top_p_warp, + no_repeat_proc, + no_bad_words_dist_proc, + ] + ) + scores_comp = processor(input_ids, scores_comp) + + # scores should be equal + self.assertTrue(torch.allclose(scores, scores_comp, atol=1e-3)) + + # input_ids should never be changed + self.assertListEqual(input_ids.tolist(), input_ids_comp.tolist()) diff --git a/tests/test_generation_utils.py b/tests/test_generation_utils.py new file mode 100644 index 00000000000..0cdd80dd741 --- /dev/null +++ b/tests/test_generation_utils.py @@ -0,0 +1,510 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Team Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a clone of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +from transformers import is_torch_available +from transformers.testing_utils import require_torch, torch_device + + +if is_torch_available(): + import torch + + from transformers import top_k_top_p_filtering + from transformers.generation_beam_search import BeamSearchScorer + from transformers.generation_logits_process import ( + LogitsProcessorList, + MinLengthLogitsProcessor, + NoBadWordsLogitsProcessor, + NoRepeatNGramLogitsProcessor, + RepetitionPenaltyLogitsProcessor, + TemperatureLogitsWarper, + TopKLogitsWarper, + TopPLogitsWarper, + ) + + +class GenerationTesterMixin: + model_tester = None + all_generative_model_classes = () + + def _get_input_ids_and_config(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + input_ids = inputs_dict["input_ids"] + attention_mask = torch.ones_like(input_ids) + + # cut to half length & take max batch_size 3 + max_batch_size = 2 + sequence_length = input_ids.shape[-1] // 2 + input_ids = input_ids[:max_batch_size, :sequence_length] + attention_mask = attention_mask[:max_batch_size, :sequence_length] + + # generate max 3 tokens + max_length = input_ids.shape[-1] + 3 + if config.eos_token_id is not None and config.pad_token_id is None: + # hack to allow generate for models such as GPT2 as is done in `generate()` + config.pad_token_id = config.eos_token_id + return config, input_ids, attention_mask, max_length + + @staticmethod + def _get_logits_processor_and_kwargs(input_length, eos_token_id): + process_kwargs = { + "min_length": input_length + 1, + "bad_words_ids": [[1, 0]], + "no_repeat_ngram_size": 2, + "repetition_penalty": 1.2, + } + logits_processor = LogitsProcessorList( + ( + [ + MinLengthLogitsProcessor(process_kwargs["min_length"], eos_token_id), + ] + if eos_token_id is not None + else [] + ) + + [ + NoBadWordsLogitsProcessor(process_kwargs["bad_words_ids"], eos_token_id), + NoRepeatNGramLogitsProcessor(process_kwargs["no_repeat_ngram_size"]), + RepetitionPenaltyLogitsProcessor(process_kwargs["repetition_penalty"]), + ] + ) + return process_kwargs, logits_processor + + @staticmethod + def _get_warper_and_kwargs(num_beams): + warp_kwargs = {"top_k": 10, "top_p": 0.7, "temperature": 0.7} + logits_warper = LogitsProcessorList( + [ + TopKLogitsWarper(top_k=warp_kwargs["top_k"], min_tokens_to_keep=(2 if num_beams > 1 else 1)), + TopPLogitsWarper(top_p=warp_kwargs["top_p"], min_tokens_to_keep=(2 if num_beams > 1 else 1)), + TemperatureLogitsWarper(warp_kwargs["temperature"]), + ] + ) + return warp_kwargs, logits_warper + + @staticmethod + def _get_beam_scorer_and_kwargs(batch_size, max_length, num_return_sequences=1): + beam_kwargs = { + "early_stopping": False, + "length_penalty": 2.0, + "num_beams": 2, + "num_return_sequences": num_return_sequences, + } + beam_scorer = BeamSearchScorer( + batch_size=batch_size, + max_length=max_length, + num_beams=beam_kwargs["num_beams"], + device=torch_device, + length_penalty=beam_kwargs["length_penalty"], + do_early_stopping=beam_kwargs["early_stopping"], + num_beam_hyps_to_keep=num_return_sequences, + ) + return beam_kwargs, beam_scorer + + @staticmethod + def _get_encoder_outputs(model, input_ids, attention_mask, num_interleave=1): + encoder = model.get_encoder() + encoder_outputs = encoder(input_ids, attention_mask=attention_mask, return_dict=True) + encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.repeat_interleave( + num_interleave, dim=0 + ) + input_ids = torch.zeros_like(input_ids[:, :1]) + model._get_decoder_start_token_id() + attention_mask = None + return encoder_outputs, input_ids, attention_mask + + def test_greedy_generate(self): + for model_class in self.all_generative_model_classes: + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + + logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( + input_ids.shape[-1], config.eos_token_id + ) + + model = model_class(config).to(torch_device) + model.eval() + + # check `generate()` and `greedy_search()` are equal + kwargs = {} + if model.config.is_encoder_decoder: + encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs( + model, input_ids, attention_mask + ) + kwargs["encoder_outputs"] = encoder_outputs + max_length = 4 + + output_ids_generate = model.generate( + input_ids, + attention_mask=attention_mask, + do_sample=False, + num_beams=1, + max_length=max_length, + **logits_process_kwargs, + ) + with torch.no_grad(): + output_ids_greedy = model.greedy_search( + input_ids, + max_length=max_length, + attention_mask=attention_mask, + logits_processor=logits_processor, + **kwargs, + ) + self.assertListEqual(output_ids_generate.tolist(), output_ids_greedy.tolist()) + + def test_sample_generate(self): + for model_class in self.all_generative_model_classes: + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( + input_ids.shape[-1], config.eos_token_id + ) + logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1) + + model = model_class(config).to(torch_device) + model.eval() + + # check `generate()` and `sample()` are equal + if model.config.is_encoder_decoder: + max_length = 4 + + torch.manual_seed(0) + output_ids_generate = model.generate( + input_ids, + do_sample=True, + num_beams=1, + max_length=max_length, + attention_mask=attention_mask, + **logits_warper_kwargs, + **process_kwargs, + ) + + torch.manual_seed(0) + kwargs = {} + if model.config.is_encoder_decoder: + encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs( + model, input_ids, attention_mask + ) + kwargs["encoder_outputs"] = encoder_outputs + else: + attention_mask_clone = attention_mask + input_ids_clone = input_ids + + with torch.no_grad(): + output_ids_sample = model.sample( + input_ids_clone, + attention_mask=attention_mask_clone, + max_length=max_length, + logits_processor=logits_processor, + logits_warper=logits_warper, + **kwargs, + ) + self.assertListEqual(output_ids_generate.tolist(), output_ids_sample.tolist()) + + # check `generate()` and `sample()` yield equal results for `num_return_sequences` + num_return_sequences = 3 + if model.config.is_encoder_decoder: + max_length = 4 + + torch.manual_seed(0) + output_ids_generate = model.generate( + input_ids, + do_sample=True, + num_beams=1, + max_length=max_length, + num_return_sequences=num_return_sequences, + attention_mask=attention_mask, + **logits_warper_kwargs, + **process_kwargs, + ) + + torch.manual_seed(0) + kwargs = {} + if model.config.is_encoder_decoder: + encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs( + model, input_ids, attention_mask, num_interleave=num_return_sequences + ) + kwargs["encoder_outputs"] = encoder_outputs + input_ids_clone = input_ids_clone.repeat_interleave(num_return_sequences, dim=0) + else: + attention_mask_clone = attention_mask.repeat_interleave(num_return_sequences, dim=0) + input_ids_clone = input_ids.repeat_interleave(num_return_sequences, dim=0) + + with torch.no_grad(): + output_ids_sample = model.sample( + input_ids_clone, + attention_mask=attention_mask_clone, + max_length=max_length, + logits_processor=logits_processor, + logits_warper=logits_warper, + **kwargs, + ) + self.assertListEqual(output_ids_generate.tolist(), output_ids_sample.tolist()) + + def test_beam_search_generate(self): + for model_class in self.all_generative_model_classes: + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + + logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs( + input_ids.shape[-1], config.eos_token_id + ) + + model = model_class(config).to(torch_device) + model.eval() + + # check `generate()` and `beam_search()` are equal + if model.config.is_encoder_decoder: + max_length = 4 + beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length) + output_ids_generate = model.generate( + input_ids, + attention_mask=attention_mask, + do_sample=False, + max_length=max_length, + **beam_kwargs, + **logits_process_kwargs, + ) + + # beam_search does not automatically interleave `batch_size` dim for `num_beams` + kwargs = {} + if model.config.is_encoder_decoder: + encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs( + model, input_ids, attention_mask, num_interleave=beam_scorer.num_beams + ) + kwargs["encoder_outputs"] = encoder_outputs + input_ids_clone = input_ids_clone.repeat_interleave(beam_scorer.num_beams, dim=0) + else: + attention_mask_clone = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0) + input_ids_clone = input_ids.repeat_interleave(beam_scorer.num_beams, dim=0) + + with torch.no_grad(): + output_ids_beam_search = model.beam_search( + input_ids_clone, + beam_scorer, + max_length=max_length, + attention_mask=attention_mask_clone, + logits_processor=logits_processor, + **kwargs, + ) + self.assertListEqual(output_ids_generate.tolist(), output_ids_beam_search.tolist()) + + # check `generate()` and `beam_search()` are equal for `num_return_sequences` + num_return_sequences = 2 + if model.config.is_encoder_decoder: + max_length = 4 + beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs( + input_ids.shape[0], max_length, num_return_sequences=num_return_sequences + ) + + output_ids_generate = model.generate( + input_ids, + attention_mask=attention_mask, + do_sample=False, + max_length=max_length, + **beam_kwargs, + **logits_process_kwargs, + ) + # beam_search does not automatically interleave `batch_size` dim for `num_beams` + kwargs = {} + if model.config.is_encoder_decoder: + encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs( + model, input_ids, attention_mask, num_interleave=beam_scorer.num_beams + ) + kwargs["encoder_outputs"] = encoder_outputs + input_ids_clone = input_ids_clone.repeat_interleave(beam_scorer.num_beams, dim=0) + else: + attention_mask_clone = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0) + input_ids_clone = input_ids.repeat_interleave(beam_scorer.num_beams, dim=0) + + with torch.no_grad(): + output_ids_beam_search = model.beam_search( + input_ids_clone, + beam_scorer, + max_length=max_length, + attention_mask=attention_mask_clone, + logits_processor=logits_processor, + **kwargs, + ) + self.assertListEqual(output_ids_generate.tolist(), output_ids_beam_search.tolist()) + + def test_beam_sample_generate(self): + for model_class in self.all_generative_model_classes: + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1) + + model = model_class(config).to(torch_device) + model.eval() + + # check `generate()` and `beam_search()` are equal + # change `num_return_sequences = 2` but not for `beam_scorer` + num_return_sequences = 2 + if model.config.is_encoder_decoder: + max_length = 4 + beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs( + input_ids.shape[0] * num_return_sequences, max_length + ) + beam_kwargs["num_return_sequences"] = num_return_sequences + torch.manual_seed(0) + output_ids_generate = model.generate( + input_ids, + attention_mask=attention_mask, + do_sample=True, + max_length=max_length, + **beam_kwargs, + **logits_warper_kwargs, + ) + # beam_search does not automatically interleave `batch_size` dim for `num_beams * num_return_sequences` + kwargs = {} + if model.config.is_encoder_decoder: + encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs( + model, input_ids, attention_mask, num_interleave=beam_scorer.num_beams * num_return_sequences + ) + kwargs["encoder_outputs"] = encoder_outputs + else: + attention_mask = attention_mask.repeat_interleave(beam_scorer.num_beams * num_return_sequences, dim=0) + + torch.manual_seed(0) + with torch.no_grad(): + output_ids_beam_sample = model.beam_sample( + input_ids.repeat_interleave(beam_scorer.num_beams * num_return_sequences, dim=0), + beam_scorer, + max_length=max_length, + attention_mask=attention_mask, + logits_warper=logits_warper, + **kwargs, + ) + self.assertListEqual(output_ids_generate.tolist(), output_ids_beam_sample.tolist()) + + def test_generate_without_input_ids(self): + config, _, _, max_length = self._get_input_ids_and_config() + + # if no bos token id => cannot generate from None + if config.bos_token_id is None: + return + + for model_class in self.all_generative_model_classes: + model = model_class(config).to(torch_device) + model.eval() + + output_ids_generate = model.generate( + do_sample=False, + max_length=max_length, + ) + + self.assertIsNotNone(output_ids_generate) + + +@require_torch +class UtilsFunctionsTest(unittest.TestCase): + + # tests whether the top_k_top_p function behaves as expected + def test_top_k_top_p_filtering(self): + logits = torch.tensor( + [ + [ + 8.2220991, # 3rd highest value; idx. 0 + -0.5620044, + 5.23229752, + 4.0386393, + -6.8798378, + -0.54785802, + -3.2012153, + 2.92777176, + 1.88171953, + 7.35341276, + 8.43207833, # 2nd highest value; idx. 10 + -9.85711836, + -5.96209236, + -1.13039161, + -7.1115294, + -0.8369633, + -5.3186408, + 7.06427407, + 0.81369344, + -0.82023817, + -5.9179796, + 0.58813443, + -6.99778438, + 4.71551189, + -0.18771637, + 7.44020759, # 4th highest value; idx. 25 + 9.38450987, # 1st highest value; idx. 26 + 2.12662941, + -9.32562038, + 2.35652522, + ], # cummulative prob of 4 highest values <= 0.6 + [ + 0.58425518, + 4.53139238, + -5.57510464, + -6.28030699, + -7.19529503, + -4.02122551, + 1.39337037, + -6.06707057, + 1.59480517, + -9.643119, + 0.03907799, + 0.67231762, + -8.88206726, + 6.27115922, # 4th highest value; idx. 13 + 2.28520723, + 4.82767506, + 4.30421368, + 8.8275313, # 2nd highest value; idx. 17 + 5.44029958, + -4.4735794, + 7.38579536, # 3rd highest value; idx. 20 + -2.91051663, + 2.61946077, + -2.5674762, + -9.48959302, + -4.02922645, + -1.35416918, + 9.67702323, # 1st highest value; idx. 27 + -5.89478553, + 1.85370467, + ], # cummulative prob of 4 highest values <= 0.6 + ], + dtype=torch.float, + device=torch_device, + ) + + non_inf_expected_idx = torch.tensor( + [[0, 0], [0, 10], [0, 25], [0, 26], [1, 13], [1, 17], [1, 20], [1, 27]], + dtype=torch.long, + device=torch_device, + ) # expected non filtered idx as noted above + + non_inf_expected_output = torch.tensor( + [ + 8.2221, + 8.4321, + 7.4402, + 9.3845, + 6.2712, + 8.8275, + 7.3858, + 9.6770, + ], # expected non filtered values as noted above + dtype=torch.float, + device=torch_device, + ) + + output = top_k_top_p_filtering(logits, top_k=10, top_p=0.6, min_tokens_to_keep=4) + non_inf_output = output[output != -float("inf")].to(device=torch_device) + non_inf_idx = (output != -float("inf")).nonzero().to(device=torch_device) + + self.assertTrue(torch.allclose(non_inf_expected_output, non_inf_output, atol=1e-12)) + self.assertTrue(torch.all(torch.eq(non_inf_expected_idx, non_inf_idx))) diff --git a/tests/test_modeling_bart.py b/tests/test_modeling_bart.py index 1b09a814f95..c6f9f65dca9 100644 --- a/tests/test_modeling_bart.py +++ b/tests/test_modeling_bart.py @@ -23,6 +23,7 @@ from transformers.file_utils import cached_property from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device from .test_configuration_common import ConfigTester +from .test_generation_utils import GenerationTesterMixin from .test_modeling_common import ModelTesterMixin, ids_tensor @@ -128,7 +129,7 @@ def prepare_bart_inputs_dict( @require_torch -class BARTModelTest(ModelTesterMixin, unittest.TestCase): +class BARTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): all_model_classes = ( (BartModel, BartForConditionalGeneration, BartForSequenceClassification, BartForQuestionAnswering) if is_torch_available() diff --git a/tests/test_modeling_bert.py b/tests/test_modeling_bert.py index a24de563621..c18cf5a7308 100755 --- a/tests/test_modeling_bert.py +++ b/tests/test_modeling_bert.py @@ -20,6 +20,7 @@ from transformers import is_torch_available from transformers.testing_utils import require_torch, slow, torch_device from .test_configuration_common import ConfigTester +from .test_generation_utils import GenerationTesterMixin from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask @@ -357,11 +358,12 @@ class BertModelTester: @require_torch -class BertModelTest(ModelTesterMixin, unittest.TestCase): +class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): all_model_classes = ( ( BertModel, + BertLMHeadModel, BertForMaskedLM, BertForMultipleChoice, BertForNextSentencePrediction, @@ -373,6 +375,7 @@ class BertModelTest(ModelTesterMixin, unittest.TestCase): if is_torch_available() else () ) + all_generative_model_classes = (BertLMHeadModel,) if is_torch_available() else () def setUp(self): self.model_tester = BertModelTester(self) diff --git a/tests/test_modeling_bert_generation.py b/tests/test_modeling_bert_generation.py index 0c626bd531b..f5ce360a89f 100755 --- a/tests/test_modeling_bert_generation.py +++ b/tests/test_modeling_bert_generation.py @@ -20,6 +20,7 @@ from transformers import is_torch_available from transformers.testing_utils import require_torch, slow, torch_device from .test_configuration_common import ConfigTester +from .test_generation_utils import GenerationTesterMixin from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask @@ -183,9 +184,10 @@ class BertGenerationEncoderTester: @require_torch -class BertGenerationEncoderTest(ModelTesterMixin, unittest.TestCase): +class BertGenerationEncoderTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): all_model_classes = (BertGenerationEncoder, BertGenerationDecoder) if is_torch_available() else () + all_generative_model_classes = (BertGenerationDecoder,) if is_torch_available() else () def setUp(self): self.model_tester = BertGenerationEncoderTester(self) diff --git a/tests/test_modeling_blenderbot.py b/tests/test_modeling_blenderbot.py index e3581783468..19fee17ba08 100644 --- a/tests/test_modeling_blenderbot.py +++ b/tests/test_modeling_blenderbot.py @@ -147,6 +147,7 @@ class Blenderbot3BIntegrationTests(unittest.TestCase): src_text = ["Sam"] model_inputs = self.tokenizer(src_text, return_tensors="pt").to(torch_device) + generated_utterances = model.generate(**model_inputs, **FASTER_GEN_KWARGS) tgt_text = 'Sam is a great name. It means "sun" in Gaelic.' @@ -156,6 +157,7 @@ class Blenderbot3BIntegrationTests(unittest.TestCase): src_text = "Social anxiety\nWow, I am never shy. Do you have anxiety?\nYes. I end up sweating and blushing and feel like i'm going to throw up.\nand why is that?" model_inputs = self.tokenizer([src_text], return_tensors="pt").to(torch_device) + generated_ids = model.generate(**model_inputs, **FASTER_GEN_KWARGS)[0] reply = self.tokenizer.decode(generated_ids, **TOK_DECODE_KW) @@ -187,6 +189,9 @@ class Blenderbot90MIntegrationTests(unittest.TestCase): ] model_inputs = self.tokenizer(src_text, return_tensors="pt").to(torch_device) + + # model does not have "token_type_ids" + model_inputs.pop("token_type_ids") assert isinstance(self.tokenizer, BlenderbotSmallTokenizer) generated_ids = self.model.generate(**model_inputs)[0] reply = self.tokenizer.decode(generated_ids, **TOK_DECODE_KW) @@ -198,10 +203,11 @@ class Blenderbot90MIntegrationTests(unittest.TestCase): def test_90_generation_from_short_input(self): model_inputs = self.tokenizer(["sam"], return_tensors="pt").to(torch_device) - generated_utterances = self.model.generate(**model_inputs) - # generated_txt = self.tokenizer.decode(generated_utterances[0]) - # assert generated_txt == "__start__ have you ever heard of sam harris? he's an american singer, songwriter, and actor. __end__" + # model does not have "token_type_ids" + model_inputs.pop("token_type_ids") + generated_utterances = self.model.generate(**model_inputs) + clean_txt = self.tokenizer.decode(generated_utterances[0], **TOK_DECODE_KW) assert clean_txt in ( "have you ever been to a sam club? it's a great club in the south.", diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index bbc64044f79..60316c40158 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -44,7 +44,6 @@ if is_torch_available(): BertModel, PretrainedConfig, PreTrainedModel, - top_k_top_p_filtering, ) @@ -882,126 +881,6 @@ class ModelTesterMixin: with torch.no_grad(): model(**inputs)[0] - def test_lm_head_model_random_no_beam_search_generate(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - input_ids = inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"] - - # make sure that input_ids is at most of size 15 - input_ids = input_ids[..., :15] - - # iterate over all generative models - for model_class in self.all_generative_model_classes: - model = model_class(config).to(torch_device) - model.eval() - - if config.bos_token_id is None: - # if bos token id is not defined, model needs input_ids - with self.assertRaises(AssertionError): - model.generate(do_sample=True, max_length=5) - # num_return_sequences = 1 - self._check_generated_ids(model.generate(input_ids, do_sample=True)) - else: - # num_return_sequences = 1 - self._check_generated_ids(model.generate(do_sample=True, max_length=5)) - - with self.assertRaises(AssertionError): - # generating multiple sequences when no beam search generation - # is not allowed as it would always generate the same sequences - model.generate(input_ids, do_sample=False, num_beams=1, num_return_sequences=2) - - # num_return_sequences > 1, sample - self._check_generated_ids(model.generate(input_ids, do_sample=True, num_return_sequences=2)) - - # check bad words tokens language generation - # create list of 1-seq bad token and list of 2-seq of bad tokens - bad_words_ids = [ - self._generate_random_bad_tokens(1, model.config), - self._generate_random_bad_tokens(2, model.config), - ] - output_tokens = model.generate( - input_ids, do_sample=True, bad_words_ids=bad_words_ids, num_return_sequences=2 - ) - # only count generated tokens - generated_ids = output_tokens[:, input_ids.shape[-1] :] - self.assertFalse(self._check_match_tokens(generated_ids.tolist(), bad_words_ids)) - - def test_lm_head_model_random_beam_search_generate(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - input_ids = (inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"]).to( - torch_device - ) - - # make sure that input_ids is at most of size 15 - input_ids = input_ids[..., :15] - - for model_class in self.all_generative_model_classes: - model = model_class(config).to(torch_device) - model.eval() - - if config.bos_token_id is None: - # if bos token id is not defined mobel needs input_ids, num_return_sequences = 1 - self._check_generated_ids(model.generate(input_ids, do_sample=True, num_beams=2)) - else: - # num_return_sequences = 1 - self._check_generated_ids(model.generate(do_sample=True, max_length=5, num_beams=2)) - - with self.assertRaises(AssertionError): - # generating more sequences than having beams leads is not possible - model.generate(input_ids, do_sample=False, num_return_sequences=3, num_beams=2) - - # num_return_sequences > 1, sample - self._check_generated_ids( - model.generate( - input_ids, - do_sample=True, - num_beams=2, - num_return_sequences=2, - ) - ) - # num_return_sequences > 1, greedy - self._check_generated_ids(model.generate(input_ids, do_sample=False, num_beams=2, num_return_sequences=2)) - - # check bad words tokens language generation - # create list of 1-seq bad token and list of 2-seq of bad tokens - bad_words_ids = [ - self._generate_random_bad_tokens(1, model.config), - self._generate_random_bad_tokens(2, model.config), - ] - output_tokens = model.generate( - input_ids, do_sample=False, bad_words_ids=bad_words_ids, num_beams=2, num_return_sequences=2 - ) - # only count generated tokens - generated_ids = output_tokens[:, input_ids.shape[-1] :] - self.assertFalse(self._check_match_tokens(generated_ids.tolist(), bad_words_ids)) - - def _generate_random_bad_tokens(self, num_bad_tokens: int, config) -> List[int]: - # special tokens cannot be bad tokens - special_tokens = [x for x in [config.bos_token_id, config.eos_token_id, config.pad_token_id] if x is not None] - # create random bad tokens that are not special tokens - bad_tokens = [] - while len(bad_tokens) < num_bad_tokens: - token = ids_tensor((1, 1), self.model_tester.vocab_size).squeeze(0).cpu().numpy()[0] - if token not in special_tokens: - bad_tokens.append(token) - return bad_tokens - - def _check_generated_ids(self, output_ids): - for token_id in output_ids[0].tolist(): - self.assertGreaterEqual(token_id, 0) - self.assertLess(token_id, self.model_tester.vocab_size) - - def _check_match_tokens(self, generated_ids, bad_words_ids): - # for all bad word tokens - for bad_word_ids in bad_words_ids: - # for all slices in batch - for generated_ids_slice in generated_ids: - # for all word idx - for i in range(len(bad_word_ids), len(generated_ids_slice)): - # if tokens match - if generated_ids_slice[i - len(bad_word_ids) : i] == bad_word_ids: - return True - return False - @require_torch_multigpu def test_multigpu_data_parallel_forward(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -1094,110 +973,3 @@ class ModelUtilsTest(unittest.TestCase): model = BertModel.from_pretrained(model_name, output_attentions=True, output_hidden_states=True) self.assertEqual(model.config.output_hidden_states, True) self.assertEqual(model.config, config) - - -@require_torch -class UtilsFunctionsTest(unittest.TestCase): - - # tests whether the top_k_top_p function behaves as expected - def test_top_k_top_p_filtering(self): - logits = torch.tensor( - [ - [ - 8.2220991, # 3rd highest value; idx. 0 - -0.5620044, - 5.23229752, - 4.0386393, - -6.8798378, - -0.54785802, - -3.2012153, - 2.92777176, - 1.88171953, - 7.35341276, # 5th highest value; idx. 9 - 8.43207833, # 2nd highest value; idx. 10 - -9.85711836, - -5.96209236, - -1.13039161, - -7.1115294, - -0.8369633, - -5.3186408, - 7.06427407, - 0.81369344, - -0.82023817, - -5.9179796, - 0.58813443, - -6.99778438, - 4.71551189, - -0.18771637, - 7.44020759, # 4th highest value; idx. 25 - 9.38450987, # 1st highest value; idx. 26 - 2.12662941, - -9.32562038, - 2.35652522, - ], # cumulative prob of 5 highest values <= 0.6 - [ - 0.58425518, - 4.53139238, - -5.57510464, - -6.28030699, - -7.19529503, - -4.02122551, - 1.39337037, - -6.06707057, - 1.59480517, - -9.643119, - 0.03907799, - 0.67231762, - -8.88206726, - 6.27115922, # 4th highest value; idx. 13 - 2.28520723, - 4.82767506, - 4.30421368, - 8.8275313, # 2nd highest value; idx. 17 - 5.44029958, # 5th highest value; idx. 18 - -4.4735794, - 7.38579536, # 3rd highest value; idx. 20 - -2.91051663, - 2.61946077, - -2.5674762, - -9.48959302, - -4.02922645, - -1.35416918, - 9.67702323, # 1st highest value; idx. 27 - -5.89478553, - 1.85370467, - ], # cumulative prob of 5 highest values <= 0.6 - ], - dtype=torch.float, - device=torch_device, - ) - - non_inf_expected_idx = torch.tensor( - [[0, 0], [0, 9], [0, 10], [0, 25], [0, 26], [1, 13], [1, 17], [1, 18], [1, 20], [1, 27]], - dtype=torch.long, - device=torch_device, - ) # expected non filtered idx as noted above - - non_inf_expected_output = torch.tensor( - [ - 8.2221, - 7.3534, - 8.4321, - 7.4402, - 9.3845, - 6.2712, - 8.8275, - 5.4403, - 7.3858, - 9.6770, - ], # expected non filtered values as noted above - dtype=torch.float, - device=torch_device, - ) - - output = top_k_top_p_filtering(logits, top_k=10, top_p=0.6, min_tokens_to_keep=4) - non_inf_output = output[output != -float("inf")].to(device=torch_device) - non_inf_idx = (output != -float("inf")).nonzero().to(device=torch_device) - - self.assertTrue(torch.allclose(non_inf_expected_output, non_inf_output, atol=1e-12)) - self.assertTrue(torch.all(torch.eq(non_inf_expected_idx, non_inf_idx))) diff --git a/tests/test_modeling_ctrl.py b/tests/test_modeling_ctrl.py index 39598b8ee66..d73c3f9c329 100644 --- a/tests/test_modeling_ctrl.py +++ b/tests/test_modeling_ctrl.py @@ -19,6 +19,7 @@ from transformers import is_torch_available from transformers.testing_utils import require_torch, slow, torch_device from .test_configuration_common import ConfigTester +from .test_generation_utils import GenerationTesterMixin from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask @@ -151,7 +152,7 @@ class CTRLModelTester: @require_torch -class CTRLModelTest(ModelTesterMixin, unittest.TestCase): +class CTRLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): all_model_classes = (CTRLModel, CTRLLMHeadModel) if is_torch_available() else () all_generative_model_classes = (CTRLLMHeadModel,) if is_torch_available() else () diff --git a/tests/test_modeling_fsmt.py b/tests/test_modeling_fsmt.py index 7df338c9736..138a2a39154 100644 --- a/tests/test_modeling_fsmt.py +++ b/tests/test_modeling_fsmt.py @@ -24,6 +24,7 @@ from transformers.file_utils import cached_property from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device from .test_configuration_common import ConfigTester +from .test_generation_utils import GenerationTesterMixin from .test_modeling_common import ModelTesterMixin, ids_tensor @@ -120,7 +121,7 @@ def prepare_fsmt_inputs_dict( @require_torch -class FSMTModelTest(ModelTesterMixin, unittest.TestCase): +class FSMTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): all_model_classes = (FSMTModel, FSMTForConditionalGeneration) if is_torch_available() else () all_generative_model_classes = (FSMTForConditionalGeneration,) if is_torch_available() else () is_encoder_decoder = True diff --git a/tests/test_modeling_gpt2.py b/tests/test_modeling_gpt2.py index 2fd4256f6b7..97d5bec376e 100644 --- a/tests/test_modeling_gpt2.py +++ b/tests/test_modeling_gpt2.py @@ -20,6 +20,7 @@ from transformers import is_torch_available from transformers.testing_utils import require_torch, slow, torch_device from .test_configuration_common import ConfigTester +from .test_generation_utils import GenerationTesterMixin from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask @@ -377,7 +378,7 @@ class GPT2ModelTester: @require_torch -class GPT2ModelTest(ModelTesterMixin, unittest.TestCase): +class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): all_model_classes = ( (GPT2Model, GPT2LMHeadModel, GPT2DoubleHeadsModel, GPT2ForSequenceClassification) @@ -510,32 +511,17 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase): self.assertListEqual(output_ids[0].tolist(), expected_output_ids) @slow - def test_lm_generate_distilgpt2(self): - model = GPT2LMHeadModel.from_pretrained("distilgpt2") + def test_gpt2_sample(self): + tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + model = GPT2LMHeadModel.from_pretrained("gpt2") model.to(torch_device) - input_ids = torch.tensor([[464, 1893]], dtype=torch.long, device=torch_device) # The president - expected_output_ids = [ - 464, - 1893, - 286, - 262, - 1578, - 1829, - 11, - 290, - 262, - 1893, - 286, - 262, - 1578, - 7526, - 11, - 423, - 587, - 287, - 262, - 2635, - ] # The president of the United States, and the president of the United Kingdom, have been in the White - output_ids = model.generate(input_ids, do_sample=False) - self.assertListEqual(output_ids[0].tolist(), expected_output_ids) + torch.manual_seed(0) + input_ids = tokenizer("Today is a nice day and", return_tensors="pt").input_ids.to(torch_device) + output_ids = model.generate(input_ids, do_sample=True) + output_str = tokenizer.decode(output_ids[0], skip_special_tokens=True) + + EXPECTED_OUTPUT_STR = ( + "Today is a nice day and if you don't know anything about the state of play during your holiday" + ) + self.assertEqual(output_str, EXPECTED_OUTPUT_STR) diff --git a/tests/test_modeling_openai.py b/tests/test_modeling_openai.py index e74ce093fa1..eae027e7a0a 100644 --- a/tests/test_modeling_openai.py +++ b/tests/test_modeling_openai.py @@ -20,6 +20,7 @@ from transformers import is_torch_available from transformers.testing_utils import require_torch, slow, torch_device from .test_configuration_common import ConfigTester +from .test_generation_utils import GenerationTesterMixin from .test_modeling_common import ModelTesterMixin, ids_tensor @@ -170,7 +171,7 @@ class OpenAIGPTModelTester: @require_torch -class OpenAIGPTModelTest(ModelTesterMixin, unittest.TestCase): +class OpenAIGPTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): all_model_classes = ( (OpenAIGPTModel, OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel, OpenAIGPTForSequenceClassification) diff --git a/tests/test_modeling_prophetnet.py b/tests/test_modeling_prophetnet.py index f74123f3d13..f9a1cce4a81 100644 --- a/tests/test_modeling_prophetnet.py +++ b/tests/test_modeling_prophetnet.py @@ -22,6 +22,7 @@ from transformers import is_torch_available from transformers.testing_utils import require_torch, slow, torch_device from .test_configuration_common import ConfigTester +from .test_generation_utils import GenerationTesterMixin from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor @@ -853,7 +854,7 @@ class ProphetNetStandaloneEncoderModelTester: @require_torch -class ProphetNetModelTest(ModelTesterMixin, unittest.TestCase): +class ProphetNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): all_model_classes = (ProphetNetModel, ProphetNetForConditionalGeneration) if is_torch_available() else () all_generative_model_classes = (ProphetNetForConditionalGeneration,) if is_torch_available() else () test_pruning = False @@ -917,7 +918,7 @@ class ProphetNetModelTest(ModelTesterMixin, unittest.TestCase): @require_torch -class ProphetNetStandaloneDecoderModelTest(ModelTesterMixin, unittest.TestCase): +class ProphetNetStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): all_model_classes = (ProphetNetDecoder, ProphetNetForCausalLM) if is_torch_available() else () all_generative_model_classes = (ProphetNetForCausalLM,) if is_torch_available() else () test_pruning = False diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index 888d22f30e9..2d5884cd75e 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -26,6 +26,7 @@ from transformers.testing_utils import ( ) from .test_configuration_common import ConfigTester +from .test_generation_utils import GenerationTesterMixin from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask @@ -196,11 +197,14 @@ class ReformerModelTester: ) def create_and_check_reformer_model_with_lm_backward(self, config, input_ids, input_mask, choice_labels): + if not self.is_training: + return + config.is_decoder = False config.lsh_num_chunks_after = 1 model = ReformerForMaskedLM(config=config) model.to(torch_device) - model.eval() + model.train() loss = model(input_ids, attention_mask=input_mask, labels=input_ids)["loss"] loss.backward() @@ -569,7 +573,7 @@ class ReformerTesterMixin: @require_torch -class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.TestCase): +class ReformerLocalAttnModelTest(ReformerTesterMixin, GenerationTesterMixin, ModelTesterMixin, unittest.TestCase): all_model_classes = ( (ReformerModel, ReformerModelWithLMHead, ReformerForSequenceClassification, ReformerForQuestionAnswering) if is_torch_available() @@ -629,7 +633,7 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest @require_torch -class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.TestCase): +class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): all_model_classes = ( (ReformerModel, ReformerModelWithLMHead, ReformerForSequenceClassification, ReformerForQuestionAnswering) if is_torch_available() diff --git a/tests/test_modeling_roberta.py b/tests/test_modeling_roberta.py index 5ab097b5ff4..de753266c38 100644 --- a/tests/test_modeling_roberta.py +++ b/tests/test_modeling_roberta.py @@ -20,6 +20,7 @@ from transformers import is_torch_available from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device from .test_configuration_common import ConfigTester +from .test_generation_utils import GenerationTesterMixin from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask @@ -267,7 +268,7 @@ class RobertaModelTester: @require_torch -class RobertaModelTest(ModelTesterMixin, unittest.TestCase): +class RobertaModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): all_model_classes = ( ( @@ -282,6 +283,7 @@ class RobertaModelTest(ModelTesterMixin, unittest.TestCase): if is_torch_available() else () ) + all_generative_model_classes = (RobertaForCausalLM,) if is_torch_available() else () def setUp(self): self.model_tester = RobertaModelTester(self) diff --git a/tests/test_modeling_t5.py b/tests/test_modeling_t5.py index 5887ec68020..77b8ee1cae9 100644 --- a/tests/test_modeling_t5.py +++ b/tests/test_modeling_t5.py @@ -23,6 +23,7 @@ from transformers.file_utils import cached_property from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device from .test_configuration_common import ConfigTester +from .test_generation_utils import GenerationTesterMixin from .test_modeling_common import ModelTesterMixin, ids_tensor @@ -466,7 +467,7 @@ class T5ModelTester: @require_torch -class T5ModelTest(ModelTesterMixin, unittest.TestCase): +class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else () all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else () @@ -592,6 +593,7 @@ class T5ModelIntegrationTests(unittest.TestCase): do_sample=False, early_stopping=True, ) + decoded = tok.batch_decode(hypotheses_batch, skip_special_tokens=True, clean_up_tokenization_spaces=False) self.assertListEqual( expected_summaries, diff --git a/tests/test_modeling_transfo_xl.py b/tests/test_modeling_transfo_xl.py index 24b1787c891..ad1b8aed7e0 100644 --- a/tests/test_modeling_transfo_xl.py +++ b/tests/test_modeling_transfo_xl.py @@ -20,6 +20,7 @@ from transformers import is_torch_available from transformers.testing_utils import require_torch, require_torch_multigpu, slow, torch_device from .test_configuration_common import ConfigTester +from .test_generation_utils import GenerationTesterMixin from .test_modeling_common import ModelTesterMixin, ids_tensor @@ -156,7 +157,7 @@ class TransfoXLModelTester: @require_torch -class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase): +class TransfoXLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): all_model_classes = (TransfoXLModel, TransfoXLLMHeadModel) if is_torch_available() else () all_generative_model_classes = (TransfoXLLMHeadModel,) if is_torch_available() else () test_pruning = False diff --git a/tests/test_modeling_xlm.py b/tests/test_modeling_xlm.py index da1ef130c00..14c3236ef97 100644 --- a/tests/test_modeling_xlm.py +++ b/tests/test_modeling_xlm.py @@ -20,6 +20,7 @@ from transformers import is_torch_available from transformers.testing_utils import require_torch, slow, torch_device from .test_configuration_common import ConfigTester +from .test_generation_utils import GenerationTesterMixin from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask @@ -331,7 +332,7 @@ class XLMModelTester: @require_torch -class XLMModelTest(ModelTesterMixin, unittest.TestCase): +class XLMModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): all_model_classes = ( ( diff --git a/tests/test_modeling_xlnet.py b/tests/test_modeling_xlnet.py index 4b2893a05bd..c154874a1eb 100644 --- a/tests/test_modeling_xlnet.py +++ b/tests/test_modeling_xlnet.py @@ -21,6 +21,7 @@ from transformers import is_torch_available from transformers.testing_utils import require_torch, slow, torch_device from .test_configuration_common import ConfigTester +from .test_generation_utils import GenerationTesterMixin from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask @@ -479,7 +480,7 @@ class XLNetModelTester: @require_torch -class XLNetModelTest(ModelTesterMixin, unittest.TestCase): +class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): all_model_classes = ( ( XLNetModel,