adding beam search

This commit is contained in:
thomwolf 2019-12-17 17:23:36 +01:00
parent a468870fd2
commit b6938916ac
2 changed files with 235 additions and 41 deletions

View File

@ -62,13 +62,18 @@ class PretrainedConfig(object):
self.is_decoder = kwargs.pop('is_decoder', False)
# Parameters for sequence generation
self.generate_length = kwargs.pop('generate_length', 10)
self.generate_max_length = kwargs.pop('generate_max_length', 20)
self.generate_do_sample = kwargs.pop('generate_do_sample', False)
self.generate_num_beams = kwargs.pop('generate_num_beams', 1)
self.generate_temperature = kwargs.pop('generate_temperature', 1.0)
self.generate_top_k = kwargs.pop('generate_top_k', 50)
self.generate_top_p = kwargs.pop('generate_top_p', 0.0)
self.generate_top_p = kwargs.pop('generate_top_p', 1.0)
self.generate_repetition_penalty = kwargs.pop('generate_repetition_penalty', 1.0)
self.generate_bos_token_id = kwargs.pop('generate_bos_token_id', 0)
self.generate_pad_token_id = kwargs.pop('generate_pad_token_id', 0)
self.generate_eos_token_ids = kwargs.pop('generate_eos_token_ids', 0)
self.generate_batch_size = kwargs.pop('generate_batch_size', 1)
self.generate_length_penalty = kwargs.pop('generate_length_penalty', 1.)
def save_pretrained(self, save_directory):
""" Save a configuration object to the directory `save_directory`, so that it

View File

@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -488,63 +488,252 @@ class PreTrainedModel(nn.Module):
return model
def generate(self, input_ids=None, length=None, do_sample=False, num_beams=None,
temperature=None, top_k=None, top_p=None, repetition_penalty=None,
**model_kwargs):
""" Generic sequence generator for single-stack models with a LM head.
def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {"input_ids": input_ids}
The method currently supports greedy decoding and sampling. See the
documentation of the `Sampler` class for more information about the
parameters related to sampling.
def generate(self, input_ids=None, max_length=None, do_sample=None, num_beams=None,
temperature=None, top_k=None, top_p=None, repetition_penalty=None,
bos_token_id=None, pad_token_id=None, eos_token_ids=None, batch_size=None,
length_penalty=None, **kwargs):
""" Sequence generator for models with a LM head.
The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
and beam-search.
Adapted in part from Facebook's XLM beam search code: https://github.com/facebookresearch/XLM
Params:
**input_ids**: (`optional`) `torch.LongTensor` of shape (1, sequence_length)
The sequence used as a prompt for the generation. If `None` the method initializes
it as an empty `torch.LongTensor` of shape (1,)
**length**: (`optional`) int
The length of the sequence to be generated.
**max_length**: (`optional`) int
The max length of the sequence to be generated. Between 1 and infinity. Default to 20.
**do_sample**: (`optional`) bool
If set to `False` we use greedy decoding; otherwise sampling.
If set to `False` we use greedy decoding; otherwise sampling. Default to greedy sampling.
**num_beams**: (`optional`) int
Number of beams for beam search. 1 means no beam serach. Default to 1.
**temperature**: (`optional`) float
The value used to module the next token probabilities.
**k**: (`optional`) int
The parameter used for k-filtering.
**p**: (`optional`) float
The parameter for nucleus sampling. Must be between 0 and 1.
**top_k**: (`optional`) int
The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
**top_p**: (`optional`) float
The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
**repetition_penalty**: (`optional`) float
The parameter for repetition penalty.
The parameter for repetition penalty. Between 1.0 and + infinity. 1.0 means no penalty. Default to 1.
"""
if input_ids is None:
input_ids = torch.tensor([[]], dtype=torch.long, device=next(self.parameters()).device)
# We cannot generate if the model does not have a LM head
if self.get_output_embeddings() is None:
raise AttributeError("You tried do generated sequences with a model that does not have a LM Head.")
sampler_config = {
"k": k,
"p": p,
"do_sample": do_sample,
"temperature": temperature,
"repetition_penalty": repetition_penalty,
}
max_length = max_length if max_length is not None else self.config.generate_max_length
do_sample = do_sample if do_sample is not None else self.config.generate_do_sample
num_beams = num_beams if num_beams is not None else self.config.generate_num_beams
temperature = temperature if temperature is not None else self.config.generate_temperature
top_k = top_k if top_k is not None else self.config.generate_top_k
top_p = top_p if top_p is not None else self.config.generate_top_p
repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.generate_repetition_penalty
bos_token_id = bos_token_id if bos_token_id is not None else self.config.generate_bos_token_id
pad_token_id = pad_token_id if pad_token_id is not None else self.config.generate_pad_token_id
eos_token_ids = eos_token_ids if eos_token_ids is not None else self.config.generate_eos_token_ids
batch_size = batch_size if batch_size is not None else self.config.generate_batch_size
length_penalty = length_penalty if length_penalty is not None else self.config.generate_length_penalty
sampler = Sampler(**sampler_config)
generated_sequence = input_ids
for _ in trange(length):
arguments = self._prepare_inputs_for_decoding(generated_sequence, **model_kwargs)
outputs = self(**arguments)
next_tokens_logits = outputs[0][:, -1, :]
next_tokens = sampler.get_one_token(
next_tokens_logits, generated_sequence
)
generated_sequence = torch.cat((generated_sequence, next_tokens), dim=1)
if input_ids is not None:
batch_size = input_ids.shape[0] # overriden by the input batch_size
if isinstance(eos_token_ids, int):
eos_token_ids = [eos_token_ids]
return generated_sequence.squeeze(0)
assert isinstance(max_length, int) and 0 < max_length, "`max_length` should be a strictely positive integer."
assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
assert isinstance(num_beams, int) and 0 < num_beams, "`num_beams` should be a strictely positive integer."
assert 0 < temperature, "`temperature` should be positive."
assert isinstance(top_k, int) and 0 < top_k, "`top_k` should be a strictely positive integer."
assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1."
assert 0 < repetition_penalty, "`repetition_penalty` should be strictely positive."
assert isinstance(bos_token_id, int) and 0 <= bos_token_id, "`bos_token_id` should be a positive integer."
assert isinstance(pad_token_id, int) and 0 <= pad_token_id, "`pad_token_id` should be a positive integer."
assert isinstance(eos_token_ids, (list, tuple)) and (0 <= e for e in eos_token_ids), \
"`eos_token_ids` should be a positive integer or a list/tuple of positive integers."
assert isinstance(batch_size, int) and 0 < batch_size, "`batch_size` should be a strictely positive integer."
assert 0 < length_penalty, "`length_penalty` should be strictely positive."
def _prepare_inputs_for_decoding(self, input_ids, **model_kwargs):
return model_kwargs.update({"input_ids": input_ids})
if input_ids is None:
input_ids = torch.full((batch_size, 1), bos_token_id, dtype=torch.long, device=next(self.parameters()).device)
else:
assert input_ids.dims() == 2
# current position and vocab size
cur_len = 1
vocab_size = self.config.vocab_size
# Expand input to num beams
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_beams, cur_len)
input_ids = input_ids.contiguous().view(batch_size * num_beams, cur_len) # (batch_size * num_beams, cur_len)
# generated hypotheses
generated_hyps = [BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=False) for _ in range(batch_size)]
# scores for each sentence in the beam
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
beam_scores[:, 1:] = -1e9
beam_scores = beam_scores.view(-1)
# cache compute states
pasts = None # self.prepare_pasts()
# done sentences
done = [False for _ in range(batch_size)]
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts)
scores = self(**model_inputs)[0] # (batch_size * num_beams, cur_len, vocab_size)
scores = scores[:, -1, :] # (batch_size * num_beams, vocab_size)
scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size)
assert scores.size() == (batch_size * num_beams, vocab_size)
# select next words with scores
_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
_scores = _scores.view(batch_size, num_beams * vocab_size) # (batch_size, num_beams * vocab_size)
next_scores, next_words = torch.topk(_scores, 2 * num_beams, dim=1, largest=True, sorted=True)
assert next_scores.size() == next_words.size() == (batch_size, 2 * num_beams)
# next batch beam content
# list of (batch_size * num_beams) tuple(next hypothesis score, next word, current position in the batch)
next_batch_beam = []
# for each sentence
for sent_id in range(batch_size):
# if we are done with this sentence
done[sent_id] = done[sent_id] or generated_hyps[sent_id].is_done(next_scores[sent_id].max().item())
if done[sent_id]:
next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch
continue
# next sentence beam content
next_sent_beam = []
# next words for this sentence
for idx, value in zip(next_words[sent_id], next_scores[sent_id]):
# get beam and word IDs
beam_id = idx // vocab_size
word_id = idx % vocab_size
# end of sentence, or next word
if word_id.item() in eos_token_ids or cur_len + 1 == max_length:
generated_hyps[sent_id].add(input_ids[sent_id * num_beams + beam_id, :cur_len].clone(), value.item())
else:
next_sent_beam.append((value, word_id, sent_id * num_beams + beam_id))
# the beam for next step is full
if len(next_sent_beam) == num_beams:
break
# update next beam content
assert len(next_sent_beam) == 0 if cur_len + 1 == max_length else num_beams
if len(next_sent_beam) == 0:
next_sent_beam = [(0, pad_token_id, 0)] * num_beams # pad the batch
next_batch_beam.extend(next_sent_beam)
assert len(next_batch_beam) == num_beams * (sent_id + 1)
# sanity check / prepare next batch
assert len(next_batch_beam) == batch_size * num_beams
beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
beam_words = input_ids.new([x[1] for x in next_batch_beam])
beam_idx = input_ids.new([x[2] for x in next_batch_beam])
# re-order batch and internal states
input_ids = input_ids[beam_idx, :]
input_ids = torch.cat([input_ids, beam_words.unsqueeze(1)], dim=-1)
# TODO: Activate cache
# for k in cache.keys():
# if k != 'slen':
# cache[k] = (cache[k][0][beam_idx], cache[k][1][beam_idx])
# update current length
cur_len = cur_len + 1
# stop when we are done with each sentence
if all(done):
break
# visualize hypotheses
# print([len(x) for x in generated_hyps], cur_len)
# globals().update( locals() );
# !import code; code.interact(local=vars())
# for ii in range(batch_size):
# for ss, ww in sorted(generated_hyps[ii].hyp, key=lambda x: x[0], reverse=True):
# print("%.3f " % ss + " ".join(self.dico[x] for x in ww.tolist()))
# print("")
# select the best hypotheses
tgt_len = src_len.new(batch_size)
best = []
for i, hypotheses in enumerate(generated_hyps):
best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1]
tgt_len[i] = len(best_hyp) + 1 # +1 for the <EOS> symbol
best.append(best_hyp)
# generate target batch
decoded = src_len.new(tgt_len.max().item(), batch_size).fill_(self.pad_index)
for i, hypo in enumerate(best):
decoded[:tgt_len[i] - 1, i] = hypo
decoded[tgt_len[i] - 1, i] = self.eos_index
# sanity check
assert (decoded == self.eos_index).sum() == 2 * batch_size
return decoded, tgt_len
class BeamHypotheses(object):
def __init__(self, n_hyp, max_length, length_penalty, early_stopping):
"""
Initialize n-best list of hypotheses.
"""
self.max_length = max_length - 1 # ignoring bos_token
self.length_penalty = length_penalty
self.early_stopping = early_stopping
self.n_hyp = n_hyp
self.hyp = []
self.worst_score = 1e9
def __len__(self):
"""
Number of hypotheses in the list.
"""
return len(self.hyp)
def add(self, hyp, sum_logprobs):
"""
Add a new hypothesis to the list.
"""
score = sum_logprobs / len(hyp) ** self.length_penalty
if len(self) < self.n_hyp or score > self.worst_score:
self.hyp.append((score, hyp))
if len(self) > self.n_hyp:
sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.hyp)])
del self.hyp[sorted_scores[0][1]]
self.worst_score = sorted_scores[1][0]
else:
self.worst_score = min(score, self.worst_score)
def is_done(self, best_sum_logprobs):
"""
If there are enough hypotheses and that none of the hypotheses being generated
can become better than the worst one in the heap, then we are done with this sentence.
"""
if len(self) < self.n_hyp:
return False
elif self.early_stopping:
return True
else:
return self.worst_score >= best_sum_logprobs / self.max_length ** self.length_penalty
class Sampler(object):