#!/usr/bin/env python # coding=utf-8 # Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Fine-tuning the library models for multiple choice. """ # You can also adapt this script on your own multiple choice task. Pointers for this are left as comments. import logging import os import sys from dataclasses import dataclass, field from itertools import chain from pathlib import Path from typing import Optional import datasets import numpy as np import tensorflow as tf from datasets import load_dataset import transformers from transformers import ( CONFIG_NAME, TF2_WEIGHTS_NAME, AutoConfig, AutoTokenizer, HfArgumentParser, TFAutoModelForMultipleChoice, TFTrainingArguments, create_optimizer, set_seed, ) from transformers.utils import check_min_version # Will error if the minimal version of Transformers is not installed. Remove at your own risks. check_min_version("4.13.0.dev0") logger = logging.getLogger(__name__) # region Helper classes and functions class SavePretrainedCallback(tf.keras.callbacks.Callback): # Hugging Face models have a save_pretrained() method that saves both the weights and the necessary # metadata to allow them to be loaded as a pretrained model in future. This is a simple Keras callback # that saves the model with this method after each epoch. def __init__(self, output_dir, **kwargs): super().__init__() self.output_dir = output_dir def on_epoch_end(self, epoch, logs=None): self.model.save_pretrained(self.output_dir) def convert_dataset_for_tensorflow( dataset, non_label_column_names, batch_size, dataset_mode="variable_batch", shuffle=True, drop_remainder=True ): """Converts a Hugging Face dataset to a Tensorflow Dataset. The dataset_mode controls whether we pad all batches to the maximum sequence length, or whether we only pad to the maximum length within that batch. The former is most useful when training on TPU, as a new graph compilation is required for each sequence length. """ def densify_ragged_batch(features, label=None): features = { feature: ragged_tensor.to_tensor(shape=batch_shape[feature]) for feature, ragged_tensor in features.items() } if label is None: return features else: return features, label feature_keys = list(set(dataset.features.keys()) - set(non_label_column_names + ["label"])) if dataset_mode == "variable_batch": batch_shape = {key: None for key in feature_keys} data = {key: tf.ragged.constant(dataset[key]) for key in feature_keys} elif dataset_mode == "constant_batch": data = {key: tf.ragged.constant(dataset[key]) for key in feature_keys} batch_shape = { key: tf.concat(([batch_size], ragged_tensor.bounding_shape()[1:]), axis=0) for key, ragged_tensor in data.items() } else: raise ValueError("Unknown dataset mode!") if "label" in dataset.features: labels = tf.convert_to_tensor(np.array(dataset["label"])) tf_dataset = tf.data.Dataset.from_tensor_slices((data, labels)) else: tf_dataset = tf.data.Dataset.from_tensor_slices(data) if shuffle: tf_dataset = tf_dataset.shuffle(buffer_size=len(dataset)) options = tf.data.Options() options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF tf_dataset = ( tf_dataset.with_options(options) .batch(batch_size=batch_size, drop_remainder=drop_remainder) .map(densify_ragged_batch) ) return tf_dataset # endregion # region Arguments @dataclass class ModelArguments: """ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. """ model_name_or_path: str = field( metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} ) config_name: Optional[str] = field( default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} ) tokenizer_name: Optional[str] = field( default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} ) cache_dir: Optional[str] = field( default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, ) use_fast_tokenizer: bool = field( default=True, metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, ) model_revision: str = field( default="main", metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, ) use_auth_token: bool = field( default=False, metadata={ "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " "with private models)." }, ) @dataclass class DataTrainingArguments: """ Arguments pertaining to what data we are going to input our model for training and eval. """ train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."}) validation_file: Optional[str] = field( default=None, metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, ) overwrite_cache: bool = field( default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} ) preprocessing_num_workers: Optional[int] = field( default=None, metadata={"help": "The number of processes to use for the preprocessing."}, ) max_seq_length: Optional[int] = field( default=None, metadata={ "help": "The maximum total input sequence length after tokenization. If passed, sequences longer " "than this will be truncated, sequences shorter will be padded." }, ) pad_to_max_length: bool = field( default=False, metadata={ "help": "Whether to pad all samples to the maximum sentence length. " "If False, will pad the samples dynamically when batching to the maximum length in the batch. More " "efficient on GPU but very bad for TPU." }, ) max_train_samples: Optional[int] = field( default=None, metadata={ "help": "For debugging purposes or quicker training, truncate the number of training examples to this " "value if set." }, ) max_eval_samples: Optional[int] = field( default=None, metadata={ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " "value if set." }, ) def __post_init__(self): if self.train_file is not None: extension = self.train_file.split(".")[-1] assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." if self.validation_file is not None: extension = self.validation_file.split(".")[-1] assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." # endregion def main(): # region Argument parsing # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. # We now keep distinct sets of args, for a cleaner separation of concerns. parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TFTrainingArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # If we pass only one argument to the script and it's the path to a json file, # let's parse it to get our arguments. model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) else: model_args, data_args, training_args = parser.parse_args_into_dataclasses() output_dir = Path(training_args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) # endregion # region Logging logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", handlers=[logging.StreamHandler(sys.stdout)], ) log_level = training_args.get_process_log_level() logger.setLevel(log_level) datasets.utils.logging.set_verbosity(log_level) transformers.utils.logging.set_verbosity(log_level) transformers.utils.logging.enable_default_handler() transformers.utils.logging.enable_explicit_format() # endregion # region Checkpoints checkpoint = None if len(os.listdir(training_args.output_dir)) > 0 and not training_args.overwrite_output_dir: if (output_dir / CONFIG_NAME).is_file() and (output_dir / TF2_WEIGHTS_NAME).is_file(): checkpoint = output_dir logger.info( f"Checkpoint detected, resuming training from checkpoint in {training_args.output_dir}. To avoid this" " behavior, change the `--output_dir` or add `--overwrite_output_dir` to train from scratch." ) else: raise ValueError( f"Output directory ({training_args.output_dir}) already exists and is not empty. " "Use --overwrite_output_dir to continue regardless." ) # endregion # Set seed before initializing model. set_seed(training_args.seed) # region Load datasets # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ # (the dataset will be downloaded automatically from the datasets Hub). # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called # 'text' is found. You can easily tweak this behavior (see below). # In distributed training, the load_dataset function guarantee that only one local process can concurrently # download the dataset. if data_args.train_file is not None or data_args.validation_file is not None: data_files = {} if data_args.train_file is not None: data_files["train"] = data_args.train_file if data_args.validation_file is not None: data_files["validation"] = data_args.validation_file extension = data_args.train_file.split(".")[-1] raw_datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir) else: # Downloading and loading the swag dataset from the hub. raw_datasets = load_dataset("swag", "regular", cache_dir=model_args.cache_dir) # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at # https://huggingface.co/docs/datasets/loading_datasets.html. # When using your own dataset or a different dataset from swag, you will probably need to change this. ending_names = [f"ending{i}" for i in range(4)] context_name = "sent1" question_header_name = "sent2" # endregion # region Load model config and tokenizer if checkpoint is not None: config_path = training_args.output_dir elif model_args.config_name: config_path = model_args.config_name else: config_path = model_args.model_name_or_path # Distributed training: # The .from_pretrained methods guarantee that only one local process can concurrently # download model & vocab. config = AutoConfig.from_pretrained( config_path, cache_dir=model_args.cache_dir, revision=model_args.model_revision, use_auth_token=True if model_args.use_auth_token else None, ) tokenizer = AutoTokenizer.from_pretrained( model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer, revision=model_args.model_revision, use_auth_token=True if model_args.use_auth_token else None, ) # endregion # region Dataset preprocessing if data_args.max_seq_length is None: max_seq_length = tokenizer.model_max_length if max_seq_length > 1024: logger.warning( f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). " "Picking 1024 instead. You can change that default value by passing --max_seq_length xxx." ) max_seq_length = 1024 else: if data_args.max_seq_length > tokenizer.model_max_length: logger.warning( f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." ) max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) def preprocess_function(examples): first_sentences = [[context] * 4 for context in examples[context_name]] question_headers = examples[question_header_name] second_sentences = [ [f"{header} {examples[end][i]}" for end in ending_names] for i, header in enumerate(question_headers) ] # Flatten out first_sentences = list(chain(*first_sentences)) second_sentences = list(chain(*second_sentences)) # Tokenize tokenized_examples = tokenizer(first_sentences, second_sentences, truncation=True, max_length=max_seq_length) # Un-flatten data = {k: [v[i : i + 4] for i in range(0, len(v), 4)] for k, v in tokenized_examples.items()} return data if training_args.do_train: if "train" not in raw_datasets: raise ValueError("--do_train requires a train dataset") train_dataset = raw_datasets["train"] non_label_columns = [feature for feature in train_dataset.features if feature not in ("label", "labels")] if data_args.max_train_samples is not None: train_dataset = train_dataset.select(range(data_args.max_train_samples)) with training_args.main_process_first(desc="train dataset map pre-processing"): train_dataset = train_dataset.map( preprocess_function, batched=True, num_proc=data_args.preprocessing_num_workers, load_from_cache_file=not data_args.overwrite_cache, ) if training_args.do_eval: if "validation" not in raw_datasets: raise ValueError("--do_eval requires a validation dataset") eval_dataset = raw_datasets["validation"] if not training_args.do_train: non_label_columns = [feature for feature in eval_dataset.features if feature not in ("label", "labels")] if data_args.max_eval_samples is not None: eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) with training_args.main_process_first(desc="validation dataset map pre-processing"): eval_dataset = eval_dataset.map( preprocess_function, batched=True, num_proc=data_args.preprocessing_num_workers, load_from_cache_file=not data_args.overwrite_cache, ) # endregion with training_args.strategy.scope(): # region Build model if checkpoint is None: model_path = model_args.model_name_or_path else: model_path = checkpoint model = TFAutoModelForMultipleChoice.from_pretrained( model_path, config=config, cache_dir=model_args.cache_dir, revision=model_args.model_revision, use_auth_token=True if model_args.use_auth_token else None, ) num_replicas = training_args.strategy.num_replicas_in_sync total_train_batch_size = training_args.per_device_train_batch_size * num_replicas total_eval_batch_size = training_args.per_device_eval_batch_size * num_replicas if training_args.do_train: total_train_steps = (len(train_dataset) // total_train_batch_size) * int(training_args.num_train_epochs) optimizer, lr_schedule = create_optimizer( init_lr=training_args.learning_rate, num_train_steps=int(total_train_steps), num_warmup_steps=0 ) else: optimizer = "adam" # Just put anything in here, since we're not using it anyway model.compile( optimizer=optimizer, loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy")], ) # endregion # region Training if training_args.do_train: tf_train_dataset = convert_dataset_for_tensorflow( train_dataset, non_label_column_names=non_label_columns, batch_size=total_train_batch_size ) if training_args.do_eval: validation_data = convert_dataset_for_tensorflow( eval_dataset, non_label_column_names=non_label_columns, batch_size=total_eval_batch_size ) else: validation_data = None model.fit( tf_train_dataset, validation_data=validation_data, epochs=int(training_args.num_train_epochs), callbacks=[SavePretrainedCallback(output_dir=training_args.output_dir)], ) # endregion # region Evaluation if training_args.do_eval and not training_args.do_train: # Do a standalone evaluation pass tf_eval_dataset = convert_dataset_for_tensorflow( eval_dataset, non_label_column_names=non_label_columns, batch_size=total_eval_batch_size ) model.evaluate(tf_eval_dataset) # endregion # region Push to hub if training_args.push_to_hub: model.push_to_hub( finetuned_from=model_args.model_name_or_path, tasks="multiple-choice", dataset_tags="swag", dataset_args="regular", dataset="SWAG", language="en", ) # endregion if __name__ == "__main__": main()