[Pipeline] Add zero shot audio classificatoin pipeline (#21600)

* add pipeline

* update init

* add zero shot to init

* update inits and correct checkpoints

* update base to support input features

* add tests

* Update src/transformers/pipelines/zero_shot_audio_classification.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update src/transformers/pipelines/zero_shot_audio_classification.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* update pieline code

* use tiny checkpoint

* nits and expected value with tiny model

* style

* last nit on tests values

* fix styling

* fix collate fn that was casting t float

* update

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
This commit is contained in:
Arthur 2023-02-27 11:43:44 +01:00 committed by GitHub
parent 2ea1ef9090
commit cc44e72d14
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 267 additions and 4 deletions

View File

@ -314,6 +314,12 @@ Pipelines available for audio tasks include the following.
- __call__
- all
### ZeroShotAudioClassificationPipeline
[[autodoc]] ZeroShotAudioClassificationPipeline
- __call__
- all
## Computer vision
Pipelines available for computer vision tasks include the following.

View File

@ -560,6 +560,7 @@ _import_structure = {
"TranslationPipeline",
"VideoClassificationPipeline",
"VisualQuestionAnsweringPipeline",
"ZeroShotAudioClassificationPipeline",
"ZeroShotClassificationPipeline",
"ZeroShotImageClassificationPipeline",
"ZeroShotObjectDetectionPipeline",
@ -4108,6 +4109,7 @@ if TYPE_CHECKING:
TranslationPipeline,
VideoClassificationPipeline,
VisualQuestionAnsweringPipeline,
ZeroShotAudioClassificationPipeline,
ZeroShotClassificationPipeline,
ZeroShotImageClassificationPipeline,
ZeroShotObjectDetectionPipeline,

View File

@ -825,16 +825,15 @@ class ClapAudioEncoder(nn.Module):
self.config = config
self.patch_embed = ClapAudioPatchEmbed(config)
self.enable_fusion = config.enable_fusion
grid_size = self.patch_embed.grid_size
self.patch_stride = self.patch_embed.patch_stride
self.spec_size = config.spec_size
self.freq_ratio = self.spec_size // config.num_mel_bins
self.freq_ratio = config.spec_size // config.num_mel_bins
self.num_features = int(config.patch_embeds_hidden_size * 2 ** (self.num_layers - 1))
self.freq_ratio = config.spec_size // config.num_mel_bins
drop_path_rate = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
grid_size = self.patch_embed.grid_size
self.input_resolutions = [(grid_size[0] // (2**i), grid_size[1] // (2**i)) for i in range(self.num_layers)]
self.layers = nn.ModuleList(

View File

@ -78,6 +78,7 @@ from .token_classification import (
)
from .video_classification import VideoClassificationPipeline
from .visual_question_answering import VisualQuestionAnsweringPipeline
from .zero_shot_audio_classification import ZeroShotAudioClassificationPipeline
from .zero_shot_classification import ZeroShotClassificationArgumentHandler, ZeroShotClassificationPipeline
from .zero_shot_image_classification import ZeroShotImageClassificationPipeline
from .zero_shot_object_detection import ZeroShotObjectDetectionPipeline
@ -299,6 +300,17 @@ SUPPORTED_TASKS = {
},
"type": "multimodal",
},
"zero-shot-audio-classification": {
"impl": ZeroShotAudioClassificationPipeline,
"tf": (TFAutoModel,) if is_tf_available() else (),
"pt": (AutoModel,) if is_torch_available() else (),
"default": {
"model": {
"pt": ("laion/clap-htsat-fused", "f39917b"),
}
},
"type": "multimodal",
},
"conversational": {
"impl": ConversationalPipeline,
"tf": (TFAutoModelForSeq2SeqLM, TFAutoModelForCausalLM) if is_tf_available() else (),
@ -534,6 +546,7 @@ def pipeline(
- `"visual-question-answering"`: will return a [`VisualQuestionAnsweringPipeline`].
- `"zero-shot-classification"`: will return a [`ZeroShotClassificationPipeline`].
- `"zero-shot-image-classification"`: will return a [`ZeroShotImageClassificationPipeline`].
- `"zero-shot-audio-classification"`: will return a [`ZeroShotAudioClassificationPipeline`].
- `"zero-shot-object-detection"`: will return a [`ZeroShotObjectDetectionPipeline`].
model (`str` or [`PreTrainedModel`] or [`TFPreTrainedModel`], *optional*):

View File

@ -81,6 +81,9 @@ def _pad(items, key, padding_value, padding_side):
# This is probable image so padding shouldn't be necessary
# B, C, H, W
return torch.cat([item[key] for item in items], dim=0)
elif dim == 4 and key == "input_features":
# this is probably a mel spectrogram batched
return torch.cat([item[key] for item in items], dim=0)
max_length = max(item[key].shape[1] for item in items)
min_length = min(item[key].shape[1] for item in items)
dtype = items[0][key].dtype
@ -154,7 +157,7 @@ def pad_collate_fn(tokenizer, feature_extractor):
for key in keys:
if key in {"input_ids"}:
# ImageGPT uses a feature extractor
if feature_extractor is not None:
if tokenizer is None and feature_extractor is not None:
_padding_value = f_padding_value
else:
_padding_value = t_padding_value

View File

@ -0,0 +1,145 @@
from typing import Union
import numpy as np
import requests
from ..utils import (
add_end_docstrings,
is_torch_available,
logging,
)
from .audio_classification import ffmpeg_read
from .base import PIPELINE_INIT_ARGS, ChunkPipeline
if is_torch_available():
import torch
logger = logging.get_logger(__name__)
@add_end_docstrings(PIPELINE_INIT_ARGS)
class ZeroShotAudioClassificationPipeline(ChunkPipeline):
"""
Zero shot audio classification pipeline using `ClapModel`. This pipeline predicts the class of an audio when you
provide an audio and a set of `candidate_labels`.
Example:
```python
>>> from transformers import pipeline
>>> from datasets import load_dataset
>>> dataset = load_dataset("ashraq/esc50")
>>> audio = next(iter(dataset["train"]["audio"]))["array"]
>>> classifier = pipeline(task="zero-shot-audio-classification", model="laion-ai/clap-htsat-tiny")
>>> classifier(audio, candidate_labels=["Sound of a dog", "Sound of vaccum cleaner"])
```
Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial) This audio
classification pipeline can currently be loaded from [`pipeline`] using the following task identifier:
`"zero-shot-audio-classification"`. See the list of available models on
[huggingface.co/models](https://huggingface.co/models?filter=zero-shot-audio-classification).
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
if self.framework != "pt":
raise ValueError(f"The {self.__class__} is only available in PyTorch.")
# No specific FOR_XXX available yet
def __call__(self, audios: Union[np.ndarray, bytes, str], **kwargs):
"""
Assign labels to the audio(s) passed as inputs.
Args:
audios (`str`, `List[str]`, `np.array` or `List[np.array]`):
The pipeline handles three types of inputs:
- A string containing a http link pointing to an audio
- A string containing a local path to an audio
- An audio loaded in numpy
candidate_labels (`List[str]`):
The candidate labels for this audio
hypothesis_template (`str`, *optional*, defaults to `"This is a sound of {}"`):
The sentence used in cunjunction with *candidate_labels* to attempt the audio classification by
replacing the placeholder with the candidate_labels. Then likelihood is estimated by using
logits_per_audio
Return:
A list of dictionaries containing result, one dictionary per proposed label. The dictionaries contain the
following keys:
- **label** (`str`) -- The label identified by the model. It is one of the suggested `candidate_label`.
- **score** (`float`) -- The score attributed by the model for that label (between 0 and 1).
"""
return super().__call__(audios, **kwargs)
def _sanitize_parameters(self, **kwargs):
preprocess_params = {}
if "candidate_labels" in kwargs:
preprocess_params["candidate_labels"] = kwargs["candidate_labels"]
if "hypothesis_template" in kwargs:
preprocess_params["hypothesis_template"] = kwargs["hypothesis_template"]
return preprocess_params, {}, {}
def preprocess(self, audio, candidate_labels=None, hypothesis_template="This is a sound of {}."):
if isinstance(audio, str):
if audio.startswith("http://") or audio.startswith("https://"):
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
# like http_huggingface_co.png
audio = requests.get(audio).content
else:
with open(audio, "rb") as f:
audio = f.read()
if isinstance(audio, bytes):
audio = ffmpeg_read(audio, self.feature_extractor.sampling_rate)
if not isinstance(audio, np.ndarray):
raise ValueError("We expect a numpy ndarray as input")
if len(audio.shape) != 1:
raise ValueError("We expect a single channel audio input for ZeroShotAudioClassificationPipeline")
n = len(candidate_labels)
for i, candidate_label in enumerate(candidate_labels):
audios = self.feature_extractor(
audio, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
)
sequence = hypothesis_template.format(candidate_label)
inputs = self.tokenizer(sequence, return_tensors=self.framework)
inputs["input_features"] = audios.input_features
yield {"is_last": i == n - 1, "candidate_label": candidate_label, **inputs}
def _forward(self, model_inputs):
is_last = model_inputs.pop("is_last")
candidate_label = model_inputs.pop("candidate_label")
outputs = self.model(**model_inputs)
# Clap does crossproduct scoring by default, so we're only
# interested in the results where audio and text and in the same
# batch position.
diag = torch.diagonal
logits_per_audio = diag(outputs.logits_per_audio)
model_outputs = {
"is_last": is_last,
"candidate_label": candidate_label,
"logits_per_audio": logits_per_audio,
}
return model_outputs
def postprocess(self, model_outputs):
candidate_labels = [outputs["candidate_label"] for outputs in model_outputs]
if self.framework == "pt":
logits = torch.cat([output["logits_per_audio"] for output in model_outputs])
probs = logits.softmax(dim=0)
scores = probs.tolist()
else:
raise ValueError("`tf` framework not supported.")
result = [
{"score": score, "label": candidate_label}
for score, candidate_label in sorted(zip(scores, candidate_labels), key=lambda x: -x[0])
]
return result

View File

@ -0,0 +1,95 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from datasets import load_dataset
from transformers.pipelines import pipeline
from transformers.testing_utils import nested_simplify, require_torch, slow
from .test_pipelines_common import PipelineTestCaseMeta
@require_torch
class ZeroShotAudioClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
# Deactivating auto tests since we don't have a good MODEL_FOR_XX mapping,
# and only CLAP would be there for now.
# model_mapping = {CLAPConfig: CLAPModel}
@require_torch
def test_small_model_pt(self):
audio_classifier = pipeline(
task="zero-shot-audio-classification", model="hf-internal-testing/tiny-clap-htsat-unfused"
)
dataset = load_dataset("ashraq/esc50")
audio = dataset["train"]["audio"][-1]["array"]
output = audio_classifier(audio, candidate_labels=["Sound of a dog", "Sound of vaccum cleaner"])
self.assertEqual(
nested_simplify(output),
[{"score": 0.501, "label": "Sound of a dog"}, {"score": 0.499, "label": "Sound of vaccum cleaner"}],
)
@unittest.skip("No models are available in TF")
def test_small_model_tf(self):
pass
@slow
@require_torch
def test_large_model_pt(self):
audio_classifier = pipeline(
task="zero-shot-audio-classification",
model="laion/clap-htsat-unfused",
)
# This is an audio of a dog
dataset = load_dataset("ashraq/esc50")
audio = dataset["train"]["audio"][-1]["array"]
output = audio_classifier(audio, candidate_labels=["Sound of a dog", "Sound of vaccum cleaner"])
self.assertEqual(
nested_simplify(output),
[
{"score": 0.999, "label": "Sound of a dog"},
{"score": 0.001, "label": "Sound of vaccum cleaner"},
],
)
output = audio_classifier([audio] * 5, candidate_labels=["Sound of a dog", "Sound of vaccum cleaner"])
self.assertEqual(
nested_simplify(output),
[
[
{"score": 0.999, "label": "Sound of a dog"},
{"score": 0.001, "label": "Sound of vaccum cleaner"},
],
]
* 5,
)
output = audio_classifier(
[audio] * 5, candidate_labels=["Sound of a dog", "Sound of vaccum cleaner"], batch_size=5
)
self.assertEqual(
nested_simplify(output),
[
[
{"score": 0.999, "label": "Sound of a dog"},
{"score": 0.001, "label": "Sound of vaccum cleaner"},
],
]
* 5,
)
@unittest.skip("No models are available in TF")
def test_large_model_tf(self):
pass