mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
faster forward following what is done for images (#21906)
* faster forward following what is done for images * add missing licence
This commit is contained in:
parent
37e0974afc
commit
dcec3277cd
@ -1,3 +1,18 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2023 The HuggingFace Inc. team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from collections import UserDict
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -5,22 +20,17 @@ import requests
|
|||||||
|
|
||||||
from ..utils import (
|
from ..utils import (
|
||||||
add_end_docstrings,
|
add_end_docstrings,
|
||||||
is_torch_available,
|
|
||||||
logging,
|
logging,
|
||||||
)
|
)
|
||||||
from .audio_classification import ffmpeg_read
|
from .audio_classification import ffmpeg_read
|
||||||
from .base import PIPELINE_INIT_ARGS, ChunkPipeline
|
from .base import PIPELINE_INIT_ARGS, Pipeline
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@add_end_docstrings(PIPELINE_INIT_ARGS)
|
@add_end_docstrings(PIPELINE_INIT_ARGS)
|
||||||
class ZeroShotAudioClassificationPipeline(ChunkPipeline):
|
class ZeroShotAudioClassificationPipeline(Pipeline):
|
||||||
"""
|
"""
|
||||||
Zero shot audio classification pipeline using `ClapModel`. This pipeline predicts the class of an audio when you
|
Zero shot audio classification pipeline using `ClapModel`. This pipeline predicts the class of an audio when you
|
||||||
provide an audio and a set of `candidate_labels`.
|
provide an audio and a set of `candidate_labels`.
|
||||||
@ -101,38 +111,36 @@ class ZeroShotAudioClassificationPipeline(ChunkPipeline):
|
|||||||
if len(audio.shape) != 1:
|
if len(audio.shape) != 1:
|
||||||
raise ValueError("We expect a single channel audio input for ZeroShotAudioClassificationPipeline")
|
raise ValueError("We expect a single channel audio input for ZeroShotAudioClassificationPipeline")
|
||||||
|
|
||||||
n = len(candidate_labels)
|
inputs = self.feature_extractor(
|
||||||
for i, candidate_label in enumerate(candidate_labels):
|
[audio], sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
|
||||||
audios = self.feature_extractor(
|
)
|
||||||
audio, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
|
inputs["candidate_labels"] = candidate_labels
|
||||||
)
|
sequences = [hypothesis_template.format(x) for x in candidate_labels]
|
||||||
sequence = hypothesis_template.format(candidate_label)
|
text_inputs = self.tokenizer(sequences, return_tensors=self.framework, padding=True)
|
||||||
inputs = self.tokenizer(sequence, return_tensors=self.framework)
|
inputs["text_inputs"] = [text_inputs]
|
||||||
inputs["input_features"] = audios.input_features
|
|
||||||
yield {"is_last": i == n - 1, "candidate_label": candidate_label, **inputs}
|
|
||||||
|
|
||||||
def _forward(self, model_inputs):
|
def _forward(self, model_inputs):
|
||||||
is_last = model_inputs.pop("is_last")
|
candidate_labels = model_inputs.pop("candidate_labels")
|
||||||
candidate_label = model_inputs.pop("candidate_label")
|
text_inputs = model_inputs.pop("text_inputs")
|
||||||
outputs = self.model(**model_inputs)
|
if isinstance(text_inputs[0], UserDict):
|
||||||
|
text_inputs = text_inputs[0]
|
||||||
|
else:
|
||||||
|
# Batching case.
|
||||||
|
text_inputs = text_inputs[0][0]
|
||||||
|
|
||||||
# Clap does crossproduct scoring by default, so we're only
|
outputs = self.model(**text_inputs, **model_inputs)
|
||||||
# interested in the results where audio and text and in the same
|
|
||||||
# batch position.
|
|
||||||
diag = torch.diagonal
|
|
||||||
logits_per_audio = diag(outputs.logits_per_audio)
|
|
||||||
|
|
||||||
model_outputs = {
|
model_outputs = {
|
||||||
"is_last": is_last,
|
"candidate_label": candidate_labels,
|
||||||
"candidate_label": candidate_label,
|
"logits_per_audio": outputs.logits_per_audio,
|
||||||
"logits_per_audio": logits_per_audio,
|
|
||||||
}
|
}
|
||||||
return model_outputs
|
return model_outputs
|
||||||
|
|
||||||
def postprocess(self, model_outputs):
|
def postprocess(self, model_outputs):
|
||||||
candidate_labels = [outputs["candidate_label"] for outputs in model_outputs]
|
candidate_labels = model_outputs.pop("candidate_labels")
|
||||||
|
logits = model_outputs["logits"][0]
|
||||||
|
|
||||||
if self.framework == "pt":
|
if self.framework == "pt":
|
||||||
logits = torch.cat([output["logits_per_audio"] for output in model_outputs])
|
|
||||||
probs = logits.softmax(dim=0)
|
probs = logits.softmax(dim=0)
|
||||||
scores = probs.tolist()
|
scores = probs.tolist()
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user