transformers/tests/test_generation_utils.py
Simon Brandeis c89f1bc92e
Add flags to return scores, hidden states and / or attention weights in GenerationMixin (#9150)
* Define new output dataclasses for greedy generation

* Add output_[...] flags in greedy generation methods

Added output_attentions, output_hidden_states, output_scores flags in
generate and greedy_search methods in GenerationMixin.

* [WIP] Implement logic and tests for output flags in generation

* Update GreedySearchOutput classes & docstring

* Implement greedy search output accumulation logic

Update greedy_search unittests

Fix generate method return value docstring

Properly init flags with the default config

* Update configuration to add output_scores flag

* Fix test_generation_utils

Sort imports and fix isinstance tests for GreedySearchOutputs

* Fix typo in generation_utils

* Add return_dict_in_generate for backwards compatibility

* Add return_dict_in_generate flag in config

* Fix tyPo in configuration

* Fix handling of attentions and hidden_states flags

* Make style & quality

* first attempt attentions

* some corrections

* improve tests

* special models requires special test

* disable xlm test for now

* clean tests

* fix for tf

* isort

* Add output dataclasses for other generation methods

* Add logic to return dict in sample generation

* Complete test for sample generation

- Pass output_attentions and output_hidden_states flags to encoder in
encoder-decoder models
- Fix import satements order in test_generation_utils file

* Add logic to return dict in sample generation

- Refactor tests to avoid using self.assertTrue, which provides
scarce information when the test fails
- Add tests for the three beam_search methods: vanilla, sample and
grouped

* Style doc

* Fix copy-paste error in generation tests

* Rename logits to scores and refactor

* Refactor group_beam_search for consistency

* make style

* add sequences_scores

* fix all tests

* add docs

* fix beam search finalize test

* correct docstring

* clean some files

* Made suggested changes to the documentation

* Style doc ?

* Style doc using the Python util

* Update src/transformers/generation_utils.py

* fix empty lines

* fix all test

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
2021-01-06 17:11:42 +01:00

1206 lines
50 KiB
Python

# coding=utf-8
# Copyright 2020 The HuggingFace Team Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a clone of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
if is_torch_available():
import torch
from transformers import BartForConditionalGeneration, BartTokenizer, top_k_top_p_filtering
from transformers.generation_beam_search import BeamSearchScorer
from transformers.generation_logits_process import (
HammingDiversityLogitsProcessor,
LogitsProcessorList,
MinLengthLogitsProcessor,
NoBadWordsLogitsProcessor,
NoRepeatNGramLogitsProcessor,
RepetitionPenaltyLogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
)
from transformers.generation_utils import (
BeamSearchDecoderOnlyOutput,
BeamSearchEncoderDecoderOutput,
GreedySearchDecoderOnlyOutput,
GreedySearchEncoderDecoderOutput,
SampleDecoderOnlyOutput,
SampleEncoderDecoderOutput,
)
class GenerationTesterMixin:
model_tester = None
all_generative_model_classes = ()
def _get_input_ids_and_config(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict["input_ids"]
attention_mask = torch.ones_like(input_ids)
# cut to half length & take max batch_size 3
max_batch_size = 2
sequence_length = input_ids.shape[-1] // 2
input_ids = input_ids[:max_batch_size, :sequence_length]
attention_mask = attention_mask[:max_batch_size, :sequence_length]
# generate max 3 tokens
max_length = input_ids.shape[-1] + 3
if config.eos_token_id is not None and config.pad_token_id is None:
# hack to allow generate for models such as GPT2 as is done in `generate()`
config.pad_token_id = config.eos_token_id
return config, input_ids, attention_mask, max_length
@staticmethod
def _get_logits_processor_and_kwargs(input_length, eos_token_id, diversity_penalty=None):
process_kwargs = {
"min_length": input_length + 1,
"bad_words_ids": [[1, 0]],
"no_repeat_ngram_size": 2,
"repetition_penalty": 1.2,
}
logits_processor = LogitsProcessorList(
(
[
HammingDiversityLogitsProcessor(diversity_penalty, num_beams=2, num_beam_groups=2),
]
if diversity_penalty is not None
else []
)
+ (
[
MinLengthLogitsProcessor(process_kwargs["min_length"], eos_token_id),
]
if eos_token_id is not None
else []
)
+ [
NoBadWordsLogitsProcessor(process_kwargs["bad_words_ids"], eos_token_id),
NoRepeatNGramLogitsProcessor(process_kwargs["no_repeat_ngram_size"]),
RepetitionPenaltyLogitsProcessor(process_kwargs["repetition_penalty"]),
]
)
return process_kwargs, logits_processor
@staticmethod
def _get_warper_and_kwargs(num_beams):
warp_kwargs = {"top_k": 10, "top_p": 0.7, "temperature": 0.7}
logits_warper = LogitsProcessorList(
[
TemperatureLogitsWarper(warp_kwargs["temperature"]),
TopKLogitsWarper(top_k=warp_kwargs["top_k"], min_tokens_to_keep=(2 if num_beams > 1 else 1)),
TopPLogitsWarper(top_p=warp_kwargs["top_p"], min_tokens_to_keep=(2 if num_beams > 1 else 1)),
]
)
return warp_kwargs, logits_warper
@staticmethod
def _get_beam_scorer_and_kwargs(batch_size, max_length, num_return_sequences=1):
beam_kwargs = {
"early_stopping": False,
"length_penalty": 2.0,
"num_beams": 2,
"num_return_sequences": num_return_sequences,
}
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
max_length=max_length,
num_beams=beam_kwargs["num_beams"],
device=torch_device,
length_penalty=beam_kwargs["length_penalty"],
do_early_stopping=beam_kwargs["early_stopping"],
num_beam_hyps_to_keep=num_return_sequences,
)
return beam_kwargs, beam_scorer
@staticmethod
def _get_diverse_beam_scorer_and_kwargs(batch_size, max_length, num_return_sequences=1):
beam_kwargs = {
"early_stopping": False,
"length_penalty": 2.0,
"num_beams": 2,
"num_return_sequences": num_return_sequences,
"num_beam_groups": 2, # one beam per group
"diversity_penalty": 2.0,
}
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
max_length=max_length,
num_beams=beam_kwargs["num_beams"],
device=torch_device,
length_penalty=beam_kwargs["length_penalty"],
do_early_stopping=beam_kwargs["early_stopping"],
num_beam_hyps_to_keep=num_return_sequences,
num_beam_groups=beam_kwargs["num_beam_groups"],
)
return beam_kwargs, beam_scorer
@staticmethod
def _get_encoder_outputs(
model, input_ids, attention_mask, output_attentions=None, output_hidden_states=None, num_interleave=1
):
encoder = model.get_encoder()
encoder_outputs = encoder(
input_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.repeat_interleave(
num_interleave, dim=0
)
input_ids = torch.zeros_like(input_ids[:, :1]) + model._get_decoder_start_token_id()
attention_mask = None
return encoder_outputs, input_ids, attention_mask
def _greedy_generate(
self,
model,
input_ids,
attention_mask,
max_length,
output_scores=False,
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=False,
):
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
input_ids.shape[-1], model.config.eos_token_id
)
kwargs = {}
if model.config.is_encoder_decoder:
max_length = 4
output_generate = model.generate(
input_ids,
attention_mask=attention_mask,
do_sample=False,
num_beams=1,
max_length=max_length,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_scores=output_scores,
return_dict_in_generate=return_dict_in_generate,
**logits_process_kwargs,
)
if model.config.is_encoder_decoder:
encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs(
model,
input_ids,
attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
kwargs["encoder_outputs"] = encoder_outputs
with torch.no_grad():
output_greedy = model.greedy_search(
input_ids,
max_length=max_length,
attention_mask=attention_mask,
logits_processor=logits_processor,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_scores=output_scores,
return_dict_in_generate=return_dict_in_generate,
**kwargs,
)
return output_greedy, output_generate
def _sample_generate(
self,
model,
input_ids,
attention_mask,
max_length,
num_return_sequences,
logits_processor,
logits_warper,
logits_warper_kwargs,
process_kwargs,
output_scores=False,
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=False,
):
torch.manual_seed(0)
output_generate = model.generate(
input_ids,
do_sample=True,
num_beams=1,
max_length=max_length,
num_return_sequences=num_return_sequences,
attention_mask=attention_mask,
output_scores=output_scores,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
**logits_warper_kwargs,
**process_kwargs,
)
torch.manual_seed(0)
kwargs = {}
if model.config.is_encoder_decoder:
encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs(
model,
input_ids,
attention_mask,
num_interleave=num_return_sequences,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
kwargs["encoder_outputs"] = encoder_outputs
input_ids_clone = input_ids_clone.repeat_interleave(num_return_sequences, dim=0)
else:
attention_mask_clone = attention_mask.repeat_interleave(num_return_sequences, dim=0)
input_ids_clone = input_ids.repeat_interleave(num_return_sequences, dim=0)
with torch.no_grad():
output_sample = model.sample(
input_ids_clone,
attention_mask=attention_mask_clone,
max_length=max_length,
logits_processor=logits_processor,
logits_warper=logits_warper,
output_scores=output_scores,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
**kwargs,
)
return output_sample, output_generate
def _beam_search_generate(
self,
model,
input_ids,
attention_mask,
max_length,
beam_scorer,
beam_kwargs,
logits_processor,
logits_process_kwargs,
output_scores=False,
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=False,
):
output_generate = model.generate(
input_ids,
attention_mask=attention_mask,
do_sample=False,
max_length=max_length,
output_scores=output_scores,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
**beam_kwargs,
**logits_process_kwargs,
)
# beam_search does not automatically interleave `batch_size` dim for `num_beams`
kwargs = {}
if model.config.is_encoder_decoder:
encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs(
model,
input_ids,
attention_mask,
num_interleave=beam_scorer.num_beams,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
kwargs["encoder_outputs"] = encoder_outputs
input_ids_clone = input_ids_clone.repeat_interleave(beam_scorer.num_beams, dim=0)
else:
attention_mask_clone = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0)
input_ids_clone = input_ids.repeat_interleave(beam_scorer.num_beams, dim=0)
with torch.no_grad():
output_beam_search = model.beam_search(
input_ids_clone,
beam_scorer,
max_length=max_length,
attention_mask=attention_mask_clone,
logits_processor=logits_processor,
output_scores=output_scores,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
**kwargs,
)
return output_generate, output_beam_search
def _beam_sample_generate(
self,
model,
input_ids,
attention_mask,
max_length,
num_return_sequences,
beam_scorer,
beam_kwargs,
logits_warper,
logits_warper_kwargs,
output_scores=False,
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=False,
):
torch.manual_seed(0)
output_generate = model.generate(
input_ids,
attention_mask=attention_mask,
do_sample=True,
max_length=max_length,
output_scores=output_scores,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
**beam_kwargs,
**logits_warper_kwargs,
)
# beam_search does not automatically interleave `batch_size` dim for `num_beams * num_return_sequences`
kwargs = {}
if model.config.is_encoder_decoder:
encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs(
model,
input_ids,
attention_mask,
num_interleave=beam_scorer.num_beams * num_return_sequences,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
kwargs["encoder_outputs"] = encoder_outputs
else:
attention_mask = attention_mask.repeat_interleave(beam_scorer.num_beams * num_return_sequences, dim=0)
torch.manual_seed(0)
with torch.no_grad():
output_beam_sample = model.beam_sample(
input_ids.repeat_interleave(beam_scorer.num_beams * num_return_sequences, dim=0),
beam_scorer,
max_length=max_length,
attention_mask=attention_mask,
logits_warper=logits_warper,
output_scores=output_scores,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
**kwargs,
)
return output_generate, output_beam_sample
def _group_beam_search_generate(
self,
model,
input_ids,
attention_mask,
max_length,
beam_scorer,
beam_kwargs,
logits_processor,
logits_process_kwargs,
output_scores=False,
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=False,
):
output_generate = model.generate(
input_ids,
attention_mask=attention_mask,
do_sample=False,
max_length=max_length,
output_scores=output_scores,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
**beam_kwargs,
**logits_process_kwargs,
)
# group_beam_search does not automatically interleave `batch_size` dim for `num_beams`
kwargs = {}
if model.config.is_encoder_decoder:
encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs(
model,
input_ids,
attention_mask,
num_interleave=beam_scorer.num_beams,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
kwargs["encoder_outputs"] = encoder_outputs
input_ids_clone = input_ids_clone.repeat_interleave(beam_scorer.num_beams, dim=0)
else:
attention_mask_clone = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0)
input_ids_clone = input_ids.repeat_interleave(beam_scorer.num_beams, dim=0)
with torch.no_grad():
output_group_beam_search = model.group_beam_search(
input_ids_clone,
beam_scorer,
max_length=max_length,
attention_mask=attention_mask_clone,
logits_processor=logits_processor,
output_scores=output_scores,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
**kwargs,
)
return output_generate, output_group_beam_search
def test_greedy_generate(self):
# check `generate()` and `greedy_search()` are equal
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
# test old generation output for backwards compatibility
model = model_class(config).to(torch_device).eval()
output_greedy, output_generate = self._greedy_generate(
model=model, input_ids=input_ids, attention_mask=attention_mask, max_length=max_length
)
self.assertListEqual(output_greedy.tolist(), output_generate.tolist())
def test_greedy_generate_dict_outputs(self):
for model_class in self.all_generative_model_classes:
# disable cache
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
config.use_cache = False
model = model_class(config).to(torch_device).eval()
output_greedy, output_generate = self._greedy_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
max_length=max_length,
output_scores=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
if model.config.is_encoder_decoder:
self.assertIsInstance(output_greedy, GreedySearchEncoderDecoderOutput)
self.assertIsInstance(output_generate, GreedySearchEncoderDecoderOutput)
else:
self.assertIsInstance(output_greedy, GreedySearchDecoderOnlyOutput)
self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput)
self.assertListEqual(output_generate.sequences.tolist(), output_greedy.sequences.tolist())
for output in (output_greedy, output_generate):
self._check_outputs(output, input_ids, model.config)
def test_greedy_generate_dict_outputs_use_cache(self):
for model_class in self.all_generative_model_classes:
# enable cache
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
if not hasattr(config, "use_cache"):
# only relevant if model has "use_cache"
return
config.use_cache = True
model = model_class(config).to(torch_device).eval()
output_greedy, output_generate = self._greedy_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
max_length=max_length,
output_scores=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
self.assertListEqual(output_generate.sequences.tolist(), output_greedy.sequences.tolist())
for output in (output_greedy, output_generate):
self._check_outputs(output, input_ids, model.config, use_cache=True)
def test_sample_generate(self):
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
model = model_class(config).to(torch_device).eval()
process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
input_ids.shape[-1], model.config.eos_token_id
)
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1)
if model.config.is_encoder_decoder:
max_length = 4
# check `generate()` and `sample()` are equal
output_sample, output_generate = self._sample_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
max_length=max_length,
num_return_sequences=1,
logits_processor=logits_processor,
logits_warper=logits_warper,
logits_warper_kwargs=logits_warper_kwargs,
process_kwargs=process_kwargs,
)
self.assertListEqual(output_sample.tolist(), output_generate.tolist())
# check `generate()` and `sample()` yield equal results for `num_return_sequences`
output_sample, output_generate = self._sample_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
max_length=max_length,
num_return_sequences=3,
logits_processor=logits_processor,
logits_warper=logits_warper,
logits_warper_kwargs=logits_warper_kwargs,
process_kwargs=process_kwargs,
)
self.assertListEqual(output_sample.tolist(), output_generate.tolist())
def test_sample_generate_dict_output(self):
for model_class in self.all_generative_model_classes:
# disable cache
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
config.use_cache = False
model = model_class(config).to(torch_device).eval()
process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
input_ids.shape[-1], model.config.eos_token_id
)
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1)
if model.config.is_encoder_decoder:
max_length = 4
output_sample, output_generate = self._sample_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
max_length=max_length,
num_return_sequences=2,
logits_processor=logits_processor,
logits_warper=logits_warper,
logits_warper_kwargs=logits_warper_kwargs,
process_kwargs=process_kwargs,
output_scores=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
if model.config.is_encoder_decoder:
self.assertIsInstance(output_sample, SampleEncoderDecoderOutput)
self.assertIsInstance(output_generate, SampleEncoderDecoderOutput)
else:
self.assertIsInstance(output_sample, SampleDecoderOnlyOutput)
self.assertIsInstance(output_generate, SampleDecoderOnlyOutput)
self.assertListEqual(output_generate.sequences.tolist(), output_sample.sequences.tolist())
for output in (output_sample, output_generate):
self._check_outputs(output, input_ids, model.config, num_return_sequences=2)
def test_beam_search_generate(self):
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
model = model_class(config).to(torch_device).eval()
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
input_ids.shape[-1], config.eos_token_id
)
if model.config.is_encoder_decoder:
max_length = 4
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length)
# check `generate()` and `beam_search()` are equal
output_generate, output_beam_search = self._beam_search_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
max_length=max_length,
beam_scorer=beam_scorer,
beam_kwargs=beam_kwargs,
logits_process_kwargs=logits_process_kwargs,
logits_processor=logits_processor,
)
self.assertListEqual(output_generate.tolist(), output_beam_search.tolist())
# check `generate()` and `beam_search()` are equal for `num_return_sequences`
num_return_sequences = 2
if model.config.is_encoder_decoder:
max_length = 4
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(
input_ids.shape[0], max_length, num_return_sequences=num_return_sequences
)
output_generate, output_beam_search = self._beam_search_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
max_length=max_length,
beam_scorer=beam_scorer,
beam_kwargs=beam_kwargs,
logits_process_kwargs=logits_process_kwargs,
logits_processor=logits_processor,
)
self.assertListEqual(output_generate.tolist(), output_beam_search.tolist())
def test_beam_search_generate_dict_output(self):
for model_class in self.all_generative_model_classes:
# disable cache
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
config.use_cache = False
model = model_class(config).to(torch_device).eval()
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
input_ids.shape[-1], config.eos_token_id
)
if model.config.is_encoder_decoder:
max_length = 4
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length)
output_generate, output_beam_search = self._beam_search_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
max_length=max_length,
beam_scorer=beam_scorer,
beam_kwargs=beam_kwargs,
logits_process_kwargs=logits_process_kwargs,
logits_processor=logits_processor,
output_scores=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
if model.config.is_encoder_decoder:
self.assertIsInstance(output_beam_search, BeamSearchEncoderDecoderOutput)
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
else:
self.assertIsInstance(output_beam_search, BeamSearchDecoderOnlyOutput)
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
self.assertListEqual(output_generate.sequences.tolist(), output_beam_search.sequences.tolist())
self.assertTrue(
torch.allclose(output_generate["sequences_scores"], output_beam_search["sequences_scores"], atol=1e-3)
)
self.assertTrue(output_generate["sequences_scores"].shape == (output_generate["sequences"].shape[0],))
self.assertTrue((output_generate["sequences_scores"] < 0).all().item())
for output in (output_beam_search, output_generate):
self._check_outputs(output, input_ids, model.config, num_return_sequences=beam_scorer.num_beams)
def test_beam_search_generate_dict_outputs_use_cache(self):
for model_class in self.all_generative_model_classes:
# enable cache
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
if not hasattr(config, "use_cache"):
# only relevant if model has "use_cache"
return
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
model = model_class(config).to(torch_device).eval()
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
input_ids.shape[-1], config.eos_token_id
)
if model.config.is_encoder_decoder:
max_length = 4
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length)
config.use_cache = True
model = model_class(config).to(torch_device).eval()
output_beam, output_generate = self._beam_search_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
max_length=max_length,
beam_scorer=beam_scorer,
beam_kwargs=beam_kwargs,
logits_process_kwargs=logits_process_kwargs,
logits_processor=logits_processor,
output_scores=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
self.assertListEqual(output_generate.sequences.tolist(), output_beam.sequences.tolist())
for output in (output_beam, output_generate):
self._check_outputs(
output, input_ids, model.config, use_cache=True, num_return_sequences=beam_scorer.num_beams
)
def test_beam_sample_generate(self):
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
print("Return dict", config.return_dict)
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1)
model = model_class(config).to(torch_device).eval()
# check `generate()` and `beam_search()` are equal
# change `num_return_sequences = 2` but not for `beam_scorer`
num_return_sequences = 2
if model.config.is_encoder_decoder:
max_length = 4
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(
input_ids.shape[0] * num_return_sequences, max_length
)
beam_kwargs["num_return_sequences"] = num_return_sequences
output_generate, output_beam_sample = self._beam_sample_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
max_length=max_length,
num_return_sequences=num_return_sequences,
beam_scorer=beam_scorer,
beam_kwargs=beam_kwargs,
logits_warper=logits_warper,
logits_warper_kwargs=logits_warper_kwargs,
)
self.assertListEqual(output_generate.tolist(), output_beam_sample.tolist())
def test_beam_sample_generate_dict_output(self):
for model_class in self.all_generative_model_classes:
# disable cache
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
config.use_cache = False
model = model_class(config).to(torch_device).eval()
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1)
num_return_sequences = 2
if model.config.is_encoder_decoder:
max_length = 4
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(
input_ids.shape[0] * num_return_sequences, max_length
)
beam_kwargs["num_return_sequences"] = num_return_sequences
output_beam_sample, output_generate = self._beam_sample_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
max_length=max_length,
num_return_sequences=num_return_sequences,
beam_scorer=beam_scorer,
beam_kwargs=beam_kwargs,
logits_warper=logits_warper,
logits_warper_kwargs=logits_warper_kwargs,
output_scores=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
if model.config.is_encoder_decoder:
self.assertIsInstance(output_beam_sample, BeamSearchEncoderDecoderOutput)
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
else:
self.assertIsInstance(output_beam_sample, BeamSearchDecoderOnlyOutput)
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
self.assertListEqual(output_generate.sequences.tolist(), output_beam_sample.sequences.tolist())
self.assertTrue(
torch.allclose(output_generate["sequences_scores"], output_beam_sample["sequences_scores"], atol=1e-3)
)
self.assertTrue(output_generate["sequences_scores"].shape == (output_generate["sequences"].shape[0],))
self.assertTrue((output_generate["sequences_scores"] < 0).all().item())
for output in (output_beam_sample, output_generate):
self._check_outputs(
output, input_ids, model.config, num_return_sequences=num_return_sequences * beam_scorer.num_beams
)
def test_generate_without_input_ids(self):
config, _, _, max_length = self._get_input_ids_and_config()
# if no bos token id => cannot generate from None
if config.bos_token_id is None:
return
for model_class in self.all_generative_model_classes:
model = model_class(config).to(torch_device)
model.eval()
output_ids_generate = model.generate(
do_sample=False,
max_length=max_length,
)
self.assertIsNotNone(output_ids_generate)
def test_group_beam_search_generate(self):
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
input_ids.shape[-1], config.eos_token_id, diversity_penalty=2.0
)
model = model_class(config).to(torch_device).eval()
# check `generate()` and `group_beam_search()` are equal
if model.config.is_encoder_decoder:
max_length = 4
beam_kwargs, beam_scorer = self._get_diverse_beam_scorer_and_kwargs(input_ids.shape[0], max_length)
output_generate, output_group_beam_search = self._group_beam_search_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
max_length=max_length,
beam_scorer=beam_scorer,
beam_kwargs=beam_kwargs,
logits_processor=logits_processor,
logits_process_kwargs=logits_process_kwargs,
)
self.assertListEqual(output_generate.tolist(), output_group_beam_search.tolist())
# check `generate()` and `group_beam_search()` are equal for `num_return_sequences`
num_return_sequences = 2
if model.config.is_encoder_decoder:
max_length = 4
beam_kwargs, beam_scorer = self._get_diverse_beam_scorer_and_kwargs(
input_ids.shape[0], max_length, num_return_sequences=num_return_sequences
)
output_generate, output_group_beam_search = self._group_beam_search_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
max_length=max_length,
beam_scorer=beam_scorer,
beam_kwargs=beam_kwargs,
logits_processor=logits_processor,
logits_process_kwargs=logits_process_kwargs,
)
self.assertListEqual(output_generate.tolist(), output_group_beam_search.tolist())
def test_group_beam_search_generate_dict_output(self):
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
config.use_cache = False
model = model_class(config).to(torch_device).eval()
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
input_ids.shape[-1], config.eos_token_id, diversity_penalty=2.0
)
num_return_sequences = 1
if model.config.is_encoder_decoder:
max_length = 4
beam_kwargs, beam_scorer = self._get_diverse_beam_scorer_and_kwargs(
input_ids.shape[0], max_length, num_return_sequences=num_return_sequences
)
output_generate, output_group_beam_search = self._group_beam_search_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
max_length=max_length,
beam_scorer=beam_scorer,
beam_kwargs=beam_kwargs,
logits_processor=logits_processor,
logits_process_kwargs=logits_process_kwargs,
output_scores=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
if model.config.is_encoder_decoder:
self.assertIsInstance(output_group_beam_search, BeamSearchEncoderDecoderOutput)
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
else:
self.assertIsInstance(output_group_beam_search, BeamSearchDecoderOnlyOutput)
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
self.assertListEqual(output_generate.sequences.tolist(), output_group_beam_search.sequences.tolist())
self.assertTrue(
torch.allclose(
output_generate["sequences_scores"], output_group_beam_search["sequences_scores"], atol=1e-3
)
)
self.assertTrue(output_generate["sequences_scores"].shape == (output_generate["sequences"].shape[0],))
self.assertTrue((output_generate["sequences_scores"] < 0).all().item())
for output in (output_group_beam_search, output_generate):
self._check_outputs(
output, input_ids, model.config, num_return_sequences=num_return_sequences * beam_scorer.num_beams
)
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
batch_size, seq_length = input_ids.shape
num_sequences_in_output = batch_size * num_return_sequences
gen_len = (
output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - seq_length
)
# scores
self._check_scores(num_sequences_in_output, output.scores, length=gen_len, config=config)
# Attentions
if config.is_encoder_decoder:
# encoder
encoder_expected_shape = (batch_size, config.num_attention_heads, seq_length, seq_length)
self.assertIsInstance(output.encoder_attentions, tuple)
self.assertListEqual(
[layer_attentions.shape for layer_attentions in output.encoder_attentions],
[encoder_expected_shape] * len(output.encoder_attentions),
)
# decoder
self._check_attentions_for_generate(
num_sequences_in_output,
output.decoder_attentions,
min_length=1,
max_length=output.sequences.shape[-1],
config=config,
use_cache=use_cache,
)
else:
# if use_cache first input is equal to no use_cache, so skip here
attentions = output.attentions if not use_cache else output.attentions[1:]
min_length = seq_length if not use_cache else seq_length + 1
self._check_attentions_for_generate(
num_sequences_in_output,
attentions=attentions,
min_length=min_length,
max_length=output.sequences.shape[-1],
config=config,
use_cache=use_cache,
)
# Hidden States
if config.is_encoder_decoder:
# encoder
encoder_expected_shape = (batch_size, seq_length, config.hidden_size)
self.assertIsInstance(output.encoder_hidden_states, tuple)
self.assertListEqual(
[layer_hidden_states.shape for layer_hidden_states in output.encoder_hidden_states],
[encoder_expected_shape] * len(output.encoder_hidden_states),
)
# decoder
self._check_hidden_states_for_generate(
num_sequences_in_output,
output.decoder_hidden_states,
min_length=1,
max_length=output.sequences.shape[-1],
config=config,
use_cache=use_cache,
)
else:
# if use_cache first input is equal to no use_cache, so skip here
hidden_states = output.hidden_states if not use_cache else output.hidden_states[1:]
min_length = seq_length if not use_cache else seq_length + 1
self._check_hidden_states_for_generate(
num_sequences_in_output,
hidden_states,
min_length=min_length,
max_length=output.sequences.shape[-1],
config=config,
use_cache=use_cache,
)
def _check_scores(self, batch_size, scores, length, config):
expected_shape = (batch_size, config.vocab_size)
self.assertIsInstance(scores, tuple)
self.assertEqual(len(scores), length)
self.assertListEqual([iter_scores.shape for iter_scores in scores], [expected_shape] * len(scores))
def _check_attentions_for_generate(
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
):
self.assertIsInstance(attentions, tuple)
self.assertListEqual(
[isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions)
)
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups)
for idx, iter_attentions in enumerate(attentions):
tgt_len = min_length + idx if not use_cache else 1
src_len = min_length + idx
expected_shape = (
batch_size * num_beam_groups,
config.num_attention_heads,
tgt_len,
src_len,
)
# check attn size
self.assertListEqual(
[layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions)
)
def _check_hidden_states_for_generate(
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1
):
self.assertIsInstance(hidden_states, tuple)
self.assertListEqual(
[isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states],
[True] * len(hidden_states),
)
self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups)
for idx, iter_hidden_states in enumerate(hidden_states):
seq_len = min_length + idx if not use_cache else 1
expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size)
# check hidden size
self.assertListEqual(
[layer_hidden_states.shape for layer_hidden_states in iter_hidden_states],
[expected_shape] * len(iter_hidden_states),
)
@require_torch
class UtilsFunctionsTest(unittest.TestCase):
# tests whether the top_k_top_p function behaves as expected
def test_top_k_top_p_filtering(self):
logits = torch.tensor(
[
[
8.2220991, # 3rd highest value; idx. 0
-0.5620044,
5.23229752,
4.0386393,
-6.8798378,
-0.54785802,
-3.2012153,
2.92777176,
1.88171953,
7.35341276,
8.43207833, # 2nd highest value; idx. 10
-9.85711836,
-5.96209236,
-1.13039161,
-7.1115294,
-0.8369633,
-5.3186408,
7.06427407,
0.81369344,
-0.82023817,
-5.9179796,
0.58813443,
-6.99778438,
4.71551189,
-0.18771637,
7.44020759, # 4th highest value; idx. 25
9.38450987, # 1st highest value; idx. 26
2.12662941,
-9.32562038,
2.35652522,
], # cummulative prob of 4 highest values <= 0.6
[
0.58425518,
4.53139238,
-5.57510464,
-6.28030699,
-7.19529503,
-4.02122551,
1.39337037,
-6.06707057,
1.59480517,
-9.643119,
0.03907799,
0.67231762,
-8.88206726,
6.27115922, # 4th highest value; idx. 13
2.28520723,
4.82767506,
4.30421368,
8.8275313, # 2nd highest value; idx. 17
5.44029958,
-4.4735794,
7.38579536, # 3rd highest value; idx. 20
-2.91051663,
2.61946077,
-2.5674762,
-9.48959302,
-4.02922645,
-1.35416918,
9.67702323, # 1st highest value; idx. 27
-5.89478553,
1.85370467,
], # cummulative prob of 4 highest values <= 0.6
],
dtype=torch.float,
device=torch_device,
)
non_inf_expected_idx = torch.tensor(
[[0, 0], [0, 10], [0, 25], [0, 26], [1, 13], [1, 17], [1, 20], [1, 27]],
dtype=torch.long,
device=torch_device,
) # expected non filtered idx as noted above
non_inf_expected_output = torch.tensor(
[
8.2221,
8.4321,
7.4402,
9.3845,
6.2712,
8.8275,
7.3858,
9.6770,
], # expected non filtered values as noted above
dtype=torch.float,
device=torch_device,
)
output = top_k_top_p_filtering(logits, top_k=10, top_p=0.6, min_tokens_to_keep=4)
non_inf_output = output[output != -float("inf")].to(device=torch_device)
non_inf_idx = (output != -float("inf")).nonzero().to(device=torch_device)
self.assertTrue(torch.allclose(non_inf_expected_output, non_inf_output, atol=1e-12))
self.assertTrue(torch.all(torch.eq(non_inf_expected_idx, non_inf_idx)))
@require_torch
class GenerationIntegrationTests(unittest.TestCase):
@slow
def test_diverse_beam_search(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood.
The celebrity couple announced the arrival of their son, Silas Randall Timberlake, in statements to People.
"Silas was the middle name of Timberlake's maternal grandfather Bill Bomar, who died in 2012, while Randall is the musician's own middle name, as well as his father's first," People reports.
The couple announced the pregnancy in January, with an Instagram post. It is the first baby for both."""
bart_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn").to(torch_device)
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
outputs = bart_model.generate(
input_ids, num_beams=4, num_return_sequences=2, num_beam_groups=4, diversity_penalty=2.0
)
generated_text = bart_tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(
generated_text,
[
"The couple announced the birth of their son, Silas Randall Timberlake, in a statement. Silas was the middle name of Timberlake's maternal grandfather Bill Bomar. Randall is the musician's own middle name, as well as his father's first. It is the first baby for both of them.",
"Justin Timberlake and Jessica Biel have a son. The baby is named Silas Randall Timberlake. It is the first child for both. The couple announced the pregnancy in January. The name Silas is the middle name of Timberlake's maternal grandfather. It's also his own middle name.",
],
)