mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add mlm pretraining xla torch readme (#12011)
* fix_torch_device_generate_test * remove @ * upload * Apply suggestions from code review * Apply suggestions from code review * Apply suggestions from code review * Update examples/flax/language-modeling/README.md * add more info * finish * fix Co-authored-by: Patrick von Platen <patrick@huggingface.co>
This commit is contained in:
parent
ecd6efe7cb
commit
16c0efca2c
@ -123,7 +123,123 @@ This should take less than 18 hours.
|
||||
Training statistics can be accessed on [tfhub.de](https://tensorboard.dev/experiment/GdYmdak2TWeVz0DDRYOrrg).
|
||||
|
||||
For a step-by-step walkthrough of how to do masked language modeling in Flax, please have a
|
||||
look at [this TODO: (Patrick)]() google colab.
|
||||
look at [this](https://colab.research.google.com/github/huggingface/notebooks/blob/master/examples/masked_language_modeling_flax.ipynb) google colab.
|
||||
|
||||
|
||||
## TODO(Patrick): Add comparison with PyTorch GPU/TPU
|
||||
## Runtime evaluation
|
||||
|
||||
We also ran masked language modeling using PyTorch/XLA on a TPUv3-8, and PyTorch on 8 V100 GPUs. We report the
|
||||
overall training time below.
|
||||
For reproducibility, we state the training commands used for PyTorch/XLA and PyTorch further below.
|
||||
|
||||
| Task | [TPU v3-8 (Flax)](https://tensorboard.dev/experiment/GdYmdak2TWeVz0DDRYOrrg/) | [TPU v3-8 (Pytorch/XLA)](https://tensorboard.dev/experiment/7Jq1kcQQRAmy12KOdXek7A/)| [8 GPU (PyTorch)](https://tensorboard.dev/experiment/PJneV8FQRxa2unPw1QnVHA) |
|
||||
|-------|-----------|------------|------------|
|
||||
| MLM | 15h32m | 23h46m | 44h14m |
|
||||
| **COST*** | $124.24 | $187.84 | $877.92 |
|
||||
|
||||
*All experiments are ran on Google Cloud Platform. Prices are on-demand prices
|
||||
(not preemptible), obtained on May 12, 2021 for zone Iowa (us-central1) using
|
||||
the following tables:
|
||||
[TPU pricing table](https://cloud.google.com/tpu/pricing) ($8.00/h for v3-8),
|
||||
[GPU pricing table](https://cloud.google.com/compute/gpus-pricing) ($2.48/h per
|
||||
V100 GPU). GPU experiments are ran without further optimizations besides JAX
|
||||
transformations. GPU experiments are ran with full precision (fp32). "TPU v3-8"
|
||||
are 8 TPU cores on 4 chips (each chips has 2 cores), while "8 GPU" are 8 GPU chips.
|
||||
|
||||
### Script to run MLM with PyTorch/XLA on TPUv3-8
|
||||
|
||||
For comparison one can run the same pre-training with PyTorch/XLA on TPU. To set up PyTorch/XLA on Cloud TPU VMs, please
|
||||
refer to [this](https://cloud.google.com/tpu/docs/pytorch-xla-ug-tpu-vm) guide.
|
||||
Having created the tokenzier and configuration in `norwegian-roberta-base`, we create the following symbolic links:
|
||||
|
||||
```bash
|
||||
ln -s ~/transformers/examples/pytorch/language-modeling/run_mlm.py ./
|
||||
ln -s ~/transformers/examples/pytorch/xla_spawn.py ./
|
||||
```
|
||||
|
||||
, set the following environment variables:
|
||||
|
||||
```bash
|
||||
export XRT_TPU_CONFIG="localservice;0;localhost:51011"
|
||||
unset LD_PRELOAD
|
||||
|
||||
export NUM_TPUS=8
|
||||
export TOKENIZERS_PARALLELISM=0
|
||||
export MODEL_DIR="./norwegian-roberta-base"
|
||||
mkdir -p ${MODEL_DIR}
|
||||
```
|
||||
|
||||
, and start training as follows:
|
||||
|
||||
```bash
|
||||
python3 xla_spawn.py --num_cores ${NUM_TPUS} run_mlm.py --output_dir="./runs" \
|
||||
--model_type="roberta" \
|
||||
--config_name="${MODEL_DIR}" \
|
||||
--tokenizer_name="${MODEL_DIR}" \
|
||||
--dataset_name="oscar" \
|
||||
--dataset_config_name="unshuffled_deduplicated_no" \
|
||||
--max_seq_length="128" \
|
||||
--weight_decay="0.01" \
|
||||
--per_device_train_batch_size="128" \
|
||||
--per_device_eval_batch_size="128" \
|
||||
--learning_rate="3e-4" \
|
||||
--warmup_steps="1000" \
|
||||
--overwrite_output_dir \
|
||||
--num_train_epochs="18" \
|
||||
--adam_beta1="0.9" \
|
||||
--adam_beta2="0.98" \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--logging_steps="500" \
|
||||
--evaluation_strategy="epoch" \
|
||||
--report_to="tensorboard" \
|
||||
--save_strategy="no"
|
||||
```
|
||||
|
||||
### Script to compare pre-training with PyTorch on 8 GPU V100's
|
||||
|
||||
For comparison you can run the same pre-training with PyTorch on GPU. Note that we have to make use of `gradient_accumulation`
|
||||
because the maximum batch size that fits on a single V100 GPU is 32 instead of 128.
|
||||
Having created the tokenzier and configuration in `norwegian-roberta-base`, we create the following symbolic links:
|
||||
|
||||
```bash
|
||||
ln -s ~/transformers/examples/pytorch/language-modeling/run_mlm.py ./
|
||||
```
|
||||
|
||||
, set some environment variables:
|
||||
|
||||
```bash
|
||||
export NUM_GPUS=8
|
||||
export TOKENIZERS_PARALLELISM=0
|
||||
export MODEL_DIR="./norwegian-roberta-base"
|
||||
mkdir -p ${MODEL_DIR}
|
||||
```
|
||||
|
||||
, and can start training as follows:
|
||||
|
||||
```bash
|
||||
python3 -m torch.distributed.launch --nproc_per_node ${NUM_GPUS} run_mlm.py \
|
||||
--output_dir="./runs" \
|
||||
--model_type="roberta" \
|
||||
--config_name="${MODEL_DIR}" \
|
||||
--tokenizer_name="${MODEL_DIR}" \
|
||||
--dataset_name="oscar" \
|
||||
--dataset_config_name="unshuffled_deduplicated_no" \
|
||||
--max_seq_length="128" \
|
||||
--weight_decay="0.01" \
|
||||
--per_device_train_batch_size="32" \
|
||||
--per_device_eval_batch_size="32" \
|
||||
--gradient_accumulation="4" \
|
||||
--learning_rate="3e-4" \
|
||||
--warmup_steps="1000" \
|
||||
--overwrite_output_dir \
|
||||
--num_train_epochs="18" \
|
||||
--adam_beta1="0.9" \
|
||||
--adam_beta2="0.98" \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--logging_steps="500" \
|
||||
--evaluation_strategy="steps" \
|
||||
--report_to="tensorboard" \
|
||||
--save_strategy="no"
|
||||
```
|
||||
|
@ -101,7 +101,7 @@ overall training time below. For comparison we ran Pytorch's [run_glue.py](https
|
||||
*All experiments are ran on Google Cloud Platform. Prices are on-demand prices
|
||||
(not preemptible), obtained on May 12, 2021 for zone Iowa (us-central1) using
|
||||
the following tables:
|
||||
[TPU pricing table](https://cloud.google.com/tpu/pricing) ($2.40/h for v3-8),
|
||||
[TPU pricing table](https://cloud.google.com/tpu/pricing) ($8.00/h for v3-8),
|
||||
[GPU pricing table](https://cloud.google.com/compute/gpus-pricing) ($2.48/h per
|
||||
V100 GPU). GPU experiments are ran without further optimizations besides JAX
|
||||
transformations. GPU experiments are ran with full precision (fp32). "TPU v3-8"
|
||||
|
Loading…
Reference in New Issue
Block a user