From b24201fa44e1a14e83be890dcbc231e926c37ec1 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 13 Apr 2022 11:36:54 +0200 Subject: [PATCH] [Doctests] Fix all T5 doc tests (#16646) * [Doctests] Fix all T5 doc tests * make style * Update docs/source/en/model_doc/t5.mdx Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Apply Sylvains comments * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- docs/source/en/model_doc/byt5.mdx | 97 ++++++++-- docs/source/en/model_doc/t5.mdx | 274 +++++++++++++++------------- docs/source/en/model_doc/t5v1.1.mdx | 4 +- utils/documentation_tests.txt | 3 + 4 files changed, 234 insertions(+), 144 deletions(-) diff --git a/docs/source/en/model_doc/byt5.mdx b/docs/source/en/model_doc/byt5.mdx index 06ed1952265..dc4c5a6caf8 100644 --- a/docs/source/en/model_doc/byt5.mdx +++ b/docs/source/en/model_doc/byt5.mdx @@ -48,37 +48,98 @@ fine-tuning. If you are doing multi-task fine-tuning, you should use a prefix. ByT5 works on raw UTF-8 bytes, so it can be used without a tokenizer: ```python -from transformers import T5ForConditionalGeneration -import torch +>>> from transformers import T5ForConditionalGeneration +>>> import torch -model = T5ForConditionalGeneration.from_pretrained("google/byt5-small") +>>> model = T5ForConditionalGeneration.from_pretrained("google/byt5-small") -input_ids = torch.tensor([list("Life is like a box of chocolates.".encode("utf-8"))]) + 3 # add 3 for special tokens -labels = ( - torch.tensor([list("La vie est comme une boîte de chocolat.".encode("utf-8"))]) + 3 -) # add 3 for special tokens +>>> num_special_tokens = 3 +>>> # Model has 3 special tokens which take up the input ids 0,1,2 of ByT5. +>>> # => Need to shift utf-8 character encodings by 3 before passing ids to model. -loss = model(input_ids, labels=labels).loss # forward pass +>>> input_ids = torch.tensor([list("Life is like a box of chocolates.".encode("utf-8"))]) + num_special_tokens + +>>> labels = torch.tensor([list("La vie est comme une boîte de chocolat.".encode("utf-8"))]) + num_special_tokens + +>>> loss = model(input_ids, labels=labels).loss +>>> loss.item() +2.66 ``` For batched inference and training it is however recommended to make use of the tokenizer: ```python -from transformers import T5ForConditionalGeneration, AutoTokenizer +>>> from transformers import T5ForConditionalGeneration, AutoTokenizer -model = T5ForConditionalGeneration.from_pretrained("google/byt5-small") -tokenizer = AutoTokenizer.from_pretrained("google/byt5-small") +>>> model = T5ForConditionalGeneration.from_pretrained("google/byt5-small") +>>> tokenizer = AutoTokenizer.from_pretrained("google/byt5-small") -model_inputs = tokenizer( - ["Life is like a box of chocolates.", "Today is Monday."], padding="longest", return_tensors="pt" -) -labels = tokenizer( - ["La vie est comme une boîte de chocolat.", "Aujourd'hui c'est lundi."], padding="longest", return_tensors="pt" -).input_ids +>>> model_inputs = tokenizer( +... ["Life is like a box of chocolates.", "Today is Monday."], padding="longest", return_tensors="pt" +... ) +>>> labels_dict = tokenizer( +... ["La vie est comme une boîte de chocolat.", "Aujourd'hui c'est lundi."], padding="longest", return_tensors="pt" +... ) +>>> labels = labels_dict.input_ids -loss = model(**model_inputs, labels=labels).loss # forward pass +>>> loss = model(**model_inputs, labels=labels).loss +>>> loss.item() +17.9 ``` +Similar to [T5](t5), ByT5 was trained on the span-mask denoising task. However, +since the model works directly on characters, the pretraining task is a bit +different. Let's corrupt some characters of the +input sentence `"The dog chases a ball in the park."` and ask ByT5 to predict them +for us. + +```python +>>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM +>>> import torch + +>>> tokenizer = AutoTokenizer.from_pretrained("google/byt5-base") +>>> model = AutoModelForSeq2SeqLM.from_pretrained("google/byt5-base") + +>>> input_ids_prompt = "The dog chases a ball in the park." +>>> input_ids = tokenizer(input_ids_prompt).input_ids + +>>> # Note that we cannot add "{extra_id_...}" to the string directly +>>> # as the Byte tokenizer would incorrectly merge the tokens +>>> # For ByT5, we need to work directly on the character level +>>> # Contrary to T5, ByT5 does not use sentinel tokens for masking, but instead +>>> # uses final utf character ids. +>>> # UTF-8 is represented by 8 bits and ByT5 has 3 special tokens. +>>> # => There are 2**8+2 = 259 input ids and mask tokens count down from index 258. +>>> # => mask to "The dog [258]a ball [257]park." + +>>> input_ids = torch.tensor([input_ids[:8] + [258] + input_ids[14:21] + [257] + input_ids[28:]]) +>>> input_ids +tensor([[ 87, 107, 104, 35, 103, 114, 106, 35, 258, 35, 100, 35, 101, 100, 111, 111, 257, 35, 115, 100, 117, 110, 49, 1]]) + +>>> # ByT5 produces only one char at a time so we need to produce many more output characters here -> set `max_length=100`. +>>> output_ids = model.generate(input_ids, max_length=100)[0].tolist() +>>> output_ids +[0, 258, 108, 118, 35, 119, 107, 104, 35, 114, 113, 104, 35, 122, 107, 114, 35, 103, 114, 104, 118, 257, 35, 108, 113, 35, 119, 107, 104, 35, 103, 108, 118, 102, 114, 256, 108, 113, 35, 119, 107, 104, 35, 115, 100, 117, 110, 49, 35, 87, 107, 104, 35, 103, 114, 106, 35, 108, 118, 35, 119, 107, 104, 35, 114, 113, 104, 35, 122, 107, 114, 35, 103, 114, 104, 118, 35, 100, 35, 101, 100, 111, 111, 35, 108, 113, 255, 35, 108, 113, 35, 119, 107, 104, 35, 115, 100, 117, 110, 49] + +>>> # ^- Note how 258 descends to 257, 256, 255 + +>>> # Now we need to split on the sentinel tokens, let's write a short loop for this +>>> output_ids_list = [] +>>> start_token = 0 +>>> sentinel_token = 258 +>>> while sentinel_token in output_ids: +... split_idx = output_ids.index(sentinel_token) +... output_ids_list.append(output_ids[start_token:split_idx]) +... start_token = split_idx +... sentinel_token -= 1 + +>>> output_ids_list.append(output_ids[start_token:]) +>>> output_string = tokenizer.batch_decode(output_ids_list) +>>> output_string +['', 'is the one who does', ' in the disco', 'in the park. The dog is the one who does a ball in', ' in the park.'] +``` + + ## ByT5Tokenizer [[autodoc]] ByT5Tokenizer diff --git a/docs/source/en/model_doc/t5.mdx b/docs/source/en/model_doc/t5.mdx index ef605a3523a..c312b3df815 100644 --- a/docs/source/en/model_doc/t5.mdx +++ b/docs/source/en/model_doc/t5.mdx @@ -32,9 +32,9 @@ NLP, we release our dataset, pre-trained models, and code.* Tips: - T5 is an encoder-decoder model pre-trained on a multi-task mixture of unsupervised and supervised tasks and for which - each task is converted into a text-to-text format. T5 works well on a variety of tasks out-of-the-box by prepending a - different prefix to the input corresponding to each task, e.g., for translation: *translate English to German: ...*, - for summarization: *summarize: ...*. +each task is converted into a text-to-text format. T5 works well on a variety of tasks out-of-the-box by prepending a +different prefix to the input corresponding to each task, e.g., for translation: *translate English to German: ...*, +for summarization: *summarize: ...*. - T5 uses relative scalar embeddings. Encoder input padding can be done on the left and on the right. @@ -83,130 +83,140 @@ language modeling head on top of the decoder. - Unsupervised denoising training - In this setup, spans of the input sequence are masked by so-called sentinel tokens (*a.k.a* unique mask tokens) and - the output sequence is formed as a concatenation of the same sentinel tokens and the *real* masked tokens. Each - sentinel token represents a unique mask token for this sentence and should start with ``, - ``, ... up to ``. As a default, 100 sentinel tokens are available in - [`T5Tokenizer`]. +In this setup, spans of the input sequence are masked by so-called sentinel tokens (*a.k.a* unique mask tokens) and +the output sequence is formed as a concatenation of the same sentinel tokens and the *real* masked tokens. Each +sentinel token represents a unique mask token for this sentence and should start with ``, +``, ... up to ``. As a default, 100 sentinel tokens are available in +[`T5Tokenizer`]. - For instance, the sentence "The cute dog walks in the park" with the masks put on "cute dog" and "the" should be - processed as follows: +For instance, the sentence "The cute dog walks in the park" with the masks put on "cute dog" and "the" should be +processed as follows: - ```python - from transformers import T5Tokenizer, T5ForConditionalGeneration +```python +>>> from transformers import T5Tokenizer, T5ForConditionalGeneration - tokenizer = T5Tokenizer.from_pretrained("t5-small") - model = T5ForConditionalGeneration.from_pretrained("t5-small") +>>> tokenizer = T5Tokenizer.from_pretrained("t5-small") +>>> model = T5ForConditionalGeneration.from_pretrained("t5-small") - input_ids = tokenizer("The walks in park", return_tensors="pt").input_ids - labels = tokenizer(" cute dog the ", return_tensors="pt").input_ids - # the forward function automatically creates the correct decoder_input_ids - loss = model(input_ids=input_ids, labels=labels).loss - ``` +>>> input_ids = tokenizer("The walks in park", return_tensors="pt").input_ids +>>> labels = tokenizer(" cute dog the ", return_tensors="pt").input_ids - If you're interested in pre-training T5 on a new corpus, check out the [run_t5_mlm_flax.py](https://github.com/huggingface/transformers/tree/main/examples/flax/language-modeling) script in the Examples - directory. +>>> # the forward function automatically creates the correct decoder_input_ids +>>> loss = model(input_ids=input_ids, labels=labels).loss +>>> loss.item() +3.7837 +``` + +If you're interested in pre-training T5 on a new corpus, check out the [run_t5_mlm_flax.py](https://github.com/huggingface/transformers/tree/main/examples/flax/language-modeling) script in the Examples +directory. - Supervised training - In this setup, the input sequence and output sequence are a standard sequence-to-sequence input-output mapping. - Suppose that we want to fine-tune the model for translation for example, and we have a training example: the input - sequence "The house is wonderful." and output sequence "Das Haus ist wunderbar.", then they should be prepared for - the model as follows: +In this setup, the input sequence and output sequence are a standard sequence-to-sequence input-output mapping. +Suppose that we want to fine-tune the model for translation for example, and we have a training example: the input +sequence "The house is wonderful." and output sequence "Das Haus ist wunderbar.", then they should be prepared for +the model as follows: - ```python - from transformers import T5Tokenizer, T5ForConditionalGeneration +```python +>>> from transformers import T5Tokenizer, T5ForConditionalGeneration - tokenizer = T5Tokenizer.from_pretrained("t5-small") - model = T5ForConditionalGeneration.from_pretrained("t5-small") +>>> tokenizer = T5Tokenizer.from_pretrained("t5-small") +>>> model = T5ForConditionalGeneration.from_pretrained("t5-small") - input_ids = tokenizer("translate English to German: The house is wonderful.", return_tensors="pt").input_ids - labels = tokenizer("Das Haus ist wunderbar.", return_tensors="pt").input_ids - # the forward function automatically creates the correct decoder_input_ids - loss = model(input_ids=input_ids, labels=labels).loss - ``` +>>> input_ids = tokenizer("translate English to German: The house is wonderful.", return_tensors="pt").input_ids +>>> labels = tokenizer("Das Haus ist wunderbar.", return_tensors="pt").input_ids - As you can see, only 2 inputs are required for the model in order to compute a loss: `input_ids` (which are the - `input_ids` of the encoded input sequence) and `labels` (which are the `input_ids` of the encoded - target sequence). The model will automatically create the `decoder_input_ids` based on the `labels`, by - shifting them one position to the right and prepending the `config.decoder_start_token_id`, which for T5 is - equal to 0 (i.e. the id of the pad token). Also note the task prefix: we prepend the input sequence with 'translate - English to German: ' before encoding it. This will help in improving the performance, as this task prefix was used - during T5's pre-training. +>>> # the forward function automatically creates the correct decoder_input_ids +>>> loss = model(input_ids=input_ids, labels=labels).loss +>>> loss.item() +0.2542 +``` - However, the example above only shows a single training example. In practice, one trains deep learning models in - batches. This entails that we must pad/truncate examples to the same length. For encoder-decoder models, one - typically defines a `max_source_length` and `max_target_length`, which determine the maximum length of the - input and output sequences respectively (otherwise they are truncated). These should be carefully set depending on - the task. +As you can see, only 2 inputs are required for the model in order to compute a loss: `input_ids` (which are the +`input_ids` of the encoded input sequence) and `labels` (which are the `input_ids` of the encoded +target sequence). The model will automatically create the `decoder_input_ids` based on the `labels`, by +shifting them one position to the right and prepending the `config.decoder_start_token_id`, which for T5 is +equal to 0 (i.e. the id of the pad token). Also note the task prefix: we prepend the input sequence with 'translate +English to German: ' before encoding it. This will help in improving the performance, as this task prefix was used +during T5's pre-training. - In addition, we must make sure that padding token id's of the `labels` are not taken into account by the loss - function. In PyTorch and Tensorflow, this can be done by replacing them with -100, which is the `ignore_index` - of the `CrossEntropyLoss`. In Flax, one can use the `decoder_attention_mask` to ignore padded tokens from - the loss (see the [Flax summarization script](https://github.com/huggingface/transformers/tree/main/examples/flax/summarization) for details). We also pass - `attention_mask` as additional input to the model, which makes sure that padding tokens of the inputs are - ignored. The code example below illustrates all of this. +However, the example above only shows a single training example. In practice, one trains deep learning models in +batches. This entails that we must pad/truncate examples to the same length. For encoder-decoder models, one +typically defines a `max_source_length` and `max_target_length`, which determine the maximum length of the +input and output sequences respectively (otherwise they are truncated). These should be carefully set depending on +the task. - ```python - from transformers import T5Tokenizer, T5ForConditionalGeneration - import torch +In addition, we must make sure that padding token id's of the `labels` are not taken into account by the loss +function. In PyTorch and Tensorflow, this can be done by replacing them with -100, which is the `ignore_index` +of the `CrossEntropyLoss`. In Flax, one can use the `decoder_attention_mask` to ignore padded tokens from +the loss (see the [Flax summarization script](https://github.com/huggingface/transformers/tree/main/examples/flax/summarization) for details). We also pass +`attention_mask` as additional input to the model, which makes sure that padding tokens of the inputs are +ignored. The code example below illustrates all of this. - tokenizer = T5Tokenizer.from_pretrained("t5-small") - model = T5ForConditionalGeneration.from_pretrained("t5-small") +```python +>>> from transformers import T5Tokenizer, T5ForConditionalGeneration +>>> import torch - # the following 2 hyperparameters are task-specific - max_source_length = 512 - max_target_length = 128 +>>> tokenizer = T5Tokenizer.from_pretrained("t5-small") +>>> model = T5ForConditionalGeneration.from_pretrained("t5-small") - # Suppose we have the following 2 training examples: - input_sequence_1 = "Welcome to NYC" - output_sequence_1 = "Bienvenue à NYC" +>>> # the following 2 hyperparameters are task-specific +>>> max_source_length = 512 +>>> max_target_length = 128 - input_sequence_2 = "HuggingFace is a company" - output_sequence_2 = "HuggingFace est une entreprise" +>>> # Suppose we have the following 2 training examples: +>>> input_sequence_1 = "Welcome to NYC" +>>> output_sequence_1 = "Bienvenue à NYC" - # encode the inputs - task_prefix = "translate English to French: " - input_sequences = [input_sequence_1, input_sequence_2] - encoding = tokenizer( - [task_prefix + sequence for sequence in input_sequences], - padding="longest", - max_length=max_source_length, - truncation=True, - return_tensors="pt", - ) - input_ids, attention_mask = encoding.input_ids, encoding.attention_mask +>>> input_sequence_2 = "HuggingFace is a company" +>>> output_sequence_2 = "HuggingFace est une entreprise" - # encode the targets - target_encoding = tokenizer( - [output_sequence_1, output_sequence_2], padding="longest", max_length=max_target_length, truncation=True - ) - labels = target_encoding.input_ids +>>> # encode the inputs +>>> task_prefix = "translate English to French: " +>>> input_sequences = [input_sequence_1, input_sequence_2] - # replace padding token id's of the labels by -100 - labels = torch.tensor(labels) - labels[labels == tokenizer.pad_token_id] = -100 +>>> encoding = tokenizer( +... [task_prefix + sequence for sequence in input_sequences], +... padding="longest", +... max_length=max_source_length, +... truncation=True, +... return_tensors="pt", +... ) - # forward pass - loss = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels).loss - ``` +>>> input_ids, attention_mask = encoding.input_ids, encoding.attention_mask + +>>> # encode the targets +>>> target_encoding = tokenizer( +... [output_sequence_1, output_sequence_2], padding="longest", max_length=max_target_length, truncation=True +... ) +>>> labels = target_encoding.input_ids + +>>> # replace padding token id's of the labels by -100 so it's ignored by the loss +>>> labels = torch.tensor(labels) +>>> labels[labels == tokenizer.pad_token_id] = -100 + +>>> # forward pass +>>> loss = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels).loss +>>> loss.item() +0.188 +``` Additional training tips: - T5 models need a slightly higher learning rate than the default one set in the `Trainer` when using the AdamW - optimizer. Typically, 1e-4 and 3e-4 work well for most problems (classification, summarization, translation, question - answering, question generation). Note that T5 was pre-trained using the AdaFactor optimizer. +optimizer. Typically, 1e-4 and 3e-4 work well for most problems (classification, summarization, translation, question +answering, question generation). Note that T5 was pre-trained using the AdaFactor optimizer. -- According to [this forum post](https://discuss.huggingface.co/t/t5-finetuning-tips/684), task prefixes matter when - (1) doing multi-task training (2) your task is similar or related to one of the supervised tasks used in T5's - pre-training mixture (see Appendix D of the [paper](https://arxiv.org/pdf/1910.10683.pdf) for the task prefixes - used). +According to [this forum post](https://discuss.huggingface.co/t/t5-finetuning-tips/684), task prefixes matter when +(1) doing multi-task training (2) your task is similar or related to one of the supervised tasks used in T5's +pre-training mixture (see Appendix D of the [paper](https://arxiv.org/pdf/1910.10683.pdf) for the task prefixes +used). -- If training on TPU, it is recommended to pad all examples of the dataset to the same length or make use of - *pad_to_multiple_of* to have a small number of predefined bucket sizes to fit all examples in. Dynamically padding - batches to the longest example is not recommended on TPU as it triggers a recompilation for every batch shape that is - encountered during training thus significantly slowing down the training. only padding up to the longest example in a - batch) leads to very slow training on TPU. +If training on TPU, it is recommended to pad all examples of the dataset to the same length or make use of +*pad_to_multiple_of* to have a small number of predefined bucket sizes to fit all examples in. Dynamically padding +batches to the longest example is not recommended on TPU as it triggers a recompilation for every batch shape that is +encountered during training thus significantly slowing down the training. only padding up to the longest example in a +batch) leads to very slow training on TPU. @@ -219,15 +229,15 @@ There's also [this blog post](https://huggingface.co/blog/encoder-decoder#encode generation works in general in encoder-decoder models. ```python -from transformers import T5Tokenizer, T5ForConditionalGeneration +>>> from transformers import T5Tokenizer, T5ForConditionalGeneration -tokenizer = T5Tokenizer.from_pretrained("t5-small") -model = T5ForConditionalGeneration.from_pretrained("t5-small") +>>> tokenizer = T5Tokenizer.from_pretrained("t5-small") +>>> model = T5ForConditionalGeneration.from_pretrained("t5-small") -input_ids = tokenizer("translate English to German: The house is wonderful.", return_tensors="pt").input_ids -outputs = model.generate(input_ids) -print(tokenizer.decode(outputs[0], skip_special_tokens=True)) -# Das Haus ist wunderbar. +>>> input_ids = tokenizer("translate English to German: The house is wonderful.", return_tensors="pt").input_ids +>>> outputs = model.generate(input_ids) +>>> print(tokenizer.decode(outputs[0], skip_special_tokens=True)) +Das Haus ist wunderbar. ``` Note that T5 uses the `pad_token_id` as the `decoder_start_token_id`, so when doing generation without using @@ -236,31 +246,47 @@ Note that T5 uses the `pad_token_id` as the `decoder_start_token_id`, so when do The example above only shows a single example. You can also do batched inference, like so: ```python -from transformers import T5Tokenizer, T5ForConditionalGeneration +>>> from transformers import T5Tokenizer, T5ForConditionalGeneration -tokenizer = T5Tokenizer.from_pretrained("t5-small") -model = T5ForConditionalGeneration.from_pretrained("t5-small") +>>> tokenizer = T5Tokenizer.from_pretrained("t5-small") +>>> model = T5ForConditionalGeneration.from_pretrained("t5-small") -# when generating, we will use the logits of right-most token to predict the next token -# so the padding should be on the left -tokenizer.padding_side = "left" -tokenizer.pad_token = tokenizer.eos_token # to avoid an error +>>> task_prefix = "translate English to German: " +>>> sentences = [ +... "The house is wonderful.", +... "I like to work in NYC.", +>>> ] # use different length sentences to test batching +>>> inputs = tokenizer([task_prefix + sentence for sentence in sentences], return_tensors="pt", padding=True) -task_prefix = "translate English to German: " -sentences = ["The house is wonderful.", "I like to work in NYC."] # use different length sentences to test batching -inputs = tokenizer([task_prefix + sentence for sentence in sentences], return_tensors="pt", padding=True) +>>> output_sequences = model.generate( +... input_ids=inputs["input_ids"], +... attention_mask=inputs["attention_mask"], +... do_sample=False, # disable sampling to test if batching affects output +... ) -output_sequences = model.generate( - input_ids=inputs["input_ids"], - attention_mask=inputs["attention_mask"], - do_sample=False, # disable sampling to test if batching affects output -) - -print(tokenizer.batch_decode(output_sequences, skip_special_tokens=True)) - -# ['Das Haus ist wunderbar.', 'Ich arbeite gerne in NYC.'] +>>> print(tokenizer.batch_decode(output_sequences, skip_special_tokens=True)) +['Das Haus ist wunderbar.', 'Ich arbeite gerne in NYC.'] ``` +Because T5 has been trained with the span-mask denoising objective, +it can be used to predict the sentinel (masked-out) tokens during inference. +The predicted tokens will then be placed between the sentinel tokens. + +```python +>>> from transformers import T5Tokenizer, T5ForConditionalGeneration + +>>> tokenizer = T5Tokenizer.from_pretrained("t5-small") +>>> model = T5ForConditionalGeneration.from_pretrained("t5-small") + +>>> input_ids = tokenizer("The walks in park", return_tensors="pt").input_ids + +>>> sequence_ids = model.generate(input_ids) +>>> sequences = tokenizer.batch_decode(sequence_ids) +>>> sequences +[' park offers the park.'] +``` + + ## Performance diff --git a/docs/source/en/model_doc/t5v1.1.mdx b/docs/source/en/model_doc/t5v1.1.mdx index 512c5c59ced..b15188961d3 100644 --- a/docs/source/en/model_doc/t5v1.1.mdx +++ b/docs/source/en/model_doc/t5v1.1.mdx @@ -20,9 +20,9 @@ repository by Colin Raffel et al. It's an improved version of the original T5 mo One can directly plug in the weights of T5v1.1 into a T5 model, like so: ```python -from transformers import T5ForConditionalGeneration +>>> from transformers import T5ForConditionalGeneration -model = T5ForConditionalGeneration.from_pretrained("google/t5-v1_1-base") +>>> model = T5ForConditionalGeneration.from_pretrained("google/t5-v1_1-base") ``` T5 Version 1.1 includes the following improvements compared to the original T5 model: diff --git a/utils/documentation_tests.txt b/utils/documentation_tests.txt index eee14c33749..170076244cd 100644 --- a/utils/documentation_tests.txt +++ b/utils/documentation_tests.txt @@ -1,6 +1,9 @@ docs/source/en/quicktour.mdx docs/source/en/task_summary.mdx docs/source/en/model_doc/speech_to_text.mdx +docs/source/en/model_doc/t5.mdx +docs/source/en/model_doc/t5v1_1.mdx +docs/source/en/model_doc/byt5.mdx docs/source/en/model_doc/tapex.mdx src/transformers/generation_utils.py src/transformers/models/bart/modeling_bart.py