mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
FillMaskPipeline: support passing top_k on __call__ (#7971)
* FillMaskPipeline: support passing top_k on __call__ Also move from topk to top_k * migrate to new param name in tests * Review from @sgugger
This commit is contained in:
parent
2e5052d4f1
commit
ff65beafa3
@ -1183,7 +1183,7 @@ class ZeroShotClassificationPipeline(Pipeline):
|
||||
@add_end_docstrings(
|
||||
PIPELINE_INIT_ARGS,
|
||||
r"""
|
||||
topk (:obj:`int`, defaults to 5): The number of predictions to return.
|
||||
top_k (:obj:`int`, defaults to 5): The number of predictions to return.
|
||||
""",
|
||||
)
|
||||
class FillMaskPipeline(Pipeline):
|
||||
@ -1212,8 +1212,9 @@ class FillMaskPipeline(Pipeline):
|
||||
framework: Optional[str] = None,
|
||||
args_parser: ArgumentHandler = None,
|
||||
device: int = -1,
|
||||
topk=5,
|
||||
top_k=5,
|
||||
task: str = "",
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
model=model,
|
||||
@ -1228,7 +1229,14 @@ class FillMaskPipeline(Pipeline):
|
||||
|
||||
self.check_model_type(TF_MODEL_WITH_LM_HEAD_MAPPING if self.framework == "tf" else MODEL_FOR_MASKED_LM_MAPPING)
|
||||
|
||||
self.topk = topk
|
||||
if "topk" in kwargs:
|
||||
warnings.warn(
|
||||
"The `topk` argument is deprecated and will be removed in a future version, use `top_k` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
self.top_k = kwargs.pop("topk")
|
||||
else:
|
||||
self.top_k = top_k
|
||||
|
||||
def ensure_exactly_one_mask_token(self, masked_index: np.ndarray):
|
||||
numel = np.prod(masked_index.shape)
|
||||
@ -1245,7 +1253,7 @@ class FillMaskPipeline(Pipeline):
|
||||
f"No mask_token ({self.tokenizer.mask_token}) found on the input",
|
||||
)
|
||||
|
||||
def __call__(self, *args, targets=None, **kwargs):
|
||||
def __call__(self, *args, targets=None, top_k: Optional[int] = None, **kwargs):
|
||||
"""
|
||||
Fill the masked token in the text(s) given as inputs.
|
||||
|
||||
@ -1256,6 +1264,8 @@ class FillMaskPipeline(Pipeline):
|
||||
When passed, the model will return the scores for the passed token or tokens rather than the top k
|
||||
predictions in the entire vocabulary. If the provided targets are not in the model vocab, they will
|
||||
be tokenized and the first resulting token will be used (with a warning).
|
||||
top_k (:obj:`int`, `optional`):
|
||||
When passed, overrides the number of predictions to return.
|
||||
|
||||
Return:
|
||||
A list or a list of list of :obj:`dict`: Each result comes as list of dictionaries with the
|
||||
@ -1303,7 +1313,7 @@ class FillMaskPipeline(Pipeline):
|
||||
logits = outputs[i, masked_index.item(), :]
|
||||
probs = tf.nn.softmax(logits)
|
||||
if targets is None:
|
||||
topk = tf.math.top_k(probs, k=self.topk)
|
||||
topk = tf.math.top_k(probs, k=top_k if top_k is not None else self.top_k)
|
||||
values, predictions = topk.values.numpy(), topk.indices.numpy()
|
||||
else:
|
||||
values = tf.gather_nd(probs, tf.reshape(target_inds, (-1, 1)))
|
||||
@ -1319,7 +1329,7 @@ class FillMaskPipeline(Pipeline):
|
||||
logits = outputs[i, masked_index.item(), :]
|
||||
probs = logits.softmax(dim=0)
|
||||
if targets is None:
|
||||
values, predictions = probs.topk(self.topk)
|
||||
values, predictions = probs.topk(top_k if top_k is not None else self.top_k)
|
||||
else:
|
||||
values = probs[..., target_inds]
|
||||
sort_inds = list(reversed(values.argsort(dim=-1)))
|
||||
|
@ -226,7 +226,7 @@ class MonoColumnInputTestCase(unittest.TestCase):
|
||||
model=model_name,
|
||||
tokenizer=model_name,
|
||||
framework="pt",
|
||||
topk=2,
|
||||
top_k=2,
|
||||
)
|
||||
self._test_mono_column_pipeline(
|
||||
nlp, valid_inputs, mandatory_keys, invalid_inputs, expected_check_keys=["sequence"]
|
||||
@ -249,7 +249,7 @@ class MonoColumnInputTestCase(unittest.TestCase):
|
||||
model=model_name,
|
||||
tokenizer=model_name,
|
||||
framework="tf",
|
||||
topk=2,
|
||||
top_k=2,
|
||||
)
|
||||
self._test_mono_column_pipeline(
|
||||
nlp, valid_inputs, mandatory_keys, invalid_inputs, expected_check_keys=["sequence"]
|
||||
@ -298,7 +298,7 @@ class MonoColumnInputTestCase(unittest.TestCase):
|
||||
model=model_name,
|
||||
tokenizer=model_name,
|
||||
framework="pt",
|
||||
topk=2,
|
||||
top_k=2,
|
||||
)
|
||||
self._test_mono_column_pipeline(
|
||||
nlp,
|
||||
@ -326,7 +326,7 @@ class MonoColumnInputTestCase(unittest.TestCase):
|
||||
]
|
||||
valid_targets = [" Patrick", " Clara"]
|
||||
for model_name in LARGE_FILL_MASK_FINETUNED_MODELS:
|
||||
nlp = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="tf", topk=2)
|
||||
nlp = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="tf", top_k=2)
|
||||
self._test_mono_column_pipeline(
|
||||
nlp,
|
||||
valid_inputs,
|
||||
|
Loading…
Reference in New Issue
Block a user