mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
Fixing the pipeline optimization by reindexing targets (V2) (#12330)
* Fixing the pipeline optimization by rescaling the logits first. * Add test for target equivalence Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
parent
2aa3cd935d
commit
4da568c152
@ -98,9 +98,9 @@ class FillMaskPipeline(Pipeline):
|
|||||||
args (:obj:`str` or :obj:`List[str]`):
|
args (:obj:`str` or :obj:`List[str]`):
|
||||||
One or several texts (or one list of prompts) with masked tokens.
|
One or several texts (or one list of prompts) with masked tokens.
|
||||||
targets (:obj:`str` or :obj:`List[str]`, `optional`):
|
targets (:obj:`str` or :obj:`List[str]`, `optional`):
|
||||||
When passed, the model will return the scores for the passed token or tokens rather than the top k
|
When passed, the model will limit the scores to the passed targets instead of looking up in the whole
|
||||||
predictions in the entire vocabulary. If the provided targets are not in the model vocab, they will be
|
vocab. If the provided targets are not in the model vocab, they will be tokenized and the first
|
||||||
tokenized and the first resulting token will be used (with a warning).
|
resulting token will be used (with a warning, and that might be slower).
|
||||||
top_k (:obj:`int`, `optional`):
|
top_k (:obj:`int`, `optional`):
|
||||||
When passed, overrides the number of predictions to return.
|
When passed, overrides the number of predictions to return.
|
||||||
|
|
||||||
@ -115,25 +115,56 @@ class FillMaskPipeline(Pipeline):
|
|||||||
inputs = self._parse_and_tokenize(*args, **kwargs)
|
inputs = self._parse_and_tokenize(*args, **kwargs)
|
||||||
outputs = self._forward(inputs, return_tensors=True)
|
outputs = self._forward(inputs, return_tensors=True)
|
||||||
|
|
||||||
|
# top_k must be defined
|
||||||
|
if top_k is None:
|
||||||
|
top_k = self.top_k
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
batch_size = outputs.shape[0] if self.framework == "tf" else outputs.size(0)
|
batch_size = outputs.shape[0] if self.framework == "tf" else outputs.size(0)
|
||||||
|
|
||||||
if targets is not None:
|
if targets is not None:
|
||||||
if len(targets) == 0 or len(targets[0]) == 0:
|
|
||||||
raise ValueError("At least one target must be provided when passed.")
|
|
||||||
if isinstance(targets, str):
|
if isinstance(targets, str):
|
||||||
targets = [targets]
|
targets = [targets]
|
||||||
|
|
||||||
targets_proc = []
|
try:
|
||||||
|
vocab = self.tokenizer.get_vocab()
|
||||||
|
except Exception:
|
||||||
|
vocab = {}
|
||||||
|
target_ids = []
|
||||||
for target in targets:
|
for target in targets:
|
||||||
target_enc = self.tokenizer.tokenize(target)
|
id_ = vocab.get(target, None)
|
||||||
if len(target_enc) > 1 or target_enc[0] == self.tokenizer.unk_token:
|
if id_ is None:
|
||||||
|
input_ids = self.tokenizer(
|
||||||
|
target,
|
||||||
|
add_special_tokens=False,
|
||||||
|
return_attention_mask=False,
|
||||||
|
return_token_type_ids=False,
|
||||||
|
max_length=1,
|
||||||
|
truncation=True,
|
||||||
|
)["input_ids"]
|
||||||
|
if len(input_ids) == 0:
|
||||||
|
logger.warning(
|
||||||
|
f"The specified target token `{target}` does not exist in the model vocabulary. "
|
||||||
|
f"We cannot replace it with anything meaningful, ignoring it"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
id_ = input_ids[0]
|
||||||
|
# XXX: If users encounter this pass
|
||||||
|
# it becomes pretty slow, so let's make sure
|
||||||
|
# The warning enables them to fix the input to
|
||||||
|
# get faster performance.
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"The specified target token `{target}` does not exist in the model vocabulary. "
|
f"The specified target token `{target}` does not exist in the model vocabulary. "
|
||||||
f"Replacing with `{target_enc[0]}`."
|
f"Replacing with `{self.tokenizer.convert_ids_to_tokens(id_)}`."
|
||||||
)
|
)
|
||||||
targets_proc.append(target_enc[0])
|
target_ids.append(id_)
|
||||||
target_inds = np.array(self.tokenizer.convert_tokens_to_ids(targets_proc))
|
target_ids = list(set(target_ids))
|
||||||
|
if len(target_ids) == 0:
|
||||||
|
raise ValueError("At least one target must be provided when passed.")
|
||||||
|
target_ids = np.array(target_ids)
|
||||||
|
# Cap top_k if there are targets
|
||||||
|
if top_k > target_ids.shape[0]:
|
||||||
|
top_k = target_ids.shape[0]
|
||||||
|
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
input_ids = inputs["input_ids"][i]
|
input_ids = inputs["input_ids"][i]
|
||||||
@ -147,14 +178,11 @@ class FillMaskPipeline(Pipeline):
|
|||||||
|
|
||||||
logits = outputs[i, masked_index.item(), :]
|
logits = outputs[i, masked_index.item(), :]
|
||||||
probs = tf.nn.softmax(logits)
|
probs = tf.nn.softmax(logits)
|
||||||
if targets is None:
|
if targets is not None:
|
||||||
topk = tf.math.top_k(probs, k=top_k if top_k is not None else self.top_k)
|
probs = tf.gather_nd(probs, tf.reshape(target_ids, (-1, 1)))
|
||||||
values, predictions = topk.values.numpy(), topk.indices.numpy()
|
|
||||||
else:
|
topk = tf.math.top_k(probs, k=top_k)
|
||||||
values = tf.gather_nd(probs, tf.reshape(target_inds, (-1, 1)))
|
values, predictions = topk.values.numpy(), topk.indices.numpy()
|
||||||
sort_inds = tf.reverse(tf.argsort(values), [0])
|
|
||||||
values = tf.gather_nd(values, tf.reshape(sort_inds, (-1, 1))).numpy()
|
|
||||||
predictions = target_inds[sort_inds.numpy()]
|
|
||||||
else:
|
else:
|
||||||
masked_index = torch.nonzero(input_ids == self.tokenizer.mask_token_id, as_tuple=False)
|
masked_index = torch.nonzero(input_ids == self.tokenizer.mask_token_id, as_tuple=False)
|
||||||
|
|
||||||
@ -163,16 +191,15 @@ class FillMaskPipeline(Pipeline):
|
|||||||
|
|
||||||
logits = outputs[i, masked_index.item(), :]
|
logits = outputs[i, masked_index.item(), :]
|
||||||
probs = logits.softmax(dim=0)
|
probs = logits.softmax(dim=0)
|
||||||
if targets is None:
|
if targets is not None:
|
||||||
values, predictions = probs.topk(top_k if top_k is not None else self.top_k)
|
probs = probs[..., target_ids]
|
||||||
else:
|
|
||||||
values = probs[..., target_inds]
|
values, predictions = probs.topk(top_k)
|
||||||
sort_inds = list(reversed(values.argsort(dim=-1)))
|
|
||||||
values = values[..., sort_inds]
|
|
||||||
predictions = target_inds[sort_inds]
|
|
||||||
|
|
||||||
for v, p in zip(values.tolist(), predictions.tolist()):
|
for v, p in zip(values.tolist(), predictions.tolist()):
|
||||||
tokens = input_ids.numpy()
|
tokens = input_ids.numpy()
|
||||||
|
if targets is not None:
|
||||||
|
p = target_ids[p].tolist()
|
||||||
tokens[masked_index] = p
|
tokens[masked_index] = p
|
||||||
# Filter padding out:
|
# Filter padding out:
|
||||||
tokens = tokens[np.where(tokens != self.tokenizer.pad_token_id)]
|
tokens = tokens[np.where(tokens != self.tokenizer.pad_token_id)]
|
||||||
|
@ -15,7 +15,7 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import pipeline
|
from transformers import pipeline
|
||||||
from transformers.testing_utils import require_tf, require_torch, slow
|
from transformers.testing_utils import nested_simplify, require_tf, require_torch, slow
|
||||||
|
|
||||||
from .test_pipelines_common import MonoInputPipelineCommonMixin
|
from .test_pipelines_common import MonoInputPipelineCommonMixin
|
||||||
|
|
||||||
@ -78,7 +78,8 @@ class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
|
|||||||
@require_torch
|
@require_torch
|
||||||
def test_torch_fill_mask_with_targets(self):
|
def test_torch_fill_mask_with_targets(self):
|
||||||
valid_inputs = ["My name is <mask>"]
|
valid_inputs = ["My name is <mask>"]
|
||||||
valid_targets = [[" Teven", " Patrick", " Clara"], [" Sam"]]
|
# ' Sam' will yield a warning but work
|
||||||
|
valid_targets = [[" Teven", "ĠPatrick", "ĠClara"], ["ĠSam"], [" Sam"]]
|
||||||
invalid_targets = [[], [""], ""]
|
invalid_targets = [[], [""], ""]
|
||||||
for model_name in self.small_models:
|
for model_name in self.small_models:
|
||||||
unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="pt")
|
unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="pt")
|
||||||
@ -89,19 +90,77 @@ class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
|
|||||||
for targets in invalid_targets:
|
for targets in invalid_targets:
|
||||||
self.assertRaises(ValueError, unmasker, valid_inputs, targets=targets)
|
self.assertRaises(ValueError, unmasker, valid_inputs, targets=targets)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@slow
|
||||||
|
def test_torch_fill_mask_targets_equivalence(self):
|
||||||
|
model_name = self.large_models[0]
|
||||||
|
unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="pt")
|
||||||
|
unmasked = unmasker(self.valid_inputs[0])
|
||||||
|
tokens = [top_mask["token_str"] for top_mask in unmasked]
|
||||||
|
scores = [top_mask["score"] for top_mask in unmasked]
|
||||||
|
|
||||||
|
unmasked_targets = unmasker(self.valid_inputs[0], targets=tokens)
|
||||||
|
target_scores = [top_mask["score"] for top_mask in unmasked_targets]
|
||||||
|
|
||||||
|
self.assertEqual(scores, target_scores)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_torch_fill_mask_with_targets_and_topk(self):
|
||||||
|
model_name = self.small_models[0]
|
||||||
|
unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="pt")
|
||||||
|
targets = [" Teven", "ĠPatrick", "ĠClara"]
|
||||||
|
top_k = 2
|
||||||
|
outputs = unmasker("My name is <mask>", targets=targets, top_k=top_k)
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(outputs),
|
||||||
|
[
|
||||||
|
{"sequence": "My name is Patrick", "score": 0.0, "token": 3499, "token_str": " Patrick"},
|
||||||
|
{"sequence": "My name is Te", "score": 0.0, "token": 2941, "token_str": " Te"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_torch_fill_mask_with_duplicate_targets_and_topk(self):
|
||||||
|
model_name = self.small_models[0]
|
||||||
|
unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="pt")
|
||||||
|
# String duplicates + id duplicates
|
||||||
|
targets = [" Teven", "ĠPatrick", "ĠClara", "ĠClara", " Clara"]
|
||||||
|
top_k = 10
|
||||||
|
outputs = unmasker("My name is <mask>", targets=targets, top_k=top_k)
|
||||||
|
|
||||||
|
# The target list contains duplicates, so we can't output more
|
||||||
|
# than them
|
||||||
|
self.assertEqual(len(outputs), 3)
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
def test_tf_fill_mask_with_targets(self):
|
def test_tf_fill_mask_with_targets(self):
|
||||||
valid_inputs = ["My name is <mask>"]
|
valid_inputs = ["My name is <mask>"]
|
||||||
valid_targets = [[" Teven", " Patrick", " Clara"], [" Sam"]]
|
# ' Teven' will yield a warning but work as " Te"
|
||||||
invalid_targets = [[], [""], ""]
|
invalid_targets = [[], [""], ""]
|
||||||
for model_name in self.small_models:
|
unmasker = pipeline(
|
||||||
unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="tf")
|
task="fill-mask", model=self.small_models[0], tokenizer=self.small_models[0], framework="tf"
|
||||||
for targets in valid_targets:
|
)
|
||||||
outputs = unmasker(valid_inputs, targets=targets)
|
outputs = unmasker(valid_inputs, targets=[" Teven", "ĠPatrick", "ĠClara"])
|
||||||
self.assertIsInstance(outputs, list)
|
self.assertEqual(
|
||||||
self.assertEqual(len(outputs), len(targets))
|
nested_simplify(outputs),
|
||||||
for targets in invalid_targets:
|
[
|
||||||
self.assertRaises(ValueError, unmasker, valid_inputs, targets=targets)
|
{"sequence": "My name is Clara", "score": 0.0, "token": 13606, "token_str": " Clara"},
|
||||||
|
{"sequence": "My name is Patrick", "score": 0.0, "token": 3499, "token_str": " Patrick"},
|
||||||
|
{"sequence": "My name is Te", "score": 0.0, "token": 2941, "token_str": " Te"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
# topk
|
||||||
|
outputs = unmasker(valid_inputs, targets=[" Teven", "ĠPatrick", "ĠClara"], top_k=2)
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(outputs),
|
||||||
|
[
|
||||||
|
{"sequence": "My name is Clara", "score": 0.0, "token": 13606, "token_str": " Clara"},
|
||||||
|
{"sequence": "My name is Patrick", "score": 0.0, "token": 3499, "token_str": " Patrick"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
for targets in invalid_targets:
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
unmasker(valid_inputs, targets=targets)
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@slow
|
@slow
|
||||||
@ -111,7 +170,7 @@ class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
|
|||||||
"My name is <mask>",
|
"My name is <mask>",
|
||||||
"The largest city in France is <mask>",
|
"The largest city in France is <mask>",
|
||||||
]
|
]
|
||||||
valid_targets = [" Patrick", " Clara"]
|
valid_targets = ["ĠPatrick", "ĠClara"]
|
||||||
for model_name in self.large_models:
|
for model_name in self.large_models:
|
||||||
unmasker = pipeline(
|
unmasker = pipeline(
|
||||||
task="fill-mask",
|
task="fill-mask",
|
||||||
@ -184,7 +243,7 @@ class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
|
|||||||
"My name is <mask>",
|
"My name is <mask>",
|
||||||
"The largest city in France is <mask>",
|
"The largest city in France is <mask>",
|
||||||
]
|
]
|
||||||
valid_targets = [" Patrick", " Clara"]
|
valid_targets = ["ĠPatrick", "ĠClara"]
|
||||||
for model_name in self.large_models:
|
for model_name in self.large_models:
|
||||||
unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="tf", top_k=2)
|
unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="tf", top_k=2)
|
||||||
|
|
||||||
@ -242,3 +301,17 @@ class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
|
|||||||
self.assertIn(key, result)
|
self.assertIn(key, result)
|
||||||
|
|
||||||
self.assertRaises(Exception, unmasker, [None])
|
self.assertRaises(Exception, unmasker, [None])
|
||||||
|
|
||||||
|
@require_tf
|
||||||
|
@slow
|
||||||
|
def test_tf_fill_mask_targets_equivalence(self):
|
||||||
|
model_name = self.large_models[0]
|
||||||
|
unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="tf")
|
||||||
|
unmasked = unmasker(self.valid_inputs[0])
|
||||||
|
tokens = [top_mask["token_str"] for top_mask in unmasked]
|
||||||
|
scores = [top_mask["score"] for top_mask in unmasked]
|
||||||
|
|
||||||
|
unmasked_targets = unmasker(self.valid_inputs[0], targets=tokens)
|
||||||
|
target_scores = [top_mask["score"] for top_mask in unmasked_targets]
|
||||||
|
|
||||||
|
self.assertEqual(scores, target_scores)
|
||||||
|
Loading…
Reference in New Issue
Block a user