mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-05 22:00:09 +06:00

* copy pytorch-t5 * init * boom boom * forward pass same * make generation work * add more tests * make test work * finish normal tests * make fix-copies * finish quality * correct slow example * correct slow test * version table * upload models * Update tests/test_modeling_flax_t5.py * correct incorrectly deleted line Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Patrick von Platen <patrick@huggingface.co>
175 lines
8.7 KiB
ReStructuredText
175 lines
8.7 KiB
ReStructuredText
..
|
|
Copyright 2020 The HuggingFace Team. All rights reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
|
the License. You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
|
specific language governing permissions and limitations under the License.
|
|
|
|
T5
|
|
-----------------------------------------------------------------------------------------------------------------------
|
|
|
|
**DISCLAIMER:** This model is still a work in progress, if you see something strange, file a `Github Issue
|
|
<https://github.com/huggingface/transformers/issues/new?assignees=&labels=&template=bug-report.md&title>`__.
|
|
|
|
Overview
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
The T5 model was presented in `Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer
|
|
<https://arxiv.org/pdf/1910.10683.pdf>`_ by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang,
|
|
Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu.
|
|
|
|
The abstract from the paper is the following:
|
|
|
|
*Transfer learning, where a model is first pre-trained on a data-rich task before being fine-tuned on a downstream
|
|
task, has emerged as a powerful technique in natural language processing (NLP). The effectiveness of transfer learning
|
|
has given rise to a diversity of approaches, methodology, and practice. In this paper, we explore the landscape of
|
|
transfer learning techniques for NLP by introducing a unified framework that converts every language problem into a
|
|
text-to-text format. Our systematic study compares pretraining objectives, architectures, unlabeled datasets, transfer
|
|
approaches, and other factors on dozens of language understanding tasks. By combining the insights from our exploration
|
|
with scale and our new "Colossal Clean Crawled Corpus", we achieve state-of-the-art results on many benchmarks covering
|
|
summarization, question answering, text classification, and more. To facilitate future work on transfer learning for
|
|
NLP, we release our dataset, pre-trained models, and code.*
|
|
|
|
Tips:
|
|
|
|
- T5 is an encoder-decoder model pre-trained on a multi-task mixture of unsupervised and supervised tasks and for which
|
|
each task is converted into a text-to-text format. T5 works well on a variety of tasks out-of-the-box by prepending a
|
|
different prefix to the input corresponding to each task, e.g., for translation: *translate English to German: ...*,
|
|
for summarization: *summarize: ...*.
|
|
|
|
For more information about which prefix to use, it is easiest to look into Appendix D of the `paper
|
|
<https://arxiv.org/pdf/1910.10683.pdf>`__. - For sequence-to-sequence generation, it is recommended to use
|
|
:meth:`~transformers.generation_utils.GenerationMixin.generate`. This method takes care of feeding the encoded input
|
|
via cross-attention layers to the decoder and auto-regressively generates the decoder output. - T5 uses relative
|
|
scalar embeddings. Encoder input padding can be done on the left and on the right.
|
|
|
|
This model was contributed by `thomwolf <https://huggingface.co/thomwolf>`__. The original code can be found `here
|
|
<https://github.com/google-research/text-to-text-transfer-transformer>`__.
|
|
|
|
Training
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
T5 is an encoder-decoder model and converts all NLP problems into a text-to-text format. It is trained using teacher
|
|
forcing. This means that for training we always need an input sequence and a target sequence. The input sequence is fed
|
|
to the model using :obj:`input_ids`. The target sequence is shifted to the right, i.e., prepended by a start-sequence
|
|
token and fed to the decoder using the :obj:`decoder_input_ids`. In teacher-forcing style, the target sequence is then
|
|
appended by the EOS token and corresponds to the :obj:`labels`. The PAD token is hereby used as the start-sequence
|
|
token. T5 can be trained / fine-tuned both in a supervised and unsupervised fashion.
|
|
|
|
- Unsupervised denoising training
|
|
|
|
In this setup spans of the input sequence are masked by so-called sentinel tokens (*a.k.a* unique mask tokens) and
|
|
the output sequence is formed as a concatenation of the same sentinel tokens and the *real* masked tokens. Each
|
|
sentinel token represents a unique mask token for this sentence and should start with :obj:`<extra_id_0>`,
|
|
:obj:`<extra_id_1>`, ... up to :obj:`<extra_id_99>`. As a default, 100 sentinel tokens are available in
|
|
:class:`~transformers.T5Tokenizer`.
|
|
|
|
For instance, the sentence "The cute dog walks in the park" with the masks put on "cute dog" and "the" should be
|
|
processed as follows:
|
|
|
|
.. code-block::
|
|
|
|
from transformers import T5ForConditionalGeneration, T5Tokenizer
|
|
model = T5ForConditionalGeneration.from_pretrained("t5-small")
|
|
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
|
|
|
input_ids = tokenizer('The <extra_id_0> walks in <extra_id_1> park', return_tensors='pt').input_ids
|
|
labels = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2>', return_tensors='pt').input_ids
|
|
# the forward function automatically creates the correct decoder_input_ids
|
|
loss = model(input_ids=input_ids, labels=labels).loss
|
|
|
|
- Supervised training
|
|
|
|
In this setup the input sequence and output sequence are standard sequence-to-sequence input output mapping. In
|
|
translation, for instance with the input sequence "The house is wonderful." and output sequence "Das Haus ist
|
|
wunderbar.", the sentences should be processed as follows:
|
|
|
|
.. code-block::
|
|
|
|
from transformers import T5ForConditionalGeneration, T5Tokenizer
|
|
model = T5ForConditionalGeneration.from_pretrained("t5-small")
|
|
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
|
|
|
input_ids = tokenizer('translate English to German: The house is wonderful.', return_tensors='pt').input_ids
|
|
labels = tokenizer('Das Haus ist wunderbar.', return_tensors='pt').input_ids
|
|
# the forward function automatically creates the correct decoder_input_ids
|
|
loss = model(input_ids=input_ids, labels=labels).loss
|
|
|
|
|
|
T5Config
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. autoclass:: transformers.T5Config
|
|
:members:
|
|
|
|
|
|
T5Tokenizer
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. autoclass:: transformers.T5Tokenizer
|
|
:members: build_inputs_with_special_tokens, get_special_tokens_mask,
|
|
create_token_type_ids_from_sequences, save_vocabulary
|
|
|
|
|
|
T5TokenizerFast
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. autoclass:: transformers.T5TokenizerFast
|
|
:members:
|
|
|
|
|
|
T5Model
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. autoclass:: transformers.T5Model
|
|
:members: forward, parallelize, deparallelize
|
|
|
|
|
|
T5ForConditionalGeneration
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. autoclass:: transformers.T5ForConditionalGeneration
|
|
:members: forward, parallelize, deparallelize
|
|
|
|
T5EncoderModel
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. autoclass:: transformers.T5EncoderModel
|
|
:members: forward, parallelize, deparallelize
|
|
|
|
TFT5Model
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. autoclass:: transformers.TFT5Model
|
|
:members: call
|
|
|
|
|
|
TFT5ForConditionalGeneration
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. autoclass:: transformers.TFT5ForConditionalGeneration
|
|
:members: call
|
|
|
|
TFT5EncoderModel
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. autoclass:: transformers.TFT5EncoderModel
|
|
:members: call
|
|
|
|
FlaxT5Model
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. autoclass:: transformers.FlaxT5Model
|
|
:members: __call__, encode, decode
|
|
|
|
FlaxT5ForConditionalGeneration
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
|
|
.. autoclass:: transformers.FlaxT5ForConditionalGeneration
|
|
:members: __call__, encode, decode
|