This commit is contained in:
thomwolf 2019-07-16 21:22:19 +02:00
parent c5b3d86a91
commit e848b54730

View File

@ -114,7 +114,7 @@ def main():
mems = None
for idx, (data, target, seq_len) in enumerate(eval_iter):
ret = model(data, target, mems)
loss, mems = ret
loss, _, mems = ret
loss = loss.mean()
total_loss += seq_len * loss.item()
total_len += seq_len