mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Move torch.nn.utils.clip_grad_norm_
This commit is contained in:
parent
164c794eb3
commit
af1ee9e648
@ -204,13 +204,16 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
|
||||
if args.fp16:
|
||||
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
|
||||
else:
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
||||
|
||||
tr_loss += loss.item()
|
||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||
if args.fp16:
|
||||
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
|
||||
else:
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
scheduler.step() # Update learning rate schedule
|
||||
model.zero_grad()
|
||||
|
Loading…
Reference in New Issue
Block a user