diff --git a/transformers/generate/beam_search.py b/transformers/generate/beam_search.py deleted file mode 100644 index b56ebbabb86..00000000000 --- a/transformers/generate/beam_search.py +++ /dev/null @@ -1,376 +0,0 @@ -# coding=utf-8 -# MIT License - -# Copyright (c) 2017-Present OpenNMT - -# Permission is hereby granted, free of charge, to any person obtaining a copy of -# this software and associated documentation files (the "Software"), to deal in -# the Software without restriction, including without limitation the rights to -# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies -# of the Software, and to permit persons to whom the Software is furnished to do -# so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. - -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -""" -Use Beam Search to generate sequences using encoder-decoder models. -""" -import torch -from torch import nn -import logging - - -logger = logging.getLogger(__name__) - - -class BeamSearch(object): - def __init__( - self, - model, - bos_token_id, - pad_token_id, - eos_token_id, - batch_size, - beam_size, - min_length, - max_length, - alpha=0, - block_repeating_trigrams=True, - device=torch.device("cpu"), - ): - r""" - Inputs: - **model**: instance of ``transformers.PreTrainedEncoderDecoder`` - The pretrained encoder-decoder model that will be used to generate the sequences. - **bos_token_id**: int - Id that is used by the tokenizer to represent the beggining of a sentence. - **pad_token_id**: int - Id that is used by the tokenizer for padding. - **eos_token_id**: int - Id that is used by the tokenizer to represent the end of a sentence. - **batch_size**: (`optional`) int - Batch size of the inputs. The value is set automatically when calling `forward`. - **beam_size**: int - Number of beams that are used for each element on the batch. - **min_length**: int - Minimum number of steps performed by the beam search before terminating. - **max_length**: int - Maximum number of steps performed by the beam search. Any beam that has not finished - will return its current solution with the highest probability. The sequence that is - returned has a length of max_length-1 to account for the end token that is subsequently added. - **alpha**: float - Parameter of the length penalty. Read the documentation of the `_length_penalty` method for mode details. - **block_repeating_trigrams**: bool - Whether to block sequences that have repeating 3-grams. - """ - super(BeamSearch, self).__init__() - self.model = model - self.device = next(model.parameters()).device # only works if all parameters of the model are stored on a single GPU - - self.bos_token_id = bos_token_id - self.eos_token_id = eos_token_id - self.pad_token_id = pad_token_id - - self.batch_size = batch_size - self.beam_size = beam_size - self.min_length = min_length - self.max_length = max_length - - self.block_repeating_trigram = block_repeating_trigrams - self.apply_length_penalty = False if alpha == 0 else True - self.alpha = alpha - - self._init_beam_state(batch_size) - - def __len__(self): - return self.growing_beams.size(1) - - def _init_beam_state(self, batch_size): - """ (re-)Initialize the state of the beams. """ - self.hypotheses = [[] for _ in range(batch_size)] - self.batch_offset = torch.arange(batch_size, dtype=torch.long, device=self.device) - self.beam_offset = torch.arange( - 0, - batch_size * self.beam_size, - step=self.beam_size, - dtype=torch.long, - device=self.device, - ) - self.growing_beams = torch.full( - (batch_size * self.beam_size, 1), - self.bos_token_id, - dtype=torch.long, - device=self.device, - ) - self.topk_log_probabilities = torch.tensor( - [0.0] + [float("-inf")] * (self.beam_size - 1), - dtype=torch.float, - device=self.device, - ).repeat(batch_size) - self.results = { - "predictions": [[] for _ in range(batch_size)], - "scores": [[] for _ in range(batch_size)], - } - self._step = 0 - self.is_done = False - - def __call__(self, encoder_input_ids, **model_kwargs): - """ Generate a sequence using Beam Search. """ - # keyword arguments come in 3 flavors: encoder-specific (prefixed by - # `encoder_`), decoder-specific (prefixed by `decoder_`) and those - # that apply to the model as whole. - # We let the specific kwargs override the common ones in case of conflict. - kwargs_common = { - argument: value - for argument, value in model_kwargs.items() - if not argument.startswith("encoder_") and not argument.startswith("decoder_") - } - kwargs_decoder = kwargs_common.copy() - kwargs_encoder = kwargs_common.copy() - kwargs_encoder.update( - { - argument[len("encoder_") :]: value - for argument, value in model_kwargs.items() - if argument.startswith("encoder_") - } - ) - kwargs_decoder.update( - { - argument[len("decoder_") :]: value - for argument, value in model_kwargs.items() - if argument.startswith("decoder_") - } - ) - - # forward pass on the encoder - encoder_outputs = self.model.encoder(encoder_input_ids, **kwargs_encoder) - encoder_hidden_states = encoder_outputs[0] - kwargs_decoder["encoder_hidden_states"] = tile( - encoder_hidden_states, self.beam_size, dim=0 - ) - try: - kwargs_decoder["encoder_attention_mask"] = tile( - kwargs_encoder["attention_mask"], self.beam_size, dim=0 - ) - except: - pass - kwargs_decoder["state"].src = tile( - kwargs_decoder["state"].src, self.beam_size, dim=0 - ) - - # grow the beam iteratively - batch_size, block_size = encoder_input_ids.size() - self._init_beam_state(batch_size) - for step in range(self.max_length): - decoder_input = fit_to_block_size(self.growing_beams, block_size, self.pad_token_id) - kwargs_decoder["attention_mask"] = build_mask(decoder_input, self.pad_token_id) - - outputs, state = self.model.decoder(decoder_input, **kwargs_decoder) - - next_token_scores = outputs[0][:, -1, :].squeeze(1) - log_probabilities = torch.nn.functional.log_softmax(next_token_scores, dim=0) - surviving_beams_rows = self.grow(log_probabilities) - if self.is_done: - break - - kwargs_decoder["encoder_hidden_states"] = kwargs_decoder[ - "encoder_hidden_states" - ].index_select(0, surviving_beams_rows) - try: - kwargs_decoder["encoder_attention_mask"] = kwargs_decoder[ - "encoder_attention_mask" - ].index_select(0, surviving_beams_rows) - except: - pass - kwargs_decoder["state"] = state - - return self.results - - def grow(self, log_probabilities): - """ Grow the beams by one step. """ - self._step += 1 - - # The number of beams changes as some beams finish so we define _B - vocab_size = log_probabilities.size(-1) - _B = log_probabilities.size(0) // self.beam_size - - # Multiply each beam probability with the probability of the - # next token (conditioned on the words in the beam). - log_probabilities += self.topk_log_probabilities.view(-1, 1) - - self._enforce_min_length(log_probabilities) - if self.block_repeating_trigram: - self._remove_beams_with_repeating_trigrams(log_probabilities, _B) - - # Find the `beam_size` (previous_beam + token) combinations with - # the highest score - self.topk_log_probabilities, topk_ids = torch.topk( - log_probabilities.view(_B, self.beam_size * vocab_size), self.beam_size, dim=1 - ) - - # Apply the length penalty. The +1 accounts for the [EOS] token - # that will be added if the beam ends. - topk_scores = self.topk_log_probabilities - if self.apply_length_penalty: - topk_scores /= self._length_penalty() - - # Retrieve the corresponding respective beam and token id - # topk_token_ids[i] will be added to topk_beam_ids[i] - topk_beam_ids = topk_ids.div(vocab_size) - topk_token_ids = topk_ids.fmod(vocab_size) - - # Retrieve the row index of the surviving beams in the original - # view of the log_probabilities tensor - surviving_beams_per_batch = topk_beam_ids + self.beam_offset[:_B].view(-1, 1) - surviving_beams_rows = surviving_beams_per_batch.view(-1) - - # Append the last predictions - self.growing_beams = torch.cat( - [ - self.growing_beams.index_select(0, surviving_beams_rows), - topk_token_ids.view(-1, 1), - ], - 1, - ) - - # Check if any of the beam searches has ended during this - # growth step. Also if top beam (most probable) has ended - # for one element of the batch. - is_finished = topk_token_ids.eq(self.eos_token_id) - self._enforce_max_length(is_finished) - if is_finished.any(): - non_finished = self._cut_finished(is_finished, topk_scores) - self.batch_offset = self.batch_offset.index_select(0, non_finished) - surviving_beams_per_batch = surviving_beams_per_batch.index_select( - 0, non_finished - ) - self.topk_log_probabilities = self.topk_log_probabilities.index_select( - 0, non_finished - ) - - surviving_beams_rows = surviving_beams_per_batch.view(-1) - self.growing_beams = self.growing_beams.index_select(0, surviving_beams_rows) - - return surviving_beams_rows - - def _cut_finished(self, is_finished, topk_scores): - """ Save the finished searches and cut the correponding sequences off - the beams. """ - is_top_beam_finished = is_finished[:, 0].eq(True) - - # Save the finished searches - predictions = self.growing_beams.view( - -1, self.beam_size, self.growing_beams.size(1) - ) - for i in range(is_finished.size(0)): - if is_top_beam_finished[i]: - is_finished[i].fill_(1) - finished_hyp = is_finished[i].nonzero().view(-1) - - # Store the finished beams as a (score, prediction) hypothesis. - b = self.batch_offset[i] - for j in finished_hyp: - self.hypotheses[b].append((topk_scores[i, j], predictions[i, j, :])) - - # If the batch reached the end, save the best hypotheses - # in terms of length-penalized score. - if is_top_beam_finished[i]: - best_score, best_prediction = max(self.hypotheses[b], key=lambda x: x[0]) - self.results["scores"][b].append(best_score) - self.results["predictions"][b].append(best_prediction) - - non_finished = is_top_beam_finished.eq(False).nonzero().view(-1) - if len(non_finished) == 0: - self.is_done = True - - return non_finished - - def _remove_beams_with_repeating_trigrams(self, log_probabilities, _B): - if self._step + 1 > 3: # [BOS] does not count - for i in range(_B * self.beam_size): - tokens = self.growing_beams[i] - trigrams = [ - (tokens[j - 1], tokens[j], tokens[j + 1]) - for j in range(1, len(self) - 1) - ] - last_trigram = tuple(trigrams[-1]) - if last_trigram in trigrams[:-1]: - log_probabilities[i] = -1e20 - - def _enforce_min_length(self, log_probabilities): - if self._step < self.min_length: - log_probabilities[:, self.eos_token_id] = -1e20 - - def _enforce_max_length(self, is_finished): - # +1 because we will need to add an [EOS] token - if self._step + 1 == self.max_length: - is_finished.fill_(1) - - def _length_penalty(self): - """ The calculation of the length penalty follows that of [1]. - - [1] Wu, Yonghui, et al. "Google's neural machine translation system: - Bridging the gap between human and machine translation." arXiv preprint - arXiv:1609.08144 (2016). - """ - return ((5.0 + (self._step + 1)) / 6.0) ** self.alpha - - -def tile(x, count, dim=0): - """ - Tiles `x` along dimension `dim` `count` times. - - Example: - >> ex = torch.tensor([1,2],[3,4]) - >> tile(ex, 2, 0) - torch.Tensor([[1,2],[1,2],[3,4],[3,4]]) - """ - perm = list(range(len(x.size()))) - if dim != 0: - perm[0], perm[dim] = perm[dim], perm[0] - x = x.permute(perm).contiguous() - out_size = list(x.size()) - out_size[0] *= count - batch = x.size(0) - x = ( - x.view(batch, -1) - .transpose(0, 1) - .repeat(count, 1) - .transpose(0, 1) - .contiguous() - .view(*out_size) - ) - if dim != 0: - x = x.permute(perm).contiguous() - return x - - -def fit_to_block_size(sequence, block_size, pad_token_id): - """ Adapt the source and target sequences' lengths to the block size. - If the sequence is shorter we append padding tokens to the right. - """ - padded_sequence = torch.full( - (sequence.size(0), block_size), - pad_token_id, - dtype=torch.long, - device=sequence.device, - ) - padded_sequence[:, : sequence.size(1)] = sequence - return sequence - - -def build_mask(sequence, pad_token_id): - """ Builds the mask. The attention mechanism will only attend to positions - with value 1. """ - mask = torch.ones_like(sequence) - idx_pad_tokens = sequence == pad_token_id - mask[idx_pad_tokens] = 0 - return mask diff --git a/transformers/tests/beam_search_tests.py b/transformers/tests/beam_search_tests.py deleted file mode 100644 index 6f2a2b9c2f0..00000000000 --- a/transformers/tests/beam_search_tests.py +++ /dev/null @@ -1,243 +0,0 @@ -from collections import namedtuple -import unittest -import pytest -import numpy as np -import torch -from torch import nn - -from transformers.generate import BeamSearch -from transformers import PreTrainedEncoderDecoder - - -class StubTransformer(nn.Module): - def __init__(self): - self.encoder = None - self.decoder = None - self._parameters = {"dumy": torch.tensor([1])} - - def forward(self): - pass - - -class BeamSearchtest(unittest.TestCase): - def test_beam_search_encoder_decoder_integration(self): - """ We make sure that no internal change in the PreTrainedEncoderDecoder - class will break the integration with the beam search. - """ - - model = StubTransformer() - try: - _ = BeamSearch( - model=model, - bos_token_id=0, - eos_token_id=1, - pad_token_id=2, - batch_size=1, - beam_size=1, - min_length=1, - max_length=1, - alpha=0, - block_repeating_trigrams=False, - ) - except: - self.fail("Instantiating BeamSearch with a PreTrainedEncoderDecoder failed.") - - def test_beam_search_min_length(self): - """ We keep predicting the end_token for the first beam and check that - it is not marked as finished until the beam has reached the minimum - length. """ - eos_idx = 3 - vocab_size = 10 - - batch_size = 3 - beam_size = 2 - min_length = 5 - - beam = BeamSearch( - model=StubTransformer(), - bos_token_id=0, - eos_token_id=eos_idx, - pad_token_id=2, - batch_size=batch_size, - beam_size=beam_size, - min_length=5, - max_length=10, - alpha=0, - block_repeating_trigrams=False, - ) - - # To test that the minimum length is correctly enforced we constantly - # assign the highest probability to the [EOS] token (and assign lower - # probabilities to some other tokens). - # Since BeamSearch will reset its probability to 1e-20 as long as - # min_length has not been reached, we need to reset the value between - # steps. - non_eos_idxs = [4, 5, 1, 8, 9] - score_distribution = torch.log_softmax( - torch.tensor([6.0, 5.0, 4.0, 3.0, 2.0, 1.0]), dim=0 - ) - - log_probabilities = torch.full((batch_size * beam_size, vocab_size), float("-inf")) - log_probabilities[0, eos_idx] = score_distribution[0] - for idx, score in zip(non_eos_idxs, score_distribution[1:]): - log_probabilities[0, idx] = score - pytest.set_trace() - for step in range(1, min_length + 2): - log_probabilities[0, eos_idx] = score_distribution[0] - - # Beam #3 and #4 teminate at the first step since the probability - # of the [EOS] token is -1e20 > -\infty so there are only two beams left. - # The top beam (most likely) always ends with 4 until we reach min_length. - surviving_beams_rows = beam.grow(log_probabilities) - if step < min_length: - np.testing.assert_array_equal( - beam.growing_beams.numpy()[0, :], np.array([0] + [4] * step) - ) - elif step == min_length: - np.testing.assert_array_equal(surviving_beams_rows.numpy(), np.array([])) - self.assertTrue(beam.is_done) - break - - log_probabilities = log_probabilities.index_select(0, surviving_beams_rows) - - def test_beam_search_max_length(self): - """ We keep predicting the same non-EOS token until we reach the - maximum permitted length """ - batch_size = 3 - beam_size = 2 - max_length = 5 - vocab_size = 10 - - beam = BeamSearch( - model=StubTransformer(), - bos_token_id=0, - eos_token_id=1, - pad_token_id=2, - batch_size=batch_size, - beam_size=beam_size, - min_length=2, - max_length=max_length, - alpha=0, - block_repeating_trigrams=False, - ) - - log_probabilities = torch.full((batch_size * beam_size, vocab_size), float("-inf")) - - # To test that beam search enforces the max length constraint we - # keep giving the highest probability to a token that is not the - # [EOS] token. - # The beam search will stop at max_length-1, assuming that one would - # add the [EOS] token at the end of the returned sequence. - token_idxs = [3, 4, 5] - score_distribution = torch.log_softmax(torch.tensor([10.0, 6.0, 4.0]), dim=0) - for idx, score in zip(token_idxs, score_distribution): - log_probabilities[:, idx] = score - - for step in range(1, max_length + 2): - surviving_beams_rows = beam.grow(log_probabilities) - if step + 1 < max_length: - self.assertFalse(beam.is_done) - elif step + 1 == max_length: # Now [EOS] is the most probable token - np.testing.assert_array_equal(surviving_beams_rows.numpy(), np.array([])) - self.assertTrue(beam.is_done) - break - - log_probabilities = log_probabilities.index_select(0, surviving_beams_rows) - - def test_beam_search_block_repeating_trigrams(self): - """ We make sure that the beams that contain repeating trigrams are removed. """ - batch_size = 3 - beam_size = 2 - max_length = 10 - vocab_size = 10 - - beam = BeamSearch( - model=StubTransformer(), - bos_token_id=0, - eos_token_id=1, - pad_token_id=2, - batch_size=batch_size, - beam_size=beam_size, - min_length=2, - max_length=max_length, - alpha=0, - block_repeating_trigrams=True, - ) - - log_probabilities = torch.full((batch_size * beam_size, vocab_size), float("-inf")) - - # To test that BeamSearch enforces the 3-gram constraint we give the - # highest probably to the same tokens in a cyclic fashion and make sure - # they disappear once the cycle has completed. - token_idxs = [3, 4, 5] - score_distribution = torch.log_softmax(torch.tensor([10.0, 6.0, 4.0]), dim=0) - for idx, score in zip(token_idxs, score_distribution): - log_probabilities[:, idx] = score - - for step in range(1, max_length + 2): - # Rotate the probabilities at each step - for idx in token_idxs: - score = score_distribution[(idx + step) % 3] - log_probabilities[::beam_size, idx] = score - - surviving_beams_rows = beam.grow(log_probabilities) - - if step < 7: - self.assertFalse( - np.array_equal( - log_probabilities.numpy()[0, :], - np.array([-1e20] * vocab_size, dtype="float32"), - ) - ) - if step == 7: - np.testing.assert_array_equal( - log_probabilities.numpy()[0, :], - np.array([-1e20] * vocab_size, dtype="float32"), - ) - - log_probabilities = log_probabilities.index_select(0, surviving_beams_rows) - - def test_beam_search_example_for_one_step(self): - """ We test that the predictions for one step of growth are correct. """ - batch_size = 2 - beam_size = 2 - max_length = 10 - vocab_size = 5 - - beam = BeamSearch( - model=StubTransformer(), - bos_token_id=0, - eos_token_id=1, - pad_token_id=2, - batch_size=batch_size, - beam_size=beam_size, - min_length=2, - max_length=max_length, - alpha=0, - block_repeating_trigrams=False, - ) - - log_probabilities = torch.full((batch_size * beam_size, vocab_size), float("-inf")) - log_probabilities[0, 3:] = torch.log_softmax(torch.tensor([2.0, 1.0]), dim=0) - log_probabilities[2, 3:] = torch.log_softmax(torch.tensor([1.0, 2.0]), dim=0) - - # First pass - surviving_beams_rows = beam.grow(log_probabilities) - np.testing.assert_array_equal(surviving_beams_rows.numpy(), np.array([0, 0, 2, 2])) - np.testing.assert_array_equal( - beam.growing_beams.numpy(), np.array([[0, 3], [0, 4], [0, 4], [0, 3]]) - ) - self.assertFalse(beam.is_done) - - # Second pass - surviving_beams_rows = beam.grow(log_probabilities) - np.testing.assert_array_equal(surviving_beams_rows.numpy(), np.array([0, 0, 2, 2])) - np.testing.assert_array_equal( - beam.growing_beams.numpy(), - np.array([[0, 3, 3], [0, 3, 4], [0, 4, 4], [0, 4, 3]]), - ) - self.assertFalse(beam.is_done) - - -if __name__ == "__name__": - unittest.main()