Add flags to return scores, hidden states and / or attention weights in GenerationMixin (#9150)

* Define new output dataclasses for greedy generation

* Add output_[...] flags in greedy generation methods

Added output_attentions, output_hidden_states, output_scores flags in
generate and greedy_search methods in GenerationMixin.

* [WIP] Implement logic and tests for output flags in generation

* Update GreedySearchOutput classes & docstring

* Implement greedy search output accumulation logic

Update greedy_search unittests

Fix generate method return value docstring

Properly init flags with the default config

* Update configuration to add output_scores flag

* Fix test_generation_utils

Sort imports and fix isinstance tests for GreedySearchOutputs

* Fix typo in generation_utils

* Add return_dict_in_generate for backwards compatibility

* Add return_dict_in_generate flag in config

* Fix tyPo in configuration

* Fix handling of attentions and hidden_states flags

* Make style & quality

* first attempt attentions

* some corrections

* improve tests

* special models requires special test

* disable xlm test for now

* clean tests

* fix for tf

* isort

* Add output dataclasses for other generation methods

* Add logic to return dict in sample generation

* Complete test for sample generation

- Pass output_attentions and output_hidden_states flags to encoder in
encoder-decoder models
- Fix import satements order in test_generation_utils file

* Add logic to return dict in sample generation

- Refactor tests to avoid using self.assertTrue, which provides
scarce information when the test fails
- Add tests for the three beam_search methods: vanilla, sample and
grouped

* Style doc

* Fix copy-paste error in generation tests

* Rename logits to scores and refactor

* Refactor group_beam_search for consistency

* make style

* add sequences_scores

* fix all tests

* add docs

* fix beam search finalize test

* correct docstring

* clean some files

* Made suggested changes to the documentation

* Style doc ?

* Style doc using the Python util

* Update src/transformers/generation_utils.py

* fix empty lines

* fix all test

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Simon Brandeis 2021-01-06 17:11:42 +01:00 committed by GitHub
parent 7a9f1b5c99
commit c89f1bc92e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 2014 additions and 321 deletions

View File

@ -13,13 +13,102 @@
Utilities for Generation Utilities for Generation
----------------------------------------------------------------------------------------------------------------------- -----------------------------------------------------------------------------------------------------------------------
This page lists all the utility functions used by :meth:`~transformers.PretrainedModel.generate`, This page lists all the utility functions used by :meth:`~transformers.PreTrainedModel.generate`,
:meth:`~transformers.PretrainedModel.greedy_search`, :meth:`~transformers.PretrainedModel.sample`, :meth:`~transformers.PreTrainedModel.greedy_search`, :meth:`~transformers.PreTrainedModel.sample`,
:meth:`~transformers.PretrainedModel.beam_search`, :meth:`~transformers.PretrainedModel.beam_sample`, and :meth:`~transformers.PreTrainedModel.beam_search`, :meth:`~transformers.PreTrainedModel.beam_sample`, and
:meth:`~transformers.PretrainedModel.group_beam_search`. :meth:`~transformers.PreTrainedModel.group_beam_search`.
Most of those are only useful if you are studying the code of the generate methods in the library. Most of those are only useful if you are studying the code of the generate methods in the library.
Generate Outputs
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The output of :meth:`~transformers.PreTrainedModel.generate` is an instance of a subclass of
:class:`~transformers.file_utils.ModelOutput`. This output is a data structure containing all the information returned
by :meth:`~transformers.PreTrainedModel.generate`, but that can also be used as tuple or dictionary.
Here's an example:
.. code-block::
from transformers import GPT2Tokenizer, GPT2LMHeadModel
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')
inputs = tokenizer("Hello, my dog is cute and ", return_tensors="pt")
generation_output = model.generate(**inputs, return_dict_in_generate=True, output_scores=True)
The ``generation_output`` object is a :class:`~transformers.generation_utils.GreedySearchDecoderOnlyOutput`, as we can
see in the documentation of that class below, it means it has the following attributes:
- ``sequences``: the generated sequences of tokens
- ``scores`` (optional): the prediction scores of the language modelling head, for each generation step
- ``hidden_states`` (optional): the hidden states of the model, for each generation step
- ``attentions`` (optional): the attention weights of the model, for each generation step
Here we have the ``scores`` since we passed along ``output_scores=True``, but we don't have ``hidden_states`` and
``attentions`` because we didn't pass ``output_hidden_states=True`` or ``output_attentions=True``.
You can access each attribute as you would usually do, and if that attribute has not been returned by the model, you
will get ``None``. Here for instance ``generation_output.scores`` are all the generated prediction scores of the
language modeling head, and ``generation_output.attentions`` is ``None``.
When using our ``generation_output`` object as a tuple, it only keeps the attributes that don't have ``None`` values.
Here, for instance, it has two elements, ``loss`` then ``logits``, so
.. code-block::
generation_output[:2]
will return the tuple ``(generation_output.sequences, generation_output.scores)`` for instance.
When using our ``generation_output`` object as a dictionary, it only keeps the attributes that don't have ``None``
values. Here, for instance, it has two keys that are ``sequences`` and ``scores``.
We document here all output types.
GreedySearchOutput
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: transformers.generation_utils.GreedySearchDecoderOnlyOutput
:members:
.. autoclass:: transformers.generation_utils.GreedySearchEncoderDecoderOutput
:members:
SampleOutput
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: transformers.generation_utils.SampleDecoderOnlyOutput
:members:
.. autoclass:: transformers.generation_utils.SampleEncoderDecoderOutput
:members:
BeamSearchOutput
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: transformers.generation_utils.BeamSearchDecoderOnlyOutput
:members:
.. autoclass:: transformers.generation_utils.BeamSearchEncoderDecoderOutput
:members:
BeamSampleOutput
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: transformers.generation_utils.BeamSampleDecoderOnlyOutput
:members:
.. autoclass:: transformers.generation_utils.BeamSampleEncoderDecoderOutput
:members:
LogitsProcessor LogitsProcessor
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -124,6 +124,11 @@ class PretrainedConfig(object):
- **num_return_sequences** (:obj:`int`, `optional`, defaults to 1) -- Number of independently computed returned - **num_return_sequences** (:obj:`int`, `optional`, defaults to 1) -- Number of independently computed returned
sequences for each element in the batch that will be used by default in the :obj:`generate` method of the sequences for each element in the batch that will be used by default in the :obj:`generate` method of the
model. model.
- **output_scores** (:obj:`bool`, `optional`, defaults to :obj:`False`) -- Whether the model should return the
logits when used for generation
- **return_dict_in_generate** (:obj:`bool`, `optional`, defaults to :obj:`False`) -- Whether the model should
return a :class:`~transformers.file_utils.ModelOutput` instead of a :obj:`torch.LongTensor`
Parameters for fine-tuning tasks Parameters for fine-tuning tasks
@ -203,6 +208,8 @@ class PretrainedConfig(object):
self.bad_words_ids = kwargs.pop("bad_words_ids", None) self.bad_words_ids = kwargs.pop("bad_words_ids", None)
self.num_return_sequences = kwargs.pop("num_return_sequences", 1) self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
self.chunk_size_feed_forward = kwargs.pop("chunk_size_feed_forward", 0) self.chunk_size_feed_forward = kwargs.pop("chunk_size_feed_forward", 0)
self.output_scores = kwargs.pop("output_scores", False)
self.return_dict_in_generate = kwargs.pop("return_dict_in_generate", False)
# Fine-tuning task arguments # Fine-tuning task arguments
self.architectures = kwargs.pop("architectures", None) self.architectures = kwargs.pop("architectures", None)
@ -343,6 +350,7 @@ class PretrainedConfig(object):
Passing :obj:`use_auth_token=True` is required when you want to use a private model. Passing :obj:`use_auth_token=True` is required when you want to use a private model.
Returns: Returns:
:class:`PretrainedConfig`: The configuration object instantiated from this pretrained model. :class:`PretrainedConfig`: The configuration object instantiated from this pretrained model.
@ -372,6 +380,8 @@ class PretrainedConfig(object):
From a ``pretrained_model_name_or_path``, resolve to a dictionary of parameters, to be used for instantiating a From a ``pretrained_model_name_or_path``, resolve to a dictionary of parameters, to be used for instantiating a
:class:`~transformers.PretrainedConfig` using ``from_dict``. :class:`~transformers.PretrainedConfig` using ``from_dict``.
Parameters: Parameters:
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`): pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
The identifier of the pre-trained checkpoint from which we want the dictionary of parameters. The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.

View File

@ -281,7 +281,7 @@ class BeamSearchScorer(BeamScorer):
final_beam_indices: torch.LongTensor, final_beam_indices: torch.LongTensor,
pad_token_id: Optional[int] = None, pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None, eos_token_id: Optional[int] = None,
) -> torch.LongTensor: ) -> Tuple[torch.LongTensor]:
batch_size = len(self._beam_hyps) batch_size = len(self._beam_hyps)
# finalize all open beam hypotheses and add to generated hypotheses # finalize all open beam hypotheses and add to generated hypotheses
@ -300,14 +300,20 @@ class BeamSearchScorer(BeamScorer):
# select the best hypotheses # select the best hypotheses
sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep) sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
best = [] best = []
best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32)
# retrieve best hypotheses # retrieve best hypotheses
for i, beam_hyp in enumerate(self._beam_hyps): for i, beam_hyp in enumerate(self._beam_hyps):
sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0]) sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0])
for j in range(self.num_beam_hyps_to_keep): for j in range(self.num_beam_hyps_to_keep):
best_hyp = sorted_hyps.pop()[1] best_hyp_tuple = sorted_hyps.pop()
best_score = best_hyp_tuple[0]
best_hyp = best_hyp_tuple[1]
sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp) sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)
# append to lists
best.append(best_hyp) best.append(best_hyp)
best_scores[i * self.num_beam_hyps_to_keep + j] = best_score
# prepare for adding eos # prepare for adding eos
sent_max_len = min(sent_lengths.max().item() + 1, self.max_length) sent_max_len = min(sent_lengths.max().item() + 1, self.max_length)
@ -322,7 +328,12 @@ class BeamSearchScorer(BeamScorer):
decoded[i, : sent_lengths[i]] = hypo decoded[i, : sent_lengths[i]] = hypo
if sent_lengths[i] < self.max_length: if sent_lengths[i] < self.max_length:
decoded[i, sent_lengths[i]] = eos_token_id decoded[i, sent_lengths[i]] = eos_token_id
return decoded return UserDict(
{
"sequences": decoded,
"sequence_scores": best_scores,
}
)
class BeamHypotheses: class BeamHypotheses:

File diff suppressed because it is too large Load Diff

View File

@ -136,7 +136,6 @@ class OpenAIGPTConfig(PretrainedConfig):
summary_activation=None, summary_activation=None,
summary_proj_to_labels=True, summary_proj_to_labels=True,
summary_first_dropout=0.1, summary_first_dropout=0.1,
use_cache=True,
**kwargs **kwargs
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
@ -159,7 +158,6 @@ class OpenAIGPTConfig(PretrainedConfig):
self.summary_activation = summary_activation self.summary_activation = summary_activation
self.summary_first_dropout = summary_first_dropout self.summary_first_dropout = summary_first_dropout
self.summary_proj_to_labels = summary_proj_to_labels self.summary_proj_to_labels = summary_proj_to_labels
self.use_cache = use_cache
@property @property
def max_position_embeddings(self): def max_position_embeddings(self):

View File

@ -190,7 +190,7 @@ class BeamSearchTester:
input_ids = torch.cat([input_ids[output_indices, :], output_tokens.unsqueeze(-1)], dim=-1) input_ids = torch.cat([input_ids[output_indices, :], output_tokens.unsqueeze(-1)], dim=-1)
# finalize # finalize
decoded = beam_scorer.finalize( sequence_output = beam_scorer.finalize(
input_ids, input_ids,
output_scores, output_scores,
output_tokens, output_tokens,
@ -198,19 +198,27 @@ class BeamSearchTester:
pad_token_id=self.pad_token_id, pad_token_id=self.pad_token_id,
eos_token_id=self.eos_token_id, eos_token_id=self.eos_token_id,
) )
sequences = sequence_output["sequences"]
sequence_scores = sequence_output["sequence_scores"]
# since `num_beam_hyps_to_keep` = 1 => only return `batch_size` x `max_length` # since `num_beam_hyps_to_keep` = 1 => only return `batch_size` x `max_length`
self.parent.assertListEqual(list(decoded.shape), [self.batch_size, max_length]) self.parent.assertListEqual(list(sequences.shape), [self.batch_size, max_length])
self.parent.assertListEqual(list(sequence_scores.shape), [self.batch_size])
# check sequence_scores
self.parent.assertFalse((sequence_scores > 0).any().item())
# first batch has to finish with eos_token # first batch has to finish with eos_token
self.parent.assertEqual(decoded[0, -1].item(), self.eos_token_id) self.parent.assertEqual(sequences[0, -1].item(), self.eos_token_id)
# other batches cannot finish with eos token # other batches cannot finish with eos token
self.parent.assertNotEqual(decoded[1, -1].item(), self.eos_token_id) self.parent.assertNotEqual(sequences[1, -1].item(), self.eos_token_id)
self.parent.assertNotEqual(decoded[2, -1].item(), self.eos_token_id) self.parent.assertNotEqual(sequences[2, -1].item(), self.eos_token_id)
# now test that if `num_beam_hyps_to_keep` is 3 => all beams are returned # now test that if `num_beam_hyps_to_keep` is 3 => all beams are returned
beam_scorer.num_beam_hyps_to_keep = self.num_beams beam_scorer.num_beam_hyps_to_keep = self.num_beams
decoded = beam_scorer.finalize( sequence_output = beam_scorer.finalize(
input_ids, input_ids,
output_scores, output_scores,
output_tokens, output_tokens,
@ -218,7 +226,11 @@ class BeamSearchTester:
pad_token_id=self.pad_token_id, pad_token_id=self.pad_token_id,
eos_token_id=self.eos_token_id, eos_token_id=self.eos_token_id,
) )
self.parent.assertListEqual(list(decoded.shape), [self.num_beams * self.batch_size, max_length]) sequences = sequence_output["sequences"]
sequence_scores = sequence_output["sequence_scores"]
self.parent.assertListEqual(list(sequences.shape), [self.num_beams * self.batch_size, max_length])
self.parent.assertListEqual(list(sequence_scores.shape), [self.num_beams * self.batch_size])
@require_torch @require_torch

File diff suppressed because it is too large Load Diff

View File

@ -638,6 +638,69 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, GenerationTesterMixin, Mod
model = ReformerModelWithLMHead.from_pretrained(model_name) model = ReformerModelWithLMHead.from_pretrained(model_name)
self.assertIsNotNone(model) self.assertIsNotNone(model)
def _check_attentions_for_generate(
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
):
self.assertIsInstance(attentions, tuple)
self.assertListEqual(
[isinstance(iter_attentions, list) for iter_attentions in attentions], [True] * len(attentions)
)
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups)
for idx, iter_attentions in enumerate(attentions):
tgt_len = min_length + idx if not use_cache else 1
num_chunks = tgt_len // config.local_attn_chunk_length + (tgt_len % config.local_attn_chunk_length != 0)
tgt_chunk_len = config.local_attn_chunk_length
src_chunk_len = config.local_attn_chunk_length * (
1 + config.local_num_chunks_after + config.local_num_chunks_before
)
if use_cache:
expected_shape = (
batch_size * num_beam_groups,
config.num_attention_heads,
tgt_len,
min_length // config.local_attn_chunk_length + 1 + idx,
)
else:
expected_shape = (
batch_size * num_beam_groups,
config.num_attention_heads,
num_chunks,
tgt_chunk_len,
src_chunk_len,
)
# check attn size
self.assertListEqual(
[layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions)
)
def _check_hidden_states_for_generate(
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1
):
self.assertIsInstance(hidden_states, tuple)
self.assertListEqual(
[isinstance(iter_hidden_states, list) for iter_hidden_states in hidden_states],
[True] * len(hidden_states),
)
self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups)
for idx, iter_hidden_states in enumerate(hidden_states):
seq_len = min_length + idx
seq_len = config.local_attn_chunk_length * (
seq_len // config.local_attn_chunk_length + (seq_len % config.local_attn_chunk_length != 0)
)
if use_cache:
seq_len = 1
expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size)
# check hidden size
self.assertListEqual(
[layer_hidden_states.shape for layer_hidden_states in iter_hidden_states],
[expected_shape] * len(iter_hidden_states),
)
@require_torch @require_torch
class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
@ -696,13 +759,77 @@ class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, Generation
self.model_tester = ReformerModelTester(self, **tester_kwargs) self.model_tester = ReformerModelTester(self, **tester_kwargs)
self.config_tester = ConfigTester(self, config_class=ReformerConfig, hidden_size=37) self.config_tester = ConfigTester(self, config_class=ReformerConfig, hidden_size=37)
def _check_attentions_for_generate(
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
):
self.assertIsInstance(attentions, tuple)
self.assertListEqual(
[isinstance(iter_attentions, list) for iter_attentions in attentions], [True] * len(attentions)
)
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups)
for idx, iter_attentions in enumerate(attentions):
tgt_len = min_length + idx if not use_cache else 1
num_chunks = tgt_len // config.lsh_attn_chunk_length + (tgt_len % config.lsh_attn_chunk_length != 0)
tgt_chunk_len = config.lsh_attn_chunk_length
src_chunk_len = config.lsh_attn_chunk_length * (
1 + config.lsh_num_chunks_after + config.lsh_num_chunks_before
)
if use_cache:
expected_shape = (
batch_size * num_beam_groups,
config.num_attention_heads,
config.num_hashes,
tgt_len,
config.num_hashes * (1 + config.lsh_num_chunks_after + config.lsh_num_chunks_before),
)
else:
expected_shape = (
batch_size * num_beam_groups,
config.num_attention_heads,
num_chunks * config.num_hashes,
tgt_chunk_len,
src_chunk_len,
)
# check attn size
self.assertListEqual(
[layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions)
)
def _check_hidden_states_for_generate(
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1
):
self.assertIsInstance(hidden_states, tuple)
self.assertListEqual(
[isinstance(iter_hidden_states, list) for iter_hidden_states in hidden_states],
[True] * len(hidden_states),
)
self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups)
for idx, iter_hidden_states in enumerate(hidden_states):
seq_len = min_length + idx if not use_cache else 1
seq_len = config.lsh_attn_chunk_length * (
seq_len // config.lsh_attn_chunk_length + (seq_len % config.lsh_attn_chunk_length != 0)
)
if use_cache:
seq_len = 1
expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size)
# check hidden size
self.assertListEqual(
[layer_hidden_states.shape for layer_hidden_states in iter_hidden_states],
[expected_shape] * len(iter_hidden_states),
)
@require_torch @require_torch
@require_sentencepiece @require_sentencepiece
@require_tokenizers @require_tokenizers
class ReformerIntegrationTests(unittest.TestCase): class ReformerIntegrationTests(unittest.TestCase):
""" """
These integration tests test the current layer activations and gradients againts the output of the Hugging Face Reformer model at time of integration: 29/06/2020. During integration, the model was tested against the output of the official Trax ReformerLM model for various cases ("lsh" only, "local" only, masked / non-masked, different chunk length, ....). In order to recover the original trax integration tests, one should use patrickvonplaten's fork of trax and the code that lives on the branch `reformer_trax_tests`. These integration tests test the current layer activations and gradients againts the output of the Hugging Face Reformer model at time of integration: 29/06/2020. During integration, the model was tested against the output of the official Trax ReformerLM model for various cases ("lsh" only, "lsh" only, masked / non-masked, different chunk length, ....). In order to recover the original trax integration tests, one should use patrickvonplaten's fork of trax and the code that lives on the branch `reformer_trax_tests`.
""" """
def _get_basic_config_and_input(self): def _get_basic_config_and_input(self):

View File

@ -304,6 +304,50 @@ class TransfoXLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestC
# transfo-xl requires special resize for lm-head # transfo-xl requires special resize for lm-head
return return
def _check_attentions_for_generate(
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
):
self.assertIsInstance(attentions, tuple)
self.assertListEqual(
[isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions)
)
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups)
for idx, iter_attentions in enumerate(attentions):
tgt_len = min_length if idx == 0 else (min_length - 2)
src_len = (min_length + config.mem_len) if idx == 0 else (min_length + config.mem_len - 2)
expected_shape = (
batch_size * num_beam_groups,
config.num_attention_heads,
tgt_len,
src_len,
)
# check attn size
self.assertListEqual(
[layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions)
)
def _check_hidden_states_for_generate(
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1
):
self.assertIsInstance(hidden_states, tuple)
self.assertListEqual(
[isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states],
[True] * len(hidden_states),
)
self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups)
for idx, iter_hidden_states in enumerate(hidden_states):
seq_len = min_length if idx == 0 else min_length - 2
expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size)
# check hidden size
self.assertListEqual(
[layer_hidden_states.shape for layer_hidden_states in iter_hidden_states],
[expected_shape] * len(iter_hidden_states),
)
@require_torch @require_torch
class TransfoXLModelLanguageGenerationTest(unittest.TestCase): class TransfoXLModelLanguageGenerationTest(unittest.TestCase):

View File

@ -400,6 +400,52 @@ class XLMModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xlm_for_multiple_choice(*config_and_inputs) self.model_tester.create_and_check_xlm_for_multiple_choice(*config_and_inputs)
def _check_attentions_for_generate(
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
):
self.assertIsInstance(attentions, tuple)
self.assertListEqual(
[isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions)
)
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups)
for idx, iter_attentions in enumerate(attentions):
# adds PAD dummy token
tgt_len = min_length + idx + 1
src_len = min_length + idx + 1
expected_shape = (
batch_size * num_beam_groups,
config.num_attention_heads,
tgt_len,
src_len,
)
# check attn size
self.assertListEqual(
[layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions)
)
def _check_hidden_states_for_generate(
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1
):
self.assertIsInstance(hidden_states, tuple)
self.assertListEqual(
[isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states],
[True] * len(hidden_states),
)
self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups)
for idx, iter_hidden_states in enumerate(hidden_states):
# adds PAD dummy token
seq_len = min_length + idx + 1
expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size)
# check hidden size
self.assertListEqual(
[layer_hidden_states.shape for layer_hidden_states in iter_hidden_states],
[expected_shape] * len(iter_hidden_states),
)
pass
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in XLM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: for model_name in XLM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:

View File

@ -593,6 +593,60 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
# xlnet cannot keep gradients in attentions or hidden states # xlnet cannot keep gradients in attentions or hidden states
return return
def _check_hidden_states_for_generate(
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1
):
self.assertIsInstance(hidden_states, tuple)
self.assertListEqual(
[isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states],
[True] * len(hidden_states),
)
self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups)
for idx, iter_hidden_states in enumerate(hidden_states):
# check hidden size
for i, layer_hidden_states in enumerate(iter_hidden_states):
# every 2nd tensor is from extra stream
if i % 2 != 0:
seq_len = 1
else:
# for first item dummy PAD token is appended so need one more
seq_len = (min_length + 1) if idx == 0 else min_length
expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size)
self.assertEqual(layer_hidden_states.shape, expected_shape)
def _check_attentions_for_generate(
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
):
self.assertIsInstance(attentions, tuple)
self.assertListEqual(
[isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions)
)
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups)
for idx, attentions_item in enumerate(attentions):
for iter_attentions in attentions_item:
tgt_len = min_length
# for first item dummy PAD token is appended so need one more
if idx == 0:
tgt_len += 1
src_len = min_length + idx + 1
expected_shape = (
batch_size * num_beam_groups,
config.num_attention_heads,
tgt_len,
src_len,
)
# check attn size
self.assertListEqual(
[layer_attention.shape for layer_attention in iter_attentions],
[expected_shape] * len(iter_attentions),
)
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in XLNET_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: for model_name in XLNET_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: