mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
[Pipeline] Add zero shot audio classificatoin pipeline (#21600)
* add pipeline * update init * add zero shot to init * update inits and correct checkpoints * update base to support input features * add tests * Update src/transformers/pipelines/zero_shot_audio_classification.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Update src/transformers/pipelines/zero_shot_audio_classification.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * update pieline code * use tiny checkpoint * nits and expected value with tiny model * style * last nit on tests values * fix styling * fix collate fn that was casting t float * update --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
This commit is contained in:
parent
2ea1ef9090
commit
cc44e72d14
@ -314,6 +314,12 @@ Pipelines available for audio tasks include the following.
|
||||
- __call__
|
||||
- all
|
||||
|
||||
### ZeroShotAudioClassificationPipeline
|
||||
|
||||
[[autodoc]] ZeroShotAudioClassificationPipeline
|
||||
- __call__
|
||||
- all
|
||||
|
||||
## Computer vision
|
||||
|
||||
Pipelines available for computer vision tasks include the following.
|
||||
|
@ -560,6 +560,7 @@ _import_structure = {
|
||||
"TranslationPipeline",
|
||||
"VideoClassificationPipeline",
|
||||
"VisualQuestionAnsweringPipeline",
|
||||
"ZeroShotAudioClassificationPipeline",
|
||||
"ZeroShotClassificationPipeline",
|
||||
"ZeroShotImageClassificationPipeline",
|
||||
"ZeroShotObjectDetectionPipeline",
|
||||
@ -4108,6 +4109,7 @@ if TYPE_CHECKING:
|
||||
TranslationPipeline,
|
||||
VideoClassificationPipeline,
|
||||
VisualQuestionAnsweringPipeline,
|
||||
ZeroShotAudioClassificationPipeline,
|
||||
ZeroShotClassificationPipeline,
|
||||
ZeroShotImageClassificationPipeline,
|
||||
ZeroShotObjectDetectionPipeline,
|
||||
|
@ -825,16 +825,15 @@ class ClapAudioEncoder(nn.Module):
|
||||
self.config = config
|
||||
self.patch_embed = ClapAudioPatchEmbed(config)
|
||||
self.enable_fusion = config.enable_fusion
|
||||
grid_size = self.patch_embed.grid_size
|
||||
self.patch_stride = self.patch_embed.patch_stride
|
||||
self.spec_size = config.spec_size
|
||||
self.freq_ratio = self.spec_size // config.num_mel_bins
|
||||
self.freq_ratio = config.spec_size // config.num_mel_bins
|
||||
|
||||
self.num_features = int(config.patch_embeds_hidden_size * 2 ** (self.num_layers - 1))
|
||||
self.freq_ratio = config.spec_size // config.num_mel_bins
|
||||
|
||||
drop_path_rate = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
|
||||
|
||||
grid_size = self.patch_embed.grid_size
|
||||
self.input_resolutions = [(grid_size[0] // (2**i), grid_size[1] // (2**i)) for i in range(self.num_layers)]
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
|
@ -78,6 +78,7 @@ from .token_classification import (
|
||||
)
|
||||
from .video_classification import VideoClassificationPipeline
|
||||
from .visual_question_answering import VisualQuestionAnsweringPipeline
|
||||
from .zero_shot_audio_classification import ZeroShotAudioClassificationPipeline
|
||||
from .zero_shot_classification import ZeroShotClassificationArgumentHandler, ZeroShotClassificationPipeline
|
||||
from .zero_shot_image_classification import ZeroShotImageClassificationPipeline
|
||||
from .zero_shot_object_detection import ZeroShotObjectDetectionPipeline
|
||||
@ -299,6 +300,17 @@ SUPPORTED_TASKS = {
|
||||
},
|
||||
"type": "multimodal",
|
||||
},
|
||||
"zero-shot-audio-classification": {
|
||||
"impl": ZeroShotAudioClassificationPipeline,
|
||||
"tf": (TFAutoModel,) if is_tf_available() else (),
|
||||
"pt": (AutoModel,) if is_torch_available() else (),
|
||||
"default": {
|
||||
"model": {
|
||||
"pt": ("laion/clap-htsat-fused", "f39917b"),
|
||||
}
|
||||
},
|
||||
"type": "multimodal",
|
||||
},
|
||||
"conversational": {
|
||||
"impl": ConversationalPipeline,
|
||||
"tf": (TFAutoModelForSeq2SeqLM, TFAutoModelForCausalLM) if is_tf_available() else (),
|
||||
@ -534,6 +546,7 @@ def pipeline(
|
||||
- `"visual-question-answering"`: will return a [`VisualQuestionAnsweringPipeline`].
|
||||
- `"zero-shot-classification"`: will return a [`ZeroShotClassificationPipeline`].
|
||||
- `"zero-shot-image-classification"`: will return a [`ZeroShotImageClassificationPipeline`].
|
||||
- `"zero-shot-audio-classification"`: will return a [`ZeroShotAudioClassificationPipeline`].
|
||||
- `"zero-shot-object-detection"`: will return a [`ZeroShotObjectDetectionPipeline`].
|
||||
|
||||
model (`str` or [`PreTrainedModel`] or [`TFPreTrainedModel`], *optional*):
|
||||
|
@ -81,6 +81,9 @@ def _pad(items, key, padding_value, padding_side):
|
||||
# This is probable image so padding shouldn't be necessary
|
||||
# B, C, H, W
|
||||
return torch.cat([item[key] for item in items], dim=0)
|
||||
elif dim == 4 and key == "input_features":
|
||||
# this is probably a mel spectrogram batched
|
||||
return torch.cat([item[key] for item in items], dim=0)
|
||||
max_length = max(item[key].shape[1] for item in items)
|
||||
min_length = min(item[key].shape[1] for item in items)
|
||||
dtype = items[0][key].dtype
|
||||
@ -154,7 +157,7 @@ def pad_collate_fn(tokenizer, feature_extractor):
|
||||
for key in keys:
|
||||
if key in {"input_ids"}:
|
||||
# ImageGPT uses a feature extractor
|
||||
if feature_extractor is not None:
|
||||
if tokenizer is None and feature_extractor is not None:
|
||||
_padding_value = f_padding_value
|
||||
else:
|
||||
_padding_value = t_padding_value
|
||||
|
145
src/transformers/pipelines/zero_shot_audio_classification.py
Normal file
145
src/transformers/pipelines/zero_shot_audio_classification.py
Normal file
@ -0,0 +1,145 @@
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
from ..utils import (
|
||||
add_end_docstrings,
|
||||
is_torch_available,
|
||||
logging,
|
||||
)
|
||||
from .audio_classification import ffmpeg_read
|
||||
from .base import PIPELINE_INIT_ARGS, ChunkPipeline
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@add_end_docstrings(PIPELINE_INIT_ARGS)
|
||||
class ZeroShotAudioClassificationPipeline(ChunkPipeline):
|
||||
"""
|
||||
Zero shot audio classification pipeline using `ClapModel`. This pipeline predicts the class of an audio when you
|
||||
provide an audio and a set of `candidate_labels`.
|
||||
|
||||
Example:
|
||||
```python
|
||||
>>> from transformers import pipeline
|
||||
>>> from datasets import load_dataset
|
||||
|
||||
>>> dataset = load_dataset("ashraq/esc50")
|
||||
>>> audio = next(iter(dataset["train"]["audio"]))["array"]
|
||||
>>> classifier = pipeline(task="zero-shot-audio-classification", model="laion-ai/clap-htsat-tiny")
|
||||
>>> classifier(audio, candidate_labels=["Sound of a dog", "Sound of vaccum cleaner"])
|
||||
```
|
||||
|
||||
|
||||
Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial) This audio
|
||||
classification pipeline can currently be loaded from [`pipeline`] using the following task identifier:
|
||||
`"zero-shot-audio-classification"`. See the list of available models on
|
||||
[huggingface.co/models](https://huggingface.co/models?filter=zero-shot-audio-classification).
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if self.framework != "pt":
|
||||
raise ValueError(f"The {self.__class__} is only available in PyTorch.")
|
||||
# No specific FOR_XXX available yet
|
||||
|
||||
def __call__(self, audios: Union[np.ndarray, bytes, str], **kwargs):
|
||||
"""
|
||||
Assign labels to the audio(s) passed as inputs.
|
||||
|
||||
Args:
|
||||
audios (`str`, `List[str]`, `np.array` or `List[np.array]`):
|
||||
The pipeline handles three types of inputs:
|
||||
- A string containing a http link pointing to an audio
|
||||
- A string containing a local path to an audio
|
||||
- An audio loaded in numpy
|
||||
candidate_labels (`List[str]`):
|
||||
The candidate labels for this audio
|
||||
hypothesis_template (`str`, *optional*, defaults to `"This is a sound of {}"`):
|
||||
The sentence used in cunjunction with *candidate_labels* to attempt the audio classification by
|
||||
replacing the placeholder with the candidate_labels. Then likelihood is estimated by using
|
||||
logits_per_audio
|
||||
Return:
|
||||
A list of dictionaries containing result, one dictionary per proposed label. The dictionaries contain the
|
||||
following keys:
|
||||
- **label** (`str`) -- The label identified by the model. It is one of the suggested `candidate_label`.
|
||||
- **score** (`float`) -- The score attributed by the model for that label (between 0 and 1).
|
||||
"""
|
||||
return super().__call__(audios, **kwargs)
|
||||
|
||||
def _sanitize_parameters(self, **kwargs):
|
||||
preprocess_params = {}
|
||||
if "candidate_labels" in kwargs:
|
||||
preprocess_params["candidate_labels"] = kwargs["candidate_labels"]
|
||||
if "hypothesis_template" in kwargs:
|
||||
preprocess_params["hypothesis_template"] = kwargs["hypothesis_template"]
|
||||
|
||||
return preprocess_params, {}, {}
|
||||
|
||||
def preprocess(self, audio, candidate_labels=None, hypothesis_template="This is a sound of {}."):
|
||||
if isinstance(audio, str):
|
||||
if audio.startswith("http://") or audio.startswith("https://"):
|
||||
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
|
||||
# like http_huggingface_co.png
|
||||
audio = requests.get(audio).content
|
||||
else:
|
||||
with open(audio, "rb") as f:
|
||||
audio = f.read()
|
||||
|
||||
if isinstance(audio, bytes):
|
||||
audio = ffmpeg_read(audio, self.feature_extractor.sampling_rate)
|
||||
|
||||
if not isinstance(audio, np.ndarray):
|
||||
raise ValueError("We expect a numpy ndarray as input")
|
||||
if len(audio.shape) != 1:
|
||||
raise ValueError("We expect a single channel audio input for ZeroShotAudioClassificationPipeline")
|
||||
|
||||
n = len(candidate_labels)
|
||||
for i, candidate_label in enumerate(candidate_labels):
|
||||
audios = self.feature_extractor(
|
||||
audio, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
|
||||
)
|
||||
sequence = hypothesis_template.format(candidate_label)
|
||||
inputs = self.tokenizer(sequence, return_tensors=self.framework)
|
||||
inputs["input_features"] = audios.input_features
|
||||
yield {"is_last": i == n - 1, "candidate_label": candidate_label, **inputs}
|
||||
|
||||
def _forward(self, model_inputs):
|
||||
is_last = model_inputs.pop("is_last")
|
||||
candidate_label = model_inputs.pop("candidate_label")
|
||||
outputs = self.model(**model_inputs)
|
||||
|
||||
# Clap does crossproduct scoring by default, so we're only
|
||||
# interested in the results where audio and text and in the same
|
||||
# batch position.
|
||||
diag = torch.diagonal
|
||||
logits_per_audio = diag(outputs.logits_per_audio)
|
||||
|
||||
model_outputs = {
|
||||
"is_last": is_last,
|
||||
"candidate_label": candidate_label,
|
||||
"logits_per_audio": logits_per_audio,
|
||||
}
|
||||
return model_outputs
|
||||
|
||||
def postprocess(self, model_outputs):
|
||||
candidate_labels = [outputs["candidate_label"] for outputs in model_outputs]
|
||||
if self.framework == "pt":
|
||||
logits = torch.cat([output["logits_per_audio"] for output in model_outputs])
|
||||
probs = logits.softmax(dim=0)
|
||||
scores = probs.tolist()
|
||||
else:
|
||||
raise ValueError("`tf` framework not supported.")
|
||||
|
||||
result = [
|
||||
{"score": score, "label": candidate_label}
|
||||
for score, candidate_label in sorted(zip(scores, candidate_labels), key=lambda x: -x[0])
|
||||
]
|
||||
return result
|
@ -0,0 +1,95 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
from transformers.pipelines import pipeline
|
||||
from transformers.testing_utils import nested_simplify, require_torch, slow
|
||||
|
||||
from .test_pipelines_common import PipelineTestCaseMeta
|
||||
|
||||
|
||||
@require_torch
|
||||
class ZeroShotAudioClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
|
||||
# Deactivating auto tests since we don't have a good MODEL_FOR_XX mapping,
|
||||
# and only CLAP would be there for now.
|
||||
# model_mapping = {CLAPConfig: CLAPModel}
|
||||
|
||||
@require_torch
|
||||
def test_small_model_pt(self):
|
||||
audio_classifier = pipeline(
|
||||
task="zero-shot-audio-classification", model="hf-internal-testing/tiny-clap-htsat-unfused"
|
||||
)
|
||||
dataset = load_dataset("ashraq/esc50")
|
||||
audio = dataset["train"]["audio"][-1]["array"]
|
||||
output = audio_classifier(audio, candidate_labels=["Sound of a dog", "Sound of vaccum cleaner"])
|
||||
self.assertEqual(
|
||||
nested_simplify(output),
|
||||
[{"score": 0.501, "label": "Sound of a dog"}, {"score": 0.499, "label": "Sound of vaccum cleaner"}],
|
||||
)
|
||||
|
||||
@unittest.skip("No models are available in TF")
|
||||
def test_small_model_tf(self):
|
||||
pass
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_large_model_pt(self):
|
||||
audio_classifier = pipeline(
|
||||
task="zero-shot-audio-classification",
|
||||
model="laion/clap-htsat-unfused",
|
||||
)
|
||||
# This is an audio of a dog
|
||||
dataset = load_dataset("ashraq/esc50")
|
||||
audio = dataset["train"]["audio"][-1]["array"]
|
||||
output = audio_classifier(audio, candidate_labels=["Sound of a dog", "Sound of vaccum cleaner"])
|
||||
|
||||
self.assertEqual(
|
||||
nested_simplify(output),
|
||||
[
|
||||
{"score": 0.999, "label": "Sound of a dog"},
|
||||
{"score": 0.001, "label": "Sound of vaccum cleaner"},
|
||||
],
|
||||
)
|
||||
|
||||
output = audio_classifier([audio] * 5, candidate_labels=["Sound of a dog", "Sound of vaccum cleaner"])
|
||||
self.assertEqual(
|
||||
nested_simplify(output),
|
||||
[
|
||||
[
|
||||
{"score": 0.999, "label": "Sound of a dog"},
|
||||
{"score": 0.001, "label": "Sound of vaccum cleaner"},
|
||||
],
|
||||
]
|
||||
* 5,
|
||||
)
|
||||
output = audio_classifier(
|
||||
[audio] * 5, candidate_labels=["Sound of a dog", "Sound of vaccum cleaner"], batch_size=5
|
||||
)
|
||||
self.assertEqual(
|
||||
nested_simplify(output),
|
||||
[
|
||||
[
|
||||
{"score": 0.999, "label": "Sound of a dog"},
|
||||
{"score": 0.001, "label": "Sound of vaccum cleaner"},
|
||||
],
|
||||
]
|
||||
* 5,
|
||||
)
|
||||
|
||||
@unittest.skip("No models are available in TF")
|
||||
def test_large_model_tf(self):
|
||||
pass
|
Loading…
Reference in New Issue
Block a user