Adding doctest for zero-shot-classification pipeline. (#20268)

* Adding doctest for `zero-shot-classification` pipeline.

* Removing nested_simplify.
This commit is contained in:
Nicolas Patry 2022-11-16 17:15:01 +01:00 committed by GitHub
parent 69715f2ee0
commit e06657a798
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -46,13 +46,36 @@ class ZeroShotClassificationArgumentHandler(ArgumentHandler):
class ZeroShotClassificationPipeline(ChunkPipeline):
"""
NLI-based zero-shot classification pipeline using a `ModelForSequenceClassification` trained on NLI (natural
language inference) tasks.
language inference) tasks. Equivalent of `text-classification` pipelines, but these models don't require a
hardcoded number of potential classes, they can be chosen at runtime. It usually means it's slower but it is
**much** more flexible.
Any combination of sequences and labels can be passed and each combination will be posed as a premise/hypothesis
pair and passed to the pretrained model. Then, the logit for *entailment* is taken as the logit for the candidate
label being valid. Any NLI model can be used, but the id of the *entailment* label must be included in the model
config's :attr:*~transformers.PretrainedConfig.label2id*.
Example:
```python
>>> from transformers import pipeline
>>> oracle = pipeline(model="facebook/bart-large-mnli")
>>> answers = oracle(
... "I have a problem with my iphone that needs to be resolved asap!!",
... candidate_labels=["urgent", "not urgent", "phone", "tablet", "computer"],
... )
{'sequence': 'I have a problem with my iphone that needs to be resolved asap!!', 'labels': ['urgent', 'phone', 'computer', 'not urgent', 'tablet'], 'scores': [0.504, 0.479, 0.013, 0.003, 0.002]}
>>> oracle(
... "I have a problem with my iphone that needs to be resolved asap!!",
... candidate_labels=["english", "german"],
... )
{'sequence': 'I have a problem with my iphone that needs to be resolved asap!!', 'labels': ['english', 'german'], 'scores': [0.814, 0.186]}
```
[Learn more about the basics of using a pipeline in the [pipeline tutorial]](../pipeline_tutorial)
This NLI pipeline can currently be loaded from [`pipeline`] using the following task identifier:
`"zero-shot-classification"`.