mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
update run_glue
This commit is contained in:
parent
3b7cb7bf44
commit
4ce237c880
@ -69,7 +69,11 @@ def train(args, train_dataset, model):
|
||||
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
|
||||
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
|
||||
|
||||
num_train_optimization_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
||||
if args.max_steps > 0:
|
||||
num_train_optimization_steps = args.max_steps
|
||||
args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
|
||||
else:
|
||||
num_train_optimization_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
||||
|
||||
# Prepare optimizer
|
||||
param_optimizer = list(model.named_parameters())
|
||||
@ -91,10 +95,8 @@ def train(args, train_dataset, model):
|
||||
warmup_linear = WarmupLinearSchedule(warmup=args.warmup_proportion, t_total=num_train_optimization_steps)
|
||||
|
||||
else:
|
||||
optimizer = BertAdam(optimizer_grouped_parameters,
|
||||
lr=args.learning_rate,
|
||||
warmup=args.warmup_proportion,
|
||||
t_total=num_train_optimization_steps)
|
||||
optimizer = BertAdam(optimizer_grouped_parameters, lr=args.learning_rate, warmup=args.warmup_proportion,
|
||||
t_total=num_train_optimization_steps)
|
||||
|
||||
# Train!
|
||||
logger.info("***** Running training *****")
|
||||
@ -113,7 +115,7 @@ def train(args, train_dataset, model):
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
inputs = {'input_ids': batch[0],
|
||||
'attention_mask': batch[1],
|
||||
'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None,
|
||||
'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM don't use segment_ids
|
||||
'labels': batch[3]}
|
||||
ouputs = model(**inputs)
|
||||
loss = ouputs[0]
|
||||
@ -140,14 +142,16 @@ def train(args, train_dataset, model):
|
||||
if not args.fp16:
|
||||
tb_writer.add_scalar('lr', optimizer.get_lr()[0], global_step)
|
||||
tb_writer.add_scalar('loss', loss.item(), global_step)
|
||||
if args.max_steps > 0 and global_step > args.max_steps:
|
||||
break
|
||||
if args.max_steps > 0 and global_step > args.max_steps:
|
||||
break
|
||||
|
||||
return global_step, tr_loss / global_step
|
||||
|
||||
|
||||
def evalutate(args, eval_task, eval_output_dir, dataset, model):
|
||||
""" Evaluate the model """
|
||||
if os.path.exists(eval_output_dir) and os.listdir(eval_output_dir) and args.do_train and not args.overwrite_output_dir:
|
||||
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(eval_output_dir))
|
||||
if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
|
||||
os.makedirs(eval_output_dir)
|
||||
|
||||
@ -166,13 +170,13 @@ def evalutate(args, eval_task, eval_output_dir, dataset, model):
|
||||
out_label_ids = None
|
||||
for batch in tqdm(eval_dataloader, desc="Evaluating"):
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
input_ids, input_mask, segment_ids, label_ids = batch
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(input_ids,
|
||||
token_type_ids=segment_ids,
|
||||
attention_mask=input_mask,
|
||||
labels=label_ids)
|
||||
inputs = {'input_ids': batch[0],
|
||||
'attention_mask': batch[1],
|
||||
'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM don't use segment_ids
|
||||
'labels': batch[3]}
|
||||
outputs = model(**inputs)
|
||||
tmp_eval_loss, logits = outputs[:2]
|
||||
|
||||
eval_loss += tmp_eval_loss.mean().item()
|
||||
@ -276,6 +280,8 @@ def main():
|
||||
help="The initial learning rate for Adam.")
|
||||
parser.add_argument("--num_train_epochs", default=3.0, type=float,
|
||||
help="Total number of training epochs to perform.")
|
||||
parser.add_argument("--max_steps", default=-1, type=int,
|
||||
help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
|
||||
parser.add_argument("--warmup_proportion", default=0.1, type=float,
|
||||
help="Proportion of training with linear learning rate warmup (0.1 = 10%% of training).")
|
||||
parser.add_argument("--no_cuda", action='store_true',
|
||||
@ -299,6 +305,9 @@ def main():
|
||||
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
|
||||
args = parser.parse_args()
|
||||
|
||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
|
||||
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
|
||||
|
||||
# Setup distant debugging if needed
|
||||
if args.server_ip and args.server_port:
|
||||
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
||||
@ -320,8 +329,8 @@ def main():
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
|
||||
logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
|
||||
device, args.n_gpu, bool(args.local_rank != -1), args.fp16))
|
||||
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||
args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
|
||||
|
||||
# Setup seeds
|
||||
random.seed(args.seed)
|
||||
@ -375,8 +384,6 @@ def main():
|
||||
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
|
||||
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||
# Create output directory if needed
|
||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
|
||||
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
|
||||
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
|
||||
os.makedirs(args.output_dir)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user