diff --git a/tests/pipelines/test_pipelines_token_classification.py b/tests/pipelines/test_pipelines_token_classification.py index 16767b342c8..c3c474be8db 100644 --- a/tests/pipelines/test_pipelines_token_classification.py +++ b/tests/pipelines/test_pipelines_token_classification.py @@ -328,8 +328,10 @@ class TokenClassificationPipelineTests(unittest.TestCase): self.assertEqual( nested_simplify(output), [ - {"entity_group": "PER", "score": ANY(float), "word": "Sarah", "start": 6, "end": 11}, - {"entity_group": "LOC", "score": ANY(float), "word": "New York", "start": 21, "end": 29}, + [ + {"entity_group": "PER", "score": ANY(float), "word": "Sarah", "start": 6, "end": 11}, + {"entity_group": "LOC", "score": ANY(float), "word": "New York", "start": 21, "end": 29}, + ] ], ) @@ -349,8 +351,8 @@ class TokenClassificationPipelineTests(unittest.TestCase): {"entity_group": "LOC", "score": ANY(float), "word": "New York", "start": 21, "end": 29}, ], [ - {"entity_group": "PER", "score": ANY(float), "word": "Wolfgang", "start": 12, "end": 20}, - {"entity_group": "LOC", "score": ANY(float), "word": "Berlin", "start": 36, "end": 42}, + {"entity_group": "PER", "score": ANY(float), "word": "Wolfgang", "start": 11, "end": 19}, + {"entity_group": "LOC", "score": ANY(float), "word": "Berlin", "start": 34, "end": 40}, ], ], )