mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix doctest CI (#21166)
* fix Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
8ad06b7c13
commit
32525428e1
@ -256,7 +256,7 @@ PT_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
|
||||
>>> with torch.no_grad():
|
||||
... logits = model(**inputs).logits
|
||||
|
||||
>>> predicted_class_ids = torch.arange(0, logits.shape[-1])[torch.sigmoid(logits).squeeze() > 0.5]
|
||||
>>> predicted_class_ids = torch.arange(0, logits.shape[-1])[torch.sigmoid(logits).squeeze(dim=0) > 0.5]
|
||||
|
||||
>>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)`
|
||||
>>> num_labels = len(model.config.id2label)
|
||||
@ -264,7 +264,9 @@ PT_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
|
||||
... "{checkpoint}", num_labels=num_labels, problem_type="multi_label_classification"
|
||||
... )
|
||||
|
||||
>>> labels = torch.nn.functional.one_hot(torch.tensor(predicted_class_ids), num_classes=num_labels).to(torch.float)
|
||||
>>> labels = torch.sum(
|
||||
... torch.nn.functional.one_hot(predicted_class_ids[None, :].clone(), num_classes=num_labels), dim=1
|
||||
... ).to(torch.float)
|
||||
>>> loss = model(**inputs, labels=labels).loss
|
||||
```
|
||||
"""
|
||||
|
Loading…
Reference in New Issue
Block a user