[Flax] ViT training example (#12300)

* begin script

* clean example, add readme

* update readme

* remove decay mask

* remove masking

* update readme & make flake happy
This commit is contained in:
Suraj Patil 2021-07-05 18:23:03 +05:30 committed by GitHub
parent e799e0f1ed
commit f1c81d6b92
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 576 additions and 0 deletions

View File

@ -0,0 +1,101 @@
<!---
Copyright 2021 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
# Image Classification training examples
The following example showcases how to train/fine-tune `ViT` for image-classification using the JAX/Flax backend.
JAX/Flax allows you to trace pure functions and compile them into efficient, fused accelerator code on both GPU and TPU.
Models written in JAX/Flax are **immutable** and updated in a purely functional
way which enables simple and efficient model parallelism.
In this example we will train/fine-tune the model on the [imagenette](https://github.com/fastai/imagenette) dataset.
Let's start by creating a model repository to save the trained model and logs.
Here we call the model `"vit-base-patch16-imagenette"`, but you can change the model name as you like.
You can do this either directly on [huggingface.co](https://huggingface.co/new) (assuming that
you are logged in) or via the command line:
```
huggingface-cli repo create vit-base-patch16-imagenette
```
Next we clone the model repository to add the tokenizer and model files.
```
git clone https://huggingface.co/<your-username>/vit-base-patch16-imagenette
```
To ensure that all tensorboard traces will be uploaded correctly, we need to
track them. You can run the following command inside your model repo to do so.
```
cd vit-base-patch16-imagenette
git lfs track "*tfevents*"
```
Great, we have set up our model repository. During training, we will automatically
push the training logs and model weights to the repo.
Next, let's add a symbolic link to the `run_image_classification_flax.py`.
```bash
export MODEL_DIR="./vit-base-patch16-imagenette
ln -s ~/transformers/examples/flax/summarization/run_image_classification_flax.py run_image_classification_flax.py
```
## Prepare the dataset
We will use the [imagenette](https://github.com/fastai/imagenette) dataset to train/fine-tune our model. Imagenette is a subset of 10 easily classified classes from Imagenet (tench, English springer, cassette player, chain saw, church, French horn, garbage truck, gas pump, golf ball, parachute).
### Download and extract the data.
```bash
wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz
tar -xvzf imagenette2.tgz
```
This will create a `imagenette2` dir with two subdirectories `train` and `val` each with multiple subdirectories per class. The training script expects the following directory structure
```bash
root/dog/xxx.png
root/dog/xxy.png
root/dog/[...]/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/[...]/asd932_.png
```
## Train the model
Next we can run the example script to fine-tune the model:
```bash
python run_image_classification.py \
--output_dir ${MODEL_DIR} \
--model_name_or_path google/vit-base-patch16-224-in21k \
--train_dir="imagenette2/train" \
--validation_dir="imagenette2/val" \
--num_train_epochs 5 \
--learning_rate 1e-3 \
--per_device_train_batch_size 128 --per_device_eval_batch_size 128 \
--overwrite_output_dir \
--preprocessing_num_workers 32 \
--push_to_hub
```
This should finish in ~7mins with 99% validation accuracy.

View File

@ -0,0 +1,8 @@
jax>=0.2.8
jaxlib>=0.1.59
flax>=0.3.4
optax>=0.0.8
-f https://download.pytorch.org/whl/torch_stable.html
torch==1.9.0+cpu
-f https://download.pytorch.org/whl/torch_stable.html
torchvision==0.10.0+cpu

View File

@ -0,0 +1,467 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2021 The HuggingFace Team All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Pre-training/Fine-tuning ViT for image classification .
Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
https://huggingface.co/models?filter=vit
"""
import logging
import os
import sys
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Callable, Optional
# for dataset and preprocessing
import torch
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm
import jax
import jax.numpy as jnp
import optax
import transformers
from flax import jax_utils
from flax.jax_utils import unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from transformers import (
CONFIG_MAPPING,
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
AutoConfig,
FlaxAutoModelForImageClassification,
HfArgumentParser,
TrainingArguments,
is_tensorboard_available,
set_seed,
)
logger = logging.getLogger(__name__)
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
"""
model_name_or_path: Optional[str] = field(
default=None,
metadata={
"help": "The model checkpoint for weights initialization."
"Don't set if you want to train a model from scratch."
},
)
model_type: Optional[str] = field(
default=None,
metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
cache_dir: Optional[str] = field(
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
)
dtype: Optional[str] = field(
default="float32",
metadata={
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
},
)
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
train_dir: str = field(
metadata={"help": "Path to the root training directory which contains one subdirectory per class."}
)
validation_dir: str = field(
metadata={"help": "Path to the root validation directory which contains one subdirectory per class."},
)
image_size: Optional[int] = field(default=224, metadata={"help": " The size (resolution) of each image."})
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set."
},
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
class TrainState(train_state.TrainState):
dropout_rng: jnp.ndarray
def replicate(self):
return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
summary_writer.scalar("train_time", train_time, step)
train_metrics = get_metrics(train_metrics)
for key, vals in train_metrics.items():
tag = f"train_{key}"
for i, val in enumerate(vals):
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
for metric_name, value in eval_metrics.items():
summary_writer.scalar(f"eval_{metric_name}", value, step)
def create_learning_rate_fn(
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
) -> Callable[[int], jnp.array]:
"""Returns a linear warmup, linear_decay learning rate function."""
steps_per_epoch = train_ds_size // train_batch_size
num_train_steps = steps_per_epoch * num_train_epochs
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
decay_fn = optax.linear_schedule(
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
)
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
return schedule_fn
def main():
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if (
os.path.exists(training_args.output_dir)
and os.listdir(training_args.output_dir)
and training_args.do_train
and not training_args.overwrite_output_dir
):
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty."
"Use --overwrite_output_dir to overcome."
)
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
# Setup logging, we only want one process per machine to log things on the screen.
logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
if jax.process_index() == 0:
transformers.utils.logging.set_verbosity_info()
else:
transformers.utils.logging.set_verbosity_error()
# Set the verbosity to info of the Transformers logger (on main process only):
logger.info(f"Training/evaluation parameters {training_args}")
# set seed for random transforms and torch dataloaders
set_seed(training_args.seed)
# Initialize datasets and pre-processing transforms
# We use torchvision here for faster pre-processing
# Note that here we are using some default pre-processing, for maximum accuray
# one should tune this part and carefully select what transformations to use.
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
train_dataset = torchvision.datasets.ImageFolder(
data_args.train_dir,
transforms.Compose(
[
transforms.RandomResizedCrop(data_args.image_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]
),
)
eval_dataset = torchvision.datasets.ImageFolder(
data_args.validation_dir,
transforms.Compose(
[
transforms.Resize(data_args.image_size),
transforms.CenterCrop(data_args.image_size),
transforms.ToTensor(),
normalize,
]
),
)
# Load pretrained model and tokenizer
if model_args.config_name:
config = AutoConfig.from_pretrained(
model_args.config_name,
num_labels=len(train_dataset.classes),
image_size=data_args.image_size,
cache_dir=model_args.cache_dir,
)
elif model_args.model_name_or_path:
config = AutoConfig.from_pretrained(
model_args.model_name_or_path,
num_labels=len(train_dataset.classes),
image_size=data_args.image_size,
cache_dir=model_args.cache_dir,
)
else:
config = CONFIG_MAPPING[model_args.model_type]()
logger.warning("You are instantiating a new config instance from scratch.")
if model_args.model_name_or_path:
model = FlaxAutoModelForImageClassification.from_pretrained(
model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
)
else:
model = FlaxAutoModelForImageClassification.from_config(
config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
)
# Store some constant
num_epochs = int(training_args.num_train_epochs)
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
steps_per_epoch = len(train_dataset) // train_batch_size
total_train_steps = steps_per_epoch * num_epochs
def collate_fn(examples):
pixel_values = torch.stack([example[0] for example in examples])
labels = torch.tensor([example[1] for example in examples])
batch = {"pixel_values": pixel_values, "labels": labels}
batch = {k: v.numpy() for k, v in batch.items()}
return batch
# Create data loaders
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=train_batch_size,
shuffle=True,
num_workers=data_args.preprocessing_num_workers,
persistent_workers=True,
drop_last=True,
collate_fn=collate_fn,
)
eval_loader = torch.utils.data.DataLoader(
eval_dataset,
batch_size=eval_batch_size,
shuffle=False,
num_workers=data_args.preprocessing_num_workers,
persistent_workers=True,
drop_last=True,
collate_fn=collate_fn,
)
# Enable tensorboard only on the master node
has_tensorboard = is_tensorboard_available()
if has_tensorboard and jax.process_index() == 0:
try:
from flax.metrics.tensorboard import SummaryWriter
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
except ImportError as ie:
has_tensorboard = False
logger.warning(
f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
)
else:
logger.warning(
"Unable to display metrics through TensorBoard because the package is not installed: "
"Please run pip install tensorboard to enable."
)
# Initialize our training
rng = jax.random.PRNGKey(training_args.seed)
rng, dropout_rng = jax.random.split(rng)
# Create learning rate schedule
linear_decay_lr_schedule_fn = create_learning_rate_fn(
len(train_dataset),
train_batch_size,
training_args.num_train_epochs,
training_args.warmup_steps,
training_args.learning_rate,
)
# create adam optimizer
adamw = optax.adamw(
learning_rate=linear_decay_lr_schedule_fn,
b1=training_args.adam_beta1,
b2=training_args.adam_beta2,
eps=training_args.adam_epsilon,
weight_decay=training_args.weight_decay,
)
# Setup train state
state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
def loss_fn(logits, labels):
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1]))
return loss.mean()
# Define gradient update step fn
def train_step(state, batch):
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
def compute_loss(params):
labels = batch.pop("labels")
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
loss = loss_fn(logits, labels)
return loss
grad_fn = jax.value_and_grad(compute_loss)
loss, grad = grad_fn(state.params)
grad = jax.lax.pmean(grad, "batch")
new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
metrics = jax.lax.pmean(metrics, axis_name="batch")
return new_state, metrics
# Define eval fn
def eval_step(params, batch):
labels = batch.pop("labels")
logits = model(**batch, params=params, train=False)[0]
loss = loss_fn(logits, labels)
# summarize metrics
accuracy = (jnp.argmax(logits, axis=-1) == labels).mean()
metrics = {"loss": loss, "accuracy": accuracy}
metrics = jax.lax.pmean(metrics, axis_name="batch")
return metrics
# Create parallel version of the train and eval step
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
p_eval_step = jax.pmap(eval_step, "batch")
# Replicate the train state on each device
state = state.replicate()
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num Epochs = {num_epochs}")
logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
logger.info(f" Total optimization steps = {total_train_steps}")
train_time = 0
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
for epoch in epochs:
# ======================== Training ================================
train_start = time.time()
# Create sampling rng
rng, input_rng = jax.random.split(rng)
train_metrics = []
steps_per_epoch = len(train_dataset) // train_batch_size
train_step_progress_bar = tqdm(total=steps_per_epoch, desc="Training...", position=1, leave=False)
# train
for batch in train_loader:
batch = shard(batch)
state, train_metric = p_train_step(state, batch)
train_metrics.append(train_metric)
train_step_progress_bar.update(1)
train_time += time.time() - train_start
train_metric = unreplicate(train_metric)
train_step_progress_bar.close()
epochs.write(
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
)
# ======================== Evaluating ==============================
eval_metrics = []
eval_steps = len(eval_dataset) // eval_batch_size
eval_step_progress_bar = tqdm(total=eval_steps, desc="Evaluating...", position=2, leave=False)
for batch in eval_loader:
# Model forward
batch = shard(batch)
metrics = p_eval_step(state.params, batch)
eval_metrics.append(metrics)
eval_step_progress_bar.update(1)
# normalize eval metrics
eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
# Print metrics and update progress bar
eval_step_progress_bar.close()
desc = (
f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {round(eval_metrics['loss'].item(), 4)} | "
f"Eval Accuracy: {round(eval_metrics['accuracy'].item(), 4)})"
)
epochs.write(desc)
epochs.desc = desc
# Save metrics
if has_tensorboard and jax.process_index() == 0:
cur_step = epoch * (len(train_dataset) // train_batch_size)
write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0:
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
model.save_pretrained(
training_args.output_dir,
params=params,
push_to_hub=training_args.push_to_hub,
commit_message=f"Saving weights and logs of epoch {epoch+1}",
)
if __name__ == "__main__":
main()