mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Update collating_graphormer.py (#23862)
This commit is contained in:
parent
62ba64b90a
commit
0623f08e99
@ -129,6 +129,6 @@ class GraphormerDataCollator:
|
||||
else: # binary classification
|
||||
batch["labels"] = torch.from_numpy(np.concatenate([i["labels"] for i in features]))
|
||||
else: # multi task classification, left to float to keep the NaNs
|
||||
batch["labels"] = torch.from_numpy(np.stack([i["labels"] for i in features], dim=0))
|
||||
batch["labels"] = torch.from_numpy(np.stack([i["labels"] for i in features], axis=0))
|
||||
|
||||
return batch
|
||||
|
Loading…
Reference in New Issue
Block a user