(feat): Moving labels to same device as logits for Deit (#22679)

This commit is contained in:
Shikhar Chauhan 2023-04-10 14:04:57 +02:00 committed by GitHub
parent 870d91fb89
commit 98597725f1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -764,6 +764,7 @@ class DeiTForImageClassification(DeiTPreTrainedModel):
loss = None
if labels is not None:
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"