mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
Feature/fix slow test in mluke (#14749)
* make MLukeTokenizerTest fast * make LukeTokenizerTest fast * add entry to _toctree.yaml
This commit is contained in:
parent
c94c1b8967
commit
824fd44fc3
@ -204,6 +204,8 @@
|
||||
title: MegatronBERT
|
||||
- local: model_doc/megatron_gpt2
|
||||
title: MegatronGPT2
|
||||
- local: model_doc/mluke
|
||||
title: MLUKE
|
||||
- local: model_doc/mobilebert
|
||||
title: MobileBERT
|
||||
- local: model_doc/mluke
|
||||
|
@ -198,6 +198,10 @@ class LukeTokenizer(RobertaTokenizer):
|
||||
max_mention_length=30,
|
||||
entity_token_1="<ent>",
|
||||
entity_token_2="<ent2>",
|
||||
entity_unk_token="[UNK]",
|
||||
entity_pad_token="[PAD]",
|
||||
entity_mask_token="[MASK]",
|
||||
entity_mask2_token="[MASK2]",
|
||||
**kwargs
|
||||
):
|
||||
# we add 2 special tokens for downstream tasks
|
||||
@ -223,11 +227,25 @@ class LukeTokenizer(RobertaTokenizer):
|
||||
max_mention_length=30,
|
||||
entity_token_1="<ent>",
|
||||
entity_token_2="<ent2>",
|
||||
entity_unk_token=entity_unk_token,
|
||||
entity_pad_token=entity_pad_token,
|
||||
entity_mask_token=entity_mask_token,
|
||||
entity_mask2_token=entity_mask2_token,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
with open(entity_vocab_file, encoding="utf-8") as entity_vocab_handle:
|
||||
self.entity_vocab = json.load(entity_vocab_handle)
|
||||
for entity_special_token in [entity_unk_token, entity_pad_token, entity_mask_token, entity_mask2_token]:
|
||||
if entity_special_token not in self.entity_vocab:
|
||||
raise ValueError(
|
||||
f"Specified entity special token ``{entity_special_token}`` is not found in entity_vocab. "
|
||||
f"Probably an incorrect entity vocab file is loaded: {entity_vocab_file}."
|
||||
)
|
||||
self.entity_unk_token_id = self.entity_vocab[entity_unk_token]
|
||||
self.entity_pad_token_id = self.entity_vocab[entity_pad_token]
|
||||
self.entity_mask_token_id = self.entity_vocab[entity_mask_token]
|
||||
self.entity_mask2_token_id = self.entity_vocab[entity_mask2_token]
|
||||
|
||||
self.task = task
|
||||
if task is None or task == "entity_span_classification":
|
||||
@ -646,8 +664,6 @@ class LukeTokenizer(RobertaTokenizer):
|
||||
first_entity_token_spans, second_entity_token_spans = None, None
|
||||
|
||||
if self.task is None:
|
||||
unk_entity_id = self.entity_vocab["[UNK]"]
|
||||
mask_entity_id = self.entity_vocab["[MASK]"]
|
||||
|
||||
if entity_spans is None:
|
||||
first_ids = get_input_ids(text)
|
||||
@ -656,9 +672,9 @@ class LukeTokenizer(RobertaTokenizer):
|
||||
|
||||
first_ids, first_entity_token_spans = get_input_ids_and_entity_token_spans(text, entity_spans)
|
||||
if entities is None:
|
||||
first_entity_ids = [mask_entity_id] * len(entity_spans)
|
||||
first_entity_ids = [self.entity_mask_token_id] * len(entity_spans)
|
||||
else:
|
||||
first_entity_ids = [self.entity_vocab.get(entity, unk_entity_id) for entity in entities]
|
||||
first_entity_ids = [self.entity_vocab.get(entity, self.entity_unk_token_id) for entity in entities]
|
||||
|
||||
if text_pair is not None:
|
||||
if entity_spans_pair is None:
|
||||
@ -670,9 +686,11 @@ class LukeTokenizer(RobertaTokenizer):
|
||||
text_pair, entity_spans_pair
|
||||
)
|
||||
if entities_pair is None:
|
||||
second_entity_ids = [mask_entity_id] * len(entity_spans_pair)
|
||||
second_entity_ids = [self.entity_mask_token_id] * len(entity_spans_pair)
|
||||
else:
|
||||
second_entity_ids = [self.entity_vocab.get(entity, unk_entity_id) for entity in entities_pair]
|
||||
second_entity_ids = [
|
||||
self.entity_vocab.get(entity, self.entity_unk_token_id) for entity in entities_pair
|
||||
]
|
||||
|
||||
elif self.task == "entity_classification":
|
||||
if not (isinstance(entity_spans, list) and len(entity_spans) == 1 and isinstance(entity_spans[0], tuple)):
|
||||
@ -680,7 +698,7 @@ class LukeTokenizer(RobertaTokenizer):
|
||||
"Entity spans should be a list containing a single tuple "
|
||||
"containing the start and end character indices of an entity"
|
||||
)
|
||||
first_entity_ids = [self.entity_vocab["[MASK]"]]
|
||||
first_entity_ids = [self.entity_mask_token_id]
|
||||
first_ids, first_entity_token_spans = get_input_ids_and_entity_token_spans(text, entity_spans)
|
||||
|
||||
# add special tokens to input ids
|
||||
@ -708,7 +726,7 @@ class LukeTokenizer(RobertaTokenizer):
|
||||
)
|
||||
|
||||
head_span, tail_span = entity_spans
|
||||
first_entity_ids = [self.entity_vocab["[MASK]"], self.entity_vocab["[MASK2]"]]
|
||||
first_entity_ids = [self.entity_mask_token_id, self.entity_mask2_token_id]
|
||||
first_ids, first_entity_token_spans = get_input_ids_and_entity_token_spans(text, entity_spans)
|
||||
|
||||
head_token_span, tail_token_span = first_entity_token_spans
|
||||
@ -729,7 +747,6 @@ class LukeTokenizer(RobertaTokenizer):
|
||||
first_ids = first_ids[:entity_token_start] + [special_token_id] + first_ids[entity_token_start:]
|
||||
|
||||
elif self.task == "entity_span_classification":
|
||||
mask_entity_id = self.entity_vocab["[MASK]"]
|
||||
|
||||
if not (isinstance(entity_spans, list) and len(entity_spans) > 0 and isinstance(entity_spans[0], tuple)):
|
||||
raise ValueError(
|
||||
@ -738,7 +755,7 @@ class LukeTokenizer(RobertaTokenizer):
|
||||
)
|
||||
|
||||
first_ids, first_entity_token_spans = get_input_ids_and_entity_token_spans(text, entity_spans)
|
||||
first_entity_ids = [mask_entity_id] * len(entity_spans)
|
||||
first_entity_ids = [self.entity_mask_token_id] * len(entity_spans)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Task {self.task} not supported")
|
||||
@ -1311,7 +1328,7 @@ class LukeTokenizer(RobertaTokenizer):
|
||||
encoded_inputs["input_ids"] = encoded_inputs["input_ids"] + [self.pad_token_id] * difference
|
||||
if entities_provided:
|
||||
encoded_inputs["entity_ids"] = (
|
||||
encoded_inputs["entity_ids"] + [self.entity_vocab["[PAD]"]] * entity_difference
|
||||
encoded_inputs["entity_ids"] + [self.entity_pad_token_id] * entity_difference
|
||||
)
|
||||
encoded_inputs["entity_position_ids"] = (
|
||||
encoded_inputs["entity_position_ids"] + [[-1] * self.max_mention_length] * entity_difference
|
||||
@ -1341,7 +1358,7 @@ class LukeTokenizer(RobertaTokenizer):
|
||||
encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
|
||||
encoded_inputs["input_ids"] = [self.pad_token_id] * difference + encoded_inputs["input_ids"]
|
||||
if entities_provided:
|
||||
encoded_inputs["entity_ids"] = [self.entity_vocab["[PAD]"]] * entity_difference + encoded_inputs[
|
||||
encoded_inputs["entity_ids"] = [self.entity_pad_token_id] * entity_difference + encoded_inputs[
|
||||
"entity_ids"
|
||||
]
|
||||
encoded_inputs["entity_position_ids"] = [
|
||||
|
@ -253,6 +253,10 @@ class MLukeTokenizer(PreTrainedTokenizer):
|
||||
max_mention_length=30,
|
||||
entity_token_1="<ent>",
|
||||
entity_token_2="<ent2>",
|
||||
entity_unk_token="[UNK]",
|
||||
entity_pad_token="[PAD]",
|
||||
entity_mask_token="[MASK]",
|
||||
entity_mask2_token="[MASK2]",
|
||||
sp_model_kwargs: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
@ -290,6 +294,10 @@ class MLukeTokenizer(PreTrainedTokenizer):
|
||||
max_mention_length=max_mention_length,
|
||||
entity_token_1=entity_token_1,
|
||||
entity_token_2=entity_token_2,
|
||||
entity_unk_token=entity_unk_token,
|
||||
entity_pad_token=entity_pad_token,
|
||||
entity_mask_token=entity_mask_token,
|
||||
entity_mask2_token=entity_mask2_token,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -314,6 +322,16 @@ class MLukeTokenizer(PreTrainedTokenizer):
|
||||
|
||||
with open(entity_vocab_file, encoding="utf-8") as entity_vocab_handle:
|
||||
self.entity_vocab = json.load(entity_vocab_handle)
|
||||
for entity_special_token in [entity_unk_token, entity_pad_token, entity_mask_token, entity_mask2_token]:
|
||||
if entity_special_token not in self.entity_vocab:
|
||||
raise ValueError(
|
||||
f"Specified entity special token ``{entity_special_token}`` is not found in entity_vocab. "
|
||||
f"Probably an incorrect entity vocab file is loaded: {entity_vocab_file}."
|
||||
)
|
||||
self.entity_unk_token_id = self.entity_vocab[entity_unk_token]
|
||||
self.entity_pad_token_id = self.entity_vocab[entity_pad_token]
|
||||
self.entity_mask_token_id = self.entity_vocab[entity_mask_token]
|
||||
self.entity_mask2_token_id = self.entity_vocab[entity_mask2_token]
|
||||
|
||||
self.task = task
|
||||
if task is None or task == "entity_span_classification":
|
||||
@ -753,8 +771,6 @@ class MLukeTokenizer(PreTrainedTokenizer):
|
||||
first_entity_token_spans, second_entity_token_spans = None, None
|
||||
|
||||
if self.task is None:
|
||||
unk_entity_id = self.entity_vocab["[UNK]"]
|
||||
mask_entity_id = self.entity_vocab["[MASK]"]
|
||||
|
||||
if entity_spans is None:
|
||||
first_ids = get_input_ids(text)
|
||||
@ -763,9 +779,9 @@ class MLukeTokenizer(PreTrainedTokenizer):
|
||||
|
||||
first_ids, first_entity_token_spans = get_input_ids_and_entity_token_spans(text, entity_spans)
|
||||
if entities is None:
|
||||
first_entity_ids = [mask_entity_id] * len(entity_spans)
|
||||
first_entity_ids = [self.entity_mask_token_id] * len(entity_spans)
|
||||
else:
|
||||
first_entity_ids = [self.entity_vocab.get(entity, unk_entity_id) for entity in entities]
|
||||
first_entity_ids = [self.entity_vocab.get(entity, self.entity_unk_token_id) for entity in entities]
|
||||
|
||||
if text_pair is not None:
|
||||
if entity_spans_pair is None:
|
||||
@ -777,9 +793,11 @@ class MLukeTokenizer(PreTrainedTokenizer):
|
||||
text_pair, entity_spans_pair
|
||||
)
|
||||
if entities_pair is None:
|
||||
second_entity_ids = [mask_entity_id] * len(entity_spans_pair)
|
||||
second_entity_ids = [self.entity_mask_token_id] * len(entity_spans_pair)
|
||||
else:
|
||||
second_entity_ids = [self.entity_vocab.get(entity, unk_entity_id) for entity in entities_pair]
|
||||
second_entity_ids = [
|
||||
self.entity_vocab.get(entity, self.entity_unk_token_id) for entity in entities_pair
|
||||
]
|
||||
|
||||
elif self.task == "entity_classification":
|
||||
if not (isinstance(entity_spans, list) and len(entity_spans) == 1 and isinstance(entity_spans[0], tuple)):
|
||||
@ -787,7 +805,7 @@ class MLukeTokenizer(PreTrainedTokenizer):
|
||||
"Entity spans should be a list containing a single tuple "
|
||||
"containing the start and end character indices of an entity"
|
||||
)
|
||||
first_entity_ids = [self.entity_vocab["[MASK]"]]
|
||||
first_entity_ids = [self.entity_mask_token_id]
|
||||
first_ids, first_entity_token_spans = get_input_ids_and_entity_token_spans(text, entity_spans)
|
||||
|
||||
# add special tokens to input ids
|
||||
@ -815,7 +833,7 @@ class MLukeTokenizer(PreTrainedTokenizer):
|
||||
)
|
||||
|
||||
head_span, tail_span = entity_spans
|
||||
first_entity_ids = [self.entity_vocab["[MASK]"], self.entity_vocab["[MASK2]"]]
|
||||
first_entity_ids = [self.entity_mask_token_id, self.entity_mask2_token_id]
|
||||
first_ids, first_entity_token_spans = get_input_ids_and_entity_token_spans(text, entity_spans)
|
||||
|
||||
head_token_span, tail_token_span = first_entity_token_spans
|
||||
@ -836,7 +854,6 @@ class MLukeTokenizer(PreTrainedTokenizer):
|
||||
first_ids = first_ids[:entity_token_start] + [special_token_id] + first_ids[entity_token_start:]
|
||||
|
||||
elif self.task == "entity_span_classification":
|
||||
mask_entity_id = self.entity_vocab["[MASK]"]
|
||||
|
||||
if not (isinstance(entity_spans, list) and len(entity_spans) > 0 and isinstance(entity_spans[0], tuple)):
|
||||
raise ValueError(
|
||||
@ -845,7 +862,7 @@ class MLukeTokenizer(PreTrainedTokenizer):
|
||||
)
|
||||
|
||||
first_ids, first_entity_token_spans = get_input_ids_and_entity_token_spans(text, entity_spans)
|
||||
first_entity_ids = [mask_entity_id] * len(entity_spans)
|
||||
first_entity_ids = [self.entity_mask_token_id] * len(entity_spans)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Task {self.task} not supported")
|
||||
@ -1422,7 +1439,7 @@ class MLukeTokenizer(PreTrainedTokenizer):
|
||||
encoded_inputs["input_ids"] = encoded_inputs["input_ids"] + [self.pad_token_id] * difference
|
||||
if entities_provided:
|
||||
encoded_inputs["entity_ids"] = (
|
||||
encoded_inputs["entity_ids"] + [self.entity_vocab["[PAD]"]] * entity_difference
|
||||
encoded_inputs["entity_ids"] + [self.entity_pad_token_id] * entity_difference
|
||||
)
|
||||
encoded_inputs["entity_position_ids"] = (
|
||||
encoded_inputs["entity_position_ids"] + [[-1] * self.max_mention_length] * entity_difference
|
||||
@ -1452,7 +1469,7 @@ class MLukeTokenizer(PreTrainedTokenizer):
|
||||
encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
|
||||
encoded_inputs["input_ids"] = [self.pad_token_id] * difference + encoded_inputs["input_ids"]
|
||||
if entities_provided:
|
||||
encoded_inputs["entity_ids"] = [self.entity_vocab["[PAD]"]] * entity_difference + encoded_inputs[
|
||||
encoded_inputs["entity_ids"] = [self.entity_pad_token_id] * entity_difference + encoded_inputs[
|
||||
"entity_ids"
|
||||
]
|
||||
encoded_inputs["entity_position_ids"] = [
|
||||
|
1
tests/fixtures/test_entity_vocab.json
vendored
Normal file
1
tests/fixtures/test_entity_vocab.json
vendored
Normal file
@ -0,0 +1 @@
|
||||
{"[MASK]": 0, "[UNK]": 1, "[PAD]": 2, "DUMMY": 3, "DUMMY2": 4, "[MASK2]": 5}
|
@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from typing import Tuple
|
||||
|
||||
@ -23,6 +23,11 @@ from transformers.testing_utils import require_torch, slow
|
||||
from .test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
|
||||
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/vocab.json")
|
||||
SAMPLE_MERGE_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/merges.txt")
|
||||
SAMPLE_ENTITY_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_entity_vocab.json")
|
||||
|
||||
|
||||
class LukeTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
tokenizer_class = LukeTokenizer
|
||||
test_rust_tokenizer = False
|
||||
@ -35,7 +40,15 @@ class LukeTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
|
||||
def get_tokenizer(self, task=None, **kwargs):
|
||||
kwargs.update(self.special_tokens_map)
|
||||
return self.tokenizer_class.from_pretrained("studio-ousia/luke-base", task=task, **kwargs)
|
||||
tokenizer = LukeTokenizer(
|
||||
vocab_file=SAMPLE_VOCAB,
|
||||
merges_file=SAMPLE_MERGE_FILE,
|
||||
entity_vocab_file=SAMPLE_ENTITY_VOCAB,
|
||||
task=task,
|
||||
**kwargs,
|
||||
)
|
||||
tokenizer.sanitize_special_tokens()
|
||||
return tokenizer
|
||||
|
||||
def get_input_output_texts(self, tokenizer):
|
||||
input_text = "lower newer"
|
||||
@ -43,25 +56,16 @@ class LukeTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
return input_text, output_text
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
tokenizer = self.tokenizer_class.from_pretrained("studio-ousia/luke-base")
|
||||
tokenizer = self.get_tokenizer()
|
||||
text = "lower newer"
|
||||
bpe_tokens = ["lower", "\u0120newer"]
|
||||
bpe_tokens = ["l", "o", "w", "er", "Ġ", "n", "e", "w", "er"]
|
||||
tokens = tokenizer.tokenize(text) # , add_prefix_space=True)
|
||||
self.assertListEqual(tokens, bpe_tokens)
|
||||
|
||||
input_tokens = tokens + [tokenizer.unk_token]
|
||||
input_bpe_tokens = [29668, 13964, 3]
|
||||
input_bpe_tokens = [0, 1, 2, 15, 10, 9, 3, 2, 15, 19]
|
||||
self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||
|
||||
def luke_dict_integration_testing(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
self.assertListEqual(tokenizer.encode("Hello world!", add_special_tokens=False), [0, 31414, 232, 328, 2])
|
||||
self.assertListEqual(
|
||||
tokenizer.encode("Hello world! cécé herlolip 418", add_special_tokens=False),
|
||||
[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2],
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_sequence_builders(self):
|
||||
tokenizer = self.tokenizer_class.from_pretrained("studio-ousia/luke-large")
|
||||
@ -235,6 +239,7 @@ class LukeTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
tokenizer(sentence, entity_spans=[0, 0, 0])
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
class LukeTokenizerIntegrationTests(unittest.TestCase):
|
||||
tokenizer_class = LukeTokenizer
|
||||
|
@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from typing import Tuple
|
||||
|
||||
@ -23,7 +24,10 @@ from transformers.testing_utils import require_torch, slow
|
||||
from .test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
|
||||
@slow
|
||||
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model")
|
||||
SAMPLE_ENTITY_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_entity_vocab.json")
|
||||
|
||||
|
||||
class MLukeTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
tokenizer_class = MLukeTokenizer
|
||||
test_rust_tokenizer = False
|
||||
@ -37,7 +41,9 @@ class MLukeTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
def get_tokenizer(self, task=None, **kwargs):
|
||||
kwargs.update(self.special_tokens_map)
|
||||
kwargs.update({"task": task})
|
||||
return self.tokenizer_class.from_pretrained("studio-ousia/mluke-base", **kwargs)
|
||||
tokenizer = MLukeTokenizer(vocab_file=SAMPLE_VOCAB, entity_vocab_file=SAMPLE_ENTITY_VOCAB, **kwargs)
|
||||
tokenizer.sanitize_special_tokens()
|
||||
return tokenizer
|
||||
|
||||
def get_input_output_texts(self, tokenizer):
|
||||
input_text = "lower newer"
|
||||
@ -45,14 +51,14 @@ class MLukeTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
return input_text, output_text
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
tokenizer = self.tokenizer_class.from_pretrained("studio-ousia/mluke-base")
|
||||
tokenizer = self.get_tokenizer()
|
||||
text = "lower newer"
|
||||
spm_tokens = ["▁lower", "▁new", "er"]
|
||||
spm_tokens = ["▁l", "ow", "er", "▁new", "er"]
|
||||
tokens = tokenizer.tokenize(text)
|
||||
self.assertListEqual(tokens, spm_tokens)
|
||||
|
||||
input_tokens = tokens + [tokenizer.unk_token]
|
||||
input_spm_tokens = [92319, 3525, 56, 3]
|
||||
input_spm_tokens = [149, 116, 40, 410, 40] + [3]
|
||||
self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_spm_tokens)
|
||||
|
||||
def mluke_dict_integration_testing(self):
|
||||
@ -140,7 +146,7 @@ class MLukeTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
sentence = "ISO 639-3 uses the code fas for the dialects spoken across Iran and Afghanistan."
|
||||
entities = ["en:ISO 639-3"]
|
||||
entities = ["DUMMY"]
|
||||
spans = [(0, 9)]
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
|
Loading…
Reference in New Issue
Block a user