mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00
typo fix in output tuple
This commit is contained in:
parent
d92a7f7721
commit
48d4a5317c
@ -520,7 +520,7 @@ def main():
|
||||
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
|
||||
batch = tuple(t.to(device) for t in batch)
|
||||
input_ids, input_mask, segment_ids, label_ids = batch
|
||||
loss = model(input_ids, segment_ids, input_mask, label_ids)
|
||||
loss, _ = model(input_ids, segment_ids, input_mask, label_ids)
|
||||
if n_gpu > 1:
|
||||
loss = loss.mean() # mean() to average on multi-gpu.
|
||||
if args.gradient_accumulation_steps > 1:
|
||||
|
Loading…
Reference in New Issue
Block a user