diff --git a/src/transformers/pipelines.py b/src/transformers/pipelines.py index df7db72d6af..f8b742a8952 100755 --- a/src/transformers/pipelines.py +++ b/src/transformers/pipelines.py @@ -87,6 +87,18 @@ def get_framework(model=None): return framework +class PipelineException(Exception): + """ + Raised by pipelines when handling __call__ + """ + + def __init__(self, task: str, model: str, reason: str): + super().__init__(reason) + + self.task = task + self.model = model + + class ArgumentHandler(ABC): """ Base interface for handling varargs for each Pipeline @@ -808,6 +820,21 @@ class FillMaskPipeline(Pipeline): self.topk = topk + def ensure_exactly_one_mask_token(self, masked_index: np.ndarray): + numel = np.prod(masked_index.shape) + if numel > 1: + raise PipelineException( + "fill-mask", + self.model.base_model_prefix, + f"More than one mask_token ({self.tokenizer.mask_token}) is not supported", + ) + elif numel < 1: + raise PipelineException( + "fill-mask", + self.model.base_model_prefix, + f"No mask_token ({self.tokenizer.mask_token}) found on the input", + ) + def __call__(self, *args, **kwargs): inputs = self._parse_and_tokenize(*args, **kwargs) outputs = self._forward(inputs, return_tensors=True) @@ -820,15 +847,22 @@ class FillMaskPipeline(Pipeline): result = [] if self.framework == "tf": - masked_index = tf.where(input_ids == self.tokenizer.mask_token_id).numpy().item() - logits = outputs[i, masked_index, :] + masked_index = tf.where(input_ids == self.tokenizer.mask_token_id).numpy() + + # Fill mask pipeline supports only one ${mask_token} per sample + self.ensure_exactly_one_mask_token(masked_index) + + logits = outputs[i, masked_index.item(), :] probs = tf.nn.softmax(logits) topk = tf.math.top_k(probs, k=self.topk) values, predictions = topk.values.numpy(), topk.indices.numpy() else: - masked_index = (input_ids == self.tokenizer.mask_token_id).nonzero().item() + masked_index = (input_ids == self.tokenizer.mask_token_id).nonzero() - logits = outputs[i, masked_index, :] + # Fill mask pipeline supports only one ${mask_token} per sample + self.ensure_exactly_one_mask_token(masked_index.numpy()) + + logits = outputs[i, masked_index.item(), :] probs = logits.softmax(dim=0) values, predictions = probs.topk(self.topk) diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index cfad071bd29..1b978f5afd9 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -217,9 +217,15 @@ class MonoColumnInputTestCase(unittest.TestCase): "My name is ", "The largest city in France is ", ] + invalid_inputs = [ + "This is " # More than 1 mask_token in the input is not supported + "This is" # No mask_token is not supported + ] for model_name in FILL_MASK_FINETUNED_MODELS: nlp = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="pt", topk=2,) - self._test_mono_column_pipeline(nlp, valid_inputs, mandatory_keys, expected_check_keys=["sequence"]) + self._test_mono_column_pipeline( + nlp, valid_inputs, mandatory_keys, invalid_inputs, expected_check_keys=["sequence"] + ) @require_tf def test_tf_fill_mask(self): @@ -228,9 +234,15 @@ class MonoColumnInputTestCase(unittest.TestCase): "My name is ", "The largest city in France is ", ] + invalid_inputs = [ + "This is " # More than 1 mask_token in the input is not supported + "This is" # No mask_token is not supported + ] for model_name in FILL_MASK_FINETUNED_MODELS: nlp = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="tf", topk=2,) - self._test_mono_column_pipeline(nlp, valid_inputs, mandatory_keys, expected_check_keys=["sequence"]) + self._test_mono_column_pipeline( + nlp, valid_inputs, mandatory_keys, invalid_inputs, expected_check_keys=["sequence"] + ) @require_torch @slow