mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
fix run_glue test
This commit is contained in:
parent
ccb6947dc1
commit
92a782b108
@ -53,6 +53,15 @@ MODEL_CLASSES = {
|
|||||||
'xlm': (XLMConfig, XLMForSequenceClassification, XLMTokenizer),
|
'xlm': (XLMConfig, XLMForSequenceClassification, XLMTokenizer),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def set_seed(args):
|
||||||
|
random.seed(args.seed)
|
||||||
|
np.random.seed(args.seed)
|
||||||
|
torch.manual_seed(args.seed)
|
||||||
|
if args.n_gpu > 0:
|
||||||
|
torch.cuda.manual_seed_all(args.seed)
|
||||||
|
|
||||||
|
|
||||||
def train(args, train_dataset, model, tokenizer):
|
def train(args, train_dataset, model, tokenizer):
|
||||||
""" Train the model """
|
""" Train the model """
|
||||||
if args.local_rank in [-1, 0]:
|
if args.local_rank in [-1, 0]:
|
||||||
@ -97,6 +106,7 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
tr_loss, logging_loss = 0.0, 0.0
|
tr_loss, logging_loss = 0.0, 0.0
|
||||||
model.zero_grad()
|
model.zero_grad()
|
||||||
train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
|
train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
|
||||||
|
set_seed(args) # Added here for reproductibility (even between python 2 and 3)
|
||||||
for _ in train_iterator:
|
for _ in train_iterator:
|
||||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
||||||
for step, batch in enumerate(epoch_iterator):
|
for step, batch in enumerate(epoch_iterator):
|
||||||
@ -371,12 +381,8 @@ def main():
|
|||||||
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||||
args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
|
args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
|
||||||
|
|
||||||
# Setup seeds
|
# Set seed
|
||||||
random.seed(args.seed)
|
set_seed(args)
|
||||||
np.random.seed(args.seed)
|
|
||||||
torch.manual_seed(args.seed)
|
|
||||||
if args.n_gpu > 0:
|
|
||||||
torch.cuda.manual_seed_all(args.seed)
|
|
||||||
|
|
||||||
# Prepare GLUE task
|
# Prepare GLUE task
|
||||||
args.task_name = args.task_name.lower()
|
args.task_name = args.task_name.lower()
|
||||||
|
@ -167,14 +167,14 @@ class AdamW(Optimizer):
|
|||||||
|
|
||||||
# Decay the first and second moment running average coefficient
|
# Decay the first and second moment running average coefficient
|
||||||
# In-place operations to update the averages at the same time
|
# In-place operations to update the averages at the same time
|
||||||
exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
exp_avg.mul_(beta1).add_(1.0 - beta1, grad)
|
||||||
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
exp_avg_sq.mul_(beta2).addcmul_(1.0 - beta2, grad, grad)
|
||||||
denom = exp_avg_sq.sqrt().add_(group['eps'])
|
denom = exp_avg_sq.sqrt().add_(group['eps'])
|
||||||
|
|
||||||
step_size = group['lr']
|
step_size = group['lr']
|
||||||
if group['correct_bias']: # No bias correction for Bert
|
if group['correct_bias']: # No bias correction for Bert
|
||||||
bias_correction1 = 1 - beta1 ** state['step']
|
bias_correction1 = 1.0 - beta1 ** state['step']
|
||||||
bias_correction2 = 1 - beta2 ** state['step']
|
bias_correction2 = 1.0 - beta2 ** state['step']
|
||||||
step_size = step_size * math.sqrt(bias_correction2) / bias_correction1
|
step_size = step_size * math.sqrt(bias_correction2) / bias_correction1
|
||||||
|
|
||||||
p.data.addcdiv_(-step_size, exp_avg, denom)
|
p.data.addcdiv_(-step_size, exp_avg, denom)
|
||||||
@ -187,7 +187,7 @@ class AdamW(Optimizer):
|
|||||||
# with the m/v parameters. This is equivalent to adding the square
|
# with the m/v parameters. This is equivalent to adding the square
|
||||||
# of the weights to the loss with plain (non-momentum) SGD.
|
# of the weights to the loss with plain (non-momentum) SGD.
|
||||||
# Add weight decay at the end (fixed version)
|
# Add weight decay at the end (fixed version)
|
||||||
if group['weight_decay'] > 0:
|
if group['weight_decay'] > 0.0:
|
||||||
p.data.add_(-group['lr'] * group['weight_decay'], p.data)
|
p.data.add_(-group['lr'] * group['weight_decay'], p.data)
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
Loading…
Reference in New Issue
Block a user