mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
beam search + single beam decoding
This commit is contained in:
parent
b6938916ac
commit
bbc0c86f9b
@ -544,29 +544,90 @@ class PreTrainedModel(nn.Module):
|
||||
if isinstance(eos_token_ids, int):
|
||||
eos_token_ids = [eos_token_ids]
|
||||
|
||||
assert isinstance(max_length, int) and 0 < max_length, "`max_length` should be a strictely positive integer."
|
||||
assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictely positive integer."
|
||||
assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
|
||||
assert isinstance(num_beams, int) and 0 < num_beams, "`num_beams` should be a strictely positive integer."
|
||||
assert 0 < temperature, "`temperature` should be positive."
|
||||
assert isinstance(top_k, int) and 0 < top_k, "`top_k` should be a strictely positive integer."
|
||||
assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictely positive integer."
|
||||
assert temperature > 0, "`temperature` should be positive."
|
||||
assert isinstance(top_k, int) and top_k > 0, "`top_k` should be a strictely positive integer."
|
||||
assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1."
|
||||
assert 0 < repetition_penalty, "`repetition_penalty` should be strictely positive."
|
||||
assert isinstance(bos_token_id, int) and 0 <= bos_token_id, "`bos_token_id` should be a positive integer."
|
||||
assert isinstance(pad_token_id, int) and 0 <= pad_token_id, "`pad_token_id` should be a positive integer."
|
||||
assert isinstance(eos_token_ids, (list, tuple)) and (0 <= e for e in eos_token_ids), \
|
||||
assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
|
||||
assert isinstance(bos_token_id, int) and bos_token_id >= 0, "`bos_token_id` should be a positive integer."
|
||||
assert isinstance(pad_token_id, int) and pad_token_id >= 0, "`pad_token_id` should be a positive integer."
|
||||
assert isinstance(eos_token_ids, (list, tuple)) and (e >= 0 for e in eos_token_ids), \
|
||||
"`eos_token_ids` should be a positive integer or a list/tuple of positive integers."
|
||||
assert isinstance(batch_size, int) and 0 < batch_size, "`batch_size` should be a strictely positive integer."
|
||||
assert 0 < length_penalty, "`length_penalty` should be strictely positive."
|
||||
assert isinstance(batch_size, int) and batch_size > 0, "`batch_size` should be a strictely positive integer."
|
||||
assert length_penalty > 0, "`length_penalty` should be strictely positive."
|
||||
|
||||
if input_ids is None:
|
||||
input_ids = torch.full((batch_size, 1), bos_token_id, dtype=torch.long, device=next(self.parameters()).device)
|
||||
else:
|
||||
assert input_ids.dims() == 2
|
||||
assert input_ids.dims() == 2, "Input prompt should be of shape (batch_size, sequence length)."
|
||||
|
||||
# current position and vocab size
|
||||
cur_len = 1
|
||||
cur_len = input_ids.shape[1]
|
||||
vocab_size = self.config.vocab_size
|
||||
|
||||
if num_beams > 1:
|
||||
return self._generate_beam_search(input_ids, cur_len, max_length, do_sample, length_penalty,
|
||||
num_beams, pad_token_id, eos_token_ids, vocab_size, batch_size)
|
||||
|
||||
return self._generate_no_beam_search(input_ids, cur_len, max_length, do_sample,
|
||||
temperature, top_k, top_p, repetition_penalty,
|
||||
pad_token_id, eos_token_ids, batch_size)
|
||||
|
||||
def _generate_no_beam_search(self, input_ids, cur_len, max_length, do_sample,
|
||||
temperature, top_k, top_p, repetition_penalty,
|
||||
pad_token_id, eos_token_ids, batch_size):
|
||||
""" Generate a sentence without beam search (num_beams == 1). """
|
||||
# current position / max lengths / length of generated sentences / unfinished sentences
|
||||
unfinished_sents = input_ids.new(batch_size).fill_(1)
|
||||
|
||||
# cache compute states
|
||||
pasts = None
|
||||
|
||||
while cur_len < max_length:
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts)
|
||||
outputs = self(**model_inputs)
|
||||
next_token_logits = outputs[0][:, -1, :]
|
||||
|
||||
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
|
||||
if repetition_penalty != 1.0:
|
||||
for i in range(batch_size):
|
||||
for _ in set(input_ids[i].tolist()):
|
||||
next_token_logits[i, _] /= repetition_penalty
|
||||
|
||||
if do_sample:
|
||||
# Temperature (higher temperature => more likely to sample low probability tokens)
|
||||
if temperature != 1.0:
|
||||
next_token_logits = next_token_logits / temperature
|
||||
# Top-p/top-k filtering
|
||||
next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
|
||||
# Sample
|
||||
next_token = torch.multinomial(F.softmax(next_token_logits, dim=-1), num_samples=1)
|
||||
else:
|
||||
# Greedy decoding
|
||||
next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1)
|
||||
|
||||
# update generations and finished sentences
|
||||
tokens_to_add = next_token * unfinished_sents + pad_token_id * (1 - unfinished_sents)
|
||||
input_ids = torch.cat([input_ids, tokens_to_add], dim=-1)
|
||||
for eos_token_id in eos_token_ids:
|
||||
unfinished_sents.mul_(tokens_to_add.squeeze(-1).ne(eos_token_id).long())
|
||||
cur_len = cur_len + 1
|
||||
|
||||
# stop when there is a </s> in each sentence, or if we exceed the maximul length
|
||||
if unfinished_sents.max() == 0:
|
||||
break
|
||||
|
||||
# add eos_token_ids to unfinished sentences
|
||||
if cur_len == max_length:
|
||||
input_ids[:, -1].masked_fill_(unfinished_sents.byte(), eos_token_ids[0])
|
||||
|
||||
return input_ids
|
||||
|
||||
def _generate_beam_search(self, input_ids, cur_len, max_length, do_sample, length_penalty,
|
||||
num_beams, pad_token_id, eos_token_ids, vocab_size, batch_size):
|
||||
""" Generate a sentence with beam search. """
|
||||
# Expand input to num beams
|
||||
input_ids = input_ids.unsqueeze(1).expand(batch_size, num_beams, cur_len)
|
||||
input_ids = input_ids.contiguous().view(batch_size * num_beams, cur_len) # (batch_size * num_beams, cur_len)
|
||||
@ -592,9 +653,11 @@ class PreTrainedModel(nn.Module):
|
||||
scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size)
|
||||
assert scores.size() == (batch_size * num_beams, vocab_size)
|
||||
|
||||
# select next words with scores
|
||||
_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
|
||||
_scores = _scores.view(batch_size, num_beams * vocab_size) # (batch_size, num_beams * vocab_size)
|
||||
# Add the log prob of the new beams to the log prob of the beginning of the sequence (sum of logs == log of the product)
|
||||
_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
|
||||
|
||||
# re-organize to group the beam together (we are keeping top hypothesis accross beams)
|
||||
_scores = _scores.view(batch_size, num_beams * vocab_size) # (batch_size, num_beams * vocab_size)
|
||||
|
||||
next_scores, next_words = torch.topk(_scores, 2 * num_beams, dim=1, largest=True, sorted=True)
|
||||
assert next_scores.size() == next_words.size() == (batch_size, 2 * num_beams)
|
||||
@ -604,11 +667,11 @@ class PreTrainedModel(nn.Module):
|
||||
next_batch_beam = []
|
||||
|
||||
# for each sentence
|
||||
for sent_id in range(batch_size):
|
||||
for batch_ex in range(batch_size):
|
||||
|
||||
# if we are done with this sentence
|
||||
done[sent_id] = done[sent_id] or generated_hyps[sent_id].is_done(next_scores[sent_id].max().item())
|
||||
if done[sent_id]:
|
||||
done[batch_ex] = done[batch_ex] or generated_hyps[batch_ex].is_done(next_scores[batch_ex].max().item())
|
||||
if done[batch_ex]:
|
||||
next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch
|
||||
continue
|
||||
|
||||
@ -616,7 +679,7 @@ class PreTrainedModel(nn.Module):
|
||||
next_sent_beam = []
|
||||
|
||||
# next words for this sentence
|
||||
for idx, value in zip(next_words[sent_id], next_scores[sent_id]):
|
||||
for idx, score in zip(next_words[batch_ex], next_scores[batch_ex]):
|
||||
|
||||
# get beam and word IDs
|
||||
beam_id = idx // vocab_size
|
||||
@ -624,9 +687,9 @@ class PreTrainedModel(nn.Module):
|
||||
|
||||
# end of sentence, or next word
|
||||
if word_id.item() in eos_token_ids or cur_len + 1 == max_length:
|
||||
generated_hyps[sent_id].add(input_ids[sent_id * num_beams + beam_id, :cur_len].clone(), value.item())
|
||||
generated_hyps[batch_ex].add(input_ids[batch_ex * num_beams + beam_id, :cur_len].clone(), score.item())
|
||||
else:
|
||||
next_sent_beam.append((value, word_id, sent_id * num_beams + beam_id))
|
||||
next_sent_beam.append((score, word_id, batch_ex * num_beams + beam_id))
|
||||
|
||||
# the beam for next step is full
|
||||
if len(next_sent_beam) == num_beams:
|
||||
@ -637,7 +700,7 @@ class PreTrainedModel(nn.Module):
|
||||
if len(next_sent_beam) == 0:
|
||||
next_sent_beam = [(0, pad_token_id, 0)] * num_beams # pad the batch
|
||||
next_batch_beam.extend(next_sent_beam)
|
||||
assert len(next_batch_beam) == num_beams * (sent_id + 1)
|
||||
assert len(next_batch_beam) == num_beams * (batch_ex + 1)
|
||||
|
||||
# sanity check / prepare next batch
|
||||
assert len(next_batch_beam) == batch_size * num_beams
|
||||
@ -670,7 +733,7 @@ class PreTrainedModel(nn.Module):
|
||||
# print("")
|
||||
|
||||
# select the best hypotheses
|
||||
tgt_len = src_len.new(batch_size)
|
||||
tgt_len = input_ids.new(batch_size)
|
||||
best = []
|
||||
|
||||
for i, hypotheses in enumerate(generated_hyps):
|
||||
@ -679,15 +742,46 @@ class PreTrainedModel(nn.Module):
|
||||
best.append(best_hyp)
|
||||
|
||||
# generate target batch
|
||||
decoded = src_len.new(tgt_len.max().item(), batch_size).fill_(self.pad_index)
|
||||
decoded = input_ids.new(batch_size, tgt_len.max().item()).fill_(pad_token_id)
|
||||
for i, hypo in enumerate(best):
|
||||
decoded[:tgt_len[i] - 1, i] = hypo
|
||||
decoded[tgt_len[i] - 1, i] = self.eos_index
|
||||
decoded[i, :tgt_len[i] - 1] = hypo
|
||||
decoded[i, tgt_len[i] - 1] = eos_token_ids[0]
|
||||
|
||||
# sanity check
|
||||
assert (decoded == self.eos_index).sum() == 2 * batch_size
|
||||
# # sanity check
|
||||
# assert (decoded == eos_token_ids[0]).sum() == 2 * batch_size
|
||||
|
||||
return decoded, tgt_len
|
||||
return decoded
|
||||
|
||||
|
||||
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
|
||||
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
||||
Args:
|
||||
logits: logits distribution shape (batch size x vocabulary size)
|
||||
top_k > 0: keep only top k tokens with highest probability (top-k filtering).
|
||||
top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
||||
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
||||
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
||||
"""
|
||||
top_k = min(top_k, logits.size(-1)) # Safety check
|
||||
if top_k > 0:
|
||||
# Remove all tokens with a probability less than the last token of the top-k
|
||||
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
||||
logits[indices_to_remove] = filter_value
|
||||
|
||||
if top_p > 0.0:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
|
||||
# Remove tokens with cumulative probability above the threshold
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
# Shift the indices to the right to keep also the first token above the threshold
|
||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
|
||||
# scatter sorted tensors to original indexing
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
|
||||
logits[indices_to_remove] = filter_value
|
||||
return logits
|
||||
|
||||
|
||||
class BeamHypotheses(object):
|
||||
|
Loading…
Reference in New Issue
Block a user