mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add tokenizer kwargs to fill mask pipeline. (#26234)
* add tokenizer kwarg inputs * Adding tokenizer_kwargs to _sanitize_parameters * Add truncation=True example to tests * Update test_pipelines_fill_mask.py * Update test_pipelines_fill_mask.py * make fix-copies and make style * Update fill_mask.py Replace single tick with double * make fix-copies * Style --------- Co-authored-by: Lysandre <lysandre@huggingface.co>
This commit is contained in:
parent
df6a855e7b
commit
b5ca8fcd20
@ -129,7 +129,7 @@ class TFBertTokenizer(tf.keras.layers.Layer):
|
||||
pad_token_id = tokenizer.pad_token_id if pad_token_id is None else pad_token_id
|
||||
|
||||
vocab = tokenizer.get_vocab()
|
||||
vocab = sorted([(wordpiece, idx) for wordpiece, idx in vocab.items()], key=lambda x: x[1])
|
||||
vocab = sorted(vocab.items(), key=lambda x: x[1])
|
||||
vocab_list = [entry[0] for entry in vocab]
|
||||
return cls(
|
||||
vocab_list=vocab_list,
|
||||
|
@ -61,7 +61,28 @@ class FillMaskPipeline(Pipeline):
|
||||
masks. The returned values are raw model output, and correspond to disjoint probabilities where one might expect
|
||||
joint probabilities (See [discussion](https://github.com/huggingface/transformers/pull/10222)).
|
||||
|
||||
</Tip>"""
|
||||
</Tip>
|
||||
|
||||
<Tip>
|
||||
|
||||
This pipeline now supports tokenizer_kwargs. For example try:
|
||||
|
||||
```python
|
||||
>>> from transformers import pipeline
|
||||
|
||||
>>> fill_masker = pipeline(model="bert-base-uncased")
|
||||
>>> tokenizer_kwargs = {"truncation": True}
|
||||
>>> fill_masker(
|
||||
... "This is a simple [MASK]. " + "...with a large amount of repeated text appended. " * 100,
|
||||
... tokenizer_kwargs=tokenizer_kwargs,
|
||||
... )
|
||||
```
|
||||
|
||||
|
||||
</Tip>
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def get_masked_index(self, input_ids: GenericTensor) -> np.ndarray:
|
||||
if self.framework == "tf":
|
||||
@ -90,10 +111,15 @@ class FillMaskPipeline(Pipeline):
|
||||
for input_ids in model_inputs["input_ids"]:
|
||||
self._ensure_exactly_one_mask_token(input_ids)
|
||||
|
||||
def preprocess(self, inputs, return_tensors=None, **preprocess_parameters) -> Dict[str, GenericTensor]:
|
||||
def preprocess(
|
||||
self, inputs, return_tensors=None, tokenizer_kwargs=None, **preprocess_parameters
|
||||
) -> Dict[str, GenericTensor]:
|
||||
if return_tensors is None:
|
||||
return_tensors = self.framework
|
||||
model_inputs = self.tokenizer(inputs, return_tensors=return_tensors)
|
||||
if tokenizer_kwargs is None:
|
||||
tokenizer_kwargs = {}
|
||||
|
||||
model_inputs = self.tokenizer(inputs, return_tensors=return_tensors, **tokenizer_kwargs)
|
||||
self.ensure_exactly_one_mask_token(model_inputs)
|
||||
return model_inputs
|
||||
|
||||
@ -198,7 +224,12 @@ class FillMaskPipeline(Pipeline):
|
||||
target_ids = np.array(target_ids)
|
||||
return target_ids
|
||||
|
||||
def _sanitize_parameters(self, top_k=None, targets=None):
|
||||
def _sanitize_parameters(self, top_k=None, targets=None, tokenizer_kwargs=None):
|
||||
preprocess_params = {}
|
||||
|
||||
if tokenizer_kwargs is not None:
|
||||
preprocess_params["tokenizer_kwargs"] = tokenizer_kwargs
|
||||
|
||||
postprocess_params = {}
|
||||
|
||||
if targets is not None:
|
||||
@ -212,7 +243,7 @@ class FillMaskPipeline(Pipeline):
|
||||
raise PipelineException(
|
||||
"fill-mask", self.model.base_model_prefix, "The tokenizer does not define a `mask_token`."
|
||||
)
|
||||
return {}, {}, postprocess_params
|
||||
return preprocess_params, {}, postprocess_params
|
||||
|
||||
def __call__(self, inputs, *args, **kwargs):
|
||||
"""
|
||||
|
File diff suppressed because one or more lines are too long
@ -211,6 +211,18 @@ class FillMaskPipelineTests(unittest.TestCase):
|
||||
],
|
||||
)
|
||||
|
||||
outputs = unmasker(
|
||||
"My name is <mask>" + "Lorem ipsum dolor sit amet, consectetur adipiscing elit," * 100,
|
||||
tokenizer_kwargs={"truncation": True},
|
||||
)
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=6),
|
||||
[
|
||||
{"sequence": "My name is grouped", "score": 2.2e-05, "token": 38015, "token_str": " grouped"},
|
||||
{"sequence": "My name is accuser", "score": 2.1e-05, "token": 25506, "token_str": " accuser"},
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_model_no_pad_pt(self):
|
||||
unmasker = pipeline(task="fill-mask", model="sshleifer/tiny-distilroberta-base", framework="pt")
|
||||
|
Loading…
Reference in New Issue
Block a user