Adding audio-classification example in the doc. (#20235)

* Adding `audio-classification` example in the doc.

* Adding `>>>` to get the real test.

* Removing assert.

* Fixup.
This commit is contained in:
Nicolas Patry 2022-11-16 09:51:03 +01:00 committed by GitHub
parent a00b7e85ea
commit 860ea8a574
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -16,6 +16,8 @@ from typing import Union
import numpy as np
import requests
from ..utils import add_end_docstrings, is_torch_available, logging
from .base import PIPELINE_INIT_ARGS, Pipeline
@ -69,6 +71,24 @@ class AudioClassificationPipeline(Pipeline):
raw waveform or an audio file. In case of an audio file, ffmpeg should be installed to support multiple audio
formats.
Example:
```python
>>> from transformers import pipeline
>>> classifier = pipeline(model="superb/wav2vec2-base-superb-ks")
>>> result = classifier("https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/1.flac")
>>> # Simplify results, different torch versions might alter the scores slightly.
>>> from transformers.testing_utils import nested_simplify
>>> nested_simplify(result)
[{'score': 0.997, 'label': '_unknown_'}, {'score': 0.002, 'label': 'left'}, {'score': 0.0, 'label': 'yes'}, {'score': 0.0, 'label': 'down'}, {'score': 0.0, 'label': 'stop'}]
```
[Using pipelines in a webserver or with a dataset](../pipeline_tutorial)
This pipeline can currently be loaded from [`pipeline`] using the following task identifier:
`"audio-classification"`.
@ -126,8 +146,13 @@ class AudioClassificationPipeline(Pipeline):
def preprocess(self, inputs):
if isinstance(inputs, str):
with open(inputs, "rb") as f:
inputs = f.read()
if inputs.startswith("http://") or inputs.startswith("https://"):
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
# like http_huggingface_co.png
inputs = requests.get(inputs).content
else:
with open(inputs, "rb") as f:
inputs = f.read()
if isinstance(inputs, bytes):
inputs = ffmpeg_read(inputs, self.feature_extractor.sampling_rate)