Fix doctest CI (#21166)

* fix

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2023-01-18 16:54:24 +01:00 committed by GitHub
parent 8ad06b7c13
commit 32525428e1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -256,7 +256,7 @@ PT_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
>>> with torch.no_grad():
... logits = model(**inputs).logits
>>> predicted_class_ids = torch.arange(0, logits.shape[-1])[torch.sigmoid(logits).squeeze() > 0.5]
>>> predicted_class_ids = torch.arange(0, logits.shape[-1])[torch.sigmoid(logits).squeeze(dim=0) > 0.5]
>>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)`
>>> num_labels = len(model.config.id2label)
@ -264,7 +264,9 @@ PT_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
... "{checkpoint}", num_labels=num_labels, problem_type="multi_label_classification"
... )
>>> labels = torch.nn.functional.one_hot(torch.tensor(predicted_class_ids), num_classes=num_labels).to(torch.float)
>>> labels = torch.sum(
... torch.nn.functional.one_hot(predicted_class_ids[None, :].clone(), num_classes=num_labels), dim=1
... ).to(torch.float)
>>> loss = model(**inputs, labels=labels).loss
```
"""