Adding support for multiple mask tokens. (#14716)

* Adding support for multiple mask tokens.

- Original implem: https://github.com/huggingface/transformers/pull/10222

Co-authored-by: njafer <naveen.jafer@oracle.com>

* In order to accomodate optionally multimodal models like Perceiver

we add information to the tasks to specify tasks where we know for sure
if we need the tokenizer/feature_extractor or not.

* Adding info in the documentation about multi masks.

+ marked as experimental.

* Add a copy() to prevent overriding the same tensor over and over.

* Fixup.

* Adding small test for multi mask with real values..

Co-authored-by: njafer <naveen.jafer@oracle.com>
This commit is contained in:
Nicolas Patry 2021-12-14 16:46:16 +01:00 committed by GitHub
parent 2a606f9974
commit e7ed7ffdcb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 116 additions and 36 deletions

View File

@ -125,18 +125,21 @@ SUPPORTED_TASKS = {
"tf": (), "tf": (),
"pt": (AutoModelForAudioClassification,) if is_torch_available() else (), "pt": (AutoModelForAudioClassification,) if is_torch_available() else (),
"default": {"model": {"pt": "superb/wav2vec2-base-superb-ks"}}, "default": {"model": {"pt": "superb/wav2vec2-base-superb-ks"}},
"type": "audio",
}, },
"automatic-speech-recognition": { "automatic-speech-recognition": {
"impl": AutomaticSpeechRecognitionPipeline, "impl": AutomaticSpeechRecognitionPipeline,
"tf": (), "tf": (),
"pt": (AutoModelForCTC, AutoModelForSpeechSeq2Seq) if is_torch_available() else (), "pt": (AutoModelForCTC, AutoModelForSpeechSeq2Seq) if is_torch_available() else (),
"default": {"model": {"pt": "facebook/wav2vec2-base-960h"}}, "default": {"model": {"pt": "facebook/wav2vec2-base-960h"}},
"type": "multimodal",
}, },
"feature-extraction": { "feature-extraction": {
"impl": FeatureExtractionPipeline, "impl": FeatureExtractionPipeline,
"tf": (TFAutoModel,) if is_tf_available() else (), "tf": (TFAutoModel,) if is_tf_available() else (),
"pt": (AutoModel,) if is_torch_available() else (), "pt": (AutoModel,) if is_torch_available() else (),
"default": {"model": {"pt": "distilbert-base-cased", "tf": "distilbert-base-cased"}}, "default": {"model": {"pt": "distilbert-base-cased", "tf": "distilbert-base-cased"}},
"type": "multimodal",
}, },
"text-classification": { "text-classification": {
"impl": TextClassificationPipeline, "impl": TextClassificationPipeline,
@ -148,6 +151,7 @@ SUPPORTED_TASKS = {
"tf": "distilbert-base-uncased-finetuned-sst-2-english", "tf": "distilbert-base-uncased-finetuned-sst-2-english",
}, },
}, },
"type": "text",
}, },
"token-classification": { "token-classification": {
"impl": TokenClassificationPipeline, "impl": TokenClassificationPipeline,
@ -159,6 +163,7 @@ SUPPORTED_TASKS = {
"tf": "dbmdz/bert-large-cased-finetuned-conll03-english", "tf": "dbmdz/bert-large-cased-finetuned-conll03-english",
}, },
}, },
"type": "text",
}, },
"question-answering": { "question-answering": {
"impl": QuestionAnsweringPipeline, "impl": QuestionAnsweringPipeline,
@ -167,6 +172,7 @@ SUPPORTED_TASKS = {
"default": { "default": {
"model": {"pt": "distilbert-base-cased-distilled-squad", "tf": "distilbert-base-cased-distilled-squad"}, "model": {"pt": "distilbert-base-cased-distilled-squad", "tf": "distilbert-base-cased-distilled-squad"},
}, },
"type": "text",
}, },
"table-question-answering": { "table-question-answering": {
"impl": TableQuestionAnsweringPipeline, "impl": TableQuestionAnsweringPipeline,
@ -179,18 +185,21 @@ SUPPORTED_TASKS = {
"tf": "google/tapas-base-finetuned-wtq", "tf": "google/tapas-base-finetuned-wtq",
}, },
}, },
"type": "text",
}, },
"fill-mask": { "fill-mask": {
"impl": FillMaskPipeline, "impl": FillMaskPipeline,
"tf": (TFAutoModelForMaskedLM,) if is_tf_available() else (), "tf": (TFAutoModelForMaskedLM,) if is_tf_available() else (),
"pt": (AutoModelForMaskedLM,) if is_torch_available() else (), "pt": (AutoModelForMaskedLM,) if is_torch_available() else (),
"default": {"model": {"pt": "distilroberta-base", "tf": "distilroberta-base"}}, "default": {"model": {"pt": "distilroberta-base", "tf": "distilroberta-base"}},
"type": "text",
}, },
"summarization": { "summarization": {
"impl": SummarizationPipeline, "impl": SummarizationPipeline,
"tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (), "tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (),
"pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (), "pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (),
"default": {"model": {"pt": "sshleifer/distilbart-cnn-12-6", "tf": "t5-small"}}, "default": {"model": {"pt": "sshleifer/distilbart-cnn-12-6", "tf": "t5-small"}},
"type": "text",
}, },
# This task is a special case as it's parametrized by SRC, TGT languages. # This task is a special case as it's parametrized by SRC, TGT languages.
"translation": { "translation": {
@ -202,18 +211,21 @@ SUPPORTED_TASKS = {
("en", "de"): {"model": {"pt": "t5-base", "tf": "t5-base"}}, ("en", "de"): {"model": {"pt": "t5-base", "tf": "t5-base"}},
("en", "ro"): {"model": {"pt": "t5-base", "tf": "t5-base"}}, ("en", "ro"): {"model": {"pt": "t5-base", "tf": "t5-base"}},
}, },
"type": "text",
}, },
"text2text-generation": { "text2text-generation": {
"impl": Text2TextGenerationPipeline, "impl": Text2TextGenerationPipeline,
"tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (), "tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (),
"pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (), "pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (),
"default": {"model": {"pt": "t5-base", "tf": "t5-base"}}, "default": {"model": {"pt": "t5-base", "tf": "t5-base"}},
"type": "text",
}, },
"text-generation": { "text-generation": {
"impl": TextGenerationPipeline, "impl": TextGenerationPipeline,
"tf": (TFAutoModelForCausalLM,) if is_tf_available() else (), "tf": (TFAutoModelForCausalLM,) if is_tf_available() else (),
"pt": (AutoModelForCausalLM,) if is_torch_available() else (), "pt": (AutoModelForCausalLM,) if is_torch_available() else (),
"default": {"model": {"pt": "gpt2", "tf": "gpt2"}}, "default": {"model": {"pt": "gpt2", "tf": "gpt2"}},
"type": "text",
}, },
"zero-shot-classification": { "zero-shot-classification": {
"impl": ZeroShotClassificationPipeline, "impl": ZeroShotClassificationPipeline,
@ -224,33 +236,48 @@ SUPPORTED_TASKS = {
"config": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"}, "config": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"},
"tokenizer": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"}, "tokenizer": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"},
}, },
"type": "text",
}, },
"conversational": { "conversational": {
"impl": ConversationalPipeline, "impl": ConversationalPipeline,
"tf": (TFAutoModelForSeq2SeqLM, TFAutoModelForCausalLM) if is_tf_available() else (), "tf": (TFAutoModelForSeq2SeqLM, TFAutoModelForCausalLM) if is_tf_available() else (),
"pt": (AutoModelForSeq2SeqLM, AutoModelForCausalLM) if is_torch_available() else (), "pt": (AutoModelForSeq2SeqLM, AutoModelForCausalLM) if is_torch_available() else (),
"default": {"model": {"pt": "microsoft/DialoGPT-medium", "tf": "microsoft/DialoGPT-medium"}}, "default": {"model": {"pt": "microsoft/DialoGPT-medium", "tf": "microsoft/DialoGPT-medium"}},
"type": "text",
}, },
"image-classification": { "image-classification": {
"impl": ImageClassificationPipeline, "impl": ImageClassificationPipeline,
"tf": (), "tf": (),
"pt": (AutoModelForImageClassification,) if is_torch_available() else (), "pt": (AutoModelForImageClassification,) if is_torch_available() else (),
"default": {"model": {"pt": "google/vit-base-patch16-224"}}, "default": {"model": {"pt": "google/vit-base-patch16-224"}},
"type": "image",
}, },
"image-segmentation": { "image-segmentation": {
"impl": ImageSegmentationPipeline, "impl": ImageSegmentationPipeline,
"tf": (), "tf": (),
"pt": (AutoModelForImageSegmentation,) if is_torch_available() else (), "pt": (AutoModelForImageSegmentation,) if is_torch_available() else (),
"default": {"model": {"pt": "facebook/detr-resnet-50-panoptic"}}, "default": {"model": {"pt": "facebook/detr-resnet-50-panoptic"}},
"type": "image",
}, },
"object-detection": { "object-detection": {
"impl": ObjectDetectionPipeline, "impl": ObjectDetectionPipeline,
"tf": (), "tf": (),
"pt": (AutoModelForObjectDetection,) if is_torch_available() else (), "pt": (AutoModelForObjectDetection,) if is_torch_available() else (),
"default": {"model": {"pt": "facebook/detr-resnet-50"}}, "default": {"model": {"pt": "facebook/detr-resnet-50"}},
"type": "image",
}, },
} }
NO_FEATURE_EXTRACTOR_TASKS = set()
NO_TOKENIZER_TASKS = set()
for task, values in SUPPORTED_TASKS.items():
if values["type"] == "text":
NO_FEATURE_EXTRACTOR_TASKS.add(task)
elif values["type"] in {"audio", "image"}:
NO_TOKENIZER_TASKS.add(task)
elif values["type"] != "multimodal":
raise ValueError(f"SUPPORTED_TASK {task} contains invalid type {values['type']}")
def get_supported_tasks() -> List[str]: def get_supported_tasks() -> List[str]:
""" """
@ -528,12 +555,14 @@ def pipeline(
load_tokenizer = type(model_config) in TOKENIZER_MAPPING or model_config.tokenizer_class is not None load_tokenizer = type(model_config) in TOKENIZER_MAPPING or model_config.tokenizer_class is not None
load_feature_extractor = type(model_config) in FEATURE_EXTRACTOR_MAPPING or feature_extractor is not None load_feature_extractor = type(model_config) in FEATURE_EXTRACTOR_MAPPING or feature_extractor is not None
if task in {"audio-classification", "image-classification"}: if task in NO_TOKENIZER_TASKS:
# These will never require a tokenizer. # These will never require a tokenizer.
# the model on the other hand might have a tokenizer, but # the model on the other hand might have a tokenizer, but
# the files could be missing from the hub, instead of failing # the files could be missing from the hub, instead of failing
# on such repos, we just force to not load it. # on such repos, we just force to not load it.
load_tokenizer = False load_tokenizer = False
if task in NO_FEATURE_EXTRACTOR_TASKS:
load_feature_extractor = False
if load_tokenizer: if load_tokenizer:
# Try to infer tokenizer from model or config name (if provided as str) # Try to infer tokenizer from model or config name (if provided as str)

View File

@ -44,7 +44,9 @@ class FillMaskPipeline(Pipeline):
.. note:: .. note::
This pipeline only works for inputs with exactly one token masked. This pipeline only works for inputs with exactly one token masked. Experimental: We added support for multiple
masks. The returned values are raw model output, and correspond to disjoint probabilities where one might
expect joint probabilities (See `discussion <https://github.com/huggingface/transformers/pull/10222>`__).
""" """
def get_masked_index(self, input_ids: GenericTensor) -> np.ndarray: def get_masked_index(self, input_ids: GenericTensor) -> np.ndarray:
@ -59,13 +61,7 @@ class FillMaskPipeline(Pipeline):
def _ensure_exactly_one_mask_token(self, input_ids: GenericTensor) -> np.ndarray: def _ensure_exactly_one_mask_token(self, input_ids: GenericTensor) -> np.ndarray:
masked_index = self.get_masked_index(input_ids) masked_index = self.get_masked_index(input_ids)
numel = np.prod(masked_index.shape) numel = np.prod(masked_index.shape)
if numel > 1: if numel < 1:
raise PipelineException(
"fill-mask",
self.model.base_model_prefix,
f"More than one mask_token ({self.tokenizer.mask_token}) is not supported",
)
elif numel < 1:
raise PipelineException( raise PipelineException(
"fill-mask", "fill-mask",
self.model.base_model_prefix, self.model.base_model_prefix,
@ -98,46 +94,53 @@ class FillMaskPipeline(Pipeline):
top_k = target_ids.shape[0] top_k = target_ids.shape[0]
input_ids = model_outputs["input_ids"][0] input_ids = model_outputs["input_ids"][0]
outputs = model_outputs["logits"] outputs = model_outputs["logits"]
result = []
if self.framework == "tf": if self.framework == "tf":
masked_index = tf.where(input_ids == self.tokenizer.mask_token_id).numpy() masked_index = tf.where(input_ids == self.tokenizer.mask_token_id).numpy()[:, 0]
# Fill mask pipeline supports only one ${mask_token} per sample outputs = outputs.numpy()
logits = outputs[0, masked_index.item(), :] logits = outputs[0, masked_index, :]
probs = tf.nn.softmax(logits) probs = tf.nn.softmax(logits, axis=-1)
if target_ids is not None: if target_ids is not None:
probs = tf.gather_nd(probs, tf.reshape(target_ids, (-1, 1))) probs = tf.gather_nd(tf.squeeze(probs, 0), target_ids.reshape(-1, 1))
probs = tf.expand_dims(probs, 0)
topk = tf.math.top_k(probs, k=top_k) topk = tf.math.top_k(probs, k=top_k)
values, predictions = topk.values.numpy(), topk.indices.numpy() values, predictions = topk.values.numpy(), topk.indices.numpy()
else: else:
masked_index = torch.nonzero(input_ids == self.tokenizer.mask_token_id, as_tuple=False) masked_index = torch.nonzero(input_ids == self.tokenizer.mask_token_id, as_tuple=False).squeeze(-1)
# Fill mask pipeline supports only one ${mask_token} per sample # Fill mask pipeline supports only one ${mask_token} per sample
logits = outputs[0, masked_index.item(), :] logits = outputs[0, masked_index, :]
probs = logits.softmax(dim=0) probs = logits.softmax(dim=-1)
if target_ids is not None: if target_ids is not None:
probs = probs[..., target_ids] probs = probs[..., target_ids]
values, predictions = probs.topk(top_k) values, predictions = probs.topk(top_k)
for v, p in zip(values.tolist(), predictions.tolist()): result = []
tokens = input_ids.numpy() single_mask = values.shape[0] == 1
for i, (_values, _predictions) in enumerate(zip(values.tolist(), predictions.tolist())):
row = []
for v, p in zip(_values, _predictions):
# Copy is important since we're going to modify this array in place
tokens = input_ids.numpy().copy()
if target_ids is not None: if target_ids is not None:
p = target_ids[p].tolist() p = target_ids[p].tolist()
tokens[masked_index] = p
tokens[masked_index[i]] = p
# Filter padding out: # Filter padding out:
tokens = tokens[np.where(tokens != self.tokenizer.pad_token_id)] tokens = tokens[np.where(tokens != self.tokenizer.pad_token_id)]
result.append( # Originally we skip special tokens to give readable output.
{ # For multi masks though, the other [MASK] would be removed otherwise
"sequence": self.tokenizer.decode(tokens, skip_special_tokens=True), # making the output look odd, so we add them back
"score": v, sequence = self.tokenizer.decode(tokens, skip_special_tokens=single_mask)
"token": p, proposition = {"score": v, "token": p, "token_str": self.tokenizer.decode(p), "sequence": sequence}
"token_str": self.tokenizer.decode(p), row.append(proposition)
} result.append(row)
) if single_mask:
return result[0]
return result return result
def get_target_ids(self, targets, top_k=None): def get_target_ids(self, targets, top_k=None):

View File

@ -104,6 +104,32 @@ class FillMaskPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
], ],
) )
outputs = unmasker("My name is <mask> <mask>", top_k=2)
self.assertEqual(
nested_simplify(outputs, decimals=6),
[
[
{
"score": 2.2e-05,
"token": 35676,
"token_str": " Maul",
"sequence": "<s>My name is Maul<mask></s>",
},
{"score": 2.2e-05, "token": 16416, "token_str": "ELS", "sequence": "<s>My name isELS<mask></s>"},
],
[
{
"score": 2.2e-05,
"token": 35676,
"token_str": " Maul",
"sequence": "<s>My name is<mask> Maul</s>",
},
{"score": 2.2e-05, "token": 16416, "token_str": "ELS", "sequence": "<s>My name is<mask>ELS</s>"},
],
],
)
@slow @slow
@require_torch @require_torch
def test_large_model_pt(self): def test_large_model_pt(self):
@ -231,9 +257,6 @@ class FillMaskPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
fill_masker([None]) fill_masker([None])
# Multiple masks
with self.assertRaises(PipelineException):
fill_masker(f"This is {tokenizer.mask_token} {tokenizer.mask_token}")
# No mask_token is not supported # No mask_token is not supported
with self.assertRaises(PipelineException): with self.assertRaises(PipelineException):
fill_masker("This is") fill_masker("This is")
@ -242,6 +265,7 @@ class FillMaskPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
self.run_test_targets(model, tokenizer) self.run_test_targets(model, tokenizer)
self.run_test_top_k_targets(model, tokenizer) self.run_test_top_k_targets(model, tokenizer)
self.fill_mask_with_duplicate_targets_and_top_k(model, tokenizer) self.fill_mask_with_duplicate_targets_and_top_k(model, tokenizer)
self.fill_mask_with_multiple_masks(model, tokenizer)
def run_test_targets(self, model, tokenizer): def run_test_targets(self, model, tokenizer):
vocab = tokenizer.get_vocab() vocab = tokenizer.get_vocab()
@ -340,3 +364,27 @@ class FillMaskPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
# The target list contains duplicates, so we can't output more # The target list contains duplicates, so we can't output more
# than them # than them
self.assertEqual(len(outputs), 3) self.assertEqual(len(outputs), 3)
def fill_mask_with_multiple_masks(self, model, tokenizer):
fill_masker = FillMaskPipeline(model=model, tokenizer=tokenizer)
outputs = fill_masker(
f"This is a {tokenizer.mask_token} {tokenizer.mask_token} {tokenizer.mask_token}", top_k=2
)
self.assertEqual(
outputs,
[
[
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
],
[
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
],
[
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
],
],
)