add model.zero_grad()

This commit is contained in:
thomwolf 2018-11-03 17:40:12 +01:00
parent a4086c5de5
commit cb76c1ddd3
2 changed files with 2 additions and 0 deletions

View File

@ -531,6 +531,7 @@ def main():
loss, _ = model(input_ids, segment_ids, input_mask, label_ids)
total_tr_loss += loss.item()
nb_tr_examples += input_ids.size(0)
model.zero_grad()
loss.backward()
optimizer.step()
global_step += 1

View File

@ -856,6 +856,7 @@ def main():
logger.info("HHHHH Forward")
loss, _ = model(input_ids, segment_ids, input_mask, start_positions, end_positions)
model.zero_grad()
logger.info("HHHHH Backward")
loss.backward()
logger.info("HHHHH Loading data")