mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Wrap code into main function for accelerate launcher to find
This commit is contained in:
parent
b82db50adf
commit
ff11df1c81
@ -59,7 +59,7 @@ class ConstantLengthDataset(IterableDataset):
|
||||
else:
|
||||
more_examples = False
|
||||
break
|
||||
tokenized_inputs = tokenizer(buffer, truncation=False)["input_ids"]
|
||||
tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"]
|
||||
all_token_ids = []
|
||||
for tokenized_input in tokenized_inputs:
|
||||
all_token_ids.extend(tokenized_input + [self.concat_token_id])
|
||||
@ -68,173 +68,176 @@ class ConstantLengthDataset(IterableDataset):
|
||||
if len(input_ids) == self.seq_length:
|
||||
yield torch.tensor(input_ids)
|
||||
|
||||
|
||||
def setup_logging(args):
|
||||
project_name = args.model_ckpt.split("/")[-1]
|
||||
logger = logging.getLogger(__name__)
|
||||
log_dir = Path(args.save_dir) / "log/"
|
||||
log_dir.mkdir(exist_ok=True)
|
||||
filename = f"debug_{accelerator.process_index}.log"
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
handlers=[logging.FileHandler(log_dir / filename), logging.StreamHandler()],
|
||||
)
|
||||
if accelerator.is_main_process: # we only want to setup logging once
|
||||
wandb.init(project=project_name, config=args)
|
||||
run_name = wandb.run.name
|
||||
tb_writer = SummaryWriter()
|
||||
tb_writer.add_hparams(vars(args), {"0": 0})
|
||||
logger.setLevel(logging.INFO)
|
||||
datasets.utils.logging.set_verbosity_info()
|
||||
transformers.utils.logging.set_verbosity_info()
|
||||
else:
|
||||
tb_writer = None
|
||||
run_name = ""
|
||||
logger.setLevel(logging.ERROR)
|
||||
datasets.utils.logging.set_verbosity_error()
|
||||
transformers.utils.logging.set_verbosity_error()
|
||||
return logger, tb_writer, run_name
|
||||
|
||||
|
||||
def create_dataloaders(args):
|
||||
ds_kwargs = {"streaming": True}
|
||||
train_data = load_dataset(args.dataset_name_train, split="train", **ds_kwargs)
|
||||
train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=args.seed)
|
||||
valid_data = load_dataset(args.dataset_name_valid, split="train", **ds_kwargs)
|
||||
train_dataset = ConstantLengthDataset(tokenizer, train_data, infinite=True, seq_length=args.seq_length)
|
||||
valid_dataset = ConstantLengthDataset(tokenizer, valid_data, infinite=False, seq_length=args.seq_length)
|
||||
train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size)
|
||||
eval_dataloader = DataLoader(valid_dataset, batch_size=args.valid_batch_size)
|
||||
return train_dataloader, eval_dataloader
|
||||
|
||||
|
||||
def get_grouped_params(model, args, no_decay=["bias", "LayerNorm.weight"]):
|
||||
params_with_wd, params_without_wd = [], []
|
||||
for n, p in model.named_parameters():
|
||||
if any(nd in n for nd in no_decay):
|
||||
params_without_wd.append(p)
|
||||
def main():
|
||||
def setup_logging(args):
|
||||
project_name = args.model_ckpt.split("/")[-1]
|
||||
logger = logging.getLogger(__name__)
|
||||
log_dir = Path(args.save_dir) / "log/"
|
||||
log_dir.mkdir(exist_ok=True)
|
||||
filename = f"debug_{accelerator.process_index}.log"
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
handlers=[logging.FileHandler(log_dir / filename), logging.StreamHandler()],
|
||||
)
|
||||
if accelerator.is_main_process: # we only want to setup logging once
|
||||
wandb.init(project=project_name, config=args)
|
||||
run_name = wandb.run.name
|
||||
tb_writer = SummaryWriter()
|
||||
tb_writer.add_hparams(vars(args), {"0": 0})
|
||||
logger.setLevel(logging.INFO)
|
||||
datasets.utils.logging.set_verbosity_info()
|
||||
transformers.utils.logging.set_verbosity_info()
|
||||
else:
|
||||
params_with_wd.append(p)
|
||||
return [
|
||||
{"params": params_with_wd, "weight_decay": args.weight_decay},
|
||||
{"params": params_without_wd, "weight_decay": 0.0},
|
||||
]
|
||||
tb_writer = None
|
||||
run_name = ""
|
||||
logger.setLevel(logging.ERROR)
|
||||
datasets.utils.logging.set_verbosity_error()
|
||||
transformers.utils.logging.set_verbosity_error()
|
||||
return logger, tb_writer, run_name
|
||||
|
||||
|
||||
def log_metrics(step, metrics):
|
||||
logger.info(f"Step {step}: {metrics}")
|
||||
if accelerator.is_main_process:
|
||||
wandb.log(metrics)
|
||||
[tb_writer.add_scalar(k, v, step) for k, v in metrics.items()]
|
||||
def create_dataloaders(args):
|
||||
ds_kwargs = {"streaming": True}
|
||||
train_data = load_dataset(args.dataset_name_train, split="train", **ds_kwargs)
|
||||
train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=args.seed)
|
||||
valid_data = load_dataset(args.dataset_name_valid, split="train", **ds_kwargs)
|
||||
train_dataset = ConstantLengthDataset(tokenizer, train_data, infinite=True, seq_length=args.seq_length)
|
||||
valid_dataset = ConstantLengthDataset(tokenizer, valid_data, infinite=False, seq_length=args.seq_length)
|
||||
train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size)
|
||||
eval_dataloader = DataLoader(valid_dataset, batch_size=args.valid_batch_size)
|
||||
return train_dataloader, eval_dataloader
|
||||
|
||||
|
||||
def evaluate(args):
|
||||
model.eval()
|
||||
losses = []
|
||||
for step, batch in enumerate(eval_dataloader):
|
||||
with torch.no_grad():
|
||||
outputs = model(batch, labels=batch)
|
||||
loss = outputs.loss.repeat(args.valid_batch_size)
|
||||
losses.append(accelerator.gather(loss))
|
||||
if args.max_eval_steps > 0 and step >= args.max_eval_steps:
|
||||
break
|
||||
loss = torch.mean(torch.cat(losses))
|
||||
try:
|
||||
perplexity = torch.exp(loss)
|
||||
except OverflowError:
|
||||
perplexity = float("inf")
|
||||
return loss.item(), perplexity.item()
|
||||
def get_grouped_params(model, args, no_decay=["bias", "LayerNorm.weight"]):
|
||||
params_with_wd, params_without_wd = [], []
|
||||
for n, p in model.named_parameters():
|
||||
if any(nd in n for nd in no_decay):
|
||||
params_without_wd.append(p)
|
||||
else:
|
||||
params_with_wd.append(p)
|
||||
return [
|
||||
{"params": params_with_wd, "weight_decay": args.weight_decay},
|
||||
{"params": params_without_wd, "weight_decay": 0.0},
|
||||
]
|
||||
|
||||
|
||||
# Accelerator
|
||||
accelerator = Accelerator()
|
||||
acc_state = {str(k): str(v) for k, v in accelerator.state.__dict__.items()}
|
||||
|
||||
# Settings
|
||||
parser = HfArgumentParser(TrainingArguments)
|
||||
args = parser.parse_args()
|
||||
|
||||
args = Namespace(**vars(args), **acc_state)
|
||||
samples_per_step = accelerator.state.num_processes * args.train_batch_size
|
||||
set_seed(args.seed)
|
||||
|
||||
# Clone model repository
|
||||
if accelerator.is_main_process:
|
||||
hf_repo = Repository(args.save_dir, clone_from=args.model_ckpt)
|
||||
|
||||
# Logging
|
||||
logger, tb_writer, run_name = setup_logging(args)
|
||||
logger.info(accelerator.state)
|
||||
|
||||
# Checkout new branch on repo
|
||||
if accelerator.is_main_process:
|
||||
hf_repo.git_checkout(run_name, create_branch_ok=True)
|
||||
|
||||
# Load model and tokenizer
|
||||
model = AutoModelForCausalLM.from_pretrained(args.save_dir)
|
||||
if args.gradient_checkpointing:
|
||||
model.gradient_checkpointing_enable()
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.save_dir)
|
||||
|
||||
# Load dataset and dataloader
|
||||
train_dataloader, eval_dataloader = create_dataloaders(args)
|
||||
|
||||
# Prepare the optimizer and learning rate scheduler
|
||||
optimizer = AdamW(get_grouped_params(model, args), lr=args.learning_rate)
|
||||
lr_scheduler = get_scheduler(
|
||||
name=args.lr_scheduler_type,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=args.num_warmup_steps,
|
||||
num_training_steps=args.max_train_steps,
|
||||
)
|
||||
|
||||
|
||||
def get_lr():
|
||||
return optimizer.param_groups[0]["lr"]
|
||||
|
||||
|
||||
# Prepare everything with our `accelerator`.
|
||||
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
|
||||
model, optimizer, train_dataloader, eval_dataloader
|
||||
)
|
||||
|
||||
# Train model
|
||||
model.train()
|
||||
completed_steps = 0
|
||||
for step, batch in enumerate(train_dataloader, start=1):
|
||||
loss = model(batch, labels=batch, use_cache=False).loss
|
||||
log_metrics(
|
||||
step, {"lr": get_lr(), "samples": step * samples_per_step, "steps": completed_steps, "loss/train": loss.item()}
|
||||
)
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
accelerator.backward(loss)
|
||||
if step % args.gradient_accumulation_steps == 0:
|
||||
accelerator.clip_grad_norm_(model.parameters(), 1.0)
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
completed_steps += 1
|
||||
if step % args.save_checkpoint_steps == 0:
|
||||
logger.info("Evaluating and saving model checkpoint")
|
||||
eval_loss, perplexity = evaluate(args)
|
||||
log_metrics(step, {"loss/eval": eval_loss, "perplexity": perplexity})
|
||||
accelerator.wait_for_everyone()
|
||||
unwrapped_model = accelerator.unwrap_model(model)
|
||||
unwrapped_model.save_pretrained(args.save_dir, save_function=accelerator.save)
|
||||
def log_metrics(step, metrics):
|
||||
logger.info(f"Step {step}: {metrics}")
|
||||
if accelerator.is_main_process:
|
||||
hf_repo.push_to_hub(commit_message=f"step {step}")
|
||||
model.train()
|
||||
if completed_steps >= args.max_train_steps:
|
||||
break
|
||||
wandb.log(metrics)
|
||||
[tb_writer.add_scalar(k, v, step) for k, v in metrics.items()]
|
||||
|
||||
# Evaluate and save the last checkpoint
|
||||
logger.info("Evaluating and saving model after training")
|
||||
eval_loss, perplexity = evaluate(args)
|
||||
log_metrics(step, {"loss/eval": eval_loss, "perplexity": perplexity})
|
||||
accelerator.wait_for_everyone()
|
||||
unwrapped_model = accelerator.unwrap_model(model)
|
||||
unwrapped_model.save_pretrained(args.save_dir, save_function=accelerator.save)
|
||||
if accelerator.is_main_process:
|
||||
hf_repo.push_to_hub(commit_message="final model")
|
||||
|
||||
def evaluate(args):
|
||||
model.eval()
|
||||
losses = []
|
||||
for step, batch in enumerate(eval_dataloader):
|
||||
with torch.no_grad():
|
||||
outputs = model(batch, labels=batch)
|
||||
loss = outputs.loss.repeat(args.valid_batch_size)
|
||||
losses.append(accelerator.gather(loss))
|
||||
if args.max_eval_steps > 0 and step >= args.max_eval_steps:
|
||||
break
|
||||
loss = torch.mean(torch.cat(losses))
|
||||
try:
|
||||
perplexity = torch.exp(loss)
|
||||
except OverflowError:
|
||||
perplexity = float("inf")
|
||||
return loss.item(), perplexity.item()
|
||||
|
||||
|
||||
# Accelerator
|
||||
accelerator = Accelerator()
|
||||
acc_state = {str(k): str(v) for k, v in accelerator.state.__dict__.items()}
|
||||
|
||||
# Settings
|
||||
parser = HfArgumentParser(TrainingArguments)
|
||||
args = parser.parse_args()
|
||||
|
||||
args = Namespace(**vars(args), **acc_state)
|
||||
samples_per_step = accelerator.state.num_processes * args.train_batch_size
|
||||
set_seed(args.seed)
|
||||
|
||||
# Clone model repository
|
||||
if accelerator.is_main_process:
|
||||
hf_repo = Repository(args.save_dir, clone_from=args.model_ckpt)
|
||||
|
||||
# Logging
|
||||
logger, tb_writer, run_name = setup_logging(args)
|
||||
logger.info(accelerator.state)
|
||||
|
||||
# Checkout new branch on repo
|
||||
if accelerator.is_main_process:
|
||||
hf_repo.git_checkout(run_name, create_branch_ok=True)
|
||||
|
||||
# Load model and tokenizer
|
||||
model = AutoModelForCausalLM.from_pretrained(args.save_dir)
|
||||
if args.gradient_checkpointing:
|
||||
model.gradient_checkpointing_enable()
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.save_dir)
|
||||
|
||||
# Load dataset and dataloader
|
||||
train_dataloader, eval_dataloader = create_dataloaders(args)
|
||||
|
||||
# Prepare the optimizer and learning rate scheduler
|
||||
optimizer = AdamW(get_grouped_params(model, args), lr=args.learning_rate)
|
||||
lr_scheduler = get_scheduler(
|
||||
name=args.lr_scheduler_type,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=args.num_warmup_steps,
|
||||
num_training_steps=args.max_train_steps,
|
||||
)
|
||||
|
||||
|
||||
def get_lr():
|
||||
return optimizer.param_groups[0]["lr"]
|
||||
|
||||
|
||||
# Prepare everything with our `accelerator`.
|
||||
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
|
||||
model, optimizer, train_dataloader, eval_dataloader
|
||||
)
|
||||
|
||||
# Train model
|
||||
model.train()
|
||||
completed_steps = 0
|
||||
for step, batch in enumerate(train_dataloader, start=1):
|
||||
loss = model(batch, labels=batch, use_cache=False).loss
|
||||
log_metrics(
|
||||
step, {"lr": get_lr(), "samples": step * samples_per_step, "steps": completed_steps, "loss/train": loss.item()}
|
||||
)
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
accelerator.backward(loss)
|
||||
if step % args.gradient_accumulation_steps == 0:
|
||||
accelerator.clip_grad_norm_(model.parameters(), 1.0)
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
completed_steps += 1
|
||||
if step % args.save_checkpoint_steps == 0:
|
||||
logger.info("Evaluating and saving model checkpoint")
|
||||
eval_loss, perplexity = evaluate(args)
|
||||
log_metrics(step, {"loss/eval": eval_loss, "perplexity": perplexity})
|
||||
accelerator.wait_for_everyone()
|
||||
unwrapped_model = accelerator.unwrap_model(model)
|
||||
unwrapped_model.save_pretrained(args.save_dir, save_function=accelerator.save)
|
||||
if accelerator.is_main_process:
|
||||
hf_repo.push_to_hub(commit_message=f"step {step}")
|
||||
model.train()
|
||||
if completed_steps >= args.max_train_steps:
|
||||
break
|
||||
|
||||
# Evaluate and save the last checkpoint
|
||||
logger.info("Evaluating and saving model after training")
|
||||
eval_loss, perplexity = evaluate(args)
|
||||
log_metrics(step, {"loss/eval": eval_loss, "perplexity": perplexity})
|
||||
accelerator.wait_for_everyone()
|
||||
unwrapped_model = accelerator.unwrap_model(model)
|
||||
unwrapped_model.save_pretrained(args.save_dir, save_function=accelerator.save)
|
||||
if accelerator.is_main_process:
|
||||
hf_repo.push_to_hub(commit_message="final model")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Reference in New Issue
Block a user