mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 13:20:12 +06:00

* Remove disclaimer * First draft * Fix rebase * Improve docs some more * Add inference section * Improve example scripts section * Improve code examples of modeling files * Add docs regarding task prefix * Address @craffel's comments * Apply suggestions from @patrickvonplaten's review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Add suggestions from code review * Apply @sgugger's suggestions * Fix Flax code examples * Fix index.rst Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
367 lines
19 KiB
ReStructuredText
367 lines
19 KiB
ReStructuredText
..
|
|
Copyright 2020 The HuggingFace Team. All rights reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
|
the License. You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
|
specific language governing permissions and limitations under the License.
|
|
|
|
T5
|
|
-----------------------------------------------------------------------------------------------------------------------
|
|
|
|
Overview
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
The T5 model was presented in `Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer
|
|
<https://arxiv.org/pdf/1910.10683.pdf>`_ by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang,
|
|
Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu.
|
|
|
|
The abstract from the paper is the following:
|
|
|
|
*Transfer learning, where a model is first pre-trained on a data-rich task before being fine-tuned on a downstream
|
|
task, has emerged as a powerful technique in natural language processing (NLP). The effectiveness of transfer learning
|
|
has given rise to a diversity of approaches, methodology, and practice. In this paper, we explore the landscape of
|
|
transfer learning techniques for NLP by introducing a unified framework that converts every language problem into a
|
|
text-to-text format. Our systematic study compares pretraining objectives, architectures, unlabeled datasets, transfer
|
|
approaches, and other factors on dozens of language understanding tasks. By combining the insights from our exploration
|
|
with scale and our new "Colossal Clean Crawled Corpus", we achieve state-of-the-art results on many benchmarks covering
|
|
summarization, question answering, text classification, and more. To facilitate future work on transfer learning for
|
|
NLP, we release our dataset, pre-trained models, and code.*
|
|
|
|
Tips:
|
|
|
|
- T5 is an encoder-decoder model pre-trained on a multi-task mixture of unsupervised and supervised tasks and for which
|
|
each task is converted into a text-to-text format. T5 works well on a variety of tasks out-of-the-box by prepending a
|
|
different prefix to the input corresponding to each task, e.g., for translation: *translate English to German: ...*,
|
|
for summarization: *summarize: ...*.
|
|
|
|
- T5 uses relative scalar embeddings. Encoder input padding can be done on the left and on the right.
|
|
|
|
- See the :ref:`training`, :ref:`inference` and :ref:`scripts` sections below for all details regarding usage.
|
|
|
|
T5 comes in different sizes:
|
|
|
|
- `t5-small <https://huggingface.co/t5-small>`__
|
|
|
|
- `t5-base <https://huggingface.co/t5-base>`__
|
|
|
|
- `t5-large <https://huggingface.co/t5-large>`__
|
|
|
|
- `t5-3b <https://huggingface.co/t5-3b>`__
|
|
|
|
- `t5-11b <https://huggingface.co/t5-11b>`__.
|
|
|
|
Based on the original T5 model, Google has released some follow-up works:
|
|
|
|
- **T5v1.1**: T5v1.1 is an improved version of T5 with some architectural tweaks, and is pre-trained on C4 only without
|
|
mixing in the supervised tasks. Refer to the documentation of T5v1.1 which can be found :doc:`here <t5v1.1>`.
|
|
|
|
- **mT5**: mT5 is a multilingual T5 model. It is pre-trained on the mC4 corpus, which includes 101 languages. Refer to
|
|
the documentation of mT5 which can be found :doc:`here <mt5>`.
|
|
|
|
- **byT5**: byT5 is a T5 model pre-trained on byte sequences rather than SentencePiece subword token sequences. Refer
|
|
to the documentation of byT5 which can be found :doc:`here <byt5>`.
|
|
|
|
All checkpoints can be found on the `hub <https://huggingface.co/models?search=t5>`__.
|
|
|
|
This model was contributed by `thomwolf <https://huggingface.co/thomwolf>`__. The original code can be found `here
|
|
<https://github.com/google-research/text-to-text-transfer-transformer>`__.
|
|
|
|
.. _training:
|
|
|
|
Training
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
T5 is an encoder-decoder model and converts all NLP problems into a text-to-text format. It is trained using teacher
|
|
forcing. This means that for training, we always need an input sequence and a corresponding target sequence. The input
|
|
sequence is fed to the model using :obj:`input_ids`. The target sequence is shifted to the right, i.e., prepended by a
|
|
start-sequence token and fed to the decoder using the :obj:`decoder_input_ids`. In teacher-forcing style, the target
|
|
sequence is then appended by the EOS token and corresponds to the :obj:`labels`. The PAD token is hereby used as the
|
|
start-sequence token. T5 can be trained / fine-tuned both in a supervised and unsupervised fashion.
|
|
|
|
One can use :class:`~transformers.T5ForConditionalGeneration` (or the Tensorflow/Flax variant), which includes the
|
|
language modeling head on top of the decoder.
|
|
|
|
- Unsupervised denoising training
|
|
|
|
In this setup, spans of the input sequence are masked by so-called sentinel tokens (*a.k.a* unique mask tokens) and
|
|
the output sequence is formed as a concatenation of the same sentinel tokens and the *real* masked tokens. Each
|
|
sentinel token represents a unique mask token for this sentence and should start with :obj:`<extra_id_0>`,
|
|
:obj:`<extra_id_1>`, ... up to :obj:`<extra_id_99>`. As a default, 100 sentinel tokens are available in
|
|
:class:`~transformers.T5Tokenizer`.
|
|
|
|
For instance, the sentence "The cute dog walks in the park" with the masks put on "cute dog" and "the" should be
|
|
processed as follows:
|
|
|
|
.. code-block::
|
|
|
|
from transformers import T5Tokenizer, T5ForConditionalGeneration
|
|
|
|
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
|
model = T5ForConditionalGeneration.from_pretrained("t5-small")
|
|
|
|
input_ids = tokenizer('The <extra_id_0> walks in <extra_id_1> park', return_tensors='pt').input_ids
|
|
labels = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2>', return_tensors='pt').input_ids
|
|
# the forward function automatically creates the correct decoder_input_ids
|
|
loss = model(input_ids=input_ids, labels=labels).loss
|
|
|
|
If you're interested in pre-training T5 on a new corpus, check out the `run_t5_mlm_flax.py
|
|
<https://github.com/huggingface/transformers/tree/master/examples/flax/language-modeling>`__ script in the Examples
|
|
directory.
|
|
|
|
- Supervised training
|
|
|
|
In this setup, the input sequence and output sequence are a standard sequence-to-sequence input-output mapping.
|
|
Suppose that we want to fine-tune the model for translation for example, and we have a training example: the input
|
|
sequence "The house is wonderful." and output sequence "Das Haus ist wunderbar.", then they should be prepared for
|
|
the model as follows:
|
|
|
|
.. code-block::
|
|
|
|
from transformers import T5Tokenizer, T5ForConditionalGeneration
|
|
|
|
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
|
model = T5ForConditionalGeneration.from_pretrained("t5-small")
|
|
|
|
input_ids = tokenizer('translate English to German: The house is wonderful.', return_tensors='pt').input_ids
|
|
labels = tokenizer('Das Haus ist wunderbar.', return_tensors='pt').input_ids
|
|
# the forward function automatically creates the correct decoder_input_ids
|
|
loss = model(input_ids=input_ids, labels=labels).loss
|
|
|
|
As you can see, only 2 inputs are required for the model in order to compute a loss: :obj:`input_ids` (which are the
|
|
:obj:`input_ids` of the encoded input sequence) and :obj:`labels` (which are the :obj:`input_ids` of the encoded
|
|
target sequence). The model will automatically create the :obj:`decoder_input_ids` based on the :obj:`labels`, by
|
|
shifting them one position to the right and prepending the :obj:`config.decoder_start_token_id`, which for T5 is
|
|
equal to 0 (i.e. the id of the pad token). Also note the task prefix: we prepend the input sequence with 'translate
|
|
English to German: ' before encoding it. This will help in improving the performance, as this task prefix was used
|
|
during T5's pre-training.
|
|
|
|
However, the example above only shows a single training example. In practice, one trains deep learning models in
|
|
batches. This entails that we must pad/truncate examples to the same length. For encoder-decoder models, one
|
|
typically defines a :obj:`max_source_length` and :obj:`max_target_length`, which determine the maximum length of the
|
|
input and output sequences respectively (otherwise they are truncated). These should be carefully set depending on
|
|
the task.
|
|
|
|
In addition, we must make sure that padding token id's of the :obj:`labels` are not taken into account by the loss
|
|
function. In PyTorch and Tensorflow, this can be done by replacing them with -100, which is the :obj:`ignore_index`
|
|
of the :obj:`CrossEntropyLoss`. In Flax, one can use the :obj:`decoder_attention_mask` to ignore padded tokens from
|
|
the loss (see the `Flax summarization script
|
|
<https://github.com/huggingface/transformers/tree/master/examples/flax/summarization>`__ for details). We also pass
|
|
:obj:`attention_mask` as additional input to the model, which makes sure that padding tokens of the inputs are
|
|
ignored. The code example below illustrates all of this.
|
|
|
|
.. code-block::
|
|
|
|
from transformers import T5Tokenizer, T5ForConditionalGeneration
|
|
import torch
|
|
|
|
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
|
model = T5ForConditionalGeneration.from_pretrained("t5-small")
|
|
|
|
# the following 2 hyperparameters are task-specific
|
|
max_source_length = 512
|
|
max_target_length = 128
|
|
|
|
# Suppose we have the following 2 training examples:
|
|
input_sequence_1 = "Welcome to NYC"
|
|
output_sequence_1 = "Bienvenue à NYC"
|
|
|
|
input_sequence_2 = "HuggingFace is a company"
|
|
output_sequence_2 = "HuggingFace est une entreprise"
|
|
|
|
# encode the inputs
|
|
task_prefix = "translate English to French: "
|
|
input_sequences = [input_sequence_1, input_sequence_2]
|
|
encoding = tokenizer([task_prefix + sequence for sequence in input_sequences],
|
|
padding='longest',
|
|
max_length=max_source_length,
|
|
truncation=True,
|
|
return_tensors="pt")
|
|
input_ids, attention_mask = encoding.input_ids, encoding.attention_mask
|
|
|
|
# encode the targets
|
|
target_encoding = tokenizer([output_sequence_1, output_sequence_2],
|
|
padding='longest',
|
|
max_length=max_target_length,
|
|
truncation=True)
|
|
labels = target_encoding.input_ids
|
|
|
|
# replace padding token id's of the labels by -100
|
|
labels = [
|
|
[(label if label != tokenizer.pad_token_id else -100) for label in labels_example] for labels_example in labels
|
|
]
|
|
labels = torch.tensor(labels)
|
|
|
|
# forward pass
|
|
loss = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels).loss
|
|
|
|
Additional training tips:
|
|
|
|
- T5 models need a slightly higher learning rate than the default one set in the :obj:`Trainer` when using the AdamW
|
|
optimizer. Typically, 1e-4 and 3e-4 work well for most problems (classification, summarization, translation, question
|
|
answering, question generation). Note that T5 was pre-trained using the AdaFactor optimizer.
|
|
|
|
- According to `this forum post <https://discuss.huggingface.co/t/t5-finetuning-tips/684>`__, task prefixes matter when
|
|
(1) doing multi-task training (2) your task is similar or related to one of the supervised tasks used in T5's
|
|
pre-training mixture (see Appendix D of the `paper <https://arxiv.org/pdf/1910.10683.pdf>`__ for the task prefixes
|
|
used).
|
|
|
|
- If training on TPU, it is recommended to pad all examples of the dataset to the same length or make use of
|
|
`pad_to_multiple_of` to have a small number of predefined bucket sizes to fit all examples in. Dynamically padding
|
|
batches to the longest example is not recommended on TPU as it triggers a recompilation for every batch shape that is
|
|
encountered during training thus significantly slowing down the training. only padding up to the longest example in a
|
|
batch) leads to very slow training on TPU.
|
|
|
|
.. _inference:
|
|
|
|
Inference
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
At inference time, it is recommended to use :meth:`~transformers.generation_utils.GenerationMixin.generate`. This
|
|
method takes care of encoding the input and feeding the encoded hidden states via cross-attention layers to the decoder
|
|
and auto-regressively generates the decoder output. Check out `this blog post
|
|
<https://huggingface.co/blog/how-to-generate>`__ to know all the details about generating text with Transformers.
|
|
There's also `this blog post <https://huggingface.co/blog/encoder-decoder#encoder-decoder>`__ which explains how
|
|
generation works in general in encoder-decoder models.
|
|
|
|
.. code-block::
|
|
|
|
from transformers import T5Tokenizer, T5ForConditionalGeneration
|
|
|
|
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
|
model = T5ForConditionalGeneration.from_pretrained("t5-small")
|
|
|
|
input_ids = tokenizer('translate English to German: The house is wonderful.', return_tensors='pt').input_ids
|
|
outputs = model.generate(input_ids)
|
|
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
|
# Das Haus ist wunderbar.
|
|
|
|
Note that T5 uses the :obj:`pad_token_id` as the :obj:`decoder_start_token_id`, so when doing generation without using
|
|
:meth:`~transformers.generation_utils.GenerationMixin.generate`, make sure you start it with the :obj:`pad_token_id`.
|
|
|
|
The example above only shows a single example. You can also do batched inference, like so:
|
|
|
|
.. code-block::
|
|
|
|
from transformers import T5Tokenizer, T5ForConditionalGeneration
|
|
|
|
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
|
model = T5ForConditionalGeneration.from_pretrained("t5-small")
|
|
|
|
# when generating, we will use the logits of right-most token to predict the next token
|
|
# so the padding should be on the left
|
|
tokenizer.padding_side = "left"
|
|
tokenizer.pad_token = tokenizer.eos_token # to avoid an error
|
|
|
|
task_prefix = 'translate English to German: '
|
|
sentences = ['The house is wonderful.', 'I like to work in NYC.'] # use different length sentences to test batching
|
|
inputs = tokenizer([task_prefix + sentence for sentence in sentences], return_tensors="pt", padding=True)
|
|
|
|
output_sequences = model.generate(
|
|
input_ids=inputs['input_ids'],
|
|
attention_mask=inputs['attention_mask'],
|
|
do_sample=False, # disable sampling to test if batching affects output
|
|
)
|
|
|
|
print(tokenizer.batch_decode(output_sequences, skip_special_tokens=True))
|
|
|
|
# ['Das Haus ist wunderbar.', 'Ich arbeite gerne in NYC.']
|
|
|
|
.. _scripts:
|
|
|
|
Example scripts
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
T5 is supported by several example scripts, both for pre-training and fine-tuning.
|
|
|
|
* pre-training: the `run_t5_mlm_flax.py
|
|
<https://github.com/huggingface/transformers/blob/master/examples/flax/language-modeling/run_t5_mlm_flax.py>`__
|
|
script allows you to further pre-train T5 or pre-train T5 from scratch on your own data. The `t5_tokenizer_model.py
|
|
<https://github.com/huggingface/transformers/blob/master/examples/flax/language-modeling/t5_tokenizer_model.py>`__
|
|
script allows you to further train a T5 tokenizer or train a T5 Tokenizer from scratch on your own data. Note that
|
|
Flax (a neural network library on top of JAX) is particularly useful to train on TPU hardware.
|
|
|
|
* fine-tuning: T5 is supported by the official summarization scripts (`PyTorch
|
|
<https://github.com/huggingface/transformers/tree/master/examples/pytorch/summarization>`__, `Tensorflow
|
|
<https://github.com/huggingface/transformers/tree/master/examples/tensorflow/summarization>`__, and `Flax
|
|
<https://github.com/huggingface/transformers/tree/master/examples/flax/summarization>`__) and translation scripts
|
|
(`PyTorch <https://github.com/huggingface/transformers/tree/master/examples/pytorch/translation>`__ and `Tensorflow
|
|
<https://github.com/huggingface/transformers/tree/master/examples/tensorflow/translation>`__). These scripts allow
|
|
you to easily fine-tune T5 on custom data for summarization/translation.
|
|
|
|
T5Config
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. autoclass:: transformers.T5Config
|
|
:members:
|
|
|
|
|
|
T5Tokenizer
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. autoclass:: transformers.T5Tokenizer
|
|
:members: build_inputs_with_special_tokens, get_special_tokens_mask,
|
|
create_token_type_ids_from_sequences, save_vocabulary
|
|
|
|
|
|
T5TokenizerFast
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. autoclass:: transformers.T5TokenizerFast
|
|
:members:
|
|
|
|
|
|
T5Model
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. autoclass:: transformers.T5Model
|
|
:members: forward, parallelize, deparallelize
|
|
|
|
|
|
T5ForConditionalGeneration
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. autoclass:: transformers.T5ForConditionalGeneration
|
|
:members: forward, parallelize, deparallelize
|
|
|
|
T5EncoderModel
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. autoclass:: transformers.T5EncoderModel
|
|
:members: forward, parallelize, deparallelize
|
|
|
|
TFT5Model
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. autoclass:: transformers.TFT5Model
|
|
:members: call
|
|
|
|
|
|
TFT5ForConditionalGeneration
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. autoclass:: transformers.TFT5ForConditionalGeneration
|
|
:members: call
|
|
|
|
TFT5EncoderModel
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. autoclass:: transformers.TFT5EncoderModel
|
|
:members: call
|
|
|
|
FlaxT5Model
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. autoclass:: transformers.FlaxT5Model
|
|
:members: __call__, encode, decode
|
|
|
|
FlaxT5ForConditionalGeneration
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. autoclass:: transformers.FlaxT5ForConditionalGeneration
|
|
:members: __call__, encode, decode
|