mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-05 05:40:05 +06:00

* Add some nicety flags for better controlling evaluation.
* Fix dependency issue with outdated requirement
* Add additional flag to example to ensure eval is done
* Wrap code into main function for accelerate launcher to find
* Fix valid batch size flag in readme
* Add note to install git-lfs when initializing/training the model
* Update examples/research_projects/codeparrot/scripts/arguments.py
Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
* Update examples/research_projects/codeparrot/README.md
Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
* Revert "Wrap code into main function for accelerate launcher to find"
This reverts commit ff11df1c81
.
* Fix formatting issue
* Move git-lfs instructions to installation section
* Add a quick check before code generation for code evaluation
* Fix styling issue
* Update examples/research_projects/codeparrot/scripts/human_eval.py
Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
* Make iterable dataset use passed in tokenizer rather than globally defined one
Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
Co-authored-by: ncoop57 <nac33@students.uwf.edu>
241 lines
9.0 KiB
Python
241 lines
9.0 KiB
Python
import logging
|
|
from argparse import Namespace
|
|
from pathlib import Path
|
|
|
|
import datasets
|
|
import torch
|
|
from datasets import load_dataset
|
|
from torch.utils.data import IterableDataset
|
|
from torch.utils.data.dataloader import DataLoader
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
import transformers
|
|
import wandb
|
|
from accelerate import Accelerator
|
|
from arguments import TrainingArguments
|
|
from huggingface_hub import Repository
|
|
from transformers import AdamW, AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, get_scheduler, set_seed
|
|
|
|
|
|
class ConstantLengthDataset(IterableDataset):
|
|
"""
|
|
Iterable dataset that returns constant length chunks of tokens from stream of text files.
|
|
Args:
|
|
tokenizer (Tokenizer): The processor used for proccessing the data.
|
|
dataset (dataset.Dataset): Dataset with text files.
|
|
infinite (bool): If True the iterator is reset after dataset reaches end else stops.
|
|
seq_length (int): Length of token sequences to return.
|
|
num_of_sequences: Number of token sequences to keep in buffer.
|
|
chars_per_token: Number of characters per token used to estimate number of tokens in text buffer.
|
|
"""
|
|
|
|
def __init__(
|
|
self, tokenizer, dataset, infinite=False, seq_length=1024, num_of_sequences=1024, chars_per_token=3.6
|
|
):
|
|
self.tokenizer = tokenizer
|
|
self.concat_token_id = tokenizer.bos_token_id
|
|
self.dataset = dataset
|
|
self.seq_length = seq_length
|
|
self.input_characters = seq_length * chars_per_token * num_of_sequences
|
|
self.epoch = 0
|
|
self.infinite = infinite
|
|
|
|
def __iter__(self):
|
|
iterator = iter(self.dataset)
|
|
more_examples = True
|
|
while more_examples:
|
|
buffer, buffer_len = [], 0
|
|
while True:
|
|
if buffer_len >= self.input_characters:
|
|
break
|
|
try:
|
|
buffer.append(next(iterator)["content"])
|
|
buffer_len += len(buffer[-1])
|
|
except StopIteration:
|
|
if self.infinite:
|
|
iterator = iter(self.dataset)
|
|
self.epoch += 1
|
|
logger.info(f"Dataset epoch: {self.epoch}")
|
|
else:
|
|
more_examples = False
|
|
break
|
|
tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"]
|
|
all_token_ids = []
|
|
for tokenized_input in tokenized_inputs:
|
|
all_token_ids.extend(tokenized_input + [self.concat_token_id])
|
|
for i in range(0, len(all_token_ids), self.seq_length):
|
|
input_ids = all_token_ids[i : i + self.seq_length]
|
|
if len(input_ids) == self.seq_length:
|
|
yield torch.tensor(input_ids)
|
|
|
|
|
|
def setup_logging(args):
|
|
project_name = args.model_ckpt.split("/")[-1]
|
|
logger = logging.getLogger(__name__)
|
|
log_dir = Path(args.save_dir) / "log/"
|
|
log_dir.mkdir(exist_ok=True)
|
|
filename = f"debug_{accelerator.process_index}.log"
|
|
logging.basicConfig(
|
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
|
datefmt="%m/%d/%Y %H:%M:%S",
|
|
level=logging.INFO,
|
|
handlers=[logging.FileHandler(log_dir / filename), logging.StreamHandler()],
|
|
)
|
|
if accelerator.is_main_process: # we only want to setup logging once
|
|
wandb.init(project=project_name, config=args)
|
|
run_name = wandb.run.name
|
|
tb_writer = SummaryWriter()
|
|
tb_writer.add_hparams(vars(args), {"0": 0})
|
|
logger.setLevel(logging.INFO)
|
|
datasets.utils.logging.set_verbosity_info()
|
|
transformers.utils.logging.set_verbosity_info()
|
|
else:
|
|
tb_writer = None
|
|
run_name = ""
|
|
logger.setLevel(logging.ERROR)
|
|
datasets.utils.logging.set_verbosity_error()
|
|
transformers.utils.logging.set_verbosity_error()
|
|
return logger, tb_writer, run_name
|
|
|
|
|
|
def create_dataloaders(args):
|
|
ds_kwargs = {"streaming": True}
|
|
train_data = load_dataset(args.dataset_name_train, split="train", **ds_kwargs)
|
|
train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=args.seed)
|
|
valid_data = load_dataset(args.dataset_name_valid, split="train", **ds_kwargs)
|
|
train_dataset = ConstantLengthDataset(tokenizer, train_data, infinite=True, seq_length=args.seq_length)
|
|
valid_dataset = ConstantLengthDataset(tokenizer, valid_data, infinite=False, seq_length=args.seq_length)
|
|
train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size)
|
|
eval_dataloader = DataLoader(valid_dataset, batch_size=args.valid_batch_size)
|
|
return train_dataloader, eval_dataloader
|
|
|
|
|
|
def get_grouped_params(model, args, no_decay=["bias", "LayerNorm.weight"]):
|
|
params_with_wd, params_without_wd = [], []
|
|
for n, p in model.named_parameters():
|
|
if any(nd in n for nd in no_decay):
|
|
params_without_wd.append(p)
|
|
else:
|
|
params_with_wd.append(p)
|
|
return [
|
|
{"params": params_with_wd, "weight_decay": args.weight_decay},
|
|
{"params": params_without_wd, "weight_decay": 0.0},
|
|
]
|
|
|
|
|
|
def log_metrics(step, metrics):
|
|
logger.info(f"Step {step}: {metrics}")
|
|
if accelerator.is_main_process:
|
|
wandb.log(metrics)
|
|
[tb_writer.add_scalar(k, v, step) for k, v in metrics.items()]
|
|
|
|
|
|
def evaluate(args):
|
|
model.eval()
|
|
losses = []
|
|
for step, batch in enumerate(eval_dataloader):
|
|
with torch.no_grad():
|
|
outputs = model(batch, labels=batch)
|
|
loss = outputs.loss.repeat(args.valid_batch_size)
|
|
losses.append(accelerator.gather(loss))
|
|
if args.max_eval_steps > 0 and step >= args.max_eval_steps:
|
|
break
|
|
loss = torch.mean(torch.cat(losses))
|
|
try:
|
|
perplexity = torch.exp(loss)
|
|
except OverflowError:
|
|
perplexity = float("inf")
|
|
return loss.item(), perplexity.item()
|
|
|
|
|
|
# Accelerator
|
|
accelerator = Accelerator()
|
|
acc_state = {str(k): str(v) for k, v in accelerator.state.__dict__.items()}
|
|
|
|
# Settings
|
|
parser = HfArgumentParser(TrainingArguments)
|
|
args = parser.parse_args()
|
|
|
|
args = Namespace(**vars(args), **acc_state)
|
|
samples_per_step = accelerator.state.num_processes * args.train_batch_size
|
|
set_seed(args.seed)
|
|
|
|
# Clone model repository
|
|
if accelerator.is_main_process:
|
|
hf_repo = Repository(args.save_dir, clone_from=args.model_ckpt)
|
|
|
|
# Logging
|
|
logger, tb_writer, run_name = setup_logging(args)
|
|
logger.info(accelerator.state)
|
|
|
|
# Checkout new branch on repo
|
|
if accelerator.is_main_process:
|
|
hf_repo.git_checkout(run_name, create_branch_ok=True)
|
|
|
|
# Load model and tokenizer
|
|
model = AutoModelForCausalLM.from_pretrained(args.save_dir)
|
|
if args.gradient_checkpointing:
|
|
model.gradient_checkpointing_enable()
|
|
tokenizer = AutoTokenizer.from_pretrained(args.save_dir)
|
|
|
|
# Load dataset and dataloader
|
|
train_dataloader, eval_dataloader = create_dataloaders(args)
|
|
|
|
# Prepare the optimizer and learning rate scheduler
|
|
optimizer = AdamW(get_grouped_params(model, args), lr=args.learning_rate)
|
|
lr_scheduler = get_scheduler(
|
|
name=args.lr_scheduler_type,
|
|
optimizer=optimizer,
|
|
num_warmup_steps=args.num_warmup_steps,
|
|
num_training_steps=args.max_train_steps,
|
|
)
|
|
|
|
|
|
def get_lr():
|
|
return optimizer.param_groups[0]["lr"]
|
|
|
|
|
|
# Prepare everything with our `accelerator`.
|
|
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
|
|
model, optimizer, train_dataloader, eval_dataloader
|
|
)
|
|
|
|
# Train model
|
|
model.train()
|
|
completed_steps = 0
|
|
for step, batch in enumerate(train_dataloader, start=1):
|
|
loss = model(batch, labels=batch, use_cache=False).loss
|
|
log_metrics(
|
|
step, {"lr": get_lr(), "samples": step * samples_per_step, "steps": completed_steps, "loss/train": loss.item()}
|
|
)
|
|
loss = loss / args.gradient_accumulation_steps
|
|
accelerator.backward(loss)
|
|
if step % args.gradient_accumulation_steps == 0:
|
|
accelerator.clip_grad_norm_(model.parameters(), 1.0)
|
|
optimizer.step()
|
|
lr_scheduler.step()
|
|
optimizer.zero_grad()
|
|
completed_steps += 1
|
|
if step % args.save_checkpoint_steps == 0:
|
|
logger.info("Evaluating and saving model checkpoint")
|
|
eval_loss, perplexity = evaluate(args)
|
|
log_metrics(step, {"loss/eval": eval_loss, "perplexity": perplexity})
|
|
accelerator.wait_for_everyone()
|
|
unwrapped_model = accelerator.unwrap_model(model)
|
|
unwrapped_model.save_pretrained(args.save_dir, save_function=accelerator.save)
|
|
if accelerator.is_main_process:
|
|
hf_repo.push_to_hub(commit_message=f"step {step}")
|
|
model.train()
|
|
if completed_steps >= args.max_train_steps:
|
|
break
|
|
|
|
# Evaluate and save the last checkpoint
|
|
logger.info("Evaluating and saving model after training")
|
|
eval_loss, perplexity = evaluate(args)
|
|
log_metrics(step, {"loss/eval": eval_loss, "perplexity": perplexity})
|
|
accelerator.wait_for_everyone()
|
|
unwrapped_model = accelerator.unwrap_model(model)
|
|
unwrapped_model.save_pretrained(args.save_dir, save_function=accelerator.save)
|
|
if accelerator.is_main_process:
|
|
hf_repo.push_to_hub(commit_message="final model")
|