Fix 2 tests in FillMaskPipelineTests (#27889)

* fix

* fix

* fix

* fix

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2023-12-08 14:55:29 +01:00 committed by GitHub
parent 79e7655906
commit e366937587
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -216,15 +216,24 @@ class FillMaskPipelineTests(unittest.TestCase):
],
)
dummy_str = "Lorem ipsum dolor sit amet, consectetur adipiscing elit," * 100
outputs = unmasker(
"My name is <mask>" + "Lorem ipsum dolor sit amet, consectetur adipiscing elit," * 100,
"My name is <mask>" + dummy_str,
tokenizer_kwargs={"truncation": True},
)
simplified = nested_simplify(outputs, decimals=4)
self.assertEqual(
nested_simplify(outputs, decimals=6),
[{"sequence": x["sequence"][:100]} for x in simplified],
[
{"sequence": "My name is grouped", "score": 2.2e-05, "token": 38015, "token_str": " grouped"},
{"sequence": "My name is accuser", "score": 2.1e-05, "token": 25506, "token_str": " accuser"},
{"sequence": f"My name is,{dummy_str}"[:100]},
{"sequence": f"My name is:,{dummy_str}"[:100]},
],
)
self.assertEqual(
[{k: x[k] for k in x if k != "sequence"} for x in simplified],
[
{"score": 0.2819, "token": 6, "token_str": ","},
{"score": 0.0954, "token": 46686, "token_str": ":,"},
],
)