mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
fixed document (#13414)
This commit is contained in:
parent
330d83fdbd
commit
41cd52a768
@ -64,9 +64,9 @@ classification:
|
||||
|
||||
class MultilabelTrainer(Trainer):
|
||||
def compute_loss(self, model, inputs, return_outputs=False):
|
||||
labels = inputs.pop("labels")
|
||||
labels = inputs.get("labels")
|
||||
outputs = model(**inputs)
|
||||
logits = outputs.logits
|
||||
logits = outputs.get('logits')
|
||||
loss_fct = nn.BCEWithLogitsLoss()
|
||||
loss = loss_fct(logits.view(-1, self.model.config.num_labels),
|
||||
labels.float().view(-1, self.model.config.num_labels))
|
||||
|
Loading…
Reference in New Issue
Block a user