transformers/docs/source/model_doc/t5.rst
NielsRogge 4766e009b0
Improve T5 docs (#13240)
* Remove disclaimer

* First draft

* Fix rebase

* Improve docs some more

* Add inference section

* Improve example scripts section

* Improve code examples of modeling files

* Add docs regarding task prefix

* Address @craffel's comments

* Apply suggestions from @patrickvonplaten's review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Add suggestions from code review

* Apply @sgugger's suggestions

* Fix Flax code examples

* Fix index.rst

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
2021-09-01 15:05:40 +02:00

367 lines
19 KiB
ReStructuredText

..
Copyright 2020 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
T5
-----------------------------------------------------------------------------------------------------------------------
Overview
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The T5 model was presented in `Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer
<https://arxiv.org/pdf/1910.10683.pdf>`_ by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang,
Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu.
The abstract from the paper is the following:
*Transfer learning, where a model is first pre-trained on a data-rich task before being fine-tuned on a downstream
task, has emerged as a powerful technique in natural language processing (NLP). The effectiveness of transfer learning
has given rise to a diversity of approaches, methodology, and practice. In this paper, we explore the landscape of
transfer learning techniques for NLP by introducing a unified framework that converts every language problem into a
text-to-text format. Our systematic study compares pretraining objectives, architectures, unlabeled datasets, transfer
approaches, and other factors on dozens of language understanding tasks. By combining the insights from our exploration
with scale and our new "Colossal Clean Crawled Corpus", we achieve state-of-the-art results on many benchmarks covering
summarization, question answering, text classification, and more. To facilitate future work on transfer learning for
NLP, we release our dataset, pre-trained models, and code.*
Tips:
- T5 is an encoder-decoder model pre-trained on a multi-task mixture of unsupervised and supervised tasks and for which
each task is converted into a text-to-text format. T5 works well on a variety of tasks out-of-the-box by prepending a
different prefix to the input corresponding to each task, e.g., for translation: *translate English to German: ...*,
for summarization: *summarize: ...*.
- T5 uses relative scalar embeddings. Encoder input padding can be done on the left and on the right.
- See the :ref:`training`, :ref:`inference` and :ref:`scripts` sections below for all details regarding usage.
T5 comes in different sizes:
- `t5-small <https://huggingface.co/t5-small>`__
- `t5-base <https://huggingface.co/t5-base>`__
- `t5-large <https://huggingface.co/t5-large>`__
- `t5-3b <https://huggingface.co/t5-3b>`__
- `t5-11b <https://huggingface.co/t5-11b>`__.
Based on the original T5 model, Google has released some follow-up works:
- **T5v1.1**: T5v1.1 is an improved version of T5 with some architectural tweaks, and is pre-trained on C4 only without
mixing in the supervised tasks. Refer to the documentation of T5v1.1 which can be found :doc:`here <t5v1.1>`.
- **mT5**: mT5 is a multilingual T5 model. It is pre-trained on the mC4 corpus, which includes 101 languages. Refer to
the documentation of mT5 which can be found :doc:`here <mt5>`.
- **byT5**: byT5 is a T5 model pre-trained on byte sequences rather than SentencePiece subword token sequences. Refer
to the documentation of byT5 which can be found :doc:`here <byt5>`.
All checkpoints can be found on the `hub <https://huggingface.co/models?search=t5>`__.
This model was contributed by `thomwolf <https://huggingface.co/thomwolf>`__. The original code can be found `here
<https://github.com/google-research/text-to-text-transfer-transformer>`__.
.. _training:
Training
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
T5 is an encoder-decoder model and converts all NLP problems into a text-to-text format. It is trained using teacher
forcing. This means that for training, we always need an input sequence and a corresponding target sequence. The input
sequence is fed to the model using :obj:`input_ids`. The target sequence is shifted to the right, i.e., prepended by a
start-sequence token and fed to the decoder using the :obj:`decoder_input_ids`. In teacher-forcing style, the target
sequence is then appended by the EOS token and corresponds to the :obj:`labels`. The PAD token is hereby used as the
start-sequence token. T5 can be trained / fine-tuned both in a supervised and unsupervised fashion.
One can use :class:`~transformers.T5ForConditionalGeneration` (or the Tensorflow/Flax variant), which includes the
language modeling head on top of the decoder.
- Unsupervised denoising training
In this setup, spans of the input sequence are masked by so-called sentinel tokens (*a.k.a* unique mask tokens) and
the output sequence is formed as a concatenation of the same sentinel tokens and the *real* masked tokens. Each
sentinel token represents a unique mask token for this sentence and should start with :obj:`<extra_id_0>`,
:obj:`<extra_id_1>`, ... up to :obj:`<extra_id_99>`. As a default, 100 sentinel tokens are available in
:class:`~transformers.T5Tokenizer`.
For instance, the sentence "The cute dog walks in the park" with the masks put on "cute dog" and "the" should be
processed as follows:
.. code-block::
from transformers import T5Tokenizer, T5ForConditionalGeneration
tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5ForConditionalGeneration.from_pretrained("t5-small")
input_ids = tokenizer('The <extra_id_0> walks in <extra_id_1> park', return_tensors='pt').input_ids
labels = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2>', return_tensors='pt').input_ids
# the forward function automatically creates the correct decoder_input_ids
loss = model(input_ids=input_ids, labels=labels).loss
If you're interested in pre-training T5 on a new corpus, check out the `run_t5_mlm_flax.py
<https://github.com/huggingface/transformers/tree/master/examples/flax/language-modeling>`__ script in the Examples
directory.
- Supervised training
In this setup, the input sequence and output sequence are a standard sequence-to-sequence input-output mapping.
Suppose that we want to fine-tune the model for translation for example, and we have a training example: the input
sequence "The house is wonderful." and output sequence "Das Haus ist wunderbar.", then they should be prepared for
the model as follows:
.. code-block::
from transformers import T5Tokenizer, T5ForConditionalGeneration
tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5ForConditionalGeneration.from_pretrained("t5-small")
input_ids = tokenizer('translate English to German: The house is wonderful.', return_tensors='pt').input_ids
labels = tokenizer('Das Haus ist wunderbar.', return_tensors='pt').input_ids
# the forward function automatically creates the correct decoder_input_ids
loss = model(input_ids=input_ids, labels=labels).loss
As you can see, only 2 inputs are required for the model in order to compute a loss: :obj:`input_ids` (which are the
:obj:`input_ids` of the encoded input sequence) and :obj:`labels` (which are the :obj:`input_ids` of the encoded
target sequence). The model will automatically create the :obj:`decoder_input_ids` based on the :obj:`labels`, by
shifting them one position to the right and prepending the :obj:`config.decoder_start_token_id`, which for T5 is
equal to 0 (i.e. the id of the pad token). Also note the task prefix: we prepend the input sequence with 'translate
English to German: ' before encoding it. This will help in improving the performance, as this task prefix was used
during T5's pre-training.
However, the example above only shows a single training example. In practice, one trains deep learning models in
batches. This entails that we must pad/truncate examples to the same length. For encoder-decoder models, one
typically defines a :obj:`max_source_length` and :obj:`max_target_length`, which determine the maximum length of the
input and output sequences respectively (otherwise they are truncated). These should be carefully set depending on
the task.
In addition, we must make sure that padding token id's of the :obj:`labels` are not taken into account by the loss
function. In PyTorch and Tensorflow, this can be done by replacing them with -100, which is the :obj:`ignore_index`
of the :obj:`CrossEntropyLoss`. In Flax, one can use the :obj:`decoder_attention_mask` to ignore padded tokens from
the loss (see the `Flax summarization script
<https://github.com/huggingface/transformers/tree/master/examples/flax/summarization>`__ for details). We also pass
:obj:`attention_mask` as additional input to the model, which makes sure that padding tokens of the inputs are
ignored. The code example below illustrates all of this.
.. code-block::
from transformers import T5Tokenizer, T5ForConditionalGeneration
import torch
tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5ForConditionalGeneration.from_pretrained("t5-small")
# the following 2 hyperparameters are task-specific
max_source_length = 512
max_target_length = 128
# Suppose we have the following 2 training examples:
input_sequence_1 = "Welcome to NYC"
output_sequence_1 = "Bienvenue à NYC"
input_sequence_2 = "HuggingFace is a company"
output_sequence_2 = "HuggingFace est une entreprise"
# encode the inputs
task_prefix = "translate English to French: "
input_sequences = [input_sequence_1, input_sequence_2]
encoding = tokenizer([task_prefix + sequence for sequence in input_sequences],
padding='longest',
max_length=max_source_length,
truncation=True,
return_tensors="pt")
input_ids, attention_mask = encoding.input_ids, encoding.attention_mask
# encode the targets
target_encoding = tokenizer([output_sequence_1, output_sequence_2],
padding='longest',
max_length=max_target_length,
truncation=True)
labels = target_encoding.input_ids
# replace padding token id's of the labels by -100
labels = [
[(label if label != tokenizer.pad_token_id else -100) for label in labels_example] for labels_example in labels
]
labels = torch.tensor(labels)
# forward pass
loss = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels).loss
Additional training tips:
- T5 models need a slightly higher learning rate than the default one set in the :obj:`Trainer` when using the AdamW
optimizer. Typically, 1e-4 and 3e-4 work well for most problems (classification, summarization, translation, question
answering, question generation). Note that T5 was pre-trained using the AdaFactor optimizer.
- According to `this forum post <https://discuss.huggingface.co/t/t5-finetuning-tips/684>`__, task prefixes matter when
(1) doing multi-task training (2) your task is similar or related to one of the supervised tasks used in T5's
pre-training mixture (see Appendix D of the `paper <https://arxiv.org/pdf/1910.10683.pdf>`__ for the task prefixes
used).
- If training on TPU, it is recommended to pad all examples of the dataset to the same length or make use of
`pad_to_multiple_of` to have a small number of predefined bucket sizes to fit all examples in. Dynamically padding
batches to the longest example is not recommended on TPU as it triggers a recompilation for every batch shape that is
encountered during training thus significantly slowing down the training. only padding up to the longest example in a
batch) leads to very slow training on TPU.
.. _inference:
Inference
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
At inference time, it is recommended to use :meth:`~transformers.generation_utils.GenerationMixin.generate`. This
method takes care of encoding the input and feeding the encoded hidden states via cross-attention layers to the decoder
and auto-regressively generates the decoder output. Check out `this blog post
<https://huggingface.co/blog/how-to-generate>`__ to know all the details about generating text with Transformers.
There's also `this blog post <https://huggingface.co/blog/encoder-decoder#encoder-decoder>`__ which explains how
generation works in general in encoder-decoder models.
.. code-block::
from transformers import T5Tokenizer, T5ForConditionalGeneration
tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5ForConditionalGeneration.from_pretrained("t5-small")
input_ids = tokenizer('translate English to German: The house is wonderful.', return_tensors='pt').input_ids
outputs = model.generate(input_ids)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
# Das Haus ist wunderbar.
Note that T5 uses the :obj:`pad_token_id` as the :obj:`decoder_start_token_id`, so when doing generation without using
:meth:`~transformers.generation_utils.GenerationMixin.generate`, make sure you start it with the :obj:`pad_token_id`.
The example above only shows a single example. You can also do batched inference, like so:
.. code-block::
from transformers import T5Tokenizer, T5ForConditionalGeneration
tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5ForConditionalGeneration.from_pretrained("t5-small")
# when generating, we will use the logits of right-most token to predict the next token
# so the padding should be on the left
tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token # to avoid an error
task_prefix = 'translate English to German: '
sentences = ['The house is wonderful.', 'I like to work in NYC.'] # use different length sentences to test batching
inputs = tokenizer([task_prefix + sentence for sentence in sentences], return_tensors="pt", padding=True)
output_sequences = model.generate(
input_ids=inputs['input_ids'],
attention_mask=inputs['attention_mask'],
do_sample=False, # disable sampling to test if batching affects output
)
print(tokenizer.batch_decode(output_sequences, skip_special_tokens=True))
# ['Das Haus ist wunderbar.', 'Ich arbeite gerne in NYC.']
.. _scripts:
Example scripts
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
T5 is supported by several example scripts, both for pre-training and fine-tuning.
* pre-training: the `run_t5_mlm_flax.py
<https://github.com/huggingface/transformers/blob/master/examples/flax/language-modeling/run_t5_mlm_flax.py>`__
script allows you to further pre-train T5 or pre-train T5 from scratch on your own data. The `t5_tokenizer_model.py
<https://github.com/huggingface/transformers/blob/master/examples/flax/language-modeling/t5_tokenizer_model.py>`__
script allows you to further train a T5 tokenizer or train a T5 Tokenizer from scratch on your own data. Note that
Flax (a neural network library on top of JAX) is particularly useful to train on TPU hardware.
* fine-tuning: T5 is supported by the official summarization scripts (`PyTorch
<https://github.com/huggingface/transformers/tree/master/examples/pytorch/summarization>`__, `Tensorflow
<https://github.com/huggingface/transformers/tree/master/examples/tensorflow/summarization>`__, and `Flax
<https://github.com/huggingface/transformers/tree/master/examples/flax/summarization>`__) and translation scripts
(`PyTorch <https://github.com/huggingface/transformers/tree/master/examples/pytorch/translation>`__ and `Tensorflow
<https://github.com/huggingface/transformers/tree/master/examples/tensorflow/translation>`__). These scripts allow
you to easily fine-tune T5 on custom data for summarization/translation.
T5Config
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.T5Config
:members:
T5Tokenizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.T5Tokenizer
:members: build_inputs_with_special_tokens, get_special_tokens_mask,
create_token_type_ids_from_sequences, save_vocabulary
T5TokenizerFast
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.T5TokenizerFast
:members:
T5Model
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.T5Model
:members: forward, parallelize, deparallelize
T5ForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.T5ForConditionalGeneration
:members: forward, parallelize, deparallelize
T5EncoderModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.T5EncoderModel
:members: forward, parallelize, deparallelize
TFT5Model
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFT5Model
:members: call
TFT5ForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFT5ForConditionalGeneration
:members: call
TFT5EncoderModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFT5EncoderModel
:members: call
FlaxT5Model
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxT5Model
:members: __call__, encode, decode
FlaxT5ForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxT5ForConditionalGeneration
:members: __call__, encode, decode