diff --git a/README.md b/README.md index 1c0a7189c09..cc0cb87a0ef 100644 --- a/README.md +++ b/README.md @@ -272,6 +272,7 @@ Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih. 1. **[Splinter](https://huggingface.co/transformers/model_doc/splinter.html)** (from Tel Aviv University), released together with the paper [Few-Shot Question Answering by Pretraining Span Selection](https://arxiv.org/abs/2101.00438) by Ori Ram, Yuval Kirstain, Jonathan Berant, Amir Globerson, Omer Levy. 1. **[SqueezeBert](https://huggingface.co/transformers/model_doc/squeezebert.html)** released with the paper [SqueezeBERT: What can computer vision teach NLP about efficient neural networks?](https://arxiv.org/abs/2006.11316) by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer. 1. **[T5](https://huggingface.co/transformers/model_doc/t5.html)** (from Google AI) released with the paper [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. +1. **[T5v1.1](https://huggingface.co/transformers/model_doc/t5v1.1.html)** (from Google AI) released in the repository [google-research/text-to-text-transfer-transformer](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. 1. **[TAPAS](https://huggingface.co/transformers/model_doc/tapas.html)** (from Google AI) released with the paper [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349) by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos. 1. **[Transformer-XL](https://huggingface.co/transformers/model_doc/transformerxl.html)** (from Google/CMU) released with the paper [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context](https://arxiv.org/abs/1901.02860) by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov. 1. **[Vision Transformer (ViT)](https://huggingface.co/transformers/model_doc/vit.html)** (from Google AI) released with the paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby. diff --git a/docs/source/index.rst b/docs/source/index.rst index cbdd1f8fd1a..5049c65b90a 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -282,35 +282,40 @@ Supported models 62. :doc:`T5 ` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer `__ by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. -63. :doc:`TAPAS ` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via +63. :doc:`T5v1.1 ` (from Google AI) released in the repository + `google-research/text-to-text-transfer-transformer + `__ by + Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi + Zhou and Wei Li and Peter J. Liu. +64. :doc:`TAPAS ` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via Pre-training `__ by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos. -64. :doc:`Transformer-XL ` (from Google/CMU) released with the paper `Transformer-XL: +65. :doc:`Transformer-XL ` (from Google/CMU) released with the paper `Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context `__ by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov. -65. :doc:`Vision Transformer (ViT) ` (from Google AI) released with the paper `An Image is Worth 16x16 +66. :doc:`Vision Transformer (ViT) ` (from Google AI) released with the paper `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale `__ by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby. -66. :doc:`VisualBERT ` (from UCLA NLP) released with the paper `VisualBERT: A Simple and +67. :doc:`VisualBERT ` (from UCLA NLP) released with the paper `VisualBERT: A Simple and Performant Baseline for Vision and Language `__ by Liunian Harold Li, Mark Yatskar, Da Yin, Cho-Jui Hsieh, Kai-Wei Chang. -67. :doc:`Wav2Vec2 ` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for +68. :doc:`Wav2Vec2 ` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations `__ by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli. -68. :doc:`XLM ` (from Facebook) released together with the paper `Cross-lingual Language Model +69. :doc:`XLM ` (from Facebook) released together with the paper `Cross-lingual Language Model Pretraining `__ by Guillaume Lample and Alexis Conneau. -69. :doc:`XLM-ProphetNet ` (from Microsoft Research) released with the paper `ProphetNet: +70. :doc:`XLM-ProphetNet ` (from Microsoft Research) released with the paper `ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training `__ by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou. -70. :doc:`XLM-RoBERTa ` (from Facebook AI), released together with the paper `Unsupervised +71. :doc:`XLM-RoBERTa ` (from Facebook AI), released together with the paper `Unsupervised Cross-lingual Representation Learning at Scale `__ by Alexis Conneau*, Kartikay Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke Zettlemoyer and Veselin Stoyanov. -71. :doc:`XLNet ` (from Google/CMU) released with the paper `​XLNet: Generalized Autoregressive +72. :doc:`XLNet ` (from Google/CMU) released with the paper `​XLNet: Generalized Autoregressive Pretraining for Language Understanding `__ by Zhilin Yang*, Zihang Dai*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le. -72. :doc:`XLSR-Wav2Vec2 ` (from Facebook AI) released with the paper `Unsupervised +73. :doc:`XLSR-Wav2Vec2 ` (from Facebook AI) released with the paper `Unsupervised Cross-Lingual Representation Learning For Speech Recognition `__ by Alexis Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael Auli. @@ -608,6 +613,7 @@ Flax), PyTorch, and/or TensorFlow. model_doc/splinter model_doc/squeezebert model_doc/t5 + model_doc/t5v1.1 model_doc/tapas model_doc/transformerxl model_doc/vit diff --git a/docs/source/model_doc/byt5.rst b/docs/source/model_doc/byt5.rst index ad8e272d0e3..590f651010d 100644 --- a/docs/source/model_doc/byt5.rst +++ b/docs/source/model_doc/byt5.rst @@ -39,8 +39,11 @@ experiments.* This model was contributed by `patrickvonplaten `__. The original code can be found `here `__. +ByT5's architecture is based on the T5v1.1 model, so one can refer to :doc:`T5v1.1's documentation page `. They +only differ in how inputs should be prepared for the model, see the code examples below. -ByT5's architecture is based on the T5 model, so one can refer to :doc:`T5's documentation page `. +Since ByT5 was pre-trained unsupervisedly, there's no real advantage to using a task prefix during single-task +fine-tuning. If you are doing multi-task fine-tuning, you should use a prefix. Example diff --git a/docs/source/model_doc/mt5.rst b/docs/source/model_doc/mt5.rst index 36dbf37b02f..6d752502d32 100644 --- a/docs/source/model_doc/mt5.rst +++ b/docs/source/model_doc/mt5.rst @@ -28,6 +28,23 @@ multilingual variant of T5 that was pre-trained on a new Common Crawl-based data the design and modified training of mT5 and demonstrate its state-of-the-art performance on many multilingual benchmarks. All of the code and model checkpoints* +Note: mT5 was only pre-trained on `mC4 `__ excluding any supervised training. +Therefore, this model has to be fine-tuned before it is useable on a downstream task, unlike the original T5 model. +Since mT5 was pre-trained unsupervisedly, there's no real advantage to using a task prefix during single-task +fine-tuning. If you are doing multi-task fine-tuning, you should use a prefix. + +Google has released the following variants: + +- `google/mt5-small `__ + +- `google/mt5-base `__ + +- `google/mt5-large `__ + +- `google/mt5-xl `__ + +- `google/mt5-xxl `__. + This model was contributed by `patrickvonplaten `__. The original code can be found `here `__. diff --git a/docs/source/model_doc/t5.rst b/docs/source/model_doc/t5.rst index 3bdabe239a9..238f6454f2d 100644 --- a/docs/source/model_doc/t5.rst +++ b/docs/source/model_doc/t5.rst @@ -13,9 +13,6 @@ T5 ----------------------------------------------------------------------------------------------------------------------- -**DISCLAIMER:** This model is still a work in progress, if you see something strange, file a `Github Issue -`__. - Overview ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -42,28 +39,56 @@ Tips: different prefix to the input corresponding to each task, e.g., for translation: *translate English to German: ...*, for summarization: *summarize: ...*. - For more information about which prefix to use, it is easiest to look into Appendix D of the `paper - `__. - For sequence-to-sequence generation, it is recommended to use - :meth:`~transformers.generation_utils.GenerationMixin.generate`. This method takes care of feeding the encoded input - via cross-attention layers to the decoder and auto-regressively generates the decoder output. - T5 uses relative - scalar embeddings. Encoder input padding can be done on the left and on the right. +- T5 uses relative scalar embeddings. Encoder input padding can be done on the left and on the right. + +- See the :ref:`training`, :ref:`inference` and :ref:`scripts` sections below for all details regarding usage. + +T5 comes in different sizes: + +- `t5-small `__ + +- `t5-base `__ + +- `t5-large `__ + +- `t5-3b `__ + +- `t5-11b `__. + +Based on the original T5 model, Google has released some follow-up works: + +- **T5v1.1**: T5v1.1 is an improved version of T5 with some architectural tweaks, and is pre-trained on C4 only without + mixing in the supervised tasks. Refer to the documentation of T5v1.1 which can be found :doc:`here `. + +- **mT5**: mT5 is a multilingual T5 model. It is pre-trained on the mC4 corpus, which includes 101 languages. Refer to + the documentation of mT5 which can be found :doc:`here `. + +- **byT5**: byT5 is a T5 model pre-trained on byte sequences rather than SentencePiece subword token sequences. Refer + to the documentation of byT5 which can be found :doc:`here `. + +All checkpoints can be found on the `hub `__. This model was contributed by `thomwolf `__. The original code can be found `here `__. +.. _training: + Training ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ T5 is an encoder-decoder model and converts all NLP problems into a text-to-text format. It is trained using teacher -forcing. This means that for training we always need an input sequence and a target sequence. The input sequence is fed -to the model using :obj:`input_ids`. The target sequence is shifted to the right, i.e., prepended by a start-sequence -token and fed to the decoder using the :obj:`decoder_input_ids`. In teacher-forcing style, the target sequence is then -appended by the EOS token and corresponds to the :obj:`labels`. The PAD token is hereby used as the start-sequence -token. T5 can be trained / fine-tuned both in a supervised and unsupervised fashion. +forcing. This means that for training, we always need an input sequence and a corresponding target sequence. The input +sequence is fed to the model using :obj:`input_ids`. The target sequence is shifted to the right, i.e., prepended by a +start-sequence token and fed to the decoder using the :obj:`decoder_input_ids`. In teacher-forcing style, the target +sequence is then appended by the EOS token and corresponds to the :obj:`labels`. The PAD token is hereby used as the +start-sequence token. T5 can be trained / fine-tuned both in a supervised and unsupervised fashion. + +One can use :class:`~transformers.T5ForConditionalGeneration` (or the Tensorflow/Flax variant), which includes the +language modeling head on top of the decoder. - Unsupervised denoising training - In this setup spans of the input sequence are masked by so-called sentinel tokens (*a.k.a* unique mask tokens) and + In this setup, spans of the input sequence are masked by so-called sentinel tokens (*a.k.a* unique mask tokens) and the output sequence is formed as a concatenation of the same sentinel tokens and the *real* masked tokens. Each sentinel token represents a unique mask token for this sentence and should start with :obj:``, :obj:``, ... up to :obj:``. As a default, 100 sentinel tokens are available in @@ -72,34 +97,201 @@ token. T5 can be trained / fine-tuned both in a supervised and unsupervised fash For instance, the sentence "The cute dog walks in the park" with the masks put on "cute dog" and "the" should be processed as follows: -.. code-block:: + .. code-block:: - from transformers import T5ForConditionalGeneration, T5Tokenizer - model = T5ForConditionalGeneration.from_pretrained("t5-small") - tokenizer = T5Tokenizer.from_pretrained("t5-small") + from transformers import T5Tokenizer, T5ForConditionalGeneration - input_ids = tokenizer('The walks in park', return_tensors='pt').input_ids - labels = tokenizer(' cute dog the ', return_tensors='pt').input_ids - # the forward function automatically creates the correct decoder_input_ids - loss = model(input_ids=input_ids, labels=labels).loss + tokenizer = T5Tokenizer.from_pretrained("t5-small") + model = T5ForConditionalGeneration.from_pretrained("t5-small") + + input_ids = tokenizer('The walks in park', return_tensors='pt').input_ids + labels = tokenizer(' cute dog the ', return_tensors='pt').input_ids + # the forward function automatically creates the correct decoder_input_ids + loss = model(input_ids=input_ids, labels=labels).loss + + If you're interested in pre-training T5 on a new corpus, check out the `run_t5_mlm_flax.py + `__ script in the Examples + directory. - Supervised training - In this setup the input sequence and output sequence are standard sequence-to-sequence input output mapping. In - translation, for instance with the input sequence "The house is wonderful." and output sequence "Das Haus ist - wunderbar.", the sentences should be processed as follows: + In this setup, the input sequence and output sequence are a standard sequence-to-sequence input-output mapping. + Suppose that we want to fine-tune the model for translation for example, and we have a training example: the input + sequence "The house is wonderful." and output sequence "Das Haus ist wunderbar.", then they should be prepared for + the model as follows: + + .. code-block:: + + from transformers import T5Tokenizer, T5ForConditionalGeneration + + tokenizer = T5Tokenizer.from_pretrained("t5-small") + model = T5ForConditionalGeneration.from_pretrained("t5-small") + + input_ids = tokenizer('translate English to German: The house is wonderful.', return_tensors='pt').input_ids + labels = tokenizer('Das Haus ist wunderbar.', return_tensors='pt').input_ids + # the forward function automatically creates the correct decoder_input_ids + loss = model(input_ids=input_ids, labels=labels).loss + + As you can see, only 2 inputs are required for the model in order to compute a loss: :obj:`input_ids` (which are the + :obj:`input_ids` of the encoded input sequence) and :obj:`labels` (which are the :obj:`input_ids` of the encoded + target sequence). The model will automatically create the :obj:`decoder_input_ids` based on the :obj:`labels`, by + shifting them one position to the right and prepending the :obj:`config.decoder_start_token_id`, which for T5 is + equal to 0 (i.e. the id of the pad token). Also note the task prefix: we prepend the input sequence with 'translate + English to German: ' before encoding it. This will help in improving the performance, as this task prefix was used + during T5's pre-training. + + However, the example above only shows a single training example. In practice, one trains deep learning models in + batches. This entails that we must pad/truncate examples to the same length. For encoder-decoder models, one + typically defines a :obj:`max_source_length` and :obj:`max_target_length`, which determine the maximum length of the + input and output sequences respectively (otherwise they are truncated). These should be carefully set depending on + the task. + + In addition, we must make sure that padding token id's of the :obj:`labels` are not taken into account by the loss + function. In PyTorch and Tensorflow, this can be done by replacing them with -100, which is the :obj:`ignore_index` + of the :obj:`CrossEntropyLoss`. In Flax, one can use the :obj:`decoder_attention_mask` to ignore padded tokens from + the loss (see the `Flax summarization script + `__ for details). We also pass + :obj:`attention_mask` as additional input to the model, which makes sure that padding tokens of the inputs are + ignored. The code example below illustrates all of this. + + .. code-block:: + + from transformers import T5Tokenizer, T5ForConditionalGeneration + import torch + + tokenizer = T5Tokenizer.from_pretrained("t5-small") + model = T5ForConditionalGeneration.from_pretrained("t5-small") + + # the following 2 hyperparameters are task-specific + max_source_length = 512 + max_target_length = 128 + + # Suppose we have the following 2 training examples: + input_sequence_1 = "Welcome to NYC" + output_sequence_1 = "Bienvenue à NYC" + + input_sequence_2 = "HuggingFace is a company" + output_sequence_2 = "HuggingFace est une entreprise" + + # encode the inputs + task_prefix = "translate English to French: " + input_sequences = [input_sequence_1, input_sequence_2] + encoding = tokenizer([task_prefix + sequence for sequence in input_sequences], + padding='longest', + max_length=max_source_length, + truncation=True, + return_tensors="pt") + input_ids, attention_mask = encoding.input_ids, encoding.attention_mask + + # encode the targets + target_encoding = tokenizer([output_sequence_1, output_sequence_2], + padding='longest', + max_length=max_target_length, + truncation=True) + labels = target_encoding.input_ids + + # replace padding token id's of the labels by -100 + labels = [ + [(label if label != tokenizer.pad_token_id else -100) for label in labels_example] for labels_example in labels + ] + labels = torch.tensor(labels) + + # forward pass + loss = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels).loss + +Additional training tips: + +- T5 models need a slightly higher learning rate than the default one set in the :obj:`Trainer` when using the AdamW + optimizer. Typically, 1e-4 and 3e-4 work well for most problems (classification, summarization, translation, question + answering, question generation). Note that T5 was pre-trained using the AdaFactor optimizer. + +- According to `this forum post `__, task prefixes matter when + (1) doing multi-task training (2) your task is similar or related to one of the supervised tasks used in T5's + pre-training mixture (see Appendix D of the `paper `__ for the task prefixes + used). + +- If training on TPU, it is recommended to pad all examples of the dataset to the same length or make use of + `pad_to_multiple_of` to have a small number of predefined bucket sizes to fit all examples in. Dynamically padding + batches to the longest example is not recommended on TPU as it triggers a recompilation for every batch shape that is + encountered during training thus significantly slowing down the training. only padding up to the longest example in a + batch) leads to very slow training on TPU. + +.. _inference: + +Inference +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +At inference time, it is recommended to use :meth:`~transformers.generation_utils.GenerationMixin.generate`. This +method takes care of encoding the input and feeding the encoded hidden states via cross-attention layers to the decoder +and auto-regressively generates the decoder output. Check out `this blog post +`__ to know all the details about generating text with Transformers. +There's also `this blog post `__ which explains how +generation works in general in encoder-decoder models. .. code-block:: - from transformers import T5ForConditionalGeneration, T5Tokenizer - model = T5ForConditionalGeneration.from_pretrained("t5-small") - tokenizer = T5Tokenizer.from_pretrained("t5-small") + from transformers import T5Tokenizer, T5ForConditionalGeneration - input_ids = tokenizer('translate English to German: The house is wonderful.', return_tensors='pt').input_ids - labels = tokenizer('Das Haus ist wunderbar.', return_tensors='pt').input_ids - # the forward function automatically creates the correct decoder_input_ids - loss = model(input_ids=input_ids, labels=labels).loss + tokenizer = T5Tokenizer.from_pretrained("t5-small") + model = T5ForConditionalGeneration.from_pretrained("t5-small") + input_ids = tokenizer('translate English to German: The house is wonderful.', return_tensors='pt').input_ids + outputs = model.generate(input_ids) + print(tokenizer.decode(outputs[0], skip_special_tokens=True)) + # Das Haus ist wunderbar. + +Note that T5 uses the :obj:`pad_token_id` as the :obj:`decoder_start_token_id`, so when doing generation without using +:meth:`~transformers.generation_utils.GenerationMixin.generate`, make sure you start it with the :obj:`pad_token_id`. + +The example above only shows a single example. You can also do batched inference, like so: + +.. code-block:: + + from transformers import T5Tokenizer, T5ForConditionalGeneration + + tokenizer = T5Tokenizer.from_pretrained("t5-small") + model = T5ForConditionalGeneration.from_pretrained("t5-small") + + # when generating, we will use the logits of right-most token to predict the next token + # so the padding should be on the left + tokenizer.padding_side = "left" + tokenizer.pad_token = tokenizer.eos_token # to avoid an error + + task_prefix = 'translate English to German: ' + sentences = ['The house is wonderful.', 'I like to work in NYC.'] # use different length sentences to test batching + inputs = tokenizer([task_prefix + sentence for sentence in sentences], return_tensors="pt", padding=True) + + output_sequences = model.generate( + input_ids=inputs['input_ids'], + attention_mask=inputs['attention_mask'], + do_sample=False, # disable sampling to test if batching affects output + ) + + print(tokenizer.batch_decode(output_sequences, skip_special_tokens=True)) + + # ['Das Haus ist wunderbar.', 'Ich arbeite gerne in NYC.'] + +.. _scripts: + +Example scripts +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +T5 is supported by several example scripts, both for pre-training and fine-tuning. + +* pre-training: the `run_t5_mlm_flax.py + `__ + script allows you to further pre-train T5 or pre-train T5 from scratch on your own data. The `t5_tokenizer_model.py + `__ + script allows you to further train a T5 tokenizer or train a T5 Tokenizer from scratch on your own data. Note that + Flax (a neural network library on top of JAX) is particularly useful to train on TPU hardware. + +* fine-tuning: T5 is supported by the official summarization scripts (`PyTorch + `__, `Tensorflow + `__, and `Flax + `__) and translation scripts + (`PyTorch `__ and `Tensorflow + `__). These scripts allow + you to easily fine-tune T5 on custom data for summarization/translation. T5Config ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/model_doc/t5v1.1.rst b/docs/source/model_doc/t5v1.1.rst new file mode 100644 index 00000000000..9de0881e215 --- /dev/null +++ b/docs/source/model_doc/t5v1.1.rst @@ -0,0 +1,66 @@ +.. + Copyright 2021 The HuggingFace Team. All rights reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + specific language governing permissions and limitations under the License. + +T5v1.1 +----------------------------------------------------------------------------------------------------------------------- + +Overview +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +T5v1.1 was released in the `google-research/text-to-text-transfer-transformer +`__ +repository by Colin Raffel et al. It's an improved version of the original T5 model. + +One can directly plug in the weights of T5v1.1 into a T5 model, like so: + +.. code-block:: + + from transformers import T5ForConditionalGeneration + + model = T5ForConditionalGeneration.from_pretrained('google/t5-v1_1-base') + +T5 Version 1.1 includes the following improvements compared to the original T5 model: + +- GEGLU activation in the feed-forward hidden layer, rather than ReLU. See `this paper + `__. + +- Dropout was turned off in pre-training (quality win). Dropout should be re-enabled during fine-tuning. + +- Pre-trained on C4 only without mixing in the downstream tasks. + +- No parameter sharing between the embedding and classifier layer. + +- "xl" and "xxl" replace "3B" and "11B". The model shapes are a bit different - larger :obj:`d_model` and smaller + :obj:`num_heads` and :obj:`d_ff`. + +Note: T5 Version 1.1 was only pre-trained on `C4 `__ excluding any supervised +training. Therefore, this model has to be fine-tuned before it is useable on a downstream task, unlike the original T5 +model. Since t5v1.1 was pre-trained unsupervisedly, there's no real advantage to using a task prefix during single-task +fine-tuning. If you are doing multi-task fine-tuning, you should use a prefix. + +Google has released the following variants: + +- `google/t5-v1_1-small `__ + +- `google/t5-v1_1-base `__ + +- `google/t5-v1_1-large `__ + +- `google/t5-v1_1-xl `__ + +- `google/t5-v1_1-xxl `__. + +One can refer to :doc:`T5's documentation page ` for all tips, code examples and notebooks. + +This model was contributed by `patrickvonplaten `__. The original code can be +found `here +`__. diff --git a/src/transformers/models/t5/modeling_flax_t5.py b/src/transformers/models/t5/modeling_flax_t5.py index 5f493b57676..0aea4133101 100644 --- a/src/transformers/models/t5/modeling_flax_t5.py +++ b/src/transformers/models/t5/modeling_flax_t5.py @@ -1059,11 +1059,11 @@ class FlaxT5PreTrainedModel(FlaxPreTrainedModel): >>> from transformers import T5Tokenizer, FlaxT5ForConditionalGeneration - >>> model = FlaxT5ForConditionalGeneration.from_pretrained('t5-small') >>> tokenizer = T5Tokenizer.from_pretrained('t5-small') + >>> model = FlaxT5ForConditionalGeneration.from_pretrained('t5-small') >>> text = "My friends are cool but they eat too many carbs." - >>> inputs = tokenizer(text, max_length=512, return_tensors='jax') + >>> inputs = tokenizer(text, return_tensors='np') >>> encoder_outputs = model.encode(**inputs) """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -1118,19 +1118,20 @@ class FlaxT5PreTrainedModel(FlaxPreTrainedModel): Example:: >>> from transformers import T5Tokenizer, FlaxT5ForConditionalGeneration + >>> import jax.numpy as jnp - >>> model = FlaxT5ForConditionalGeneration.from_pretrained('t5-small') >>> tokenizer = T5Tokenizer.from_pretrained('t5-small') + >>> model = FlaxT5ForConditionalGeneration.from_pretrained('t5-small') >>> text = "My friends are cool but they eat too many carbs." - >>> inputs = tokenizer(text, max_length=512, return_tensors='jax') + >>> inputs = tokenizer(text, return_tensors='np') >>> encoder_outputs = model.encode(**inputs) >>> decoder_start_token_id = model.config.decoder_start_token_id >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id >>> outputs = model.decode(decoder_input_ids, encoder_outputs) - >>> last_decoder_hidden_states = outputs.last_hidden_state + >>> logits = outputs.logits """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1329,8 +1330,9 @@ FLAX_T5_MODEL_DOCSTRING = """ >>> input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="np").input_ids >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="np").input_ids - >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) + >>> # forward pass + >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) >>> last_hidden_states = outputs.last_hidden_state """ @@ -1469,19 +1471,20 @@ class FlaxT5ForConditionalGeneration(FlaxT5PreTrainedModel): Example:: >>> from transformers import T5Tokenizer, FlaxT5ForConditionalGeneration + >>> import jax.numpy as jnp - >>> model = FlaxT5ForConditionalGeneration.from_pretrained('t5-small') >>> tokenizer = T5Tokenizer.from_pretrained('t5-small') + >>> model = FlaxT5ForConditionalGeneration.from_pretrained('t5-small') >>> text = "summarize: My friends are cool but they eat too many carbs." - >>> inputs = tokenizer(text, max_length=512, return_tensors='jax') + >>> inputs = tokenizer(text, return_tensors='np') >>> encoder_outputs = model.encode(**inputs) >>> decoder_start_token_id = model.config.decoder_start_token_id >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id >>> outputs = model.decode(decoder_input_ids, encoder_outputs) - >>> last_decoder_hidden_states = outputs.last_hidden_state + >>> logits = outputs.logits """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1617,15 +1620,15 @@ FLAX_T5_CONDITIONAL_GENERATION_DOCSTRING = """ >>> from transformers import T5Tokenizer, FlaxT5ForConditionalGeneration - >>> model = FlaxT5ForConditionalGeneration.from_pretrained('t5-small') >>> tokenizer = T5Tokenizer.from_pretrained('t5-small') + >>> model = FlaxT5ForConditionalGeneration.from_pretrained('t5-small') >>> ARTICLE_TO_SUMMARIZE = "summarize: My friends are cool but they eat too many carbs." - >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=512, return_tensors='jax') + >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], return_tensors='np') >>> # Generate Summary >>> summary_ids = model.generate(inputs['input_ids']).sequences - >>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)) + >>> print(tokenizer.decode(summary_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)) """ diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 8f07f04c550..27ef440bfb1 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -1234,7 +1234,7 @@ num_heads)`. @add_start_docstrings( - "The bare T5 Model transformer outputting raw hidden-states" "without any specific head on top.", + "The bare T5 Model transformer outputting raw hidden-states without any specific head on top.", T5_START_DOCSTRING, ) class T5Model(T5PreTrainedModel): @@ -1344,8 +1344,9 @@ class T5Model(T5PreTrainedModel): >>> input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="pt").input_ids # Batch size 1 >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1 - >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) + >>> # forward pass + >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) >>> last_hidden_states = outputs.last_hidden_state """ use_cache = use_cache if use_cache is not None else self.config.use_cache @@ -1537,14 +1538,18 @@ class T5ForConditionalGeneration(T5PreTrainedModel): >>> tokenizer = T5Tokenizer.from_pretrained('t5-small') >>> model = T5ForConditionalGeneration.from_pretrained('t5-small') + >>> # training >>> input_ids = tokenizer('The walks in park', return_tensors='pt').input_ids - >>> labels = tokenizer(' cute dog the ', return_tensors='pt').input_ids + >>> labels = tokenizer(' cute dog the ', return_tensors='pt').input_ids >>> outputs = model(input_ids=input_ids, labels=labels) >>> loss = outputs.loss >>> logits = outputs.logits - >>> input_ids = tokenizer("summarize: studies have shown that owning a dog is good for you ", return_tensors="pt").input_ids # Batch size 1 + >>> # inference + >>> input_ids = tokenizer("summarize: studies have shown that owning a dog is good for you", return_tensors="pt").input_ids # Batch size 1 >>> outputs = model.generate(input_ids) + >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True)) + >>> # studies have shown that owning a dog is good for you. """ use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict diff --git a/src/transformers/models/t5/modeling_tf_t5.py b/src/transformers/models/t5/modeling_tf_t5.py index 17f5a1dd887..abff17f36f1 100644 --- a/src/transformers/models/t5/modeling_tf_t5.py +++ b/src/transformers/models/t5/modeling_tf_t5.py @@ -1133,8 +1133,10 @@ class TFT5Model(TFT5PreTrainedModel): >>> input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="tf").input_ids # Batch size 1 >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="tf").input_ids # Batch size 1 - >>> outputs = model(input_ids, decoder_input_ids=decoder_input_ids) + >>> # forward pass + >>> outputs = model(input_ids, decoder_input_ids=decoder_input_ids) + >>> last_hidden_states = outputs.last_hidden_state """ # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask @@ -1321,15 +1323,18 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling >>> tokenizer = T5Tokenizer.from_pretrained('t5-small') >>> model = TFT5ForConditionalGeneration.from_pretrained('t5-small') + >>> # training >>> inputs = tokenizer('The walks in park', return_tensors='tf').input_ids - >>> labels = tokenizer(' cute dog the ', return_tensors='tf').input_ids + >>> labels = tokenizer(' cute dog the ', return_tensors='tf').input_ids >>> outputs = model(inputs, labels=labels) >>> loss = outputs.loss >>> logits = outputs.logits - >>> inputs = tokenizer("summarize: studies have shown that owning a dog is good for you ", return_tensors="tf").input_ids # Batch size 1 - - >>> result = model.generate(inputs) + >>> # inference + >>> inputs = tokenizer("summarize: studies have shown that owning a dog is good for you", return_tensors="tf").input_ids # Batch size 1 + >>> outputs = model.generate(inputs) + >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True)) + >>> # studies have shown that owning a dog is good for you """ # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask @@ -1571,7 +1576,7 @@ class TFT5EncoderModel(TFT5PreTrainedModel): Examples:: - >>> from transformers import T5Tokenizer, TFT5Model + >>> from transformers import T5Tokenizer, TFT5EncoderModel >>> tokenizer = T5Tokenizer.from_pretrained('t5-small') >>> model = TFT5EncoderModel.from_pretrained('t5-small') @@ -1579,7 +1584,6 @@ class TFT5EncoderModel(TFT5PreTrainedModel): >>> input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="tf").input_ids # Batch size 1 >>> outputs = model(input_ids) - """ inputs = input_processing( func=self.call,