mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
remove beam search
This commit is contained in:
parent
2403a66598
commit
c0443df593
@ -1,376 +0,0 @@
|
|||||||
# coding=utf-8
|
|
||||||
# MIT License
|
|
||||||
|
|
||||||
# Copyright (c) 2017-Present OpenNMT
|
|
||||||
|
|
||||||
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
|
||||||
# this software and associated documentation files (the "Software"), to deal in
|
|
||||||
# the Software without restriction, including without limitation the rights to
|
|
||||||
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
|
|
||||||
# of the Software, and to permit persons to whom the Software is furnished to do
|
|
||||||
# so, subject to the following conditions:
|
|
||||||
|
|
||||||
# The above copyright notice and this permission notice shall be included in all
|
|
||||||
# copies or substantial portions of the Software.
|
|
||||||
|
|
||||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
||||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
||||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
||||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
||||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
||||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
||||||
# SOFTWARE.
|
|
||||||
"""
|
|
||||||
Use Beam Search to generate sequences using encoder-decoder models.
|
|
||||||
"""
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
import logging
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class BeamSearch(object):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model,
|
|
||||||
bos_token_id,
|
|
||||||
pad_token_id,
|
|
||||||
eos_token_id,
|
|
||||||
batch_size,
|
|
||||||
beam_size,
|
|
||||||
min_length,
|
|
||||||
max_length,
|
|
||||||
alpha=0,
|
|
||||||
block_repeating_trigrams=True,
|
|
||||||
device=torch.device("cpu"),
|
|
||||||
):
|
|
||||||
r"""
|
|
||||||
Inputs:
|
|
||||||
**model**: instance of ``transformers.PreTrainedEncoderDecoder``
|
|
||||||
The pretrained encoder-decoder model that will be used to generate the sequences.
|
|
||||||
**bos_token_id**: int
|
|
||||||
Id that is used by the tokenizer to represent the beggining of a sentence.
|
|
||||||
**pad_token_id**: int
|
|
||||||
Id that is used by the tokenizer for padding.
|
|
||||||
**eos_token_id**: int
|
|
||||||
Id that is used by the tokenizer to represent the end of a sentence.
|
|
||||||
**batch_size**: (`optional`) int
|
|
||||||
Batch size of the inputs. The value is set automatically when calling `forward`.
|
|
||||||
**beam_size**: int
|
|
||||||
Number of beams that are used for each element on the batch.
|
|
||||||
**min_length**: int
|
|
||||||
Minimum number of steps performed by the beam search before terminating.
|
|
||||||
**max_length**: int
|
|
||||||
Maximum number of steps performed by the beam search. Any beam that has not finished
|
|
||||||
will return its current solution with the highest probability. The sequence that is
|
|
||||||
returned has a length of max_length-1 to account for the end token that is subsequently added.
|
|
||||||
**alpha**: float
|
|
||||||
Parameter of the length penalty. Read the documentation of the `_length_penalty` method for mode details.
|
|
||||||
**block_repeating_trigrams**: bool
|
|
||||||
Whether to block sequences that have repeating 3-grams.
|
|
||||||
"""
|
|
||||||
super(BeamSearch, self).__init__()
|
|
||||||
self.model = model
|
|
||||||
self.device = next(model.parameters()).device # only works if all parameters of the model are stored on a single GPU
|
|
||||||
|
|
||||||
self.bos_token_id = bos_token_id
|
|
||||||
self.eos_token_id = eos_token_id
|
|
||||||
self.pad_token_id = pad_token_id
|
|
||||||
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.beam_size = beam_size
|
|
||||||
self.min_length = min_length
|
|
||||||
self.max_length = max_length
|
|
||||||
|
|
||||||
self.block_repeating_trigram = block_repeating_trigrams
|
|
||||||
self.apply_length_penalty = False if alpha == 0 else True
|
|
||||||
self.alpha = alpha
|
|
||||||
|
|
||||||
self._init_beam_state(batch_size)
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return self.growing_beams.size(1)
|
|
||||||
|
|
||||||
def _init_beam_state(self, batch_size):
|
|
||||||
""" (re-)Initialize the state of the beams. """
|
|
||||||
self.hypotheses = [[] for _ in range(batch_size)]
|
|
||||||
self.batch_offset = torch.arange(batch_size, dtype=torch.long, device=self.device)
|
|
||||||
self.beam_offset = torch.arange(
|
|
||||||
0,
|
|
||||||
batch_size * self.beam_size,
|
|
||||||
step=self.beam_size,
|
|
||||||
dtype=torch.long,
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
self.growing_beams = torch.full(
|
|
||||||
(batch_size * self.beam_size, 1),
|
|
||||||
self.bos_token_id,
|
|
||||||
dtype=torch.long,
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
self.topk_log_probabilities = torch.tensor(
|
|
||||||
[0.0] + [float("-inf")] * (self.beam_size - 1),
|
|
||||||
dtype=torch.float,
|
|
||||||
device=self.device,
|
|
||||||
).repeat(batch_size)
|
|
||||||
self.results = {
|
|
||||||
"predictions": [[] for _ in range(batch_size)],
|
|
||||||
"scores": [[] for _ in range(batch_size)],
|
|
||||||
}
|
|
||||||
self._step = 0
|
|
||||||
self.is_done = False
|
|
||||||
|
|
||||||
def __call__(self, encoder_input_ids, **model_kwargs):
|
|
||||||
""" Generate a sequence using Beam Search. """
|
|
||||||
# keyword arguments come in 3 flavors: encoder-specific (prefixed by
|
|
||||||
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those
|
|
||||||
# that apply to the model as whole.
|
|
||||||
# We let the specific kwargs override the common ones in case of conflict.
|
|
||||||
kwargs_common = {
|
|
||||||
argument: value
|
|
||||||
for argument, value in model_kwargs.items()
|
|
||||||
if not argument.startswith("encoder_") and not argument.startswith("decoder_")
|
|
||||||
}
|
|
||||||
kwargs_decoder = kwargs_common.copy()
|
|
||||||
kwargs_encoder = kwargs_common.copy()
|
|
||||||
kwargs_encoder.update(
|
|
||||||
{
|
|
||||||
argument[len("encoder_") :]: value
|
|
||||||
for argument, value in model_kwargs.items()
|
|
||||||
if argument.startswith("encoder_")
|
|
||||||
}
|
|
||||||
)
|
|
||||||
kwargs_decoder.update(
|
|
||||||
{
|
|
||||||
argument[len("decoder_") :]: value
|
|
||||||
for argument, value in model_kwargs.items()
|
|
||||||
if argument.startswith("decoder_")
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# forward pass on the encoder
|
|
||||||
encoder_outputs = self.model.encoder(encoder_input_ids, **kwargs_encoder)
|
|
||||||
encoder_hidden_states = encoder_outputs[0]
|
|
||||||
kwargs_decoder["encoder_hidden_states"] = tile(
|
|
||||||
encoder_hidden_states, self.beam_size, dim=0
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
kwargs_decoder["encoder_attention_mask"] = tile(
|
|
||||||
kwargs_encoder["attention_mask"], self.beam_size, dim=0
|
|
||||||
)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
kwargs_decoder["state"].src = tile(
|
|
||||||
kwargs_decoder["state"].src, self.beam_size, dim=0
|
|
||||||
)
|
|
||||||
|
|
||||||
# grow the beam iteratively
|
|
||||||
batch_size, block_size = encoder_input_ids.size()
|
|
||||||
self._init_beam_state(batch_size)
|
|
||||||
for step in range(self.max_length):
|
|
||||||
decoder_input = fit_to_block_size(self.growing_beams, block_size, self.pad_token_id)
|
|
||||||
kwargs_decoder["attention_mask"] = build_mask(decoder_input, self.pad_token_id)
|
|
||||||
|
|
||||||
outputs, state = self.model.decoder(decoder_input, **kwargs_decoder)
|
|
||||||
|
|
||||||
next_token_scores = outputs[0][:, -1, :].squeeze(1)
|
|
||||||
log_probabilities = torch.nn.functional.log_softmax(next_token_scores, dim=0)
|
|
||||||
surviving_beams_rows = self.grow(log_probabilities)
|
|
||||||
if self.is_done:
|
|
||||||
break
|
|
||||||
|
|
||||||
kwargs_decoder["encoder_hidden_states"] = kwargs_decoder[
|
|
||||||
"encoder_hidden_states"
|
|
||||||
].index_select(0, surviving_beams_rows)
|
|
||||||
try:
|
|
||||||
kwargs_decoder["encoder_attention_mask"] = kwargs_decoder[
|
|
||||||
"encoder_attention_mask"
|
|
||||||
].index_select(0, surviving_beams_rows)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
kwargs_decoder["state"] = state
|
|
||||||
|
|
||||||
return self.results
|
|
||||||
|
|
||||||
def grow(self, log_probabilities):
|
|
||||||
""" Grow the beams by one step. """
|
|
||||||
self._step += 1
|
|
||||||
|
|
||||||
# The number of beams changes as some beams finish so we define _B
|
|
||||||
vocab_size = log_probabilities.size(-1)
|
|
||||||
_B = log_probabilities.size(0) // self.beam_size
|
|
||||||
|
|
||||||
# Multiply each beam probability with the probability of the
|
|
||||||
# next token (conditioned on the words in the beam).
|
|
||||||
log_probabilities += self.topk_log_probabilities.view(-1, 1)
|
|
||||||
|
|
||||||
self._enforce_min_length(log_probabilities)
|
|
||||||
if self.block_repeating_trigram:
|
|
||||||
self._remove_beams_with_repeating_trigrams(log_probabilities, _B)
|
|
||||||
|
|
||||||
# Find the `beam_size` (previous_beam + token) combinations with
|
|
||||||
# the highest score
|
|
||||||
self.topk_log_probabilities, topk_ids = torch.topk(
|
|
||||||
log_probabilities.view(_B, self.beam_size * vocab_size), self.beam_size, dim=1
|
|
||||||
)
|
|
||||||
|
|
||||||
# Apply the length penalty. The +1 accounts for the [EOS] token
|
|
||||||
# that will be added if the beam ends.
|
|
||||||
topk_scores = self.topk_log_probabilities
|
|
||||||
if self.apply_length_penalty:
|
|
||||||
topk_scores /= self._length_penalty()
|
|
||||||
|
|
||||||
# Retrieve the corresponding respective beam and token id
|
|
||||||
# topk_token_ids[i] will be added to topk_beam_ids[i]
|
|
||||||
topk_beam_ids = topk_ids.div(vocab_size)
|
|
||||||
topk_token_ids = topk_ids.fmod(vocab_size)
|
|
||||||
|
|
||||||
# Retrieve the row index of the surviving beams in the original
|
|
||||||
# view of the log_probabilities tensor
|
|
||||||
surviving_beams_per_batch = topk_beam_ids + self.beam_offset[:_B].view(-1, 1)
|
|
||||||
surviving_beams_rows = surviving_beams_per_batch.view(-1)
|
|
||||||
|
|
||||||
# Append the last predictions
|
|
||||||
self.growing_beams = torch.cat(
|
|
||||||
[
|
|
||||||
self.growing_beams.index_select(0, surviving_beams_rows),
|
|
||||||
topk_token_ids.view(-1, 1),
|
|
||||||
],
|
|
||||||
1,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if any of the beam searches has ended during this
|
|
||||||
# growth step. Also if top beam (most probable) has ended
|
|
||||||
# for one element of the batch.
|
|
||||||
is_finished = topk_token_ids.eq(self.eos_token_id)
|
|
||||||
self._enforce_max_length(is_finished)
|
|
||||||
if is_finished.any():
|
|
||||||
non_finished = self._cut_finished(is_finished, topk_scores)
|
|
||||||
self.batch_offset = self.batch_offset.index_select(0, non_finished)
|
|
||||||
surviving_beams_per_batch = surviving_beams_per_batch.index_select(
|
|
||||||
0, non_finished
|
|
||||||
)
|
|
||||||
self.topk_log_probabilities = self.topk_log_probabilities.index_select(
|
|
||||||
0, non_finished
|
|
||||||
)
|
|
||||||
|
|
||||||
surviving_beams_rows = surviving_beams_per_batch.view(-1)
|
|
||||||
self.growing_beams = self.growing_beams.index_select(0, surviving_beams_rows)
|
|
||||||
|
|
||||||
return surviving_beams_rows
|
|
||||||
|
|
||||||
def _cut_finished(self, is_finished, topk_scores):
|
|
||||||
""" Save the finished searches and cut the correponding sequences off
|
|
||||||
the beams. """
|
|
||||||
is_top_beam_finished = is_finished[:, 0].eq(True)
|
|
||||||
|
|
||||||
# Save the finished searches
|
|
||||||
predictions = self.growing_beams.view(
|
|
||||||
-1, self.beam_size, self.growing_beams.size(1)
|
|
||||||
)
|
|
||||||
for i in range(is_finished.size(0)):
|
|
||||||
if is_top_beam_finished[i]:
|
|
||||||
is_finished[i].fill_(1)
|
|
||||||
finished_hyp = is_finished[i].nonzero().view(-1)
|
|
||||||
|
|
||||||
# Store the finished beams as a (score, prediction) hypothesis.
|
|
||||||
b = self.batch_offset[i]
|
|
||||||
for j in finished_hyp:
|
|
||||||
self.hypotheses[b].append((topk_scores[i, j], predictions[i, j, :]))
|
|
||||||
|
|
||||||
# If the batch reached the end, save the best hypotheses
|
|
||||||
# in terms of length-penalized score.
|
|
||||||
if is_top_beam_finished[i]:
|
|
||||||
best_score, best_prediction = max(self.hypotheses[b], key=lambda x: x[0])
|
|
||||||
self.results["scores"][b].append(best_score)
|
|
||||||
self.results["predictions"][b].append(best_prediction)
|
|
||||||
|
|
||||||
non_finished = is_top_beam_finished.eq(False).nonzero().view(-1)
|
|
||||||
if len(non_finished) == 0:
|
|
||||||
self.is_done = True
|
|
||||||
|
|
||||||
return non_finished
|
|
||||||
|
|
||||||
def _remove_beams_with_repeating_trigrams(self, log_probabilities, _B):
|
|
||||||
if self._step + 1 > 3: # [BOS] does not count
|
|
||||||
for i in range(_B * self.beam_size):
|
|
||||||
tokens = self.growing_beams[i]
|
|
||||||
trigrams = [
|
|
||||||
(tokens[j - 1], tokens[j], tokens[j + 1])
|
|
||||||
for j in range(1, len(self) - 1)
|
|
||||||
]
|
|
||||||
last_trigram = tuple(trigrams[-1])
|
|
||||||
if last_trigram in trigrams[:-1]:
|
|
||||||
log_probabilities[i] = -1e20
|
|
||||||
|
|
||||||
def _enforce_min_length(self, log_probabilities):
|
|
||||||
if self._step < self.min_length:
|
|
||||||
log_probabilities[:, self.eos_token_id] = -1e20
|
|
||||||
|
|
||||||
def _enforce_max_length(self, is_finished):
|
|
||||||
# +1 because we will need to add an [EOS] token
|
|
||||||
if self._step + 1 == self.max_length:
|
|
||||||
is_finished.fill_(1)
|
|
||||||
|
|
||||||
def _length_penalty(self):
|
|
||||||
""" The calculation of the length penalty follows that of [1].
|
|
||||||
|
|
||||||
[1] Wu, Yonghui, et al. "Google's neural machine translation system:
|
|
||||||
Bridging the gap between human and machine translation." arXiv preprint
|
|
||||||
arXiv:1609.08144 (2016).
|
|
||||||
"""
|
|
||||||
return ((5.0 + (self._step + 1)) / 6.0) ** self.alpha
|
|
||||||
|
|
||||||
|
|
||||||
def tile(x, count, dim=0):
|
|
||||||
"""
|
|
||||||
Tiles `x` along dimension `dim` `count` times.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
>> ex = torch.tensor([1,2],[3,4])
|
|
||||||
>> tile(ex, 2, 0)
|
|
||||||
torch.Tensor([[1,2],[1,2],[3,4],[3,4]])
|
|
||||||
"""
|
|
||||||
perm = list(range(len(x.size())))
|
|
||||||
if dim != 0:
|
|
||||||
perm[0], perm[dim] = perm[dim], perm[0]
|
|
||||||
x = x.permute(perm).contiguous()
|
|
||||||
out_size = list(x.size())
|
|
||||||
out_size[0] *= count
|
|
||||||
batch = x.size(0)
|
|
||||||
x = (
|
|
||||||
x.view(batch, -1)
|
|
||||||
.transpose(0, 1)
|
|
||||||
.repeat(count, 1)
|
|
||||||
.transpose(0, 1)
|
|
||||||
.contiguous()
|
|
||||||
.view(*out_size)
|
|
||||||
)
|
|
||||||
if dim != 0:
|
|
||||||
x = x.permute(perm).contiguous()
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def fit_to_block_size(sequence, block_size, pad_token_id):
|
|
||||||
""" Adapt the source and target sequences' lengths to the block size.
|
|
||||||
If the sequence is shorter we append padding tokens to the right.
|
|
||||||
"""
|
|
||||||
padded_sequence = torch.full(
|
|
||||||
(sequence.size(0), block_size),
|
|
||||||
pad_token_id,
|
|
||||||
dtype=torch.long,
|
|
||||||
device=sequence.device,
|
|
||||||
)
|
|
||||||
padded_sequence[:, : sequence.size(1)] = sequence
|
|
||||||
return sequence
|
|
||||||
|
|
||||||
|
|
||||||
def build_mask(sequence, pad_token_id):
|
|
||||||
""" Builds the mask. The attention mechanism will only attend to positions
|
|
||||||
with value 1. """
|
|
||||||
mask = torch.ones_like(sequence)
|
|
||||||
idx_pad_tokens = sequence == pad_token_id
|
|
||||||
mask[idx_pad_tokens] = 0
|
|
||||||
return mask
|
|
@ -1,243 +0,0 @@
|
|||||||
from collections import namedtuple
|
|
||||||
import unittest
|
|
||||||
import pytest
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
from transformers.generate import BeamSearch
|
|
||||||
from transformers import PreTrainedEncoderDecoder
|
|
||||||
|
|
||||||
|
|
||||||
class StubTransformer(nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
self.encoder = None
|
|
||||||
self.decoder = None
|
|
||||||
self._parameters = {"dumy": torch.tensor([1])}
|
|
||||||
|
|
||||||
def forward(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class BeamSearchtest(unittest.TestCase):
|
|
||||||
def test_beam_search_encoder_decoder_integration(self):
|
|
||||||
""" We make sure that no internal change in the PreTrainedEncoderDecoder
|
|
||||||
class will break the integration with the beam search.
|
|
||||||
"""
|
|
||||||
|
|
||||||
model = StubTransformer()
|
|
||||||
try:
|
|
||||||
_ = BeamSearch(
|
|
||||||
model=model,
|
|
||||||
bos_token_id=0,
|
|
||||||
eos_token_id=1,
|
|
||||||
pad_token_id=2,
|
|
||||||
batch_size=1,
|
|
||||||
beam_size=1,
|
|
||||||
min_length=1,
|
|
||||||
max_length=1,
|
|
||||||
alpha=0,
|
|
||||||
block_repeating_trigrams=False,
|
|
||||||
)
|
|
||||||
except:
|
|
||||||
self.fail("Instantiating BeamSearch with a PreTrainedEncoderDecoder failed.")
|
|
||||||
|
|
||||||
def test_beam_search_min_length(self):
|
|
||||||
""" We keep predicting the end_token for the first beam and check that
|
|
||||||
it is not marked as finished until the beam has reached the minimum
|
|
||||||
length. """
|
|
||||||
eos_idx = 3
|
|
||||||
vocab_size = 10
|
|
||||||
|
|
||||||
batch_size = 3
|
|
||||||
beam_size = 2
|
|
||||||
min_length = 5
|
|
||||||
|
|
||||||
beam = BeamSearch(
|
|
||||||
model=StubTransformer(),
|
|
||||||
bos_token_id=0,
|
|
||||||
eos_token_id=eos_idx,
|
|
||||||
pad_token_id=2,
|
|
||||||
batch_size=batch_size,
|
|
||||||
beam_size=beam_size,
|
|
||||||
min_length=5,
|
|
||||||
max_length=10,
|
|
||||||
alpha=0,
|
|
||||||
block_repeating_trigrams=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# To test that the minimum length is correctly enforced we constantly
|
|
||||||
# assign the highest probability to the [EOS] token (and assign lower
|
|
||||||
# probabilities to some other tokens).
|
|
||||||
# Since BeamSearch will reset its probability to 1e-20 as long as
|
|
||||||
# min_length has not been reached, we need to reset the value between
|
|
||||||
# steps.
|
|
||||||
non_eos_idxs = [4, 5, 1, 8, 9]
|
|
||||||
score_distribution = torch.log_softmax(
|
|
||||||
torch.tensor([6.0, 5.0, 4.0, 3.0, 2.0, 1.0]), dim=0
|
|
||||||
)
|
|
||||||
|
|
||||||
log_probabilities = torch.full((batch_size * beam_size, vocab_size), float("-inf"))
|
|
||||||
log_probabilities[0, eos_idx] = score_distribution[0]
|
|
||||||
for idx, score in zip(non_eos_idxs, score_distribution[1:]):
|
|
||||||
log_probabilities[0, idx] = score
|
|
||||||
pytest.set_trace()
|
|
||||||
for step in range(1, min_length + 2):
|
|
||||||
log_probabilities[0, eos_idx] = score_distribution[0]
|
|
||||||
|
|
||||||
# Beam #3 and #4 teminate at the first step since the probability
|
|
||||||
# of the [EOS] token is -1e20 > -\infty so there are only two beams left.
|
|
||||||
# The top beam (most likely) always ends with 4 until we reach min_length.
|
|
||||||
surviving_beams_rows = beam.grow(log_probabilities)
|
|
||||||
if step < min_length:
|
|
||||||
np.testing.assert_array_equal(
|
|
||||||
beam.growing_beams.numpy()[0, :], np.array([0] + [4] * step)
|
|
||||||
)
|
|
||||||
elif step == min_length:
|
|
||||||
np.testing.assert_array_equal(surviving_beams_rows.numpy(), np.array([]))
|
|
||||||
self.assertTrue(beam.is_done)
|
|
||||||
break
|
|
||||||
|
|
||||||
log_probabilities = log_probabilities.index_select(0, surviving_beams_rows)
|
|
||||||
|
|
||||||
def test_beam_search_max_length(self):
|
|
||||||
""" We keep predicting the same non-EOS token until we reach the
|
|
||||||
maximum permitted length """
|
|
||||||
batch_size = 3
|
|
||||||
beam_size = 2
|
|
||||||
max_length = 5
|
|
||||||
vocab_size = 10
|
|
||||||
|
|
||||||
beam = BeamSearch(
|
|
||||||
model=StubTransformer(),
|
|
||||||
bos_token_id=0,
|
|
||||||
eos_token_id=1,
|
|
||||||
pad_token_id=2,
|
|
||||||
batch_size=batch_size,
|
|
||||||
beam_size=beam_size,
|
|
||||||
min_length=2,
|
|
||||||
max_length=max_length,
|
|
||||||
alpha=0,
|
|
||||||
block_repeating_trigrams=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
log_probabilities = torch.full((batch_size * beam_size, vocab_size), float("-inf"))
|
|
||||||
|
|
||||||
# To test that beam search enforces the max length constraint we
|
|
||||||
# keep giving the highest probability to a token that is not the
|
|
||||||
# [EOS] token.
|
|
||||||
# The beam search will stop at max_length-1, assuming that one would
|
|
||||||
# add the [EOS] token at the end of the returned sequence.
|
|
||||||
token_idxs = [3, 4, 5]
|
|
||||||
score_distribution = torch.log_softmax(torch.tensor([10.0, 6.0, 4.0]), dim=0)
|
|
||||||
for idx, score in zip(token_idxs, score_distribution):
|
|
||||||
log_probabilities[:, idx] = score
|
|
||||||
|
|
||||||
for step in range(1, max_length + 2):
|
|
||||||
surviving_beams_rows = beam.grow(log_probabilities)
|
|
||||||
if step + 1 < max_length:
|
|
||||||
self.assertFalse(beam.is_done)
|
|
||||||
elif step + 1 == max_length: # Now [EOS] is the most probable token
|
|
||||||
np.testing.assert_array_equal(surviving_beams_rows.numpy(), np.array([]))
|
|
||||||
self.assertTrue(beam.is_done)
|
|
||||||
break
|
|
||||||
|
|
||||||
log_probabilities = log_probabilities.index_select(0, surviving_beams_rows)
|
|
||||||
|
|
||||||
def test_beam_search_block_repeating_trigrams(self):
|
|
||||||
""" We make sure that the beams that contain repeating trigrams are removed. """
|
|
||||||
batch_size = 3
|
|
||||||
beam_size = 2
|
|
||||||
max_length = 10
|
|
||||||
vocab_size = 10
|
|
||||||
|
|
||||||
beam = BeamSearch(
|
|
||||||
model=StubTransformer(),
|
|
||||||
bos_token_id=0,
|
|
||||||
eos_token_id=1,
|
|
||||||
pad_token_id=2,
|
|
||||||
batch_size=batch_size,
|
|
||||||
beam_size=beam_size,
|
|
||||||
min_length=2,
|
|
||||||
max_length=max_length,
|
|
||||||
alpha=0,
|
|
||||||
block_repeating_trigrams=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
log_probabilities = torch.full((batch_size * beam_size, vocab_size), float("-inf"))
|
|
||||||
|
|
||||||
# To test that BeamSearch enforces the 3-gram constraint we give the
|
|
||||||
# highest probably to the same tokens in a cyclic fashion and make sure
|
|
||||||
# they disappear once the cycle has completed.
|
|
||||||
token_idxs = [3, 4, 5]
|
|
||||||
score_distribution = torch.log_softmax(torch.tensor([10.0, 6.0, 4.0]), dim=0)
|
|
||||||
for idx, score in zip(token_idxs, score_distribution):
|
|
||||||
log_probabilities[:, idx] = score
|
|
||||||
|
|
||||||
for step in range(1, max_length + 2):
|
|
||||||
# Rotate the probabilities at each step
|
|
||||||
for idx in token_idxs:
|
|
||||||
score = score_distribution[(idx + step) % 3]
|
|
||||||
log_probabilities[::beam_size, idx] = score
|
|
||||||
|
|
||||||
surviving_beams_rows = beam.grow(log_probabilities)
|
|
||||||
|
|
||||||
if step < 7:
|
|
||||||
self.assertFalse(
|
|
||||||
np.array_equal(
|
|
||||||
log_probabilities.numpy()[0, :],
|
|
||||||
np.array([-1e20] * vocab_size, dtype="float32"),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if step == 7:
|
|
||||||
np.testing.assert_array_equal(
|
|
||||||
log_probabilities.numpy()[0, :],
|
|
||||||
np.array([-1e20] * vocab_size, dtype="float32"),
|
|
||||||
)
|
|
||||||
|
|
||||||
log_probabilities = log_probabilities.index_select(0, surviving_beams_rows)
|
|
||||||
|
|
||||||
def test_beam_search_example_for_one_step(self):
|
|
||||||
""" We test that the predictions for one step of growth are correct. """
|
|
||||||
batch_size = 2
|
|
||||||
beam_size = 2
|
|
||||||
max_length = 10
|
|
||||||
vocab_size = 5
|
|
||||||
|
|
||||||
beam = BeamSearch(
|
|
||||||
model=StubTransformer(),
|
|
||||||
bos_token_id=0,
|
|
||||||
eos_token_id=1,
|
|
||||||
pad_token_id=2,
|
|
||||||
batch_size=batch_size,
|
|
||||||
beam_size=beam_size,
|
|
||||||
min_length=2,
|
|
||||||
max_length=max_length,
|
|
||||||
alpha=0,
|
|
||||||
block_repeating_trigrams=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
log_probabilities = torch.full((batch_size * beam_size, vocab_size), float("-inf"))
|
|
||||||
log_probabilities[0, 3:] = torch.log_softmax(torch.tensor([2.0, 1.0]), dim=0)
|
|
||||||
log_probabilities[2, 3:] = torch.log_softmax(torch.tensor([1.0, 2.0]), dim=0)
|
|
||||||
|
|
||||||
# First pass
|
|
||||||
surviving_beams_rows = beam.grow(log_probabilities)
|
|
||||||
np.testing.assert_array_equal(surviving_beams_rows.numpy(), np.array([0, 0, 2, 2]))
|
|
||||||
np.testing.assert_array_equal(
|
|
||||||
beam.growing_beams.numpy(), np.array([[0, 3], [0, 4], [0, 4], [0, 3]])
|
|
||||||
)
|
|
||||||
self.assertFalse(beam.is_done)
|
|
||||||
|
|
||||||
# Second pass
|
|
||||||
surviving_beams_rows = beam.grow(log_probabilities)
|
|
||||||
np.testing.assert_array_equal(surviving_beams_rows.numpy(), np.array([0, 0, 2, 2]))
|
|
||||||
np.testing.assert_array_equal(
|
|
||||||
beam.growing_beams.numpy(),
|
|
||||||
np.array([[0, 3, 3], [0, 3, 4], [0, 4, 4], [0, 4, 3]]),
|
|
||||||
)
|
|
||||||
self.assertFalse(beam.is_done)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__name__":
|
|
||||||
unittest.main()
|
|
Loading…
Reference in New Issue
Block a user