Update trainer.mdx class_weights example (#23787)

class_weights tensor should follow model's device
This commit is contained in:
amitportnoy 2023-05-26 15:36:33 +03:00 committed by GitHub
parent 4d9b76a80f
commit d61d747627
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -61,7 +61,7 @@ class CustomTrainer(Trainer):
outputs = model(**inputs)
logits = outputs.get("logits")
# compute custom loss (suppose one has 3 labels with different weights)
loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0]))
loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0], device=model.device))
loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
return (loss, outputs) if return_outputs else loss
```