Improve special_token_id logic in run_generation.py and add tests (#2885)

* improving generation

* finalized special token behaviour for no_beam_search generation

* solved modeling_utils merge conflict

* solve merge conflicts in modeling_utils.py

* add run_generation improvements from PR #2749

* adapted language generation to not use hardcoded -1 if no padding token is available

* remove the -1 removal as hard coded -1`s are not necessary anymore

* add lightweight language generation testing for randomely initialized models - just checking whether no errors are thrown

* add slow language generation tests for pretrained models using hardcoded output with pytorch seed

* delete ipdb

* check that all generated tokens are valid

* renaming

* renaming Generation -> Generate

* make style

* updated so that generate_beam_search has same token behavior than generate_no_beam_search

* consistent return format for run_generation.py

* deleted pretrain lm generate tests -> will be added in another PR

* cleaning of unused if statements and renaming

* run_generate will always return an iterable

* make style

* consistent renaming

* improve naming, make sure generate function always returns the same tensor, add docstring

* add slow tests for all lmhead models

* make style and improve example comments modeling_utils

* better naming and refactoring in modeling_utils

* improving generation

* finalized special token behaviour for no_beam_search generation

* solved modeling_utils merge conflict

* solve merge conflicts in modeling_utils.py

* add run_generation improvements from PR #2749

* adapted language generation to not use hardcoded -1 if no padding token is available

* remove the -1 removal as hard coded -1`s are not necessary anymore

* add lightweight language generation testing for randomely initialized models - just checking whether no errors are thrown

* add slow language generation tests for pretrained models using hardcoded output with pytorch seed

* delete ipdb

* check that all generated tokens are valid

* renaming

* renaming Generation -> Generate

* make style

* updated so that generate_beam_search has same token behavior than generate_no_beam_search

* consistent return format for run_generation.py

* deleted pretrain lm generate tests -> will be added in another PR

* cleaning of unused if statements and renaming

* run_generate will always return an iterable

* make style

* consistent renaming

* improve naming, make sure generate function always returns the same tensor, add docstring

* add slow tests for all lmhead models

* make style and improve example comments modeling_utils

* better naming and refactoring in modeling_utils

* changed fast random lm generation testing design to more general one

* delete in old testing design in gpt2

* correct old variable name

* temporary fix for encoder_decoder lm generation tests - has to be updated when t5 is fixed

* adapted all fast random generate tests to new design

* better warning description in modeling_utils

* better comment

* better comment and error message

Co-authored-by: Thomas Wolf <thomwolf@users.noreply.github.com>
This commit is contained in:
Patrick von Platen 2020-02-21 18:10:00 +01:00 committed by GitHub
parent c749a543fa
commit fc38d4c86f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 231 additions and 75 deletions

View File

@ -106,6 +106,8 @@ def prepare_xlm_input(args, model, tokenizer, prompt_text):
language = None language = None
while language not in available_languages: while language not in available_languages:
language = input("Using XLM. Select language in " + str(list(available_languages)) + " >>> ") language = input("Using XLM. Select language in " + str(list(available_languages)) + " >>> ")
model.config.lang_id = model.config.lang2id[language]
# kwargs["language"] = tokenizer.lang2id[language] # kwargs["language"] = tokenizer.lang2id[language]
# TODO fix mask_token_id setup when configurations will be synchronized between models and tokenizers # TODO fix mask_token_id setup when configurations will be synchronized between models and tokenizers
@ -119,12 +121,12 @@ def prepare_xlm_input(args, model, tokenizer, prompt_text):
def prepare_xlnet_input(args, _, tokenizer, prompt_text): def prepare_xlnet_input(args, _, tokenizer, prompt_text):
prompt_text = (args.padding_text if args.padding_text else PADDING_TEXT) + prompt_text prompt_text = (args.padding_text if args.padding_text else PADDING_TEXT) + prompt_text
return prompt_text, {} return prompt_text
def prepare_transfoxl_input(args, _, tokenizer, prompt_text): def prepare_transfoxl_input(args, _, tokenizer, prompt_text):
prompt_text = (args.padding_text if args.padding_text else PADDING_TEXT) + prompt_text prompt_text = (args.padding_text if args.padding_text else PADDING_TEXT) + prompt_text
return prompt_text, {} return prompt_text
PREPROCESSING_FUNCTIONS = { PREPROCESSING_FUNCTIONS = {
@ -183,6 +185,7 @@ def main():
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available") parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
parser.add_argument("--num_return_sequences", type=int, default=1, help="The number of samples to generate.")
args = parser.parse_args() args = parser.parse_args()
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
@ -210,28 +213,48 @@ def main():
requires_preprocessing = args.model_type in PREPROCESSING_FUNCTIONS.keys() requires_preprocessing = args.model_type in PREPROCESSING_FUNCTIONS.keys()
if requires_preprocessing: if requires_preprocessing:
prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type) prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type)
prompt_text = prepare_input(args, model, tokenizer, prompt_text) preprocessed_prompt_text = prepare_input(args, model, tokenizer, prompt_text)
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt") encoded_prompt = tokenizer.encode(preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt")
else:
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt")
encoded_prompt = encoded_prompt.to(args.device) encoded_prompt = encoded_prompt.to(args.device)
output_sequences = model.generate( output_sequences = model.generate(
input_ids=encoded_prompt, input_ids=encoded_prompt,
max_length=args.length, max_length=args.length + len(encoded_prompt[0]),
temperature=args.temperature, temperature=args.temperature,
top_k=args.k, top_k=args.k,
top_p=args.p, top_p=args.p,
repetition_penalty=args.repetition_penalty, repetition_penalty=args.repetition_penalty,
do_sample=True, do_sample=True,
num_return_sequences=args.num_return_sequences,
) )
# Batch size == 1. to add more examples please use num_return_sequences > 1 # Remove the batch dimension when returning multiple sequences
generated_sequence = output_sequences[0].tolist() if len(output_sequences.shape) > 2:
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True) output_sequences.squeeze_()
text = text[: text.find(args.stop_token) if args.stop_token else None]
print(text) generated_sequences = []
return text for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
print("=== GENERATED SEQUENCE {} ===".format(generated_sequence_idx + 1))
generated_sequence = generated_sequence.tolist()
# Decode text
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
# Remove all text after the stop token
text = text[: text.find(args.stop_token) if args.stop_token else None]
# Add the prompt at the beginning of the sequence. Remove the excess text that was used for pre-processing
total_sequence = (
prompt_text + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :]
)
generated_sequences.append(total_sequence)
print(total_sequence)
return generated_sequences
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -97,4 +97,4 @@ class ExamplesTests(unittest.TestCase):
model_type, model_name = ("--model_type=openai-gpt", "--model_name_or_path=openai-gpt") model_type, model_name = ("--model_type=openai-gpt", "--model_name_or_path=openai-gpt")
with patch.object(sys, "argv", testargs + [model_type, model_name]): with patch.object(sys, "argv", testargs + [model_type, model_name]):
result = run_generation.main() result = run_generation.main()
self.assertGreaterEqual(len(result), 10) self.assertGreaterEqual(len(result[0]), 10)

View File

@ -75,9 +75,9 @@ class PretrainedConfig(object):
self.top_k = kwargs.pop("top_k", 50) self.top_k = kwargs.pop("top_k", 50)
self.top_p = kwargs.pop("top_p", 1.0) self.top_p = kwargs.pop("top_p", 1.0)
self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0) self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
self.bos_token_id = kwargs.pop("bos_token_id", 0) self.bos_token_id = kwargs.pop("bos_token_id", None)
self.pad_token_id = kwargs.pop("pad_token_id", 0) self.pad_token_id = kwargs.pop("pad_token_id", None)
self.eos_token_ids = kwargs.pop("eos_token_ids", 0) self.eos_token_ids = kwargs.pop("eos_token_ids", None)
self.length_penalty = kwargs.pop("length_penalty", 1.0) self.length_penalty = kwargs.pop("length_penalty", 1.0)
self.num_return_sequences = kwargs.pop("num_return_sequences", 1) self.num_return_sequences = kwargs.pop("num_return_sequences", 1)

View File

@ -645,33 +645,39 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
num_return_sequences: (`optional`) int num_return_sequences: (`optional`) int
The number of independently computed returned sequences for each element in the batch. Default to 1. The number of independently computed returned sequences for each element in the batch. Default to 1.
Return:
output: `torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`
sequence_length is either equal to max_length or shorter if all batches finished early due to the `eos_token_id`
Examples:: Examples::
tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer
model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache. model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
outputs = model.generate(max_length=40, bos_token_id=tokenizer.bos_token_id, eos_token_ids=tokenizer.eos_token_id) # do greedy decoding without beam search outputs = model.generate(max_length=40, bos_token_id=tokenizer.bos_token_id, eos_token_ids=tokenizer.eos_token_id, do_sample=False) # do greedy decoding
print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True))) print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
tokenizer = AutoTokenizer.from_pretrained('openai-gpt') # Initialize tokenizer tokenizer = AutoTokenizer.from_pretrained('openai-gpt') # Initialize tokenizer
model = AutoModelWithLMHead.from_pretrained('openai-gpt') # Download model and configuration from S3 and cache. model = AutoModelWithLMHead.from_pretrained('openai-gpt') # Download model and configuration from S3 and cache.
input_context = 'The dog' input_context = 'The dog'
input_ids = torch.tensor(tokenizer.encode(input_context)).unsqueeze(0) # encode input context input_ids = torch.tensor(tokenizer.encode(input_context)).unsqueeze(0) # encode input context
outputs = model.generate(input_ids=input_ids, do_sample=True, num_beams=5, num_return_sequences=3, temperature=1.5) # generate 3 independent sequences using beam search decoding (5 beams) with sampling from initial context 'The dog' outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3, temperature=1.5) # generate 3 independent sequences using beam search decoding (5 beams) with sampling from initial context 'The dog'
for i in range(3): # 3 output sequences were generated for i in range(3): # 3 output sequences were generated
print('Generated {}: {}'.format(i, tokenizer.decode(outputs[0][i], skip_special_tokens=True))) print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer
model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache. model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
input_context = 'The dog' input_context = 'The dog'
input_ids = torch.tensor(tokenizer.encode(input_context)).unsqueeze(0) # encode input context input_ids = torch.tensor(tokenizer.encode(input_context)).unsqueeze(0) # encode input context
outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, bos_token_id=tokenizer.bos_token_id, eos_token_ids=tokenizer.eos_token_id, num_beams=3) # generate sequences using greedy beam search decoding (3 beams) outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, bos_token_id=tokenizer.bos_token_id, pad_token_id=tokenizer.pad_token_id, eos_token_ids=tokenizer.eos_token_id, num_return_sequences=3) # 3 generate sequences using by sampling
print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True))) for i in range(3): # 3 output sequences were generated
print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
tokenizer = AutoTokenizer.from_pretrained('ctrl') # Initialize tokenizer tokenizer = AutoTokenizer.from_pretrained('ctrl') # Initialize tokenizer
model = AutoModelWithLMHead.from_pretrained('ctrl') # Download model and configuration from S3 and cache. model = AutoModelWithLMHead.from_pretrained('ctrl') # Download model and configuration from S3 and cache.
input_context = 'Legal My neighbor is' # "Legal" is one of the control codes for ctrl input_context = 'Legal My neighbor is' # "Legal" is one of the control codes for ctrl
input_ids = torch.tensor(tokenizer.encode(input_context)).unsqueeze(0) # encode input context input_ids = torch.tensor(tokenizer.encode(input_context)).unsqueeze(0) # encode input context
outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2) # generate sequences using using greedy search outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2) # generate sequences
print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True))) print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
""" """
@ -712,10 +718,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer." assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1." assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1."
assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1." assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
assert isinstance(bos_token_id, int) and bos_token_id >= 0, "`bos_token_id` should be a positive integer." assert input_ids is not None or (
assert isinstance(pad_token_id, int) and pad_token_id >= 0, "`pad_token_id` should be a positive integer." isinstance(bos_token_id, int) and bos_token_id >= 0
assert isinstance(eos_token_ids, (list, tuple)) and ( ), "If input_ids is not defined, `bos_token_id` should be a positive integer."
e >= 0 for e in eos_token_ids assert pad_token_id is None or (
isinstance(pad_token_id, int) and (pad_token_id >= 0)
), "`pad_token_id` should be a positive integer."
assert (eos_token_ids is None) or (
isinstance(eos_token_ids, (list, tuple)) and ((isinstance(e, int) and e >= 0) for e in eos_token_ids)
), "`eos_token_ids` should be a positive integer or a list/tuple of positive integers." ), "`eos_token_ids` should be a positive integer or a list/tuple of positive integers."
assert length_penalty > 0, "`length_penalty` should be strictely positive." assert length_penalty > 0, "`length_penalty` should be strictely positive."
assert ( assert (
@ -723,12 +733,22 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
), "`num_return_sequences` should be a strictely positive integer." ), "`num_return_sequences` should be a strictely positive integer."
if input_ids is None: if input_ids is None:
assert isinstance(bos_token_id, int) and bos_token_id >= 0, (
"you should either supply a context to complete as `input_ids` input "
"or a `bos_token_id` (integer >= 0) as a first token to start the generation."
)
input_ids = torch.full( input_ids = torch.full(
(batch_size, 1), bos_token_id, dtype=torch.long, device=next(self.parameters()).device (batch_size, 1), bos_token_id, dtype=torch.long, device=next(self.parameters()).device
) )
else: else:
assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)." assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)."
if pad_token_id is None and eos_token_ids is not None:
logger.warning(
"Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_ids[0])
)
pad_token_id = eos_token_ids[0]
# current position and vocab size # current position and vocab size
cur_len = input_ids.shape[1] cur_len = input_ids.shape[1]
vocab_size = self.config.vocab_size vocab_size = self.config.vocab_size
@ -775,8 +795,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
effective_batch_size, effective_batch_size,
) )
if num_return_sequences != 1:
output = output.view(batch_size, num_return_sequences, -1)
return output return output
def _generate_no_beam_search( def _generate_no_beam_search(
@ -798,6 +816,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
""" """
# current position / max lengths / length of generated sentences / unfinished sentences # current position / max lengths / length of generated sentences / unfinished sentences
unfinished_sents = input_ids.new(batch_size).fill_(1) unfinished_sents = input_ids.new(batch_size).fill_(1)
sent_lengths = input_ids.new(batch_size).fill_(max_length)
past = None past = None
@ -833,21 +852,41 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
next_token = torch.argmax(next_token_logits, dim=-1) next_token = torch.argmax(next_token_logits, dim=-1)
# update generations and finished sentences # update generations and finished sentences
tokens_to_add = next_token * unfinished_sents + pad_token_id * (1 - unfinished_sents) if eos_token_ids is not None:
# pad finished sentences if eos_token_ids exist
tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents)
else:
tokens_to_add = next_token
input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1) input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1)
for eos_token_id in eos_token_ids:
unfinished_sents.mul_(tokens_to_add.ne(eos_token_id).long()) if eos_token_ids is not None:
for eos_token_id in eos_token_ids:
eos_in_sents = tokens_to_add == eos_token_id
# if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(eos_in_sents.long()).bool()
sent_lengths.masked_fill_(is_sents_unfinished_and_token_to_add_is_eos, cur_len + 1)
# unfinished_sents is set to zero if eos in sentence
unfinished_sents.mul_((~eos_in_sents).long())
cur_len = cur_len + 1 cur_len = cur_len + 1
# stop when there is a </s> in each sentence, or if we exceed the maximul length # stop when there is a </s> in each sentence, or if we exceed the maximul length
if unfinished_sents.max() == 0: if unfinished_sents.max() == 0:
break break
# add eos_token_ids to unfinished sentences # if there are different sentences lengths in the batch, some batches have to be padded
if cur_len == max_length: if sent_lengths.min().item() != sent_lengths.max().item():
input_ids[:, -1].masked_fill_(unfinished_sents.to(dtype=torch.bool), eos_token_ids[0]) assert pad_token_id is not None, "`Pad_token_id` has to be defined if batches have different lengths"
# finished sents are filled with pad_token
decoded = input_ids.new(batch_size, sent_lengths.max().item()).fill_(pad_token_id)
else:
decoded = input_ids
return input_ids for hypo_idx, hypo in enumerate(input_ids):
decoded[hypo_idx, : sent_lengths[hypo_idx]] = hypo[: sent_lengths[hypo_idx]]
return decoded
def _generate_beam_search( def _generate_beam_search(
self, self,
@ -941,11 +980,19 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
next_batch_beam = [] next_batch_beam = []
# for each sentence # for each sentence
for batch_ex in range(batch_size): for batch_idx in range(batch_size):
# if we are done with this sentence # if we are done with this sentence
done[batch_ex] = done[batch_ex] or generated_hyps[batch_ex].is_done(next_scores[batch_ex].max().item()) done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
if done[batch_ex]: next_scores[batch_idx].max().item()
)
if done[batch_idx]:
assert (
len(generated_hyps[batch_idx]) >= num_beams
), "Batch can only be done if at least {} beams have been generated".format(num_beams)
assert (
eos_token_ids is not None and pad_token_id is not None
), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch
continue continue
@ -953,30 +1000,29 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
next_sent_beam = [] next_sent_beam = []
# next words for this sentence # next words for this sentence
for idx, score in zip(next_words[batch_ex], next_scores[batch_ex]): for idx, score in zip(next_words[batch_idx], next_scores[batch_idx]):
# get beam and word IDs # get beam and word IDs
beam_id = idx // vocab_size beam_id = idx // vocab_size
word_id = idx % vocab_size word_id = idx % vocab_size
# end of sentence, or next word # add to generated hypotheses if end of sentence or last iteration
if word_id.item() in eos_token_ids or cur_len + 1 == max_length: if eos_token_ids is not None and word_id.item() in eos_token_ids:
generated_hyps[batch_ex].add( generated_hyps[batch_idx].add(
input_ids[batch_ex * num_beams + beam_id, :cur_len].clone(), score.item() input_ids[batch_idx * num_beams + beam_id, :cur_len].clone(), score.item()
) )
else: else:
next_sent_beam.append((score, word_id, batch_ex * num_beams + beam_id)) # add next predicted word if it is not eos_token
next_sent_beam.append((score, word_id, batch_idx * num_beams + beam_id))
# the beam for next step is full # the beam for next step is full
if len(next_sent_beam) == num_beams: if len(next_sent_beam) == num_beams:
break break
# update next beam content # update next beam content
assert len(next_sent_beam) == 0 if cur_len + 1 == max_length else num_beams assert len(next_sent_beam) == num_beams, "Beam should always be full"
if len(next_sent_beam) == 0:
next_sent_beam = [(0, pad_token_id, 0)] * num_beams # pad the batch
next_batch_beam.extend(next_sent_beam) next_batch_beam.extend(next_sent_beam)
assert len(next_batch_beam) == num_beams * (batch_ex + 1) assert len(next_batch_beam) == num_beams * (batch_idx + 1)
# sanity check / prepare next batch # sanity check / prepare next batch
assert len(next_batch_beam) == batch_size * num_beams assert len(next_batch_beam) == batch_size * num_beams
@ -1008,29 +1054,42 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
if all(done): if all(done):
break break
# visualize hypotheses for batch_idx in range(batch_size):
# print([len(x) for x in generated_hyps], cur_len) # Add all open beam hypothesis to generated_hyps
# globals().update( locals() ); if not done[batch_idx]:
# !import code; code.interact(local=vars()) for idx, score in zip(next_words[batch_idx], next_scores[batch_idx]):
# for ii in range(batch_size):
# for ss, ww in sorted(generated_hyps[ii].hyp, key=lambda x: x[0], reverse=True): # get beam and word IDs
# print("%.3f " % ss + " ".join(self.dico[x] for x in ww.tolist())) beam_id = idx // vocab_size
# print("") word_id = idx % vocab_size
generated_hyps[batch_idx].add(
input_ids[batch_idx * num_beams + beam_id, :cur_len].clone(), score.item()
)
# select the best hypotheses # select the best hypotheses
tgt_len = input_ids.new(batch_size) sent_lengths = input_ids.new(batch_size)
best = [] best = []
for i, hypotheses in enumerate(generated_hyps): for i, hypotheses in enumerate(generated_hyps):
best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1] best_hyp = max(hypotheses.beams, key=lambda x: x[0])[1]
tgt_len[i] = len(best_hyp) + 1 # +1 for the <EOS> symbol sent_lengths[i] = len(best_hyp)
best.append(best_hyp) best.append(best_hyp)
# generate target batch # shorter batches are filled with pad_token
decoded = input_ids.new(batch_size, tgt_len.max().item()).fill_(pad_token_id) if sent_lengths.min().item() != sent_lengths.max().item():
for i, hypo in enumerate(best): assert pad_token_id is not None, "`Pad_token_id` has to be defined"
decoded[i, : tgt_len[i] - 1] = hypo sent_max_len = min(sent_lengths.max().item() + 1, max_length)
decoded[i, tgt_len[i] - 1] = eos_token_ids[0] decoded = input_ids.new(batch_size, sent_max_len).fill_(pad_token_id)
# fill with hypothesis and eos_token_id if necessary
for i, hypo in enumerate(best):
decoded[i, : sent_lengths[i]] = hypo
if sent_lengths[i] < max_length:
decoded[i, sent_lengths[i]] = eos_token_ids[0]
else:
# none of the hypotheses have an eos_token
assert (len(hypo) == max_length for hypo in best)
decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)
return decoded return decoded
@ -1071,33 +1130,33 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf")
class BeamHypotheses(object): class BeamHypotheses(object):
def __init__(self, n_hyp, max_length, length_penalty, early_stopping): def __init__(self, num_beams, max_length, length_penalty, early_stopping):
""" """
Initialize n-best list of hypotheses. Initialize n-best list of hypotheses.
""" """
self.max_length = max_length - 1 # ignoring bos_token self.max_length = max_length - 1 # ignoring bos_token
self.length_penalty = length_penalty self.length_penalty = length_penalty
self.early_stopping = early_stopping self.early_stopping = early_stopping
self.n_hyp = n_hyp self.num_beams = num_beams
self.hyp = [] self.beams = []
self.worst_score = 1e9 self.worst_score = 1e9
def __len__(self): def __len__(self):
""" """
Number of hypotheses in the list. Number of hypotheses in the list.
""" """
return len(self.hyp) return len(self.beams)
def add(self, hyp, sum_logprobs): def add(self, hyp, sum_logprobs):
""" """
Add a new hypothesis to the list. Add a new hypothesis to the list.
""" """
score = sum_logprobs / len(hyp) ** self.length_penalty score = sum_logprobs / len(hyp) ** self.length_penalty
if len(self) < self.n_hyp or score > self.worst_score: if len(self) < self.num_beams or score > self.worst_score:
self.hyp.append((score, hyp)) self.beams.append((score, hyp))
if len(self) > self.n_hyp: if len(self) > self.num_beams:
sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.hyp)]) sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)])
del self.hyp[sorted_scores[0][1]] del self.beams[sorted_scores[0][1]]
self.worst_score = sorted_scores[1][0] self.worst_score = sorted_scores[1][0]
else: else:
self.worst_score = min(score, self.worst_score) self.worst_score = min(score, self.worst_score)
@ -1107,7 +1166,7 @@ class BeamHypotheses(object):
If there are enough hypotheses and that none of the hypotheses being generated If there are enough hypotheses and that none of the hypotheses being generated
can become better than the worst one in the heap, then we are done with this sentence. can become better than the worst one in the heap, then we are done with this sentence.
""" """
if len(self) < self.n_hyp: if len(self) < self.num_beams:
return False return False
elif self.early_stopping: elif self.early_stopping:
return True return True

View File

@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy import copy
import logging import logging
import os.path import os.path
@ -53,6 +52,7 @@ class ModelTesterMixin:
model_tester = None model_tester = None
all_model_classes = () all_model_classes = ()
all_generative_model_classes = ()
test_torchscript = True test_torchscript = True
test_pruning = True test_pruning = True
test_resize_embeddings = True test_resize_embeddings = True
@ -595,6 +595,47 @@ class ModelTesterMixin:
with torch.no_grad(): with torch.no_grad():
model(**inputs_dict) model(**inputs_dict)
def test_lm_head_model_random_generate(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict.get(
"input_ids", None
) # TODO (PVP): ugly workaround to make code work for t5 for the moment - has to changed when t5 is fixed.
for model_class in self.all_generative_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
if config.bos_token_id is None:
with self.assertRaises(AssertionError):
model.generate(max_length=5)
# batch_size = 1
self._check_generated_tokens(model.generate(input_ids))
# batch_size = 1, num_beams > 1
self._check_generated_tokens(model.generate(input_ids, num_beams=3))
else:
# batch_size = 1
self._check_generated_tokens(model.generate(max_length=5))
# batch_size = 1, num_beams > 1
self._check_generated_tokens(model.generate(max_length=5, num_beams=3))
# batch_size > 1, sample
self._check_generated_tokens(model.generate(input_ids, num_return_sequences=3))
# batch_size > 1, greedy
self._check_generated_tokens(model.generate(input_ids, do_sample=False, num_return_sequences=3))
# batch_size > 1, num_beams > 1, sample
self._check_generated_tokens(model.generate(input_ids, num_beams=3, num_return_sequences=3,))
# batch_size > 1, num_beams > 1, greedy
self._check_generated_tokens(
model.generate(input_ids, do_sample=False, num_beams=3, num_return_sequences=3)
)
def _check_generated_tokens(self, output_ids):
for token_id in output_ids[0].tolist():
self.assertGreaterEqual(token_id, 0)
self.assertLess(token_id, self.model_tester.vocab_size)
global_rng = random.Random() global_rng = random.Random()

View File

@ -30,6 +30,7 @@ if is_torch_available():
class CTRLModelTest(ModelTesterMixin, unittest.TestCase): class CTRLModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (CTRLModel, CTRLLMHeadModel) if is_torch_available() else () all_model_classes = (CTRLModel, CTRLLMHeadModel) if is_torch_available() else ()
all_generative_model_classes = (CTRLLMHeadModel,) if is_torch_available() else ()
test_pruning = False test_pruning = False
test_torchscript = False test_torchscript = False
test_resize_embeddings = False test_resize_embeddings = False

View File

@ -37,6 +37,9 @@ if is_torch_available():
class GPT2ModelTest(ModelTesterMixin, unittest.TestCase): class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (GPT2Model, GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else () all_model_classes = (GPT2Model, GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else ()
all_generative_model_classes = (
(GPT2LMHeadModel,) if is_torch_available() else ()
) # TODO (PVP): Add Double HeadsModel when generate() function is changed accordingly
class GPT2ModelTester(object): class GPT2ModelTester(object):
def __init__( def __init__(
@ -88,6 +91,8 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
self.num_labels = num_labels self.num_labels = num_labels
self.num_choices = num_choices self.num_choices = num_choices
self.scope = scope self.scope = scope
self.bos_token_id = vocab_size - 1
self.eos_token_id = vocab_size - 1
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
@ -122,9 +127,11 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
# hidden_dropout_prob=self.hidden_dropout_prob, # hidden_dropout_prob=self.hidden_dropout_prob,
# attention_probs_dropout_prob=self.attention_probs_dropout_prob, # attention_probs_dropout_prob=self.attention_probs_dropout_prob,
n_positions=self.max_position_embeddings, n_positions=self.max_position_embeddings,
n_ctx=self.max_position_embeddings n_ctx=self.max_position_embeddings,
# type_vocab_size=self.type_vocab_size, # type_vocab_size=self.type_vocab_size,
# initializer_range=self.initializer_range # initializer_range=self.initializer_range
bos_token_id=self.bos_token_id,
eos_token_ids=self.eos_token_id,
) )
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2) head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)

View File

@ -39,6 +39,9 @@ class OpenAIGPTModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
(OpenAIGPTModel, OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel) if is_torch_available() else () (OpenAIGPTModel, OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel) if is_torch_available() else ()
) )
all_generative_model_classes = (
(OpenAIGPTLMHeadModel,) if is_torch_available() else ()
) # TODO (PVP): Add Double HeadsModel when generate() function is changed accordingly
class OpenAIGPTModelTester(object): class OpenAIGPTModelTester(object):
def __init__( def __init__(

View File

@ -34,6 +34,7 @@ if is_torch_available():
class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase): class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (TransfoXLModel, TransfoXLLMHeadModel) if is_torch_available() else () all_model_classes = (TransfoXLModel, TransfoXLLMHeadModel) if is_torch_available() else ()
all_generative_model_classes = (TransfoXLLMHeadModel,) if is_torch_available() else ()
test_pruning = False test_pruning = False
test_torchscript = False test_torchscript = False
test_resize_embeddings = False test_resize_embeddings = False
@ -59,6 +60,7 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase):
num_hidden_layers=5, num_hidden_layers=5,
scope=None, scope=None,
seed=1, seed=1,
eos_token_id=0,
): ):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
@ -79,6 +81,7 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase):
self.num_hidden_layers = num_hidden_layers self.num_hidden_layers = num_hidden_layers
self.scope = scope self.scope = scope
self.seed = seed self.seed = seed
self.eos_token_id = eos_token_id
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
@ -100,6 +103,7 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase):
d_inner=self.d_inner, d_inner=self.d_inner,
div_val=self.div_val, div_val=self.div_val,
n_layer=self.num_hidden_layers, n_layer=self.num_hidden_layers,
eos_token_ids=self.eos_token_id,
) )
return (config, input_ids_1, input_ids_2, lm_labels) return (config, input_ids_1, input_ids_2, lm_labels)

View File

@ -49,6 +49,9 @@ class XLMModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available() if is_torch_available()
else () else ()
) )
all_generative_model_classes = (
(XLMWithLMHeadModel,) if is_torch_available() else ()
) # TODO (PVP): Check other models whether language generation is also applicable
class XLMModelTester(object): class XLMModelTester(object):
def __init__( def __init__(
@ -81,6 +84,7 @@ class XLMModelTest(ModelTesterMixin, unittest.TestCase):
summary_type="last", summary_type="last",
use_proj=True, use_proj=True,
scope=None, scope=None,
bos_token_id=0,
): ):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
@ -111,6 +115,7 @@ class XLMModelTest(ModelTesterMixin, unittest.TestCase):
self.num_labels = num_labels self.num_labels = num_labels
self.num_choices = num_choices self.num_choices = num_choices
self.scope = scope self.scope = scope
self.bos_token_id = bos_token_id
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
@ -151,6 +156,7 @@ class XLMModelTest(ModelTesterMixin, unittest.TestCase):
initializer_range=self.initializer_range, initializer_range=self.initializer_range,
summary_type=self.summary_type, summary_type=self.summary_type,
use_proj=self.use_proj, use_proj=self.use_proj,
bos_token_id=self.bos_token_id,
) )
return ( return (

View File

@ -52,6 +52,9 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available() if is_torch_available()
else () else ()
) )
all_generative_model_classes = (
(XLNetLMHeadModel,) if is_torch_available() else ()
) # TODO (PVP): Check other models whether language generation is also applicable
test_pruning = False test_pruning = False
class XLNetModelTester(object): class XLNetModelTester(object):
@ -78,6 +81,9 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
initializer_range=0.05, initializer_range=0.05,
seed=1, seed=1,
type_vocab_size=2, type_vocab_size=2,
bos_token_id=1,
eos_token_id=2,
pad_token_id=5,
): ):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
@ -101,6 +107,9 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
self.seed = seed self.seed = seed
self.type_vocab_size = type_vocab_size self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size self.type_sequence_label_size = type_sequence_label_size
self.bos_token_id = bos_token_id
self.pad_token_id = pad_token_id
self.eos_token_id = eos_token_id
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
@ -142,6 +151,9 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
bi_data=self.bi_data, bi_data=self.bi_data,
initializer_range=self.initializer_range, initializer_range=self.initializer_range,
num_labels=self.type_sequence_label_size, num_labels=self.type_sequence_label_size,
bos_token_id=self.bos_token_id,
pad_token_id=self.pad_token_id,
eos_token_id=self.eos_token_id,
) )
return ( return (