mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
adding beam search
This commit is contained in:
parent
a468870fd2
commit
b6938916ac
@ -62,13 +62,18 @@ class PretrainedConfig(object):
|
||||
self.is_decoder = kwargs.pop('is_decoder', False)
|
||||
|
||||
# Parameters for sequence generation
|
||||
self.generate_length = kwargs.pop('generate_length', 10)
|
||||
self.generate_max_length = kwargs.pop('generate_max_length', 20)
|
||||
self.generate_do_sample = kwargs.pop('generate_do_sample', False)
|
||||
self.generate_num_beams = kwargs.pop('generate_num_beams', 1)
|
||||
self.generate_temperature = kwargs.pop('generate_temperature', 1.0)
|
||||
self.generate_top_k = kwargs.pop('generate_top_k', 50)
|
||||
self.generate_top_p = kwargs.pop('generate_top_p', 0.0)
|
||||
self.generate_top_p = kwargs.pop('generate_top_p', 1.0)
|
||||
self.generate_repetition_penalty = kwargs.pop('generate_repetition_penalty', 1.0)
|
||||
self.generate_bos_token_id = kwargs.pop('generate_bos_token_id', 0)
|
||||
self.generate_pad_token_id = kwargs.pop('generate_pad_token_id', 0)
|
||||
self.generate_eos_token_ids = kwargs.pop('generate_eos_token_ids', 0)
|
||||
self.generate_batch_size = kwargs.pop('generate_batch_size', 1)
|
||||
self.generate_length_penalty = kwargs.pop('generate_length_penalty', 1.)
|
||||
|
||||
def save_pretrained(self, save_directory):
|
||||
""" Save a configuration object to the directory `save_directory`, so that it
|
||||
|
@ -1,5 +1,5 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
||||
# Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -488,63 +488,252 @@ class PreTrainedModel(nn.Module):
|
||||
|
||||
return model
|
||||
|
||||
def generate(self, input_ids=None, length=None, do_sample=False, num_beams=None,
|
||||
temperature=None, top_k=None, top_p=None, repetition_penalty=None,
|
||||
**model_kwargs):
|
||||
""" Generic sequence generator for single-stack models with a LM head.
|
||||
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
||||
return {"input_ids": input_ids}
|
||||
|
||||
The method currently supports greedy decoding and sampling. See the
|
||||
documentation of the `Sampler` class for more information about the
|
||||
parameters related to sampling.
|
||||
def generate(self, input_ids=None, max_length=None, do_sample=None, num_beams=None,
|
||||
temperature=None, top_k=None, top_p=None, repetition_penalty=None,
|
||||
bos_token_id=None, pad_token_id=None, eos_token_ids=None, batch_size=None,
|
||||
length_penalty=None, **kwargs):
|
||||
""" Sequence generator for models with a LM head.
|
||||
|
||||
The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
|
||||
and beam-search.
|
||||
|
||||
Adapted in part from Facebook's XLM beam search code: https://github.com/facebookresearch/XLM
|
||||
|
||||
Params:
|
||||
**input_ids**: (`optional`) `torch.LongTensor` of shape (1, sequence_length)
|
||||
The sequence used as a prompt for the generation. If `None` the method initializes
|
||||
it as an empty `torch.LongTensor` of shape (1,)
|
||||
**length**: (`optional`) int
|
||||
The length of the sequence to be generated.
|
||||
**max_length**: (`optional`) int
|
||||
The max length of the sequence to be generated. Between 1 and infinity. Default to 20.
|
||||
**do_sample**: (`optional`) bool
|
||||
If set to `False` we use greedy decoding; otherwise sampling.
|
||||
If set to `False` we use greedy decoding; otherwise sampling. Default to greedy sampling.
|
||||
**num_beams**: (`optional`) int
|
||||
Number of beams for beam search. 1 means no beam serach. Default to 1.
|
||||
**temperature**: (`optional`) float
|
||||
The value used to module the next token probabilities.
|
||||
**k**: (`optional`) int
|
||||
The parameter used for k-filtering.
|
||||
**p**: (`optional`) float
|
||||
The parameter for nucleus sampling. Must be between 0 and 1.
|
||||
**top_k**: (`optional`) int
|
||||
The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
|
||||
**top_p**: (`optional`) float
|
||||
The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
|
||||
**repetition_penalty**: (`optional`) float
|
||||
The parameter for repetition penalty.
|
||||
The parameter for repetition penalty. Between 1.0 and + infinity. 1.0 means no penalty. Default to 1.
|
||||
"""
|
||||
|
||||
if input_ids is None:
|
||||
input_ids = torch.tensor([[]], dtype=torch.long, device=next(self.parameters()).device)
|
||||
|
||||
# We cannot generate if the model does not have a LM head
|
||||
if self.get_output_embeddings() is None:
|
||||
raise AttributeError("You tried do generated sequences with a model that does not have a LM Head.")
|
||||
|
||||
sampler_config = {
|
||||
"k": k,
|
||||
"p": p,
|
||||
"do_sample": do_sample,
|
||||
"temperature": temperature,
|
||||
"repetition_penalty": repetition_penalty,
|
||||
}
|
||||
max_length = max_length if max_length is not None else self.config.generate_max_length
|
||||
do_sample = do_sample if do_sample is not None else self.config.generate_do_sample
|
||||
num_beams = num_beams if num_beams is not None else self.config.generate_num_beams
|
||||
temperature = temperature if temperature is not None else self.config.generate_temperature
|
||||
top_k = top_k if top_k is not None else self.config.generate_top_k
|
||||
top_p = top_p if top_p is not None else self.config.generate_top_p
|
||||
repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.generate_repetition_penalty
|
||||
bos_token_id = bos_token_id if bos_token_id is not None else self.config.generate_bos_token_id
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.generate_pad_token_id
|
||||
eos_token_ids = eos_token_ids if eos_token_ids is not None else self.config.generate_eos_token_ids
|
||||
batch_size = batch_size if batch_size is not None else self.config.generate_batch_size
|
||||
length_penalty = length_penalty if length_penalty is not None else self.config.generate_length_penalty
|
||||
|
||||
sampler = Sampler(**sampler_config)
|
||||
generated_sequence = input_ids
|
||||
for _ in trange(length):
|
||||
arguments = self._prepare_inputs_for_decoding(generated_sequence, **model_kwargs)
|
||||
outputs = self(**arguments)
|
||||
next_tokens_logits = outputs[0][:, -1, :]
|
||||
next_tokens = sampler.get_one_token(
|
||||
next_tokens_logits, generated_sequence
|
||||
)
|
||||
generated_sequence = torch.cat((generated_sequence, next_tokens), dim=1)
|
||||
if input_ids is not None:
|
||||
batch_size = input_ids.shape[0] # overriden by the input batch_size
|
||||
if isinstance(eos_token_ids, int):
|
||||
eos_token_ids = [eos_token_ids]
|
||||
|
||||
return generated_sequence.squeeze(0)
|
||||
assert isinstance(max_length, int) and 0 < max_length, "`max_length` should be a strictely positive integer."
|
||||
assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
|
||||
assert isinstance(num_beams, int) and 0 < num_beams, "`num_beams` should be a strictely positive integer."
|
||||
assert 0 < temperature, "`temperature` should be positive."
|
||||
assert isinstance(top_k, int) and 0 < top_k, "`top_k` should be a strictely positive integer."
|
||||
assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1."
|
||||
assert 0 < repetition_penalty, "`repetition_penalty` should be strictely positive."
|
||||
assert isinstance(bos_token_id, int) and 0 <= bos_token_id, "`bos_token_id` should be a positive integer."
|
||||
assert isinstance(pad_token_id, int) and 0 <= pad_token_id, "`pad_token_id` should be a positive integer."
|
||||
assert isinstance(eos_token_ids, (list, tuple)) and (0 <= e for e in eos_token_ids), \
|
||||
"`eos_token_ids` should be a positive integer or a list/tuple of positive integers."
|
||||
assert isinstance(batch_size, int) and 0 < batch_size, "`batch_size` should be a strictely positive integer."
|
||||
assert 0 < length_penalty, "`length_penalty` should be strictely positive."
|
||||
|
||||
def _prepare_inputs_for_decoding(self, input_ids, **model_kwargs):
|
||||
return model_kwargs.update({"input_ids": input_ids})
|
||||
if input_ids is None:
|
||||
input_ids = torch.full((batch_size, 1), bos_token_id, dtype=torch.long, device=next(self.parameters()).device)
|
||||
else:
|
||||
assert input_ids.dims() == 2
|
||||
|
||||
# current position and vocab size
|
||||
cur_len = 1
|
||||
vocab_size = self.config.vocab_size
|
||||
|
||||
# Expand input to num beams
|
||||
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_beams, cur_len)
|
||||
input_ids = input_ids.contiguous().view(batch_size * num_beams, cur_len) # (batch_size * num_beams, cur_len)
|
||||
|
||||
# generated hypotheses
|
||||
generated_hyps = [BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=False) for _ in range(batch_size)]
|
||||
|
||||
# scores for each sentence in the beam
|
||||
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
|
||||
beam_scores[:, 1:] = -1e9
|
||||
beam_scores = beam_scores.view(-1)
|
||||
|
||||
# cache compute states
|
||||
pasts = None # self.prepare_pasts()
|
||||
|
||||
# done sentences
|
||||
done = [False for _ in range(batch_size)]
|
||||
|
||||
while cur_len < max_length:
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts)
|
||||
scores = self(**model_inputs)[0] # (batch_size * num_beams, cur_len, vocab_size)
|
||||
scores = scores[:, -1, :] # (batch_size * num_beams, vocab_size)
|
||||
scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size)
|
||||
assert scores.size() == (batch_size * num_beams, vocab_size)
|
||||
|
||||
# select next words with scores
|
||||
_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
|
||||
_scores = _scores.view(batch_size, num_beams * vocab_size) # (batch_size, num_beams * vocab_size)
|
||||
|
||||
next_scores, next_words = torch.topk(_scores, 2 * num_beams, dim=1, largest=True, sorted=True)
|
||||
assert next_scores.size() == next_words.size() == (batch_size, 2 * num_beams)
|
||||
|
||||
# next batch beam content
|
||||
# list of (batch_size * num_beams) tuple(next hypothesis score, next word, current position in the batch)
|
||||
next_batch_beam = []
|
||||
|
||||
# for each sentence
|
||||
for sent_id in range(batch_size):
|
||||
|
||||
# if we are done with this sentence
|
||||
done[sent_id] = done[sent_id] or generated_hyps[sent_id].is_done(next_scores[sent_id].max().item())
|
||||
if done[sent_id]:
|
||||
next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch
|
||||
continue
|
||||
|
||||
# next sentence beam content
|
||||
next_sent_beam = []
|
||||
|
||||
# next words for this sentence
|
||||
for idx, value in zip(next_words[sent_id], next_scores[sent_id]):
|
||||
|
||||
# get beam and word IDs
|
||||
beam_id = idx // vocab_size
|
||||
word_id = idx % vocab_size
|
||||
|
||||
# end of sentence, or next word
|
||||
if word_id.item() in eos_token_ids or cur_len + 1 == max_length:
|
||||
generated_hyps[sent_id].add(input_ids[sent_id * num_beams + beam_id, :cur_len].clone(), value.item())
|
||||
else:
|
||||
next_sent_beam.append((value, word_id, sent_id * num_beams + beam_id))
|
||||
|
||||
# the beam for next step is full
|
||||
if len(next_sent_beam) == num_beams:
|
||||
break
|
||||
|
||||
# update next beam content
|
||||
assert len(next_sent_beam) == 0 if cur_len + 1 == max_length else num_beams
|
||||
if len(next_sent_beam) == 0:
|
||||
next_sent_beam = [(0, pad_token_id, 0)] * num_beams # pad the batch
|
||||
next_batch_beam.extend(next_sent_beam)
|
||||
assert len(next_batch_beam) == num_beams * (sent_id + 1)
|
||||
|
||||
# sanity check / prepare next batch
|
||||
assert len(next_batch_beam) == batch_size * num_beams
|
||||
beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
|
||||
beam_words = input_ids.new([x[1] for x in next_batch_beam])
|
||||
beam_idx = input_ids.new([x[2] for x in next_batch_beam])
|
||||
|
||||
# re-order batch and internal states
|
||||
input_ids = input_ids[beam_idx, :]
|
||||
input_ids = torch.cat([input_ids, beam_words.unsqueeze(1)], dim=-1)
|
||||
# TODO: Activate cache
|
||||
# for k in cache.keys():
|
||||
# if k != 'slen':
|
||||
# cache[k] = (cache[k][0][beam_idx], cache[k][1][beam_idx])
|
||||
|
||||
# update current length
|
||||
cur_len = cur_len + 1
|
||||
|
||||
# stop when we are done with each sentence
|
||||
if all(done):
|
||||
break
|
||||
|
||||
# visualize hypotheses
|
||||
# print([len(x) for x in generated_hyps], cur_len)
|
||||
# globals().update( locals() );
|
||||
# !import code; code.interact(local=vars())
|
||||
# for ii in range(batch_size):
|
||||
# for ss, ww in sorted(generated_hyps[ii].hyp, key=lambda x: x[0], reverse=True):
|
||||
# print("%.3f " % ss + " ".join(self.dico[x] for x in ww.tolist()))
|
||||
# print("")
|
||||
|
||||
# select the best hypotheses
|
||||
tgt_len = src_len.new(batch_size)
|
||||
best = []
|
||||
|
||||
for i, hypotheses in enumerate(generated_hyps):
|
||||
best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1]
|
||||
tgt_len[i] = len(best_hyp) + 1 # +1 for the <EOS> symbol
|
||||
best.append(best_hyp)
|
||||
|
||||
# generate target batch
|
||||
decoded = src_len.new(tgt_len.max().item(), batch_size).fill_(self.pad_index)
|
||||
for i, hypo in enumerate(best):
|
||||
decoded[:tgt_len[i] - 1, i] = hypo
|
||||
decoded[tgt_len[i] - 1, i] = self.eos_index
|
||||
|
||||
# sanity check
|
||||
assert (decoded == self.eos_index).sum() == 2 * batch_size
|
||||
|
||||
return decoded, tgt_len
|
||||
|
||||
|
||||
class BeamHypotheses(object):
|
||||
|
||||
def __init__(self, n_hyp, max_length, length_penalty, early_stopping):
|
||||
"""
|
||||
Initialize n-best list of hypotheses.
|
||||
"""
|
||||
self.max_length = max_length - 1 # ignoring bos_token
|
||||
self.length_penalty = length_penalty
|
||||
self.early_stopping = early_stopping
|
||||
self.n_hyp = n_hyp
|
||||
self.hyp = []
|
||||
self.worst_score = 1e9
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
Number of hypotheses in the list.
|
||||
"""
|
||||
return len(self.hyp)
|
||||
|
||||
def add(self, hyp, sum_logprobs):
|
||||
"""
|
||||
Add a new hypothesis to the list.
|
||||
"""
|
||||
score = sum_logprobs / len(hyp) ** self.length_penalty
|
||||
if len(self) < self.n_hyp or score > self.worst_score:
|
||||
self.hyp.append((score, hyp))
|
||||
if len(self) > self.n_hyp:
|
||||
sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.hyp)])
|
||||
del self.hyp[sorted_scores[0][1]]
|
||||
self.worst_score = sorted_scores[1][0]
|
||||
else:
|
||||
self.worst_score = min(score, self.worst_score)
|
||||
|
||||
def is_done(self, best_sum_logprobs):
|
||||
"""
|
||||
If there are enough hypotheses and that none of the hypotheses being generated
|
||||
can become better than the worst one in the heap, then we are done with this sentence.
|
||||
"""
|
||||
if len(self) < self.n_hyp:
|
||||
return False
|
||||
elif self.early_stopping:
|
||||
return True
|
||||
else:
|
||||
return self.worst_score >= best_sum_logprobs / self.max_length ** self.length_penalty
|
||||
|
||||
|
||||
class Sampler(object):
|
||||
|
Loading…
Reference in New Issue
Block a user