transformers/examples/flax/language-modeling
Patrick von Platen 16c0efca2c
Add mlm pretraining xla torch readme (#12011)
* fix_torch_device_generate_test

* remove @

* upload

* Apply suggestions from code review

* Apply suggestions from code review

* Apply suggestions from code review

* Update examples/flax/language-modeling/README.md

* add more info

* finish

* fix

Co-authored-by: Patrick von Platen <patrick@huggingface.co>
2021-06-14 10:31:21 +01:00
..
README.md Add mlm pretraining xla torch readme (#12011) 2021-06-14 10:31:21 +01:00
requirements.txt [Flax MLM] Refactor run mlm with optax (#11745) 2021-05-19 12:00:58 +01:00
run_clm_flax.py Flax CLM script (#12023) 2021-06-11 15:16:20 +05:30
run_mlm_flax.py pass decay_mask fn to optimizer (#12087) 2021-06-09 18:49:27 +01:00

Language model training examples

The following example showcases how to train a language model from scratch using the JAX/Flax backend.

JAX/Flax allows you to trace pure functions and compile them into efficient, fused accelerator code on both GPU and TPU. Models written in JAX/Flax are immutable and updated in a purely functional way which enables simple and efficient model parallelism.

Masked language modeling

In the following, we demonstrate how to train a bi-directional transformer model using masked language modeling objective as introduced in BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding. More specifically, we demonstrate how JAX/Flax can be leveraged to pre-train roberta-base in Norwegian on a single TPUv3-8 pod.

The example script uses the 🤗 Datasets library. You can easily customize them to your needs if you need extra processing on your datasets.

Let's start by creating a folder to save the trained model and a symbolic link to the run_mlm_flax.py script.

export MODEL_DIR="./norwegian-roberta-base"
mkdir -p ${MODEL_DIR}
ln -s ~/transformers/examples/flax/language-modeling/run_mlm_flax.py run_mlm_flax.py

Train tokenizer

In the first step, we train a tokenizer to efficiently process the text input for the model. Similar to how it is shown in How to train a new language model from scratch using Transformers and Tokenizers, we use a ByteLevelBPETokenizer. The tokenizer is trained on the complete Norwegian dataset of OSCAR and consequently saved in ${MODEL_DIR} This can take up to 10 minutes depending on your hardware .

from datasets import load_dataset
from tokenizers import trainers, Tokenizer, normalizers, ByteLevelBPETokenizer

model_dir = "./norwegian-roberta-base"  # ${MODEL_DIR}

# load dataset
dataset = load_dataset("oscar", "unshuffled_deduplicated_no", split="train")

# Instantiate tokenizer
tokenizer = ByteLevelBPETokenizer()

def batch_iterator(batch_size=1000):
    for i in range(0, len(dataset), batch_size):
        yield dataset[i: i + batch_size]["text"]

# Customized training
tokenizer.train_from_iterator(batch_iterator(), vocab_size=50265, min_frequency=2, special_tokens=[
    "<s>",
    "<pad>",
    "</s>",
    "<unk>",
    "<mask>",
])

# Save files to disk
tokenizer.save(f"{model_dir}/tokenizer.json")

Create configuration

Next, we create the model's configuration file. This is as simple as loading and storing **roberta-base** in the local model folder:

from transformers import RobertaConfig

model_dir = "./norwegian-roberta-base"  # ${MODEL_DIR}

config = RobertaConfig.from_pretrained("roberta-base")
config.save_pretrained(model_dir)

Train model

Next we can run the example script to pretrain the model:

./run_mlm_flax.py \
        --output_dir="./runs" \
        --model_type="roberta" \
        --config_name="${MODEL_DIR}" \
        --tokenizer_name="${MODEL_DIR}" \
        --dataset_name="oscar" \
        --dataset_config_name="unshuffled_deduplicated_no" \
        --max_seq_length="128" \
        --weight_decay="0.01" \
        --per_device_train_batch_size="128" \
        --per_device_eval_batch_size="128" \
        --learning_rate="3e-4" \
        --warmup_steps="1000" \
        --overwrite_output_dir \
        --pad_to_max_length \
        --num_train_epochs="18" \
        --adam_beta1="0.9" \
        --adam_beta2="0.98"

Training should converge at a loss and accuracy of 1.78 and 0.64 respectively after 18 epochs on a single TPUv3-8. This should take less than 18 hours. Training statistics can be accessed on tfhub.de.

For a step-by-step walkthrough of how to do masked language modeling in Flax, please have a look at this google colab.

Runtime evaluation

We also ran masked language modeling using PyTorch/XLA on a TPUv3-8, and PyTorch on 8 V100 GPUs. We report the overall training time below. For reproducibility, we state the training commands used for PyTorch/XLA and PyTorch further below.

Task TPU v3-8 (Flax) TPU v3-8 (Pytorch/XLA) 8 GPU (PyTorch)
MLM 15h32m 23h46m 44h14m
COST* $124.24 $187.84 $877.92

*All experiments are ran on Google Cloud Platform. Prices are on-demand prices (not preemptible), obtained on May 12, 2021 for zone Iowa (us-central1) using the following tables: TPU pricing table ($8.00/h for v3-8), GPU pricing table ($2.48/h per V100 GPU). GPU experiments are ran without further optimizations besides JAX transformations. GPU experiments are ran with full precision (fp32). "TPU v3-8" are 8 TPU cores on 4 chips (each chips has 2 cores), while "8 GPU" are 8 GPU chips.

Script to run MLM with PyTorch/XLA on TPUv3-8

For comparison one can run the same pre-training with PyTorch/XLA on TPU. To set up PyTorch/XLA on Cloud TPU VMs, please refer to this guide. Having created the tokenzier and configuration in norwegian-roberta-base, we create the following symbolic links:

ln -s ~/transformers/examples/pytorch/language-modeling/run_mlm.py ./
ln -s ~/transformers/examples/pytorch/xla_spawn.py ./

, set the following environment variables:

export XRT_TPU_CONFIG="localservice;0;localhost:51011"
unset LD_PRELOAD

export NUM_TPUS=8
export TOKENIZERS_PARALLELISM=0
export MODEL_DIR="./norwegian-roberta-base"
mkdir -p ${MODEL_DIR}

, and start training as follows:

python3 xla_spawn.py --num_cores ${NUM_TPUS} run_mlm.py --output_dir="./runs" \
										--model_type="roberta" \
										--config_name="${MODEL_DIR}" \
										--tokenizer_name="${MODEL_DIR}" \
										--dataset_name="oscar" \
										--dataset_config_name="unshuffled_deduplicated_no" \
										--max_seq_length="128" \
										--weight_decay="0.01" \
										--per_device_train_batch_size="128" \
										--per_device_eval_batch_size="128" \
										--learning_rate="3e-4" \
										--warmup_steps="1000" \
										--overwrite_output_dir \
										--num_train_epochs="18" \
										--adam_beta1="0.9" \
										--adam_beta2="0.98" \
										--do_train \
										--do_eval \
										--logging_steps="500" \
										--evaluation_strategy="epoch" \
										--report_to="tensorboard" \
										--save_strategy="no"

Script to compare pre-training with PyTorch on 8 GPU V100's

For comparison you can run the same pre-training with PyTorch on GPU. Note that we have to make use of gradient_accumulation because the maximum batch size that fits on a single V100 GPU is 32 instead of 128. Having created the tokenzier and configuration in norwegian-roberta-base, we create the following symbolic links:

ln -s ~/transformers/examples/pytorch/language-modeling/run_mlm.py ./

, set some environment variables:

export NUM_GPUS=8
export TOKENIZERS_PARALLELISM=0
export MODEL_DIR="./norwegian-roberta-base"
mkdir -p ${MODEL_DIR}

, and can start training as follows:

python3 -m torch.distributed.launch --nproc_per_node ${NUM_GPUS} run_mlm.py \
                        --output_dir="./runs" \
                        --model_type="roberta" \
                        --config_name="${MODEL_DIR}" \
                        --tokenizer_name="${MODEL_DIR}" \
                        --dataset_name="oscar" \
                        --dataset_config_name="unshuffled_deduplicated_no" \
                        --max_seq_length="128" \
                        --weight_decay="0.01" \
                        --per_device_train_batch_size="32" \
                        --per_device_eval_batch_size="32" \
                        --gradient_accumulation="4" \
                        --learning_rate="3e-4" \
                        --warmup_steps="1000" \
                        --overwrite_output_dir \
                        --num_train_epochs="18" \
                        --adam_beta1="0.9" \
                        --adam_beta2="0.98" \
                        --do_train \
                        --do_eval \
                        --logging_steps="500" \
                        --evaluation_strategy="steps" \
                        --report_to="tensorboard" \
                        --save_strategy="no"