#!/usr/bin/env python # coding=utf-8 # Copyright 2021 The HuggingFace Team All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Pre-training/Fine-tuning ViT for image classification . Here is the full list of checkpoints on the hub that can be fine-tuned by this script: https://huggingface.co/models?filter=vit """ import logging import os import sys import time from dataclasses import dataclass, field from pathlib import Path from typing import Callable, Optional # for dataset and preprocessing import torch import torchvision import torchvision.transforms as transforms from tqdm import tqdm import jax import jax.numpy as jnp import optax import transformers from flax import jax_utils from flax.jax_utils import unreplicate from flax.training import train_state from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key from transformers import ( CONFIG_MAPPING, FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, AutoConfig, FlaxAutoModelForImageClassification, HfArgumentParser, TrainingArguments, is_tensorboard_available, set_seed, ) logger = logging.getLogger(__name__) MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING.keys()) MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) @dataclass class ModelArguments: """ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. """ model_name_or_path: Optional[str] = field( default=None, metadata={ "help": "The model checkpoint for weights initialization." "Don't set if you want to train a model from scratch." }, ) model_type: Optional[str] = field( default=None, metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}, ) config_name: Optional[str] = field( default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} ) cache_dir: Optional[str] = field( default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} ) dtype: Optional[str] = field( default="float32", metadata={ "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`." }, ) @dataclass class DataTrainingArguments: """ Arguments pertaining to what data we are going to input our model for training and eval. """ train_dir: str = field( metadata={"help": "Path to the root training directory which contains one subdirectory per class."} ) validation_dir: str = field( metadata={"help": "Path to the root validation directory which contains one subdirectory per class."}, ) image_size: Optional[int] = field(default=224, metadata={"help": " The size (resolution) of each image."}) max_train_samples: Optional[int] = field( default=None, metadata={ "help": "For debugging purposes or quicker training, truncate the number of training examples to this " "value if set." }, ) max_eval_samples: Optional[int] = field( default=None, metadata={ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " "value if set." }, ) overwrite_cache: bool = field( default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} ) preprocessing_num_workers: Optional[int] = field( default=None, metadata={"help": "The number of processes to use for the preprocessing."}, ) class TrainState(train_state.TrainState): dropout_rng: jnp.ndarray def replicate(self): return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng)) def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step): summary_writer.scalar("train_time", train_time, step) train_metrics = get_metrics(train_metrics) for key, vals in train_metrics.items(): tag = f"train_{key}" for i, val in enumerate(vals): summary_writer.scalar(tag, val, step - len(vals) + i + 1) for metric_name, value in eval_metrics.items(): summary_writer.scalar(f"eval_{metric_name}", value, step) def create_learning_rate_fn( train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float ) -> Callable[[int], jnp.array]: """Returns a linear warmup, linear_decay learning rate function.""" steps_per_epoch = train_ds_size // train_batch_size num_train_steps = steps_per_epoch * num_train_epochs warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps) decay_fn = optax.linear_schedule( init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps ) schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]) return schedule_fn def main(): # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. # We now keep distinct sets of args, for a cleaner separation of concerns. parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # If we pass only one argument to the script and it's the path to a json file, # let's parse it to get our arguments. model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) else: model_args, data_args, training_args = parser.parse_args_into_dataclasses() if ( os.path.exists(training_args.output_dir) and os.listdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir ): raise ValueError( f"Output directory ({training_args.output_dir}) already exists and is not empty." "Use --overwrite_output_dir to overcome." ) # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) # Setup logging, we only want one process per machine to log things on the screen. logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) if jax.process_index() == 0: transformers.utils.logging.set_verbosity_info() else: transformers.utils.logging.set_verbosity_error() # Set the verbosity to info of the Transformers logger (on main process only): logger.info(f"Training/evaluation parameters {training_args}") # set seed for random transforms and torch dataloaders set_seed(training_args.seed) # Initialize datasets and pre-processing transforms # We use torchvision here for faster pre-processing # Note that here we are using some default pre-processing, for maximum accuray # one should tune this part and carefully select what transformations to use. normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) train_dataset = torchvision.datasets.ImageFolder( data_args.train_dir, transforms.Compose( [ transforms.RandomResizedCrop(data_args.image_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ] ), ) eval_dataset = torchvision.datasets.ImageFolder( data_args.validation_dir, transforms.Compose( [ transforms.Resize(data_args.image_size), transforms.CenterCrop(data_args.image_size), transforms.ToTensor(), normalize, ] ), ) # Load pretrained model and tokenizer if model_args.config_name: config = AutoConfig.from_pretrained( model_args.config_name, num_labels=len(train_dataset.classes), image_size=data_args.image_size, cache_dir=model_args.cache_dir, ) elif model_args.model_name_or_path: config = AutoConfig.from_pretrained( model_args.model_name_or_path, num_labels=len(train_dataset.classes), image_size=data_args.image_size, cache_dir=model_args.cache_dir, ) else: config = CONFIG_MAPPING[model_args.model_type]() logger.warning("You are instantiating a new config instance from scratch.") if model_args.model_name_or_path: model = FlaxAutoModelForImageClassification.from_pretrained( model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype) ) else: model = FlaxAutoModelForImageClassification.from_config( config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype) ) # Store some constant num_epochs = int(training_args.num_train_epochs) train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count() steps_per_epoch = len(train_dataset) // train_batch_size total_train_steps = steps_per_epoch * num_epochs def collate_fn(examples): pixel_values = torch.stack([example[0] for example in examples]) labels = torch.tensor([example[1] for example in examples]) batch = {"pixel_values": pixel_values, "labels": labels} batch = {k: v.numpy() for k, v in batch.items()} return batch # Create data loaders train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=data_args.preprocessing_num_workers, persistent_workers=True, drop_last=True, collate_fn=collate_fn, ) eval_loader = torch.utils.data.DataLoader( eval_dataset, batch_size=eval_batch_size, shuffle=False, num_workers=data_args.preprocessing_num_workers, persistent_workers=True, drop_last=True, collate_fn=collate_fn, ) # Enable tensorboard only on the master node has_tensorboard = is_tensorboard_available() if has_tensorboard and jax.process_index() == 0: try: from flax.metrics.tensorboard import SummaryWriter summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir)) except ImportError as ie: has_tensorboard = False logger.warning( f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" ) else: logger.warning( "Unable to display metrics through TensorBoard because the package is not installed: " "Please run pip install tensorboard to enable." ) # Initialize our training rng = jax.random.PRNGKey(training_args.seed) rng, dropout_rng = jax.random.split(rng) # Create learning rate schedule linear_decay_lr_schedule_fn = create_learning_rate_fn( len(train_dataset), train_batch_size, training_args.num_train_epochs, training_args.warmup_steps, training_args.learning_rate, ) # create adam optimizer adamw = optax.adamw( learning_rate=linear_decay_lr_schedule_fn, b1=training_args.adam_beta1, b2=training_args.adam_beta2, eps=training_args.adam_epsilon, weight_decay=training_args.weight_decay, ) # Setup train state state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng) def loss_fn(logits, labels): loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) return loss.mean() # Define gradient update step fn def train_step(state, batch): dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng) def compute_loss(params): labels = batch.pop("labels") logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] loss = loss_fn(logits, labels) return loss grad_fn = jax.value_and_grad(compute_loss) loss, grad = grad_fn(state.params) grad = jax.lax.pmean(grad, "batch") new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng) metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} metrics = jax.lax.pmean(metrics, axis_name="batch") return new_state, metrics # Define eval fn def eval_step(params, batch): labels = batch.pop("labels") logits = model(**batch, params=params, train=False)[0] loss = loss_fn(logits, labels) # summarize metrics accuracy = (jnp.argmax(logits, axis=-1) == labels).mean() metrics = {"loss": loss, "accuracy": accuracy} metrics = jax.lax.pmean(metrics, axis_name="batch") return metrics # Create parallel version of the train and eval step p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,)) p_eval_step = jax.pmap(eval_step, "batch") # Replicate the train state on each device state = state.replicate() logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num Epochs = {num_epochs}") logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}") logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}") logger.info(f" Total optimization steps = {total_train_steps}") train_time = 0 epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) for epoch in epochs: # ======================== Training ================================ train_start = time.time() # Create sampling rng rng, input_rng = jax.random.split(rng) train_metrics = [] steps_per_epoch = len(train_dataset) // train_batch_size train_step_progress_bar = tqdm(total=steps_per_epoch, desc="Training...", position=1, leave=False) # train for batch in train_loader: batch = shard(batch) state, train_metric = p_train_step(state, batch) train_metrics.append(train_metric) train_step_progress_bar.update(1) train_time += time.time() - train_start train_metric = unreplicate(train_metric) train_step_progress_bar.close() epochs.write( f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})" ) # ======================== Evaluating ============================== eval_metrics = [] eval_steps = len(eval_dataset) // eval_batch_size eval_step_progress_bar = tqdm(total=eval_steps, desc="Evaluating...", position=2, leave=False) for batch in eval_loader: # Model forward batch = shard(batch) metrics = p_eval_step(state.params, batch) eval_metrics.append(metrics) eval_step_progress_bar.update(1) # normalize eval metrics eval_metrics = get_metrics(eval_metrics) eval_metrics = jax.tree_map(jnp.mean, eval_metrics) # Print metrics and update progress bar eval_step_progress_bar.close() desc = ( f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {round(eval_metrics['loss'].item(), 4)} | " f"Eval Accuracy: {round(eval_metrics['accuracy'].item(), 4)})" ) epochs.write(desc) epochs.desc = desc # Save metrics if has_tensorboard and jax.process_index() == 0: cur_step = epoch * (len(train_dataset) // train_batch_size) write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step) # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: params = jax.device_get(jax.tree_map(lambda x: x[0], state.params)) model.save_pretrained( training_args.output_dir, params=params, push_to_hub=training_args.push_to_hub, commit_message=f"Saving weights and logs of epoch {epoch+1}", ) if __name__ == "__main__": main()