diff --git a/src/transformers/pipelines/fill_mask.py b/src/transformers/pipelines/fill_mask.py index 86ce54b3e96..a34b67859c9 100644 --- a/src/transformers/pipelines/fill_mask.py +++ b/src/transformers/pipelines/fill_mask.py @@ -98,9 +98,9 @@ class FillMaskPipeline(Pipeline): args (:obj:`str` or :obj:`List[str]`): One or several texts (or one list of prompts) with masked tokens. targets (:obj:`str` or :obj:`List[str]`, `optional`): - When passed, the model will return the scores for the passed token or tokens rather than the top k - predictions in the entire vocabulary. If the provided targets are not in the model vocab, they will be - tokenized and the first resulting token will be used (with a warning). + When passed, the model will limit the scores to the passed targets instead of looking up in the whole + vocab. If the provided targets are not in the model vocab, they will be tokenized and the first + resulting token will be used (with a warning, and that might be slower). top_k (:obj:`int`, `optional`): When passed, overrides the number of predictions to return. @@ -115,25 +115,56 @@ class FillMaskPipeline(Pipeline): inputs = self._parse_and_tokenize(*args, **kwargs) outputs = self._forward(inputs, return_tensors=True) + # top_k must be defined + if top_k is None: + top_k = self.top_k + results = [] batch_size = outputs.shape[0] if self.framework == "tf" else outputs.size(0) if targets is not None: - if len(targets) == 0 or len(targets[0]) == 0: - raise ValueError("At least one target must be provided when passed.") if isinstance(targets, str): targets = [targets] - targets_proc = [] + try: + vocab = self.tokenizer.get_vocab() + except Exception: + vocab = {} + target_ids = [] for target in targets: - target_enc = self.tokenizer.tokenize(target) - if len(target_enc) > 1 or target_enc[0] == self.tokenizer.unk_token: + id_ = vocab.get(target, None) + if id_ is None: + input_ids = self.tokenizer( + target, + add_special_tokens=False, + return_attention_mask=False, + return_token_type_ids=False, + max_length=1, + truncation=True, + )["input_ids"] + if len(input_ids) == 0: + logger.warning( + f"The specified target token `{target}` does not exist in the model vocabulary. " + f"We cannot replace it with anything meaningful, ignoring it" + ) + continue + id_ = input_ids[0] + # XXX: If users encounter this pass + # it becomes pretty slow, so let's make sure + # The warning enables them to fix the input to + # get faster performance. logger.warning( f"The specified target token `{target}` does not exist in the model vocabulary. " - f"Replacing with `{target_enc[0]}`." + f"Replacing with `{self.tokenizer.convert_ids_to_tokens(id_)}`." ) - targets_proc.append(target_enc[0]) - target_inds = np.array(self.tokenizer.convert_tokens_to_ids(targets_proc)) + target_ids.append(id_) + target_ids = list(set(target_ids)) + if len(target_ids) == 0: + raise ValueError("At least one target must be provided when passed.") + target_ids = np.array(target_ids) + # Cap top_k if there are targets + if top_k > target_ids.shape[0]: + top_k = target_ids.shape[0] for i in range(batch_size): input_ids = inputs["input_ids"][i] @@ -147,14 +178,11 @@ class FillMaskPipeline(Pipeline): logits = outputs[i, masked_index.item(), :] probs = tf.nn.softmax(logits) - if targets is None: - topk = tf.math.top_k(probs, k=top_k if top_k is not None else self.top_k) - values, predictions = topk.values.numpy(), topk.indices.numpy() - else: - values = tf.gather_nd(probs, tf.reshape(target_inds, (-1, 1))) - sort_inds = tf.reverse(tf.argsort(values), [0]) - values = tf.gather_nd(values, tf.reshape(sort_inds, (-1, 1))).numpy() - predictions = target_inds[sort_inds.numpy()] + if targets is not None: + probs = tf.gather_nd(probs, tf.reshape(target_ids, (-1, 1))) + + topk = tf.math.top_k(probs, k=top_k) + values, predictions = topk.values.numpy(), topk.indices.numpy() else: masked_index = torch.nonzero(input_ids == self.tokenizer.mask_token_id, as_tuple=False) @@ -163,13 +191,11 @@ class FillMaskPipeline(Pipeline): logits = outputs[i, masked_index.item(), :] probs = logits.softmax(dim=0) - if targets is None: - values, predictions = probs.topk(top_k if top_k is not None else self.top_k) - else: - values = probs[..., target_inds] - sort_inds = list(reversed(values.argsort(dim=-1))) - values = values[..., sort_inds] - predictions = target_inds[sort_inds] + + if targets is not None: + probs = probs[..., target_ids] + + values, predictions = probs.topk(top_k) for v, p in zip(values.tolist(), predictions.tolist()): tokens = input_ids.numpy() diff --git a/tests/test_pipelines_fill_mask.py b/tests/test_pipelines_fill_mask.py index 8865bae0c8a..5de8b0b1f96 100644 --- a/tests/test_pipelines_fill_mask.py +++ b/tests/test_pipelines_fill_mask.py @@ -78,7 +78,8 @@ class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase): @require_torch def test_torch_fill_mask_with_targets(self): valid_inputs = ["My name is "] - valid_targets = [[" Teven", " Patrick", " Clara"], [" Sam"]] + # ' Sam' will yield a warning but work + valid_targets = [[" Teven", "ĠPatrick", "ĠClara"], ["ĠSam"], [" Sam"]] invalid_targets = [[], [""], ""] for model_name in self.small_models: unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="pt") @@ -89,10 +90,34 @@ class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase): for targets in invalid_targets: self.assertRaises(ValueError, unmasker, valid_inputs, targets=targets) + @require_torch + def test_torch_fill_mask_with_targets_and_topk(self): + model_name = self.small_models[0] + unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="pt") + targets = [" Teven", "ĠPatrick", "ĠClara"] + top_k = 2 + outputs = unmasker("My name is ", targets=targets, top_k=top_k) + + self.assertEqual(len(outputs), 2) + + @require_torch + def test_torch_fill_mask_with_duplicate_targets_and_topk(self): + model_name = self.small_models[0] + unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="pt") + # String duplicates + id duplicates + targets = [" Teven", "ĠPatrick", "ĠClara", "ĠClara", " Clara"] + top_k = 10 + outputs = unmasker("My name is ", targets=targets, top_k=top_k) + + # The target list contains duplicates, so we can't output more + # than them + self.assertEqual(len(outputs), 3) + @require_tf def test_tf_fill_mask_with_targets(self): valid_inputs = ["My name is "] - valid_targets = [[" Teven", " Patrick", " Clara"], [" Sam"]] + # ' Sam' will yield a warning but work + valid_targets = [[" Teven", "ĠPatrick", "ĠClara"], ["ĠSam"], [" Sam"]] invalid_targets = [[], [""], ""] for model_name in self.small_models: unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="tf") @@ -111,7 +136,7 @@ class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase): "My name is ", "The largest city in France is ", ] - valid_targets = [" Patrick", " Clara"] + valid_targets = ["ĠPatrick", "ĠClara"] for model_name in self.large_models: unmasker = pipeline( task="fill-mask", @@ -184,7 +209,7 @@ class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase): "My name is ", "The largest city in France is ", ] - valid_targets = [" Patrick", " Clara"] + valid_targets = ["ĠPatrick", "ĠClara"] for model_name in self.large_models: unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="tf", top_k=2)