mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
faster forward following what is done for images (#21906)
* faster forward following what is done for images * add missing licence
This commit is contained in:
parent
37e0974afc
commit
dcec3277cd
@ -1,3 +1,18 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from collections import UserDict
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
@ -5,22 +20,17 @@ import requests
|
||||
|
||||
from ..utils import (
|
||||
add_end_docstrings,
|
||||
is_torch_available,
|
||||
logging,
|
||||
)
|
||||
from .audio_classification import ffmpeg_read
|
||||
from .base import PIPELINE_INIT_ARGS, ChunkPipeline
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from .base import PIPELINE_INIT_ARGS, Pipeline
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@add_end_docstrings(PIPELINE_INIT_ARGS)
|
||||
class ZeroShotAudioClassificationPipeline(ChunkPipeline):
|
||||
class ZeroShotAudioClassificationPipeline(Pipeline):
|
||||
"""
|
||||
Zero shot audio classification pipeline using `ClapModel`. This pipeline predicts the class of an audio when you
|
||||
provide an audio and a set of `candidate_labels`.
|
||||
@ -101,38 +111,36 @@ class ZeroShotAudioClassificationPipeline(ChunkPipeline):
|
||||
if len(audio.shape) != 1:
|
||||
raise ValueError("We expect a single channel audio input for ZeroShotAudioClassificationPipeline")
|
||||
|
||||
n = len(candidate_labels)
|
||||
for i, candidate_label in enumerate(candidate_labels):
|
||||
audios = self.feature_extractor(
|
||||
audio, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
|
||||
)
|
||||
sequence = hypothesis_template.format(candidate_label)
|
||||
inputs = self.tokenizer(sequence, return_tensors=self.framework)
|
||||
inputs["input_features"] = audios.input_features
|
||||
yield {"is_last": i == n - 1, "candidate_label": candidate_label, **inputs}
|
||||
inputs = self.feature_extractor(
|
||||
[audio], sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
|
||||
)
|
||||
inputs["candidate_labels"] = candidate_labels
|
||||
sequences = [hypothesis_template.format(x) for x in candidate_labels]
|
||||
text_inputs = self.tokenizer(sequences, return_tensors=self.framework, padding=True)
|
||||
inputs["text_inputs"] = [text_inputs]
|
||||
|
||||
def _forward(self, model_inputs):
|
||||
is_last = model_inputs.pop("is_last")
|
||||
candidate_label = model_inputs.pop("candidate_label")
|
||||
outputs = self.model(**model_inputs)
|
||||
candidate_labels = model_inputs.pop("candidate_labels")
|
||||
text_inputs = model_inputs.pop("text_inputs")
|
||||
if isinstance(text_inputs[0], UserDict):
|
||||
text_inputs = text_inputs[0]
|
||||
else:
|
||||
# Batching case.
|
||||
text_inputs = text_inputs[0][0]
|
||||
|
||||
# Clap does crossproduct scoring by default, so we're only
|
||||
# interested in the results where audio and text and in the same
|
||||
# batch position.
|
||||
diag = torch.diagonal
|
||||
logits_per_audio = diag(outputs.logits_per_audio)
|
||||
outputs = self.model(**text_inputs, **model_inputs)
|
||||
|
||||
model_outputs = {
|
||||
"is_last": is_last,
|
||||
"candidate_label": candidate_label,
|
||||
"logits_per_audio": logits_per_audio,
|
||||
"candidate_label": candidate_labels,
|
||||
"logits_per_audio": outputs.logits_per_audio,
|
||||
}
|
||||
return model_outputs
|
||||
|
||||
def postprocess(self, model_outputs):
|
||||
candidate_labels = [outputs["candidate_label"] for outputs in model_outputs]
|
||||
candidate_labels = model_outputs.pop("candidate_labels")
|
||||
logits = model_outputs["logits"][0]
|
||||
|
||||
if self.framework == "pt":
|
||||
logits = torch.cat([output["logits_per_audio"] for output in model_outputs])
|
||||
probs = logits.softmax(dim=0)
|
||||
scores = probs.tolist()
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user