diff --git a/src/transformers/pipelines/zero_shot_audio_classification.py b/src/transformers/pipelines/zero_shot_audio_classification.py index c534763e87e..b1604aa92c1 100644 --- a/src/transformers/pipelines/zero_shot_audio_classification.py +++ b/src/transformers/pipelines/zero_shot_audio_classification.py @@ -1,3 +1,18 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import UserDict from typing import Union import numpy as np @@ -5,22 +20,17 @@ import requests from ..utils import ( add_end_docstrings, - is_torch_available, logging, ) from .audio_classification import ffmpeg_read -from .base import PIPELINE_INIT_ARGS, ChunkPipeline - - -if is_torch_available(): - import torch +from .base import PIPELINE_INIT_ARGS, Pipeline logger = logging.get_logger(__name__) @add_end_docstrings(PIPELINE_INIT_ARGS) -class ZeroShotAudioClassificationPipeline(ChunkPipeline): +class ZeroShotAudioClassificationPipeline(Pipeline): """ Zero shot audio classification pipeline using `ClapModel`. This pipeline predicts the class of an audio when you provide an audio and a set of `candidate_labels`. @@ -101,38 +111,36 @@ class ZeroShotAudioClassificationPipeline(ChunkPipeline): if len(audio.shape) != 1: raise ValueError("We expect a single channel audio input for ZeroShotAudioClassificationPipeline") - n = len(candidate_labels) - for i, candidate_label in enumerate(candidate_labels): - audios = self.feature_extractor( - audio, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt" - ) - sequence = hypothesis_template.format(candidate_label) - inputs = self.tokenizer(sequence, return_tensors=self.framework) - inputs["input_features"] = audios.input_features - yield {"is_last": i == n - 1, "candidate_label": candidate_label, **inputs} + inputs = self.feature_extractor( + [audio], sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt" + ) + inputs["candidate_labels"] = candidate_labels + sequences = [hypothesis_template.format(x) for x in candidate_labels] + text_inputs = self.tokenizer(sequences, return_tensors=self.framework, padding=True) + inputs["text_inputs"] = [text_inputs] def _forward(self, model_inputs): - is_last = model_inputs.pop("is_last") - candidate_label = model_inputs.pop("candidate_label") - outputs = self.model(**model_inputs) + candidate_labels = model_inputs.pop("candidate_labels") + text_inputs = model_inputs.pop("text_inputs") + if isinstance(text_inputs[0], UserDict): + text_inputs = text_inputs[0] + else: + # Batching case. + text_inputs = text_inputs[0][0] - # Clap does crossproduct scoring by default, so we're only - # interested in the results where audio and text and in the same - # batch position. - diag = torch.diagonal - logits_per_audio = diag(outputs.logits_per_audio) + outputs = self.model(**text_inputs, **model_inputs) model_outputs = { - "is_last": is_last, - "candidate_label": candidate_label, - "logits_per_audio": logits_per_audio, + "candidate_label": candidate_labels, + "logits_per_audio": outputs.logits_per_audio, } return model_outputs def postprocess(self, model_outputs): - candidate_labels = [outputs["candidate_label"] for outputs in model_outputs] + candidate_labels = model_outputs.pop("candidate_labels") + logits = model_outputs["logits"][0] + if self.framework == "pt": - logits = torch.cat([output["logits_per_audio"] for output in model_outputs]) probs = logits.softmax(dim=0) scores = probs.tolist() else: