FillMaskPipeline: support passing top_k on __call__ (#7971)

* FillMaskPipeline: support passing top_k on __call__

Also move from topk to top_k

* migrate to new param name in tests

* Review from @sgugger
This commit is contained in:
Julien Chaumond 2020-10-22 18:54:25 +02:00 committed by GitHub
parent 2e5052d4f1
commit ff65beafa3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 20 additions and 10 deletions

View File

@ -1183,7 +1183,7 @@ class ZeroShotClassificationPipeline(Pipeline):
@add_end_docstrings(
PIPELINE_INIT_ARGS,
r"""
topk (:obj:`int`, defaults to 5): The number of predictions to return.
top_k (:obj:`int`, defaults to 5): The number of predictions to return.
""",
)
class FillMaskPipeline(Pipeline):
@ -1212,8 +1212,9 @@ class FillMaskPipeline(Pipeline):
framework: Optional[str] = None,
args_parser: ArgumentHandler = None,
device: int = -1,
topk=5,
top_k=5,
task: str = "",
**kwargs
):
super().__init__(
model=model,
@ -1228,7 +1229,14 @@ class FillMaskPipeline(Pipeline):
self.check_model_type(TF_MODEL_WITH_LM_HEAD_MAPPING if self.framework == "tf" else MODEL_FOR_MASKED_LM_MAPPING)
self.topk = topk
if "topk" in kwargs:
warnings.warn(
"The `topk` argument is deprecated and will be removed in a future version, use `top_k` instead.",
FutureWarning,
)
self.top_k = kwargs.pop("topk")
else:
self.top_k = top_k
def ensure_exactly_one_mask_token(self, masked_index: np.ndarray):
numel = np.prod(masked_index.shape)
@ -1245,7 +1253,7 @@ class FillMaskPipeline(Pipeline):
f"No mask_token ({self.tokenizer.mask_token}) found on the input",
)
def __call__(self, *args, targets=None, **kwargs):
def __call__(self, *args, targets=None, top_k: Optional[int] = None, **kwargs):
"""
Fill the masked token in the text(s) given as inputs.
@ -1256,6 +1264,8 @@ class FillMaskPipeline(Pipeline):
When passed, the model will return the scores for the passed token or tokens rather than the top k
predictions in the entire vocabulary. If the provided targets are not in the model vocab, they will
be tokenized and the first resulting token will be used (with a warning).
top_k (:obj:`int`, `optional`):
When passed, overrides the number of predictions to return.
Return:
A list or a list of list of :obj:`dict`: Each result comes as list of dictionaries with the
@ -1303,7 +1313,7 @@ class FillMaskPipeline(Pipeline):
logits = outputs[i, masked_index.item(), :]
probs = tf.nn.softmax(logits)
if targets is None:
topk = tf.math.top_k(probs, k=self.topk)
topk = tf.math.top_k(probs, k=top_k if top_k is not None else self.top_k)
values, predictions = topk.values.numpy(), topk.indices.numpy()
else:
values = tf.gather_nd(probs, tf.reshape(target_inds, (-1, 1)))
@ -1319,7 +1329,7 @@ class FillMaskPipeline(Pipeline):
logits = outputs[i, masked_index.item(), :]
probs = logits.softmax(dim=0)
if targets is None:
values, predictions = probs.topk(self.topk)
values, predictions = probs.topk(top_k if top_k is not None else self.top_k)
else:
values = probs[..., target_inds]
sort_inds = list(reversed(values.argsort(dim=-1)))

View File

@ -226,7 +226,7 @@ class MonoColumnInputTestCase(unittest.TestCase):
model=model_name,
tokenizer=model_name,
framework="pt",
topk=2,
top_k=2,
)
self._test_mono_column_pipeline(
nlp, valid_inputs, mandatory_keys, invalid_inputs, expected_check_keys=["sequence"]
@ -249,7 +249,7 @@ class MonoColumnInputTestCase(unittest.TestCase):
model=model_name,
tokenizer=model_name,
framework="tf",
topk=2,
top_k=2,
)
self._test_mono_column_pipeline(
nlp, valid_inputs, mandatory_keys, invalid_inputs, expected_check_keys=["sequence"]
@ -298,7 +298,7 @@ class MonoColumnInputTestCase(unittest.TestCase):
model=model_name,
tokenizer=model_name,
framework="pt",
topk=2,
top_k=2,
)
self._test_mono_column_pipeline(
nlp,
@ -326,7 +326,7 @@ class MonoColumnInputTestCase(unittest.TestCase):
]
valid_targets = [" Patrick", " Clara"]
for model_name in LARGE_FILL_MASK_FINETUNED_MODELS:
nlp = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="tf", topk=2)
nlp = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="tf", top_k=2)
self._test_mono_column_pipeline(
nlp,
valid_inputs,