[Flax] Refactor gpt2 & bert example docs (#13024)

* fix_torch_device_generate_test

* remove @

* improve docs for clm

* speed-ups

* correct t5 example as well

* push final touches

* Update examples/flax/language-modeling/README.md

* correct docs for mlm

* Update examples/flax/language-modeling/README.md

Co-authored-by: Patrick von Platen <patrick@huggingface.co>
This commit is contained in:
Patrick von Platen 2021-08-09 13:37:50 +02:00 committed by GitHub
parent 3ff2cde5ca
commit 13a9c9a354
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 84 additions and 62 deletions

View File

@ -49,21 +49,15 @@ Next we clone the model repository to add the tokenizer and model files.
git clone https://huggingface.co/<your-username>/norwegian-roberta-base
```
To ensure that all tensorboard traces will be uploaded correctly, we need to
track them. You can run the following command inside your model repo to do so.
To setup all relevant files for training, let's go into the cloned model directory.
```
```bash
cd norwegian-roberta-base
git lfs track "*tfevents*"
```
Great, we have set up our model repository. During training, we will automatically
push the training logs and model weights to the repo.
Next, let's add a symbolic link to the `run_mlm_flax.py`.
```bash
export MODEL_DIR="./norwegian-roberta-base"
ln -s ~/transformers/examples/flax/language-modeling/run_mlm_flax.py run_mlm_flax.py
```
@ -71,15 +65,13 @@ ln -s ~/transformers/examples/flax/language-modeling/run_mlm_flax.py run_mlm_fla
In the first step, we train a tokenizer to efficiently process the text input for the model. Similar to how it is shown in [How to train a new language model from scratch using Transformers and Tokenizers](https://huggingface.co/blog/how-to-train), we use a **`ByteLevelBPETokenizer`**.
The tokenizer is trained on the complete Norwegian dataset of OSCAR
and consequently saved in `${MODEL_DIR}`
and consequently saved in the cloned model directory.
This can take up to 10 minutes depending on your hardware ☕.
```python
from datasets import load_dataset
from tokenizers import trainers, Tokenizer, normalizers, ByteLevelBPETokenizer
model_dir = "./norwegian-roberta-base" # ${MODEL_DIR}
# load dataset
dataset = load_dataset("oscar", "unshuffled_deduplicated_no", split="train")
@ -100,7 +92,7 @@ tokenizer.train_from_iterator(batch_iterator(), vocab_size=50265, min_frequency=
])
# Save files to disk
tokenizer.save(f"{model_dir}/tokenizer.json")
tokenizer.save("./")
```
### Create configuration
@ -112,22 +104,23 @@ in the local model folder:
```python
from transformers import RobertaConfig
model_dir = "./norwegian-roberta-base" # ${MODEL_DIR}
config = RobertaConfig.from_pretrained("roberta-base", vocab_size=tokenizer.get_vocab_size())
config.save_pretrained(model_dir)
config = RobertaConfig.from_pretrained("roberta-base", vocab_size=50265)
config.save_pretrained("./")
```
Great, we have set up our model repository. During training, we will automatically
push the training logs and model weights to the repo.
### Train model
Next we can run the example script to pretrain the model:
```bash
./run_mlm_flax.py \
--output_dir="${MODEL_DIR}" \
--output_dir="./" \
--model_type="roberta" \
--config_name="${MODEL_DIR}" \
--tokenizer_name="${MODEL_DIR}" \
--config_name="./" \
--tokenizer_name="./" \
--dataset_name="oscar" \
--dataset_config_name="unshuffled_deduplicated_no" \
--max_seq_length="128" \
@ -180,25 +173,51 @@ Next we clone the model repository to add the tokenizer and model files.
git clone https://huggingface.co/<your-username>/norwegian-gpt2
```
To ensure that all tensorboard traces will be uploaded correctly, we need to
track them. You can run the following command inside your model repo to do so.
```
cd norwegian-gpt2
git lfs track "*tfevents*"
```
Great, we have set up our model repository. During training, we will automatically
push the training logs and model weights to the repo.
Next, let's add a symbolic link to the `run_clm_flax.py`.
To setup all relevant files for training, let's go into the cloned model directory.
```bash
cd norwegian-gpt2
```
Next, let's add a symbolic link to the training script `run_clm_flax.py`.
```bash
export MODEL_DIR="./norwegian-gpt2"
ln -s ~/transformers/examples/flax/language-modeling/run_clm_flax.py run_clm_flax.py
```
Next, we'll follow the same steps as above in [Train tokenizer](#train-tokenizer) to train the tokenizer.
### Train tokenizer
In the first step, we train a tokenizer to efficiently process the text input for the model. Similar to how it is shown in [How to train a new language model from scratch using Transformers and Tokenizers](https://huggingface.co/blog/how-to-train), we use a **`ByteLevelBPETokenizer`**.
The tokenizer is trained on the complete Norwegian dataset of OSCAR
and consequently saved in the cloned model directory.
This can take up to 10 minutes depending on your hardware ☕.
```python
from datasets import load_dataset
from tokenizers import trainers, Tokenizer, normalizers, ByteLevelBPETokenizer
# load dataset
dataset = load_dataset("oscar", "unshuffled_deduplicated_no", split="train")
# Instantiate tokenizer
tokenizer = ByteLevelBPETokenizer()
def batch_iterator(batch_size=1000):
for i in range(0, len(dataset), batch_size):
yield dataset[i: i + batch_size]["text"]
# Customized training
tokenizer.train_from_iterator(batch_iterator(), vocab_size=50257, min_frequency=2, special_tokens=[
"<s>",
"<pad>",
"</s>",
"<unk>",
"<mask>",
])
# Save files to disk
tokenizer.save("./tokenizer.json")
```
### Create configuration
@ -209,22 +228,23 @@ in the local model folder:
```python
from transformers import GPT2Config
model_dir = "./norwegian-gpt2" # ${MODEL_DIR}
config = GPT2Config.from_pretrained("gpt2", resid_pdrop=0.0, embd_pdrop=0.0, attn_pdrop=0.0, vocab_size=tokenizer.get_vocab_size())
config.save_pretrained(model_dir)
config = GPT2Config.from_pretrained("gpt2", resid_pdrop=0.0, embd_pdrop=0.0, attn_pdrop=0.0, vocab_size=50257)
config.save_pretrained("./")
```
Great, we have set up our model repository. During training, we will now automatically
push the training logs and model weights to the repo.
### Train model
Next we can run the example script to pretrain the model:
Finally, we can run the example script to pretrain the model:
```bash
./run_clm_flax.py \
--output_dir="${MODEL_DIR}" \
--output_dir="./l" \
--model_type="gpt2" \
--config_name="${MODEL_DIR}" \
--tokenizer_name="${MODEL_DIR}" \
--config_name="./" \
--tokenizer_name="./" \
--dataset_name="oscar" \
--dataset_config_name="unshuffled_deduplicated_no" \
--do_train --do_eval \
@ -246,6 +266,9 @@ of 3.24 and 25.72 respectively after 20 epochs on a single TPUv3-8.
This should take less than ~21 hours.
Training statistics can be accessed on [tfhub.de](https://tensorboard.dev/experiment/2zEhLwJ0Qp2FAkI3WVH9qA).
For a step-by-step walkthrough of how to do causal language modeling in Flax, please have a
look at [this](https://colab.research.google.com/github/huggingface/notebooks/blob/master/examples/causal_language_modeling_flax.ipynb) google colab.
## T5-like span-masked language modeling
In the following, we demonstrate how to train a T5 model using the span-masked language model
@ -272,21 +295,15 @@ Next we clone the model repository to add the tokenizer and model files.
git clone https://huggingface.co/<your-username>/norwegian-t5-base
```
To ensure that all tensorboard traces will be uploaded correctly, we need to
track them. You can run the following command inside your model repo to do so.
To setup all relevant files for trairing, let's go into the cloned model directory.
```
```bash
cd norwegian-t5-base
git lfs track "*tfevents*"
```
Great, we have set up our model repository. During training, we will automatically
push the training logs and model weights to the repo.
Next, let's add a symbolic link to the `run_t5_mlm_flax.py` and `t5_tokenizer_model` scripts.
```bash
export MODEL_DIR="./norwegian-t5-base"
ln -s ~/transformers/examples/flax/language-modeling/run_t5_mlm_flax.py run_t5_mlm_flax.py
ln -s ~/transformers/examples/flax/language-modeling/t5_tokenizer_model.py t5_tokenizer_model.py
```
@ -299,7 +316,7 @@ a sentencepiece unigram tokenizer as shown in [t5_tokenizer_model.py](https://gi
which is heavily inspired from [yandex-research/DeDLOC's tokenizer model](https://github.com/yandex-research/DeDLOC/blob/5c994bc64e573702a9a79add3ecd68b38f14b548/sahajbert/tokenizer/tokenizer_model.py) .
The tokenizer is trained on the complete Norwegian dataset of OSCAR
and consequently saved in `${MODEL_DIR}`
and consequently saved in the cloned model directory.
This can take up to 120 minutes depending on your hardware ☕☕☕ .
```python
@ -310,7 +327,6 @@ from t5_tokenizer_model import SentencePieceUnigramTokenizer
vocab_size = 32_000
input_sentence_size = None
model_dir = "./norwegian-t5-base" # ${MODEL_DIR}
# Initialize a dataset
dataset = datasets.load_dataset("oscar", name="unshuffled_deduplicated_no", split="train")
@ -335,7 +351,7 @@ tokenizer.train_from_iterator(
)
# Save files to disk
tokenizer.save(f"{model_dir}/tokenizer.json")
tokenizer.save("./tokenizer.json")
```
### Create configuration
@ -347,12 +363,13 @@ in the local model folder:
```python
from transformers import T5Config
model_dir = "./norwegian-t5-base" # ${MODEL_DIR}
config = T5Config.from_pretrained("google/t5-v1_1-base", vocab_size=tokenizer.get_vocab_size())
config.save_pretrained(model_dir)
config.save_pretrained("./")
```
Great, we have set up our model repository. During training, we will automatically
push the training logs and model weights to the repo.
### Train model
Next we can run the example script to pretrain the model:

View File

@ -31,6 +31,7 @@ from pathlib import Path
from typing import Callable, Optional
import datasets
import numpy as np
from datasets import Dataset, load_dataset
from tqdm import tqdm
@ -51,6 +52,7 @@ from transformers import (
HfArgumentParser,
TrainingArguments,
is_tensorboard_available,
set_seed,
)
from transformers.testing_utils import CaptureLogger
@ -182,18 +184,16 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf
steps_per_epoch = len(dataset) // batch_size
if shuffle:
batch_idx = jax.random.permutation(rng, len(dataset))
batch_idx = np.random.permutation(len(dataset))
else:
batch_idx = jnp.arange(len(dataset))
batch_idx = np.arange(len(dataset))
batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
for idx in batch_idx:
batch = dataset[idx]
batch = {k: jnp.array(v) for k, v in batch.items()}
batch = shard(batch)
batch = {k: np.array(v) for k, v in batch.items()}
yield batch
@ -269,6 +269,9 @@ def main():
# Set the verbosity to info of the Transformers logger (on main process only):
logger.info(f"Training/evaluation parameters {training_args}")
# Set seed before initializing model.
set_seed(training_args.seed)
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
# (the dataset will be downloaded automatically from the datasets Hub).
@ -577,7 +580,7 @@ def main():
train_time = 0
train_metrics = []
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
epochs = tqdm(range(num_epochs), desc="Epoch ... ", position=0)
for epoch in epochs:
# ======================== Training ================================
train_start = time.time()
@ -591,6 +594,7 @@ def main():
# train
for step in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
batch = next(train_loader)
batch = shard(batch)
state, train_metric = p_train_step(state, batch)
train_metrics.append(train_metric)
@ -617,6 +621,7 @@ def main():
for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
# Model forward
batch = next(eval_loader)
batch = shard(batch)
metrics = p_eval_step(state.params, batch)
eval_metrics.append(metrics)

View File

@ -214,7 +214,7 @@ class FlaxDataCollatorForLanguageModeling:
def mask_tokens(
self, inputs: np.ndarray, special_tokens_mask: Optional[np.ndarray]
) -> Tuple[jnp.ndarray, jnp.ndarray]:
) -> Tuple[np.ndarray, np.ndarray]:
"""
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
"""