From e7ed7ffdcb66c78d3437ed4c3a63c3640f50f436 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 14 Dec 2021 16:46:16 +0100 Subject: [PATCH] Adding support for multiple mask tokens. (#14716) * Adding support for multiple mask tokens. - Original implem: https://github.com/huggingface/transformers/pull/10222 Co-authored-by: njafer * In order to accomodate optionally multimodal models like Perceiver we add information to the tasks to specify tasks where we know for sure if we need the tokenizer/feature_extractor or not. * Adding info in the documentation about multi masks. + marked as experimental. * Add a copy() to prevent overriding the same tensor over and over. * Fixup. * Adding small test for multi mask with real values.. Co-authored-by: njafer --- src/transformers/pipelines/__init__.py | 31 +++++++++++- src/transformers/pipelines/fill_mask.py | 67 +++++++++++++------------ tests/test_pipelines_fill_mask.py | 54 ++++++++++++++++++-- 3 files changed, 116 insertions(+), 36 deletions(-) diff --git a/src/transformers/pipelines/__init__.py b/src/transformers/pipelines/__init__.py index c27bb4aef8e..4671981218a 100755 --- a/src/transformers/pipelines/__init__.py +++ b/src/transformers/pipelines/__init__.py @@ -125,18 +125,21 @@ SUPPORTED_TASKS = { "tf": (), "pt": (AutoModelForAudioClassification,) if is_torch_available() else (), "default": {"model": {"pt": "superb/wav2vec2-base-superb-ks"}}, + "type": "audio", }, "automatic-speech-recognition": { "impl": AutomaticSpeechRecognitionPipeline, "tf": (), "pt": (AutoModelForCTC, AutoModelForSpeechSeq2Seq) if is_torch_available() else (), "default": {"model": {"pt": "facebook/wav2vec2-base-960h"}}, + "type": "multimodal", }, "feature-extraction": { "impl": FeatureExtractionPipeline, "tf": (TFAutoModel,) if is_tf_available() else (), "pt": (AutoModel,) if is_torch_available() else (), "default": {"model": {"pt": "distilbert-base-cased", "tf": "distilbert-base-cased"}}, + "type": "multimodal", }, "text-classification": { "impl": TextClassificationPipeline, @@ -148,6 +151,7 @@ SUPPORTED_TASKS = { "tf": "distilbert-base-uncased-finetuned-sst-2-english", }, }, + "type": "text", }, "token-classification": { "impl": TokenClassificationPipeline, @@ -159,6 +163,7 @@ SUPPORTED_TASKS = { "tf": "dbmdz/bert-large-cased-finetuned-conll03-english", }, }, + "type": "text", }, "question-answering": { "impl": QuestionAnsweringPipeline, @@ -167,6 +172,7 @@ SUPPORTED_TASKS = { "default": { "model": {"pt": "distilbert-base-cased-distilled-squad", "tf": "distilbert-base-cased-distilled-squad"}, }, + "type": "text", }, "table-question-answering": { "impl": TableQuestionAnsweringPipeline, @@ -179,18 +185,21 @@ SUPPORTED_TASKS = { "tf": "google/tapas-base-finetuned-wtq", }, }, + "type": "text", }, "fill-mask": { "impl": FillMaskPipeline, "tf": (TFAutoModelForMaskedLM,) if is_tf_available() else (), "pt": (AutoModelForMaskedLM,) if is_torch_available() else (), "default": {"model": {"pt": "distilroberta-base", "tf": "distilroberta-base"}}, + "type": "text", }, "summarization": { "impl": SummarizationPipeline, "tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (), "pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (), "default": {"model": {"pt": "sshleifer/distilbart-cnn-12-6", "tf": "t5-small"}}, + "type": "text", }, # This task is a special case as it's parametrized by SRC, TGT languages. "translation": { @@ -202,18 +211,21 @@ SUPPORTED_TASKS = { ("en", "de"): {"model": {"pt": "t5-base", "tf": "t5-base"}}, ("en", "ro"): {"model": {"pt": "t5-base", "tf": "t5-base"}}, }, + "type": "text", }, "text2text-generation": { "impl": Text2TextGenerationPipeline, "tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (), "pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (), "default": {"model": {"pt": "t5-base", "tf": "t5-base"}}, + "type": "text", }, "text-generation": { "impl": TextGenerationPipeline, "tf": (TFAutoModelForCausalLM,) if is_tf_available() else (), "pt": (AutoModelForCausalLM,) if is_torch_available() else (), "default": {"model": {"pt": "gpt2", "tf": "gpt2"}}, + "type": "text", }, "zero-shot-classification": { "impl": ZeroShotClassificationPipeline, @@ -224,33 +236,48 @@ SUPPORTED_TASKS = { "config": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"}, "tokenizer": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"}, }, + "type": "text", }, "conversational": { "impl": ConversationalPipeline, "tf": (TFAutoModelForSeq2SeqLM, TFAutoModelForCausalLM) if is_tf_available() else (), "pt": (AutoModelForSeq2SeqLM, AutoModelForCausalLM) if is_torch_available() else (), "default": {"model": {"pt": "microsoft/DialoGPT-medium", "tf": "microsoft/DialoGPT-medium"}}, + "type": "text", }, "image-classification": { "impl": ImageClassificationPipeline, "tf": (), "pt": (AutoModelForImageClassification,) if is_torch_available() else (), "default": {"model": {"pt": "google/vit-base-patch16-224"}}, + "type": "image", }, "image-segmentation": { "impl": ImageSegmentationPipeline, "tf": (), "pt": (AutoModelForImageSegmentation,) if is_torch_available() else (), "default": {"model": {"pt": "facebook/detr-resnet-50-panoptic"}}, + "type": "image", }, "object-detection": { "impl": ObjectDetectionPipeline, "tf": (), "pt": (AutoModelForObjectDetection,) if is_torch_available() else (), "default": {"model": {"pt": "facebook/detr-resnet-50"}}, + "type": "image", }, } +NO_FEATURE_EXTRACTOR_TASKS = set() +NO_TOKENIZER_TASKS = set() +for task, values in SUPPORTED_TASKS.items(): + if values["type"] == "text": + NO_FEATURE_EXTRACTOR_TASKS.add(task) + elif values["type"] in {"audio", "image"}: + NO_TOKENIZER_TASKS.add(task) + elif values["type"] != "multimodal": + raise ValueError(f"SUPPORTED_TASK {task} contains invalid type {values['type']}") + def get_supported_tasks() -> List[str]: """ @@ -528,12 +555,14 @@ def pipeline( load_tokenizer = type(model_config) in TOKENIZER_MAPPING or model_config.tokenizer_class is not None load_feature_extractor = type(model_config) in FEATURE_EXTRACTOR_MAPPING or feature_extractor is not None - if task in {"audio-classification", "image-classification"}: + if task in NO_TOKENIZER_TASKS: # These will never require a tokenizer. # the model on the other hand might have a tokenizer, but # the files could be missing from the hub, instead of failing # on such repos, we just force to not load it. load_tokenizer = False + if task in NO_FEATURE_EXTRACTOR_TASKS: + load_feature_extractor = False if load_tokenizer: # Try to infer tokenizer from model or config name (if provided as str) diff --git a/src/transformers/pipelines/fill_mask.py b/src/transformers/pipelines/fill_mask.py index 94c63d0b8b5..2e7ab0ed905 100644 --- a/src/transformers/pipelines/fill_mask.py +++ b/src/transformers/pipelines/fill_mask.py @@ -44,7 +44,9 @@ class FillMaskPipeline(Pipeline): .. note:: - This pipeline only works for inputs with exactly one token masked. + This pipeline only works for inputs with exactly one token masked. Experimental: We added support for multiple + masks. The returned values are raw model output, and correspond to disjoint probabilities where one might + expect joint probabilities (See `discussion `__). """ def get_masked_index(self, input_ids: GenericTensor) -> np.ndarray: @@ -59,13 +61,7 @@ class FillMaskPipeline(Pipeline): def _ensure_exactly_one_mask_token(self, input_ids: GenericTensor) -> np.ndarray: masked_index = self.get_masked_index(input_ids) numel = np.prod(masked_index.shape) - if numel > 1: - raise PipelineException( - "fill-mask", - self.model.base_model_prefix, - f"More than one mask_token ({self.tokenizer.mask_token}) is not supported", - ) - elif numel < 1: + if numel < 1: raise PipelineException( "fill-mask", self.model.base_model_prefix, @@ -98,46 +94,53 @@ class FillMaskPipeline(Pipeline): top_k = target_ids.shape[0] input_ids = model_outputs["input_ids"][0] outputs = model_outputs["logits"] - result = [] if self.framework == "tf": - masked_index = tf.where(input_ids == self.tokenizer.mask_token_id).numpy() + masked_index = tf.where(input_ids == self.tokenizer.mask_token_id).numpy()[:, 0] - # Fill mask pipeline supports only one ${mask_token} per sample + outputs = outputs.numpy() - logits = outputs[0, masked_index.item(), :] - probs = tf.nn.softmax(logits) + logits = outputs[0, masked_index, :] + probs = tf.nn.softmax(logits, axis=-1) if target_ids is not None: - probs = tf.gather_nd(probs, tf.reshape(target_ids, (-1, 1))) + probs = tf.gather_nd(tf.squeeze(probs, 0), target_ids.reshape(-1, 1)) + probs = tf.expand_dims(probs, 0) topk = tf.math.top_k(probs, k=top_k) values, predictions = topk.values.numpy(), topk.indices.numpy() else: - masked_index = torch.nonzero(input_ids == self.tokenizer.mask_token_id, as_tuple=False) + masked_index = torch.nonzero(input_ids == self.tokenizer.mask_token_id, as_tuple=False).squeeze(-1) # Fill mask pipeline supports only one ${mask_token} per sample - logits = outputs[0, masked_index.item(), :] - probs = logits.softmax(dim=0) + logits = outputs[0, masked_index, :] + probs = logits.softmax(dim=-1) if target_ids is not None: probs = probs[..., target_ids] values, predictions = probs.topk(top_k) - for v, p in zip(values.tolist(), predictions.tolist()): - tokens = input_ids.numpy() - if target_ids is not None: - p = target_ids[p].tolist() - tokens[masked_index] = p - # Filter padding out: - tokens = tokens[np.where(tokens != self.tokenizer.pad_token_id)] - result.append( - { - "sequence": self.tokenizer.decode(tokens, skip_special_tokens=True), - "score": v, - "token": p, - "token_str": self.tokenizer.decode(p), - } - ) + result = [] + single_mask = values.shape[0] == 1 + for i, (_values, _predictions) in enumerate(zip(values.tolist(), predictions.tolist())): + row = [] + for v, p in zip(_values, _predictions): + # Copy is important since we're going to modify this array in place + tokens = input_ids.numpy().copy() + if target_ids is not None: + p = target_ids[p].tolist() + + tokens[masked_index[i]] = p + # Filter padding out: + tokens = tokens[np.where(tokens != self.tokenizer.pad_token_id)] + # Originally we skip special tokens to give readable output. + # For multi masks though, the other [MASK] would be removed otherwise + # making the output look odd, so we add them back + sequence = self.tokenizer.decode(tokens, skip_special_tokens=single_mask) + proposition = {"score": v, "token": p, "token_str": self.tokenizer.decode(p), "sequence": sequence} + row.append(proposition) + result.append(row) + if single_mask: + return result[0] return result def get_target_ids(self, targets, top_k=None): diff --git a/tests/test_pipelines_fill_mask.py b/tests/test_pipelines_fill_mask.py index 43801ef0c1c..ed551bf6f49 100644 --- a/tests/test_pipelines_fill_mask.py +++ b/tests/test_pipelines_fill_mask.py @@ -104,6 +104,32 @@ class FillMaskPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta): ], ) + outputs = unmasker("My name is ", top_k=2) + + self.assertEqual( + nested_simplify(outputs, decimals=6), + [ + [ + { + "score": 2.2e-05, + "token": 35676, + "token_str": " Maul", + "sequence": "My name is Maul", + }, + {"score": 2.2e-05, "token": 16416, "token_str": "ELS", "sequence": "My name isELS"}, + ], + [ + { + "score": 2.2e-05, + "token": 35676, + "token_str": " Maul", + "sequence": "My name is Maul", + }, + {"score": 2.2e-05, "token": 16416, "token_str": "ELS", "sequence": "My name isELS"}, + ], + ], + ) + @slow @require_torch def test_large_model_pt(self): @@ -231,9 +257,6 @@ class FillMaskPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta): with self.assertRaises(ValueError): fill_masker([None]) - # Multiple masks - with self.assertRaises(PipelineException): - fill_masker(f"This is {tokenizer.mask_token} {tokenizer.mask_token}") # No mask_token is not supported with self.assertRaises(PipelineException): fill_masker("This is") @@ -242,6 +265,7 @@ class FillMaskPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta): self.run_test_targets(model, tokenizer) self.run_test_top_k_targets(model, tokenizer) self.fill_mask_with_duplicate_targets_and_top_k(model, tokenizer) + self.fill_mask_with_multiple_masks(model, tokenizer) def run_test_targets(self, model, tokenizer): vocab = tokenizer.get_vocab() @@ -340,3 +364,27 @@ class FillMaskPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta): # The target list contains duplicates, so we can't output more # than them self.assertEqual(len(outputs), 3) + + def fill_mask_with_multiple_masks(self, model, tokenizer): + fill_masker = FillMaskPipeline(model=model, tokenizer=tokenizer) + + outputs = fill_masker( + f"This is a {tokenizer.mask_token} {tokenizer.mask_token} {tokenizer.mask_token}", top_k=2 + ) + self.assertEqual( + outputs, + [ + [ + {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)}, + {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)}, + ], + [ + {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)}, + {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)}, + ], + [ + {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)}, + {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)}, + ], + ], + )