mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Update trainer.mdx class_weights example (#23787)
class_weights tensor should follow model's device
This commit is contained in:
parent
4d9b76a80f
commit
d61d747627
@ -61,7 +61,7 @@ class CustomTrainer(Trainer):
|
||||
outputs = model(**inputs)
|
||||
logits = outputs.get("logits")
|
||||
# compute custom loss (suppose one has 3 labels with different weights)
|
||||
loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0]))
|
||||
loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0], device=model.device))
|
||||
loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
|
||||
return (loss, outputs) if return_outputs else loss
|
||||
```
|
||||
|
Loading…
Reference in New Issue
Block a user