mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[Flax] ViT training example (#12300)
* begin script * clean example, add readme * update readme * remove decay mask * remove masking * update readme & make flake happy
This commit is contained in:
parent
e799e0f1ed
commit
f1c81d6b92
101
examples/flax/vision/README.md
Normal file
101
examples/flax/vision/README.md
Normal file
@ -0,0 +1,101 @@
|
||||
<!---
|
||||
Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
-->
|
||||
|
||||
# Image Classification training examples
|
||||
|
||||
The following example showcases how to train/fine-tune `ViT` for image-classification using the JAX/Flax backend.
|
||||
|
||||
JAX/Flax allows you to trace pure functions and compile them into efficient, fused accelerator code on both GPU and TPU.
|
||||
Models written in JAX/Flax are **immutable** and updated in a purely functional
|
||||
way which enables simple and efficient model parallelism.
|
||||
|
||||
|
||||
In this example we will train/fine-tune the model on the [imagenette](https://github.com/fastai/imagenette) dataset.
|
||||
|
||||
Let's start by creating a model repository to save the trained model and logs.
|
||||
Here we call the model `"vit-base-patch16-imagenette"`, but you can change the model name as you like.
|
||||
|
||||
You can do this either directly on [huggingface.co](https://huggingface.co/new) (assuming that
|
||||
you are logged in) or via the command line:
|
||||
|
||||
```
|
||||
huggingface-cli repo create vit-base-patch16-imagenette
|
||||
```
|
||||
Next we clone the model repository to add the tokenizer and model files.
|
||||
```
|
||||
git clone https://huggingface.co/<your-username>/vit-base-patch16-imagenette
|
||||
```
|
||||
To ensure that all tensorboard traces will be uploaded correctly, we need to
|
||||
track them. You can run the following command inside your model repo to do so.
|
||||
|
||||
```
|
||||
cd vit-base-patch16-imagenette
|
||||
git lfs track "*tfevents*"
|
||||
```
|
||||
|
||||
Great, we have set up our model repository. During training, we will automatically
|
||||
push the training logs and model weights to the repo.
|
||||
|
||||
Next, let's add a symbolic link to the `run_image_classification_flax.py`.
|
||||
|
||||
```bash
|
||||
export MODEL_DIR="./vit-base-patch16-imagenette
|
||||
ln -s ~/transformers/examples/flax/summarization/run_image_classification_flax.py run_image_classification_flax.py
|
||||
```
|
||||
|
||||
## Prepare the dataset
|
||||
|
||||
We will use the [imagenette](https://github.com/fastai/imagenette) dataset to train/fine-tune our model. Imagenette is a subset of 10 easily classified classes from Imagenet (tench, English springer, cassette player, chain saw, church, French horn, garbage truck, gas pump, golf ball, parachute).
|
||||
|
||||
|
||||
### Download and extract the data.
|
||||
|
||||
```bash
|
||||
wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz
|
||||
tar -xvzf imagenette2.tgz
|
||||
```
|
||||
|
||||
This will create a `imagenette2` dir with two subdirectories `train` and `val` each with multiple subdirectories per class. The training script expects the following directory structure
|
||||
|
||||
```bash
|
||||
root/dog/xxx.png
|
||||
root/dog/xxy.png
|
||||
root/dog/[...]/xxz.png
|
||||
|
||||
root/cat/123.png
|
||||
root/cat/nsdf3.png
|
||||
root/cat/[...]/asd932_.png
|
||||
```
|
||||
|
||||
## Train the model
|
||||
|
||||
Next we can run the example script to fine-tune the model:
|
||||
|
||||
```bash
|
||||
python run_image_classification.py \
|
||||
--output_dir ${MODEL_DIR} \
|
||||
--model_name_or_path google/vit-base-patch16-224-in21k \
|
||||
--train_dir="imagenette2/train" \
|
||||
--validation_dir="imagenette2/val" \
|
||||
--num_train_epochs 5 \
|
||||
--learning_rate 1e-3 \
|
||||
--per_device_train_batch_size 128 --per_device_eval_batch_size 128 \
|
||||
--overwrite_output_dir \
|
||||
--preprocessing_num_workers 32 \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
This should finish in ~7mins with 99% validation accuracy.
|
8
examples/flax/vision/requirements.txt
Normal file
8
examples/flax/vision/requirements.txt
Normal file
@ -0,0 +1,8 @@
|
||||
jax>=0.2.8
|
||||
jaxlib>=0.1.59
|
||||
flax>=0.3.4
|
||||
optax>=0.0.8
|
||||
-f https://download.pytorch.org/whl/torch_stable.html
|
||||
torch==1.9.0+cpu
|
||||
-f https://download.pytorch.org/whl/torch_stable.html
|
||||
torchvision==0.10.0+cpu
|
467
examples/flax/vision/run_image_classification.py
Normal file
467
examples/flax/vision/run_image_classification.py
Normal file
@ -0,0 +1,467 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Team All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Pre-training/Fine-tuning ViT for image classification .
|
||||
|
||||
Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
|
||||
https://huggingface.co/models?filter=vit
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional
|
||||
|
||||
# for dataset and preprocessing
|
||||
import torch
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
from tqdm import tqdm
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import optax
|
||||
import transformers
|
||||
from flax import jax_utils
|
||||
from flax.jax_utils import unreplicate
|
||||
from flax.training import train_state
|
||||
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
||||
from transformers import (
|
||||
CONFIG_MAPPING,
|
||||
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||
AutoConfig,
|
||||
FlaxAutoModelForImageClassification,
|
||||
HfArgumentParser,
|
||||
TrainingArguments,
|
||||
is_tensorboard_available,
|
||||
set_seed,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING.keys())
|
||||
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
"""
|
||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
|
||||
"""
|
||||
|
||||
model_name_or_path: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The model checkpoint for weights initialization."
|
||||
"Don't set if you want to train a model from scratch."
|
||||
},
|
||||
)
|
||||
model_type: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
|
||||
)
|
||||
config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
||||
)
|
||||
cache_dir: Optional[str] = field(
|
||||
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
|
||||
)
|
||||
dtype: Optional[str] = field(
|
||||
default="float32",
|
||||
metadata={
|
||||
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataTrainingArguments:
|
||||
"""
|
||||
Arguments pertaining to what data we are going to input our model for training and eval.
|
||||
"""
|
||||
|
||||
train_dir: str = field(
|
||||
metadata={"help": "Path to the root training directory which contains one subdirectory per class."}
|
||||
)
|
||||
validation_dir: str = field(
|
||||
metadata={"help": "Path to the root validation directory which contains one subdirectory per class."},
|
||||
)
|
||||
image_size: Optional[int] = field(default=224, metadata={"help": " The size (resolution) of each image."})
|
||||
max_train_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_eval_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
overwrite_cache: bool = field(
|
||||
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
||||
)
|
||||
preprocessing_num_workers: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of processes to use for the preprocessing."},
|
||||
)
|
||||
|
||||
|
||||
class TrainState(train_state.TrainState):
|
||||
dropout_rng: jnp.ndarray
|
||||
|
||||
def replicate(self):
|
||||
return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
|
||||
|
||||
|
||||
def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
|
||||
summary_writer.scalar("train_time", train_time, step)
|
||||
|
||||
train_metrics = get_metrics(train_metrics)
|
||||
for key, vals in train_metrics.items():
|
||||
tag = f"train_{key}"
|
||||
for i, val in enumerate(vals):
|
||||
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
|
||||
|
||||
for metric_name, value in eval_metrics.items():
|
||||
summary_writer.scalar(f"eval_{metric_name}", value, step)
|
||||
|
||||
|
||||
def create_learning_rate_fn(
|
||||
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
|
||||
) -> Callable[[int], jnp.array]:
|
||||
"""Returns a linear warmup, linear_decay learning rate function."""
|
||||
steps_per_epoch = train_ds_size // train_batch_size
|
||||
num_train_steps = steps_per_epoch * num_train_epochs
|
||||
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
|
||||
decay_fn = optax.linear_schedule(
|
||||
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
|
||||
)
|
||||
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
|
||||
return schedule_fn
|
||||
|
||||
|
||||
def main():
|
||||
# See all possible arguments in src/transformers/training_args.py
|
||||
# or by passing the --help flag to this script.
|
||||
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
||||
|
||||
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
# If we pass only one argument to the script and it's the path to a json file,
|
||||
# let's parse it to get our arguments.
|
||||
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
if (
|
||||
os.path.exists(training_args.output_dir)
|
||||
and os.listdir(training_args.output_dir)
|
||||
and training_args.do_train
|
||||
and not training_args.overwrite_output_dir
|
||||
):
|
||||
raise ValueError(
|
||||
f"Output directory ({training_args.output_dir}) already exists and is not empty."
|
||||
"Use --overwrite_output_dir to overcome."
|
||||
)
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
# Setup logging, we only want one process per machine to log things on the screen.
|
||||
logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
|
||||
if jax.process_index() == 0:
|
||||
transformers.utils.logging.set_verbosity_info()
|
||||
else:
|
||||
transformers.utils.logging.set_verbosity_error()
|
||||
|
||||
# Set the verbosity to info of the Transformers logger (on main process only):
|
||||
logger.info(f"Training/evaluation parameters {training_args}")
|
||||
|
||||
# set seed for random transforms and torch dataloaders
|
||||
set_seed(training_args.seed)
|
||||
|
||||
# Initialize datasets and pre-processing transforms
|
||||
# We use torchvision here for faster pre-processing
|
||||
# Note that here we are using some default pre-processing, for maximum accuray
|
||||
# one should tune this part and carefully select what transformations to use.
|
||||
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||
train_dataset = torchvision.datasets.ImageFolder(
|
||||
data_args.train_dir,
|
||||
transforms.Compose(
|
||||
[
|
||||
transforms.RandomResizedCrop(data_args.image_size),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
normalize,
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
eval_dataset = torchvision.datasets.ImageFolder(
|
||||
data_args.validation_dir,
|
||||
transforms.Compose(
|
||||
[
|
||||
transforms.Resize(data_args.image_size),
|
||||
transforms.CenterCrop(data_args.image_size),
|
||||
transforms.ToTensor(),
|
||||
normalize,
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
# Load pretrained model and tokenizer
|
||||
if model_args.config_name:
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_args.config_name,
|
||||
num_labels=len(train_dataset.classes),
|
||||
image_size=data_args.image_size,
|
||||
cache_dir=model_args.cache_dir,
|
||||
)
|
||||
elif model_args.model_name_or_path:
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
num_labels=len(train_dataset.classes),
|
||||
image_size=data_args.image_size,
|
||||
cache_dir=model_args.cache_dir,
|
||||
)
|
||||
else:
|
||||
config = CONFIG_MAPPING[model_args.model_type]()
|
||||
logger.warning("You are instantiating a new config instance from scratch.")
|
||||
|
||||
if model_args.model_name_or_path:
|
||||
model = FlaxAutoModelForImageClassification.from_pretrained(
|
||||
model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
||||
)
|
||||
else:
|
||||
model = FlaxAutoModelForImageClassification.from_config(
|
||||
config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
||||
)
|
||||
|
||||
# Store some constant
|
||||
num_epochs = int(training_args.num_train_epochs)
|
||||
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
|
||||
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
|
||||
steps_per_epoch = len(train_dataset) // train_batch_size
|
||||
total_train_steps = steps_per_epoch * num_epochs
|
||||
|
||||
def collate_fn(examples):
|
||||
pixel_values = torch.stack([example[0] for example in examples])
|
||||
labels = torch.tensor([example[1] for example in examples])
|
||||
|
||||
batch = {"pixel_values": pixel_values, "labels": labels}
|
||||
batch = {k: v.numpy() for k, v in batch.items()}
|
||||
|
||||
return batch
|
||||
|
||||
# Create data loaders
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=train_batch_size,
|
||||
shuffle=True,
|
||||
num_workers=data_args.preprocessing_num_workers,
|
||||
persistent_workers=True,
|
||||
drop_last=True,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
|
||||
eval_loader = torch.utils.data.DataLoader(
|
||||
eval_dataset,
|
||||
batch_size=eval_batch_size,
|
||||
shuffle=False,
|
||||
num_workers=data_args.preprocessing_num_workers,
|
||||
persistent_workers=True,
|
||||
drop_last=True,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
|
||||
# Enable tensorboard only on the master node
|
||||
has_tensorboard = is_tensorboard_available()
|
||||
if has_tensorboard and jax.process_index() == 0:
|
||||
try:
|
||||
from flax.metrics.tensorboard import SummaryWriter
|
||||
|
||||
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
|
||||
except ImportError as ie:
|
||||
has_tensorboard = False
|
||||
logger.warning(
|
||||
f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Unable to display metrics through TensorBoard because the package is not installed: "
|
||||
"Please run pip install tensorboard to enable."
|
||||
)
|
||||
|
||||
# Initialize our training
|
||||
rng = jax.random.PRNGKey(training_args.seed)
|
||||
rng, dropout_rng = jax.random.split(rng)
|
||||
|
||||
# Create learning rate schedule
|
||||
linear_decay_lr_schedule_fn = create_learning_rate_fn(
|
||||
len(train_dataset),
|
||||
train_batch_size,
|
||||
training_args.num_train_epochs,
|
||||
training_args.warmup_steps,
|
||||
training_args.learning_rate,
|
||||
)
|
||||
|
||||
# create adam optimizer
|
||||
adamw = optax.adamw(
|
||||
learning_rate=linear_decay_lr_schedule_fn,
|
||||
b1=training_args.adam_beta1,
|
||||
b2=training_args.adam_beta2,
|
||||
eps=training_args.adam_epsilon,
|
||||
weight_decay=training_args.weight_decay,
|
||||
)
|
||||
|
||||
# Setup train state
|
||||
state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
|
||||
|
||||
def loss_fn(logits, labels):
|
||||
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1]))
|
||||
return loss.mean()
|
||||
|
||||
# Define gradient update step fn
|
||||
def train_step(state, batch):
|
||||
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
|
||||
|
||||
def compute_loss(params):
|
||||
labels = batch.pop("labels")
|
||||
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
|
||||
loss = loss_fn(logits, labels)
|
||||
return loss
|
||||
|
||||
grad_fn = jax.value_and_grad(compute_loss)
|
||||
loss, grad = grad_fn(state.params)
|
||||
grad = jax.lax.pmean(grad, "batch")
|
||||
|
||||
new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
|
||||
|
||||
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
|
||||
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
||||
|
||||
return new_state, metrics
|
||||
|
||||
# Define eval fn
|
||||
def eval_step(params, batch):
|
||||
labels = batch.pop("labels")
|
||||
logits = model(**batch, params=params, train=False)[0]
|
||||
loss = loss_fn(logits, labels)
|
||||
|
||||
# summarize metrics
|
||||
accuracy = (jnp.argmax(logits, axis=-1) == labels).mean()
|
||||
metrics = {"loss": loss, "accuracy": accuracy}
|
||||
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
||||
return metrics
|
||||
|
||||
# Create parallel version of the train and eval step
|
||||
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
|
||||
p_eval_step = jax.pmap(eval_step, "batch")
|
||||
|
||||
# Replicate the train state on each device
|
||||
state = state.replicate()
|
||||
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {len(train_dataset)}")
|
||||
logger.info(f" Num Epochs = {num_epochs}")
|
||||
logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
|
||||
logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
|
||||
logger.info(f" Total optimization steps = {total_train_steps}")
|
||||
|
||||
train_time = 0
|
||||
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
||||
for epoch in epochs:
|
||||
# ======================== Training ================================
|
||||
train_start = time.time()
|
||||
|
||||
# Create sampling rng
|
||||
rng, input_rng = jax.random.split(rng)
|
||||
train_metrics = []
|
||||
|
||||
steps_per_epoch = len(train_dataset) // train_batch_size
|
||||
train_step_progress_bar = tqdm(total=steps_per_epoch, desc="Training...", position=1, leave=False)
|
||||
# train
|
||||
for batch in train_loader:
|
||||
batch = shard(batch)
|
||||
state, train_metric = p_train_step(state, batch)
|
||||
train_metrics.append(train_metric)
|
||||
|
||||
train_step_progress_bar.update(1)
|
||||
|
||||
train_time += time.time() - train_start
|
||||
|
||||
train_metric = unreplicate(train_metric)
|
||||
|
||||
train_step_progress_bar.close()
|
||||
epochs.write(
|
||||
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
|
||||
)
|
||||
|
||||
# ======================== Evaluating ==============================
|
||||
eval_metrics = []
|
||||
eval_steps = len(eval_dataset) // eval_batch_size
|
||||
eval_step_progress_bar = tqdm(total=eval_steps, desc="Evaluating...", position=2, leave=False)
|
||||
for batch in eval_loader:
|
||||
# Model forward
|
||||
batch = shard(batch)
|
||||
metrics = p_eval_step(state.params, batch)
|
||||
eval_metrics.append(metrics)
|
||||
|
||||
eval_step_progress_bar.update(1)
|
||||
|
||||
# normalize eval metrics
|
||||
eval_metrics = get_metrics(eval_metrics)
|
||||
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
||||
|
||||
# Print metrics and update progress bar
|
||||
eval_step_progress_bar.close()
|
||||
desc = (
|
||||
f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {round(eval_metrics['loss'].item(), 4)} | "
|
||||
f"Eval Accuracy: {round(eval_metrics['accuracy'].item(), 4)})"
|
||||
)
|
||||
epochs.write(desc)
|
||||
epochs.desc = desc
|
||||
|
||||
# Save metrics
|
||||
if has_tensorboard and jax.process_index() == 0:
|
||||
cur_step = epoch * (len(train_dataset) // train_batch_size)
|
||||
write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
|
||||
|
||||
# save checkpoint after each epoch and push checkpoint to the hub
|
||||
if jax.process_index() == 0:
|
||||
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
||||
model.save_pretrained(
|
||||
training_args.output_dir,
|
||||
params=params,
|
||||
push_to_hub=training_args.push_to_hub,
|
||||
commit_message=f"Saving weights and logs of epoch {epoch+1}",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Reference in New Issue
Block a user