transformers/tests/test_generation_utils.py
Patrick von Platen cb966e640b
[Generate Test] fix greedy generate test (#8293)
* fix greedy generate test

* delet ipdb
2020-11-04 15:44:36 +01:00

514 lines
20 KiB
Python

# coding=utf-8
# Copyright 2020 The HuggingFace Team Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a clone of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import is_torch_available
from transformers.testing_utils import require_torch, torch_device
if is_torch_available():
import torch
from transformers import top_k_top_p_filtering
from transformers.generation_beam_search import BeamSearchScorer
from transformers.generation_logits_process import (
LogitsProcessorList,
MinLengthLogitsProcessor,
NoBadWordsLogitsProcessor,
NoRepeatNGramLogitsProcessor,
RepetitionPenaltyLogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
)
class GenerationTesterMixin:
model_tester = None
all_generative_model_classes = ()
def _get_input_ids_and_config(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict["input_ids"]
attention_mask = torch.ones_like(input_ids)
# cut to half length & take max batch_size 3
max_batch_size = 2
sequence_length = input_ids.shape[-1] // 2
input_ids = input_ids[:max_batch_size, :sequence_length]
attention_mask = attention_mask[:max_batch_size, :sequence_length]
# generate max 3 tokens
max_length = input_ids.shape[-1] + 3
if config.eos_token_id is not None and config.pad_token_id is None:
# hack to allow generate for models such as GPT2 as is done in `generate()`
config.pad_token_id = config.eos_token_id
return config, input_ids, attention_mask, max_length
@staticmethod
def _get_logits_processor_and_kwargs(input_length, eos_token_id):
process_kwargs = {
"min_length": input_length + 1,
"bad_words_ids": [[1, 0]],
"no_repeat_ngram_size": 2,
"repetition_penalty": 1.2,
}
logits_processor = LogitsProcessorList(
(
[
MinLengthLogitsProcessor(process_kwargs["min_length"], eos_token_id),
]
if eos_token_id is not None
else []
)
+ [
NoBadWordsLogitsProcessor(process_kwargs["bad_words_ids"], eos_token_id),
NoRepeatNGramLogitsProcessor(process_kwargs["no_repeat_ngram_size"]),
RepetitionPenaltyLogitsProcessor(process_kwargs["repetition_penalty"]),
]
)
return process_kwargs, logits_processor
@staticmethod
def _get_warper_and_kwargs(num_beams):
warp_kwargs = {"top_k": 10, "top_p": 0.7, "temperature": 0.7}
logits_warper = LogitsProcessorList(
[
TopKLogitsWarper(top_k=warp_kwargs["top_k"], min_tokens_to_keep=(2 if num_beams > 1 else 1)),
TopPLogitsWarper(top_p=warp_kwargs["top_p"], min_tokens_to_keep=(2 if num_beams > 1 else 1)),
TemperatureLogitsWarper(warp_kwargs["temperature"]),
]
)
return warp_kwargs, logits_warper
@staticmethod
def _get_beam_scorer_and_kwargs(batch_size, max_length, num_return_sequences=1):
beam_kwargs = {
"early_stopping": False,
"length_penalty": 2.0,
"num_beams": 2,
"num_return_sequences": num_return_sequences,
}
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
max_length=max_length,
num_beams=beam_kwargs["num_beams"],
device=torch_device,
length_penalty=beam_kwargs["length_penalty"],
do_early_stopping=beam_kwargs["early_stopping"],
num_beam_hyps_to_keep=num_return_sequences,
)
return beam_kwargs, beam_scorer
@staticmethod
def _get_encoder_outputs(model, input_ids, attention_mask, num_interleave=1):
encoder = model.get_encoder()
encoder_outputs = encoder(input_ids, attention_mask=attention_mask, return_dict=True)
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.repeat_interleave(
num_interleave, dim=0
)
input_ids = torch.zeros_like(input_ids[:, :1]) + model._get_decoder_start_token_id()
attention_mask = None
return encoder_outputs, input_ids, attention_mask
def test_greedy_generate(self):
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
input_ids.shape[-1], config.eos_token_id
)
model = model_class(config).to(torch_device)
model.eval()
# check `generate()` and `greedy_search()` are equal
kwargs = {}
if model.config.is_encoder_decoder:
max_length = 4
output_ids_generate = model.generate(
input_ids,
attention_mask=attention_mask,
do_sample=False,
num_beams=1,
max_length=max_length,
**logits_process_kwargs,
)
if model.config.is_encoder_decoder:
encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs(
model, input_ids, attention_mask
)
kwargs["encoder_outputs"] = encoder_outputs
with torch.no_grad():
output_ids_greedy = model.greedy_search(
input_ids,
max_length=max_length,
attention_mask=attention_mask,
logits_processor=logits_processor,
**kwargs,
)
self.assertListEqual(output_ids_generate.tolist(), output_ids_greedy.tolist())
def test_sample_generate(self):
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
input_ids.shape[-1], config.eos_token_id
)
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1)
model = model_class(config).to(torch_device)
model.eval()
# check `generate()` and `sample()` are equal
if model.config.is_encoder_decoder:
max_length = 4
torch.manual_seed(0)
output_ids_generate = model.generate(
input_ids,
do_sample=True,
num_beams=1,
max_length=max_length,
attention_mask=attention_mask,
**logits_warper_kwargs,
**process_kwargs,
)
torch.manual_seed(0)
kwargs = {}
if model.config.is_encoder_decoder:
encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs(
model, input_ids, attention_mask
)
kwargs["encoder_outputs"] = encoder_outputs
else:
attention_mask_clone = attention_mask
input_ids_clone = input_ids
with torch.no_grad():
output_ids_sample = model.sample(
input_ids_clone,
attention_mask=attention_mask_clone,
max_length=max_length,
logits_processor=logits_processor,
logits_warper=logits_warper,
**kwargs,
)
self.assertListEqual(output_ids_generate.tolist(), output_ids_sample.tolist())
# check `generate()` and `sample()` yield equal results for `num_return_sequences`
num_return_sequences = 3
if model.config.is_encoder_decoder:
max_length = 4
torch.manual_seed(0)
output_ids_generate = model.generate(
input_ids,
do_sample=True,
num_beams=1,
max_length=max_length,
num_return_sequences=num_return_sequences,
attention_mask=attention_mask,
**logits_warper_kwargs,
**process_kwargs,
)
torch.manual_seed(0)
kwargs = {}
if model.config.is_encoder_decoder:
encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs(
model, input_ids, attention_mask, num_interleave=num_return_sequences
)
kwargs["encoder_outputs"] = encoder_outputs
input_ids_clone = input_ids_clone.repeat_interleave(num_return_sequences, dim=0)
else:
attention_mask_clone = attention_mask.repeat_interleave(num_return_sequences, dim=0)
input_ids_clone = input_ids.repeat_interleave(num_return_sequences, dim=0)
with torch.no_grad():
output_ids_sample = model.sample(
input_ids_clone,
attention_mask=attention_mask_clone,
max_length=max_length,
logits_processor=logits_processor,
logits_warper=logits_warper,
**kwargs,
)
self.assertListEqual(output_ids_generate.tolist(), output_ids_sample.tolist())
def test_beam_search_generate(self):
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
input_ids.shape[-1], config.eos_token_id
)
model = model_class(config).to(torch_device)
model.eval()
# check `generate()` and `beam_search()` are equal
if model.config.is_encoder_decoder:
max_length = 4
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length)
output_ids_generate = model.generate(
input_ids,
attention_mask=attention_mask,
do_sample=False,
max_length=max_length,
**beam_kwargs,
**logits_process_kwargs,
)
# beam_search does not automatically interleave `batch_size` dim for `num_beams`
kwargs = {}
if model.config.is_encoder_decoder:
encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs(
model, input_ids, attention_mask, num_interleave=beam_scorer.num_beams
)
kwargs["encoder_outputs"] = encoder_outputs
input_ids_clone = input_ids_clone.repeat_interleave(beam_scorer.num_beams, dim=0)
else:
attention_mask_clone = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0)
input_ids_clone = input_ids.repeat_interleave(beam_scorer.num_beams, dim=0)
with torch.no_grad():
output_ids_beam_search = model.beam_search(
input_ids_clone,
beam_scorer,
max_length=max_length,
attention_mask=attention_mask_clone,
logits_processor=logits_processor,
**kwargs,
)
self.assertListEqual(output_ids_generate.tolist(), output_ids_beam_search.tolist())
# check `generate()` and `beam_search()` are equal for `num_return_sequences`
num_return_sequences = 2
if model.config.is_encoder_decoder:
max_length = 4
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(
input_ids.shape[0], max_length, num_return_sequences=num_return_sequences
)
output_ids_generate = model.generate(
input_ids,
attention_mask=attention_mask,
do_sample=False,
max_length=max_length,
**beam_kwargs,
**logits_process_kwargs,
)
# beam_search does not automatically interleave `batch_size` dim for `num_beams`
kwargs = {}
if model.config.is_encoder_decoder:
encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs(
model, input_ids, attention_mask, num_interleave=beam_scorer.num_beams
)
kwargs["encoder_outputs"] = encoder_outputs
input_ids_clone = input_ids_clone.repeat_interleave(beam_scorer.num_beams, dim=0)
else:
attention_mask_clone = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0)
input_ids_clone = input_ids.repeat_interleave(beam_scorer.num_beams, dim=0)
with torch.no_grad():
output_ids_beam_search = model.beam_search(
input_ids_clone,
beam_scorer,
max_length=max_length,
attention_mask=attention_mask_clone,
logits_processor=logits_processor,
**kwargs,
)
self.assertListEqual(output_ids_generate.tolist(), output_ids_beam_search.tolist())
def test_beam_sample_generate(self):
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1)
model = model_class(config).to(torch_device)
model.eval()
# check `generate()` and `beam_search()` are equal
# change `num_return_sequences = 2` but not for `beam_scorer`
num_return_sequences = 2
if model.config.is_encoder_decoder:
max_length = 4
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(
input_ids.shape[0] * num_return_sequences, max_length
)
beam_kwargs["num_return_sequences"] = num_return_sequences
torch.manual_seed(0)
output_ids_generate = model.generate(
input_ids,
attention_mask=attention_mask,
do_sample=True,
max_length=max_length,
**beam_kwargs,
**logits_warper_kwargs,
)
# beam_search does not automatically interleave `batch_size` dim for `num_beams * num_return_sequences`
kwargs = {}
if model.config.is_encoder_decoder:
encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs(
model, input_ids, attention_mask, num_interleave=beam_scorer.num_beams * num_return_sequences
)
kwargs["encoder_outputs"] = encoder_outputs
else:
attention_mask = attention_mask.repeat_interleave(beam_scorer.num_beams * num_return_sequences, dim=0)
torch.manual_seed(0)
with torch.no_grad():
output_ids_beam_sample = model.beam_sample(
input_ids.repeat_interleave(beam_scorer.num_beams * num_return_sequences, dim=0),
beam_scorer,
max_length=max_length,
attention_mask=attention_mask,
logits_warper=logits_warper,
**kwargs,
)
self.assertListEqual(output_ids_generate.tolist(), output_ids_beam_sample.tolist())
def test_generate_without_input_ids(self):
config, _, _, max_length = self._get_input_ids_and_config()
# if no bos token id => cannot generate from None
if config.bos_token_id is None:
return
for model_class in self.all_generative_model_classes:
model = model_class(config).to(torch_device)
model.eval()
output_ids_generate = model.generate(
do_sample=False,
max_length=max_length,
)
self.assertIsNotNone(output_ids_generate)
@require_torch
class UtilsFunctionsTest(unittest.TestCase):
# tests whether the top_k_top_p function behaves as expected
def test_top_k_top_p_filtering(self):
logits = torch.tensor(
[
[
8.2220991, # 3rd highest value; idx. 0
-0.5620044,
5.23229752,
4.0386393,
-6.8798378,
-0.54785802,
-3.2012153,
2.92777176,
1.88171953,
7.35341276,
8.43207833, # 2nd highest value; idx. 10
-9.85711836,
-5.96209236,
-1.13039161,
-7.1115294,
-0.8369633,
-5.3186408,
7.06427407,
0.81369344,
-0.82023817,
-5.9179796,
0.58813443,
-6.99778438,
4.71551189,
-0.18771637,
7.44020759, # 4th highest value; idx. 25
9.38450987, # 1st highest value; idx. 26
2.12662941,
-9.32562038,
2.35652522,
], # cummulative prob of 4 highest values <= 0.6
[
0.58425518,
4.53139238,
-5.57510464,
-6.28030699,
-7.19529503,
-4.02122551,
1.39337037,
-6.06707057,
1.59480517,
-9.643119,
0.03907799,
0.67231762,
-8.88206726,
6.27115922, # 4th highest value; idx. 13
2.28520723,
4.82767506,
4.30421368,
8.8275313, # 2nd highest value; idx. 17
5.44029958,
-4.4735794,
7.38579536, # 3rd highest value; idx. 20
-2.91051663,
2.61946077,
-2.5674762,
-9.48959302,
-4.02922645,
-1.35416918,
9.67702323, # 1st highest value; idx. 27
-5.89478553,
1.85370467,
], # cummulative prob of 4 highest values <= 0.6
],
dtype=torch.float,
device=torch_device,
)
non_inf_expected_idx = torch.tensor(
[[0, 0], [0, 10], [0, 25], [0, 26], [1, 13], [1, 17], [1, 20], [1, 27]],
dtype=torch.long,
device=torch_device,
) # expected non filtered idx as noted above
non_inf_expected_output = torch.tensor(
[
8.2221,
8.4321,
7.4402,
9.3845,
6.2712,
8.8275,
7.3858,
9.6770,
], # expected non filtered values as noted above
dtype=torch.float,
device=torch_device,
)
output = top_k_top_p_filtering(logits, top_k=10, top_p=0.6, min_tokens_to_keep=4)
non_inf_output = output[output != -float("inf")].to(device=torch_device)
non_inf_idx = (output != -float("inf")).nonzero().to(device=torch_device)
self.assertTrue(torch.allclose(non_inf_expected_output, non_inf_output, atol=1e-12))
self.assertTrue(torch.all(torch.eq(non_inf_expected_idx, non_inf_idx)))