mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Improve special_token_id logic in run_generation.py and add tests (#2885)
* improving generation * finalized special token behaviour for no_beam_search generation * solved modeling_utils merge conflict * solve merge conflicts in modeling_utils.py * add run_generation improvements from PR #2749 * adapted language generation to not use hardcoded -1 if no padding token is available * remove the -1 removal as hard coded -1`s are not necessary anymore * add lightweight language generation testing for randomely initialized models - just checking whether no errors are thrown * add slow language generation tests for pretrained models using hardcoded output with pytorch seed * delete ipdb * check that all generated tokens are valid * renaming * renaming Generation -> Generate * make style * updated so that generate_beam_search has same token behavior than generate_no_beam_search * consistent return format for run_generation.py * deleted pretrain lm generate tests -> will be added in another PR * cleaning of unused if statements and renaming * run_generate will always return an iterable * make style * consistent renaming * improve naming, make sure generate function always returns the same tensor, add docstring * add slow tests for all lmhead models * make style and improve example comments modeling_utils * better naming and refactoring in modeling_utils * improving generation * finalized special token behaviour for no_beam_search generation * solved modeling_utils merge conflict * solve merge conflicts in modeling_utils.py * add run_generation improvements from PR #2749 * adapted language generation to not use hardcoded -1 if no padding token is available * remove the -1 removal as hard coded -1`s are not necessary anymore * add lightweight language generation testing for randomely initialized models - just checking whether no errors are thrown * add slow language generation tests for pretrained models using hardcoded output with pytorch seed * delete ipdb * check that all generated tokens are valid * renaming * renaming Generation -> Generate * make style * updated so that generate_beam_search has same token behavior than generate_no_beam_search * consistent return format for run_generation.py * deleted pretrain lm generate tests -> will be added in another PR * cleaning of unused if statements and renaming * run_generate will always return an iterable * make style * consistent renaming * improve naming, make sure generate function always returns the same tensor, add docstring * add slow tests for all lmhead models * make style and improve example comments modeling_utils * better naming and refactoring in modeling_utils * changed fast random lm generation testing design to more general one * delete in old testing design in gpt2 * correct old variable name * temporary fix for encoder_decoder lm generation tests - has to be updated when t5 is fixed * adapted all fast random generate tests to new design * better warning description in modeling_utils * better comment * better comment and error message Co-authored-by: Thomas Wolf <thomwolf@users.noreply.github.com>
This commit is contained in:
parent
c749a543fa
commit
fc38d4c86f
@ -106,6 +106,8 @@ def prepare_xlm_input(args, model, tokenizer, prompt_text):
|
||||
language = None
|
||||
while language not in available_languages:
|
||||
language = input("Using XLM. Select language in " + str(list(available_languages)) + " >>> ")
|
||||
|
||||
model.config.lang_id = model.config.lang2id[language]
|
||||
# kwargs["language"] = tokenizer.lang2id[language]
|
||||
|
||||
# TODO fix mask_token_id setup when configurations will be synchronized between models and tokenizers
|
||||
@ -119,12 +121,12 @@ def prepare_xlm_input(args, model, tokenizer, prompt_text):
|
||||
|
||||
def prepare_xlnet_input(args, _, tokenizer, prompt_text):
|
||||
prompt_text = (args.padding_text if args.padding_text else PADDING_TEXT) + prompt_text
|
||||
return prompt_text, {}
|
||||
return prompt_text
|
||||
|
||||
|
||||
def prepare_transfoxl_input(args, _, tokenizer, prompt_text):
|
||||
prompt_text = (args.padding_text if args.padding_text else PADDING_TEXT) + prompt_text
|
||||
return prompt_text, {}
|
||||
return prompt_text
|
||||
|
||||
|
||||
PREPROCESSING_FUNCTIONS = {
|
||||
@ -183,6 +185,7 @@ def main():
|
||||
|
||||
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
||||
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
|
||||
parser.add_argument("--num_return_sequences", type=int, default=1, help="The number of samples to generate.")
|
||||
args = parser.parse_args()
|
||||
|
||||
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
||||
@ -210,28 +213,48 @@ def main():
|
||||
requires_preprocessing = args.model_type in PREPROCESSING_FUNCTIONS.keys()
|
||||
if requires_preprocessing:
|
||||
prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type)
|
||||
prompt_text = prepare_input(args, model, tokenizer, prompt_text)
|
||||
preprocessed_prompt_text = prepare_input(args, model, tokenizer, prompt_text)
|
||||
encoded_prompt = tokenizer.encode(preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt")
|
||||
else:
|
||||
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt")
|
||||
encoded_prompt = encoded_prompt.to(args.device)
|
||||
|
||||
output_sequences = model.generate(
|
||||
input_ids=encoded_prompt,
|
||||
max_length=args.length,
|
||||
max_length=args.length + len(encoded_prompt[0]),
|
||||
temperature=args.temperature,
|
||||
top_k=args.k,
|
||||
top_p=args.p,
|
||||
repetition_penalty=args.repetition_penalty,
|
||||
do_sample=True,
|
||||
num_return_sequences=args.num_return_sequences,
|
||||
)
|
||||
|
||||
# Batch size == 1. to add more examples please use num_return_sequences > 1
|
||||
generated_sequence = output_sequences[0].tolist()
|
||||
# Remove the batch dimension when returning multiple sequences
|
||||
if len(output_sequences.shape) > 2:
|
||||
output_sequences.squeeze_()
|
||||
|
||||
generated_sequences = []
|
||||
|
||||
for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
|
||||
print("=== GENERATED SEQUENCE {} ===".format(generated_sequence_idx + 1))
|
||||
generated_sequence = generated_sequence.tolist()
|
||||
|
||||
# Decode text
|
||||
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
|
||||
|
||||
# Remove all text after the stop token
|
||||
text = text[: text.find(args.stop_token) if args.stop_token else None]
|
||||
|
||||
print(text)
|
||||
# Add the prompt at the beginning of the sequence. Remove the excess text that was used for pre-processing
|
||||
total_sequence = (
|
||||
prompt_text + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :]
|
||||
)
|
||||
|
||||
return text
|
||||
generated_sequences.append(total_sequence)
|
||||
print(total_sequence)
|
||||
|
||||
return generated_sequences
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -97,4 +97,4 @@ class ExamplesTests(unittest.TestCase):
|
||||
model_type, model_name = ("--model_type=openai-gpt", "--model_name_or_path=openai-gpt")
|
||||
with patch.object(sys, "argv", testargs + [model_type, model_name]):
|
||||
result = run_generation.main()
|
||||
self.assertGreaterEqual(len(result), 10)
|
||||
self.assertGreaterEqual(len(result[0]), 10)
|
||||
|
@ -75,9 +75,9 @@ class PretrainedConfig(object):
|
||||
self.top_k = kwargs.pop("top_k", 50)
|
||||
self.top_p = kwargs.pop("top_p", 1.0)
|
||||
self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
|
||||
self.bos_token_id = kwargs.pop("bos_token_id", 0)
|
||||
self.pad_token_id = kwargs.pop("pad_token_id", 0)
|
||||
self.eos_token_ids = kwargs.pop("eos_token_ids", 0)
|
||||
self.bos_token_id = kwargs.pop("bos_token_id", None)
|
||||
self.pad_token_id = kwargs.pop("pad_token_id", None)
|
||||
self.eos_token_ids = kwargs.pop("eos_token_ids", None)
|
||||
self.length_penalty = kwargs.pop("length_penalty", 1.0)
|
||||
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
|
||||
|
||||
|
@ -645,33 +645,39 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
num_return_sequences: (`optional`) int
|
||||
The number of independently computed returned sequences for each element in the batch. Default to 1.
|
||||
|
||||
Return:
|
||||
|
||||
output: `torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`
|
||||
sequence_length is either equal to max_length or shorter if all batches finished early due to the `eos_token_id`
|
||||
|
||||
Examples::
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer
|
||||
model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
|
||||
outputs = model.generate(max_length=40, bos_token_id=tokenizer.bos_token_id, eos_token_ids=tokenizer.eos_token_id) # do greedy decoding without beam search
|
||||
outputs = model.generate(max_length=40, bos_token_id=tokenizer.bos_token_id, eos_token_ids=tokenizer.eos_token_id, do_sample=False) # do greedy decoding
|
||||
print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained('openai-gpt') # Initialize tokenizer
|
||||
model = AutoModelWithLMHead.from_pretrained('openai-gpt') # Download model and configuration from S3 and cache.
|
||||
input_context = 'The dog'
|
||||
input_ids = torch.tensor(tokenizer.encode(input_context)).unsqueeze(0) # encode input context
|
||||
outputs = model.generate(input_ids=input_ids, do_sample=True, num_beams=5, num_return_sequences=3, temperature=1.5) # generate 3 independent sequences using beam search decoding (5 beams) with sampling from initial context 'The dog'
|
||||
outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3, temperature=1.5) # generate 3 independent sequences using beam search decoding (5 beams) with sampling from initial context 'The dog'
|
||||
for i in range(3): # 3 output sequences were generated
|
||||
print('Generated {}: {}'.format(i, tokenizer.decode(outputs[0][i], skip_special_tokens=True)))
|
||||
print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer
|
||||
model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
|
||||
input_context = 'The dog'
|
||||
input_ids = torch.tensor(tokenizer.encode(input_context)).unsqueeze(0) # encode input context
|
||||
outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, bos_token_id=tokenizer.bos_token_id, eos_token_ids=tokenizer.eos_token_id, num_beams=3) # generate sequences using greedy beam search decoding (3 beams)
|
||||
print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
|
||||
outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, bos_token_id=tokenizer.bos_token_id, pad_token_id=tokenizer.pad_token_id, eos_token_ids=tokenizer.eos_token_id, num_return_sequences=3) # 3 generate sequences using by sampling
|
||||
for i in range(3): # 3 output sequences were generated
|
||||
print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained('ctrl') # Initialize tokenizer
|
||||
model = AutoModelWithLMHead.from_pretrained('ctrl') # Download model and configuration from S3 and cache.
|
||||
input_context = 'Legal My neighbor is' # "Legal" is one of the control codes for ctrl
|
||||
input_ids = torch.tensor(tokenizer.encode(input_context)).unsqueeze(0) # encode input context
|
||||
outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2) # generate sequences using using greedy search
|
||||
outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2) # generate sequences
|
||||
print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
|
||||
|
||||
"""
|
||||
@ -712,10 +718,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
|
||||
assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1."
|
||||
assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
|
||||
assert isinstance(bos_token_id, int) and bos_token_id >= 0, "`bos_token_id` should be a positive integer."
|
||||
assert isinstance(pad_token_id, int) and pad_token_id >= 0, "`pad_token_id` should be a positive integer."
|
||||
assert isinstance(eos_token_ids, (list, tuple)) and (
|
||||
e >= 0 for e in eos_token_ids
|
||||
assert input_ids is not None or (
|
||||
isinstance(bos_token_id, int) and bos_token_id >= 0
|
||||
), "If input_ids is not defined, `bos_token_id` should be a positive integer."
|
||||
assert pad_token_id is None or (
|
||||
isinstance(pad_token_id, int) and (pad_token_id >= 0)
|
||||
), "`pad_token_id` should be a positive integer."
|
||||
assert (eos_token_ids is None) or (
|
||||
isinstance(eos_token_ids, (list, tuple)) and ((isinstance(e, int) and e >= 0) for e in eos_token_ids)
|
||||
), "`eos_token_ids` should be a positive integer or a list/tuple of positive integers."
|
||||
assert length_penalty > 0, "`length_penalty` should be strictely positive."
|
||||
assert (
|
||||
@ -723,12 +733,22 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
), "`num_return_sequences` should be a strictely positive integer."
|
||||
|
||||
if input_ids is None:
|
||||
assert isinstance(bos_token_id, int) and bos_token_id >= 0, (
|
||||
"you should either supply a context to complete as `input_ids` input "
|
||||
"or a `bos_token_id` (integer >= 0) as a first token to start the generation."
|
||||
)
|
||||
input_ids = torch.full(
|
||||
(batch_size, 1), bos_token_id, dtype=torch.long, device=next(self.parameters()).device
|
||||
)
|
||||
else:
|
||||
assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)."
|
||||
|
||||
if pad_token_id is None and eos_token_ids is not None:
|
||||
logger.warning(
|
||||
"Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_ids[0])
|
||||
)
|
||||
pad_token_id = eos_token_ids[0]
|
||||
|
||||
# current position and vocab size
|
||||
cur_len = input_ids.shape[1]
|
||||
vocab_size = self.config.vocab_size
|
||||
@ -775,8 +795,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
effective_batch_size,
|
||||
)
|
||||
|
||||
if num_return_sequences != 1:
|
||||
output = output.view(batch_size, num_return_sequences, -1)
|
||||
return output
|
||||
|
||||
def _generate_no_beam_search(
|
||||
@ -798,6 +816,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
"""
|
||||
# current position / max lengths / length of generated sentences / unfinished sentences
|
||||
unfinished_sents = input_ids.new(batch_size).fill_(1)
|
||||
sent_lengths = input_ids.new(batch_size).fill_(max_length)
|
||||
|
||||
past = None
|
||||
|
||||
@ -833,21 +852,41 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
next_token = torch.argmax(next_token_logits, dim=-1)
|
||||
|
||||
# update generations and finished sentences
|
||||
tokens_to_add = next_token * unfinished_sents + pad_token_id * (1 - unfinished_sents)
|
||||
if eos_token_ids is not None:
|
||||
# pad finished sentences if eos_token_ids exist
|
||||
tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents)
|
||||
else:
|
||||
tokens_to_add = next_token
|
||||
|
||||
input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1)
|
||||
|
||||
if eos_token_ids is not None:
|
||||
for eos_token_id in eos_token_ids:
|
||||
unfinished_sents.mul_(tokens_to_add.ne(eos_token_id).long())
|
||||
eos_in_sents = tokens_to_add == eos_token_id
|
||||
# if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
|
||||
is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(eos_in_sents.long()).bool()
|
||||
sent_lengths.masked_fill_(is_sents_unfinished_and_token_to_add_is_eos, cur_len + 1)
|
||||
# unfinished_sents is set to zero if eos in sentence
|
||||
unfinished_sents.mul_((~eos_in_sents).long())
|
||||
|
||||
cur_len = cur_len + 1
|
||||
|
||||
# stop when there is a </s> in each sentence, or if we exceed the maximul length
|
||||
if unfinished_sents.max() == 0:
|
||||
break
|
||||
|
||||
# add eos_token_ids to unfinished sentences
|
||||
if cur_len == max_length:
|
||||
input_ids[:, -1].masked_fill_(unfinished_sents.to(dtype=torch.bool), eos_token_ids[0])
|
||||
# if there are different sentences lengths in the batch, some batches have to be padded
|
||||
if sent_lengths.min().item() != sent_lengths.max().item():
|
||||
assert pad_token_id is not None, "`Pad_token_id` has to be defined if batches have different lengths"
|
||||
# finished sents are filled with pad_token
|
||||
decoded = input_ids.new(batch_size, sent_lengths.max().item()).fill_(pad_token_id)
|
||||
else:
|
||||
decoded = input_ids
|
||||
|
||||
return input_ids
|
||||
for hypo_idx, hypo in enumerate(input_ids):
|
||||
decoded[hypo_idx, : sent_lengths[hypo_idx]] = hypo[: sent_lengths[hypo_idx]]
|
||||
|
||||
return decoded
|
||||
|
||||
def _generate_beam_search(
|
||||
self,
|
||||
@ -941,11 +980,19 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
next_batch_beam = []
|
||||
|
||||
# for each sentence
|
||||
for batch_ex in range(batch_size):
|
||||
for batch_idx in range(batch_size):
|
||||
|
||||
# if we are done with this sentence
|
||||
done[batch_ex] = done[batch_ex] or generated_hyps[batch_ex].is_done(next_scores[batch_ex].max().item())
|
||||
if done[batch_ex]:
|
||||
done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
|
||||
next_scores[batch_idx].max().item()
|
||||
)
|
||||
if done[batch_idx]:
|
||||
assert (
|
||||
len(generated_hyps[batch_idx]) >= num_beams
|
||||
), "Batch can only be done if at least {} beams have been generated".format(num_beams)
|
||||
assert (
|
||||
eos_token_ids is not None and pad_token_id is not None
|
||||
), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
|
||||
next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch
|
||||
continue
|
||||
|
||||
@ -953,30 +1000,29 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
next_sent_beam = []
|
||||
|
||||
# next words for this sentence
|
||||
for idx, score in zip(next_words[batch_ex], next_scores[batch_ex]):
|
||||
for idx, score in zip(next_words[batch_idx], next_scores[batch_idx]):
|
||||
|
||||
# get beam and word IDs
|
||||
beam_id = idx // vocab_size
|
||||
word_id = idx % vocab_size
|
||||
|
||||
# end of sentence, or next word
|
||||
if word_id.item() in eos_token_ids or cur_len + 1 == max_length:
|
||||
generated_hyps[batch_ex].add(
|
||||
input_ids[batch_ex * num_beams + beam_id, :cur_len].clone(), score.item()
|
||||
# add to generated hypotheses if end of sentence or last iteration
|
||||
if eos_token_ids is not None and word_id.item() in eos_token_ids:
|
||||
generated_hyps[batch_idx].add(
|
||||
input_ids[batch_idx * num_beams + beam_id, :cur_len].clone(), score.item()
|
||||
)
|
||||
else:
|
||||
next_sent_beam.append((score, word_id, batch_ex * num_beams + beam_id))
|
||||
# add next predicted word if it is not eos_token
|
||||
next_sent_beam.append((score, word_id, batch_idx * num_beams + beam_id))
|
||||
|
||||
# the beam for next step is full
|
||||
if len(next_sent_beam) == num_beams:
|
||||
break
|
||||
|
||||
# update next beam content
|
||||
assert len(next_sent_beam) == 0 if cur_len + 1 == max_length else num_beams
|
||||
if len(next_sent_beam) == 0:
|
||||
next_sent_beam = [(0, pad_token_id, 0)] * num_beams # pad the batch
|
||||
assert len(next_sent_beam) == num_beams, "Beam should always be full"
|
||||
next_batch_beam.extend(next_sent_beam)
|
||||
assert len(next_batch_beam) == num_beams * (batch_ex + 1)
|
||||
assert len(next_batch_beam) == num_beams * (batch_idx + 1)
|
||||
|
||||
# sanity check / prepare next batch
|
||||
assert len(next_batch_beam) == batch_size * num_beams
|
||||
@ -1008,29 +1054,42 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
if all(done):
|
||||
break
|
||||
|
||||
# visualize hypotheses
|
||||
# print([len(x) for x in generated_hyps], cur_len)
|
||||
# globals().update( locals() );
|
||||
# !import code; code.interact(local=vars())
|
||||
# for ii in range(batch_size):
|
||||
# for ss, ww in sorted(generated_hyps[ii].hyp, key=lambda x: x[0], reverse=True):
|
||||
# print("%.3f " % ss + " ".join(self.dico[x] for x in ww.tolist()))
|
||||
# print("")
|
||||
for batch_idx in range(batch_size):
|
||||
# Add all open beam hypothesis to generated_hyps
|
||||
if not done[batch_idx]:
|
||||
for idx, score in zip(next_words[batch_idx], next_scores[batch_idx]):
|
||||
|
||||
# get beam and word IDs
|
||||
beam_id = idx // vocab_size
|
||||
word_id = idx % vocab_size
|
||||
generated_hyps[batch_idx].add(
|
||||
input_ids[batch_idx * num_beams + beam_id, :cur_len].clone(), score.item()
|
||||
)
|
||||
|
||||
# select the best hypotheses
|
||||
tgt_len = input_ids.new(batch_size)
|
||||
sent_lengths = input_ids.new(batch_size)
|
||||
best = []
|
||||
|
||||
for i, hypotheses in enumerate(generated_hyps):
|
||||
best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1]
|
||||
tgt_len[i] = len(best_hyp) + 1 # +1 for the <EOS> symbol
|
||||
best_hyp = max(hypotheses.beams, key=lambda x: x[0])[1]
|
||||
sent_lengths[i] = len(best_hyp)
|
||||
best.append(best_hyp)
|
||||
|
||||
# generate target batch
|
||||
decoded = input_ids.new(batch_size, tgt_len.max().item()).fill_(pad_token_id)
|
||||
# shorter batches are filled with pad_token
|
||||
if sent_lengths.min().item() != sent_lengths.max().item():
|
||||
assert pad_token_id is not None, "`Pad_token_id` has to be defined"
|
||||
sent_max_len = min(sent_lengths.max().item() + 1, max_length)
|
||||
decoded = input_ids.new(batch_size, sent_max_len).fill_(pad_token_id)
|
||||
|
||||
# fill with hypothesis and eos_token_id if necessary
|
||||
for i, hypo in enumerate(best):
|
||||
decoded[i, : tgt_len[i] - 1] = hypo
|
||||
decoded[i, tgt_len[i] - 1] = eos_token_ids[0]
|
||||
decoded[i, : sent_lengths[i]] = hypo
|
||||
if sent_lengths[i] < max_length:
|
||||
decoded[i, sent_lengths[i]] = eos_token_ids[0]
|
||||
else:
|
||||
# none of the hypotheses have an eos_token
|
||||
assert (len(hypo) == max_length for hypo in best)
|
||||
decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)
|
||||
|
||||
return decoded
|
||||
|
||||
@ -1071,33 +1130,33 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf")
|
||||
|
||||
|
||||
class BeamHypotheses(object):
|
||||
def __init__(self, n_hyp, max_length, length_penalty, early_stopping):
|
||||
def __init__(self, num_beams, max_length, length_penalty, early_stopping):
|
||||
"""
|
||||
Initialize n-best list of hypotheses.
|
||||
"""
|
||||
self.max_length = max_length - 1 # ignoring bos_token
|
||||
self.length_penalty = length_penalty
|
||||
self.early_stopping = early_stopping
|
||||
self.n_hyp = n_hyp
|
||||
self.hyp = []
|
||||
self.num_beams = num_beams
|
||||
self.beams = []
|
||||
self.worst_score = 1e9
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
Number of hypotheses in the list.
|
||||
"""
|
||||
return len(self.hyp)
|
||||
return len(self.beams)
|
||||
|
||||
def add(self, hyp, sum_logprobs):
|
||||
"""
|
||||
Add a new hypothesis to the list.
|
||||
"""
|
||||
score = sum_logprobs / len(hyp) ** self.length_penalty
|
||||
if len(self) < self.n_hyp or score > self.worst_score:
|
||||
self.hyp.append((score, hyp))
|
||||
if len(self) > self.n_hyp:
|
||||
sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.hyp)])
|
||||
del self.hyp[sorted_scores[0][1]]
|
||||
if len(self) < self.num_beams or score > self.worst_score:
|
||||
self.beams.append((score, hyp))
|
||||
if len(self) > self.num_beams:
|
||||
sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)])
|
||||
del self.beams[sorted_scores[0][1]]
|
||||
self.worst_score = sorted_scores[1][0]
|
||||
else:
|
||||
self.worst_score = min(score, self.worst_score)
|
||||
@ -1107,7 +1166,7 @@ class BeamHypotheses(object):
|
||||
If there are enough hypotheses and that none of the hypotheses being generated
|
||||
can become better than the worst one in the heap, then we are done with this sentence.
|
||||
"""
|
||||
if len(self) < self.n_hyp:
|
||||
if len(self) < self.num_beams:
|
||||
return False
|
||||
elif self.early_stopping:
|
||||
return True
|
||||
|
@ -13,7 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import copy
|
||||
import logging
|
||||
import os.path
|
||||
@ -53,6 +52,7 @@ class ModelTesterMixin:
|
||||
|
||||
model_tester = None
|
||||
all_model_classes = ()
|
||||
all_generative_model_classes = ()
|
||||
test_torchscript = True
|
||||
test_pruning = True
|
||||
test_resize_embeddings = True
|
||||
@ -595,6 +595,47 @@ class ModelTesterMixin:
|
||||
with torch.no_grad():
|
||||
model(**inputs_dict)
|
||||
|
||||
def test_lm_head_model_random_generate(self):
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
input_ids = inputs_dict.get(
|
||||
"input_ids", None
|
||||
) # TODO (PVP): ugly workaround to make code work for t5 for the moment - has to changed when t5 is fixed.
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
if config.bos_token_id is None:
|
||||
with self.assertRaises(AssertionError):
|
||||
model.generate(max_length=5)
|
||||
# batch_size = 1
|
||||
self._check_generated_tokens(model.generate(input_ids))
|
||||
# batch_size = 1, num_beams > 1
|
||||
self._check_generated_tokens(model.generate(input_ids, num_beams=3))
|
||||
else:
|
||||
# batch_size = 1
|
||||
self._check_generated_tokens(model.generate(max_length=5))
|
||||
# batch_size = 1, num_beams > 1
|
||||
self._check_generated_tokens(model.generate(max_length=5, num_beams=3))
|
||||
|
||||
# batch_size > 1, sample
|
||||
self._check_generated_tokens(model.generate(input_ids, num_return_sequences=3))
|
||||
# batch_size > 1, greedy
|
||||
self._check_generated_tokens(model.generate(input_ids, do_sample=False, num_return_sequences=3))
|
||||
# batch_size > 1, num_beams > 1, sample
|
||||
self._check_generated_tokens(model.generate(input_ids, num_beams=3, num_return_sequences=3,))
|
||||
# batch_size > 1, num_beams > 1, greedy
|
||||
self._check_generated_tokens(
|
||||
model.generate(input_ids, do_sample=False, num_beams=3, num_return_sequences=3)
|
||||
)
|
||||
|
||||
def _check_generated_tokens(self, output_ids):
|
||||
for token_id in output_ids[0].tolist():
|
||||
self.assertGreaterEqual(token_id, 0)
|
||||
self.assertLess(token_id, self.model_tester.vocab_size)
|
||||
|
||||
|
||||
global_rng = random.Random()
|
||||
|
||||
|
@ -30,6 +30,7 @@ if is_torch_available():
|
||||
class CTRLModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (CTRLModel, CTRLLMHeadModel) if is_torch_available() else ()
|
||||
all_generative_model_classes = (CTRLLMHeadModel,) if is_torch_available() else ()
|
||||
test_pruning = False
|
||||
test_torchscript = False
|
||||
test_resize_embeddings = False
|
||||
|
@ -37,6 +37,9 @@ if is_torch_available():
|
||||
class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (GPT2Model, GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else ()
|
||||
all_generative_model_classes = (
|
||||
(GPT2LMHeadModel,) if is_torch_available() else ()
|
||||
) # TODO (PVP): Add Double HeadsModel when generate() function is changed accordingly
|
||||
|
||||
class GPT2ModelTester(object):
|
||||
def __init__(
|
||||
@ -88,6 +91,8 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
self.num_labels = num_labels
|
||||
self.num_choices = num_choices
|
||||
self.scope = scope
|
||||
self.bos_token_id = vocab_size - 1
|
||||
self.eos_token_id = vocab_size - 1
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
@ -122,9 +127,11 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
# hidden_dropout_prob=self.hidden_dropout_prob,
|
||||
# attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
n_positions=self.max_position_embeddings,
|
||||
n_ctx=self.max_position_embeddings
|
||||
n_ctx=self.max_position_embeddings,
|
||||
# type_vocab_size=self.type_vocab_size,
|
||||
# initializer_range=self.initializer_range
|
||||
bos_token_id=self.bos_token_id,
|
||||
eos_token_ids=self.eos_token_id,
|
||||
)
|
||||
|
||||
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
|
||||
|
@ -39,6 +39,9 @@ class OpenAIGPTModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(OpenAIGPTModel, OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel) if is_torch_available() else ()
|
||||
)
|
||||
all_generative_model_classes = (
|
||||
(OpenAIGPTLMHeadModel,) if is_torch_available() else ()
|
||||
) # TODO (PVP): Add Double HeadsModel when generate() function is changed accordingly
|
||||
|
||||
class OpenAIGPTModelTester(object):
|
||||
def __init__(
|
||||
|
@ -34,6 +34,7 @@ if is_torch_available():
|
||||
class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (TransfoXLModel, TransfoXLLMHeadModel) if is_torch_available() else ()
|
||||
all_generative_model_classes = (TransfoXLLMHeadModel,) if is_torch_available() else ()
|
||||
test_pruning = False
|
||||
test_torchscript = False
|
||||
test_resize_embeddings = False
|
||||
@ -59,6 +60,7 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
num_hidden_layers=5,
|
||||
scope=None,
|
||||
seed=1,
|
||||
eos_token_id=0,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
@ -79,6 +81,7 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.scope = scope
|
||||
self.seed = seed
|
||||
self.eos_token_id = eos_token_id
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
@ -100,6 +103,7 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
d_inner=self.d_inner,
|
||||
div_val=self.div_val,
|
||||
n_layer=self.num_hidden_layers,
|
||||
eos_token_ids=self.eos_token_id,
|
||||
)
|
||||
|
||||
return (config, input_ids_1, input_ids_2, lm_labels)
|
||||
|
@ -49,6 +49,9 @@ class XLMModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
all_generative_model_classes = (
|
||||
(XLMWithLMHeadModel,) if is_torch_available() else ()
|
||||
) # TODO (PVP): Check other models whether language generation is also applicable
|
||||
|
||||
class XLMModelTester(object):
|
||||
def __init__(
|
||||
@ -81,6 +84,7 @@ class XLMModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
summary_type="last",
|
||||
use_proj=True,
|
||||
scope=None,
|
||||
bos_token_id=0,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
@ -111,6 +115,7 @@ class XLMModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
self.num_labels = num_labels
|
||||
self.num_choices = num_choices
|
||||
self.scope = scope
|
||||
self.bos_token_id = bos_token_id
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
@ -151,6 +156,7 @@ class XLMModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
initializer_range=self.initializer_range,
|
||||
summary_type=self.summary_type,
|
||||
use_proj=self.use_proj,
|
||||
bos_token_id=self.bos_token_id,
|
||||
)
|
||||
|
||||
return (
|
||||
|
@ -52,6 +52,9 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
all_generative_model_classes = (
|
||||
(XLNetLMHeadModel,) if is_torch_available() else ()
|
||||
) # TODO (PVP): Check other models whether language generation is also applicable
|
||||
test_pruning = False
|
||||
|
||||
class XLNetModelTester(object):
|
||||
@ -78,6 +81,9 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
initializer_range=0.05,
|
||||
seed=1,
|
||||
type_vocab_size=2,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
pad_token_id=5,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
@ -101,6 +107,9 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
self.seed = seed
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.bos_token_id = bos_token_id
|
||||
self.pad_token_id = pad_token_id
|
||||
self.eos_token_id = eos_token_id
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
@ -142,6 +151,9 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
bi_data=self.bi_data,
|
||||
initializer_range=self.initializer_range,
|
||||
num_labels=self.type_sequence_label_size,
|
||||
bos_token_id=self.bos_token_id,
|
||||
pad_token_id=self.pad_token_id,
|
||||
eos_token_id=self.eos_token_id,
|
||||
)
|
||||
|
||||
return (
|
||||
|
Loading…
Reference in New Issue
Block a user