mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Fix CodeParrot training script (#17291)
* average loss over batches and accumulated steps for tracking * fix layernorm weight decay * use AdamW from Pytorch instead of Transformers * add shuffling of sequences inside the batches * add shuffling of sequences inside the batches * add logging dir and reformat code * fix lr tracking * remove Mistral scaling * keep Mistral scaling * reformat code * fix error * fix error * use shuffling function from Pytorch * remove argument for shuffling batch sequences as it isn't optional * update package versions and install accelerate from source * remove unused package * Update loss average over accumulated steps Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com> * Update loss average over accumulated steps Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com> * use one shuffle buffer argument * compute avg_loss in one line Co-authored-by: Loubna ben allal <loubnabenallal@gmail.com> Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
This commit is contained in:
parent
b9bb417324
commit
b48ac1a094
@ -1,7 +1,7 @@
|
|||||||
transformers==4.15.0
|
transformers==4.19.0
|
||||||
datasets==1.16.0
|
datasets==1.16.0
|
||||||
accelerate==0.6.2
|
|
||||||
wandb==0.12.0
|
wandb==0.12.0
|
||||||
tensorboard==2.6.0
|
tensorboard==2.6.0
|
||||||
torch==1.9.0
|
torch==1.11.0
|
||||||
huggingface-hub==0.1.0
|
huggingface-hub==0.1.0
|
||||||
|
git+https://github.com/huggingface/accelerate.git@3c45b6f760ad8745be9ebc9bbb26f5b04dea4abe
|
@ -24,7 +24,7 @@ class TrainingArguments:
|
|||||||
valid_batch_size: Optional[int] = field(default=2, metadata={"help": "Batch size for evaluation."})
|
valid_batch_size: Optional[int] = field(default=2, metadata={"help": "Batch size for evaluation."})
|
||||||
weight_decay: Optional[float] = field(default=0.1, metadata={"help": "Value of weight decay."})
|
weight_decay: Optional[float] = field(default=0.1, metadata={"help": "Value of weight decay."})
|
||||||
shuffle_buffer: Optional[int] = field(
|
shuffle_buffer: Optional[int] = field(
|
||||||
default=1000, metadata={"help": "Size of buffer used to shuffle streaming dataset."}
|
default=10000, metadata={"help": "Size of buffer used to shuffle streaming dataset."}
|
||||||
)
|
)
|
||||||
learning_rate: Optional[float] = field(default=2e-4, metadata={"help": "Learning rate fo training."})
|
learning_rate: Optional[float] = field(default=2e-4, metadata={"help": "Learning rate fo training."})
|
||||||
lr_scheduler_type: Optional[str] = field(default="cosine", metadata={"help": "Learning rate."})
|
lr_scheduler_type: Optional[str] = field(default="cosine", metadata={"help": "Learning rate."})
|
||||||
|
@ -7,14 +7,16 @@ from pathlib import Path
|
|||||||
import datasets
|
import datasets
|
||||||
import torch
|
import torch
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
from torch.optim import AdamW
|
||||||
from torch.utils.data import IterableDataset
|
from torch.utils.data import IterableDataset
|
||||||
from torch.utils.data.dataloader import DataLoader
|
from torch.utils.data.dataloader import DataLoader
|
||||||
|
from torch.utils.data.datapipes.iter.combinatorics import ShufflerIterDataPipe
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
from accelerate import Accelerator, DistributedType
|
from accelerate import Accelerator, DistributedType
|
||||||
from arguments import TrainingArguments
|
from arguments import TrainingArguments
|
||||||
from huggingface_hub import Repository
|
from huggingface_hub import Repository
|
||||||
from transformers import AdamW, AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, get_scheduler, set_seed
|
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, get_scheduler, set_seed
|
||||||
|
|
||||||
|
|
||||||
class ConstantLengthDataset(IterableDataset):
|
class ConstantLengthDataset(IterableDataset):
|
||||||
@ -25,9 +27,9 @@ class ConstantLengthDataset(IterableDataset):
|
|||||||
dataset (dataset.Dataset): Dataset with text files.
|
dataset (dataset.Dataset): Dataset with text files.
|
||||||
infinite (bool): If True the iterator is reset after dataset reaches end else stops.
|
infinite (bool): If True the iterator is reset after dataset reaches end else stops.
|
||||||
seq_length (int): Length of token sequences to return.
|
seq_length (int): Length of token sequences to return.
|
||||||
num_of_sequences: Number of token sequences to keep in buffer.
|
num_of_sequences (int): Number of token sequences to keep in buffer.
|
||||||
chars_per_token: Number of characters per token used to estimate number of tokens in text buffer.
|
chars_per_token (int): Number of characters per token used to estimate number of tokens in text buffer.
|
||||||
tokenized: If true we use a pretokenized dataset.
|
tokenized (bool): If true we use a pretokenized dataset.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -88,6 +90,9 @@ class ConstantLengthDataset(IterableDataset):
|
|||||||
self.current_size += 1
|
self.current_size += 1
|
||||||
yield torch.tensor(input_ids)
|
yield torch.tensor(input_ids)
|
||||||
|
|
||||||
|
def shuffle(self, buffer_size=1000):
|
||||||
|
return ShufflerIterDataPipe(self, buffer_size=buffer_size)
|
||||||
|
|
||||||
|
|
||||||
def setup_logging(args):
|
def setup_logging(args):
|
||||||
project_name = args.model_ckpt.split("/")[-1]
|
project_name = args.model_ckpt.split("/")[-1]
|
||||||
@ -126,12 +131,13 @@ def create_dataloaders(args):
|
|||||||
valid_dataset = ConstantLengthDataset(
|
valid_dataset = ConstantLengthDataset(
|
||||||
tokenizer, valid_data, infinite=False, seq_length=args.seq_length, tokenized=args.tokenized
|
tokenizer, valid_data, infinite=False, seq_length=args.seq_length, tokenized=args.tokenized
|
||||||
)
|
)
|
||||||
train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size)
|
train_dataset = train_dataset.shuffle(buffer_size=args.shuffle_buffer)
|
||||||
|
train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True)
|
||||||
eval_dataloader = DataLoader(valid_dataset, batch_size=args.valid_batch_size)
|
eval_dataloader = DataLoader(valid_dataset, batch_size=args.valid_batch_size)
|
||||||
return train_dataloader, eval_dataloader
|
return train_dataloader, eval_dataloader
|
||||||
|
|
||||||
|
|
||||||
def get_grouped_params(model, args, no_decay=["bias", "LayerNorm.weight"]):
|
def get_grouped_params(model, args, no_decay=["bias", "ln_1.weight", "ln_2.weight", "ln_f.weight"]):
|
||||||
params_with_wd, params_without_wd = [], []
|
params_with_wd, params_without_wd = [], []
|
||||||
for n, p in model.named_parameters():
|
for n, p in model.named_parameters():
|
||||||
if any(nd in n for nd in no_decay):
|
if any(nd in n for nd in no_decay):
|
||||||
@ -184,14 +190,14 @@ def evaluate(args):
|
|||||||
return loss.item(), perplexity.item()
|
return loss.item(), perplexity.item()
|
||||||
|
|
||||||
|
|
||||||
# Accelerator
|
|
||||||
accelerator = Accelerator(log_with=["wandb", "tensorboard"])
|
|
||||||
acc_state = {str(k): str(v) for k, v in accelerator.state.__dict__.items()}
|
|
||||||
|
|
||||||
# Settings
|
# Settings
|
||||||
parser = HfArgumentParser(TrainingArguments)
|
parser = HfArgumentParser(TrainingArguments)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Accelerator
|
||||||
|
accelerator = Accelerator(log_with=["wandb", "tensorboard"], logging_dir=f"{args.save_dir}/log")
|
||||||
|
acc_state = {str(k): str(v) for k, v in accelerator.state.__dict__.items()}
|
||||||
|
|
||||||
args = Namespace(**vars(args), **acc_state)
|
args = Namespace(**vars(args), **acc_state)
|
||||||
samples_per_step = accelerator.state.num_processes * args.train_batch_size
|
samples_per_step = accelerator.state.num_processes * args.train_batch_size
|
||||||
set_seed(args.seed)
|
set_seed(args.seed)
|
||||||
@ -256,13 +262,14 @@ if args.resume_from_checkpoint:
|
|||||||
model.train()
|
model.train()
|
||||||
completed_steps = 0
|
completed_steps = 0
|
||||||
t_start = time.time()
|
t_start = time.time()
|
||||||
|
loss_tracking = 0
|
||||||
for step, batch in enumerate(train_dataloader, start=1):
|
for step, batch in enumerate(train_dataloader, start=1):
|
||||||
if args.resume_from_checkpoint and step < resume_step:
|
if args.resume_from_checkpoint and step < resume_step:
|
||||||
continue # we need to skip steps until we reach the resumed step
|
continue # we need to skip steps until we reach the resumed step
|
||||||
loss = model(batch, labels=batch, use_cache=False).loss
|
loss = model(batch, labels=batch, use_cache=False).loss
|
||||||
log_metrics(
|
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
|
||||||
step, {"lr": get_lr(), "samples": step * samples_per_step, "steps": completed_steps, "loss/train": loss.item()}
|
loss_tracking += avg_loss.item() / args.gradient_accumulation_steps
|
||||||
)
|
log_metrics(step, {"samples": step * samples_per_step, "loss_per_step/train": loss.item()})
|
||||||
loss = loss / args.gradient_accumulation_steps
|
loss = loss / args.gradient_accumulation_steps
|
||||||
if step % args.gradient_accumulation_steps != 0:
|
if step % args.gradient_accumulation_steps != 0:
|
||||||
# Prevent backward from doing gradient all_reduce in every step
|
# Prevent backward from doing gradient all_reduce in every step
|
||||||
@ -272,16 +279,27 @@ for step, batch in enumerate(train_dataloader, start=1):
|
|||||||
else:
|
else:
|
||||||
accelerator.backward(loss)
|
accelerator.backward(loss)
|
||||||
else:
|
else:
|
||||||
|
lr = get_lr()
|
||||||
accelerator.backward(loss)
|
accelerator.backward(loss)
|
||||||
accelerator.clip_grad_norm_(model.parameters(), 1.0)
|
accelerator.clip_grad_norm_(model.parameters(), 1.0)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
lr_scheduler.step()
|
lr_scheduler.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
completed_steps += 1
|
|
||||||
elapsed_time = time.time() - t_start
|
elapsed_time = time.time() - t_start
|
||||||
tflops = compute_tflops(elapsed_time, accelerator, args)
|
tflops = compute_tflops(elapsed_time, accelerator, args)
|
||||||
log_metrics(step, {"steps": completed_steps, "tflops": tflops, "time_per_iteration": elapsed_time})
|
log_metrics(
|
||||||
|
step,
|
||||||
|
{
|
||||||
|
"steps": completed_steps,
|
||||||
|
"loss/train": loss_tracking,
|
||||||
|
"lr": lr,
|
||||||
|
"tflops": tflops,
|
||||||
|
"time_per_iteration": elapsed_time,
|
||||||
|
},
|
||||||
|
)
|
||||||
t_start = time.time()
|
t_start = time.time()
|
||||||
|
loss_tracking = 0
|
||||||
|
completed_steps += 1
|
||||||
if step % args.save_checkpoint_steps == 0:
|
if step % args.save_checkpoint_steps == 0:
|
||||||
logger.info("Evaluating and saving model checkpoint")
|
logger.info("Evaluating and saving model checkpoint")
|
||||||
eval_loss, perplexity = evaluate(args)
|
eval_loss, perplexity = evaluate(args)
|
||||||
|
@ -10,7 +10,11 @@ args = parser.parse_args()
|
|||||||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name)
|
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name)
|
||||||
|
|
||||||
# Config: "scale_attn_by_layer_idx" and "reorder_and_upcast_attn" are Mistral stability tweaks
|
# Config: "scale_attn_by_layer_idx" and "reorder_and_upcast_attn" are Mistral stability tweaks
|
||||||
config_kwargs = {"vocab_size": len(tokenizer), "scale_attn_by_layer_idx": True, "reorder_and_upcast_attn": True}
|
config_kwargs = {
|
||||||
|
"vocab_size": len(tokenizer),
|
||||||
|
"scale_attn_by_inverse_layer_idx": True,
|
||||||
|
"reorder_and_upcast_attn": True,
|
||||||
|
}
|
||||||
|
|
||||||
# Load model config (GPT-2 large in this case)
|
# Load model config (GPT-2 large in this case)
|
||||||
config = AutoConfig.from_pretrained(args.config_name, **config_kwargs)
|
config = AutoConfig.from_pretrained(args.config_name, **config_kwargs)
|
||||||
|
Loading…
Reference in New Issue
Block a user