transformers/tests/test_audio_classification_top_k.py
Sambhav Dixit 0de15c988b
Fix Audio Classification Pipeline top_k Documentation Mismatch and Bug #35736 (#35771)
* added condition for top_k Doc mismatch fix

* initilation of test file for top_k changes

* added test for returning all labels

* added test for few labels

* tests/test_audio_classification_top_k.py

* final fix

* ruff fix

---------

Co-authored-by: sambhavnoobcoder <indosambahv@gmail.com>
2025-02-05 16:25:08 +00:00

61 lines
2.0 KiB
Python

import unittest
import numpy as np
from transformers import pipeline
from transformers.testing_utils import require_torch
@require_torch
class AudioClassificationTopKTest(unittest.TestCase):
def test_top_k_none_returns_all_labels(self):
model_name = "superb/wav2vec2-base-superb-ks" # model with more than 5 labels
classification_pipeline = pipeline(
"audio-classification",
model=model_name,
top_k=None,
)
# Create dummy input
sampling_rate = 16000
signal = np.zeros((sampling_rate,), dtype=np.float32)
result = classification_pipeline(signal)
num_labels = classification_pipeline.model.config.num_labels
self.assertEqual(len(result), num_labels, "Should return all labels when top_k is None")
def test_top_k_none_with_few_labels(self):
model_name = "superb/hubert-base-superb-er" # model with fewer labels
classification_pipeline = pipeline(
"audio-classification",
model=model_name,
top_k=None,
)
# Create dummy input
sampling_rate = 16000
signal = np.zeros((sampling_rate,), dtype=np.float32)
result = classification_pipeline(signal)
num_labels = classification_pipeline.model.config.num_labels
self.assertEqual(len(result), num_labels, "Should handle models with fewer labels correctly")
def test_top_k_greater_than_labels(self):
model_name = "superb/hubert-base-superb-er"
classification_pipeline = pipeline(
"audio-classification",
model=model_name,
top_k=100, # intentionally large number
)
# Create dummy input
sampling_rate = 16000
signal = np.zeros((sampling_rate,), dtype=np.float32)
result = classification_pipeline(signal)
num_labels = classification_pipeline.model.config.num_labels
self.assertEqual(len(result), num_labels, "Should cap top_k to number of labels")