remove beam search

This commit is contained in:
Rémi Louf 2019-12-05 18:13:41 +01:00 committed by Julien Chaumond
parent 2403a66598
commit c0443df593
2 changed files with 0 additions and 619 deletions

View File

@ -1,376 +0,0 @@
# coding=utf-8
# MIT License
# Copyright (c) 2017-Present OpenNMT
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
# of the Software, and to permit persons to whom the Software is furnished to do
# so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""
Use Beam Search to generate sequences using encoder-decoder models.
"""
import torch
from torch import nn
import logging
logger = logging.getLogger(__name__)
class BeamSearch(object):
def __init__(
self,
model,
bos_token_id,
pad_token_id,
eos_token_id,
batch_size,
beam_size,
min_length,
max_length,
alpha=0,
block_repeating_trigrams=True,
device=torch.device("cpu"),
):
r"""
Inputs:
**model**: instance of ``transformers.PreTrainedEncoderDecoder``
The pretrained encoder-decoder model that will be used to generate the sequences.
**bos_token_id**: int
Id that is used by the tokenizer to represent the beggining of a sentence.
**pad_token_id**: int
Id that is used by the tokenizer for padding.
**eos_token_id**: int
Id that is used by the tokenizer to represent the end of a sentence.
**batch_size**: (`optional`) int
Batch size of the inputs. The value is set automatically when calling `forward`.
**beam_size**: int
Number of beams that are used for each element on the batch.
**min_length**: int
Minimum number of steps performed by the beam search before terminating.
**max_length**: int
Maximum number of steps performed by the beam search. Any beam that has not finished
will return its current solution with the highest probability. The sequence that is
returned has a length of max_length-1 to account for the end token that is subsequently added.
**alpha**: float
Parameter of the length penalty. Read the documentation of the `_length_penalty` method for mode details.
**block_repeating_trigrams**: bool
Whether to block sequences that have repeating 3-grams.
"""
super(BeamSearch, self).__init__()
self.model = model
self.device = next(model.parameters()).device # only works if all parameters of the model are stored on a single GPU
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.pad_token_id = pad_token_id
self.batch_size = batch_size
self.beam_size = beam_size
self.min_length = min_length
self.max_length = max_length
self.block_repeating_trigram = block_repeating_trigrams
self.apply_length_penalty = False if alpha == 0 else True
self.alpha = alpha
self._init_beam_state(batch_size)
def __len__(self):
return self.growing_beams.size(1)
def _init_beam_state(self, batch_size):
""" (re-)Initialize the state of the beams. """
self.hypotheses = [[] for _ in range(batch_size)]
self.batch_offset = torch.arange(batch_size, dtype=torch.long, device=self.device)
self.beam_offset = torch.arange(
0,
batch_size * self.beam_size,
step=self.beam_size,
dtype=torch.long,
device=self.device,
)
self.growing_beams = torch.full(
(batch_size * self.beam_size, 1),
self.bos_token_id,
dtype=torch.long,
device=self.device,
)
self.topk_log_probabilities = torch.tensor(
[0.0] + [float("-inf")] * (self.beam_size - 1),
dtype=torch.float,
device=self.device,
).repeat(batch_size)
self.results = {
"predictions": [[] for _ in range(batch_size)],
"scores": [[] for _ in range(batch_size)],
}
self._step = 0
self.is_done = False
def __call__(self, encoder_input_ids, **model_kwargs):
""" Generate a sequence using Beam Search. """
# keyword arguments come in 3 flavors: encoder-specific (prefixed by
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those
# that apply to the model as whole.
# We let the specific kwargs override the common ones in case of conflict.
kwargs_common = {
argument: value
for argument, value in model_kwargs.items()
if not argument.startswith("encoder_") and not argument.startswith("decoder_")
}
kwargs_decoder = kwargs_common.copy()
kwargs_encoder = kwargs_common.copy()
kwargs_encoder.update(
{
argument[len("encoder_") :]: value
for argument, value in model_kwargs.items()
if argument.startswith("encoder_")
}
)
kwargs_decoder.update(
{
argument[len("decoder_") :]: value
for argument, value in model_kwargs.items()
if argument.startswith("decoder_")
}
)
# forward pass on the encoder
encoder_outputs = self.model.encoder(encoder_input_ids, **kwargs_encoder)
encoder_hidden_states = encoder_outputs[0]
kwargs_decoder["encoder_hidden_states"] = tile(
encoder_hidden_states, self.beam_size, dim=0
)
try:
kwargs_decoder["encoder_attention_mask"] = tile(
kwargs_encoder["attention_mask"], self.beam_size, dim=0
)
except:
pass
kwargs_decoder["state"].src = tile(
kwargs_decoder["state"].src, self.beam_size, dim=0
)
# grow the beam iteratively
batch_size, block_size = encoder_input_ids.size()
self._init_beam_state(batch_size)
for step in range(self.max_length):
decoder_input = fit_to_block_size(self.growing_beams, block_size, self.pad_token_id)
kwargs_decoder["attention_mask"] = build_mask(decoder_input, self.pad_token_id)
outputs, state = self.model.decoder(decoder_input, **kwargs_decoder)
next_token_scores = outputs[0][:, -1, :].squeeze(1)
log_probabilities = torch.nn.functional.log_softmax(next_token_scores, dim=0)
surviving_beams_rows = self.grow(log_probabilities)
if self.is_done:
break
kwargs_decoder["encoder_hidden_states"] = kwargs_decoder[
"encoder_hidden_states"
].index_select(0, surviving_beams_rows)
try:
kwargs_decoder["encoder_attention_mask"] = kwargs_decoder[
"encoder_attention_mask"
].index_select(0, surviving_beams_rows)
except:
pass
kwargs_decoder["state"] = state
return self.results
def grow(self, log_probabilities):
""" Grow the beams by one step. """
self._step += 1
# The number of beams changes as some beams finish so we define _B
vocab_size = log_probabilities.size(-1)
_B = log_probabilities.size(0) // self.beam_size
# Multiply each beam probability with the probability of the
# next token (conditioned on the words in the beam).
log_probabilities += self.topk_log_probabilities.view(-1, 1)
self._enforce_min_length(log_probabilities)
if self.block_repeating_trigram:
self._remove_beams_with_repeating_trigrams(log_probabilities, _B)
# Find the `beam_size` (previous_beam + token) combinations with
# the highest score
self.topk_log_probabilities, topk_ids = torch.topk(
log_probabilities.view(_B, self.beam_size * vocab_size), self.beam_size, dim=1
)
# Apply the length penalty. The +1 accounts for the [EOS] token
# that will be added if the beam ends.
topk_scores = self.topk_log_probabilities
if self.apply_length_penalty:
topk_scores /= self._length_penalty()
# Retrieve the corresponding respective beam and token id
# topk_token_ids[i] will be added to topk_beam_ids[i]
topk_beam_ids = topk_ids.div(vocab_size)
topk_token_ids = topk_ids.fmod(vocab_size)
# Retrieve the row index of the surviving beams in the original
# view of the log_probabilities tensor
surviving_beams_per_batch = topk_beam_ids + self.beam_offset[:_B].view(-1, 1)
surviving_beams_rows = surviving_beams_per_batch.view(-1)
# Append the last predictions
self.growing_beams = torch.cat(
[
self.growing_beams.index_select(0, surviving_beams_rows),
topk_token_ids.view(-1, 1),
],
1,
)
# Check if any of the beam searches has ended during this
# growth step. Also if top beam (most probable) has ended
# for one element of the batch.
is_finished = topk_token_ids.eq(self.eos_token_id)
self._enforce_max_length(is_finished)
if is_finished.any():
non_finished = self._cut_finished(is_finished, topk_scores)
self.batch_offset = self.batch_offset.index_select(0, non_finished)
surviving_beams_per_batch = surviving_beams_per_batch.index_select(
0, non_finished
)
self.topk_log_probabilities = self.topk_log_probabilities.index_select(
0, non_finished
)
surviving_beams_rows = surviving_beams_per_batch.view(-1)
self.growing_beams = self.growing_beams.index_select(0, surviving_beams_rows)
return surviving_beams_rows
def _cut_finished(self, is_finished, topk_scores):
""" Save the finished searches and cut the correponding sequences off
the beams. """
is_top_beam_finished = is_finished[:, 0].eq(True)
# Save the finished searches
predictions = self.growing_beams.view(
-1, self.beam_size, self.growing_beams.size(1)
)
for i in range(is_finished.size(0)):
if is_top_beam_finished[i]:
is_finished[i].fill_(1)
finished_hyp = is_finished[i].nonzero().view(-1)
# Store the finished beams as a (score, prediction) hypothesis.
b = self.batch_offset[i]
for j in finished_hyp:
self.hypotheses[b].append((topk_scores[i, j], predictions[i, j, :]))
# If the batch reached the end, save the best hypotheses
# in terms of length-penalized score.
if is_top_beam_finished[i]:
best_score, best_prediction = max(self.hypotheses[b], key=lambda x: x[0])
self.results["scores"][b].append(best_score)
self.results["predictions"][b].append(best_prediction)
non_finished = is_top_beam_finished.eq(False).nonzero().view(-1)
if len(non_finished) == 0:
self.is_done = True
return non_finished
def _remove_beams_with_repeating_trigrams(self, log_probabilities, _B):
if self._step + 1 > 3: # [BOS] does not count
for i in range(_B * self.beam_size):
tokens = self.growing_beams[i]
trigrams = [
(tokens[j - 1], tokens[j], tokens[j + 1])
for j in range(1, len(self) - 1)
]
last_trigram = tuple(trigrams[-1])
if last_trigram in trigrams[:-1]:
log_probabilities[i] = -1e20
def _enforce_min_length(self, log_probabilities):
if self._step < self.min_length:
log_probabilities[:, self.eos_token_id] = -1e20
def _enforce_max_length(self, is_finished):
# +1 because we will need to add an [EOS] token
if self._step + 1 == self.max_length:
is_finished.fill_(1)
def _length_penalty(self):
""" The calculation of the length penalty follows that of [1].
[1] Wu, Yonghui, et al. "Google's neural machine translation system:
Bridging the gap between human and machine translation." arXiv preprint
arXiv:1609.08144 (2016).
"""
return ((5.0 + (self._step + 1)) / 6.0) ** self.alpha
def tile(x, count, dim=0):
"""
Tiles `x` along dimension `dim` `count` times.
Example:
>> ex = torch.tensor([1,2],[3,4])
>> tile(ex, 2, 0)
torch.Tensor([[1,2],[1,2],[3,4],[3,4]])
"""
perm = list(range(len(x.size())))
if dim != 0:
perm[0], perm[dim] = perm[dim], perm[0]
x = x.permute(perm).contiguous()
out_size = list(x.size())
out_size[0] *= count
batch = x.size(0)
x = (
x.view(batch, -1)
.transpose(0, 1)
.repeat(count, 1)
.transpose(0, 1)
.contiguous()
.view(*out_size)
)
if dim != 0:
x = x.permute(perm).contiguous()
return x
def fit_to_block_size(sequence, block_size, pad_token_id):
""" Adapt the source and target sequences' lengths to the block size.
If the sequence is shorter we append padding tokens to the right.
"""
padded_sequence = torch.full(
(sequence.size(0), block_size),
pad_token_id,
dtype=torch.long,
device=sequence.device,
)
padded_sequence[:, : sequence.size(1)] = sequence
return sequence
def build_mask(sequence, pad_token_id):
""" Builds the mask. The attention mechanism will only attend to positions
with value 1. """
mask = torch.ones_like(sequence)
idx_pad_tokens = sequence == pad_token_id
mask[idx_pad_tokens] = 0
return mask

View File

@ -1,243 +0,0 @@
from collections import namedtuple
import unittest
import pytest
import numpy as np
import torch
from torch import nn
from transformers.generate import BeamSearch
from transformers import PreTrainedEncoderDecoder
class StubTransformer(nn.Module):
def __init__(self):
self.encoder = None
self.decoder = None
self._parameters = {"dumy": torch.tensor([1])}
def forward(self):
pass
class BeamSearchtest(unittest.TestCase):
def test_beam_search_encoder_decoder_integration(self):
""" We make sure that no internal change in the PreTrainedEncoderDecoder
class will break the integration with the beam search.
"""
model = StubTransformer()
try:
_ = BeamSearch(
model=model,
bos_token_id=0,
eos_token_id=1,
pad_token_id=2,
batch_size=1,
beam_size=1,
min_length=1,
max_length=1,
alpha=0,
block_repeating_trigrams=False,
)
except:
self.fail("Instantiating BeamSearch with a PreTrainedEncoderDecoder failed.")
def test_beam_search_min_length(self):
""" We keep predicting the end_token for the first beam and check that
it is not marked as finished until the beam has reached the minimum
length. """
eos_idx = 3
vocab_size = 10
batch_size = 3
beam_size = 2
min_length = 5
beam = BeamSearch(
model=StubTransformer(),
bos_token_id=0,
eos_token_id=eos_idx,
pad_token_id=2,
batch_size=batch_size,
beam_size=beam_size,
min_length=5,
max_length=10,
alpha=0,
block_repeating_trigrams=False,
)
# To test that the minimum length is correctly enforced we constantly
# assign the highest probability to the [EOS] token (and assign lower
# probabilities to some other tokens).
# Since BeamSearch will reset its probability to 1e-20 as long as
# min_length has not been reached, we need to reset the value between
# steps.
non_eos_idxs = [4, 5, 1, 8, 9]
score_distribution = torch.log_softmax(
torch.tensor([6.0, 5.0, 4.0, 3.0, 2.0, 1.0]), dim=0
)
log_probabilities = torch.full((batch_size * beam_size, vocab_size), float("-inf"))
log_probabilities[0, eos_idx] = score_distribution[0]
for idx, score in zip(non_eos_idxs, score_distribution[1:]):
log_probabilities[0, idx] = score
pytest.set_trace()
for step in range(1, min_length + 2):
log_probabilities[0, eos_idx] = score_distribution[0]
# Beam #3 and #4 teminate at the first step since the probability
# of the [EOS] token is -1e20 > -\infty so there are only two beams left.
# The top beam (most likely) always ends with 4 until we reach min_length.
surviving_beams_rows = beam.grow(log_probabilities)
if step < min_length:
np.testing.assert_array_equal(
beam.growing_beams.numpy()[0, :], np.array([0] + [4] * step)
)
elif step == min_length:
np.testing.assert_array_equal(surviving_beams_rows.numpy(), np.array([]))
self.assertTrue(beam.is_done)
break
log_probabilities = log_probabilities.index_select(0, surviving_beams_rows)
def test_beam_search_max_length(self):
""" We keep predicting the same non-EOS token until we reach the
maximum permitted length """
batch_size = 3
beam_size = 2
max_length = 5
vocab_size = 10
beam = BeamSearch(
model=StubTransformer(),
bos_token_id=0,
eos_token_id=1,
pad_token_id=2,
batch_size=batch_size,
beam_size=beam_size,
min_length=2,
max_length=max_length,
alpha=0,
block_repeating_trigrams=False,
)
log_probabilities = torch.full((batch_size * beam_size, vocab_size), float("-inf"))
# To test that beam search enforces the max length constraint we
# keep giving the highest probability to a token that is not the
# [EOS] token.
# The beam search will stop at max_length-1, assuming that one would
# add the [EOS] token at the end of the returned sequence.
token_idxs = [3, 4, 5]
score_distribution = torch.log_softmax(torch.tensor([10.0, 6.0, 4.0]), dim=0)
for idx, score in zip(token_idxs, score_distribution):
log_probabilities[:, idx] = score
for step in range(1, max_length + 2):
surviving_beams_rows = beam.grow(log_probabilities)
if step + 1 < max_length:
self.assertFalse(beam.is_done)
elif step + 1 == max_length: # Now [EOS] is the most probable token
np.testing.assert_array_equal(surviving_beams_rows.numpy(), np.array([]))
self.assertTrue(beam.is_done)
break
log_probabilities = log_probabilities.index_select(0, surviving_beams_rows)
def test_beam_search_block_repeating_trigrams(self):
""" We make sure that the beams that contain repeating trigrams are removed. """
batch_size = 3
beam_size = 2
max_length = 10
vocab_size = 10
beam = BeamSearch(
model=StubTransformer(),
bos_token_id=0,
eos_token_id=1,
pad_token_id=2,
batch_size=batch_size,
beam_size=beam_size,
min_length=2,
max_length=max_length,
alpha=0,
block_repeating_trigrams=True,
)
log_probabilities = torch.full((batch_size * beam_size, vocab_size), float("-inf"))
# To test that BeamSearch enforces the 3-gram constraint we give the
# highest probably to the same tokens in a cyclic fashion and make sure
# they disappear once the cycle has completed.
token_idxs = [3, 4, 5]
score_distribution = torch.log_softmax(torch.tensor([10.0, 6.0, 4.0]), dim=0)
for idx, score in zip(token_idxs, score_distribution):
log_probabilities[:, idx] = score
for step in range(1, max_length + 2):
# Rotate the probabilities at each step
for idx in token_idxs:
score = score_distribution[(idx + step) % 3]
log_probabilities[::beam_size, idx] = score
surviving_beams_rows = beam.grow(log_probabilities)
if step < 7:
self.assertFalse(
np.array_equal(
log_probabilities.numpy()[0, :],
np.array([-1e20] * vocab_size, dtype="float32"),
)
)
if step == 7:
np.testing.assert_array_equal(
log_probabilities.numpy()[0, :],
np.array([-1e20] * vocab_size, dtype="float32"),
)
log_probabilities = log_probabilities.index_select(0, surviving_beams_rows)
def test_beam_search_example_for_one_step(self):
""" We test that the predictions for one step of growth are correct. """
batch_size = 2
beam_size = 2
max_length = 10
vocab_size = 5
beam = BeamSearch(
model=StubTransformer(),
bos_token_id=0,
eos_token_id=1,
pad_token_id=2,
batch_size=batch_size,
beam_size=beam_size,
min_length=2,
max_length=max_length,
alpha=0,
block_repeating_trigrams=False,
)
log_probabilities = torch.full((batch_size * beam_size, vocab_size), float("-inf"))
log_probabilities[0, 3:] = torch.log_softmax(torch.tensor([2.0, 1.0]), dim=0)
log_probabilities[2, 3:] = torch.log_softmax(torch.tensor([1.0, 2.0]), dim=0)
# First pass
surviving_beams_rows = beam.grow(log_probabilities)
np.testing.assert_array_equal(surviving_beams_rows.numpy(), np.array([0, 0, 2, 2]))
np.testing.assert_array_equal(
beam.growing_beams.numpy(), np.array([[0, 3], [0, 4], [0, 4], [0, 3]])
)
self.assertFalse(beam.is_done)
# Second pass
surviving_beams_rows = beam.grow(log_probabilities)
np.testing.assert_array_equal(surviving_beams_rows.numpy(), np.array([0, 0, 2, 2]))
np.testing.assert_array_equal(
beam.growing_beams.numpy(),
np.array([[0, 3, 3], [0, 3, 4], [0, 4, 4], [0, 4, 3]]),
)
self.assertFalse(beam.is_done)
if __name__ == "__name__":
unittest.main()