From bbc0c86f9b96b62b95853a18945f855c661a13b9 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Tue, 17 Dec 2019 23:27:02 +0100 Subject: [PATCH] beam search + single beam decoding --- transformers/modeling_utils.py | 152 ++++++++++++++++++++++++++------- 1 file changed, 123 insertions(+), 29 deletions(-) diff --git a/transformers/modeling_utils.py b/transformers/modeling_utils.py index 003e17a0d9d..52743d8c2ff 100644 --- a/transformers/modeling_utils.py +++ b/transformers/modeling_utils.py @@ -544,29 +544,90 @@ class PreTrainedModel(nn.Module): if isinstance(eos_token_ids, int): eos_token_ids = [eos_token_ids] - assert isinstance(max_length, int) and 0 < max_length, "`max_length` should be a strictely positive integer." + assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictely positive integer." assert isinstance(do_sample, bool), "`do_sample` should be a boolean." - assert isinstance(num_beams, int) and 0 < num_beams, "`num_beams` should be a strictely positive integer." - assert 0 < temperature, "`temperature` should be positive." - assert isinstance(top_k, int) and 0 < top_k, "`top_k` should be a strictely positive integer." + assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictely positive integer." + assert temperature > 0, "`temperature` should be positive." + assert isinstance(top_k, int) and top_k > 0, "`top_k` should be a strictely positive integer." assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1." - assert 0 < repetition_penalty, "`repetition_penalty` should be strictely positive." - assert isinstance(bos_token_id, int) and 0 <= bos_token_id, "`bos_token_id` should be a positive integer." - assert isinstance(pad_token_id, int) and 0 <= pad_token_id, "`pad_token_id` should be a positive integer." - assert isinstance(eos_token_ids, (list, tuple)) and (0 <= e for e in eos_token_ids), \ + assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1." + assert isinstance(bos_token_id, int) and bos_token_id >= 0, "`bos_token_id` should be a positive integer." + assert isinstance(pad_token_id, int) and pad_token_id >= 0, "`pad_token_id` should be a positive integer." + assert isinstance(eos_token_ids, (list, tuple)) and (e >= 0 for e in eos_token_ids), \ "`eos_token_ids` should be a positive integer or a list/tuple of positive integers." - assert isinstance(batch_size, int) and 0 < batch_size, "`batch_size` should be a strictely positive integer." - assert 0 < length_penalty, "`length_penalty` should be strictely positive." + assert isinstance(batch_size, int) and batch_size > 0, "`batch_size` should be a strictely positive integer." + assert length_penalty > 0, "`length_penalty` should be strictely positive." if input_ids is None: input_ids = torch.full((batch_size, 1), bos_token_id, dtype=torch.long, device=next(self.parameters()).device) else: - assert input_ids.dims() == 2 + assert input_ids.dims() == 2, "Input prompt should be of shape (batch_size, sequence length)." # current position and vocab size - cur_len = 1 + cur_len = input_ids.shape[1] vocab_size = self.config.vocab_size + if num_beams > 1: + return self._generate_beam_search(input_ids, cur_len, max_length, do_sample, length_penalty, + num_beams, pad_token_id, eos_token_ids, vocab_size, batch_size) + + return self._generate_no_beam_search(input_ids, cur_len, max_length, do_sample, + temperature, top_k, top_p, repetition_penalty, + pad_token_id, eos_token_ids, batch_size) + + def _generate_no_beam_search(self, input_ids, cur_len, max_length, do_sample, + temperature, top_k, top_p, repetition_penalty, + pad_token_id, eos_token_ids, batch_size): + """ Generate a sentence without beam search (num_beams == 1). """ + # current position / max lengths / length of generated sentences / unfinished sentences + unfinished_sents = input_ids.new(batch_size).fill_(1) + + # cache compute states + pasts = None + + while cur_len < max_length: + model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts) + outputs = self(**model_inputs) + next_token_logits = outputs[0][:, -1, :] + + # repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858) + if repetition_penalty != 1.0: + for i in range(batch_size): + for _ in set(input_ids[i].tolist()): + next_token_logits[i, _] /= repetition_penalty + + if do_sample: + # Temperature (higher temperature => more likely to sample low probability tokens) + if temperature != 1.0: + next_token_logits = next_token_logits / temperature + # Top-p/top-k filtering + next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) + # Sample + next_token = torch.multinomial(F.softmax(next_token_logits, dim=-1), num_samples=1) + else: + # Greedy decoding + next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1) + + # update generations and finished sentences + tokens_to_add = next_token * unfinished_sents + pad_token_id * (1 - unfinished_sents) + input_ids = torch.cat([input_ids, tokens_to_add], dim=-1) + for eos_token_id in eos_token_ids: + unfinished_sents.mul_(tokens_to_add.squeeze(-1).ne(eos_token_id).long()) + cur_len = cur_len + 1 + + # stop when there is a in each sentence, or if we exceed the maximul length + if unfinished_sents.max() == 0: + break + + # add eos_token_ids to unfinished sentences + if cur_len == max_length: + input_ids[:, -1].masked_fill_(unfinished_sents.byte(), eos_token_ids[0]) + + return input_ids + + def _generate_beam_search(self, input_ids, cur_len, max_length, do_sample, length_penalty, + num_beams, pad_token_id, eos_token_ids, vocab_size, batch_size): + """ Generate a sentence with beam search. """ # Expand input to num beams input_ids = input_ids.unsqueeze(1).expand(batch_size, num_beams, cur_len) input_ids = input_ids.contiguous().view(batch_size * num_beams, cur_len) # (batch_size * num_beams, cur_len) @@ -592,9 +653,11 @@ class PreTrainedModel(nn.Module): scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size) assert scores.size() == (batch_size * num_beams, vocab_size) - # select next words with scores - _scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size) - _scores = _scores.view(batch_size, num_beams * vocab_size) # (batch_size, num_beams * vocab_size) + # Add the log prob of the new beams to the log prob of the beginning of the sequence (sum of logs == log of the product) + _scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size) + + # re-organize to group the beam together (we are keeping top hypothesis accross beams) + _scores = _scores.view(batch_size, num_beams * vocab_size) # (batch_size, num_beams * vocab_size) next_scores, next_words = torch.topk(_scores, 2 * num_beams, dim=1, largest=True, sorted=True) assert next_scores.size() == next_words.size() == (batch_size, 2 * num_beams) @@ -604,11 +667,11 @@ class PreTrainedModel(nn.Module): next_batch_beam = [] # for each sentence - for sent_id in range(batch_size): + for batch_ex in range(batch_size): # if we are done with this sentence - done[sent_id] = done[sent_id] or generated_hyps[sent_id].is_done(next_scores[sent_id].max().item()) - if done[sent_id]: + done[batch_ex] = done[batch_ex] or generated_hyps[batch_ex].is_done(next_scores[batch_ex].max().item()) + if done[batch_ex]: next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch continue @@ -616,7 +679,7 @@ class PreTrainedModel(nn.Module): next_sent_beam = [] # next words for this sentence - for idx, value in zip(next_words[sent_id], next_scores[sent_id]): + for idx, score in zip(next_words[batch_ex], next_scores[batch_ex]): # get beam and word IDs beam_id = idx // vocab_size @@ -624,9 +687,9 @@ class PreTrainedModel(nn.Module): # end of sentence, or next word if word_id.item() in eos_token_ids or cur_len + 1 == max_length: - generated_hyps[sent_id].add(input_ids[sent_id * num_beams + beam_id, :cur_len].clone(), value.item()) + generated_hyps[batch_ex].add(input_ids[batch_ex * num_beams + beam_id, :cur_len].clone(), score.item()) else: - next_sent_beam.append((value, word_id, sent_id * num_beams + beam_id)) + next_sent_beam.append((score, word_id, batch_ex * num_beams + beam_id)) # the beam for next step is full if len(next_sent_beam) == num_beams: @@ -637,7 +700,7 @@ class PreTrainedModel(nn.Module): if len(next_sent_beam) == 0: next_sent_beam = [(0, pad_token_id, 0)] * num_beams # pad the batch next_batch_beam.extend(next_sent_beam) - assert len(next_batch_beam) == num_beams * (sent_id + 1) + assert len(next_batch_beam) == num_beams * (batch_ex + 1) # sanity check / prepare next batch assert len(next_batch_beam) == batch_size * num_beams @@ -670,7 +733,7 @@ class PreTrainedModel(nn.Module): # print("") # select the best hypotheses - tgt_len = src_len.new(batch_size) + tgt_len = input_ids.new(batch_size) best = [] for i, hypotheses in enumerate(generated_hyps): @@ -679,15 +742,46 @@ class PreTrainedModel(nn.Module): best.append(best_hyp) # generate target batch - decoded = src_len.new(tgt_len.max().item(), batch_size).fill_(self.pad_index) + decoded = input_ids.new(batch_size, tgt_len.max().item()).fill_(pad_token_id) for i, hypo in enumerate(best): - decoded[:tgt_len[i] - 1, i] = hypo - decoded[tgt_len[i] - 1, i] = self.eos_index + decoded[i, :tgt_len[i] - 1] = hypo + decoded[i, tgt_len[i] - 1] = eos_token_ids[0] - # sanity check - assert (decoded == self.eos_index).sum() == 2 * batch_size + # # sanity check + # assert (decoded == eos_token_ids[0]).sum() == 2 * batch_size - return decoded, tgt_len + return decoded + + +def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): + """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering + Args: + logits: logits distribution shape (batch size x vocabulary size) + top_k > 0: keep only top k tokens with highest probability (top-k filtering). + top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). + Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) + From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 + """ + top_k = min(top_k, logits.size(-1)) # Safety check + if top_k > 0: + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = filter_value + + if top_p > 0.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + + # Remove tokens with cumulative probability above the threshold + sorted_indices_to_remove = cumulative_probs > top_p + # Shift the indices to the right to keep also the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + # scatter sorted tensors to original indexing + indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove) + logits[indices_to_remove] = filter_value + return logits class BeamHypotheses(object):