mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 10:41:07 +06:00
push (#10846)
This commit is contained in:
parent
82b8d8c7b0
commit
0f226f78ce
@ -401,7 +401,7 @@ def evaluate(batch):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
logits = model(inputs.input_values.to("cuda"), attention_mask=inputs.attention_mask.to("cuda")).logits
|
logits = model(inputs.input_values.to("cuda"), attention_mask=inputs.attention_mask.to("cuda")).logits
|
||||||
|
|
||||||
pred_ids = torch.argmax(logits, dim=-1)
|
pred_ids = torch.argmax(logits, dim=-1)
|
||||||
batch["pred_strings"] = processor.batch_decode(pred_ids)
|
batch["pred_strings"] = processor.batch_decode(pred_ids)
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user