mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
Improve T5 docs (#13240)
* Remove disclaimer * First draft * Fix rebase * Improve docs some more * Add inference section * Improve example scripts section * Improve code examples of modeling files * Add docs regarding task prefix * Address @craffel's comments * Apply suggestions from @patrickvonplaten's review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Add suggestions from code review * Apply @sgugger's suggestions * Fix Flax code examples * Fix index.rst Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
parent
ba1b3db709
commit
4766e009b0
@ -272,6 +272,7 @@ Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih.
|
||||
1. **[Splinter](https://huggingface.co/transformers/model_doc/splinter.html)** (from Tel Aviv University), released together with the paper [Few-Shot Question Answering by Pretraining Span Selection](https://arxiv.org/abs/2101.00438) by Ori Ram, Yuval Kirstain, Jonathan Berant, Amir Globerson, Omer Levy.
|
||||
1. **[SqueezeBert](https://huggingface.co/transformers/model_doc/squeezebert.html)** released with the paper [SqueezeBERT: What can computer vision teach NLP about efficient neural networks?](https://arxiv.org/abs/2006.11316) by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer.
|
||||
1. **[T5](https://huggingface.co/transformers/model_doc/t5.html)** (from Google AI) released with the paper [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu.
|
||||
1. **[T5v1.1](https://huggingface.co/transformers/model_doc/t5v1.1.html)** (from Google AI) released in the repository [google-research/text-to-text-transfer-transformer](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu.
|
||||
1. **[TAPAS](https://huggingface.co/transformers/model_doc/tapas.html)** (from Google AI) released with the paper [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349) by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos.
|
||||
1. **[Transformer-XL](https://huggingface.co/transformers/model_doc/transformerxl.html)** (from Google/CMU) released with the paper [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context](https://arxiv.org/abs/1901.02860) by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
|
||||
1. **[Vision Transformer (ViT)](https://huggingface.co/transformers/model_doc/vit.html)** (from Google AI) released with the paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby.
|
||||
|
@ -282,35 +282,40 @@ Supported models
|
||||
62. :doc:`T5 <model_doc/t5>` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a
|
||||
Unified Text-to-Text Transformer <https://arxiv.org/abs/1910.10683>`__ by Colin Raffel and Noam Shazeer and Adam
|
||||
Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu.
|
||||
63. :doc:`TAPAS <model_doc/tapas>` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via
|
||||
63. :doc:`T5v1.1 <model_doc/t5v1.1>` (from Google AI) released in the repository
|
||||
`google-research/text-to-text-transfer-transformer
|
||||
<https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511>`__ by
|
||||
Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi
|
||||
Zhou and Wei Li and Peter J. Liu.
|
||||
64. :doc:`TAPAS <model_doc/tapas>` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via
|
||||
Pre-training <https://arxiv.org/abs/2004.02349>`__ by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller,
|
||||
Francesco Piccinno and Julian Martin Eisenschlos.
|
||||
64. :doc:`Transformer-XL <model_doc/transformerxl>` (from Google/CMU) released with the paper `Transformer-XL:
|
||||
65. :doc:`Transformer-XL <model_doc/transformerxl>` (from Google/CMU) released with the paper `Transformer-XL:
|
||||
Attentive Language Models Beyond a Fixed-Length Context <https://arxiv.org/abs/1901.02860>`__ by Zihang Dai*,
|
||||
Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
|
||||
65. :doc:`Vision Transformer (ViT) <model_doc/vit>` (from Google AI) released with the paper `An Image is Worth 16x16
|
||||
66. :doc:`Vision Transformer (ViT) <model_doc/vit>` (from Google AI) released with the paper `An Image is Worth 16x16
|
||||
Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`__ by Alexey Dosovitskiy,
|
||||
Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias
|
||||
Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby.
|
||||
66. :doc:`VisualBERT <model_doc/visual_bert>` (from UCLA NLP) released with the paper `VisualBERT: A Simple and
|
||||
67. :doc:`VisualBERT <model_doc/visual_bert>` (from UCLA NLP) released with the paper `VisualBERT: A Simple and
|
||||
Performant Baseline for Vision and Language <https://arxiv.org/pdf/1908.03557>`__ by Liunian Harold Li, Mark
|
||||
Yatskar, Da Yin, Cho-Jui Hsieh, Kai-Wei Chang.
|
||||
67. :doc:`Wav2Vec2 <model_doc/wav2vec2>` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for
|
||||
68. :doc:`Wav2Vec2 <model_doc/wav2vec2>` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for
|
||||
Self-Supervised Learning of Speech Representations <https://arxiv.org/abs/2006.11477>`__ by Alexei Baevski, Henry
|
||||
Zhou, Abdelrahman Mohamed, Michael Auli.
|
||||
68. :doc:`XLM <model_doc/xlm>` (from Facebook) released together with the paper `Cross-lingual Language Model
|
||||
69. :doc:`XLM <model_doc/xlm>` (from Facebook) released together with the paper `Cross-lingual Language Model
|
||||
Pretraining <https://arxiv.org/abs/1901.07291>`__ by Guillaume Lample and Alexis Conneau.
|
||||
69. :doc:`XLM-ProphetNet <model_doc/xlmprophetnet>` (from Microsoft Research) released with the paper `ProphetNet:
|
||||
70. :doc:`XLM-ProphetNet <model_doc/xlmprophetnet>` (from Microsoft Research) released with the paper `ProphetNet:
|
||||
Predicting Future N-gram for Sequence-to-Sequence Pre-training <https://arxiv.org/abs/2001.04063>`__ by Yu Yan,
|
||||
Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou.
|
||||
70. :doc:`XLM-RoBERTa <model_doc/xlmroberta>` (from Facebook AI), released together with the paper `Unsupervised
|
||||
71. :doc:`XLM-RoBERTa <model_doc/xlmroberta>` (from Facebook AI), released together with the paper `Unsupervised
|
||||
Cross-lingual Representation Learning at Scale <https://arxiv.org/abs/1911.02116>`__ by Alexis Conneau*, Kartikay
|
||||
Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke
|
||||
Zettlemoyer and Veselin Stoyanov.
|
||||
71. :doc:`XLNet <model_doc/xlnet>` (from Google/CMU) released with the paper `XLNet: Generalized Autoregressive
|
||||
72. :doc:`XLNet <model_doc/xlnet>` (from Google/CMU) released with the paper `XLNet: Generalized Autoregressive
|
||||
Pretraining for Language Understanding <https://arxiv.org/abs/1906.08237>`__ by Zhilin Yang*, Zihang Dai*, Yiming
|
||||
Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le.
|
||||
72. :doc:`XLSR-Wav2Vec2 <model_doc/xlsr_wav2vec2>` (from Facebook AI) released with the paper `Unsupervised
|
||||
73. :doc:`XLSR-Wav2Vec2 <model_doc/xlsr_wav2vec2>` (from Facebook AI) released with the paper `Unsupervised
|
||||
Cross-Lingual Representation Learning For Speech Recognition <https://arxiv.org/abs/2006.13979>`__ by Alexis
|
||||
Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael Auli.
|
||||
|
||||
@ -608,6 +613,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
model_doc/splinter
|
||||
model_doc/squeezebert
|
||||
model_doc/t5
|
||||
model_doc/t5v1.1
|
||||
model_doc/tapas
|
||||
model_doc/transformerxl
|
||||
model_doc/vit
|
||||
|
@ -39,8 +39,11 @@ experiments.*
|
||||
This model was contributed by `patrickvonplaten <https://huggingface.co/patrickvonplaten>`__. The original code can be
|
||||
found `here <https://github.com/google-research/byt5>`__.
|
||||
|
||||
ByT5's architecture is based on the T5v1.1 model, so one can refer to :doc:`T5v1.1's documentation page <t5v1.1>`. They
|
||||
only differ in how inputs should be prepared for the model, see the code examples below.
|
||||
|
||||
ByT5's architecture is based on the T5 model, so one can refer to :doc:`T5's documentation page <t5>`.
|
||||
Since ByT5 was pre-trained unsupervisedly, there's no real advantage to using a task prefix during single-task
|
||||
fine-tuning. If you are doing multi-task fine-tuning, you should use a prefix.
|
||||
|
||||
|
||||
Example
|
||||
|
@ -28,6 +28,23 @@ multilingual variant of T5 that was pre-trained on a new Common Crawl-based data
|
||||
the design and modified training of mT5 and demonstrate its state-of-the-art performance on many multilingual
|
||||
benchmarks. All of the code and model checkpoints*
|
||||
|
||||
Note: mT5 was only pre-trained on `mC4 <https://huggingface.co/datasets/mc4>`__ excluding any supervised training.
|
||||
Therefore, this model has to be fine-tuned before it is useable on a downstream task, unlike the original T5 model.
|
||||
Since mT5 was pre-trained unsupervisedly, there's no real advantage to using a task prefix during single-task
|
||||
fine-tuning. If you are doing multi-task fine-tuning, you should use a prefix.
|
||||
|
||||
Google has released the following variants:
|
||||
|
||||
- `google/mt5-small <https://huggingface.co/google/mt5-small>`__
|
||||
|
||||
- `google/mt5-base <https://huggingface.co/google/mt5-base>`__
|
||||
|
||||
- `google/mt5-large <https://huggingface.co/google/mt5-large>`__
|
||||
|
||||
- `google/mt5-xl <https://huggingface.co/google/mt5-xl>`__
|
||||
|
||||
- `google/mt5-xxl <https://huggingface.co/google/mt5-xxl>`__.
|
||||
|
||||
This model was contributed by `patrickvonplaten <https://huggingface.co/patrickvonplaten>`__. The original code can be
|
||||
found `here <https://github.com/google-research/multilingual-t5>`__.
|
||||
|
||||
|
@ -13,9 +13,6 @@
|
||||
T5
|
||||
-----------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
**DISCLAIMER:** This model is still a work in progress, if you see something strange, file a `Github Issue
|
||||
<https://github.com/huggingface/transformers/issues/new?assignees=&labels=&template=bug-report.md&title>`__.
|
||||
|
||||
Overview
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
@ -42,28 +39,56 @@ Tips:
|
||||
different prefix to the input corresponding to each task, e.g., for translation: *translate English to German: ...*,
|
||||
for summarization: *summarize: ...*.
|
||||
|
||||
For more information about which prefix to use, it is easiest to look into Appendix D of the `paper
|
||||
<https://arxiv.org/pdf/1910.10683.pdf>`__. - For sequence-to-sequence generation, it is recommended to use
|
||||
:meth:`~transformers.generation_utils.GenerationMixin.generate`. This method takes care of feeding the encoded input
|
||||
via cross-attention layers to the decoder and auto-regressively generates the decoder output. - T5 uses relative
|
||||
scalar embeddings. Encoder input padding can be done on the left and on the right.
|
||||
- T5 uses relative scalar embeddings. Encoder input padding can be done on the left and on the right.
|
||||
|
||||
- See the :ref:`training`, :ref:`inference` and :ref:`scripts` sections below for all details regarding usage.
|
||||
|
||||
T5 comes in different sizes:
|
||||
|
||||
- `t5-small <https://huggingface.co/t5-small>`__
|
||||
|
||||
- `t5-base <https://huggingface.co/t5-base>`__
|
||||
|
||||
- `t5-large <https://huggingface.co/t5-large>`__
|
||||
|
||||
- `t5-3b <https://huggingface.co/t5-3b>`__
|
||||
|
||||
- `t5-11b <https://huggingface.co/t5-11b>`__.
|
||||
|
||||
Based on the original T5 model, Google has released some follow-up works:
|
||||
|
||||
- **T5v1.1**: T5v1.1 is an improved version of T5 with some architectural tweaks, and is pre-trained on C4 only without
|
||||
mixing in the supervised tasks. Refer to the documentation of T5v1.1 which can be found :doc:`here <t5v1.1>`.
|
||||
|
||||
- **mT5**: mT5 is a multilingual T5 model. It is pre-trained on the mC4 corpus, which includes 101 languages. Refer to
|
||||
the documentation of mT5 which can be found :doc:`here <mt5>`.
|
||||
|
||||
- **byT5**: byT5 is a T5 model pre-trained on byte sequences rather than SentencePiece subword token sequences. Refer
|
||||
to the documentation of byT5 which can be found :doc:`here <byt5>`.
|
||||
|
||||
All checkpoints can be found on the `hub <https://huggingface.co/models?search=t5>`__.
|
||||
|
||||
This model was contributed by `thomwolf <https://huggingface.co/thomwolf>`__. The original code can be found `here
|
||||
<https://github.com/google-research/text-to-text-transfer-transformer>`__.
|
||||
|
||||
.. _training:
|
||||
|
||||
Training
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
T5 is an encoder-decoder model and converts all NLP problems into a text-to-text format. It is trained using teacher
|
||||
forcing. This means that for training we always need an input sequence and a target sequence. The input sequence is fed
|
||||
to the model using :obj:`input_ids`. The target sequence is shifted to the right, i.e., prepended by a start-sequence
|
||||
token and fed to the decoder using the :obj:`decoder_input_ids`. In teacher-forcing style, the target sequence is then
|
||||
appended by the EOS token and corresponds to the :obj:`labels`. The PAD token is hereby used as the start-sequence
|
||||
token. T5 can be trained / fine-tuned both in a supervised and unsupervised fashion.
|
||||
forcing. This means that for training, we always need an input sequence and a corresponding target sequence. The input
|
||||
sequence is fed to the model using :obj:`input_ids`. The target sequence is shifted to the right, i.e., prepended by a
|
||||
start-sequence token and fed to the decoder using the :obj:`decoder_input_ids`. In teacher-forcing style, the target
|
||||
sequence is then appended by the EOS token and corresponds to the :obj:`labels`. The PAD token is hereby used as the
|
||||
start-sequence token. T5 can be trained / fine-tuned both in a supervised and unsupervised fashion.
|
||||
|
||||
One can use :class:`~transformers.T5ForConditionalGeneration` (or the Tensorflow/Flax variant), which includes the
|
||||
language modeling head on top of the decoder.
|
||||
|
||||
- Unsupervised denoising training
|
||||
|
||||
In this setup spans of the input sequence are masked by so-called sentinel tokens (*a.k.a* unique mask tokens) and
|
||||
In this setup, spans of the input sequence are masked by so-called sentinel tokens (*a.k.a* unique mask tokens) and
|
||||
the output sequence is formed as a concatenation of the same sentinel tokens and the *real* masked tokens. Each
|
||||
sentinel token represents a unique mask token for this sentence and should start with :obj:`<extra_id_0>`,
|
||||
:obj:`<extra_id_1>`, ... up to :obj:`<extra_id_99>`. As a default, 100 sentinel tokens are available in
|
||||
@ -72,34 +97,201 @@ token. T5 can be trained / fine-tuned both in a supervised and unsupervised fash
|
||||
For instance, the sentence "The cute dog walks in the park" with the masks put on "cute dog" and "the" should be
|
||||
processed as follows:
|
||||
|
||||
.. code-block::
|
||||
.. code-block::
|
||||
|
||||
from transformers import T5ForConditionalGeneration, T5Tokenizer
|
||||
model = T5ForConditionalGeneration.from_pretrained("t5-small")
|
||||
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
||||
from transformers import T5Tokenizer, T5ForConditionalGeneration
|
||||
|
||||
input_ids = tokenizer('The <extra_id_0> walks in <extra_id_1> park', return_tensors='pt').input_ids
|
||||
labels = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2>', return_tensors='pt').input_ids
|
||||
# the forward function automatically creates the correct decoder_input_ids
|
||||
loss = model(input_ids=input_ids, labels=labels).loss
|
||||
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
||||
model = T5ForConditionalGeneration.from_pretrained("t5-small")
|
||||
|
||||
input_ids = tokenizer('The <extra_id_0> walks in <extra_id_1> park', return_tensors='pt').input_ids
|
||||
labels = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2>', return_tensors='pt').input_ids
|
||||
# the forward function automatically creates the correct decoder_input_ids
|
||||
loss = model(input_ids=input_ids, labels=labels).loss
|
||||
|
||||
If you're interested in pre-training T5 on a new corpus, check out the `run_t5_mlm_flax.py
|
||||
<https://github.com/huggingface/transformers/tree/master/examples/flax/language-modeling>`__ script in the Examples
|
||||
directory.
|
||||
|
||||
- Supervised training
|
||||
|
||||
In this setup the input sequence and output sequence are standard sequence-to-sequence input output mapping. In
|
||||
translation, for instance with the input sequence "The house is wonderful." and output sequence "Das Haus ist
|
||||
wunderbar.", the sentences should be processed as follows:
|
||||
In this setup, the input sequence and output sequence are a standard sequence-to-sequence input-output mapping.
|
||||
Suppose that we want to fine-tune the model for translation for example, and we have a training example: the input
|
||||
sequence "The house is wonderful." and output sequence "Das Haus ist wunderbar.", then they should be prepared for
|
||||
the model as follows:
|
||||
|
||||
.. code-block::
|
||||
|
||||
from transformers import T5Tokenizer, T5ForConditionalGeneration
|
||||
|
||||
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
||||
model = T5ForConditionalGeneration.from_pretrained("t5-small")
|
||||
|
||||
input_ids = tokenizer('translate English to German: The house is wonderful.', return_tensors='pt').input_ids
|
||||
labels = tokenizer('Das Haus ist wunderbar.', return_tensors='pt').input_ids
|
||||
# the forward function automatically creates the correct decoder_input_ids
|
||||
loss = model(input_ids=input_ids, labels=labels).loss
|
||||
|
||||
As you can see, only 2 inputs are required for the model in order to compute a loss: :obj:`input_ids` (which are the
|
||||
:obj:`input_ids` of the encoded input sequence) and :obj:`labels` (which are the :obj:`input_ids` of the encoded
|
||||
target sequence). The model will automatically create the :obj:`decoder_input_ids` based on the :obj:`labels`, by
|
||||
shifting them one position to the right and prepending the :obj:`config.decoder_start_token_id`, which for T5 is
|
||||
equal to 0 (i.e. the id of the pad token). Also note the task prefix: we prepend the input sequence with 'translate
|
||||
English to German: ' before encoding it. This will help in improving the performance, as this task prefix was used
|
||||
during T5's pre-training.
|
||||
|
||||
However, the example above only shows a single training example. In practice, one trains deep learning models in
|
||||
batches. This entails that we must pad/truncate examples to the same length. For encoder-decoder models, one
|
||||
typically defines a :obj:`max_source_length` and :obj:`max_target_length`, which determine the maximum length of the
|
||||
input and output sequences respectively (otherwise they are truncated). These should be carefully set depending on
|
||||
the task.
|
||||
|
||||
In addition, we must make sure that padding token id's of the :obj:`labels` are not taken into account by the loss
|
||||
function. In PyTorch and Tensorflow, this can be done by replacing them with -100, which is the :obj:`ignore_index`
|
||||
of the :obj:`CrossEntropyLoss`. In Flax, one can use the :obj:`decoder_attention_mask` to ignore padded tokens from
|
||||
the loss (see the `Flax summarization script
|
||||
<https://github.com/huggingface/transformers/tree/master/examples/flax/summarization>`__ for details). We also pass
|
||||
:obj:`attention_mask` as additional input to the model, which makes sure that padding tokens of the inputs are
|
||||
ignored. The code example below illustrates all of this.
|
||||
|
||||
.. code-block::
|
||||
|
||||
from transformers import T5Tokenizer, T5ForConditionalGeneration
|
||||
import torch
|
||||
|
||||
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
||||
model = T5ForConditionalGeneration.from_pretrained("t5-small")
|
||||
|
||||
# the following 2 hyperparameters are task-specific
|
||||
max_source_length = 512
|
||||
max_target_length = 128
|
||||
|
||||
# Suppose we have the following 2 training examples:
|
||||
input_sequence_1 = "Welcome to NYC"
|
||||
output_sequence_1 = "Bienvenue à NYC"
|
||||
|
||||
input_sequence_2 = "HuggingFace is a company"
|
||||
output_sequence_2 = "HuggingFace est une entreprise"
|
||||
|
||||
# encode the inputs
|
||||
task_prefix = "translate English to French: "
|
||||
input_sequences = [input_sequence_1, input_sequence_2]
|
||||
encoding = tokenizer([task_prefix + sequence for sequence in input_sequences],
|
||||
padding='longest',
|
||||
max_length=max_source_length,
|
||||
truncation=True,
|
||||
return_tensors="pt")
|
||||
input_ids, attention_mask = encoding.input_ids, encoding.attention_mask
|
||||
|
||||
# encode the targets
|
||||
target_encoding = tokenizer([output_sequence_1, output_sequence_2],
|
||||
padding='longest',
|
||||
max_length=max_target_length,
|
||||
truncation=True)
|
||||
labels = target_encoding.input_ids
|
||||
|
||||
# replace padding token id's of the labels by -100
|
||||
labels = [
|
||||
[(label if label != tokenizer.pad_token_id else -100) for label in labels_example] for labels_example in labels
|
||||
]
|
||||
labels = torch.tensor(labels)
|
||||
|
||||
# forward pass
|
||||
loss = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels).loss
|
||||
|
||||
Additional training tips:
|
||||
|
||||
- T5 models need a slightly higher learning rate than the default one set in the :obj:`Trainer` when using the AdamW
|
||||
optimizer. Typically, 1e-4 and 3e-4 work well for most problems (classification, summarization, translation, question
|
||||
answering, question generation). Note that T5 was pre-trained using the AdaFactor optimizer.
|
||||
|
||||
- According to `this forum post <https://discuss.huggingface.co/t/t5-finetuning-tips/684>`__, task prefixes matter when
|
||||
(1) doing multi-task training (2) your task is similar or related to one of the supervised tasks used in T5's
|
||||
pre-training mixture (see Appendix D of the `paper <https://arxiv.org/pdf/1910.10683.pdf>`__ for the task prefixes
|
||||
used).
|
||||
|
||||
- If training on TPU, it is recommended to pad all examples of the dataset to the same length or make use of
|
||||
`pad_to_multiple_of` to have a small number of predefined bucket sizes to fit all examples in. Dynamically padding
|
||||
batches to the longest example is not recommended on TPU as it triggers a recompilation for every batch shape that is
|
||||
encountered during training thus significantly slowing down the training. only padding up to the longest example in a
|
||||
batch) leads to very slow training on TPU.
|
||||
|
||||
.. _inference:
|
||||
|
||||
Inference
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
At inference time, it is recommended to use :meth:`~transformers.generation_utils.GenerationMixin.generate`. This
|
||||
method takes care of encoding the input and feeding the encoded hidden states via cross-attention layers to the decoder
|
||||
and auto-regressively generates the decoder output. Check out `this blog post
|
||||
<https://huggingface.co/blog/how-to-generate>`__ to know all the details about generating text with Transformers.
|
||||
There's also `this blog post <https://huggingface.co/blog/encoder-decoder#encoder-decoder>`__ which explains how
|
||||
generation works in general in encoder-decoder models.
|
||||
|
||||
.. code-block::
|
||||
|
||||
from transformers import T5ForConditionalGeneration, T5Tokenizer
|
||||
model = T5ForConditionalGeneration.from_pretrained("t5-small")
|
||||
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
||||
from transformers import T5Tokenizer, T5ForConditionalGeneration
|
||||
|
||||
input_ids = tokenizer('translate English to German: The house is wonderful.', return_tensors='pt').input_ids
|
||||
labels = tokenizer('Das Haus ist wunderbar.', return_tensors='pt').input_ids
|
||||
# the forward function automatically creates the correct decoder_input_ids
|
||||
loss = model(input_ids=input_ids, labels=labels).loss
|
||||
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
||||
model = T5ForConditionalGeneration.from_pretrained("t5-small")
|
||||
|
||||
input_ids = tokenizer('translate English to German: The house is wonderful.', return_tensors='pt').input_ids
|
||||
outputs = model.generate(input_ids)
|
||||
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
||||
# Das Haus ist wunderbar.
|
||||
|
||||
Note that T5 uses the :obj:`pad_token_id` as the :obj:`decoder_start_token_id`, so when doing generation without using
|
||||
:meth:`~transformers.generation_utils.GenerationMixin.generate`, make sure you start it with the :obj:`pad_token_id`.
|
||||
|
||||
The example above only shows a single example. You can also do batched inference, like so:
|
||||
|
||||
.. code-block::
|
||||
|
||||
from transformers import T5Tokenizer, T5ForConditionalGeneration
|
||||
|
||||
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
||||
model = T5ForConditionalGeneration.from_pretrained("t5-small")
|
||||
|
||||
# when generating, we will use the logits of right-most token to predict the next token
|
||||
# so the padding should be on the left
|
||||
tokenizer.padding_side = "left"
|
||||
tokenizer.pad_token = tokenizer.eos_token # to avoid an error
|
||||
|
||||
task_prefix = 'translate English to German: '
|
||||
sentences = ['The house is wonderful.', 'I like to work in NYC.'] # use different length sentences to test batching
|
||||
inputs = tokenizer([task_prefix + sentence for sentence in sentences], return_tensors="pt", padding=True)
|
||||
|
||||
output_sequences = model.generate(
|
||||
input_ids=inputs['input_ids'],
|
||||
attention_mask=inputs['attention_mask'],
|
||||
do_sample=False, # disable sampling to test if batching affects output
|
||||
)
|
||||
|
||||
print(tokenizer.batch_decode(output_sequences, skip_special_tokens=True))
|
||||
|
||||
# ['Das Haus ist wunderbar.', 'Ich arbeite gerne in NYC.']
|
||||
|
||||
.. _scripts:
|
||||
|
||||
Example scripts
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
T5 is supported by several example scripts, both for pre-training and fine-tuning.
|
||||
|
||||
* pre-training: the `run_t5_mlm_flax.py
|
||||
<https://github.com/huggingface/transformers/blob/master/examples/flax/language-modeling/run_t5_mlm_flax.py>`__
|
||||
script allows you to further pre-train T5 or pre-train T5 from scratch on your own data. The `t5_tokenizer_model.py
|
||||
<https://github.com/huggingface/transformers/blob/master/examples/flax/language-modeling/t5_tokenizer_model.py>`__
|
||||
script allows you to further train a T5 tokenizer or train a T5 Tokenizer from scratch on your own data. Note that
|
||||
Flax (a neural network library on top of JAX) is particularly useful to train on TPU hardware.
|
||||
|
||||
* fine-tuning: T5 is supported by the official summarization scripts (`PyTorch
|
||||
<https://github.com/huggingface/transformers/tree/master/examples/pytorch/summarization>`__, `Tensorflow
|
||||
<https://github.com/huggingface/transformers/tree/master/examples/tensorflow/summarization>`__, and `Flax
|
||||
<https://github.com/huggingface/transformers/tree/master/examples/flax/summarization>`__) and translation scripts
|
||||
(`PyTorch <https://github.com/huggingface/transformers/tree/master/examples/pytorch/translation>`__ and `Tensorflow
|
||||
<https://github.com/huggingface/transformers/tree/master/examples/tensorflow/translation>`__). These scripts allow
|
||||
you to easily fine-tune T5 on custom data for summarization/translation.
|
||||
|
||||
T5Config
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
66
docs/source/model_doc/t5v1.1.rst
Normal file
66
docs/source/model_doc/t5v1.1.rst
Normal file
@ -0,0 +1,66 @@
|
||||
..
|
||||
Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
T5v1.1
|
||||
-----------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
Overview
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
T5v1.1 was released in the `google-research/text-to-text-transfer-transformer
|
||||
<https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511>`__
|
||||
repository by Colin Raffel et al. It's an improved version of the original T5 model.
|
||||
|
||||
One can directly plug in the weights of T5v1.1 into a T5 model, like so:
|
||||
|
||||
.. code-block::
|
||||
|
||||
from transformers import T5ForConditionalGeneration
|
||||
|
||||
model = T5ForConditionalGeneration.from_pretrained('google/t5-v1_1-base')
|
||||
|
||||
T5 Version 1.1 includes the following improvements compared to the original T5 model:
|
||||
|
||||
- GEGLU activation in the feed-forward hidden layer, rather than ReLU. See `this paper
|
||||
<https://arxiv.org/abs/2002.05202>`__.
|
||||
|
||||
- Dropout was turned off in pre-training (quality win). Dropout should be re-enabled during fine-tuning.
|
||||
|
||||
- Pre-trained on C4 only without mixing in the downstream tasks.
|
||||
|
||||
- No parameter sharing between the embedding and classifier layer.
|
||||
|
||||
- "xl" and "xxl" replace "3B" and "11B". The model shapes are a bit different - larger :obj:`d_model` and smaller
|
||||
:obj:`num_heads` and :obj:`d_ff`.
|
||||
|
||||
Note: T5 Version 1.1 was only pre-trained on `C4 <https://huggingface.co/datasets/c4>`__ excluding any supervised
|
||||
training. Therefore, this model has to be fine-tuned before it is useable on a downstream task, unlike the original T5
|
||||
model. Since t5v1.1 was pre-trained unsupervisedly, there's no real advantage to using a task prefix during single-task
|
||||
fine-tuning. If you are doing multi-task fine-tuning, you should use a prefix.
|
||||
|
||||
Google has released the following variants:
|
||||
|
||||
- `google/t5-v1_1-small <https://huggingface.co/google/t5-v1_1-small>`__
|
||||
|
||||
- `google/t5-v1_1-base <https://huggingface.co/google/t5-v1_1-base>`__
|
||||
|
||||
- `google/t5-v1_1-large <https://huggingface.co/google/t5-v1_1-large>`__
|
||||
|
||||
- `google/t5-v1_1-xl <https://huggingface.co/google/t5-v1_1-xl>`__
|
||||
|
||||
- `google/t5-v1_1-xxl <https://huggingface.co/google/t5-v1_1-xxl>`__.
|
||||
|
||||
One can refer to :doc:`T5's documentation page <t5>` for all tips, code examples and notebooks.
|
||||
|
||||
This model was contributed by `patrickvonplaten <https://huggingface.co/patrickvonplaten>`__. The original code can be
|
||||
found `here
|
||||
<https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511>`__.
|
@ -1059,11 +1059,11 @@ class FlaxT5PreTrainedModel(FlaxPreTrainedModel):
|
||||
|
||||
>>> from transformers import T5Tokenizer, FlaxT5ForConditionalGeneration
|
||||
|
||||
>>> model = FlaxT5ForConditionalGeneration.from_pretrained('t5-small')
|
||||
>>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
|
||||
>>> model = FlaxT5ForConditionalGeneration.from_pretrained('t5-small')
|
||||
|
||||
>>> text = "My friends are cool but they eat too many carbs."
|
||||
>>> inputs = tokenizer(text, max_length=512, return_tensors='jax')
|
||||
>>> inputs = tokenizer(text, return_tensors='np')
|
||||
>>> encoder_outputs = model.encode(**inputs)
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
@ -1118,19 +1118,20 @@ class FlaxT5PreTrainedModel(FlaxPreTrainedModel):
|
||||
Example::
|
||||
|
||||
>>> from transformers import T5Tokenizer, FlaxT5ForConditionalGeneration
|
||||
>>> import jax.numpy as jnp
|
||||
|
||||
>>> model = FlaxT5ForConditionalGeneration.from_pretrained('t5-small')
|
||||
>>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
|
||||
>>> model = FlaxT5ForConditionalGeneration.from_pretrained('t5-small')
|
||||
|
||||
>>> text = "My friends are cool but they eat too many carbs."
|
||||
>>> inputs = tokenizer(text, max_length=512, return_tensors='jax')
|
||||
>>> inputs = tokenizer(text, return_tensors='np')
|
||||
>>> encoder_outputs = model.encode(**inputs)
|
||||
|
||||
>>> decoder_start_token_id = model.config.decoder_start_token_id
|
||||
>>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
|
||||
|
||||
>>> outputs = model.decode(decoder_input_ids, encoder_outputs)
|
||||
>>> last_decoder_hidden_states = outputs.last_hidden_state
|
||||
>>> logits = outputs.logits
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@ -1329,8 +1330,9 @@ FLAX_T5_MODEL_DOCSTRING = """
|
||||
|
||||
>>> input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="np").input_ids
|
||||
>>> decoder_input_ids = tokenizer("Studies show that", return_tensors="np").input_ids
|
||||
>>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
|
||||
|
||||
>>> # forward pass
|
||||
>>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
|
||||
>>> last_hidden_states = outputs.last_hidden_state
|
||||
"""
|
||||
|
||||
@ -1469,19 +1471,20 @@ class FlaxT5ForConditionalGeneration(FlaxT5PreTrainedModel):
|
||||
Example::
|
||||
|
||||
>>> from transformers import T5Tokenizer, FlaxT5ForConditionalGeneration
|
||||
>>> import jax.numpy as jnp
|
||||
|
||||
>>> model = FlaxT5ForConditionalGeneration.from_pretrained('t5-small')
|
||||
>>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
|
||||
>>> model = FlaxT5ForConditionalGeneration.from_pretrained('t5-small')
|
||||
|
||||
>>> text = "summarize: My friends are cool but they eat too many carbs."
|
||||
>>> inputs = tokenizer(text, max_length=512, return_tensors='jax')
|
||||
>>> inputs = tokenizer(text, return_tensors='np')
|
||||
>>> encoder_outputs = model.encode(**inputs)
|
||||
|
||||
>>> decoder_start_token_id = model.config.decoder_start_token_id
|
||||
>>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
|
||||
|
||||
>>> outputs = model.decode(decoder_input_ids, encoder_outputs)
|
||||
>>> last_decoder_hidden_states = outputs.last_hidden_state
|
||||
>>> logits = outputs.logits
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@ -1617,15 +1620,15 @@ FLAX_T5_CONDITIONAL_GENERATION_DOCSTRING = """
|
||||
|
||||
>>> from transformers import T5Tokenizer, FlaxT5ForConditionalGeneration
|
||||
|
||||
>>> model = FlaxT5ForConditionalGeneration.from_pretrained('t5-small')
|
||||
>>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
|
||||
>>> model = FlaxT5ForConditionalGeneration.from_pretrained('t5-small')
|
||||
|
||||
>>> ARTICLE_TO_SUMMARIZE = "summarize: My friends are cool but they eat too many carbs."
|
||||
>>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=512, return_tensors='jax')
|
||||
>>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], return_tensors='np')
|
||||
|
||||
>>> # Generate Summary
|
||||
>>> summary_ids = model.generate(inputs['input_ids']).sequences
|
||||
>>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False))
|
||||
>>> print(tokenizer.decode(summary_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=False))
|
||||
"""
|
||||
|
||||
|
||||
|
@ -1234,7 +1234,7 @@ num_heads)`.
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare T5 Model transformer outputting raw hidden-states" "without any specific head on top.",
|
||||
"The bare T5 Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
T5_START_DOCSTRING,
|
||||
)
|
||||
class T5Model(T5PreTrainedModel):
|
||||
@ -1344,8 +1344,9 @@ class T5Model(T5PreTrainedModel):
|
||||
|
||||
>>> input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="pt").input_ids # Batch size 1
|
||||
>>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
|
||||
>>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
|
||||
|
||||
>>> # forward pass
|
||||
>>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
|
||||
>>> last_hidden_states = outputs.last_hidden_state
|
||||
"""
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
@ -1537,14 +1538,18 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
>>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
|
||||
>>> model = T5ForConditionalGeneration.from_pretrained('t5-small')
|
||||
|
||||
>>> # training
|
||||
>>> input_ids = tokenizer('The <extra_id_0> walks in <extra_id_1> park', return_tensors='pt').input_ids
|
||||
>>> labels = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2> </s>', return_tensors='pt').input_ids
|
||||
>>> labels = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2>', return_tensors='pt').input_ids
|
||||
>>> outputs = model(input_ids=input_ids, labels=labels)
|
||||
>>> loss = outputs.loss
|
||||
>>> logits = outputs.logits
|
||||
|
||||
>>> input_ids = tokenizer("summarize: studies have shown that owning a dog is good for you ", return_tensors="pt").input_ids # Batch size 1
|
||||
>>> # inference
|
||||
>>> input_ids = tokenizer("summarize: studies have shown that owning a dog is good for you", return_tensors="pt").input_ids # Batch size 1
|
||||
>>> outputs = model.generate(input_ids)
|
||||
>>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
||||
>>> # studies have shown that owning a dog is good for you.
|
||||
"""
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
@ -1133,8 +1133,10 @@ class TFT5Model(TFT5PreTrainedModel):
|
||||
|
||||
>>> input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="tf").input_ids # Batch size 1
|
||||
>>> decoder_input_ids = tokenizer("Studies show that", return_tensors="tf").input_ids # Batch size 1
|
||||
>>> outputs = model(input_ids, decoder_input_ids=decoder_input_ids)
|
||||
|
||||
>>> # forward pass
|
||||
>>> outputs = model(input_ids, decoder_input_ids=decoder_input_ids)
|
||||
>>> last_hidden_states = outputs.last_hidden_state
|
||||
|
||||
"""
|
||||
# FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
|
||||
@ -1321,15 +1323,18 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
|
||||
>>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
|
||||
>>> model = TFT5ForConditionalGeneration.from_pretrained('t5-small')
|
||||
|
||||
>>> # training
|
||||
>>> inputs = tokenizer('The <extra_id_0> walks in <extra_id_1> park', return_tensors='tf').input_ids
|
||||
>>> labels = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2> </s>', return_tensors='tf').input_ids
|
||||
>>> labels = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2>', return_tensors='tf').input_ids
|
||||
>>> outputs = model(inputs, labels=labels)
|
||||
>>> loss = outputs.loss
|
||||
>>> logits = outputs.logits
|
||||
|
||||
>>> inputs = tokenizer("summarize: studies have shown that owning a dog is good for you ", return_tensors="tf").input_ids # Batch size 1
|
||||
|
||||
>>> result = model.generate(inputs)
|
||||
>>> # inference
|
||||
>>> inputs = tokenizer("summarize: studies have shown that owning a dog is good for you", return_tensors="tf").input_ids # Batch size 1
|
||||
>>> outputs = model.generate(inputs)
|
||||
>>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
||||
>>> # studies have shown that owning a dog is good for you
|
||||
|
||||
"""
|
||||
# FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
|
||||
@ -1571,7 +1576,7 @@ class TFT5EncoderModel(TFT5PreTrainedModel):
|
||||
|
||||
Examples::
|
||||
|
||||
>>> from transformers import T5Tokenizer, TFT5Model
|
||||
>>> from transformers import T5Tokenizer, TFT5EncoderModel
|
||||
|
||||
>>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
|
||||
>>> model = TFT5EncoderModel.from_pretrained('t5-small')
|
||||
@ -1579,7 +1584,6 @@ class TFT5EncoderModel(TFT5PreTrainedModel):
|
||||
>>> input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="tf").input_ids # Batch size 1
|
||||
>>> outputs = model(input_ids)
|
||||
|
||||
|
||||
"""
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
|
Loading…
Reference in New Issue
Block a user