Update collating_graphormer.py (#23862)

This commit is contained in:
Clémentine Fourrier 2023-05-30 16:23:20 +02:00 committed by GitHub
parent 62ba64b90a
commit 0623f08e99
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -129,6 +129,6 @@ class GraphormerDataCollator:
else: # binary classification
batch["labels"] = torch.from_numpy(np.concatenate([i["labels"] for i in features]))
else: # multi task classification, left to float to keep the NaNs
batch["labels"] = torch.from_numpy(np.stack([i["labels"] for i in features], dim=0))
batch["labels"] = torch.from_numpy(np.stack([i["labels"] for i in features], axis=0))
return batch