mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
Raises PipelineException on FillMaskPipeline when there are != 1 mask_token in the input (#5389)
* Added PipelineException Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * fill-mask pipeline raises exception when more than one mask_token detected. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Put everything in a function. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Added tests on pipeline fill-mask when input has != 1 mask_token Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Fix numel() computation for TF Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Addressing PR comments. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Remove function typing to avoid import on specific framework. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Quality. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Retry typing with @julien-c tip. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Quality². Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Simplify fill-mask mask_token checking. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Trigger CI
This commit is contained in:
parent
6c55e9fc32
commit
608d5a7c44
@ -87,6 +87,18 @@ def get_framework(model=None):
|
||||
return framework
|
||||
|
||||
|
||||
class PipelineException(Exception):
|
||||
"""
|
||||
Raised by pipelines when handling __call__
|
||||
"""
|
||||
|
||||
def __init__(self, task: str, model: str, reason: str):
|
||||
super().__init__(reason)
|
||||
|
||||
self.task = task
|
||||
self.model = model
|
||||
|
||||
|
||||
class ArgumentHandler(ABC):
|
||||
"""
|
||||
Base interface for handling varargs for each Pipeline
|
||||
@ -808,6 +820,21 @@ class FillMaskPipeline(Pipeline):
|
||||
|
||||
self.topk = topk
|
||||
|
||||
def ensure_exactly_one_mask_token(self, masked_index: np.ndarray):
|
||||
numel = np.prod(masked_index.shape)
|
||||
if numel > 1:
|
||||
raise PipelineException(
|
||||
"fill-mask",
|
||||
self.model.base_model_prefix,
|
||||
f"More than one mask_token ({self.tokenizer.mask_token}) is not supported",
|
||||
)
|
||||
elif numel < 1:
|
||||
raise PipelineException(
|
||||
"fill-mask",
|
||||
self.model.base_model_prefix,
|
||||
f"No mask_token ({self.tokenizer.mask_token}) found on the input",
|
||||
)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
inputs = self._parse_and_tokenize(*args, **kwargs)
|
||||
outputs = self._forward(inputs, return_tensors=True)
|
||||
@ -820,15 +847,22 @@ class FillMaskPipeline(Pipeline):
|
||||
result = []
|
||||
|
||||
if self.framework == "tf":
|
||||
masked_index = tf.where(input_ids == self.tokenizer.mask_token_id).numpy().item()
|
||||
logits = outputs[i, masked_index, :]
|
||||
masked_index = tf.where(input_ids == self.tokenizer.mask_token_id).numpy()
|
||||
|
||||
# Fill mask pipeline supports only one ${mask_token} per sample
|
||||
self.ensure_exactly_one_mask_token(masked_index)
|
||||
|
||||
logits = outputs[i, masked_index.item(), :]
|
||||
probs = tf.nn.softmax(logits)
|
||||
topk = tf.math.top_k(probs, k=self.topk)
|
||||
values, predictions = topk.values.numpy(), topk.indices.numpy()
|
||||
else:
|
||||
masked_index = (input_ids == self.tokenizer.mask_token_id).nonzero().item()
|
||||
masked_index = (input_ids == self.tokenizer.mask_token_id).nonzero()
|
||||
|
||||
logits = outputs[i, masked_index, :]
|
||||
# Fill mask pipeline supports only one ${mask_token} per sample
|
||||
self.ensure_exactly_one_mask_token(masked_index.numpy())
|
||||
|
||||
logits = outputs[i, masked_index.item(), :]
|
||||
probs = logits.softmax(dim=0)
|
||||
values, predictions = probs.topk(self.topk)
|
||||
|
||||
|
@ -217,9 +217,15 @@ class MonoColumnInputTestCase(unittest.TestCase):
|
||||
"My name is <mask>",
|
||||
"The largest city in France is <mask>",
|
||||
]
|
||||
invalid_inputs = [
|
||||
"This is <mask> <mask>" # More than 1 mask_token in the input is not supported
|
||||
"This is" # No mask_token is not supported
|
||||
]
|
||||
for model_name in FILL_MASK_FINETUNED_MODELS:
|
||||
nlp = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="pt", topk=2,)
|
||||
self._test_mono_column_pipeline(nlp, valid_inputs, mandatory_keys, expected_check_keys=["sequence"])
|
||||
self._test_mono_column_pipeline(
|
||||
nlp, valid_inputs, mandatory_keys, invalid_inputs, expected_check_keys=["sequence"]
|
||||
)
|
||||
|
||||
@require_tf
|
||||
def test_tf_fill_mask(self):
|
||||
@ -228,9 +234,15 @@ class MonoColumnInputTestCase(unittest.TestCase):
|
||||
"My name is <mask>",
|
||||
"The largest city in France is <mask>",
|
||||
]
|
||||
invalid_inputs = [
|
||||
"This is <mask> <mask>" # More than 1 mask_token in the input is not supported
|
||||
"This is" # No mask_token is not supported
|
||||
]
|
||||
for model_name in FILL_MASK_FINETUNED_MODELS:
|
||||
nlp = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="tf", topk=2,)
|
||||
self._test_mono_column_pipeline(nlp, valid_inputs, mandatory_keys, expected_check_keys=["sequence"])
|
||||
self._test_mono_column_pipeline(
|
||||
nlp, valid_inputs, mandatory_keys, invalid_inputs, expected_check_keys=["sequence"]
|
||||
)
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
|
Loading…
Reference in New Issue
Block a user