mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00
update optimizer run_squad
This commit is contained in:
parent
f514cbbf30
commit
25f73add07
@ -800,13 +800,17 @@ def main():
|
||||
if n_gpu > 1:
|
||||
model = torch.nn.DataParallel(model)
|
||||
|
||||
optimizer = BERTAdam([{'params': [p for n, p in model.named_parameters() if n != 'bias'], 'l2': 0.01},
|
||||
{'params': [p for n, p in model.named_parameters() if n == 'bias'], 'l2': 0.}
|
||||
],
|
||||
lr=args.learning_rate, schedule='warmup_linear',
|
||||
no_decay = ['bias', 'gamma', 'beta']
|
||||
optimizer_parameters = [
|
||||
{'params': [p for n, p in model.named_parameters() if n not in no_decay], 'weight_decay_rate': 0.01},
|
||||
{'params': [p for n, p in model.named_parameters() if n in no_decay], 'weight_decay_rate': 0.0}
|
||||
]
|
||||
|
||||
optimizer = BERTAdam(optimizer_parameters,
|
||||
lr=args.learning_rate,
|
||||
warmup=args.warmup_proportion,
|
||||
t_total=num_train_steps)
|
||||
|
||||
|
||||
global_step = 0
|
||||
if args.do_train:
|
||||
train_features = convert_examples_to_features(
|
||||
|
Loading…
Reference in New Issue
Block a user