Update AudioClassificationPipelineTests::test_small_model_pt for PT 2.0.0 (#22023)

fix

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2023-03-08 13:56:47 +01:00 committed by GitHub
parent bbd949970d
commit dfe9a31973
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -88,15 +88,20 @@ class AudioClassificationPipelineTests(unittest.TestCase):
audio = np.ones((8000,))
output = audio_classifier(audio, top_k=4)
self.assertEqual(
nested_simplify(output, decimals=4),
[
{"score": 0.0842, "label": "no"},
{"score": 0.0838, "label": "up"},
{"score": 0.0837, "label": "go"},
{"score": 0.0834, "label": "right"},
],
)
EXPECTED_OUTPUT = [
{"score": 0.0842, "label": "no"},
{"score": 0.0838, "label": "up"},
{"score": 0.0837, "label": "go"},
{"score": 0.0834, "label": "right"},
]
EXPECTED_OUTPUT_PT_2 = [
{"score": 0.0845, "label": "stop"},
{"score": 0.0844, "label": "on"},
{"score": 0.0841, "label": "right"},
{"score": 0.0834, "label": "left"},
]
self.assertIn(nested_simplify(output, decimals=4), [EXPECTED_OUTPUT, EXPECTED_OUTPUT_PT_2])
@require_torch
@slow