mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 17:52:35 +06:00
Adding support for multiple mask tokens. (#14716)
* Adding support for multiple mask tokens. - Original implem: https://github.com/huggingface/transformers/pull/10222 Co-authored-by: njafer <naveen.jafer@oracle.com> * In order to accomodate optionally multimodal models like Perceiver we add information to the tasks to specify tasks where we know for sure if we need the tokenizer/feature_extractor or not. * Adding info in the documentation about multi masks. + marked as experimental. * Add a copy() to prevent overriding the same tensor over and over. * Fixup. * Adding small test for multi mask with real values.. Co-authored-by: njafer <naveen.jafer@oracle.com>
This commit is contained in:
parent
2a606f9974
commit
e7ed7ffdcb
@ -125,18 +125,21 @@ SUPPORTED_TASKS = {
|
||||
"tf": (),
|
||||
"pt": (AutoModelForAudioClassification,) if is_torch_available() else (),
|
||||
"default": {"model": {"pt": "superb/wav2vec2-base-superb-ks"}},
|
||||
"type": "audio",
|
||||
},
|
||||
"automatic-speech-recognition": {
|
||||
"impl": AutomaticSpeechRecognitionPipeline,
|
||||
"tf": (),
|
||||
"pt": (AutoModelForCTC, AutoModelForSpeechSeq2Seq) if is_torch_available() else (),
|
||||
"default": {"model": {"pt": "facebook/wav2vec2-base-960h"}},
|
||||
"type": "multimodal",
|
||||
},
|
||||
"feature-extraction": {
|
||||
"impl": FeatureExtractionPipeline,
|
||||
"tf": (TFAutoModel,) if is_tf_available() else (),
|
||||
"pt": (AutoModel,) if is_torch_available() else (),
|
||||
"default": {"model": {"pt": "distilbert-base-cased", "tf": "distilbert-base-cased"}},
|
||||
"type": "multimodal",
|
||||
},
|
||||
"text-classification": {
|
||||
"impl": TextClassificationPipeline,
|
||||
@ -148,6 +151,7 @@ SUPPORTED_TASKS = {
|
||||
"tf": "distilbert-base-uncased-finetuned-sst-2-english",
|
||||
},
|
||||
},
|
||||
"type": "text",
|
||||
},
|
||||
"token-classification": {
|
||||
"impl": TokenClassificationPipeline,
|
||||
@ -159,6 +163,7 @@ SUPPORTED_TASKS = {
|
||||
"tf": "dbmdz/bert-large-cased-finetuned-conll03-english",
|
||||
},
|
||||
},
|
||||
"type": "text",
|
||||
},
|
||||
"question-answering": {
|
||||
"impl": QuestionAnsweringPipeline,
|
||||
@ -167,6 +172,7 @@ SUPPORTED_TASKS = {
|
||||
"default": {
|
||||
"model": {"pt": "distilbert-base-cased-distilled-squad", "tf": "distilbert-base-cased-distilled-squad"},
|
||||
},
|
||||
"type": "text",
|
||||
},
|
||||
"table-question-answering": {
|
||||
"impl": TableQuestionAnsweringPipeline,
|
||||
@ -179,18 +185,21 @@ SUPPORTED_TASKS = {
|
||||
"tf": "google/tapas-base-finetuned-wtq",
|
||||
},
|
||||
},
|
||||
"type": "text",
|
||||
},
|
||||
"fill-mask": {
|
||||
"impl": FillMaskPipeline,
|
||||
"tf": (TFAutoModelForMaskedLM,) if is_tf_available() else (),
|
||||
"pt": (AutoModelForMaskedLM,) if is_torch_available() else (),
|
||||
"default": {"model": {"pt": "distilroberta-base", "tf": "distilroberta-base"}},
|
||||
"type": "text",
|
||||
},
|
||||
"summarization": {
|
||||
"impl": SummarizationPipeline,
|
||||
"tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (),
|
||||
"pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (),
|
||||
"default": {"model": {"pt": "sshleifer/distilbart-cnn-12-6", "tf": "t5-small"}},
|
||||
"type": "text",
|
||||
},
|
||||
# This task is a special case as it's parametrized by SRC, TGT languages.
|
||||
"translation": {
|
||||
@ -202,18 +211,21 @@ SUPPORTED_TASKS = {
|
||||
("en", "de"): {"model": {"pt": "t5-base", "tf": "t5-base"}},
|
||||
("en", "ro"): {"model": {"pt": "t5-base", "tf": "t5-base"}},
|
||||
},
|
||||
"type": "text",
|
||||
},
|
||||
"text2text-generation": {
|
||||
"impl": Text2TextGenerationPipeline,
|
||||
"tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (),
|
||||
"pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (),
|
||||
"default": {"model": {"pt": "t5-base", "tf": "t5-base"}},
|
||||
"type": "text",
|
||||
},
|
||||
"text-generation": {
|
||||
"impl": TextGenerationPipeline,
|
||||
"tf": (TFAutoModelForCausalLM,) if is_tf_available() else (),
|
||||
"pt": (AutoModelForCausalLM,) if is_torch_available() else (),
|
||||
"default": {"model": {"pt": "gpt2", "tf": "gpt2"}},
|
||||
"type": "text",
|
||||
},
|
||||
"zero-shot-classification": {
|
||||
"impl": ZeroShotClassificationPipeline,
|
||||
@ -224,33 +236,48 @@ SUPPORTED_TASKS = {
|
||||
"config": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"},
|
||||
"tokenizer": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"},
|
||||
},
|
||||
"type": "text",
|
||||
},
|
||||
"conversational": {
|
||||
"impl": ConversationalPipeline,
|
||||
"tf": (TFAutoModelForSeq2SeqLM, TFAutoModelForCausalLM) if is_tf_available() else (),
|
||||
"pt": (AutoModelForSeq2SeqLM, AutoModelForCausalLM) if is_torch_available() else (),
|
||||
"default": {"model": {"pt": "microsoft/DialoGPT-medium", "tf": "microsoft/DialoGPT-medium"}},
|
||||
"type": "text",
|
||||
},
|
||||
"image-classification": {
|
||||
"impl": ImageClassificationPipeline,
|
||||
"tf": (),
|
||||
"pt": (AutoModelForImageClassification,) if is_torch_available() else (),
|
||||
"default": {"model": {"pt": "google/vit-base-patch16-224"}},
|
||||
"type": "image",
|
||||
},
|
||||
"image-segmentation": {
|
||||
"impl": ImageSegmentationPipeline,
|
||||
"tf": (),
|
||||
"pt": (AutoModelForImageSegmentation,) if is_torch_available() else (),
|
||||
"default": {"model": {"pt": "facebook/detr-resnet-50-panoptic"}},
|
||||
"type": "image",
|
||||
},
|
||||
"object-detection": {
|
||||
"impl": ObjectDetectionPipeline,
|
||||
"tf": (),
|
||||
"pt": (AutoModelForObjectDetection,) if is_torch_available() else (),
|
||||
"default": {"model": {"pt": "facebook/detr-resnet-50"}},
|
||||
"type": "image",
|
||||
},
|
||||
}
|
||||
|
||||
NO_FEATURE_EXTRACTOR_TASKS = set()
|
||||
NO_TOKENIZER_TASKS = set()
|
||||
for task, values in SUPPORTED_TASKS.items():
|
||||
if values["type"] == "text":
|
||||
NO_FEATURE_EXTRACTOR_TASKS.add(task)
|
||||
elif values["type"] in {"audio", "image"}:
|
||||
NO_TOKENIZER_TASKS.add(task)
|
||||
elif values["type"] != "multimodal":
|
||||
raise ValueError(f"SUPPORTED_TASK {task} contains invalid type {values['type']}")
|
||||
|
||||
|
||||
def get_supported_tasks() -> List[str]:
|
||||
"""
|
||||
@ -528,12 +555,14 @@ def pipeline(
|
||||
load_tokenizer = type(model_config) in TOKENIZER_MAPPING or model_config.tokenizer_class is not None
|
||||
load_feature_extractor = type(model_config) in FEATURE_EXTRACTOR_MAPPING or feature_extractor is not None
|
||||
|
||||
if task in {"audio-classification", "image-classification"}:
|
||||
if task in NO_TOKENIZER_TASKS:
|
||||
# These will never require a tokenizer.
|
||||
# the model on the other hand might have a tokenizer, but
|
||||
# the files could be missing from the hub, instead of failing
|
||||
# on such repos, we just force to not load it.
|
||||
load_tokenizer = False
|
||||
if task in NO_FEATURE_EXTRACTOR_TASKS:
|
||||
load_feature_extractor = False
|
||||
|
||||
if load_tokenizer:
|
||||
# Try to infer tokenizer from model or config name (if provided as str)
|
||||
|
@ -44,7 +44,9 @@ class FillMaskPipeline(Pipeline):
|
||||
|
||||
.. note::
|
||||
|
||||
This pipeline only works for inputs with exactly one token masked.
|
||||
This pipeline only works for inputs with exactly one token masked. Experimental: We added support for multiple
|
||||
masks. The returned values are raw model output, and correspond to disjoint probabilities where one might
|
||||
expect joint probabilities (See `discussion <https://github.com/huggingface/transformers/pull/10222>`__).
|
||||
"""
|
||||
|
||||
def get_masked_index(self, input_ids: GenericTensor) -> np.ndarray:
|
||||
@ -59,13 +61,7 @@ class FillMaskPipeline(Pipeline):
|
||||
def _ensure_exactly_one_mask_token(self, input_ids: GenericTensor) -> np.ndarray:
|
||||
masked_index = self.get_masked_index(input_ids)
|
||||
numel = np.prod(masked_index.shape)
|
||||
if numel > 1:
|
||||
raise PipelineException(
|
||||
"fill-mask",
|
||||
self.model.base_model_prefix,
|
||||
f"More than one mask_token ({self.tokenizer.mask_token}) is not supported",
|
||||
)
|
||||
elif numel < 1:
|
||||
if numel < 1:
|
||||
raise PipelineException(
|
||||
"fill-mask",
|
||||
self.model.base_model_prefix,
|
||||
@ -98,46 +94,53 @@ class FillMaskPipeline(Pipeline):
|
||||
top_k = target_ids.shape[0]
|
||||
input_ids = model_outputs["input_ids"][0]
|
||||
outputs = model_outputs["logits"]
|
||||
result = []
|
||||
|
||||
if self.framework == "tf":
|
||||
masked_index = tf.where(input_ids == self.tokenizer.mask_token_id).numpy()
|
||||
masked_index = tf.where(input_ids == self.tokenizer.mask_token_id).numpy()[:, 0]
|
||||
|
||||
# Fill mask pipeline supports only one ${mask_token} per sample
|
||||
outputs = outputs.numpy()
|
||||
|
||||
logits = outputs[0, masked_index.item(), :]
|
||||
probs = tf.nn.softmax(logits)
|
||||
logits = outputs[0, masked_index, :]
|
||||
probs = tf.nn.softmax(logits, axis=-1)
|
||||
if target_ids is not None:
|
||||
probs = tf.gather_nd(probs, tf.reshape(target_ids, (-1, 1)))
|
||||
probs = tf.gather_nd(tf.squeeze(probs, 0), target_ids.reshape(-1, 1))
|
||||
probs = tf.expand_dims(probs, 0)
|
||||
|
||||
topk = tf.math.top_k(probs, k=top_k)
|
||||
values, predictions = topk.values.numpy(), topk.indices.numpy()
|
||||
else:
|
||||
masked_index = torch.nonzero(input_ids == self.tokenizer.mask_token_id, as_tuple=False)
|
||||
masked_index = torch.nonzero(input_ids == self.tokenizer.mask_token_id, as_tuple=False).squeeze(-1)
|
||||
# Fill mask pipeline supports only one ${mask_token} per sample
|
||||
|
||||
logits = outputs[0, masked_index.item(), :]
|
||||
probs = logits.softmax(dim=0)
|
||||
logits = outputs[0, masked_index, :]
|
||||
probs = logits.softmax(dim=-1)
|
||||
if target_ids is not None:
|
||||
probs = probs[..., target_ids]
|
||||
|
||||
values, predictions = probs.topk(top_k)
|
||||
|
||||
for v, p in zip(values.tolist(), predictions.tolist()):
|
||||
tokens = input_ids.numpy()
|
||||
if target_ids is not None:
|
||||
p = target_ids[p].tolist()
|
||||
tokens[masked_index] = p
|
||||
# Filter padding out:
|
||||
tokens = tokens[np.where(tokens != self.tokenizer.pad_token_id)]
|
||||
result.append(
|
||||
{
|
||||
"sequence": self.tokenizer.decode(tokens, skip_special_tokens=True),
|
||||
"score": v,
|
||||
"token": p,
|
||||
"token_str": self.tokenizer.decode(p),
|
||||
}
|
||||
)
|
||||
result = []
|
||||
single_mask = values.shape[0] == 1
|
||||
for i, (_values, _predictions) in enumerate(zip(values.tolist(), predictions.tolist())):
|
||||
row = []
|
||||
for v, p in zip(_values, _predictions):
|
||||
# Copy is important since we're going to modify this array in place
|
||||
tokens = input_ids.numpy().copy()
|
||||
if target_ids is not None:
|
||||
p = target_ids[p].tolist()
|
||||
|
||||
tokens[masked_index[i]] = p
|
||||
# Filter padding out:
|
||||
tokens = tokens[np.where(tokens != self.tokenizer.pad_token_id)]
|
||||
# Originally we skip special tokens to give readable output.
|
||||
# For multi masks though, the other [MASK] would be removed otherwise
|
||||
# making the output look odd, so we add them back
|
||||
sequence = self.tokenizer.decode(tokens, skip_special_tokens=single_mask)
|
||||
proposition = {"score": v, "token": p, "token_str": self.tokenizer.decode(p), "sequence": sequence}
|
||||
row.append(proposition)
|
||||
result.append(row)
|
||||
if single_mask:
|
||||
return result[0]
|
||||
return result
|
||||
|
||||
def get_target_ids(self, targets, top_k=None):
|
||||
|
@ -104,6 +104,32 @@ class FillMaskPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
|
||||
],
|
||||
)
|
||||
|
||||
outputs = unmasker("My name is <mask> <mask>", top_k=2)
|
||||
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs, decimals=6),
|
||||
[
|
||||
[
|
||||
{
|
||||
"score": 2.2e-05,
|
||||
"token": 35676,
|
||||
"token_str": " Maul",
|
||||
"sequence": "<s>My name is Maul<mask></s>",
|
||||
},
|
||||
{"score": 2.2e-05, "token": 16416, "token_str": "ELS", "sequence": "<s>My name isELS<mask></s>"},
|
||||
],
|
||||
[
|
||||
{
|
||||
"score": 2.2e-05,
|
||||
"token": 35676,
|
||||
"token_str": " Maul",
|
||||
"sequence": "<s>My name is<mask> Maul</s>",
|
||||
},
|
||||
{"score": 2.2e-05, "token": 16416, "token_str": "ELS", "sequence": "<s>My name is<mask>ELS</s>"},
|
||||
],
|
||||
],
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_large_model_pt(self):
|
||||
@ -231,9 +257,6 @@ class FillMaskPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
fill_masker([None])
|
||||
# Multiple masks
|
||||
with self.assertRaises(PipelineException):
|
||||
fill_masker(f"This is {tokenizer.mask_token} {tokenizer.mask_token}")
|
||||
# No mask_token is not supported
|
||||
with self.assertRaises(PipelineException):
|
||||
fill_masker("This is")
|
||||
@ -242,6 +265,7 @@ class FillMaskPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
|
||||
self.run_test_targets(model, tokenizer)
|
||||
self.run_test_top_k_targets(model, tokenizer)
|
||||
self.fill_mask_with_duplicate_targets_and_top_k(model, tokenizer)
|
||||
self.fill_mask_with_multiple_masks(model, tokenizer)
|
||||
|
||||
def run_test_targets(self, model, tokenizer):
|
||||
vocab = tokenizer.get_vocab()
|
||||
@ -340,3 +364,27 @@ class FillMaskPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
|
||||
# The target list contains duplicates, so we can't output more
|
||||
# than them
|
||||
self.assertEqual(len(outputs), 3)
|
||||
|
||||
def fill_mask_with_multiple_masks(self, model, tokenizer):
|
||||
fill_masker = FillMaskPipeline(model=model, tokenizer=tokenizer)
|
||||
|
||||
outputs = fill_masker(
|
||||
f"This is a {tokenizer.mask_token} {tokenizer.mask_token} {tokenizer.mask_token}", top_k=2
|
||||
)
|
||||
self.assertEqual(
|
||||
outputs,
|
||||
[
|
||||
[
|
||||
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
|
||||
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
|
||||
],
|
||||
[
|
||||
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
|
||||
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
|
||||
],
|
||||
[
|
||||
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
|
||||
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
|
||||
],
|
||||
],
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user