mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
* added condition for top_k Doc mismatch fix * initilation of test file for top_k changes * added test for returning all labels * added test for few labels * tests/test_audio_classification_top_k.py * final fix * ruff fix --------- Co-authored-by: sambhavnoobcoder <indosambahv@gmail.com>
This commit is contained in:
parent
694aaa7fbc
commit
0de15c988b
@ -91,8 +91,11 @@ class AudioClassificationPipeline(Pipeline):
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# Default, might be overriden by the model.config.
|
||||
kwargs["top_k"] = kwargs.get("top_k", 5)
|
||||
# Only set default top_k if explicitly provided
|
||||
if "top_k" in kwargs and kwargs["top_k"] is None:
|
||||
kwargs["top_k"] = None
|
||||
elif "top_k" not in kwargs:
|
||||
kwargs["top_k"] = 5
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
if self.framework != "pt":
|
||||
@ -141,12 +144,16 @@ class AudioClassificationPipeline(Pipeline):
|
||||
return super().__call__(inputs, **kwargs)
|
||||
|
||||
def _sanitize_parameters(self, top_k=None, function_to_apply=None, **kwargs):
|
||||
# No parameters on this pipeline right now
|
||||
postprocess_params = {}
|
||||
if top_k is not None:
|
||||
|
||||
# If top_k is None, use all labels
|
||||
if top_k is None:
|
||||
postprocess_params["top_k"] = self.model.config.num_labels
|
||||
else:
|
||||
if top_k > self.model.config.num_labels:
|
||||
top_k = self.model.config.num_labels
|
||||
postprocess_params["top_k"] = top_k
|
||||
|
||||
if function_to_apply is not None:
|
||||
if function_to_apply not in ["softmax", "sigmoid", "none"]:
|
||||
raise ValueError(
|
||||
|
60
tests/test_audio_classification_top_k.py
Normal file
60
tests/test_audio_classification_top_k.py
Normal file
@ -0,0 +1,60 @@
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import pipeline
|
||||
from transformers.testing_utils import require_torch
|
||||
|
||||
|
||||
@require_torch
|
||||
class AudioClassificationTopKTest(unittest.TestCase):
|
||||
def test_top_k_none_returns_all_labels(self):
|
||||
model_name = "superb/wav2vec2-base-superb-ks" # model with more than 5 labels
|
||||
classification_pipeline = pipeline(
|
||||
"audio-classification",
|
||||
model=model_name,
|
||||
top_k=None,
|
||||
)
|
||||
|
||||
# Create dummy input
|
||||
sampling_rate = 16000
|
||||
signal = np.zeros((sampling_rate,), dtype=np.float32)
|
||||
|
||||
result = classification_pipeline(signal)
|
||||
num_labels = classification_pipeline.model.config.num_labels
|
||||
|
||||
self.assertEqual(len(result), num_labels, "Should return all labels when top_k is None")
|
||||
|
||||
def test_top_k_none_with_few_labels(self):
|
||||
model_name = "superb/hubert-base-superb-er" # model with fewer labels
|
||||
classification_pipeline = pipeline(
|
||||
"audio-classification",
|
||||
model=model_name,
|
||||
top_k=None,
|
||||
)
|
||||
|
||||
# Create dummy input
|
||||
sampling_rate = 16000
|
||||
signal = np.zeros((sampling_rate,), dtype=np.float32)
|
||||
|
||||
result = classification_pipeline(signal)
|
||||
num_labels = classification_pipeline.model.config.num_labels
|
||||
|
||||
self.assertEqual(len(result), num_labels, "Should handle models with fewer labels correctly")
|
||||
|
||||
def test_top_k_greater_than_labels(self):
|
||||
model_name = "superb/hubert-base-superb-er"
|
||||
classification_pipeline = pipeline(
|
||||
"audio-classification",
|
||||
model=model_name,
|
||||
top_k=100, # intentionally large number
|
||||
)
|
||||
|
||||
# Create dummy input
|
||||
sampling_rate = 16000
|
||||
signal = np.zeros((sampling_rate,), dtype=np.float32)
|
||||
|
||||
result = classification_pipeline(signal)
|
||||
num_labels = classification_pipeline.model.config.num_labels
|
||||
|
||||
self.assertEqual(len(result), num_labels, "Should cap top_k to number of labels")
|
Loading…
Reference in New Issue
Block a user