faster forward following what is done for images (#21906)

* faster forward following what is done for images

* add missing licence
This commit is contained in:
Arthur 2023-03-03 06:18:18 +01:00 committed by GitHub
parent 37e0974afc
commit dcec3277cd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,3 +1,18 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import UserDict
from typing import Union
import numpy as np
@ -5,22 +20,17 @@ import requests
from ..utils import (
add_end_docstrings,
is_torch_available,
logging,
)
from .audio_classification import ffmpeg_read
from .base import PIPELINE_INIT_ARGS, ChunkPipeline
if is_torch_available():
import torch
from .base import PIPELINE_INIT_ARGS, Pipeline
logger = logging.get_logger(__name__)
@add_end_docstrings(PIPELINE_INIT_ARGS)
class ZeroShotAudioClassificationPipeline(ChunkPipeline):
class ZeroShotAudioClassificationPipeline(Pipeline):
"""
Zero shot audio classification pipeline using `ClapModel`. This pipeline predicts the class of an audio when you
provide an audio and a set of `candidate_labels`.
@ -101,38 +111,36 @@ class ZeroShotAudioClassificationPipeline(ChunkPipeline):
if len(audio.shape) != 1:
raise ValueError("We expect a single channel audio input for ZeroShotAudioClassificationPipeline")
n = len(candidate_labels)
for i, candidate_label in enumerate(candidate_labels):
audios = self.feature_extractor(
audio, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
)
sequence = hypothesis_template.format(candidate_label)
inputs = self.tokenizer(sequence, return_tensors=self.framework)
inputs["input_features"] = audios.input_features
yield {"is_last": i == n - 1, "candidate_label": candidate_label, **inputs}
inputs = self.feature_extractor(
[audio], sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
)
inputs["candidate_labels"] = candidate_labels
sequences = [hypothesis_template.format(x) for x in candidate_labels]
text_inputs = self.tokenizer(sequences, return_tensors=self.framework, padding=True)
inputs["text_inputs"] = [text_inputs]
def _forward(self, model_inputs):
is_last = model_inputs.pop("is_last")
candidate_label = model_inputs.pop("candidate_label")
outputs = self.model(**model_inputs)
candidate_labels = model_inputs.pop("candidate_labels")
text_inputs = model_inputs.pop("text_inputs")
if isinstance(text_inputs[0], UserDict):
text_inputs = text_inputs[0]
else:
# Batching case.
text_inputs = text_inputs[0][0]
# Clap does crossproduct scoring by default, so we're only
# interested in the results where audio and text and in the same
# batch position.
diag = torch.diagonal
logits_per_audio = diag(outputs.logits_per_audio)
outputs = self.model(**text_inputs, **model_inputs)
model_outputs = {
"is_last": is_last,
"candidate_label": candidate_label,
"logits_per_audio": logits_per_audio,
"candidate_label": candidate_labels,
"logits_per_audio": outputs.logits_per_audio,
}
return model_outputs
def postprocess(self, model_outputs):
candidate_labels = [outputs["candidate_label"] for outputs in model_outputs]
candidate_labels = model_outputs.pop("candidate_labels")
logits = model_outputs["logits"][0]
if self.framework == "pt":
logits = torch.cat([output["logits_per_audio"] for output in model_outputs])
probs = logits.softmax(dim=0)
scores = probs.tolist()
else: