From dfe9a3197364c7f0e2169d7c16c357c9c1311cb9 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Wed, 8 Mar 2023 13:56:47 +0100 Subject: [PATCH] Update `AudioClassificationPipelineTests::test_small_model_pt` for PT 2.0.0 (#22023) fix Co-authored-by: ydshieh --- .../test_pipelines_audio_classification.py | 23 +++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/tests/pipelines/test_pipelines_audio_classification.py b/tests/pipelines/test_pipelines_audio_classification.py index b0ff5517973..208690396c4 100644 --- a/tests/pipelines/test_pipelines_audio_classification.py +++ b/tests/pipelines/test_pipelines_audio_classification.py @@ -88,15 +88,20 @@ class AudioClassificationPipelineTests(unittest.TestCase): audio = np.ones((8000,)) output = audio_classifier(audio, top_k=4) - self.assertEqual( - nested_simplify(output, decimals=4), - [ - {"score": 0.0842, "label": "no"}, - {"score": 0.0838, "label": "up"}, - {"score": 0.0837, "label": "go"}, - {"score": 0.0834, "label": "right"}, - ], - ) + + EXPECTED_OUTPUT = [ + {"score": 0.0842, "label": "no"}, + {"score": 0.0838, "label": "up"}, + {"score": 0.0837, "label": "go"}, + {"score": 0.0834, "label": "right"}, + ] + EXPECTED_OUTPUT_PT_2 = [ + {"score": 0.0845, "label": "stop"}, + {"score": 0.0844, "label": "on"}, + {"score": 0.0841, "label": "right"}, + {"score": 0.0834, "label": "left"}, + ] + self.assertIn(nested_simplify(output, decimals=4), [EXPECTED_OUTPUT, EXPECTED_OUTPUT_PT_2]) @require_torch @slow