Raises PipelineException on FillMaskPipeline when there are != 1 mask_token in the input (#5389)

* Added PipelineException

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* fill-mask pipeline raises exception when more than one mask_token detected.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Put everything in a function.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Added tests on pipeline fill-mask when input has != 1 mask_token

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Fix numel() computation for TF

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Addressing PR comments.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Remove function typing to avoid import on specific framework.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Quality.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Retry typing with @julien-c tip.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Quality².

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Simplify fill-mask mask_token checking.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Trigger CI
This commit is contained in:
Funtowicz Morgan 2020-07-01 17:27:47 +02:00 committed by GitHub
parent 6c55e9fc32
commit 608d5a7c44
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 52 additions and 6 deletions

View File

@ -87,6 +87,18 @@ def get_framework(model=None):
return framework
class PipelineException(Exception):
"""
Raised by pipelines when handling __call__
"""
def __init__(self, task: str, model: str, reason: str):
super().__init__(reason)
self.task = task
self.model = model
class ArgumentHandler(ABC):
"""
Base interface for handling varargs for each Pipeline
@ -808,6 +820,21 @@ class FillMaskPipeline(Pipeline):
self.topk = topk
def ensure_exactly_one_mask_token(self, masked_index: np.ndarray):
numel = np.prod(masked_index.shape)
if numel > 1:
raise PipelineException(
"fill-mask",
self.model.base_model_prefix,
f"More than one mask_token ({self.tokenizer.mask_token}) is not supported",
)
elif numel < 1:
raise PipelineException(
"fill-mask",
self.model.base_model_prefix,
f"No mask_token ({self.tokenizer.mask_token}) found on the input",
)
def __call__(self, *args, **kwargs):
inputs = self._parse_and_tokenize(*args, **kwargs)
outputs = self._forward(inputs, return_tensors=True)
@ -820,15 +847,22 @@ class FillMaskPipeline(Pipeline):
result = []
if self.framework == "tf":
masked_index = tf.where(input_ids == self.tokenizer.mask_token_id).numpy().item()
logits = outputs[i, masked_index, :]
masked_index = tf.where(input_ids == self.tokenizer.mask_token_id).numpy()
# Fill mask pipeline supports only one ${mask_token} per sample
self.ensure_exactly_one_mask_token(masked_index)
logits = outputs[i, masked_index.item(), :]
probs = tf.nn.softmax(logits)
topk = tf.math.top_k(probs, k=self.topk)
values, predictions = topk.values.numpy(), topk.indices.numpy()
else:
masked_index = (input_ids == self.tokenizer.mask_token_id).nonzero().item()
masked_index = (input_ids == self.tokenizer.mask_token_id).nonzero()
logits = outputs[i, masked_index, :]
# Fill mask pipeline supports only one ${mask_token} per sample
self.ensure_exactly_one_mask_token(masked_index.numpy())
logits = outputs[i, masked_index.item(), :]
probs = logits.softmax(dim=0)
values, predictions = probs.topk(self.topk)

View File

@ -217,9 +217,15 @@ class MonoColumnInputTestCase(unittest.TestCase):
"My name is <mask>",
"The largest city in France is <mask>",
]
invalid_inputs = [
"This is <mask> <mask>" # More than 1 mask_token in the input is not supported
"This is" # No mask_token is not supported
]
for model_name in FILL_MASK_FINETUNED_MODELS:
nlp = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="pt", topk=2,)
self._test_mono_column_pipeline(nlp, valid_inputs, mandatory_keys, expected_check_keys=["sequence"])
self._test_mono_column_pipeline(
nlp, valid_inputs, mandatory_keys, invalid_inputs, expected_check_keys=["sequence"]
)
@require_tf
def test_tf_fill_mask(self):
@ -228,9 +234,15 @@ class MonoColumnInputTestCase(unittest.TestCase):
"My name is <mask>",
"The largest city in France is <mask>",
]
invalid_inputs = [
"This is <mask> <mask>" # More than 1 mask_token in the input is not supported
"This is" # No mask_token is not supported
]
for model_name in FILL_MASK_FINETUNED_MODELS:
nlp = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="tf", topk=2,)
self._test_mono_column_pipeline(nlp, valid_inputs, mandatory_keys, expected_check_keys=["sequence"])
self._test_mono_column_pipeline(
nlp, valid_inputs, mandatory_keys, invalid_inputs, expected_check_keys=["sequence"]
)
@require_torch
@slow