transformers/examples/flax/language-modeling
Suraj Patil 15b498f3b8
Flax CLM script (#12023)
* first draft

* max_seq_length => block_size

* fix arg names

* fix typos

* fix loss calculation

* add max examples, fix  train eval steps, metrics

* optimizer mask

* fix perpelexity, metric logging

* fix logging

* data_collator = > data_loader

* refactor loss_fn

* support single GPU

* pass distributed to write_metric

* fix jitting

* fix single device training

* fix single device metrics

* close inner progress bars once finished

* add overwrite_cache arg

* ifx dataset caching issue

* add more logs

* few small fixes,

* address nicholas suggestions

* fix docstr

* address patricks suggestions

* make flake happy

* pass new new_dropout_rng to apply_gradients

* reset train metrics after every epoc

* remove distributed logis, small fixes
2021-06-11 15:16:20 +05:30
..
README.md [Flax MLM] Refactor run mlm with optax (#11745) 2021-05-19 12:00:58 +01:00
requirements.txt [Flax MLM] Refactor run mlm with optax (#11745) 2021-05-19 12:00:58 +01:00
run_clm_flax.py Flax CLM script (#12023) 2021-06-11 15:16:20 +05:30
run_mlm_flax.py pass decay_mask fn to optimizer (#12087) 2021-06-09 18:49:27 +01:00

Language model training examples

The following example showcases how to train a language model from scratch using the JAX/Flax backend.

JAX/Flax allows you to trace pure functions and compile them into efficient, fused accelerator code on both GPU and TPU. Models written in JAX/Flax are immutable and updated in a purely functional way which enables simple and efficient model parallelism.

Masked language modeling

In the following, we demonstrate how to train a bi-directional transformer model using masked language modeling objective as introduced in BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding. More specifically, we demonstrate how JAX/Flax can be leveraged to pre-train roberta-base in Norwegian on a single TPUv3-8 pod.

The example script uses the 🤗 Datasets library. You can easily customize them to your needs if you need extra processing on your datasets.

Let's start by creating a folder to save the trained model and a symbolic link to the run_mlm_flax.py script.

export MODEL_DIR="./norwegian-roberta-base"
mkdir -p ${MODEL_DIR}
ln -s ~/transformers/examples/flax/language-modeling/run_mlm_flax.py run_mlm_flax.py

Train tokenizer

In the first step, we train a tokenizer to efficiently process the text input for the model. Similar to how it is shown in How to train a new language model from scratch using Transformers and Tokenizers, we use a ByteLevelBPETokenizer. The tokenizer is trained on the complete Norwegian dataset of OSCAR and consequently saved in ${MODEL_DIR} This can take up to 10 minutes depending on your hardware .

from datasets import load_dataset
from tokenizers import trainers, Tokenizer, normalizers, ByteLevelBPETokenizer

model_dir = "./norwegian-roberta-base"  # ${MODEL_DIR}

# load dataset
dataset = load_dataset("oscar", "unshuffled_deduplicated_no", split="train")

# Instantiate tokenizer
tokenizer = ByteLevelBPETokenizer()

def batch_iterator(batch_size=1000):
    for i in range(0, len(dataset), batch_size):
        yield dataset[i: i + batch_size]["text"]

# Customized training
tokenizer.train_from_iterator(batch_iterator(), vocab_size=50265, min_frequency=2, special_tokens=[
    "<s>",
    "<pad>",
    "</s>",
    "<unk>",
    "<mask>",
])

# Save files to disk
tokenizer.save(f"{model_dir}/tokenizer.json")

Create configuration

Next, we create the model's configuration file. This is as simple as loading and storing **roberta-base** in the local model folder:

from transformers import RobertaConfig

model_dir = "./norwegian-roberta-base"  # ${MODEL_DIR}

config = RobertaConfig.from_pretrained("roberta-base")
config.save_pretrained(model_dir)

Train model

Next we can run the example script to pretrain the model:

./run_mlm_flax.py \
        --output_dir="./runs" \
        --model_type="roberta" \
        --config_name="${MODEL_DIR}" \
        --tokenizer_name="${MODEL_DIR}" \
        --dataset_name="oscar" \
        --dataset_config_name="unshuffled_deduplicated_no" \
        --max_seq_length="128" \
        --weight_decay="0.01" \
        --per_device_train_batch_size="128" \
        --per_device_eval_batch_size="128" \
        --learning_rate="3e-4" \
        --warmup_steps="1000" \
        --overwrite_output_dir \
        --pad_to_max_length \
        --num_train_epochs="18" \
        --adam_beta1="0.9" \
        --adam_beta2="0.98"

Training should converge at a loss and accuracy of 1.78 and 0.64 respectively after 18 epochs on a single TPUv3-8. This should take less than 18 hours. Training statistics can be accessed on tfhub.de.

For a step-by-step walkthrough of how to do masked language modeling in Flax, please have a look at this TODO: (Patrick) google colab.

TODO(Patrick): Add comparison with PyTorch GPU/TPU