mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add flags to return scores, hidden states and / or attention weights in GenerationMixin (#9150)
* Define new output dataclasses for greedy generation * Add output_[...] flags in greedy generation methods Added output_attentions, output_hidden_states, output_scores flags in generate and greedy_search methods in GenerationMixin. * [WIP] Implement logic and tests for output flags in generation * Update GreedySearchOutput classes & docstring * Implement greedy search output accumulation logic Update greedy_search unittests Fix generate method return value docstring Properly init flags with the default config * Update configuration to add output_scores flag * Fix test_generation_utils Sort imports and fix isinstance tests for GreedySearchOutputs * Fix typo in generation_utils * Add return_dict_in_generate for backwards compatibility * Add return_dict_in_generate flag in config * Fix tyPo in configuration * Fix handling of attentions and hidden_states flags * Make style & quality * first attempt attentions * some corrections * improve tests * special models requires special test * disable xlm test for now * clean tests * fix for tf * isort * Add output dataclasses for other generation methods * Add logic to return dict in sample generation * Complete test for sample generation - Pass output_attentions and output_hidden_states flags to encoder in encoder-decoder models - Fix import satements order in test_generation_utils file * Add logic to return dict in sample generation - Refactor tests to avoid using self.assertTrue, which provides scarce information when the test fails - Add tests for the three beam_search methods: vanilla, sample and grouped * Style doc * Fix copy-paste error in generation tests * Rename logits to scores and refactor * Refactor group_beam_search for consistency * make style * add sequences_scores * fix all tests * add docs * fix beam search finalize test * correct docstring * clean some files * Made suggested changes to the documentation * Style doc ? * Style doc using the Python util * Update src/transformers/generation_utils.py * fix empty lines * fix all test Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
parent
7a9f1b5c99
commit
c89f1bc92e
@ -13,13 +13,102 @@
|
||||
Utilities for Generation
|
||||
-----------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
This page lists all the utility functions used by :meth:`~transformers.PretrainedModel.generate`,
|
||||
:meth:`~transformers.PretrainedModel.greedy_search`, :meth:`~transformers.PretrainedModel.sample`,
|
||||
:meth:`~transformers.PretrainedModel.beam_search`, :meth:`~transformers.PretrainedModel.beam_sample`, and
|
||||
:meth:`~transformers.PretrainedModel.group_beam_search`.
|
||||
This page lists all the utility functions used by :meth:`~transformers.PreTrainedModel.generate`,
|
||||
:meth:`~transformers.PreTrainedModel.greedy_search`, :meth:`~transformers.PreTrainedModel.sample`,
|
||||
:meth:`~transformers.PreTrainedModel.beam_search`, :meth:`~transformers.PreTrainedModel.beam_sample`, and
|
||||
:meth:`~transformers.PreTrainedModel.group_beam_search`.
|
||||
|
||||
Most of those are only useful if you are studying the code of the generate methods in the library.
|
||||
|
||||
Generate Outputs
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
The output of :meth:`~transformers.PreTrainedModel.generate` is an instance of a subclass of
|
||||
:class:`~transformers.file_utils.ModelOutput`. This output is a data structure containing all the information returned
|
||||
by :meth:`~transformers.PreTrainedModel.generate`, but that can also be used as tuple or dictionary.
|
||||
|
||||
Here's an example:
|
||||
|
||||
.. code-block::
|
||||
|
||||
from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
||||
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
model = GPT2LMHeadModel.from_pretrained('gpt2')
|
||||
|
||||
inputs = tokenizer("Hello, my dog is cute and ", return_tensors="pt")
|
||||
generation_output = model.generate(**inputs, return_dict_in_generate=True, output_scores=True)
|
||||
|
||||
The ``generation_output`` object is a :class:`~transformers.generation_utils.GreedySearchDecoderOnlyOutput`, as we can
|
||||
see in the documentation of that class below, it means it has the following attributes:
|
||||
|
||||
- ``sequences``: the generated sequences of tokens
|
||||
- ``scores`` (optional): the prediction scores of the language modelling head, for each generation step
|
||||
- ``hidden_states`` (optional): the hidden states of the model, for each generation step
|
||||
- ``attentions`` (optional): the attention weights of the model, for each generation step
|
||||
|
||||
Here we have the ``scores`` since we passed along ``output_scores=True``, but we don't have ``hidden_states`` and
|
||||
``attentions`` because we didn't pass ``output_hidden_states=True`` or ``output_attentions=True``.
|
||||
|
||||
You can access each attribute as you would usually do, and if that attribute has not been returned by the model, you
|
||||
will get ``None``. Here for instance ``generation_output.scores`` are all the generated prediction scores of the
|
||||
language modeling head, and ``generation_output.attentions`` is ``None``.
|
||||
|
||||
When using our ``generation_output`` object as a tuple, it only keeps the attributes that don't have ``None`` values.
|
||||
Here, for instance, it has two elements, ``loss`` then ``logits``, so
|
||||
|
||||
.. code-block::
|
||||
|
||||
generation_output[:2]
|
||||
|
||||
will return the tuple ``(generation_output.sequences, generation_output.scores)`` for instance.
|
||||
|
||||
When using our ``generation_output`` object as a dictionary, it only keeps the attributes that don't have ``None``
|
||||
values. Here, for instance, it has two keys that are ``sequences`` and ``scores``.
|
||||
|
||||
We document here all output types.
|
||||
|
||||
|
||||
GreedySearchOutput
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: transformers.generation_utils.GreedySearchDecoderOnlyOutput
|
||||
:members:
|
||||
|
||||
.. autoclass:: transformers.generation_utils.GreedySearchEncoderDecoderOutput
|
||||
:members:
|
||||
|
||||
|
||||
SampleOutput
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: transformers.generation_utils.SampleDecoderOnlyOutput
|
||||
:members:
|
||||
|
||||
.. autoclass:: transformers.generation_utils.SampleEncoderDecoderOutput
|
||||
:members:
|
||||
|
||||
|
||||
BeamSearchOutput
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: transformers.generation_utils.BeamSearchDecoderOnlyOutput
|
||||
:members:
|
||||
|
||||
.. autoclass:: transformers.generation_utils.BeamSearchEncoderDecoderOutput
|
||||
:members:
|
||||
|
||||
|
||||
BeamSampleOutput
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: transformers.generation_utils.BeamSampleDecoderOnlyOutput
|
||||
:members:
|
||||
|
||||
.. autoclass:: transformers.generation_utils.BeamSampleEncoderDecoderOutput
|
||||
:members:
|
||||
|
||||
|
||||
LogitsProcessor
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
@ -124,6 +124,11 @@ class PretrainedConfig(object):
|
||||
- **num_return_sequences** (:obj:`int`, `optional`, defaults to 1) -- Number of independently computed returned
|
||||
sequences for each element in the batch that will be used by default in the :obj:`generate` method of the
|
||||
model.
|
||||
- **output_scores** (:obj:`bool`, `optional`, defaults to :obj:`False`) -- Whether the model should return the
|
||||
logits when used for generation
|
||||
- **return_dict_in_generate** (:obj:`bool`, `optional`, defaults to :obj:`False`) -- Whether the model should
|
||||
return a :class:`~transformers.file_utils.ModelOutput` instead of a :obj:`torch.LongTensor`
|
||||
|
||||
|
||||
Parameters for fine-tuning tasks
|
||||
|
||||
@ -203,6 +208,8 @@ class PretrainedConfig(object):
|
||||
self.bad_words_ids = kwargs.pop("bad_words_ids", None)
|
||||
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
|
||||
self.chunk_size_feed_forward = kwargs.pop("chunk_size_feed_forward", 0)
|
||||
self.output_scores = kwargs.pop("output_scores", False)
|
||||
self.return_dict_in_generate = kwargs.pop("return_dict_in_generate", False)
|
||||
|
||||
# Fine-tuning task arguments
|
||||
self.architectures = kwargs.pop("architectures", None)
|
||||
@ -343,6 +350,7 @@ class PretrainedConfig(object):
|
||||
|
||||
Passing :obj:`use_auth_token=True` is required when you want to use a private model.
|
||||
|
||||
|
||||
Returns:
|
||||
:class:`PretrainedConfig`: The configuration object instantiated from this pretrained model.
|
||||
|
||||
@ -372,6 +380,8 @@ class PretrainedConfig(object):
|
||||
From a ``pretrained_model_name_or_path``, resolve to a dictionary of parameters, to be used for instantiating a
|
||||
:class:`~transformers.PretrainedConfig` using ``from_dict``.
|
||||
|
||||
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
|
||||
The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
|
||||
|
@ -281,7 +281,7 @@ class BeamSearchScorer(BeamScorer):
|
||||
final_beam_indices: torch.LongTensor,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
) -> torch.LongTensor:
|
||||
) -> Tuple[torch.LongTensor]:
|
||||
batch_size = len(self._beam_hyps)
|
||||
|
||||
# finalize all open beam hypotheses and add to generated hypotheses
|
||||
@ -300,14 +300,20 @@ class BeamSearchScorer(BeamScorer):
|
||||
# select the best hypotheses
|
||||
sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
|
||||
best = []
|
||||
best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32)
|
||||
|
||||
# retrieve best hypotheses
|
||||
for i, beam_hyp in enumerate(self._beam_hyps):
|
||||
sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0])
|
||||
for j in range(self.num_beam_hyps_to_keep):
|
||||
best_hyp = sorted_hyps.pop()[1]
|
||||
best_hyp_tuple = sorted_hyps.pop()
|
||||
best_score = best_hyp_tuple[0]
|
||||
best_hyp = best_hyp_tuple[1]
|
||||
sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)
|
||||
|
||||
# append to lists
|
||||
best.append(best_hyp)
|
||||
best_scores[i * self.num_beam_hyps_to_keep + j] = best_score
|
||||
|
||||
# prepare for adding eos
|
||||
sent_max_len = min(sent_lengths.max().item() + 1, self.max_length)
|
||||
@ -322,7 +328,12 @@ class BeamSearchScorer(BeamScorer):
|
||||
decoded[i, : sent_lengths[i]] = hypo
|
||||
if sent_lengths[i] < self.max_length:
|
||||
decoded[i, sent_lengths[i]] = eos_token_id
|
||||
return decoded
|
||||
return UserDict(
|
||||
{
|
||||
"sequences": decoded,
|
||||
"sequence_scores": best_scores,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class BeamHypotheses:
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -136,7 +136,6 @@ class OpenAIGPTConfig(PretrainedConfig):
|
||||
summary_activation=None,
|
||||
summary_proj_to_labels=True,
|
||||
summary_first_dropout=0.1,
|
||||
use_cache=True,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
@ -159,7 +158,6 @@ class OpenAIGPTConfig(PretrainedConfig):
|
||||
self.summary_activation = summary_activation
|
||||
self.summary_first_dropout = summary_first_dropout
|
||||
self.summary_proj_to_labels = summary_proj_to_labels
|
||||
self.use_cache = use_cache
|
||||
|
||||
@property
|
||||
def max_position_embeddings(self):
|
||||
|
@ -190,7 +190,7 @@ class BeamSearchTester:
|
||||
input_ids = torch.cat([input_ids[output_indices, :], output_tokens.unsqueeze(-1)], dim=-1)
|
||||
|
||||
# finalize
|
||||
decoded = beam_scorer.finalize(
|
||||
sequence_output = beam_scorer.finalize(
|
||||
input_ids,
|
||||
output_scores,
|
||||
output_tokens,
|
||||
@ -198,19 +198,27 @@ class BeamSearchTester:
|
||||
pad_token_id=self.pad_token_id,
|
||||
eos_token_id=self.eos_token_id,
|
||||
)
|
||||
|
||||
sequences = sequence_output["sequences"]
|
||||
sequence_scores = sequence_output["sequence_scores"]
|
||||
|
||||
# since `num_beam_hyps_to_keep` = 1 => only return `batch_size` x `max_length`
|
||||
self.parent.assertListEqual(list(decoded.shape), [self.batch_size, max_length])
|
||||
self.parent.assertListEqual(list(sequences.shape), [self.batch_size, max_length])
|
||||
self.parent.assertListEqual(list(sequence_scores.shape), [self.batch_size])
|
||||
|
||||
# check sequence_scores
|
||||
self.parent.assertFalse((sequence_scores > 0).any().item())
|
||||
|
||||
# first batch has to finish with eos_token
|
||||
self.parent.assertEqual(decoded[0, -1].item(), self.eos_token_id)
|
||||
self.parent.assertEqual(sequences[0, -1].item(), self.eos_token_id)
|
||||
|
||||
# other batches cannot finish with eos token
|
||||
self.parent.assertNotEqual(decoded[1, -1].item(), self.eos_token_id)
|
||||
self.parent.assertNotEqual(decoded[2, -1].item(), self.eos_token_id)
|
||||
self.parent.assertNotEqual(sequences[1, -1].item(), self.eos_token_id)
|
||||
self.parent.assertNotEqual(sequences[2, -1].item(), self.eos_token_id)
|
||||
|
||||
# now test that if `num_beam_hyps_to_keep` is 3 => all beams are returned
|
||||
beam_scorer.num_beam_hyps_to_keep = self.num_beams
|
||||
decoded = beam_scorer.finalize(
|
||||
sequence_output = beam_scorer.finalize(
|
||||
input_ids,
|
||||
output_scores,
|
||||
output_tokens,
|
||||
@ -218,7 +226,11 @@ class BeamSearchTester:
|
||||
pad_token_id=self.pad_token_id,
|
||||
eos_token_id=self.eos_token_id,
|
||||
)
|
||||
self.parent.assertListEqual(list(decoded.shape), [self.num_beams * self.batch_size, max_length])
|
||||
sequences = sequence_output["sequences"]
|
||||
sequence_scores = sequence_output["sequence_scores"]
|
||||
|
||||
self.parent.assertListEqual(list(sequences.shape), [self.num_beams * self.batch_size, max_length])
|
||||
self.parent.assertListEqual(list(sequence_scores.shape), [self.num_beams * self.batch_size])
|
||||
|
||||
|
||||
@require_torch
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -638,6 +638,69 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, GenerationTesterMixin, Mod
|
||||
model = ReformerModelWithLMHead.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
def _check_attentions_for_generate(
|
||||
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
||||
):
|
||||
self.assertIsInstance(attentions, tuple)
|
||||
self.assertListEqual(
|
||||
[isinstance(iter_attentions, list) for iter_attentions in attentions], [True] * len(attentions)
|
||||
)
|
||||
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups)
|
||||
|
||||
for idx, iter_attentions in enumerate(attentions):
|
||||
tgt_len = min_length + idx if not use_cache else 1
|
||||
num_chunks = tgt_len // config.local_attn_chunk_length + (tgt_len % config.local_attn_chunk_length != 0)
|
||||
tgt_chunk_len = config.local_attn_chunk_length
|
||||
src_chunk_len = config.local_attn_chunk_length * (
|
||||
1 + config.local_num_chunks_after + config.local_num_chunks_before
|
||||
)
|
||||
|
||||
if use_cache:
|
||||
expected_shape = (
|
||||
batch_size * num_beam_groups,
|
||||
config.num_attention_heads,
|
||||
tgt_len,
|
||||
min_length // config.local_attn_chunk_length + 1 + idx,
|
||||
)
|
||||
else:
|
||||
expected_shape = (
|
||||
batch_size * num_beam_groups,
|
||||
config.num_attention_heads,
|
||||
num_chunks,
|
||||
tgt_chunk_len,
|
||||
src_chunk_len,
|
||||
)
|
||||
# check attn size
|
||||
self.assertListEqual(
|
||||
[layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions)
|
||||
)
|
||||
|
||||
def _check_hidden_states_for_generate(
|
||||
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
||||
):
|
||||
self.assertIsInstance(hidden_states, tuple)
|
||||
self.assertListEqual(
|
||||
[isinstance(iter_hidden_states, list) for iter_hidden_states in hidden_states],
|
||||
[True] * len(hidden_states),
|
||||
)
|
||||
self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups)
|
||||
|
||||
for idx, iter_hidden_states in enumerate(hidden_states):
|
||||
seq_len = min_length + idx
|
||||
seq_len = config.local_attn_chunk_length * (
|
||||
seq_len // config.local_attn_chunk_length + (seq_len % config.local_attn_chunk_length != 0)
|
||||
)
|
||||
|
||||
if use_cache:
|
||||
seq_len = 1
|
||||
|
||||
expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size)
|
||||
# check hidden size
|
||||
self.assertListEqual(
|
||||
[layer_hidden_states.shape for layer_hidden_states in iter_hidden_states],
|
||||
[expected_shape] * len(iter_hidden_states),
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
@ -696,13 +759,77 @@ class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, Generation
|
||||
self.model_tester = ReformerModelTester(self, **tester_kwargs)
|
||||
self.config_tester = ConfigTester(self, config_class=ReformerConfig, hidden_size=37)
|
||||
|
||||
def _check_attentions_for_generate(
|
||||
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
||||
):
|
||||
self.assertIsInstance(attentions, tuple)
|
||||
self.assertListEqual(
|
||||
[isinstance(iter_attentions, list) for iter_attentions in attentions], [True] * len(attentions)
|
||||
)
|
||||
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups)
|
||||
|
||||
for idx, iter_attentions in enumerate(attentions):
|
||||
tgt_len = min_length + idx if not use_cache else 1
|
||||
num_chunks = tgt_len // config.lsh_attn_chunk_length + (tgt_len % config.lsh_attn_chunk_length != 0)
|
||||
tgt_chunk_len = config.lsh_attn_chunk_length
|
||||
src_chunk_len = config.lsh_attn_chunk_length * (
|
||||
1 + config.lsh_num_chunks_after + config.lsh_num_chunks_before
|
||||
)
|
||||
|
||||
if use_cache:
|
||||
expected_shape = (
|
||||
batch_size * num_beam_groups,
|
||||
config.num_attention_heads,
|
||||
config.num_hashes,
|
||||
tgt_len,
|
||||
config.num_hashes * (1 + config.lsh_num_chunks_after + config.lsh_num_chunks_before),
|
||||
)
|
||||
else:
|
||||
expected_shape = (
|
||||
batch_size * num_beam_groups,
|
||||
config.num_attention_heads,
|
||||
num_chunks * config.num_hashes,
|
||||
tgt_chunk_len,
|
||||
src_chunk_len,
|
||||
)
|
||||
# check attn size
|
||||
self.assertListEqual(
|
||||
[layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions)
|
||||
)
|
||||
|
||||
def _check_hidden_states_for_generate(
|
||||
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
||||
):
|
||||
self.assertIsInstance(hidden_states, tuple)
|
||||
self.assertListEqual(
|
||||
[isinstance(iter_hidden_states, list) for iter_hidden_states in hidden_states],
|
||||
[True] * len(hidden_states),
|
||||
)
|
||||
self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups)
|
||||
|
||||
for idx, iter_hidden_states in enumerate(hidden_states):
|
||||
seq_len = min_length + idx if not use_cache else 1
|
||||
seq_len = config.lsh_attn_chunk_length * (
|
||||
seq_len // config.lsh_attn_chunk_length + (seq_len % config.lsh_attn_chunk_length != 0)
|
||||
)
|
||||
|
||||
if use_cache:
|
||||
seq_len = 1
|
||||
|
||||
expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size)
|
||||
# check hidden size
|
||||
self.assertListEqual(
|
||||
[layer_hidden_states.shape for layer_hidden_states in iter_hidden_states],
|
||||
[expected_shape] * len(iter_hidden_states),
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
class ReformerIntegrationTests(unittest.TestCase):
|
||||
"""
|
||||
These integration tests test the current layer activations and gradients againts the output of the Hugging Face Reformer model at time of integration: 29/06/2020. During integration, the model was tested against the output of the official Trax ReformerLM model for various cases ("lsh" only, "local" only, masked / non-masked, different chunk length, ....). In order to recover the original trax integration tests, one should use patrickvonplaten's fork of trax and the code that lives on the branch `reformer_trax_tests`.
|
||||
These integration tests test the current layer activations and gradients againts the output of the Hugging Face Reformer model at time of integration: 29/06/2020. During integration, the model was tested against the output of the official Trax ReformerLM model for various cases ("lsh" only, "lsh" only, masked / non-masked, different chunk length, ....). In order to recover the original trax integration tests, one should use patrickvonplaten's fork of trax and the code that lives on the branch `reformer_trax_tests`.
|
||||
"""
|
||||
|
||||
def _get_basic_config_and_input(self):
|
||||
|
@ -304,6 +304,50 @@ class TransfoXLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestC
|
||||
# transfo-xl requires special resize for lm-head
|
||||
return
|
||||
|
||||
def _check_attentions_for_generate(
|
||||
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
||||
):
|
||||
self.assertIsInstance(attentions, tuple)
|
||||
self.assertListEqual(
|
||||
[isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions)
|
||||
)
|
||||
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups)
|
||||
|
||||
for idx, iter_attentions in enumerate(attentions):
|
||||
tgt_len = min_length if idx == 0 else (min_length - 2)
|
||||
src_len = (min_length + config.mem_len) if idx == 0 else (min_length + config.mem_len - 2)
|
||||
|
||||
expected_shape = (
|
||||
batch_size * num_beam_groups,
|
||||
config.num_attention_heads,
|
||||
tgt_len,
|
||||
src_len,
|
||||
)
|
||||
|
||||
# check attn size
|
||||
self.assertListEqual(
|
||||
[layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions)
|
||||
)
|
||||
|
||||
def _check_hidden_states_for_generate(
|
||||
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
||||
):
|
||||
self.assertIsInstance(hidden_states, tuple)
|
||||
self.assertListEqual(
|
||||
[isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states],
|
||||
[True] * len(hidden_states),
|
||||
)
|
||||
self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups)
|
||||
|
||||
for idx, iter_hidden_states in enumerate(hidden_states):
|
||||
seq_len = min_length if idx == 0 else min_length - 2
|
||||
expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size)
|
||||
# check hidden size
|
||||
self.assertListEqual(
|
||||
[layer_hidden_states.shape for layer_hidden_states in iter_hidden_states],
|
||||
[expected_shape] * len(iter_hidden_states),
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
class TransfoXLModelLanguageGenerationTest(unittest.TestCase):
|
||||
|
@ -400,6 +400,52 @@ class XLMModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_xlm_for_multiple_choice(*config_and_inputs)
|
||||
|
||||
def _check_attentions_for_generate(
|
||||
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
||||
):
|
||||
self.assertIsInstance(attentions, tuple)
|
||||
self.assertListEqual(
|
||||
[isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions)
|
||||
)
|
||||
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups)
|
||||
|
||||
for idx, iter_attentions in enumerate(attentions):
|
||||
# adds PAD dummy token
|
||||
tgt_len = min_length + idx + 1
|
||||
src_len = min_length + idx + 1
|
||||
|
||||
expected_shape = (
|
||||
batch_size * num_beam_groups,
|
||||
config.num_attention_heads,
|
||||
tgt_len,
|
||||
src_len,
|
||||
)
|
||||
# check attn size
|
||||
self.assertListEqual(
|
||||
[layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions)
|
||||
)
|
||||
|
||||
def _check_hidden_states_for_generate(
|
||||
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
||||
):
|
||||
self.assertIsInstance(hidden_states, tuple)
|
||||
self.assertListEqual(
|
||||
[isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states],
|
||||
[True] * len(hidden_states),
|
||||
)
|
||||
self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups)
|
||||
|
||||
for idx, iter_hidden_states in enumerate(hidden_states):
|
||||
# adds PAD dummy token
|
||||
seq_len = min_length + idx + 1
|
||||
expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size)
|
||||
# check hidden size
|
||||
self.assertListEqual(
|
||||
[layer_hidden_states.shape for layer_hidden_states in iter_hidden_states],
|
||||
[expected_shape] * len(iter_hidden_states),
|
||||
)
|
||||
pass
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in XLM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
|
@ -593,6 +593,60 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
|
||||
# xlnet cannot keep gradients in attentions or hidden states
|
||||
return
|
||||
|
||||
def _check_hidden_states_for_generate(
|
||||
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
||||
):
|
||||
self.assertIsInstance(hidden_states, tuple)
|
||||
self.assertListEqual(
|
||||
[isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states],
|
||||
[True] * len(hidden_states),
|
||||
)
|
||||
self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups)
|
||||
|
||||
for idx, iter_hidden_states in enumerate(hidden_states):
|
||||
# check hidden size
|
||||
for i, layer_hidden_states in enumerate(iter_hidden_states):
|
||||
# every 2nd tensor is from extra stream
|
||||
if i % 2 != 0:
|
||||
seq_len = 1
|
||||
else:
|
||||
# for first item dummy PAD token is appended so need one more
|
||||
seq_len = (min_length + 1) if idx == 0 else min_length
|
||||
|
||||
expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size)
|
||||
self.assertEqual(layer_hidden_states.shape, expected_shape)
|
||||
|
||||
def _check_attentions_for_generate(
|
||||
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
||||
):
|
||||
self.assertIsInstance(attentions, tuple)
|
||||
self.assertListEqual(
|
||||
[isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions)
|
||||
)
|
||||
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups)
|
||||
|
||||
for idx, attentions_item in enumerate(attentions):
|
||||
for iter_attentions in attentions_item:
|
||||
tgt_len = min_length
|
||||
|
||||
# for first item dummy PAD token is appended so need one more
|
||||
if idx == 0:
|
||||
tgt_len += 1
|
||||
|
||||
src_len = min_length + idx + 1
|
||||
|
||||
expected_shape = (
|
||||
batch_size * num_beam_groups,
|
||||
config.num_attention_heads,
|
||||
tgt_len,
|
||||
src_len,
|
||||
)
|
||||
# check attn size
|
||||
self.assertListEqual(
|
||||
[layer_attention.shape for layer_attention in iter_attentions],
|
||||
[expected_shape] * len(iter_attentions),
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in XLNET_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
|
Loading…
Reference in New Issue
Block a user