[Doctests] Fix all T5 doc tests (#16646)

* [Doctests] Fix all T5 doc tests

* make style

* Update docs/source/en/model_doc/t5.mdx

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Apply Sylvains comments

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Patrick von Platen 2022-04-13 11:36:54 +02:00 committed by GitHub
parent f7196f2e63
commit b24201fa44
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 234 additions and 144 deletions

View File

@ -48,37 +48,98 @@ fine-tuning. If you are doing multi-task fine-tuning, you should use a prefix.
ByT5 works on raw UTF-8 bytes, so it can be used without a tokenizer:
```python
from transformers import T5ForConditionalGeneration
import torch
>>> from transformers import T5ForConditionalGeneration
>>> import torch
model = T5ForConditionalGeneration.from_pretrained("google/byt5-small")
>>> model = T5ForConditionalGeneration.from_pretrained("google/byt5-small")
input_ids = torch.tensor([list("Life is like a box of chocolates.".encode("utf-8"))]) + 3 # add 3 for special tokens
labels = (
torch.tensor([list("La vie est comme une boîte de chocolat.".encode("utf-8"))]) + 3
) # add 3 for special tokens
>>> num_special_tokens = 3
>>> # Model has 3 special tokens which take up the input ids 0,1,2 of ByT5.
>>> # => Need to shift utf-8 character encodings by 3 before passing ids to model.
loss = model(input_ids, labels=labels).loss # forward pass
>>> input_ids = torch.tensor([list("Life is like a box of chocolates.".encode("utf-8"))]) + num_special_tokens
>>> labels = torch.tensor([list("La vie est comme une boîte de chocolat.".encode("utf-8"))]) + num_special_tokens
>>> loss = model(input_ids, labels=labels).loss
>>> loss.item()
2.66
```
For batched inference and training it is however recommended to make use of the tokenizer:
```python
from transformers import T5ForConditionalGeneration, AutoTokenizer
>>> from transformers import T5ForConditionalGeneration, AutoTokenizer
model = T5ForConditionalGeneration.from_pretrained("google/byt5-small")
tokenizer = AutoTokenizer.from_pretrained("google/byt5-small")
>>> model = T5ForConditionalGeneration.from_pretrained("google/byt5-small")
>>> tokenizer = AutoTokenizer.from_pretrained("google/byt5-small")
model_inputs = tokenizer(
["Life is like a box of chocolates.", "Today is Monday."], padding="longest", return_tensors="pt"
)
labels = tokenizer(
["La vie est comme une boîte de chocolat.", "Aujourd'hui c'est lundi."], padding="longest", return_tensors="pt"
).input_ids
>>> model_inputs = tokenizer(
... ["Life is like a box of chocolates.", "Today is Monday."], padding="longest", return_tensors="pt"
... )
>>> labels_dict = tokenizer(
... ["La vie est comme une boîte de chocolat.", "Aujourd'hui c'est lundi."], padding="longest", return_tensors="pt"
... )
>>> labels = labels_dict.input_ids
loss = model(**model_inputs, labels=labels).loss # forward pass
>>> loss = model(**model_inputs, labels=labels).loss
>>> loss.item()
17.9
```
Similar to [T5](t5), ByT5 was trained on the span-mask denoising task. However,
since the model works directly on characters, the pretraining task is a bit
different. Let's corrupt some characters of the
input sentence `"The dog chases a ball in the park."` and ask ByT5 to predict them
for us.
```python
>>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
>>> import torch
>>> tokenizer = AutoTokenizer.from_pretrained("google/byt5-base")
>>> model = AutoModelForSeq2SeqLM.from_pretrained("google/byt5-base")
>>> input_ids_prompt = "The dog chases a ball in the park."
>>> input_ids = tokenizer(input_ids_prompt).input_ids
>>> # Note that we cannot add "{extra_id_...}" to the string directly
>>> # as the Byte tokenizer would incorrectly merge the tokens
>>> # For ByT5, we need to work directly on the character level
>>> # Contrary to T5, ByT5 does not use sentinel tokens for masking, but instead
>>> # uses final utf character ids.
>>> # UTF-8 is represented by 8 bits and ByT5 has 3 special tokens.
>>> # => There are 2**8+2 = 259 input ids and mask tokens count down from index 258.
>>> # => mask to "The dog [258]a ball [257]park."
>>> input_ids = torch.tensor([input_ids[:8] + [258] + input_ids[14:21] + [257] + input_ids[28:]])
>>> input_ids
tensor([[ 87, 107, 104, 35, 103, 114, 106, 35, 258, 35, 100, 35, 101, 100, 111, 111, 257, 35, 115, 100, 117, 110, 49, 1]])
>>> # ByT5 produces only one char at a time so we need to produce many more output characters here -> set `max_length=100`.
>>> output_ids = model.generate(input_ids, max_length=100)[0].tolist()
>>> output_ids
[0, 258, 108, 118, 35, 119, 107, 104, 35, 114, 113, 104, 35, 122, 107, 114, 35, 103, 114, 104, 118, 257, 35, 108, 113, 35, 119, 107, 104, 35, 103, 108, 118, 102, 114, 256, 108, 113, 35, 119, 107, 104, 35, 115, 100, 117, 110, 49, 35, 87, 107, 104, 35, 103, 114, 106, 35, 108, 118, 35, 119, 107, 104, 35, 114, 113, 104, 35, 122, 107, 114, 35, 103, 114, 104, 118, 35, 100, 35, 101, 100, 111, 111, 35, 108, 113, 255, 35, 108, 113, 35, 119, 107, 104, 35, 115, 100, 117, 110, 49]
>>> # ^- Note how 258 descends to 257, 256, 255
>>> # Now we need to split on the sentinel tokens, let's write a short loop for this
>>> output_ids_list = []
>>> start_token = 0
>>> sentinel_token = 258
>>> while sentinel_token in output_ids:
... split_idx = output_ids.index(sentinel_token)
... output_ids_list.append(output_ids[start_token:split_idx])
... start_token = split_idx
... sentinel_token -= 1
>>> output_ids_list.append(output_ids[start_token:])
>>> output_string = tokenizer.batch_decode(output_ids_list)
>>> output_string
['<pad>', 'is the one who does', ' in the disco', 'in the park. The dog is the one who does a ball in', ' in the park.']
```
## ByT5Tokenizer
[[autodoc]] ByT5Tokenizer

View File

@ -32,9 +32,9 @@ NLP, we release our dataset, pre-trained models, and code.*
Tips:
- T5 is an encoder-decoder model pre-trained on a multi-task mixture of unsupervised and supervised tasks and for which
each task is converted into a text-to-text format. T5 works well on a variety of tasks out-of-the-box by prepending a
different prefix to the input corresponding to each task, e.g., for translation: *translate English to German: ...*,
for summarization: *summarize: ...*.
each task is converted into a text-to-text format. T5 works well on a variety of tasks out-of-the-box by prepending a
different prefix to the input corresponding to each task, e.g., for translation: *translate English to German: ...*,
for summarization: *summarize: ...*.
- T5 uses relative scalar embeddings. Encoder input padding can be done on the left and on the right.
@ -83,130 +83,140 @@ language modeling head on top of the decoder.
- Unsupervised denoising training
In this setup, spans of the input sequence are masked by so-called sentinel tokens (*a.k.a* unique mask tokens) and
the output sequence is formed as a concatenation of the same sentinel tokens and the *real* masked tokens. Each
sentinel token represents a unique mask token for this sentence and should start with `<extra_id_0>`,
`<extra_id_1>`, ... up to `<extra_id_99>`. As a default, 100 sentinel tokens are available in
[`T5Tokenizer`].
In this setup, spans of the input sequence are masked by so-called sentinel tokens (*a.k.a* unique mask tokens) and
the output sequence is formed as a concatenation of the same sentinel tokens and the *real* masked tokens. Each
sentinel token represents a unique mask token for this sentence and should start with `<extra_id_0>`,
`<extra_id_1>`, ... up to `<extra_id_99>`. As a default, 100 sentinel tokens are available in
[`T5Tokenizer`].
For instance, the sentence "The cute dog walks in the park" with the masks put on "cute dog" and "the" should be
processed as follows:
For instance, the sentence "The cute dog walks in the park" with the masks put on "cute dog" and "the" should be
processed as follows:
```python
from transformers import T5Tokenizer, T5ForConditionalGeneration
```python
>>> from transformers import T5Tokenizer, T5ForConditionalGeneration
tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5ForConditionalGeneration.from_pretrained("t5-small")
>>> tokenizer = T5Tokenizer.from_pretrained("t5-small")
>>> model = T5ForConditionalGeneration.from_pretrained("t5-small")
input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
labels = tokenizer("<extra_id_0> cute dog <extra_id_1> the <extra_id_2>", return_tensors="pt").input_ids
# the forward function automatically creates the correct decoder_input_ids
loss = model(input_ids=input_ids, labels=labels).loss
```
>>> input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
>>> labels = tokenizer("<extra_id_0> cute dog <extra_id_1> the <extra_id_2>", return_tensors="pt").input_ids
If you're interested in pre-training T5 on a new corpus, check out the [run_t5_mlm_flax.py](https://github.com/huggingface/transformers/tree/main/examples/flax/language-modeling) script in the Examples
directory.
>>> # the forward function automatically creates the correct decoder_input_ids
>>> loss = model(input_ids=input_ids, labels=labels).loss
>>> loss.item()
3.7837
```
If you're interested in pre-training T5 on a new corpus, check out the [run_t5_mlm_flax.py](https://github.com/huggingface/transformers/tree/main/examples/flax/language-modeling) script in the Examples
directory.
- Supervised training
In this setup, the input sequence and output sequence are a standard sequence-to-sequence input-output mapping.
Suppose that we want to fine-tune the model for translation for example, and we have a training example: the input
sequence "The house is wonderful." and output sequence "Das Haus ist wunderbar.", then they should be prepared for
the model as follows:
In this setup, the input sequence and output sequence are a standard sequence-to-sequence input-output mapping.
Suppose that we want to fine-tune the model for translation for example, and we have a training example: the input
sequence "The house is wonderful." and output sequence "Das Haus ist wunderbar.", then they should be prepared for
the model as follows:
```python
from transformers import T5Tokenizer, T5ForConditionalGeneration
```python
>>> from transformers import T5Tokenizer, T5ForConditionalGeneration
tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5ForConditionalGeneration.from_pretrained("t5-small")
>>> tokenizer = T5Tokenizer.from_pretrained("t5-small")
>>> model = T5ForConditionalGeneration.from_pretrained("t5-small")
input_ids = tokenizer("translate English to German: The house is wonderful.", return_tensors="pt").input_ids
labels = tokenizer("Das Haus ist wunderbar.", return_tensors="pt").input_ids
# the forward function automatically creates the correct decoder_input_ids
loss = model(input_ids=input_ids, labels=labels).loss
```
>>> input_ids = tokenizer("translate English to German: The house is wonderful.", return_tensors="pt").input_ids
>>> labels = tokenizer("Das Haus ist wunderbar.", return_tensors="pt").input_ids
As you can see, only 2 inputs are required for the model in order to compute a loss: `input_ids` (which are the
`input_ids` of the encoded input sequence) and `labels` (which are the `input_ids` of the encoded
target sequence). The model will automatically create the `decoder_input_ids` based on the `labels`, by
shifting them one position to the right and prepending the `config.decoder_start_token_id`, which for T5 is
equal to 0 (i.e. the id of the pad token). Also note the task prefix: we prepend the input sequence with 'translate
English to German: ' before encoding it. This will help in improving the performance, as this task prefix was used
during T5's pre-training.
>>> # the forward function automatically creates the correct decoder_input_ids
>>> loss = model(input_ids=input_ids, labels=labels).loss
>>> loss.item()
0.2542
```
However, the example above only shows a single training example. In practice, one trains deep learning models in
batches. This entails that we must pad/truncate examples to the same length. For encoder-decoder models, one
typically defines a `max_source_length` and `max_target_length`, which determine the maximum length of the
input and output sequences respectively (otherwise they are truncated). These should be carefully set depending on
the task.
As you can see, only 2 inputs are required for the model in order to compute a loss: `input_ids` (which are the
`input_ids` of the encoded input sequence) and `labels` (which are the `input_ids` of the encoded
target sequence). The model will automatically create the `decoder_input_ids` based on the `labels`, by
shifting them one position to the right and prepending the `config.decoder_start_token_id`, which for T5 is
equal to 0 (i.e. the id of the pad token). Also note the task prefix: we prepend the input sequence with 'translate
English to German: ' before encoding it. This will help in improving the performance, as this task prefix was used
during T5's pre-training.
In addition, we must make sure that padding token id's of the `labels` are not taken into account by the loss
function. In PyTorch and Tensorflow, this can be done by replacing them with -100, which is the `ignore_index`
of the `CrossEntropyLoss`. In Flax, one can use the `decoder_attention_mask` to ignore padded tokens from
the loss (see the [Flax summarization script](https://github.com/huggingface/transformers/tree/main/examples/flax/summarization) for details). We also pass
`attention_mask` as additional input to the model, which makes sure that padding tokens of the inputs are
ignored. The code example below illustrates all of this.
However, the example above only shows a single training example. In practice, one trains deep learning models in
batches. This entails that we must pad/truncate examples to the same length. For encoder-decoder models, one
typically defines a `max_source_length` and `max_target_length`, which determine the maximum length of the
input and output sequences respectively (otherwise they are truncated). These should be carefully set depending on
the task.
```python
from transformers import T5Tokenizer, T5ForConditionalGeneration
import torch
In addition, we must make sure that padding token id's of the `labels` are not taken into account by the loss
function. In PyTorch and Tensorflow, this can be done by replacing them with -100, which is the `ignore_index`
of the `CrossEntropyLoss`. In Flax, one can use the `decoder_attention_mask` to ignore padded tokens from
the loss (see the [Flax summarization script](https://github.com/huggingface/transformers/tree/main/examples/flax/summarization) for details). We also pass
`attention_mask` as additional input to the model, which makes sure that padding tokens of the inputs are
ignored. The code example below illustrates all of this.
tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5ForConditionalGeneration.from_pretrained("t5-small")
```python
>>> from transformers import T5Tokenizer, T5ForConditionalGeneration
>>> import torch
# the following 2 hyperparameters are task-specific
max_source_length = 512
max_target_length = 128
>>> tokenizer = T5Tokenizer.from_pretrained("t5-small")
>>> model = T5ForConditionalGeneration.from_pretrained("t5-small")
# Suppose we have the following 2 training examples:
input_sequence_1 = "Welcome to NYC"
output_sequence_1 = "Bienvenue à NYC"
>>> # the following 2 hyperparameters are task-specific
>>> max_source_length = 512
>>> max_target_length = 128
input_sequence_2 = "HuggingFace is a company"
output_sequence_2 = "HuggingFace est une entreprise"
>>> # Suppose we have the following 2 training examples:
>>> input_sequence_1 = "Welcome to NYC"
>>> output_sequence_1 = "Bienvenue à NYC"
# encode the inputs
task_prefix = "translate English to French: "
input_sequences = [input_sequence_1, input_sequence_2]
encoding = tokenizer(
[task_prefix + sequence for sequence in input_sequences],
padding="longest",
max_length=max_source_length,
truncation=True,
return_tensors="pt",
)
input_ids, attention_mask = encoding.input_ids, encoding.attention_mask
>>> input_sequence_2 = "HuggingFace is a company"
>>> output_sequence_2 = "HuggingFace est une entreprise"
# encode the targets
target_encoding = tokenizer(
[output_sequence_1, output_sequence_2], padding="longest", max_length=max_target_length, truncation=True
)
labels = target_encoding.input_ids
>>> # encode the inputs
>>> task_prefix = "translate English to French: "
>>> input_sequences = [input_sequence_1, input_sequence_2]
# replace padding token id's of the labels by -100
labels = torch.tensor(labels)
labels[labels == tokenizer.pad_token_id] = -100
>>> encoding = tokenizer(
... [task_prefix + sequence for sequence in input_sequences],
... padding="longest",
... max_length=max_source_length,
... truncation=True,
... return_tensors="pt",
... )
# forward pass
loss = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels).loss
```
>>> input_ids, attention_mask = encoding.input_ids, encoding.attention_mask
>>> # encode the targets
>>> target_encoding = tokenizer(
... [output_sequence_1, output_sequence_2], padding="longest", max_length=max_target_length, truncation=True
... )
>>> labels = target_encoding.input_ids
>>> # replace padding token id's of the labels by -100 so it's ignored by the loss
>>> labels = torch.tensor(labels)
>>> labels[labels == tokenizer.pad_token_id] = -100
>>> # forward pass
>>> loss = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels).loss
>>> loss.item()
0.188
```
Additional training tips:
- T5 models need a slightly higher learning rate than the default one set in the `Trainer` when using the AdamW
optimizer. Typically, 1e-4 and 3e-4 work well for most problems (classification, summarization, translation, question
answering, question generation). Note that T5 was pre-trained using the AdaFactor optimizer.
optimizer. Typically, 1e-4 and 3e-4 work well for most problems (classification, summarization, translation, question
answering, question generation). Note that T5 was pre-trained using the AdaFactor optimizer.
- According to [this forum post](https://discuss.huggingface.co/t/t5-finetuning-tips/684), task prefixes matter when
(1) doing multi-task training (2) your task is similar or related to one of the supervised tasks used in T5's
pre-training mixture (see Appendix D of the [paper](https://arxiv.org/pdf/1910.10683.pdf) for the task prefixes
used).
According to [this forum post](https://discuss.huggingface.co/t/t5-finetuning-tips/684), task prefixes matter when
(1) doing multi-task training (2) your task is similar or related to one of the supervised tasks used in T5's
pre-training mixture (see Appendix D of the [paper](https://arxiv.org/pdf/1910.10683.pdf) for the task prefixes
used).
- If training on TPU, it is recommended to pad all examples of the dataset to the same length or make use of
*pad_to_multiple_of* to have a small number of predefined bucket sizes to fit all examples in. Dynamically padding
batches to the longest example is not recommended on TPU as it triggers a recompilation for every batch shape that is
encountered during training thus significantly slowing down the training. only padding up to the longest example in a
batch) leads to very slow training on TPU.
If training on TPU, it is recommended to pad all examples of the dataset to the same length or make use of
*pad_to_multiple_of* to have a small number of predefined bucket sizes to fit all examples in. Dynamically padding
batches to the longest example is not recommended on TPU as it triggers a recompilation for every batch shape that is
encountered during training thus significantly slowing down the training. only padding up to the longest example in a
batch) leads to very slow training on TPU.
<a id='inference'></a>
@ -219,15 +229,15 @@ There's also [this blog post](https://huggingface.co/blog/encoder-decoder#encode
generation works in general in encoder-decoder models.
```python
from transformers import T5Tokenizer, T5ForConditionalGeneration
>>> from transformers import T5Tokenizer, T5ForConditionalGeneration
tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5ForConditionalGeneration.from_pretrained("t5-small")
>>> tokenizer = T5Tokenizer.from_pretrained("t5-small")
>>> model = T5ForConditionalGeneration.from_pretrained("t5-small")
input_ids = tokenizer("translate English to German: The house is wonderful.", return_tensors="pt").input_ids
outputs = model.generate(input_ids)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
# Das Haus ist wunderbar.
>>> input_ids = tokenizer("translate English to German: The house is wonderful.", return_tensors="pt").input_ids
>>> outputs = model.generate(input_ids)
>>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
Das Haus ist wunderbar.
```
Note that T5 uses the `pad_token_id` as the `decoder_start_token_id`, so when doing generation without using
@ -236,31 +246,47 @@ Note that T5 uses the `pad_token_id` as the `decoder_start_token_id`, so when do
The example above only shows a single example. You can also do batched inference, like so:
```python
from transformers import T5Tokenizer, T5ForConditionalGeneration
>>> from transformers import T5Tokenizer, T5ForConditionalGeneration
tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5ForConditionalGeneration.from_pretrained("t5-small")
>>> tokenizer = T5Tokenizer.from_pretrained("t5-small")
>>> model = T5ForConditionalGeneration.from_pretrained("t5-small")
# when generating, we will use the logits of right-most token to predict the next token
# so the padding should be on the left
tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token # to avoid an error
>>> task_prefix = "translate English to German: "
>>> sentences = [
... "The house is wonderful.",
... "I like to work in NYC.",
>>> ] # use different length sentences to test batching
>>> inputs = tokenizer([task_prefix + sentence for sentence in sentences], return_tensors="pt", padding=True)
task_prefix = "translate English to German: "
sentences = ["The house is wonderful.", "I like to work in NYC."] # use different length sentences to test batching
inputs = tokenizer([task_prefix + sentence for sentence in sentences], return_tensors="pt", padding=True)
>>> output_sequences = model.generate(
... input_ids=inputs["input_ids"],
... attention_mask=inputs["attention_mask"],
... do_sample=False, # disable sampling to test if batching affects output
... )
output_sequences = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
do_sample=False, # disable sampling to test if batching affects output
)
print(tokenizer.batch_decode(output_sequences, skip_special_tokens=True))
# ['Das Haus ist wunderbar.', 'Ich arbeite gerne in NYC.']
>>> print(tokenizer.batch_decode(output_sequences, skip_special_tokens=True))
['Das Haus ist wunderbar.', 'Ich arbeite gerne in NYC.']
```
Because T5 has been trained with the span-mask denoising objective,
it can be used to predict the sentinel (masked-out) tokens during inference.
The predicted tokens will then be placed between the sentinel tokens.
```python
>>> from transformers import T5Tokenizer, T5ForConditionalGeneration
>>> tokenizer = T5Tokenizer.from_pretrained("t5-small")
>>> model = T5ForConditionalGeneration.from_pretrained("t5-small")
>>> input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
>>> sequence_ids = model.generate(input_ids)
>>> sequences = tokenizer.batch_decode(sequence_ids)
>>> sequences
['<pad> <extra_id_0> park offers<extra_id_1> the<extra_id_2> park.</s>']
```
<a id='scripts'></a>
## Performance

View File

@ -20,9 +20,9 @@ repository by Colin Raffel et al. It's an improved version of the original T5 mo
One can directly plug in the weights of T5v1.1 into a T5 model, like so:
```python
from transformers import T5ForConditionalGeneration
>>> from transformers import T5ForConditionalGeneration
model = T5ForConditionalGeneration.from_pretrained("google/t5-v1_1-base")
>>> model = T5ForConditionalGeneration.from_pretrained("google/t5-v1_1-base")
```
T5 Version 1.1 includes the following improvements compared to the original T5 model:

View File

@ -1,6 +1,9 @@
docs/source/en/quicktour.mdx
docs/source/en/task_summary.mdx
docs/source/en/model_doc/speech_to_text.mdx
docs/source/en/model_doc/t5.mdx
docs/source/en/model_doc/t5v1_1.mdx
docs/source/en/model_doc/byt5.mdx
docs/source/en/model_doc/tapex.mdx
src/transformers/generation_utils.py
src/transformers/models/bart/modeling_bart.py