This commit is contained in:
Patrick von Platen 2021-03-22 10:32:21 +03:00 committed by GitHub
parent 82b8d8c7b0
commit 0f226f78ce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -401,7 +401,7 @@ def evaluate(batch):
with torch.no_grad():
logits = model(inputs.input_values.to("cuda"), attention_mask=inputs.attention_mask.to("cuda")).logits
pred_ids = torch.argmax(logits, dim=-1)
pred_ids = torch.argmax(logits, dim=-1)
batch["pred_strings"] = processor.batch_decode(pred_ids)
return batch