transformers/tests/test_pipelines_fill_mask.py
Nicolas Patry d4be498441
Optimizing away the fill-mask pipeline. (#12113)
* Optimizing away the `fill-mask` pipeline.

- Don't send anything to the tokenizer unless needed. Vocab check is
much faster
- Keep BC by sending data to the tokenizer when needed. User handling warning messages will see performance benefits again
- Make `targets` and `top_k` work together better `top_k` cannot be
higher than `len(targets)` but can be smaller still.
- Actually simplify the `target_ids` in case of duplicate (it can happen
because we're parsing raw strings)
- Removed useless code to fail on empty strings. It works only if empty
string is in first position, moved to ignoring them instead.
- Changed the related tests as only the tests would fail correctly
(having incorrect value in first position)

* Make tests compatible for 2 different vocabs... (at the price of a
warning).

Co-authored-by: @EtaoinWu

* ValueError working globally

* Update src/transformers/pipelines/fill_mask.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* `tokenizer.vocab` -> `tokenizer.get_vocab()` for more compatiblity +
fallback.

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
2021-06-23 10:38:04 +02:00

270 lines
11 KiB
Python

# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import pipeline
from transformers.testing_utils import require_tf, require_torch, slow
from .test_pipelines_common import MonoInputPipelineCommonMixin
EXPECTED_FILL_MASK_RESULT = [
[
{"sequence": "My name is John", "score": 0.00782308354973793, "token": 610, "token_str": " John"},
{"sequence": "My name is Chris", "score": 0.007475061342120171, "token": 1573, "token_str": " Chris"},
],
[
{
"sequence": "The largest city in France is Paris",
"score": 0.2510891854763031,
"token": 2201,
"token_str": " Paris",
},
{
"sequence": "The largest city in France is Lyon",
"score": 0.21418564021587372,
"token": 12790,
"token_str": " Lyon",
},
],
]
EXPECTED_FILL_MASK_TARGET_RESULT = [EXPECTED_FILL_MASK_RESULT[0]]
class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
pipeline_task = "fill-mask"
pipeline_loading_kwargs = {"top_k": 2}
small_models = ["sshleifer/tiny-distilroberta-base"] # Models tested without the @slow decorator
large_models = ["distilroberta-base"] # Models tested with the @slow decorator
mandatory_keys = {"sequence", "score", "token"}
valid_inputs = [
"My name is <mask>",
"The largest city in France is <mask>",
]
invalid_inputs = [
"This is <mask> <mask>" # More than 1 mask_token in the input is not supported
"This is" # No mask_token is not supported
]
expected_check_keys = ["sequence"]
@require_torch
def test_torch_fill_mask(self):
valid_inputs = "My name is <mask>"
unmasker = pipeline(task="fill-mask", model=self.small_models[0])
outputs = unmasker(valid_inputs)
self.assertIsInstance(outputs, list)
# This passes
outputs = unmasker(valid_inputs, targets=[" Patrick", " Clara"])
self.assertIsInstance(outputs, list)
# This used to fail with `cannot mix args and kwargs`
outputs = unmasker(valid_inputs, something=False)
self.assertIsInstance(outputs, list)
@require_torch
def test_torch_fill_mask_with_targets(self):
valid_inputs = ["My name is <mask>"]
# ' Sam' will yield a warning but work
valid_targets = [[" Teven", "ĠPatrick", "ĠClara"], ["ĠSam"], [" Sam"]]
invalid_targets = [[], [""], ""]
for model_name in self.small_models:
unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="pt")
for targets in valid_targets:
outputs = unmasker(valid_inputs, targets=targets)
self.assertIsInstance(outputs, list)
self.assertEqual(len(outputs), len(targets))
for targets in invalid_targets:
self.assertRaises(ValueError, unmasker, valid_inputs, targets=targets)
@require_torch
def test_torch_fill_mask_with_targets_and_topk(self):
model_name = self.small_models[0]
unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="pt")
targets = [" Teven", "ĠPatrick", "ĠClara"]
top_k = 2
outputs = unmasker("My name is <mask>", targets=targets, top_k=top_k)
self.assertEqual(len(outputs), 2)
@require_torch
def test_torch_fill_mask_with_duplicate_targets_and_topk(self):
model_name = self.small_models[0]
unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="pt")
# String duplicates + id duplicates
targets = [" Teven", "ĠPatrick", "ĠClara", "ĠClara", " Clara"]
top_k = 10
outputs = unmasker("My name is <mask>", targets=targets, top_k=top_k)
# The target list contains duplicates, so we can't output more
# than them
self.assertEqual(len(outputs), 3)
@require_tf
def test_tf_fill_mask_with_targets(self):
valid_inputs = ["My name is <mask>"]
# ' Sam' will yield a warning but work
valid_targets = [[" Teven", "ĠPatrick", "ĠClara"], ["ĠSam"], [" Sam"]]
invalid_targets = [[], [""], ""]
for model_name in self.small_models:
unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="tf")
for targets in valid_targets:
outputs = unmasker(valid_inputs, targets=targets)
self.assertIsInstance(outputs, list)
self.assertEqual(len(outputs), len(targets))
for targets in invalid_targets:
self.assertRaises(ValueError, unmasker, valid_inputs, targets=targets)
@require_torch
@slow
def test_torch_fill_mask_results(self):
mandatory_keys = {"sequence", "score", "token"}
valid_inputs = [
"My name is <mask>",
"The largest city in France is <mask>",
]
valid_targets = ["ĠPatrick", "ĠClara"]
for model_name in self.large_models:
unmasker = pipeline(
task="fill-mask",
model=model_name,
tokenizer=model_name,
framework="pt",
top_k=2,
)
mono_result = unmasker(valid_inputs[0], targets=valid_targets)
self.assertIsInstance(mono_result, list)
self.assertIsInstance(mono_result[0], dict)
for mandatory_key in mandatory_keys:
self.assertIn(mandatory_key, mono_result[0])
multi_result = [unmasker(valid_input) for valid_input in valid_inputs]
self.assertIsInstance(multi_result, list)
self.assertIsInstance(multi_result[0], (dict, list))
for result, expected in zip(multi_result, EXPECTED_FILL_MASK_RESULT):
for r, e in zip(result, expected):
self.assertEqual(r["sequence"], e["sequence"])
self.assertEqual(r["token_str"], e["token_str"])
self.assertEqual(r["token"], e["token"])
self.assertAlmostEqual(r["score"], e["score"], places=3)
if isinstance(multi_result[0], list):
multi_result = multi_result[0]
for result in multi_result:
for key in mandatory_keys:
self.assertIn(key, result)
self.assertRaises(Exception, unmasker, [None])
valid_inputs = valid_inputs[:1]
mono_result = unmasker(valid_inputs[0], targets=valid_targets)
self.assertIsInstance(mono_result, list)
self.assertIsInstance(mono_result[0], dict)
for mandatory_key in mandatory_keys:
self.assertIn(mandatory_key, mono_result[0])
multi_result = [unmasker(valid_input) for valid_input in valid_inputs]
self.assertIsInstance(multi_result, list)
self.assertIsInstance(multi_result[0], (dict, list))
for result, expected in zip(multi_result, EXPECTED_FILL_MASK_TARGET_RESULT):
for r, e in zip(result, expected):
self.assertEqual(r["sequence"], e["sequence"])
self.assertEqual(r["token_str"], e["token_str"])
self.assertEqual(r["token"], e["token"])
self.assertAlmostEqual(r["score"], e["score"], places=3)
if isinstance(multi_result[0], list):
multi_result = multi_result[0]
for result in multi_result:
for key in mandatory_keys:
self.assertIn(key, result)
self.assertRaises(Exception, unmasker, [None])
@require_tf
@slow
def test_tf_fill_mask_results(self):
mandatory_keys = {"sequence", "score", "token"}
valid_inputs = [
"My name is <mask>",
"The largest city in France is <mask>",
]
valid_targets = ["ĠPatrick", "ĠClara"]
for model_name in self.large_models:
unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="tf", top_k=2)
mono_result = unmasker(valid_inputs[0], targets=valid_targets)
self.assertIsInstance(mono_result, list)
self.assertIsInstance(mono_result[0], dict)
for mandatory_key in mandatory_keys:
self.assertIn(mandatory_key, mono_result[0])
multi_result = [unmasker(valid_input) for valid_input in valid_inputs]
self.assertIsInstance(multi_result, list)
self.assertIsInstance(multi_result[0], (dict, list))
for result, expected in zip(multi_result, EXPECTED_FILL_MASK_RESULT):
for r, e in zip(result, expected):
self.assertEqual(r["sequence"], e["sequence"])
self.assertEqual(r["token_str"], e["token_str"])
self.assertEqual(r["token"], e["token"])
self.assertAlmostEqual(r["score"], e["score"], places=3)
if isinstance(multi_result[0], list):
multi_result = multi_result[0]
for result in multi_result:
for key in mandatory_keys:
self.assertIn(key, result)
self.assertRaises(Exception, unmasker, [None])
valid_inputs = valid_inputs[:1]
mono_result = unmasker(valid_inputs[0], targets=valid_targets)
self.assertIsInstance(mono_result, list)
self.assertIsInstance(mono_result[0], dict)
for mandatory_key in mandatory_keys:
self.assertIn(mandatory_key, mono_result[0])
multi_result = [unmasker(valid_input) for valid_input in valid_inputs]
self.assertIsInstance(multi_result, list)
self.assertIsInstance(multi_result[0], (dict, list))
for result, expected in zip(multi_result, EXPECTED_FILL_MASK_TARGET_RESULT):
for r, e in zip(result, expected):
self.assertEqual(r["sequence"], e["sequence"])
self.assertEqual(r["token_str"], e["token_str"])
self.assertEqual(r["token"], e["token"])
self.assertAlmostEqual(r["score"], e["score"], places=3)
if isinstance(multi_result[0], list):
multi_result = multi_result[0]
for result in multi_result:
for key in mandatory_keys:
self.assertIn(key, result)
self.assertRaises(Exception, unmasker, [None])