Improve T5 docs (#13240)

* Remove disclaimer

* First draft

* Fix rebase

* Improve docs some more

* Add inference section

* Improve example scripts section

* Improve code examples of modeling files

* Add docs regarding task prefix

* Address @craffel's comments

* Apply suggestions from @patrickvonplaten's review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Add suggestions from code review

* Apply @sgugger's suggestions

* Fix Flax code examples

* Fix index.rst

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
NielsRogge 2021-09-01 15:05:40 +02:00 committed by GitHub
parent ba1b3db709
commit 4766e009b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 363 additions and 66 deletions

View File

@ -272,6 +272,7 @@ Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih.
1. **[Splinter](https://huggingface.co/transformers/model_doc/splinter.html)** (from Tel Aviv University), released together with the paper [Few-Shot Question Answering by Pretraining Span Selection](https://arxiv.org/abs/2101.00438) by Ori Ram, Yuval Kirstain, Jonathan Berant, Amir Globerson, Omer Levy.
1. **[SqueezeBert](https://huggingface.co/transformers/model_doc/squeezebert.html)** released with the paper [SqueezeBERT: What can computer vision teach NLP about efficient neural networks?](https://arxiv.org/abs/2006.11316) by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer.
1. **[T5](https://huggingface.co/transformers/model_doc/t5.html)** (from Google AI) released with the paper [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu.
1. **[T5v1.1](https://huggingface.co/transformers/model_doc/t5v1.1.html)** (from Google AI) released in the repository [google-research/text-to-text-transfer-transformer](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu.
1. **[TAPAS](https://huggingface.co/transformers/model_doc/tapas.html)** (from Google AI) released with the paper [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349) by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos.
1. **[Transformer-XL](https://huggingface.co/transformers/model_doc/transformerxl.html)** (from Google/CMU) released with the paper [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context](https://arxiv.org/abs/1901.02860) by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
1. **[Vision Transformer (ViT)](https://huggingface.co/transformers/model_doc/vit.html)** (from Google AI) released with the paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby.

View File

@ -282,35 +282,40 @@ Supported models
62. :doc:`T5 <model_doc/t5>` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a
Unified Text-to-Text Transformer <https://arxiv.org/abs/1910.10683>`__ by Colin Raffel and Noam Shazeer and Adam
Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu.
63. :doc:`TAPAS <model_doc/tapas>` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via
63. :doc:`T5v1.1 <model_doc/t5v1.1>` (from Google AI) released in the repository
`google-research/text-to-text-transfer-transformer
<https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511>`__ by
Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi
Zhou and Wei Li and Peter J. Liu.
64. :doc:`TAPAS <model_doc/tapas>` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via
Pre-training <https://arxiv.org/abs/2004.02349>`__ by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller,
Francesco Piccinno and Julian Martin Eisenschlos.
64. :doc:`Transformer-XL <model_doc/transformerxl>` (from Google/CMU) released with the paper `Transformer-XL:
65. :doc:`Transformer-XL <model_doc/transformerxl>` (from Google/CMU) released with the paper `Transformer-XL:
Attentive Language Models Beyond a Fixed-Length Context <https://arxiv.org/abs/1901.02860>`__ by Zihang Dai*,
Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
65. :doc:`Vision Transformer (ViT) <model_doc/vit>` (from Google AI) released with the paper `An Image is Worth 16x16
66. :doc:`Vision Transformer (ViT) <model_doc/vit>` (from Google AI) released with the paper `An Image is Worth 16x16
Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`__ by Alexey Dosovitskiy,
Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias
Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby.
66. :doc:`VisualBERT <model_doc/visual_bert>` (from UCLA NLP) released with the paper `VisualBERT: A Simple and
67. :doc:`VisualBERT <model_doc/visual_bert>` (from UCLA NLP) released with the paper `VisualBERT: A Simple and
Performant Baseline for Vision and Language <https://arxiv.org/pdf/1908.03557>`__ by Liunian Harold Li, Mark
Yatskar, Da Yin, Cho-Jui Hsieh, Kai-Wei Chang.
67. :doc:`Wav2Vec2 <model_doc/wav2vec2>` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for
68. :doc:`Wav2Vec2 <model_doc/wav2vec2>` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for
Self-Supervised Learning of Speech Representations <https://arxiv.org/abs/2006.11477>`__ by Alexei Baevski, Henry
Zhou, Abdelrahman Mohamed, Michael Auli.
68. :doc:`XLM <model_doc/xlm>` (from Facebook) released together with the paper `Cross-lingual Language Model
69. :doc:`XLM <model_doc/xlm>` (from Facebook) released together with the paper `Cross-lingual Language Model
Pretraining <https://arxiv.org/abs/1901.07291>`__ by Guillaume Lample and Alexis Conneau.
69. :doc:`XLM-ProphetNet <model_doc/xlmprophetnet>` (from Microsoft Research) released with the paper `ProphetNet:
70. :doc:`XLM-ProphetNet <model_doc/xlmprophetnet>` (from Microsoft Research) released with the paper `ProphetNet:
Predicting Future N-gram for Sequence-to-Sequence Pre-training <https://arxiv.org/abs/2001.04063>`__ by Yu Yan,
Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou.
70. :doc:`XLM-RoBERTa <model_doc/xlmroberta>` (from Facebook AI), released together with the paper `Unsupervised
71. :doc:`XLM-RoBERTa <model_doc/xlmroberta>` (from Facebook AI), released together with the paper `Unsupervised
Cross-lingual Representation Learning at Scale <https://arxiv.org/abs/1911.02116>`__ by Alexis Conneau*, Kartikay
Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke
Zettlemoyer and Veselin Stoyanov.
71. :doc:`XLNet <model_doc/xlnet>` (from Google/CMU) released with the paper `XLNet: Generalized Autoregressive
72. :doc:`XLNet <model_doc/xlnet>` (from Google/CMU) released with the paper `XLNet: Generalized Autoregressive
Pretraining for Language Understanding <https://arxiv.org/abs/1906.08237>`__ by Zhilin Yang*, Zihang Dai*, Yiming
Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le.
72. :doc:`XLSR-Wav2Vec2 <model_doc/xlsr_wav2vec2>` (from Facebook AI) released with the paper `Unsupervised
73. :doc:`XLSR-Wav2Vec2 <model_doc/xlsr_wav2vec2>` (from Facebook AI) released with the paper `Unsupervised
Cross-Lingual Representation Learning For Speech Recognition <https://arxiv.org/abs/2006.13979>`__ by Alexis
Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael Auli.
@ -608,6 +613,7 @@ Flax), PyTorch, and/or TensorFlow.
model_doc/splinter
model_doc/squeezebert
model_doc/t5
model_doc/t5v1.1
model_doc/tapas
model_doc/transformerxl
model_doc/vit

View File

@ -39,8 +39,11 @@ experiments.*
This model was contributed by `patrickvonplaten <https://huggingface.co/patrickvonplaten>`__. The original code can be
found `here <https://github.com/google-research/byt5>`__.
ByT5's architecture is based on the T5v1.1 model, so one can refer to :doc:`T5v1.1's documentation page <t5v1.1>`. They
only differ in how inputs should be prepared for the model, see the code examples below.
ByT5's architecture is based on the T5 model, so one can refer to :doc:`T5's documentation page <t5>`.
Since ByT5 was pre-trained unsupervisedly, there's no real advantage to using a task prefix during single-task
fine-tuning. If you are doing multi-task fine-tuning, you should use a prefix.
Example

View File

@ -28,6 +28,23 @@ multilingual variant of T5 that was pre-trained on a new Common Crawl-based data
the design and modified training of mT5 and demonstrate its state-of-the-art performance on many multilingual
benchmarks. All of the code and model checkpoints*
Note: mT5 was only pre-trained on `mC4 <https://huggingface.co/datasets/mc4>`__ excluding any supervised training.
Therefore, this model has to be fine-tuned before it is useable on a downstream task, unlike the original T5 model.
Since mT5 was pre-trained unsupervisedly, there's no real advantage to using a task prefix during single-task
fine-tuning. If you are doing multi-task fine-tuning, you should use a prefix.
Google has released the following variants:
- `google/mt5-small <https://huggingface.co/google/mt5-small>`__
- `google/mt5-base <https://huggingface.co/google/mt5-base>`__
- `google/mt5-large <https://huggingface.co/google/mt5-large>`__
- `google/mt5-xl <https://huggingface.co/google/mt5-xl>`__
- `google/mt5-xxl <https://huggingface.co/google/mt5-xxl>`__.
This model was contributed by `patrickvonplaten <https://huggingface.co/patrickvonplaten>`__. The original code can be
found `here <https://github.com/google-research/multilingual-t5>`__.

View File

@ -13,9 +13,6 @@
T5
-----------------------------------------------------------------------------------------------------------------------
**DISCLAIMER:** This model is still a work in progress, if you see something strange, file a `Github Issue
<https://github.com/huggingface/transformers/issues/new?assignees=&labels=&template=bug-report.md&title>`__.
Overview
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@ -42,28 +39,56 @@ Tips:
different prefix to the input corresponding to each task, e.g., for translation: *translate English to German: ...*,
for summarization: *summarize: ...*.
For more information about which prefix to use, it is easiest to look into Appendix D of the `paper
<https://arxiv.org/pdf/1910.10683.pdf>`__. - For sequence-to-sequence generation, it is recommended to use
:meth:`~transformers.generation_utils.GenerationMixin.generate`. This method takes care of feeding the encoded input
via cross-attention layers to the decoder and auto-regressively generates the decoder output. - T5 uses relative
scalar embeddings. Encoder input padding can be done on the left and on the right.
- T5 uses relative scalar embeddings. Encoder input padding can be done on the left and on the right.
- See the :ref:`training`, :ref:`inference` and :ref:`scripts` sections below for all details regarding usage.
T5 comes in different sizes:
- `t5-small <https://huggingface.co/t5-small>`__
- `t5-base <https://huggingface.co/t5-base>`__
- `t5-large <https://huggingface.co/t5-large>`__
- `t5-3b <https://huggingface.co/t5-3b>`__
- `t5-11b <https://huggingface.co/t5-11b>`__.
Based on the original T5 model, Google has released some follow-up works:
- **T5v1.1**: T5v1.1 is an improved version of T5 with some architectural tweaks, and is pre-trained on C4 only without
mixing in the supervised tasks. Refer to the documentation of T5v1.1 which can be found :doc:`here <t5v1.1>`.
- **mT5**: mT5 is a multilingual T5 model. It is pre-trained on the mC4 corpus, which includes 101 languages. Refer to
the documentation of mT5 which can be found :doc:`here <mt5>`.
- **byT5**: byT5 is a T5 model pre-trained on byte sequences rather than SentencePiece subword token sequences. Refer
to the documentation of byT5 which can be found :doc:`here <byt5>`.
All checkpoints can be found on the `hub <https://huggingface.co/models?search=t5>`__.
This model was contributed by `thomwolf <https://huggingface.co/thomwolf>`__. The original code can be found `here
<https://github.com/google-research/text-to-text-transfer-transformer>`__.
.. _training:
Training
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
T5 is an encoder-decoder model and converts all NLP problems into a text-to-text format. It is trained using teacher
forcing. This means that for training we always need an input sequence and a target sequence. The input sequence is fed
to the model using :obj:`input_ids`. The target sequence is shifted to the right, i.e., prepended by a start-sequence
token and fed to the decoder using the :obj:`decoder_input_ids`. In teacher-forcing style, the target sequence is then
appended by the EOS token and corresponds to the :obj:`labels`. The PAD token is hereby used as the start-sequence
token. T5 can be trained / fine-tuned both in a supervised and unsupervised fashion.
forcing. This means that for training, we always need an input sequence and a corresponding target sequence. The input
sequence is fed to the model using :obj:`input_ids`. The target sequence is shifted to the right, i.e., prepended by a
start-sequence token and fed to the decoder using the :obj:`decoder_input_ids`. In teacher-forcing style, the target
sequence is then appended by the EOS token and corresponds to the :obj:`labels`. The PAD token is hereby used as the
start-sequence token. T5 can be trained / fine-tuned both in a supervised and unsupervised fashion.
One can use :class:`~transformers.T5ForConditionalGeneration` (or the Tensorflow/Flax variant), which includes the
language modeling head on top of the decoder.
- Unsupervised denoising training
In this setup spans of the input sequence are masked by so-called sentinel tokens (*a.k.a* unique mask tokens) and
In this setup, spans of the input sequence are masked by so-called sentinel tokens (*a.k.a* unique mask tokens) and
the output sequence is formed as a concatenation of the same sentinel tokens and the *real* masked tokens. Each
sentinel token represents a unique mask token for this sentence and should start with :obj:`<extra_id_0>`,
:obj:`<extra_id_1>`, ... up to :obj:`<extra_id_99>`. As a default, 100 sentinel tokens are available in
@ -72,34 +97,201 @@ token. T5 can be trained / fine-tuned both in a supervised and unsupervised fash
For instance, the sentence "The cute dog walks in the park" with the masks put on "cute dog" and "the" should be
processed as follows:
.. code-block::
.. code-block::
from transformers import T5ForConditionalGeneration, T5Tokenizer
model = T5ForConditionalGeneration.from_pretrained("t5-small")
tokenizer = T5Tokenizer.from_pretrained("t5-small")
from transformers import T5Tokenizer, T5ForConditionalGeneration
input_ids = tokenizer('The <extra_id_0> walks in <extra_id_1> park', return_tensors='pt').input_ids
labels = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2>', return_tensors='pt').input_ids
# the forward function automatically creates the correct decoder_input_ids
loss = model(input_ids=input_ids, labels=labels).loss
tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5ForConditionalGeneration.from_pretrained("t5-small")
input_ids = tokenizer('The <extra_id_0> walks in <extra_id_1> park', return_tensors='pt').input_ids
labels = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2>', return_tensors='pt').input_ids
# the forward function automatically creates the correct decoder_input_ids
loss = model(input_ids=input_ids, labels=labels).loss
If you're interested in pre-training T5 on a new corpus, check out the `run_t5_mlm_flax.py
<https://github.com/huggingface/transformers/tree/master/examples/flax/language-modeling>`__ script in the Examples
directory.
- Supervised training
In this setup the input sequence and output sequence are standard sequence-to-sequence input output mapping. In
translation, for instance with the input sequence "The house is wonderful." and output sequence "Das Haus ist
wunderbar.", the sentences should be processed as follows:
In this setup, the input sequence and output sequence are a standard sequence-to-sequence input-output mapping.
Suppose that we want to fine-tune the model for translation for example, and we have a training example: the input
sequence "The house is wonderful." and output sequence "Das Haus ist wunderbar.", then they should be prepared for
the model as follows:
.. code-block::
from transformers import T5Tokenizer, T5ForConditionalGeneration
tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5ForConditionalGeneration.from_pretrained("t5-small")
input_ids = tokenizer('translate English to German: The house is wonderful.', return_tensors='pt').input_ids
labels = tokenizer('Das Haus ist wunderbar.', return_tensors='pt').input_ids
# the forward function automatically creates the correct decoder_input_ids
loss = model(input_ids=input_ids, labels=labels).loss
As you can see, only 2 inputs are required for the model in order to compute a loss: :obj:`input_ids` (which are the
:obj:`input_ids` of the encoded input sequence) and :obj:`labels` (which are the :obj:`input_ids` of the encoded
target sequence). The model will automatically create the :obj:`decoder_input_ids` based on the :obj:`labels`, by
shifting them one position to the right and prepending the :obj:`config.decoder_start_token_id`, which for T5 is
equal to 0 (i.e. the id of the pad token). Also note the task prefix: we prepend the input sequence with 'translate
English to German: ' before encoding it. This will help in improving the performance, as this task prefix was used
during T5's pre-training.
However, the example above only shows a single training example. In practice, one trains deep learning models in
batches. This entails that we must pad/truncate examples to the same length. For encoder-decoder models, one
typically defines a :obj:`max_source_length` and :obj:`max_target_length`, which determine the maximum length of the
input and output sequences respectively (otherwise they are truncated). These should be carefully set depending on
the task.
In addition, we must make sure that padding token id's of the :obj:`labels` are not taken into account by the loss
function. In PyTorch and Tensorflow, this can be done by replacing them with -100, which is the :obj:`ignore_index`
of the :obj:`CrossEntropyLoss`. In Flax, one can use the :obj:`decoder_attention_mask` to ignore padded tokens from
the loss (see the `Flax summarization script
<https://github.com/huggingface/transformers/tree/master/examples/flax/summarization>`__ for details). We also pass
:obj:`attention_mask` as additional input to the model, which makes sure that padding tokens of the inputs are
ignored. The code example below illustrates all of this.
.. code-block::
from transformers import T5Tokenizer, T5ForConditionalGeneration
import torch
tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5ForConditionalGeneration.from_pretrained("t5-small")
# the following 2 hyperparameters are task-specific
max_source_length = 512
max_target_length = 128
# Suppose we have the following 2 training examples:
input_sequence_1 = "Welcome to NYC"
output_sequence_1 = "Bienvenue à NYC"
input_sequence_2 = "HuggingFace is a company"
output_sequence_2 = "HuggingFace est une entreprise"
# encode the inputs
task_prefix = "translate English to French: "
input_sequences = [input_sequence_1, input_sequence_2]
encoding = tokenizer([task_prefix + sequence for sequence in input_sequences],
padding='longest',
max_length=max_source_length,
truncation=True,
return_tensors="pt")
input_ids, attention_mask = encoding.input_ids, encoding.attention_mask
# encode the targets
target_encoding = tokenizer([output_sequence_1, output_sequence_2],
padding='longest',
max_length=max_target_length,
truncation=True)
labels = target_encoding.input_ids
# replace padding token id's of the labels by -100
labels = [
[(label if label != tokenizer.pad_token_id else -100) for label in labels_example] for labels_example in labels
]
labels = torch.tensor(labels)
# forward pass
loss = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels).loss
Additional training tips:
- T5 models need a slightly higher learning rate than the default one set in the :obj:`Trainer` when using the AdamW
optimizer. Typically, 1e-4 and 3e-4 work well for most problems (classification, summarization, translation, question
answering, question generation). Note that T5 was pre-trained using the AdaFactor optimizer.
- According to `this forum post <https://discuss.huggingface.co/t/t5-finetuning-tips/684>`__, task prefixes matter when
(1) doing multi-task training (2) your task is similar or related to one of the supervised tasks used in T5's
pre-training mixture (see Appendix D of the `paper <https://arxiv.org/pdf/1910.10683.pdf>`__ for the task prefixes
used).
- If training on TPU, it is recommended to pad all examples of the dataset to the same length or make use of
`pad_to_multiple_of` to have a small number of predefined bucket sizes to fit all examples in. Dynamically padding
batches to the longest example is not recommended on TPU as it triggers a recompilation for every batch shape that is
encountered during training thus significantly slowing down the training. only padding up to the longest example in a
batch) leads to very slow training on TPU.
.. _inference:
Inference
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
At inference time, it is recommended to use :meth:`~transformers.generation_utils.GenerationMixin.generate`. This
method takes care of encoding the input and feeding the encoded hidden states via cross-attention layers to the decoder
and auto-regressively generates the decoder output. Check out `this blog post
<https://huggingface.co/blog/how-to-generate>`__ to know all the details about generating text with Transformers.
There's also `this blog post <https://huggingface.co/blog/encoder-decoder#encoder-decoder>`__ which explains how
generation works in general in encoder-decoder models.
.. code-block::
from transformers import T5ForConditionalGeneration, T5Tokenizer
model = T5ForConditionalGeneration.from_pretrained("t5-small")
tokenizer = T5Tokenizer.from_pretrained("t5-small")
from transformers import T5Tokenizer, T5ForConditionalGeneration
input_ids = tokenizer('translate English to German: The house is wonderful.', return_tensors='pt').input_ids
labels = tokenizer('Das Haus ist wunderbar.', return_tensors='pt').input_ids
# the forward function automatically creates the correct decoder_input_ids
loss = model(input_ids=input_ids, labels=labels).loss
tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5ForConditionalGeneration.from_pretrained("t5-small")
input_ids = tokenizer('translate English to German: The house is wonderful.', return_tensors='pt').input_ids
outputs = model.generate(input_ids)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
# Das Haus ist wunderbar.
Note that T5 uses the :obj:`pad_token_id` as the :obj:`decoder_start_token_id`, so when doing generation without using
:meth:`~transformers.generation_utils.GenerationMixin.generate`, make sure you start it with the :obj:`pad_token_id`.
The example above only shows a single example. You can also do batched inference, like so:
.. code-block::
from transformers import T5Tokenizer, T5ForConditionalGeneration
tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5ForConditionalGeneration.from_pretrained("t5-small")
# when generating, we will use the logits of right-most token to predict the next token
# so the padding should be on the left
tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token # to avoid an error
task_prefix = 'translate English to German: '
sentences = ['The house is wonderful.', 'I like to work in NYC.'] # use different length sentences to test batching
inputs = tokenizer([task_prefix + sentence for sentence in sentences], return_tensors="pt", padding=True)
output_sequences = model.generate(
input_ids=inputs['input_ids'],
attention_mask=inputs['attention_mask'],
do_sample=False, # disable sampling to test if batching affects output
)
print(tokenizer.batch_decode(output_sequences, skip_special_tokens=True))
# ['Das Haus ist wunderbar.', 'Ich arbeite gerne in NYC.']
.. _scripts:
Example scripts
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
T5 is supported by several example scripts, both for pre-training and fine-tuning.
* pre-training: the `run_t5_mlm_flax.py
<https://github.com/huggingface/transformers/blob/master/examples/flax/language-modeling/run_t5_mlm_flax.py>`__
script allows you to further pre-train T5 or pre-train T5 from scratch on your own data. The `t5_tokenizer_model.py
<https://github.com/huggingface/transformers/blob/master/examples/flax/language-modeling/t5_tokenizer_model.py>`__
script allows you to further train a T5 tokenizer or train a T5 Tokenizer from scratch on your own data. Note that
Flax (a neural network library on top of JAX) is particularly useful to train on TPU hardware.
* fine-tuning: T5 is supported by the official summarization scripts (`PyTorch
<https://github.com/huggingface/transformers/tree/master/examples/pytorch/summarization>`__, `Tensorflow
<https://github.com/huggingface/transformers/tree/master/examples/tensorflow/summarization>`__, and `Flax
<https://github.com/huggingface/transformers/tree/master/examples/flax/summarization>`__) and translation scripts
(`PyTorch <https://github.com/huggingface/transformers/tree/master/examples/pytorch/translation>`__ and `Tensorflow
<https://github.com/huggingface/transformers/tree/master/examples/tensorflow/translation>`__). These scripts allow
you to easily fine-tune T5 on custom data for summarization/translation.
T5Config
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -0,0 +1,66 @@
..
Copyright 2021 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
T5v1.1
-----------------------------------------------------------------------------------------------------------------------
Overview
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
T5v1.1 was released in the `google-research/text-to-text-transfer-transformer
<https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511>`__
repository by Colin Raffel et al. It's an improved version of the original T5 model.
One can directly plug in the weights of T5v1.1 into a T5 model, like so:
.. code-block::
from transformers import T5ForConditionalGeneration
model = T5ForConditionalGeneration.from_pretrained('google/t5-v1_1-base')
T5 Version 1.1 includes the following improvements compared to the original T5 model:
- GEGLU activation in the feed-forward hidden layer, rather than ReLU. See `this paper
<https://arxiv.org/abs/2002.05202>`__.
- Dropout was turned off in pre-training (quality win). Dropout should be re-enabled during fine-tuning.
- Pre-trained on C4 only without mixing in the downstream tasks.
- No parameter sharing between the embedding and classifier layer.
- "xl" and "xxl" replace "3B" and "11B". The model shapes are a bit different - larger :obj:`d_model` and smaller
:obj:`num_heads` and :obj:`d_ff`.
Note: T5 Version 1.1 was only pre-trained on `C4 <https://huggingface.co/datasets/c4>`__ excluding any supervised
training. Therefore, this model has to be fine-tuned before it is useable on a downstream task, unlike the original T5
model. Since t5v1.1 was pre-trained unsupervisedly, there's no real advantage to using a task prefix during single-task
fine-tuning. If you are doing multi-task fine-tuning, you should use a prefix.
Google has released the following variants:
- `google/t5-v1_1-small <https://huggingface.co/google/t5-v1_1-small>`__
- `google/t5-v1_1-base <https://huggingface.co/google/t5-v1_1-base>`__
- `google/t5-v1_1-large <https://huggingface.co/google/t5-v1_1-large>`__
- `google/t5-v1_1-xl <https://huggingface.co/google/t5-v1_1-xl>`__
- `google/t5-v1_1-xxl <https://huggingface.co/google/t5-v1_1-xxl>`__.
One can refer to :doc:`T5's documentation page <t5>` for all tips, code examples and notebooks.
This model was contributed by `patrickvonplaten <https://huggingface.co/patrickvonplaten>`__. The original code can be
found `here
<https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511>`__.

View File

@ -1059,11 +1059,11 @@ class FlaxT5PreTrainedModel(FlaxPreTrainedModel):
>>> from transformers import T5Tokenizer, FlaxT5ForConditionalGeneration
>>> model = FlaxT5ForConditionalGeneration.from_pretrained('t5-small')
>>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
>>> model = FlaxT5ForConditionalGeneration.from_pretrained('t5-small')
>>> text = "My friends are cool but they eat too many carbs."
>>> inputs = tokenizer(text, max_length=512, return_tensors='jax')
>>> inputs = tokenizer(text, return_tensors='np')
>>> encoder_outputs = model.encode(**inputs)
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@ -1118,19 +1118,20 @@ class FlaxT5PreTrainedModel(FlaxPreTrainedModel):
Example::
>>> from transformers import T5Tokenizer, FlaxT5ForConditionalGeneration
>>> import jax.numpy as jnp
>>> model = FlaxT5ForConditionalGeneration.from_pretrained('t5-small')
>>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
>>> model = FlaxT5ForConditionalGeneration.from_pretrained('t5-small')
>>> text = "My friends are cool but they eat too many carbs."
>>> inputs = tokenizer(text, max_length=512, return_tensors='jax')
>>> inputs = tokenizer(text, return_tensors='np')
>>> encoder_outputs = model.encode(**inputs)
>>> decoder_start_token_id = model.config.decoder_start_token_id
>>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
>>> outputs = model.decode(decoder_input_ids, encoder_outputs)
>>> last_decoder_hidden_states = outputs.last_hidden_state
>>> logits = outputs.logits
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@ -1329,8 +1330,9 @@ FLAX_T5_MODEL_DOCSTRING = """
>>> input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="np").input_ids
>>> decoder_input_ids = tokenizer("Studies show that", return_tensors="np").input_ids
>>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
>>> # forward pass
>>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
>>> last_hidden_states = outputs.last_hidden_state
"""
@ -1469,19 +1471,20 @@ class FlaxT5ForConditionalGeneration(FlaxT5PreTrainedModel):
Example::
>>> from transformers import T5Tokenizer, FlaxT5ForConditionalGeneration
>>> import jax.numpy as jnp
>>> model = FlaxT5ForConditionalGeneration.from_pretrained('t5-small')
>>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
>>> model = FlaxT5ForConditionalGeneration.from_pretrained('t5-small')
>>> text = "summarize: My friends are cool but they eat too many carbs."
>>> inputs = tokenizer(text, max_length=512, return_tensors='jax')
>>> inputs = tokenizer(text, return_tensors='np')
>>> encoder_outputs = model.encode(**inputs)
>>> decoder_start_token_id = model.config.decoder_start_token_id
>>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
>>> outputs = model.decode(decoder_input_ids, encoder_outputs)
>>> last_decoder_hidden_states = outputs.last_hidden_state
>>> logits = outputs.logits
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@ -1617,15 +1620,15 @@ FLAX_T5_CONDITIONAL_GENERATION_DOCSTRING = """
>>> from transformers import T5Tokenizer, FlaxT5ForConditionalGeneration
>>> model = FlaxT5ForConditionalGeneration.from_pretrained('t5-small')
>>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
>>> model = FlaxT5ForConditionalGeneration.from_pretrained('t5-small')
>>> ARTICLE_TO_SUMMARIZE = "summarize: My friends are cool but they eat too many carbs."
>>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=512, return_tensors='jax')
>>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], return_tensors='np')
>>> # Generate Summary
>>> summary_ids = model.generate(inputs['input_ids']).sequences
>>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False))
>>> print(tokenizer.decode(summary_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=False))
"""

View File

@ -1234,7 +1234,7 @@ num_heads)`.
@add_start_docstrings(
"The bare T5 Model transformer outputting raw hidden-states" "without any specific head on top.",
"The bare T5 Model transformer outputting raw hidden-states without any specific head on top.",
T5_START_DOCSTRING,
)
class T5Model(T5PreTrainedModel):
@ -1344,8 +1344,9 @@ class T5Model(T5PreTrainedModel):
>>> input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="pt").input_ids # Batch size 1
>>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
>>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
>>> # forward pass
>>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
>>> last_hidden_states = outputs.last_hidden_state
"""
use_cache = use_cache if use_cache is not None else self.config.use_cache
@ -1537,14 +1538,18 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
>>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
>>> model = T5ForConditionalGeneration.from_pretrained('t5-small')
>>> # training
>>> input_ids = tokenizer('The <extra_id_0> walks in <extra_id_1> park', return_tensors='pt').input_ids
>>> labels = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2> </s>', return_tensors='pt').input_ids
>>> labels = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2>', return_tensors='pt').input_ids
>>> outputs = model(input_ids=input_ids, labels=labels)
>>> loss = outputs.loss
>>> logits = outputs.logits
>>> input_ids = tokenizer("summarize: studies have shown that owning a dog is good for you ", return_tensors="pt").input_ids # Batch size 1
>>> # inference
>>> input_ids = tokenizer("summarize: studies have shown that owning a dog is good for you", return_tensors="pt").input_ids # Batch size 1
>>> outputs = model.generate(input_ids)
>>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
>>> # studies have shown that owning a dog is good for you.
"""
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

View File

@ -1133,8 +1133,10 @@ class TFT5Model(TFT5PreTrainedModel):
>>> input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="tf").input_ids # Batch size 1
>>> decoder_input_ids = tokenizer("Studies show that", return_tensors="tf").input_ids # Batch size 1
>>> outputs = model(input_ids, decoder_input_ids=decoder_input_ids)
>>> # forward pass
>>> outputs = model(input_ids, decoder_input_ids=decoder_input_ids)
>>> last_hidden_states = outputs.last_hidden_state
"""
# FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
@ -1321,15 +1323,18 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
>>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
>>> model = TFT5ForConditionalGeneration.from_pretrained('t5-small')
>>> # training
>>> inputs = tokenizer('The <extra_id_0> walks in <extra_id_1> park', return_tensors='tf').input_ids
>>> labels = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2> </s>', return_tensors='tf').input_ids
>>> labels = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2>', return_tensors='tf').input_ids
>>> outputs = model(inputs, labels=labels)
>>> loss = outputs.loss
>>> logits = outputs.logits
>>> inputs = tokenizer("summarize: studies have shown that owning a dog is good for you ", return_tensors="tf").input_ids # Batch size 1
>>> result = model.generate(inputs)
>>> # inference
>>> inputs = tokenizer("summarize: studies have shown that owning a dog is good for you", return_tensors="tf").input_ids # Batch size 1
>>> outputs = model.generate(inputs)
>>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
>>> # studies have shown that owning a dog is good for you
"""
# FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
@ -1571,7 +1576,7 @@ class TFT5EncoderModel(TFT5PreTrainedModel):
Examples::
>>> from transformers import T5Tokenizer, TFT5Model
>>> from transformers import T5Tokenizer, TFT5EncoderModel
>>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
>>> model = TFT5EncoderModel.from_pretrained('t5-small')
@ -1579,7 +1584,6 @@ class TFT5EncoderModel(TFT5PreTrainedModel):
>>> input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="tf").input_ids # Batch size 1
>>> outputs = model(input_ids)
"""
inputs = input_processing(
func=self.call,