fixed document (#13414)

This commit is contained in:
Mohan Zhang 2021-09-08 11:48:00 -04:00 committed by GitHub
parent 330d83fdbd
commit 41cd52a768
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -64,9 +64,9 @@ classification:
class MultilabelTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
labels = inputs.pop("labels")
labels = inputs.get("labels")
outputs = model(**inputs)
logits = outputs.logits
logits = outputs.get('logits')
loss_fct = nn.BCEWithLogitsLoss()
loss = loss_fct(logits.view(-1, self.model.config.num_labels),
labels.float().view(-1, self.model.config.num_labels))