Fixing the pipeline optimization by reindexing targets (V2) (#12330)

* Fixing the pipeline optimization by rescaling the logits first.

* Add test for target equivalence

Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
Nicolas Patry 2021-07-08 16:58:15 +02:00 committed by GitHub
parent 2aa3cd935d
commit 4da568c152
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 139 additions and 39 deletions

View File

@ -98,9 +98,9 @@ class FillMaskPipeline(Pipeline):
args (:obj:`str` or :obj:`List[str]`):
One or several texts (or one list of prompts) with masked tokens.
targets (:obj:`str` or :obj:`List[str]`, `optional`):
When passed, the model will return the scores for the passed token or tokens rather than the top k
predictions in the entire vocabulary. If the provided targets are not in the model vocab, they will be
tokenized and the first resulting token will be used (with a warning).
When passed, the model will limit the scores to the passed targets instead of looking up in the whole
vocab. If the provided targets are not in the model vocab, they will be tokenized and the first
resulting token will be used (with a warning, and that might be slower).
top_k (:obj:`int`, `optional`):
When passed, overrides the number of predictions to return.
@ -115,25 +115,56 @@ class FillMaskPipeline(Pipeline):
inputs = self._parse_and_tokenize(*args, **kwargs)
outputs = self._forward(inputs, return_tensors=True)
# top_k must be defined
if top_k is None:
top_k = self.top_k
results = []
batch_size = outputs.shape[0] if self.framework == "tf" else outputs.size(0)
if targets is not None:
if len(targets) == 0 or len(targets[0]) == 0:
raise ValueError("At least one target must be provided when passed.")
if isinstance(targets, str):
targets = [targets]
targets_proc = []
try:
vocab = self.tokenizer.get_vocab()
except Exception:
vocab = {}
target_ids = []
for target in targets:
target_enc = self.tokenizer.tokenize(target)
if len(target_enc) > 1 or target_enc[0] == self.tokenizer.unk_token:
id_ = vocab.get(target, None)
if id_ is None:
input_ids = self.tokenizer(
target,
add_special_tokens=False,
return_attention_mask=False,
return_token_type_ids=False,
max_length=1,
truncation=True,
)["input_ids"]
if len(input_ids) == 0:
logger.warning(
f"The specified target token `{target}` does not exist in the model vocabulary. "
f"We cannot replace it with anything meaningful, ignoring it"
)
continue
id_ = input_ids[0]
# XXX: If users encounter this pass
# it becomes pretty slow, so let's make sure
# The warning enables them to fix the input to
# get faster performance.
logger.warning(
f"The specified target token `{target}` does not exist in the model vocabulary. "
f"Replacing with `{target_enc[0]}`."
f"Replacing with `{self.tokenizer.convert_ids_to_tokens(id_)}`."
)
targets_proc.append(target_enc[0])
target_inds = np.array(self.tokenizer.convert_tokens_to_ids(targets_proc))
target_ids.append(id_)
target_ids = list(set(target_ids))
if len(target_ids) == 0:
raise ValueError("At least one target must be provided when passed.")
target_ids = np.array(target_ids)
# Cap top_k if there are targets
if top_k > target_ids.shape[0]:
top_k = target_ids.shape[0]
for i in range(batch_size):
input_ids = inputs["input_ids"][i]
@ -147,14 +178,11 @@ class FillMaskPipeline(Pipeline):
logits = outputs[i, masked_index.item(), :]
probs = tf.nn.softmax(logits)
if targets is None:
topk = tf.math.top_k(probs, k=top_k if top_k is not None else self.top_k)
values, predictions = topk.values.numpy(), topk.indices.numpy()
else:
values = tf.gather_nd(probs, tf.reshape(target_inds, (-1, 1)))
sort_inds = tf.reverse(tf.argsort(values), [0])
values = tf.gather_nd(values, tf.reshape(sort_inds, (-1, 1))).numpy()
predictions = target_inds[sort_inds.numpy()]
if targets is not None:
probs = tf.gather_nd(probs, tf.reshape(target_ids, (-1, 1)))
topk = tf.math.top_k(probs, k=top_k)
values, predictions = topk.values.numpy(), topk.indices.numpy()
else:
masked_index = torch.nonzero(input_ids == self.tokenizer.mask_token_id, as_tuple=False)
@ -163,16 +191,15 @@ class FillMaskPipeline(Pipeline):
logits = outputs[i, masked_index.item(), :]
probs = logits.softmax(dim=0)
if targets is None:
values, predictions = probs.topk(top_k if top_k is not None else self.top_k)
else:
values = probs[..., target_inds]
sort_inds = list(reversed(values.argsort(dim=-1)))
values = values[..., sort_inds]
predictions = target_inds[sort_inds]
if targets is not None:
probs = probs[..., target_ids]
values, predictions = probs.topk(top_k)
for v, p in zip(values.tolist(), predictions.tolist()):
tokens = input_ids.numpy()
if targets is not None:
p = target_ids[p].tolist()
tokens[masked_index] = p
# Filter padding out:
tokens = tokens[np.where(tokens != self.tokenizer.pad_token_id)]

View File

@ -15,7 +15,7 @@
import unittest
from transformers import pipeline
from transformers.testing_utils import require_tf, require_torch, slow
from transformers.testing_utils import nested_simplify, require_tf, require_torch, slow
from .test_pipelines_common import MonoInputPipelineCommonMixin
@ -78,7 +78,8 @@ class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
@require_torch
def test_torch_fill_mask_with_targets(self):
valid_inputs = ["My name is <mask>"]
valid_targets = [[" Teven", " Patrick", " Clara"], [" Sam"]]
# ' Sam' will yield a warning but work
valid_targets = [[" Teven", "ĠPatrick", "ĠClara"], ["ĠSam"], [" Sam"]]
invalid_targets = [[], [""], ""]
for model_name in self.small_models:
unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="pt")
@ -89,19 +90,77 @@ class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
for targets in invalid_targets:
self.assertRaises(ValueError, unmasker, valid_inputs, targets=targets)
@require_torch
@slow
def test_torch_fill_mask_targets_equivalence(self):
model_name = self.large_models[0]
unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="pt")
unmasked = unmasker(self.valid_inputs[0])
tokens = [top_mask["token_str"] for top_mask in unmasked]
scores = [top_mask["score"] for top_mask in unmasked]
unmasked_targets = unmasker(self.valid_inputs[0], targets=tokens)
target_scores = [top_mask["score"] for top_mask in unmasked_targets]
self.assertEqual(scores, target_scores)
@require_torch
def test_torch_fill_mask_with_targets_and_topk(self):
model_name = self.small_models[0]
unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="pt")
targets = [" Teven", "ĠPatrick", "ĠClara"]
top_k = 2
outputs = unmasker("My name is <mask>", targets=targets, top_k=top_k)
self.assertEqual(
nested_simplify(outputs),
[
{"sequence": "My name is Patrick", "score": 0.0, "token": 3499, "token_str": " Patrick"},
{"sequence": "My name is Te", "score": 0.0, "token": 2941, "token_str": " Te"},
],
)
@require_torch
def test_torch_fill_mask_with_duplicate_targets_and_topk(self):
model_name = self.small_models[0]
unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="pt")
# String duplicates + id duplicates
targets = [" Teven", "ĠPatrick", "ĠClara", "ĠClara", " Clara"]
top_k = 10
outputs = unmasker("My name is <mask>", targets=targets, top_k=top_k)
# The target list contains duplicates, so we can't output more
# than them
self.assertEqual(len(outputs), 3)
@require_tf
def test_tf_fill_mask_with_targets(self):
valid_inputs = ["My name is <mask>"]
valid_targets = [[" Teven", " Patrick", " Clara"], [" Sam"]]
# ' Teven' will yield a warning but work as " Te"
invalid_targets = [[], [""], ""]
for model_name in self.small_models:
unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="tf")
for targets in valid_targets:
outputs = unmasker(valid_inputs, targets=targets)
self.assertIsInstance(outputs, list)
self.assertEqual(len(outputs), len(targets))
for targets in invalid_targets:
self.assertRaises(ValueError, unmasker, valid_inputs, targets=targets)
unmasker = pipeline(
task="fill-mask", model=self.small_models[0], tokenizer=self.small_models[0], framework="tf"
)
outputs = unmasker(valid_inputs, targets=[" Teven", "ĠPatrick", "ĠClara"])
self.assertEqual(
nested_simplify(outputs),
[
{"sequence": "My name is Clara", "score": 0.0, "token": 13606, "token_str": " Clara"},
{"sequence": "My name is Patrick", "score": 0.0, "token": 3499, "token_str": " Patrick"},
{"sequence": "My name is Te", "score": 0.0, "token": 2941, "token_str": " Te"},
],
)
# topk
outputs = unmasker(valid_inputs, targets=[" Teven", "ĠPatrick", "ĠClara"], top_k=2)
self.assertEqual(
nested_simplify(outputs),
[
{"sequence": "My name is Clara", "score": 0.0, "token": 13606, "token_str": " Clara"},
{"sequence": "My name is Patrick", "score": 0.0, "token": 3499, "token_str": " Patrick"},
],
)
for targets in invalid_targets:
with self.assertRaises(ValueError):
unmasker(valid_inputs, targets=targets)
@require_torch
@slow
@ -111,7 +170,7 @@ class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
"My name is <mask>",
"The largest city in France is <mask>",
]
valid_targets = [" Patrick", " Clara"]
valid_targets = ["ĠPatrick", "ĠClara"]
for model_name in self.large_models:
unmasker = pipeline(
task="fill-mask",
@ -184,7 +243,7 @@ class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
"My name is <mask>",
"The largest city in France is <mask>",
]
valid_targets = [" Patrick", " Clara"]
valid_targets = ["ĠPatrick", "ĠClara"]
for model_name in self.large_models:
unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="tf", top_k=2)
@ -242,3 +301,17 @@ class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
self.assertIn(key, result)
self.assertRaises(Exception, unmasker, [None])
@require_tf
@slow
def test_tf_fill_mask_targets_equivalence(self):
model_name = self.large_models[0]
unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="tf")
unmasked = unmasker(self.valid_inputs[0])
tokens = [top_mask["token_str"] for top_mask in unmasked]
scores = [top_mask["score"] for top_mask in unmasked]
unmasked_targets = unmasker(self.valid_inputs[0], targets=tokens)
target_scores = [top_mask["score"] for top_mask in unmasked_targets]
self.assertEqual(scores, target_scores)