mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 21:30:07 +06:00

* Define new output dataclasses for greedy generation * Add output_[...] flags in greedy generation methods Added output_attentions, output_hidden_states, output_scores flags in generate and greedy_search methods in GenerationMixin. * [WIP] Implement logic and tests for output flags in generation * Update GreedySearchOutput classes & docstring * Implement greedy search output accumulation logic Update greedy_search unittests Fix generate method return value docstring Properly init flags with the default config * Update configuration to add output_scores flag * Fix test_generation_utils Sort imports and fix isinstance tests for GreedySearchOutputs * Fix typo in generation_utils * Add return_dict_in_generate for backwards compatibility * Add return_dict_in_generate flag in config * Fix tyPo in configuration * Fix handling of attentions and hidden_states flags * Make style & quality * first attempt attentions * some corrections * improve tests * special models requires special test * disable xlm test for now * clean tests * fix for tf * isort * Add output dataclasses for other generation methods * Add logic to return dict in sample generation * Complete test for sample generation - Pass output_attentions and output_hidden_states flags to encoder in encoder-decoder models - Fix import satements order in test_generation_utils file * Add logic to return dict in sample generation - Refactor tests to avoid using self.assertTrue, which provides scarce information when the test fails - Add tests for the three beam_search methods: vanilla, sample and grouped * Style doc * Fix copy-paste error in generation tests * Rename logits to scores and refactor * Refactor group_beam_search for consistency * make style * add sequences_scores * fix all tests * add docs * fix beam search finalize test * correct docstring * clean some files * Made suggested changes to the documentation * Style doc ? * Style doc using the Python util * Update src/transformers/generation_utils.py * fix empty lines * fix all test Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
169 lines
6.5 KiB
ReStructuredText
169 lines
6.5 KiB
ReStructuredText
..
|
|
Copyright 2020 The HuggingFace Team. All rights reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
|
the License. You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
|
specific language governing permissions and limitations under the License.
|
|
|
|
Utilities for Generation
|
|
-----------------------------------------------------------------------------------------------------------------------
|
|
|
|
This page lists all the utility functions used by :meth:`~transformers.PreTrainedModel.generate`,
|
|
:meth:`~transformers.PreTrainedModel.greedy_search`, :meth:`~transformers.PreTrainedModel.sample`,
|
|
:meth:`~transformers.PreTrainedModel.beam_search`, :meth:`~transformers.PreTrainedModel.beam_sample`, and
|
|
:meth:`~transformers.PreTrainedModel.group_beam_search`.
|
|
|
|
Most of those are only useful if you are studying the code of the generate methods in the library.
|
|
|
|
Generate Outputs
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
The output of :meth:`~transformers.PreTrainedModel.generate` is an instance of a subclass of
|
|
:class:`~transformers.file_utils.ModelOutput`. This output is a data structure containing all the information returned
|
|
by :meth:`~transformers.PreTrainedModel.generate`, but that can also be used as tuple or dictionary.
|
|
|
|
Here's an example:
|
|
|
|
.. code-block::
|
|
|
|
from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
|
|
|
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
|
model = GPT2LMHeadModel.from_pretrained('gpt2')
|
|
|
|
inputs = tokenizer("Hello, my dog is cute and ", return_tensors="pt")
|
|
generation_output = model.generate(**inputs, return_dict_in_generate=True, output_scores=True)
|
|
|
|
The ``generation_output`` object is a :class:`~transformers.generation_utils.GreedySearchDecoderOnlyOutput`, as we can
|
|
see in the documentation of that class below, it means it has the following attributes:
|
|
|
|
- ``sequences``: the generated sequences of tokens
|
|
- ``scores`` (optional): the prediction scores of the language modelling head, for each generation step
|
|
- ``hidden_states`` (optional): the hidden states of the model, for each generation step
|
|
- ``attentions`` (optional): the attention weights of the model, for each generation step
|
|
|
|
Here we have the ``scores`` since we passed along ``output_scores=True``, but we don't have ``hidden_states`` and
|
|
``attentions`` because we didn't pass ``output_hidden_states=True`` or ``output_attentions=True``.
|
|
|
|
You can access each attribute as you would usually do, and if that attribute has not been returned by the model, you
|
|
will get ``None``. Here for instance ``generation_output.scores`` are all the generated prediction scores of the
|
|
language modeling head, and ``generation_output.attentions`` is ``None``.
|
|
|
|
When using our ``generation_output`` object as a tuple, it only keeps the attributes that don't have ``None`` values.
|
|
Here, for instance, it has two elements, ``loss`` then ``logits``, so
|
|
|
|
.. code-block::
|
|
|
|
generation_output[:2]
|
|
|
|
will return the tuple ``(generation_output.sequences, generation_output.scores)`` for instance.
|
|
|
|
When using our ``generation_output`` object as a dictionary, it only keeps the attributes that don't have ``None``
|
|
values. Here, for instance, it has two keys that are ``sequences`` and ``scores``.
|
|
|
|
We document here all output types.
|
|
|
|
|
|
GreedySearchOutput
|
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
|
|
.. autoclass:: transformers.generation_utils.GreedySearchDecoderOnlyOutput
|
|
:members:
|
|
|
|
.. autoclass:: transformers.generation_utils.GreedySearchEncoderDecoderOutput
|
|
:members:
|
|
|
|
|
|
SampleOutput
|
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
|
|
.. autoclass:: transformers.generation_utils.SampleDecoderOnlyOutput
|
|
:members:
|
|
|
|
.. autoclass:: transformers.generation_utils.SampleEncoderDecoderOutput
|
|
:members:
|
|
|
|
|
|
BeamSearchOutput
|
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
|
|
.. autoclass:: transformers.generation_utils.BeamSearchDecoderOnlyOutput
|
|
:members:
|
|
|
|
.. autoclass:: transformers.generation_utils.BeamSearchEncoderDecoderOutput
|
|
:members:
|
|
|
|
|
|
BeamSampleOutput
|
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
|
|
.. autoclass:: transformers.generation_utils.BeamSampleDecoderOnlyOutput
|
|
:members:
|
|
|
|
.. autoclass:: transformers.generation_utils.BeamSampleEncoderDecoderOutput
|
|
:members:
|
|
|
|
|
|
LogitsProcessor
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
A :class:`~transformers.LogitsProcessor` can be used to modify the prediction scores of a language model head for
|
|
generation.
|
|
|
|
.. autoclass:: transformers.LogitsProcessor
|
|
:members: __call__
|
|
|
|
.. autoclass:: transformers.LogitsProcessorList
|
|
:members: __call__
|
|
|
|
.. autoclass:: transformers.LogitsWarper
|
|
:members: __call__
|
|
|
|
.. autoclass:: transformers.MinLengthLogitsProcessor
|
|
:members: __call__
|
|
|
|
.. autoclass:: transformers.TemperatureLogitsWarper
|
|
:members: __call__
|
|
|
|
.. autoclass:: transformers.RepetitionPenaltyLogitsProcessor
|
|
:members: __call__
|
|
|
|
.. autoclass:: transformers.TopPLogitsWarper
|
|
:members: __call__
|
|
|
|
.. autoclass:: transformers.TopKLogitsWarper
|
|
:members: __call__
|
|
|
|
.. autoclass:: transformers.NoRepeatNGramLogitsProcessor
|
|
:members: __call__
|
|
|
|
.. autoclass:: transformers.NoBadWordsLogitsProcessor
|
|
:members: __call__
|
|
|
|
.. autoclass:: transformers.PrefixConstrainedLogitsProcessor
|
|
:members: __call__
|
|
|
|
.. autoclass:: transformers.HammingDiversityLogitsProcessor
|
|
:members: __call__
|
|
|
|
BeamSearch
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. autoclass:: transformers.BeamScorer
|
|
:members: process, finalize
|
|
|
|
.. autoclass:: transformers.BeamSearchScorer
|
|
:members: process, finalize
|
|
|
|
Utilities
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. autofunction:: transformers.top_k_top_p_filtering
|
|
|
|
.. autofunction:: transformers.tf_top_k_top_p_filtering
|