mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Adding doctest for zero-shot-classification
pipeline. (#20268)
* Adding doctest for `zero-shot-classification` pipeline. * Removing nested_simplify.
This commit is contained in:
parent
69715f2ee0
commit
e06657a798
@ -46,13 +46,36 @@ class ZeroShotClassificationArgumentHandler(ArgumentHandler):
|
||||
class ZeroShotClassificationPipeline(ChunkPipeline):
|
||||
"""
|
||||
NLI-based zero-shot classification pipeline using a `ModelForSequenceClassification` trained on NLI (natural
|
||||
language inference) tasks.
|
||||
language inference) tasks. Equivalent of `text-classification` pipelines, but these models don't require a
|
||||
hardcoded number of potential classes, they can be chosen at runtime. It usually means it's slower but it is
|
||||
**much** more flexible.
|
||||
|
||||
Any combination of sequences and labels can be passed and each combination will be posed as a premise/hypothesis
|
||||
pair and passed to the pretrained model. Then, the logit for *entailment* is taken as the logit for the candidate
|
||||
label being valid. Any NLI model can be used, but the id of the *entailment* label must be included in the model
|
||||
config's :attr:*~transformers.PretrainedConfig.label2id*.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import pipeline
|
||||
|
||||
>>> oracle = pipeline(model="facebook/bart-large-mnli")
|
||||
>>> answers = oracle(
|
||||
... "I have a problem with my iphone that needs to be resolved asap!!",
|
||||
... candidate_labels=["urgent", "not urgent", "phone", "tablet", "computer"],
|
||||
... )
|
||||
{'sequence': 'I have a problem with my iphone that needs to be resolved asap!!', 'labels': ['urgent', 'phone', 'computer', 'not urgent', 'tablet'], 'scores': [0.504, 0.479, 0.013, 0.003, 0.002]}
|
||||
|
||||
>>> oracle(
|
||||
... "I have a problem with my iphone that needs to be resolved asap!!",
|
||||
... candidate_labels=["english", "german"],
|
||||
... )
|
||||
{'sequence': 'I have a problem with my iphone that needs to be resolved asap!!', 'labels': ['english', 'german'], 'scores': [0.814, 0.186]}
|
||||
```
|
||||
|
||||
[Learn more about the basics of using a pipeline in the [pipeline tutorial]](../pipeline_tutorial)
|
||||
|
||||
This NLI pipeline can currently be loaded from [`pipeline`] using the following task identifier:
|
||||
`"zero-shot-classification"`.
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user