diff --git a/tests/pipelines/test_pipelines_fill_mask.py b/tests/pipelines/test_pipelines_fill_mask.py index c85797fbb6e..571b320d617 100644 --- a/tests/pipelines/test_pipelines_fill_mask.py +++ b/tests/pipelines/test_pipelines_fill_mask.py @@ -216,15 +216,24 @@ class FillMaskPipelineTests(unittest.TestCase): ], ) + dummy_str = "Lorem ipsum dolor sit amet, consectetur adipiscing elit," * 100 outputs = unmasker( - "My name is " + "Lorem ipsum dolor sit amet, consectetur adipiscing elit," * 100, + "My name is " + dummy_str, tokenizer_kwargs={"truncation": True}, ) + simplified = nested_simplify(outputs, decimals=4) self.assertEqual( - nested_simplify(outputs, decimals=6), + [{"sequence": x["sequence"][:100]} for x in simplified], [ - {"sequence": "My name is grouped", "score": 2.2e-05, "token": 38015, "token_str": " grouped"}, - {"sequence": "My name is accuser", "score": 2.1e-05, "token": 25506, "token_str": " accuser"}, + {"sequence": f"My name is,{dummy_str}"[:100]}, + {"sequence": f"My name is:,{dummy_str}"[:100]}, + ], + ) + self.assertEqual( + [{k: x[k] for k in x if k != "sequence"} for x in simplified], + [ + {"score": 0.2819, "token": 6, "token_str": ","}, + {"score": 0.0954, "token": 46686, "token_str": ":,"}, ], )