beam search + single beam decoding

This commit is contained in:
thomwolf 2019-12-17 23:27:02 +01:00
parent b6938916ac
commit bbc0c86f9b

View File

@ -544,29 +544,90 @@ class PreTrainedModel(nn.Module):
if isinstance(eos_token_ids, int):
eos_token_ids = [eos_token_ids]
assert isinstance(max_length, int) and 0 < max_length, "`max_length` should be a strictely positive integer."
assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictely positive integer."
assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
assert isinstance(num_beams, int) and 0 < num_beams, "`num_beams` should be a strictely positive integer."
assert 0 < temperature, "`temperature` should be positive."
assert isinstance(top_k, int) and 0 < top_k, "`top_k` should be a strictely positive integer."
assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictely positive integer."
assert temperature > 0, "`temperature` should be positive."
assert isinstance(top_k, int) and top_k > 0, "`top_k` should be a strictely positive integer."
assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1."
assert 0 < repetition_penalty, "`repetition_penalty` should be strictely positive."
assert isinstance(bos_token_id, int) and 0 <= bos_token_id, "`bos_token_id` should be a positive integer."
assert isinstance(pad_token_id, int) and 0 <= pad_token_id, "`pad_token_id` should be a positive integer."
assert isinstance(eos_token_ids, (list, tuple)) and (0 <= e for e in eos_token_ids), \
assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
assert isinstance(bos_token_id, int) and bos_token_id >= 0, "`bos_token_id` should be a positive integer."
assert isinstance(pad_token_id, int) and pad_token_id >= 0, "`pad_token_id` should be a positive integer."
assert isinstance(eos_token_ids, (list, tuple)) and (e >= 0 for e in eos_token_ids), \
"`eos_token_ids` should be a positive integer or a list/tuple of positive integers."
assert isinstance(batch_size, int) and 0 < batch_size, "`batch_size` should be a strictely positive integer."
assert 0 < length_penalty, "`length_penalty` should be strictely positive."
assert isinstance(batch_size, int) and batch_size > 0, "`batch_size` should be a strictely positive integer."
assert length_penalty > 0, "`length_penalty` should be strictely positive."
if input_ids is None:
input_ids = torch.full((batch_size, 1), bos_token_id, dtype=torch.long, device=next(self.parameters()).device)
else:
assert input_ids.dims() == 2
assert input_ids.dims() == 2, "Input prompt should be of shape (batch_size, sequence length)."
# current position and vocab size
cur_len = 1
cur_len = input_ids.shape[1]
vocab_size = self.config.vocab_size
if num_beams > 1:
return self._generate_beam_search(input_ids, cur_len, max_length, do_sample, length_penalty,
num_beams, pad_token_id, eos_token_ids, vocab_size, batch_size)
return self._generate_no_beam_search(input_ids, cur_len, max_length, do_sample,
temperature, top_k, top_p, repetition_penalty,
pad_token_id, eos_token_ids, batch_size)
def _generate_no_beam_search(self, input_ids, cur_len, max_length, do_sample,
temperature, top_k, top_p, repetition_penalty,
pad_token_id, eos_token_ids, batch_size):
""" Generate a sentence without beam search (num_beams == 1). """
# current position / max lengths / length of generated sentences / unfinished sentences
unfinished_sents = input_ids.new(batch_size).fill_(1)
# cache compute states
pasts = None
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts)
outputs = self(**model_inputs)
next_token_logits = outputs[0][:, -1, :]
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
if repetition_penalty != 1.0:
for i in range(batch_size):
for _ in set(input_ids[i].tolist()):
next_token_logits[i, _] /= repetition_penalty
if do_sample:
# Temperature (higher temperature => more likely to sample low probability tokens)
if temperature != 1.0:
next_token_logits = next_token_logits / temperature
# Top-p/top-k filtering
next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
# Sample
next_token = torch.multinomial(F.softmax(next_token_logits, dim=-1), num_samples=1)
else:
# Greedy decoding
next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1)
# update generations and finished sentences
tokens_to_add = next_token * unfinished_sents + pad_token_id * (1 - unfinished_sents)
input_ids = torch.cat([input_ids, tokens_to_add], dim=-1)
for eos_token_id in eos_token_ids:
unfinished_sents.mul_(tokens_to_add.squeeze(-1).ne(eos_token_id).long())
cur_len = cur_len + 1
# stop when there is a </s> in each sentence, or if we exceed the maximul length
if unfinished_sents.max() == 0:
break
# add eos_token_ids to unfinished sentences
if cur_len == max_length:
input_ids[:, -1].masked_fill_(unfinished_sents.byte(), eos_token_ids[0])
return input_ids
def _generate_beam_search(self, input_ids, cur_len, max_length, do_sample, length_penalty,
num_beams, pad_token_id, eos_token_ids, vocab_size, batch_size):
""" Generate a sentence with beam search. """
# Expand input to num beams
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_beams, cur_len)
input_ids = input_ids.contiguous().view(batch_size * num_beams, cur_len) # (batch_size * num_beams, cur_len)
@ -592,9 +653,11 @@ class PreTrainedModel(nn.Module):
scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size)
assert scores.size() == (batch_size * num_beams, vocab_size)
# select next words with scores
_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
_scores = _scores.view(batch_size, num_beams * vocab_size) # (batch_size, num_beams * vocab_size)
# Add the log prob of the new beams to the log prob of the beginning of the sequence (sum of logs == log of the product)
_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
# re-organize to group the beam together (we are keeping top hypothesis accross beams)
_scores = _scores.view(batch_size, num_beams * vocab_size) # (batch_size, num_beams * vocab_size)
next_scores, next_words = torch.topk(_scores, 2 * num_beams, dim=1, largest=True, sorted=True)
assert next_scores.size() == next_words.size() == (batch_size, 2 * num_beams)
@ -604,11 +667,11 @@ class PreTrainedModel(nn.Module):
next_batch_beam = []
# for each sentence
for sent_id in range(batch_size):
for batch_ex in range(batch_size):
# if we are done with this sentence
done[sent_id] = done[sent_id] or generated_hyps[sent_id].is_done(next_scores[sent_id].max().item())
if done[sent_id]:
done[batch_ex] = done[batch_ex] or generated_hyps[batch_ex].is_done(next_scores[batch_ex].max().item())
if done[batch_ex]:
next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch
continue
@ -616,7 +679,7 @@ class PreTrainedModel(nn.Module):
next_sent_beam = []
# next words for this sentence
for idx, value in zip(next_words[sent_id], next_scores[sent_id]):
for idx, score in zip(next_words[batch_ex], next_scores[batch_ex]):
# get beam and word IDs
beam_id = idx // vocab_size
@ -624,9 +687,9 @@ class PreTrainedModel(nn.Module):
# end of sentence, or next word
if word_id.item() in eos_token_ids or cur_len + 1 == max_length:
generated_hyps[sent_id].add(input_ids[sent_id * num_beams + beam_id, :cur_len].clone(), value.item())
generated_hyps[batch_ex].add(input_ids[batch_ex * num_beams + beam_id, :cur_len].clone(), score.item())
else:
next_sent_beam.append((value, word_id, sent_id * num_beams + beam_id))
next_sent_beam.append((score, word_id, batch_ex * num_beams + beam_id))
# the beam for next step is full
if len(next_sent_beam) == num_beams:
@ -637,7 +700,7 @@ class PreTrainedModel(nn.Module):
if len(next_sent_beam) == 0:
next_sent_beam = [(0, pad_token_id, 0)] * num_beams # pad the batch
next_batch_beam.extend(next_sent_beam)
assert len(next_batch_beam) == num_beams * (sent_id + 1)
assert len(next_batch_beam) == num_beams * (batch_ex + 1)
# sanity check / prepare next batch
assert len(next_batch_beam) == batch_size * num_beams
@ -670,7 +733,7 @@ class PreTrainedModel(nn.Module):
# print("")
# select the best hypotheses
tgt_len = src_len.new(batch_size)
tgt_len = input_ids.new(batch_size)
best = []
for i, hypotheses in enumerate(generated_hyps):
@ -679,15 +742,46 @@ class PreTrainedModel(nn.Module):
best.append(best_hyp)
# generate target batch
decoded = src_len.new(tgt_len.max().item(), batch_size).fill_(self.pad_index)
decoded = input_ids.new(batch_size, tgt_len.max().item()).fill_(pad_token_id)
for i, hypo in enumerate(best):
decoded[:tgt_len[i] - 1, i] = hypo
decoded[tgt_len[i] - 1, i] = self.eos_index
decoded[i, :tgt_len[i] - 1] = hypo
decoded[i, tgt_len[i] - 1] = eos_token_ids[0]
# sanity check
assert (decoded == self.eos_index).sum() == 2 * batch_size
# # sanity check
# assert (decoded == eos_token_ids[0]).sum() == 2 * batch_size
return decoded, tgt_len
return decoded
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (batch size x vocabulary size)
top_k > 0: keep only top k tokens with highest probability (top-k filtering).
top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
"""
top_k = min(top_k, logits.size(-1)) # Safety check
if top_k > 0:
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p > 0.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
logits[indices_to_remove] = filter_value
return logits
class BeamHypotheses(object):