mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
fix low-precision audio classification pipeline (#35435)
* fix low-precision audio classification pipeline Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * add test Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix format Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix torch import Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix torch import Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix format Signed-off-by: jiqing-feng <jiqing.feng@intel.com> --------- Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
This commit is contained in:
parent
641238eb76
commit
f19135afc7
@ -212,6 +212,8 @@ class AudioClassificationPipeline(Pipeline):
|
||||
processed = self.feature_extractor(
|
||||
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
|
||||
)
|
||||
if self.torch_dtype is not None:
|
||||
processed = processed.to(dtype=self.torch_dtype)
|
||||
return processed
|
||||
|
||||
def _forward(self, model_inputs):
|
||||
|
@ -17,7 +17,11 @@ import unittest
|
||||
import numpy as np
|
||||
from huggingface_hub import AudioClassificationOutputElement
|
||||
|
||||
from transformers import MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING, TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
|
||||
from transformers import (
|
||||
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING,
|
||||
TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING,
|
||||
is_torch_available,
|
||||
)
|
||||
from transformers.pipelines import AudioClassificationPipeline, pipeline
|
||||
from transformers.testing_utils import (
|
||||
compare_pipeline_output_to_hub_spec,
|
||||
@ -32,6 +36,10 @@ from transformers.testing_utils import (
|
||||
from .test_pipelines_common import ANY
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
@is_pipeline_test
|
||||
class AudioClassificationPipelineTests(unittest.TestCase):
|
||||
model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
|
||||
@ -127,6 +135,33 @@ class AudioClassificationPipelineTests(unittest.TestCase):
|
||||
output = audio_classifier(audio_dict, top_k=4)
|
||||
self.assertIn(nested_simplify(output, decimals=4), [EXPECTED_OUTPUT, EXPECTED_OUTPUT_PT_2])
|
||||
|
||||
@require_torch
|
||||
def test_small_model_pt_fp16(self):
|
||||
model = "anton-l/wav2vec2-random-tiny-classifier"
|
||||
|
||||
audio_classifier = pipeline("audio-classification", model=model, torch_dtype=torch.float16)
|
||||
|
||||
audio = np.ones((8000,))
|
||||
output = audio_classifier(audio, top_k=4)
|
||||
|
||||
EXPECTED_OUTPUT = [
|
||||
{"score": 0.0839, "label": "no"},
|
||||
{"score": 0.0837, "label": "go"},
|
||||
{"score": 0.0836, "label": "yes"},
|
||||
{"score": 0.0835, "label": "right"},
|
||||
]
|
||||
EXPECTED_OUTPUT_PT_2 = [
|
||||
{"score": 0.0845, "label": "stop"},
|
||||
{"score": 0.0844, "label": "on"},
|
||||
{"score": 0.0841, "label": "right"},
|
||||
{"score": 0.0834, "label": "left"},
|
||||
]
|
||||
self.assertIn(nested_simplify(output, decimals=4), [EXPECTED_OUTPUT, EXPECTED_OUTPUT_PT_2])
|
||||
|
||||
audio_dict = {"array": np.ones((8000,)), "sampling_rate": audio_classifier.feature_extractor.sampling_rate}
|
||||
output = audio_classifier(audio_dict, top_k=4)
|
||||
self.assertIn(nested_simplify(output, decimals=4), [EXPECTED_OUTPUT, EXPECTED_OUTPUT_PT_2])
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_large_model_pt(self):
|
||||
|
Loading…
Reference in New Issue
Block a user