mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
(feat): Moving labels to same device as logits for Deit (#22679)
This commit is contained in:
parent
870d91fb89
commit
98597725f1
@ -764,6 +764,7 @@ class DeiTForImageClassification(DeiTPreTrainedModel):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
labels = labels.to(logits.device)
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
|
Loading…
Reference in New Issue
Block a user