diff --git a/src/transformers/pipelines/fill_mask.py b/src/transformers/pipelines/fill_mask.py index bca61d144d2..5392db979b6 100644 --- a/src/transformers/pipelines/fill_mask.py +++ b/src/transformers/pipelines/fill_mask.py @@ -219,4 +219,7 @@ class FillMaskPipeline(Pipeline): - **token** (:obj:`int`) -- The predicted token id (to replace the masked one). - **token** (:obj:`str`) -- The predicted token (to replace the masked one). """ - return super().__call__(inputs, **kwargs) + outputs = super().__call__(inputs, **kwargs) + if isinstance(inputs, list) and len(inputs) == 1: + return outputs[0] + return outputs diff --git a/tests/test_pipelines_fill_mask.py b/tests/test_pipelines_fill_mask.py index 8f65417f472..fb48fe52cdb 100644 --- a/tests/test_pipelines_fill_mask.py +++ b/tests/test_pipelines_fill_mask.py @@ -186,6 +186,18 @@ class FillMaskPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta): ], ) + outputs = fill_masker([f"This is a {tokenizer.mask_token}"]) + self.assertEqual( + outputs, + [ + {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)}, + {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)}, + {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)}, + {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)}, + {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)}, + ], + ) + outputs = fill_masker([f"This is a {tokenizer.mask_token}", f"Another {tokenizer.mask_token} great test."]) self.assertEqual( outputs,