mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
[TPU] Doc, fix xla_spawn.py, only preprocess dataset once (#4223)
* [TPU] Doc, fix xla_spawn.py, only preprocess dataset once * Update examples/README.md * [xla_spawn] Add `_mp_fn` to other Trainer scripts * [TPU] Fix: eval dataloader was None
This commit is contained in:
parent
274d850d34
commit
7b75aa9fa5
@ -53,4 +53,28 @@ pip install -r ./examples/requirements.txt
|
||||
|
||||
## Running on TPUs
|
||||
|
||||
Documentation to come.
|
||||
When using Tensorflow, TPUs are supported out of the box as a `tf.distribute.Strategy`.
|
||||
|
||||
When using PyTorch, we support TPUs thanks to `pytorch/xla`. For more context and information on how to setup your TPU environment refer to Google's documentation and to the
|
||||
very detailed [pytorch/xla README](https://github.com/pytorch/xla/blob/master/README.md).
|
||||
|
||||
In this repo, we provide a very simple launcher script named [xla_spawn.py](./xla_spawn.py) that lets you run our example scripts on multiple TPU cores without any boilerplate.
|
||||
Just pass a `--num_cores` flag to this script, then your regular training script with its arguments (this is similar to the `torch.distributed.launch` helper for torch.distributed).
|
||||
|
||||
For example for `run_glue`:
|
||||
|
||||
```bash
|
||||
python examples/xla_spawn.py --num_cores 8 \
|
||||
examples/text-classification/run_glue.py
|
||||
--model_name_or_path bert-base-cased \
|
||||
--task_name mnli \
|
||||
--data_dir ./data/glue_data/MNLI \
|
||||
--output_dir ./models/tpu \
|
||||
--overwrite_output_dir \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--num_train_epochs 1 \
|
||||
--save_steps 20000
|
||||
```
|
||||
|
||||
Feedback and more use cases and benchmarks involving TPUs are welcome, please share with the community.
|
||||
|
@ -404,7 +404,7 @@ def main():
|
||||
logger.info("Training/evaluation parameters %s", args)
|
||||
|
||||
# Prepare dataset for the GLUE task
|
||||
eval_dataset = GlueDataset(args, tokenizer=tokenizer, evaluate=True, local_rank=args.local_rank)
|
||||
eval_dataset = GlueDataset(args, tokenizer=tokenizer, evaluate=True)
|
||||
if args.data_subset > 0:
|
||||
eval_dataset = Subset(eval_dataset, list(range(min(args.data_subset, len(eval_dataset)))))
|
||||
eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
|
||||
|
@ -280,5 +280,10 @@ def main():
|
||||
return results
|
||||
|
||||
|
||||
def _mp_fn(index):
|
||||
# For xla_spawn (TPUs)
|
||||
main()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -221,5 +221,10 @@ def main():
|
||||
return results
|
||||
|
||||
|
||||
def _mp_fn(index):
|
||||
# For xla_spawn (TPUs)
|
||||
main()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -85,10 +85,12 @@ CoLA, SST-2. The following section provides details on how to run half-precision
|
||||
said, there shouldn’t be any issues in running half-precision training with the remaining GLUE tasks as well,
|
||||
since the data processor for each task inherits from the base class DataProcessor.
|
||||
|
||||
## Running on TPUs
|
||||
## Running on TPUs in PyTorch
|
||||
|
||||
You can accelerate your workloads on Google's TPUs. For information on how to setup your TPU environment refer to this
|
||||
[README](https://github.com/pytorch/xla/blob/master/README.md).
|
||||
**Update**: read the more up-to-date [Running on TPUs](../README.md#running-on-tpus) in the main README.md instead.
|
||||
|
||||
Even when running PyTorch, you can accelerate your workloads on Google's TPUs, using `pytorch/xla`. For information on how to setup your TPU environment refer to the
|
||||
[pytorch/xla README](https://github.com/pytorch/xla/blob/master/README.md).
|
||||
|
||||
The following are some examples of running the `*_tpu.py` finetuning scripts on TPUs. All steps for data preparation are
|
||||
identical to your normal GPU + Huggingface setup.
|
||||
@ -101,7 +103,6 @@ export GLUE_DIR=/path/to/glue
|
||||
export TASK_NAME=MNLI
|
||||
|
||||
python run_glue_tpu.py \
|
||||
--model_type bert \
|
||||
--model_name_or_path bert-base-cased \
|
||||
--task_name $TASK_NAME \
|
||||
--do_train \
|
||||
@ -115,8 +116,7 @@ python run_glue_tpu.py \
|
||||
--overwrite_output_dir \
|
||||
--logging_steps 50 \
|
||||
--save_steps 200 \
|
||||
--num_cores=8 \
|
||||
--only_log_master
|
||||
--num_cores=8
|
||||
```
|
||||
|
||||
### MRPC
|
||||
|
@ -134,16 +134,8 @@ def main():
|
||||
)
|
||||
|
||||
# Get datasets
|
||||
train_dataset = (
|
||||
GlueDataset(data_args, tokenizer=tokenizer, local_rank=training_args.local_rank)
|
||||
if training_args.do_train
|
||||
else None
|
||||
)
|
||||
eval_dataset = (
|
||||
GlueDataset(data_args, tokenizer=tokenizer, local_rank=training_args.local_rank, evaluate=True)
|
||||
if training_args.do_eval
|
||||
else None
|
||||
)
|
||||
train_dataset = GlueDataset(data_args, tokenizer=tokenizer) if training_args.do_train else None
|
||||
eval_dataset = GlueDataset(data_args, tokenizer=tokenizer, evaluate=True) if training_args.do_eval else None
|
||||
|
||||
def compute_metrics(p: EvalPrediction) -> Dict:
|
||||
if output_mode == "classification":
|
||||
@ -181,9 +173,7 @@ def main():
|
||||
eval_datasets = [eval_dataset]
|
||||
if data_args.task_name == "mnli":
|
||||
mnli_mm_data_args = dataclasses.replace(data_args, task_name="mnli-mm")
|
||||
eval_datasets.append(
|
||||
GlueDataset(mnli_mm_data_args, tokenizer=tokenizer, local_rank=training_args.local_rank, evaluate=True)
|
||||
)
|
||||
eval_datasets.append(GlueDataset(mnli_mm_data_args, tokenizer=tokenizer, evaluate=True))
|
||||
|
||||
for eval_dataset in eval_datasets:
|
||||
result = trainer.evaluate(eval_dataset=eval_dataset)
|
||||
|
@ -292,5 +292,10 @@ def main():
|
||||
return results
|
||||
|
||||
|
||||
def _mp_fn(index):
|
||||
# For xla_spawn (TPUs)
|
||||
main()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -12,17 +12,13 @@ Inspired by https://github.com/pytorch/pytorch/blob/master/torch/distributed/lau
|
||||
|
||||
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
from argparse import REMAINDER, ArgumentParser
|
||||
from pathlib import Path
|
||||
|
||||
import torch_xla.distributed.xla_multiprocessing as xmp
|
||||
|
||||
|
||||
def trim_suffix(s: str, suffix: str):
|
||||
return s if not s.endswith(suffix) or len(suffix) == 0 else s[: -len(suffix)]
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""
|
||||
Helper function parsing the command line options
|
||||
@ -44,7 +40,7 @@ def parse_args():
|
||||
"training_script",
|
||||
type=str,
|
||||
help=(
|
||||
"The full module name to the single TPU training "
|
||||
"The full path to the single TPU training "
|
||||
"program/script to be launched in parallel, "
|
||||
"followed by all the arguments for the "
|
||||
"training script"
|
||||
@ -61,7 +57,9 @@ def main():
|
||||
args = parse_args()
|
||||
|
||||
# Import training_script as a module.
|
||||
mod_name = trim_suffix(os.path.basename(args.training_script), ".py")
|
||||
script_fpath = Path(args.training_script)
|
||||
sys.path.append(str(script_fpath.parent.resolve()))
|
||||
mod_name = script_fpath.stem
|
||||
mod = importlib.import_module(mod_name)
|
||||
|
||||
# Patch sys.argv
|
||||
|
@ -5,12 +5,12 @@ from dataclasses import dataclass, field
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from filelock import FileLock
|
||||
from torch.utils.data.dataset import Dataset
|
||||
|
||||
from ...tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
|
||||
from ...tokenization_utils import PreTrainedTokenizer
|
||||
from ...tokenization_xlm_roberta import XLMRobertaTokenizer
|
||||
from ...trainer import torch_distributed_zero_first
|
||||
from ..processors.glue import glue_convert_examples_to_features, glue_output_modes, glue_processors
|
||||
from ..processors.utils import InputFeatures
|
||||
|
||||
@ -63,7 +63,6 @@ class GlueDataset(Dataset):
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
limit_length: Optional[int] = None,
|
||||
evaluate=False,
|
||||
local_rank=-1,
|
||||
):
|
||||
self.args = args
|
||||
processor = glue_processors[args.task_name]()
|
||||
@ -75,9 +74,11 @@ class GlueDataset(Dataset):
|
||||
"dev" if evaluate else "train", tokenizer.__class__.__name__, str(args.max_seq_length), args.task_name,
|
||||
),
|
||||
)
|
||||
with torch_distributed_zero_first(local_rank):
|
||||
# Make sure only the first process in distributed training processes the dataset,
|
||||
# and the others will use the cache.
|
||||
|
||||
# Make sure only the first process in distributed training processes the dataset,
|
||||
# and the others will use the cache.
|
||||
lock_path = cached_features_file + ".lock"
|
||||
with FileLock(lock_path):
|
||||
|
||||
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
||||
start = time.time()
|
||||
@ -109,13 +110,12 @@ class GlueDataset(Dataset):
|
||||
label_list=label_list,
|
||||
output_mode=self.output_mode,
|
||||
)
|
||||
if local_rank in [-1, 0]:
|
||||
start = time.time()
|
||||
torch.save(self.features, cached_features_file)
|
||||
# ^ This seems to take a lot of time so I want to investigate why and how we can improve.
|
||||
logger.info(
|
||||
f"Saving features into cached file %s [took %.3f s]", cached_features_file, time.time() - start
|
||||
)
|
||||
start = time.time()
|
||||
torch.save(self.features, cached_features_file)
|
||||
# ^ This seems to take a lot of time so I want to investigate why and how we can improve.
|
||||
logger.info(
|
||||
f"Saving features into cached file %s [took %.3f s]", cached_features_file, time.time() - start
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.features)
|
||||
|
@ -6,7 +6,7 @@ import re
|
||||
import shutil
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -195,10 +195,12 @@ class Trainer:
|
||||
if eval_dataset is None and self.eval_dataset is None:
|
||||
raise ValueError("Trainer: evaluation requires an eval_dataset.")
|
||||
|
||||
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||
|
||||
sampler = get_tpu_sampler(eval_dataset) if is_tpu_available() else None
|
||||
|
||||
data_loader = DataLoader(
|
||||
eval_dataset if eval_dataset is not None else self.eval_dataset,
|
||||
eval_dataset,
|
||||
sampler=sampler,
|
||||
batch_size=self.args.eval_batch_size,
|
||||
shuffle=False,
|
||||
@ -267,6 +269,16 @@ class Trainer:
|
||||
# keep track of model topology and gradients
|
||||
wandb.watch(self.model)
|
||||
|
||||
def num_examples(self, dataloader: Union[DataLoader, "pl.PerDeviceLoader"]) -> int:
|
||||
"""
|
||||
Helper to get num of examples from a DataLoader, by accessing its Dataset.
|
||||
"""
|
||||
if is_tpu_available():
|
||||
assert isinstance(dataloader, pl.PerDeviceLoader)
|
||||
return len(dataloader._loader._loader.dataset)
|
||||
else:
|
||||
return len(dataloader.dataset)
|
||||
|
||||
def train(self, model_path: Optional[str] = None):
|
||||
"""
|
||||
Main training entry point.
|
||||
@ -326,17 +338,15 @@ class Trainer:
|
||||
|
||||
# Train!
|
||||
if is_tpu_available():
|
||||
num_examples = len(train_dataloader._loader._loader.dataset)
|
||||
total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size()
|
||||
else:
|
||||
num_examples = len(train_dataloader.dataset)
|
||||
total_train_batch_size = (
|
||||
self.args.train_batch_size
|
||||
* self.args.gradient_accumulation_steps
|
||||
* (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1)
|
||||
)
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Num examples = %d", num_examples)
|
||||
logger.info(" Num examples = %d", self.num_examples(train_dataloader))
|
||||
logger.info(" Num Epochs = %d", num_train_epochs)
|
||||
logger.info(" Instantaneous batch size per device = %d", self.args.per_gpu_train_batch_size)
|
||||
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size)
|
||||
@ -606,9 +616,13 @@ class Trainer:
|
||||
model = self.model
|
||||
model.to(self.args.device)
|
||||
|
||||
if is_tpu_available():
|
||||
batch_size = dataloader._loader._loader.batch_size
|
||||
else:
|
||||
batch_size = dataloader.batch_size
|
||||
logger.info("***** Running %s *****", description)
|
||||
logger.info(" Num examples = %d", len(dataloader.dataset))
|
||||
logger.info(" Batch size = %d", dataloader.batch_size)
|
||||
logger.info(" Num examples = %d", self.num_examples(dataloader))
|
||||
logger.info(" Batch size = %d", batch_size)
|
||||
eval_losses: List[float] = []
|
||||
preds: np.ndarray = None
|
||||
label_ids: np.ndarray = None
|
||||
|
Loading…
Reference in New Issue
Block a user