mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 21:30:07 +06:00

* [WIP] Rework the pipeline tutorial - Switch to `asr` instead of another NLP task. - It also has simpler to understand results. - Added a section with interaction with `datasets`. - Added a section with writing a simple webserver. * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Addressing comments. * Links. * Fixing docs format. * Adding pipeline_webserver to _toctree. * Warnig -> Tip warnings={true}. * Fix link ? * Links ? * Fixing link, adding chunk batching. * Oops. * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/pipeline_tutorial.mdx Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
274 lines
12 KiB
Plaintext
274 lines
12 KiB
Plaintext
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
|
the License. You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
|
specific language governing permissions and limitations under the License.
|
|
-->
|
|
|
|
# Pipelines for inference
|
|
|
|
The [`pipeline`] makes it simple to use any model from the [Hub](https://huggingface.co/models) for inference on any language, computer vision, speech, and multimodal tasks. Even if you don't have experience with a specific modality or aren't familiar with the underlying code behind the models, you can still use them for inference with the [`pipeline`]! This tutorial will teach you to:
|
|
|
|
* Use a [`pipeline`] for inference.
|
|
* Use a specific tokenizer or model.
|
|
* Use a [`pipeline`] for audio, vision, and multimodal tasks.
|
|
|
|
<Tip>
|
|
|
|
Take a look at the [`pipeline`] documentation for a complete list of supported tasks and available parameters.
|
|
|
|
</Tip>
|
|
|
|
## Pipeline usage
|
|
|
|
While each task has an associated [`pipeline`], it is simpler to use the general [`pipeline`] abstraction which contains all the task-specific pipelines. The [`pipeline`] automatically loads a default model and a preprocessing class capable of inference for your task.
|
|
|
|
1. Start by creating a [`pipeline`] and specify an inference task:
|
|
|
|
```py
|
|
>>> from transformers import pipeline
|
|
|
|
>>> generator = pipeline(task="automatic-speech-recognition")
|
|
```
|
|
|
|
2. Pass your input text to the [`pipeline`]:
|
|
|
|
```py
|
|
>>> generator("https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/mlk.flac")
|
|
{'text': 'I HAVE A DREAM BUT ONE DAY THIS NATION WILL RISE UP LIVE UP THE TRUE MEANING OF ITS TREES'}
|
|
```
|
|
|
|
Not the result you had in mind? Check out some of the [most downloaded automatic speech recognition models](https://huggingface.co/models?pipeline_tag=automatic-speech-recognition&sort=downloads) on the Hub to see if you can get a better transcription.
|
|
Let's try [openai/whisper-large](https://huggingface.co/openai/whisper-large):
|
|
|
|
```py
|
|
>>> generator = pipeline(model="openai/whisper-large")
|
|
>>> generator("https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/mlk.flac")
|
|
{'text': ' I have a dream that one day this nation will rise up and live out the true meaning of its creed.'}
|
|
```
|
|
|
|
Now this result looks more accurate!
|
|
We really encourage you to check out the Hub for models in different languages, models specialized in your field, and more.
|
|
You can check out and compare model results directly from your browser on the Hub to see if it fits or
|
|
handles corner cases better than other ones.
|
|
And if you don't find a model for your use case, you can always start [training](training) your own!
|
|
|
|
If you have several inputs, you can pass your input as a list:
|
|
|
|
```py
|
|
generator(
|
|
[
|
|
"https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/mlk.flac",
|
|
"https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/1.flac",
|
|
]
|
|
)
|
|
```
|
|
|
|
If you want to iterate over a whole dataset, or want to use it for inference in a webserver, check out dedicated parts
|
|
|
|
[Using pipelines on a dataset](#using-pipelines-on-a-dataset)
|
|
|
|
[Using pipelines for a webserver](./pipeline_webserver)
|
|
|
|
## Parameters
|
|
|
|
[`pipeline`] supports many parameters; some are task specific, and some are general to all pipelines.
|
|
In general you can specify parameters anywhere you want:
|
|
|
|
```py
|
|
generator(model="openai/whisper-large", my_parameter=1)
|
|
out = generate(...) # This will use `my_parameter=1`.
|
|
out = generate(..., my_parameter=2) # This will override and use `my_parameter=2`.
|
|
out = generate(...) # This will go back to using `my_parameter=1`.
|
|
```
|
|
|
|
Let's check out 3 important ones:
|
|
|
|
### Device
|
|
|
|
If you use `device=n`, the pipeline automatically puts the model on the specified device.
|
|
This will work regardless of whether you are using PyTorch or Tensorflow.
|
|
|
|
```py
|
|
generator(model="openai/whisper-large", device=0)
|
|
```
|
|
|
|
If the model is too large for a single GPU, you can set `device_map="auto"` to allow 🤗 [Accelerate](https://huggingface.co/docs/accelerate) to automatically determine how to load and store the model weights.
|
|
|
|
```py
|
|
#!pip install accelerate
|
|
generator(model="openai/whisper-large", device_map="auto")
|
|
```
|
|
|
|
### Batch size
|
|
|
|
By default, pipelines will not batch inference for reasons explained in detail [here](https://huggingface.co/docs/transformers/main_classes/pipelines#pipeline-batching). The reason is that batching is not necessarily faster, and can actually be quite slower in some cases.
|
|
|
|
But if it works in your use case, you can use:
|
|
|
|
```py
|
|
generator(model="openai/whisper-large", device=0, batch_size=2)
|
|
audio_filenames = [f"audio_{i}.flac" for i in range(10)]
|
|
texts = generator(audio_filenames)
|
|
```
|
|
|
|
This runs the pipeline on the 10 provided audio files, but it will pass them in batches of 2
|
|
to the model (which is on a GPU, where batching is more likely to help) without requiring any further code from you.
|
|
The output should always match what you would have received without batching. It is only meant as a way to help you get more speed out of a pipeline.
|
|
|
|
Pipelines can also alleviate some of the complexities of batching because, for some pipelines, a single item (like a long audio file) needs to be chunked into multiple parts to be processed by a model. The pipeline performs this [*chunk batching*](./main_classes/pipelines#pipeline-chunk-batching) for you.
|
|
|
|
### Task specific parameters
|
|
|
|
All tasks provide task specific parameters which allow for additional flexibility and options to help you get your job done.
|
|
For instance, the [`transformers.AutomaticSpeechRecognitionPipeline.__call__`] method has a `return_timestamps` parameter which sounds promising for subtitling videos:
|
|
|
|
|
|
```py
|
|
>>> # Not using whisper, as it cannot provide timestamps.
|
|
>>> generator = pipeline(model="facebook/wav2vec2-large-960h-lv60-self", return_timestamps="word")
|
|
>>> generator("https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/mlk.flac")
|
|
{'text': 'I HAVE A DREAM BUT ONE DAY THIS NATION WILL RISE UP AND LIVE OUT THE TRUE MEANING OF ITS CREED',
|
|
'chunks': [
|
|
{'text': 'I', 'timestamp': (1.22, 1.24)},
|
|
{'text': 'HAVE', 'timestamp': (1.42, 1.58)},
|
|
{'text': 'A', 'timestamp': (1.66, 1.68)},
|
|
{'text': 'DREAM', 'timestamp': (1.76, 2.14)},
|
|
{'text': 'BUT', 'timestamp': (3.68, 3.8)},
|
|
{'text': 'ONE', 'timestamp': (3.94, 4.06)},
|
|
{'text': 'DAY', 'timestamp': (4.16, 4.3)},
|
|
{'text': 'THIS', 'timestamp': (6.36, 6.54)},
|
|
{'text': 'NATION', 'timestamp': (6.68, 7.1)},
|
|
{'text': 'WILL', 'timestamp': (7.32, 7.56)},
|
|
{'text': 'RISE', 'timestamp': (7.8, 8.26)},
|
|
{'text': 'UP', 'timestamp': (8.38, 8.48)},
|
|
{'text': 'AND', 'timestamp': (10.08, 10.18)},
|
|
{'text': 'LIVE', 'timestamp': (10.26, 10.48)},
|
|
{'text': 'OUT', 'timestamp': (10.58, 10.7)},
|
|
{'text': 'THE', 'timestamp': (10.82, 10.9)},
|
|
{'text': 'TRUE', 'timestamp': (10.98, 11.18)},
|
|
{'text': 'MEANING', 'timestamp': (11.26, 11.58)},
|
|
{'text': 'OF', 'timestamp': (11.66, 11.7)},
|
|
{'text': 'ITS', 'timestamp': (11.76, 11.88)},
|
|
{'text': 'CREED', 'timestamp': (12.0, 12.38)}
|
|
]}
|
|
```
|
|
|
|
As you can see, the model inferred the text and also outputted **when** the various words were pronounced
|
|
in the sentence.
|
|
|
|
There are many parameters available for each task, so check out each task's API reference to see what you can tinker with!
|
|
For instance, the [`~transformers.AutomaticSpeechRecognitionPipeline`] has a `chunk_length_s` parameter which is helpful for working on really long audio files (for example, subtitling entire movies or hour-long videos) that a model typically cannot handle on its own.
|
|
|
|
|
|
If you can't find a parameter that would really help you out, feel free to [request it](https://github.com/huggingface/transformers/issues/new?assignees=&labels=feature&template=feature-request.yml)!
|
|
|
|
|
|
## Using pipelines on a dataset
|
|
|
|
The pipeline can also run inference on a large dataset. The easiest way we recommend doing this is by using an iterator:
|
|
|
|
```py
|
|
def data():
|
|
for i in range(1000):
|
|
yield f"My example {i}"
|
|
|
|
|
|
pipe = pipe(model="gpt2", device=0)
|
|
generated_characters = 0
|
|
for out in pipe(data()):
|
|
generated_characters += len(out["generated_text"])
|
|
```
|
|
|
|
The iterator `data()` yields each result, and the pipeline automatically
|
|
recognizes the input is iterable and will start fetching the data while
|
|
it continues to process it on the GPU (this uses [DataLoader](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) under the hood).
|
|
This is important because you don't have to allocate memory for the whole dataset
|
|
and you can feed the GPU as fast as possible.
|
|
|
|
Since batching could speed things up, it may be useful to try tuning the `batch_size` parameter here.
|
|
|
|
The simplest way to iterate over a dataset is to just load one from 🤗 [Datasets](https://github.com/huggingface/datasets/):
|
|
|
|
```py
|
|
# KeyDataset is a util that will just output the item we're interested in.
|
|
from transformers.pipelines.pt_utils import KeyDataset
|
|
|
|
pipe = pipeline(model="hf-internal-testing/tiny-random-wav2vec2", device=0)
|
|
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation[:10]")
|
|
|
|
for out in pipe(KeyDataset(dataset["audio"])):
|
|
print(out)
|
|
```
|
|
|
|
|
|
## Using pipelines for a webserver
|
|
|
|
<Tip>
|
|
Creating an inference engine is a complex topic which deserves it's own
|
|
page.
|
|
</Tip>
|
|
|
|
[Link](./pipeline_webserver)
|
|
|
|
## Vision pipeline
|
|
|
|
Using a [`pipeline`] for vision tasks is practically identical.
|
|
|
|
Specify your task and pass your image to the classifier. The image can be a link or a local path to the image. For example, what species of cat is shown below?
|
|
|
|

|
|
|
|
```py
|
|
>>> from transformers import pipeline
|
|
|
|
>>> vision_classifier = pipeline(model="google/vit-base-patch16-224")
|
|
>>> preds = vision_classifier(
|
|
... images="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
|
|
... )
|
|
>>> preds = [{"score": round(pred["score"], 4), "label": pred["label"]} for pred in preds]
|
|
>>> preds
|
|
[{'score': 0.4335, 'label': 'lynx, catamount'}, {'score': 0.0348, 'label': 'cougar, puma, catamount, mountain lion, painter, panther, Felis concolor'}, {'score': 0.0324, 'label': 'snow leopard, ounce, Panthera uncia'}, {'score': 0.0239, 'label': 'Egyptian cat'}, {'score': 0.0229, 'label': 'tiger cat'}]
|
|
```
|
|
|
|
### Text pipeline
|
|
|
|
Using a [`pipeline`] for NLP tasks is practically identical.
|
|
|
|
```py
|
|
>>> from transformers import pipeline
|
|
|
|
>>> # This model is a `zero-shot-classification` model.
|
|
>>> # It will classify text, except you are free to choose any label you might imagine
|
|
>>> classifier = pipeline(model="facebook/bart-large-mnli")
|
|
>>> classifier(
|
|
... "I have a problem with my iphone that needs to be resolved asap!!",
|
|
... candidate_labels=["urgent", "not urgent", "phone", "tablet", "computer"],
|
|
... )
|
|
{'sequence': 'I have a problem with my iphone that needs to be resolved asap!!',
|
|
'labels': ['urgent', 'phone', 'computer', 'not urgent', 'tablet'],
|
|
'scores': [0.504,0.479,0.013,0.003,0.002]}
|
|
```
|
|
|
|
### Multimodal pipeline
|
|
|
|
The [`pipeline`] supports more than one modality. For example, a visual question answering (VQA) task combines text and image. Feel free to use any image link you like and a question you want to ask about the image. The image can be a URL or a local path to the image.
|
|
|
|
For example, if you use this [invoice image](https://huggingface.co/spaces/impira/docquery/resolve/2359223c1837a7587402bda0f2643382a6eefeab/invoice.png):
|
|
|
|
```py
|
|
>>> from transformers import pipeline
|
|
|
|
>>> vqa = pipeline(model="impira/layoutlm-document-qa")
|
|
>>> vqa(
|
|
... image="https://huggingface.co/spaces/impira/docquery/resolve/2359223c1837a7587402bda0f2643382a6eefeab/invoice.png",
|
|
... question="What is the invoice number?",
|
|
... )
|
|
[{'score': 0.635722279548645, 'answer': '1110212019', 'start': 22, 'end': 22}]
|
|
```
|