Fix Audio Classification Pipeline top_k Documentation Mismatch and Bug #35736 (#35771)

* added condition for top_k Doc mismatch fix

* initilation of test file for top_k changes

* added test for returning all labels

* added test for few labels

* tests/test_audio_classification_top_k.py

* final fix

* ruff fix

---------

Co-authored-by: sambhavnoobcoder <indosambahv@gmail.com>
This commit is contained in:
Sambhav Dixit 2025-02-05 21:55:08 +05:30 committed by GitHub
parent 694aaa7fbc
commit 0de15c988b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 71 additions and 4 deletions

View File

@ -91,8 +91,11 @@ class AudioClassificationPipeline(Pipeline):
"""
def __init__(self, *args, **kwargs):
# Default, might be overriden by the model.config.
kwargs["top_k"] = kwargs.get("top_k", 5)
# Only set default top_k if explicitly provided
if "top_k" in kwargs and kwargs["top_k"] is None:
kwargs["top_k"] = None
elif "top_k" not in kwargs:
kwargs["top_k"] = 5
super().__init__(*args, **kwargs)
if self.framework != "pt":
@ -141,12 +144,16 @@ class AudioClassificationPipeline(Pipeline):
return super().__call__(inputs, **kwargs)
def _sanitize_parameters(self, top_k=None, function_to_apply=None, **kwargs):
# No parameters on this pipeline right now
postprocess_params = {}
if top_k is not None:
# If top_k is None, use all labels
if top_k is None:
postprocess_params["top_k"] = self.model.config.num_labels
else:
if top_k > self.model.config.num_labels:
top_k = self.model.config.num_labels
postprocess_params["top_k"] = top_k
if function_to_apply is not None:
if function_to_apply not in ["softmax", "sigmoid", "none"]:
raise ValueError(

View File

@ -0,0 +1,60 @@
import unittest
import numpy as np
from transformers import pipeline
from transformers.testing_utils import require_torch
@require_torch
class AudioClassificationTopKTest(unittest.TestCase):
def test_top_k_none_returns_all_labels(self):
model_name = "superb/wav2vec2-base-superb-ks" # model with more than 5 labels
classification_pipeline = pipeline(
"audio-classification",
model=model_name,
top_k=None,
)
# Create dummy input
sampling_rate = 16000
signal = np.zeros((sampling_rate,), dtype=np.float32)
result = classification_pipeline(signal)
num_labels = classification_pipeline.model.config.num_labels
self.assertEqual(len(result), num_labels, "Should return all labels when top_k is None")
def test_top_k_none_with_few_labels(self):
model_name = "superb/hubert-base-superb-er" # model with fewer labels
classification_pipeline = pipeline(
"audio-classification",
model=model_name,
top_k=None,
)
# Create dummy input
sampling_rate = 16000
signal = np.zeros((sampling_rate,), dtype=np.float32)
result = classification_pipeline(signal)
num_labels = classification_pipeline.model.config.num_labels
self.assertEqual(len(result), num_labels, "Should handle models with fewer labels correctly")
def test_top_k_greater_than_labels(self):
model_name = "superb/hubert-base-superb-er"
classification_pipeline = pipeline(
"audio-classification",
model=model_name,
top_k=100, # intentionally large number
)
# Create dummy input
sampling_rate = 16000
signal = np.zeros((sampling_rate,), dtype=np.float32)
result = classification_pipeline(signal)
num_labels = classification_pipeline.model.config.num_labels
self.assertEqual(len(result), num_labels, "Should cap top_k to number of labels")