mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Update AudioClassificationPipelineTests::test_small_model_pt
for PT 2.0.0 (#22023)
fix Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
bbd949970d
commit
dfe9a31973
@ -88,15 +88,20 @@ class AudioClassificationPipelineTests(unittest.TestCase):
|
||||
|
||||
audio = np.ones((8000,))
|
||||
output = audio_classifier(audio, top_k=4)
|
||||
self.assertEqual(
|
||||
nested_simplify(output, decimals=4),
|
||||
[
|
||||
{"score": 0.0842, "label": "no"},
|
||||
{"score": 0.0838, "label": "up"},
|
||||
{"score": 0.0837, "label": "go"},
|
||||
{"score": 0.0834, "label": "right"},
|
||||
],
|
||||
)
|
||||
|
||||
EXPECTED_OUTPUT = [
|
||||
{"score": 0.0842, "label": "no"},
|
||||
{"score": 0.0838, "label": "up"},
|
||||
{"score": 0.0837, "label": "go"},
|
||||
{"score": 0.0834, "label": "right"},
|
||||
]
|
||||
EXPECTED_OUTPUT_PT_2 = [
|
||||
{"score": 0.0845, "label": "stop"},
|
||||
{"score": 0.0844, "label": "on"},
|
||||
{"score": 0.0841, "label": "right"},
|
||||
{"score": 0.0834, "label": "left"},
|
||||
]
|
||||
self.assertIn(nested_simplify(output, decimals=4), [EXPECTED_OUTPUT, EXPECTED_OUTPUT_PT_2])
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
|
Loading…
Reference in New Issue
Block a user