mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 13:20:12 +06:00
526 lines
21 KiB
Python
Executable File
526 lines
21 KiB
Python
Executable File
#!/usr/bin/env python
|
|
# coding=utf-8
|
|
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...)
|
|
on a text file or a dataset without using HuggingFace Trainer.
|
|
|
|
Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
|
|
https://huggingface.co/models?filter=causal-lm
|
|
"""
|
|
# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.
|
|
|
|
import argparse
|
|
import logging
|
|
import math
|
|
import os
|
|
import random
|
|
from pathlib import Path
|
|
|
|
import datasets
|
|
import torch
|
|
from datasets import load_dataset
|
|
from torch.utils.data import DataLoader
|
|
from tqdm.auto import tqdm
|
|
|
|
import transformers
|
|
from accelerate import Accelerator, DistributedType
|
|
from huggingface_hub import Repository
|
|
from transformers import (
|
|
CONFIG_MAPPING,
|
|
MODEL_MAPPING,
|
|
AdamW,
|
|
AutoConfig,
|
|
AutoModelForCausalLM,
|
|
AutoTokenizer,
|
|
SchedulerType,
|
|
default_data_collator,
|
|
get_scheduler,
|
|
set_seed,
|
|
)
|
|
from transformers.file_utils import get_full_repo_name
|
|
from transformers.utils.versions import require_version
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
|
|
|
MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
|
|
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description="Finetune a transformers model on a causal language modeling task")
|
|
parser.add_argument(
|
|
"--dataset_name",
|
|
type=str,
|
|
default=None,
|
|
help="The name of the dataset to use (via the datasets library).",
|
|
)
|
|
parser.add_argument(
|
|
"--dataset_config_name",
|
|
type=str,
|
|
default=None,
|
|
help="The configuration name of the dataset to use (via the datasets library).",
|
|
)
|
|
parser.add_argument(
|
|
"--train_file", type=str, default=None, help="A csv or a json file containing the training data."
|
|
)
|
|
parser.add_argument(
|
|
"--validation_file", type=str, default=None, help="A csv or a json file containing the validation data."
|
|
)
|
|
parser.add_argument(
|
|
"--validation_split_percentage",
|
|
default=5,
|
|
help="The percentage of the train set used as validation set in case there's no validation split",
|
|
)
|
|
parser.add_argument(
|
|
"--model_name_or_path",
|
|
type=str,
|
|
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
|
required=True,
|
|
)
|
|
parser.add_argument(
|
|
"--config_name",
|
|
type=str,
|
|
default=None,
|
|
help="Pretrained config name or path if not the same as model_name",
|
|
)
|
|
parser.add_argument(
|
|
"--tokenizer_name",
|
|
type=str,
|
|
default=None,
|
|
help="Pretrained tokenizer name or path if not the same as model_name",
|
|
)
|
|
parser.add_argument(
|
|
"--use_slow_tokenizer",
|
|
action="store_true",
|
|
help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).",
|
|
)
|
|
parser.add_argument(
|
|
"--per_device_train_batch_size",
|
|
type=int,
|
|
default=8,
|
|
help="Batch size (per device) for the training dataloader.",
|
|
)
|
|
parser.add_argument(
|
|
"--per_device_eval_batch_size",
|
|
type=int,
|
|
default=8,
|
|
help="Batch size (per device) for the evaluation dataloader.",
|
|
)
|
|
parser.add_argument(
|
|
"--learning_rate",
|
|
type=float,
|
|
default=5e-5,
|
|
help="Initial learning rate (after the potential warmup period) to use.",
|
|
)
|
|
parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
|
|
parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.")
|
|
parser.add_argument(
|
|
"--max_train_steps",
|
|
type=int,
|
|
default=None,
|
|
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
|
)
|
|
parser.add_argument(
|
|
"--gradient_accumulation_steps",
|
|
type=int,
|
|
default=1,
|
|
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
|
)
|
|
parser.add_argument(
|
|
"--lr_scheduler_type",
|
|
type=SchedulerType,
|
|
default="linear",
|
|
help="The scheduler type to use.",
|
|
choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
|
|
)
|
|
parser.add_argument(
|
|
"--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler."
|
|
)
|
|
parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.")
|
|
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
|
parser.add_argument(
|
|
"--model_type",
|
|
type=str,
|
|
default=None,
|
|
help="Model type to use if training from scratch.",
|
|
choices=MODEL_TYPES,
|
|
)
|
|
parser.add_argument(
|
|
"--block_size",
|
|
type=int,
|
|
default=None,
|
|
help="Optional input sequence length after tokenization. The training dataset will be truncated in block of this size for training. Default to the model max input length for single sentence inputs (take into account special tokens).",
|
|
)
|
|
parser.add_argument(
|
|
"--preprocessing_num_workers",
|
|
type=int,
|
|
default=None,
|
|
help="The number of processes to use for the preprocessing.",
|
|
)
|
|
parser.add_argument(
|
|
"--overwrite_cache", type=bool, default=False, help="Overwrite the cached training and evaluation sets"
|
|
)
|
|
parser.add_argument(
|
|
"--no_keep_linebreaks", action="store_true", help="Do not keep line breaks when using TXT files."
|
|
)
|
|
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
|
parser.add_argument(
|
|
"--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`."
|
|
)
|
|
parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.")
|
|
args = parser.parse_args()
|
|
|
|
# Sanity checks
|
|
if args.dataset_name is None and args.train_file is None and args.validation_file is None:
|
|
raise ValueError("Need either a dataset name or a training/validation file.")
|
|
else:
|
|
if args.train_file is not None:
|
|
extension = args.train_file.split(".")[-1]
|
|
assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, json or txt file."
|
|
if args.validation_file is not None:
|
|
extension = args.validation_file.split(".")[-1]
|
|
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file."
|
|
|
|
if args.push_to_hub:
|
|
assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed."
|
|
|
|
return args
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
|
|
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
|
|
accelerator = Accelerator()
|
|
# Make one log on every process with the configuration for debugging.
|
|
logging.basicConfig(
|
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
|
datefmt="%m/%d/%Y %H:%M:%S",
|
|
level=logging.INFO,
|
|
)
|
|
logger.info(accelerator.state)
|
|
|
|
# Setup logging, we only want one process per machine to log things on the screen.
|
|
# accelerator.is_local_main_process is only True for one process per machine.
|
|
logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR)
|
|
if accelerator.is_local_main_process:
|
|
datasets.utils.logging.set_verbosity_warning()
|
|
transformers.utils.logging.set_verbosity_info()
|
|
else:
|
|
datasets.utils.logging.set_verbosity_error()
|
|
transformers.utils.logging.set_verbosity_error()
|
|
|
|
# If passed along, set the training seed now.
|
|
if args.seed is not None:
|
|
set_seed(args.seed)
|
|
|
|
# Handle the repository creation
|
|
if accelerator.is_main_process:
|
|
if args.push_to_hub:
|
|
if args.hub_model_id is None:
|
|
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
|
|
else:
|
|
repo_name = args.hub_model_id
|
|
repo = Repository(args.output_dir, clone_from=repo_name)
|
|
elif args.output_dir is not None:
|
|
os.makedirs(args.output_dir, exist_ok=True)
|
|
accelerator.wait_for_everyone()
|
|
|
|
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
|
|
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
|
|
# (the dataset will be downloaded automatically from the datasets Hub).
|
|
#
|
|
# For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
|
|
# 'text' is found. You can easily tweak this behavior (see below).
|
|
#
|
|
# In distributed training, the load_dataset function guarantee that only one local process can concurrently
|
|
# download the dataset.
|
|
if args.dataset_name is not None:
|
|
# Downloading and loading a dataset from the hub.
|
|
raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name)
|
|
if "validation" not in raw_datasets.keys():
|
|
raw_datasets["validation"] = load_dataset(
|
|
args.dataset_name,
|
|
args.dataset_config_name,
|
|
split=f"train[:{args.validation_split_percentage}%]",
|
|
)
|
|
raw_datasets["train"] = load_dataset(
|
|
args.dataset_name,
|
|
args.dataset_config_name,
|
|
split=f"train[{args.validation_split_percentage}%:]",
|
|
)
|
|
else:
|
|
data_files = {}
|
|
dataset_args = {}
|
|
if args.train_file is not None:
|
|
data_files["train"] = args.train_file
|
|
if args.validation_file is not None:
|
|
data_files["validation"] = args.validation_file
|
|
extension = args.train_file.split(".")[-1]
|
|
if extension == "txt":
|
|
extension = "text"
|
|
dataset_args["keep_linebreaks"] = not args.no_keep_linebreaks
|
|
raw_datasets = load_dataset(extension, data_files=data_files, **dataset_args)
|
|
# If no validation data is there, validation_split_percentage will be used to divide the dataset.
|
|
if "validation" not in raw_datasets.keys():
|
|
raw_datasets["validation"] = load_dataset(
|
|
extension,
|
|
data_files=data_files,
|
|
split=f"train[:{args.validation_split_percentage}%]",
|
|
**dataset_args,
|
|
)
|
|
raw_datasets["train"] = load_dataset(
|
|
extension,
|
|
data_files=data_files,
|
|
split=f"train[{args.validation_split_percentage}%:]",
|
|
**dataset_args,
|
|
)
|
|
|
|
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
|
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
|
|
|
# Load pretrained model and tokenizer
|
|
#
|
|
# In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
|
|
# download model & vocab.
|
|
if args.config_name:
|
|
config = AutoConfig.from_pretrained(args.config_name)
|
|
elif args.model_name_or_path:
|
|
config = AutoConfig.from_pretrained(args.model_name_or_path)
|
|
else:
|
|
config = CONFIG_MAPPING[args.model_type]()
|
|
logger.warning("You are instantiating a new config instance from scratch.")
|
|
|
|
if args.tokenizer_name:
|
|
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=not args.use_slow_tokenizer)
|
|
elif args.model_name_or_path:
|
|
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer)
|
|
else:
|
|
raise ValueError(
|
|
"You are instantiating a new tokenizer from scratch. This is not supported by this script."
|
|
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
|
|
)
|
|
|
|
if args.model_name_or_path:
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
args.model_name_or_path,
|
|
from_tf=bool(".ckpt" in args.model_name_or_path),
|
|
config=config,
|
|
)
|
|
else:
|
|
logger.info("Training new model from scratch")
|
|
model = AutoModelForCausalLM.from_config(config)
|
|
|
|
model.resize_token_embeddings(len(tokenizer))
|
|
|
|
# Preprocessing the datasets.
|
|
# First we tokenize all the texts.
|
|
column_names = raw_datasets["train"].column_names
|
|
text_column_name = "text" if "text" in column_names else column_names[0]
|
|
|
|
def tokenize_function(examples):
|
|
return tokenizer(examples[text_column_name])
|
|
|
|
with accelerator.main_process_first():
|
|
tokenized_datasets = raw_datasets.map(
|
|
tokenize_function,
|
|
batched=True,
|
|
num_proc=args.preprocessing_num_workers,
|
|
remove_columns=column_names,
|
|
load_from_cache_file=not args.overwrite_cache,
|
|
desc="Running tokenizer on dataset",
|
|
)
|
|
|
|
if args.block_size is None:
|
|
block_size = tokenizer.model_max_length
|
|
if block_size > 1024:
|
|
logger.warning(
|
|
f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
|
|
"Picking 1024 instead. You can change that default value by passing --block_size xxx."
|
|
)
|
|
block_size = 1024
|
|
else:
|
|
if args.block_size > tokenizer.model_max_length:
|
|
logger.warning(
|
|
f"The block_size passed ({args.block_size}) is larger than the maximum length for the model"
|
|
f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}."
|
|
)
|
|
block_size = min(args.block_size, tokenizer.model_max_length)
|
|
|
|
# Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
|
|
def group_texts(examples):
|
|
# Concatenate all texts.
|
|
concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
|
|
total_length = len(concatenated_examples[list(examples.keys())[0]])
|
|
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
|
|
# customize this part to your needs.
|
|
if total_length >= block_size:
|
|
total_length = (total_length // block_size) * block_size
|
|
# Split by chunks of max_len.
|
|
result = {
|
|
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
|
|
for k, t in concatenated_examples.items()
|
|
}
|
|
result["labels"] = result["input_ids"].copy()
|
|
return result
|
|
|
|
# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
|
|
# for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
|
|
# to preprocess.
|
|
#
|
|
# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
|
|
# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
|
|
|
|
with accelerator.main_process_first():
|
|
lm_datasets = tokenized_datasets.map(
|
|
group_texts,
|
|
batched=True,
|
|
num_proc=args.preprocessing_num_workers,
|
|
load_from_cache_file=not args.overwrite_cache,
|
|
desc=f"Grouping texts in chunks of {block_size}",
|
|
)
|
|
|
|
train_dataset = lm_datasets["train"]
|
|
eval_dataset = lm_datasets["validation"]
|
|
|
|
# Log a few random samples from the training set:
|
|
for index in random.sample(range(len(train_dataset)), 3):
|
|
logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
|
|
|
|
# DataLoaders creation:
|
|
train_dataloader = DataLoader(
|
|
train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=args.per_device_train_batch_size
|
|
)
|
|
eval_dataloader = DataLoader(
|
|
eval_dataset, collate_fn=default_data_collator, batch_size=args.per_device_eval_batch_size
|
|
)
|
|
|
|
# Optimizer
|
|
# Split weights in two groups, one with weight decay and the other not.
|
|
no_decay = ["bias", "LayerNorm.weight"]
|
|
optimizer_grouped_parameters = [
|
|
{
|
|
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
|
"weight_decay": args.weight_decay,
|
|
},
|
|
{
|
|
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
|
|
"weight_decay": 0.0,
|
|
},
|
|
]
|
|
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
|
|
|
|
# Prepare everything with our `accelerator`.
|
|
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
|
|
model, optimizer, train_dataloader, eval_dataloader
|
|
)
|
|
|
|
# On TPU, the tie weights in our model have been disconnected, so we need to restore the ties.
|
|
if accelerator.distributed_type == DistributedType.TPU:
|
|
model.tie_weights()
|
|
|
|
# Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be
|
|
# shorter in multiprocess)
|
|
|
|
# Scheduler and math around the number of training steps.
|
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
|
if args.max_train_steps is None:
|
|
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
|
else:
|
|
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
|
|
|
lr_scheduler = get_scheduler(
|
|
name=args.lr_scheduler_type,
|
|
optimizer=optimizer,
|
|
num_warmup_steps=args.num_warmup_steps,
|
|
num_training_steps=args.max_train_steps,
|
|
)
|
|
|
|
# Train!
|
|
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
|
|
|
logger.info("***** Running training *****")
|
|
logger.info(f" Num examples = {len(train_dataset)}")
|
|
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
|
logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
|
|
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
|
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
|
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
|
# Only show the progress bar once on each machine.
|
|
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
|
|
completed_steps = 0
|
|
|
|
for epoch in range(args.num_train_epochs):
|
|
model.train()
|
|
for step, batch in enumerate(train_dataloader):
|
|
outputs = model(**batch)
|
|
loss = outputs.loss
|
|
loss = loss / args.gradient_accumulation_steps
|
|
accelerator.backward(loss)
|
|
if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
|
|
optimizer.step()
|
|
lr_scheduler.step()
|
|
optimizer.zero_grad()
|
|
progress_bar.update(1)
|
|
completed_steps += 1
|
|
|
|
if completed_steps >= args.max_train_steps:
|
|
break
|
|
|
|
model.eval()
|
|
losses = []
|
|
for step, batch in enumerate(eval_dataloader):
|
|
with torch.no_grad():
|
|
outputs = model(**batch)
|
|
|
|
loss = outputs.loss
|
|
losses.append(accelerator.gather(loss.repeat(args.per_device_eval_batch_size)))
|
|
|
|
losses = torch.cat(losses)
|
|
losses = losses[: len(eval_dataset)]
|
|
try:
|
|
perplexity = math.exp(torch.mean(losses))
|
|
except OverflowError:
|
|
perplexity = float("inf")
|
|
|
|
logger.info(f"epoch {epoch}: perplexity: {perplexity}")
|
|
|
|
if args.push_to_hub and epoch < args.num_train_epochs - 1:
|
|
accelerator.wait_for_everyone()
|
|
unwrapped_model = accelerator.unwrap_model(model)
|
|
unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save)
|
|
if accelerator.is_main_process:
|
|
tokenizer.save_pretrained(args.output_dir)
|
|
repo.push_to_hub(
|
|
commit_message=f"Training in progress epoch {epoch}", blocking=False, auto_lfs_prune=True
|
|
)
|
|
|
|
if args.output_dir is not None:
|
|
accelerator.wait_for_everyone()
|
|
unwrapped_model = accelerator.unwrap_model(model)
|
|
unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save)
|
|
if accelerator.is_main_process:
|
|
tokenizer.save_pretrained(args.output_dir)
|
|
if args.push_to_hub:
|
|
repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|