mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[cleanup] test_tokenization_common.py (#4390)
This commit is contained in:
parent
8f1d047148
commit
07dd7c2fd8
@ -198,11 +198,12 @@ Follow these steps to start contributing:
|
||||
are useful to avoid duplicated work, and to differentiate it from PRs ready
|
||||
to be merged;
|
||||
4. Make sure existing tests pass;
|
||||
5. Add high-coverage tests. No quality test, no merge.
|
||||
5. Add high-coverage tests. No quality testing = no merge.
|
||||
- If you are adding a new model, make sure that you use `ModelTester.all_model_classes = (MyModel, MyModelWithLMHead,...)`, which triggers the common tests.
|
||||
- If you are adding new `@slow` tests, make sure they pass using `RUN_SLOW=1 python -m pytest tests/test_my_new_model.py`.
|
||||
- If you are adding a new tokenizer, write tests, and make sure `RUN_SLOW=1 python -m pytest tests/test_tokenization_{your_model_name}.py` passes.
|
||||
CircleCI does not run them.
|
||||
6. All public methods must have informative docstrings;
|
||||
6. All public methods must have informative docstrings that work nicely with sphinx. See `modeling_ctrl.py` for an example.
|
||||
|
||||
### Tests
|
||||
|
||||
|
@ -199,7 +199,7 @@ class RobertaTokenizer(GPT2Tokenizer):
|
||||
if token_ids_1 is not None:
|
||||
raise ValueError(
|
||||
"You should not supply a second sequence if the provided sequence of "
|
||||
"ids is already formated with special tokens for the model."
|
||||
"ids is already formatted with special tokens for the model."
|
||||
)
|
||||
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
|
||||
|
||||
|
@ -771,26 +771,26 @@ class PreTrainedTokenizer(SpecialTokensMixin):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def is_fast(self):
|
||||
def is_fast(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def max_len(self):
|
||||
def max_len(self) -> int:
|
||||
""" Kept here for backward compatibility.
|
||||
Now renamed to `model_max_length` to avoid ambiguity.
|
||||
"""
|
||||
return self.model_max_length
|
||||
|
||||
@property
|
||||
def max_len_single_sentence(self):
|
||||
def max_len_single_sentence(self) -> int:
|
||||
return self.model_max_length - self.num_special_tokens_to_add(pair=False)
|
||||
|
||||
@property
|
||||
def max_len_sentences_pair(self):
|
||||
def max_len_sentences_pair(self) -> int:
|
||||
return self.model_max_length - self.num_special_tokens_to_add(pair=True)
|
||||
|
||||
@max_len_single_sentence.setter
|
||||
def max_len_single_sentence(self, value):
|
||||
def max_len_single_sentence(self, value) -> int:
|
||||
""" For backward compatibility, allow to try to setup 'max_len_single_sentence' """
|
||||
if value == self.model_max_length - self.num_special_tokens_to_add(pair=False):
|
||||
logger.warning(
|
||||
@ -802,7 +802,7 @@ class PreTrainedTokenizer(SpecialTokensMixin):
|
||||
)
|
||||
|
||||
@max_len_sentences_pair.setter
|
||||
def max_len_sentences_pair(self, value):
|
||||
def max_len_sentences_pair(self, value) -> int:
|
||||
""" For backward compatibility, allow to try to setup 'max_len_sentences_pair' """
|
||||
if value == self.model_max_length - self.num_special_tokens_to_add(pair=True):
|
||||
logger.warning(
|
||||
@ -1118,7 +1118,7 @@ class PreTrainedTokenizer(SpecialTokensMixin):
|
||||
|
||||
return vocab_files + (special_tokens_map_file, added_tokens_file)
|
||||
|
||||
def save_vocabulary(self, save_directory):
|
||||
def save_vocabulary(self, save_directory) -> Tuple[str]:
|
||||
""" Save the tokenizer vocabulary to a directory. This method does *NOT* save added tokens
|
||||
and special token mappings.
|
||||
|
||||
@ -1128,7 +1128,7 @@ class PreTrainedTokenizer(SpecialTokensMixin):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def add_tokens(self, new_tokens):
|
||||
def add_tokens(self, new_tokens: Union[str, List[str]]) -> int:
|
||||
"""
|
||||
Add a list of new tokens to the tokenizer class. If the new tokens are not in the
|
||||
vocabulary, they are added to it with indices starting from length of the current vocabulary.
|
||||
@ -1156,7 +1156,7 @@ class PreTrainedTokenizer(SpecialTokensMixin):
|
||||
if not isinstance(new_tokens, list):
|
||||
new_tokens = [new_tokens]
|
||||
|
||||
to_add_tokens = []
|
||||
tokens_to_add = []
|
||||
for token in new_tokens:
|
||||
assert isinstance(token, str)
|
||||
if self.init_kwargs.get("do_lower_case", False) and token not in self.all_special_tokens:
|
||||
@ -1164,18 +1164,18 @@ class PreTrainedTokenizer(SpecialTokensMixin):
|
||||
if (
|
||||
token != self.unk_token
|
||||
and self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token)
|
||||
and token not in to_add_tokens
|
||||
and token not in tokens_to_add
|
||||
):
|
||||
to_add_tokens.append(token)
|
||||
tokens_to_add.append(token)
|
||||
logger.info("Adding %s to the vocabulary", token)
|
||||
|
||||
added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(to_add_tokens))
|
||||
added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(tokens_to_add))
|
||||
added_tok_decoder = {v: k for k, v in added_tok_encoder.items()}
|
||||
self.added_tokens_encoder.update(added_tok_encoder)
|
||||
self.unique_added_tokens_encoder = set(self.added_tokens_encoder.keys()).union(set(self.all_special_tokens))
|
||||
self.added_tokens_decoder.update(added_tok_decoder)
|
||||
|
||||
return len(to_add_tokens)
|
||||
return len(tokens_to_add)
|
||||
|
||||
def num_special_tokens_to_add(self, pair=False):
|
||||
"""
|
||||
@ -2080,10 +2080,7 @@ class PreTrainedTokenizer(SpecialTokensMixin):
|
||||
def build_inputs_with_special_tokens(self, token_ids_0: List, token_ids_1: Optional[List] = None) -> List:
|
||||
"""
|
||||
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
|
||||
by concatenating and adding special tokens.
|
||||
A RoBERTa sequence has the following format:
|
||||
single sequence: <s> X </s>
|
||||
pair of sequences: <s> A </s></s> B </s>
|
||||
by concatenating and adding special tokens. This implementation does not add special tokens.
|
||||
"""
|
||||
if token_ids_1 is None:
|
||||
return token_ids_0
|
||||
|
@ -36,9 +36,6 @@ class AlbertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
tokenizer = AlbertTokenizer(SAMPLE_VOCAB)
|
||||
tokenizer.save_pretrained(self.tmpdirname)
|
||||
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return AlbertTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_input_output_texts(self):
|
||||
input_text = "this is a test"
|
||||
output_text = "this is a test"
|
||||
|
@ -59,9 +59,6 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return BertTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_rust_tokenizer(self, **kwargs):
|
||||
return BertTokenizerFast.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
|
@ -60,9 +60,6 @@ class BertJapaneseTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return BertJapaneseTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_input_output_texts(self):
|
||||
input_text = "こんにちは、世界。 \nこんばんは、世界。"
|
||||
output_text = "こんにちは 、 世界 。 こんばんは 、 世界 。"
|
||||
|
@ -22,12 +22,12 @@ from collections import OrderedDict
|
||||
from typing import TYPE_CHECKING, Dict, Tuple, Union
|
||||
|
||||
from tests.utils import require_tf, require_torch
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import (
|
||||
PretrainedConfig,
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast,
|
||||
PreTrainedModel,
|
||||
TFPreTrainedModel,
|
||||
@ -67,19 +67,24 @@ class TokenizerTesterMixin:
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmpdirname)
|
||||
|
||||
def get_tokenizer(self, **kwargs):
|
||||
raise NotImplementedError
|
||||
def get_tokenizer(self, **kwargs) -> PreTrainedTokenizer:
|
||||
return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_rust_tokenizer(self, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_input_output_texts(self):
|
||||
raise NotImplementedError
|
||||
def get_input_output_texts(self) -> Tuple[str, str]:
|
||||
"""Feel free to overwrite"""
|
||||
# TODO: @property
|
||||
return (
|
||||
"This is a test",
|
||||
"This is a test",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def convert_batch_encode_plus_format_to_encode_plus(batch_encode_plus_sequences):
|
||||
# Switch from batch_encode_plus format: {'input_ids': [[...], [...]], ...}
|
||||
# to the concatenated encode_plus format: [{'input_ids': [...], ...}, {'input_ids': [...], ...}]
|
||||
# to the list of examples/ encode_plus format: [{'input_ids': [...], ...}, {'input_ids': [...], ...}]
|
||||
return [
|
||||
{value: batch_encode_plus_sequences[value][i] for value in batch_encode_plus_sequences.keys()}
|
||||
for i in range(len(batch_encode_plus_sequences["input_ids"]))
|
||||
@ -114,13 +119,13 @@ class TokenizerTesterMixin:
|
||||
|
||||
# Now let's start the test
|
||||
tokenizer = self.get_tokenizer(max_len=42)
|
||||
|
||||
before_tokens = tokenizer.encode("He is very happy, UNwant\u00E9d,running", add_special_tokens=False)
|
||||
sample_text = "He is very happy, UNwant\u00E9d,running"
|
||||
before_tokens = tokenizer.encode(sample_text, add_special_tokens=False)
|
||||
|
||||
tokenizer.save_pretrained(self.tmpdirname)
|
||||
tokenizer = self.tokenizer_class.from_pretrained(self.tmpdirname)
|
||||
|
||||
after_tokens = tokenizer.encode("He is very happy, UNwant\u00E9d,running", add_special_tokens=False)
|
||||
after_tokens = tokenizer.encode(sample_text, add_special_tokens=False)
|
||||
self.assertListEqual(before_tokens, after_tokens)
|
||||
|
||||
self.assertEqual(tokenizer.max_len, 42)
|
||||
@ -128,6 +133,7 @@ class TokenizerTesterMixin:
|
||||
self.assertEqual(tokenizer.max_len, 43)
|
||||
|
||||
def test_pickle_tokenizer(self):
|
||||
"""Google pickle __getstate__ __setstate__ if you are struggling with this."""
|
||||
tokenizer = self.get_tokenizer()
|
||||
self.assertIsNotNone(tokenizer)
|
||||
|
||||
@ -253,7 +259,7 @@ class TokenizerTesterMixin:
|
||||
decoded = tokenizer.decode(encoded, skip_special_tokens=True)
|
||||
assert special_token not in decoded
|
||||
|
||||
def test_required_methods_tokenizer(self):
|
||||
def test_internal_consistency(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
input_text, output_text = self.get_input_output_texts()
|
||||
|
||||
@ -263,13 +269,12 @@ class TokenizerTesterMixin:
|
||||
self.assertListEqual(ids, ids_2)
|
||||
|
||||
tokens_2 = tokenizer.convert_ids_to_tokens(ids)
|
||||
self.assertNotEqual(len(tokens_2), 0)
|
||||
text_2 = tokenizer.decode(ids)
|
||||
self.assertIsInstance(text_2, str)
|
||||
|
||||
self.assertEqual(text_2, output_text)
|
||||
|
||||
self.assertNotEqual(len(tokens_2), 0)
|
||||
self.assertIsInstance(text_2, str)
|
||||
|
||||
def test_encode_decode_with_spaces(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
@ -429,10 +434,7 @@ class TokenizerTesterMixin:
|
||||
|
||||
def test_special_tokens_mask(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
sequence_0 = "Encode this."
|
||||
sequence_1 = "This one too please."
|
||||
|
||||
# Testing single inputs
|
||||
encoded_sequence = tokenizer.encode(sequence_0, add_special_tokens=False)
|
||||
encoded_sequence_dict = tokenizer.encode_plus(
|
||||
@ -442,13 +444,13 @@ class TokenizerTesterMixin:
|
||||
special_tokens_mask = encoded_sequence_dict["special_tokens_mask"]
|
||||
self.assertEqual(len(special_tokens_mask), len(encoded_sequence_w_special))
|
||||
|
||||
filtered_sequence = [
|
||||
(x if not special_tokens_mask[i] else None) for i, x in enumerate(encoded_sequence_w_special)
|
||||
]
|
||||
filtered_sequence = [x for x in filtered_sequence if x is not None]
|
||||
filtered_sequence = [x for i, x in enumerate(encoded_sequence_w_special) if not special_tokens_mask[i]]
|
||||
self.assertEqual(encoded_sequence, filtered_sequence)
|
||||
|
||||
# Testing inputs pairs
|
||||
def test_special_tokens_mask_input_pairs(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
sequence_0 = "Encode this."
|
||||
sequence_1 = "This one too please."
|
||||
encoded_sequence = tokenizer.encode(sequence_0, add_special_tokens=False)
|
||||
encoded_sequence += tokenizer.encode(sequence_1, add_special_tokens=False)
|
||||
encoded_sequence_dict = tokenizer.encode_plus(
|
||||
@ -464,7 +466,9 @@ class TokenizerTesterMixin:
|
||||
filtered_sequence = [x for x in filtered_sequence if x is not None]
|
||||
self.assertEqual(encoded_sequence, filtered_sequence)
|
||||
|
||||
# Testing with already existing special tokens
|
||||
def test_special_tokens_mask_already_has_special_tokens(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
sequence_0 = "Encode this."
|
||||
if tokenizer.cls_token_id == tokenizer.unk_token_id and tokenizer.cls_token_id == tokenizer.unk_token_id:
|
||||
tokenizer.add_special_tokens({"cls_token": "</s>", "sep_token": "<s>"})
|
||||
encoded_sequence_dict = tokenizer.encode_plus(
|
||||
@ -514,13 +518,12 @@ class TokenizerTesterMixin:
|
||||
tokenizer.padding_side = "right"
|
||||
padded_sequence_right = tokenizer.encode(sequence, pad_to_max_length=True)
|
||||
padded_sequence_right_length = len(padded_sequence_right)
|
||||
assert sequence_length == padded_sequence_right_length
|
||||
assert encoded_sequence == padded_sequence_right
|
||||
|
||||
tokenizer.padding_side = "left"
|
||||
padded_sequence_left = tokenizer.encode(sequence, pad_to_max_length=True)
|
||||
padded_sequence_left_length = len(padded_sequence_left)
|
||||
|
||||
assert sequence_length == padded_sequence_right_length
|
||||
assert encoded_sequence == padded_sequence_right
|
||||
assert sequence_length == padded_sequence_left_length
|
||||
assert encoded_sequence == padded_sequence_left
|
||||
|
||||
@ -617,6 +620,9 @@ class TokenizerTesterMixin:
|
||||
self.assertIsInstance(vocab, dict)
|
||||
self.assertEqual(len(vocab), len(tokenizer))
|
||||
|
||||
def test_conversion_reversible(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
vocab = tokenizer.get_vocab()
|
||||
for word, ind in vocab.items():
|
||||
self.assertEqual(tokenizer.convert_tokens_to_ids(word), ind)
|
||||
self.assertEqual(tokenizer.convert_ids_to_tokens(ind), word)
|
||||
@ -746,6 +752,7 @@ class TokenizerTesterMixin:
|
||||
|
||||
@require_torch
|
||||
def test_torch_encode_plus_sent_to_model(self):
|
||||
import torch
|
||||
from transformers import MODEL_MAPPING, TOKENIZER_MAPPING
|
||||
|
||||
MODEL_TOKENIZER_MAPPING = merge_model_tokenizer_mappings(MODEL_MAPPING, TOKENIZER_MAPPING)
|
||||
@ -773,8 +780,10 @@ class TokenizerTesterMixin:
|
||||
encoded_sequence = tokenizer.encode_plus(sequence, return_tensors="pt")
|
||||
batch_encoded_sequence = tokenizer.batch_encode_plus([sequence, sequence], return_tensors="pt")
|
||||
# This should not fail
|
||||
model(**encoded_sequence)
|
||||
model(**batch_encoded_sequence)
|
||||
|
||||
with torch.no_grad(): # saves some time
|
||||
model(**encoded_sequence)
|
||||
model(**batch_encoded_sequence)
|
||||
|
||||
if self.test_rust_tokenizer:
|
||||
fast_tokenizer = self.get_rust_tokenizer()
|
||||
|
@ -24,9 +24,6 @@ class DistilBertTokenizationTest(BertTokenizationTest):
|
||||
|
||||
tokenizer_class = DistilBertTokenizer
|
||||
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return DistilBertTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_rust_tokenizer(self, **kwargs):
|
||||
return DistilBertTokenizerFast.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
|
@ -64,13 +64,8 @@ class OpenAIGPTTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
with open(self.merges_file, "w") as fp:
|
||||
fp.write("\n".join(merges))
|
||||
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return OpenAIGPTTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_input_output_texts(self):
|
||||
input_text = "lower newer"
|
||||
output_text = "lower newer"
|
||||
return input_text, output_text
|
||||
return "lower newer", "lower newer"
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
tokenizer = OpenAIGPTTokenizer(self.vocab_file, self.merges_file)
|
||||
|
@ -37,14 +37,6 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
tokenizer = T5Tokenizer(SAMPLE_VOCAB)
|
||||
tokenizer.save_pretrained(self.tmpdirname)
|
||||
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return T5Tokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_input_output_texts(self):
|
||||
input_text = "This is a test"
|
||||
output_text = "This is a test"
|
||||
return input_text, output_text
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
tokenizer = T5Tokenizer(SAMPLE_VOCAB)
|
||||
|
||||
|
@ -65,9 +65,6 @@ class XLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
with open(self.merges_file, "w") as fp:
|
||||
fp.write("\n".join(merges))
|
||||
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return XLMTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_input_output_texts(self):
|
||||
input_text = "lower newer"
|
||||
output_text = "lower newer"
|
||||
|
@ -17,6 +17,7 @@
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.tokenization_xlm_roberta import SPIECE_UNDERLINE, XLMRobertaTokenizer
|
||||
|
||||
from .test_tokenization_common import TokenizerTesterMixin
|
||||
@ -37,14 +38,6 @@ class XLMRobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
tokenizer = XLMRobertaTokenizer(SAMPLE_VOCAB, keep_accents=True)
|
||||
tokenizer.save_pretrained(self.tmpdirname)
|
||||
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return XLMRobertaTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_input_output_texts(self):
|
||||
input_text = "This is a test"
|
||||
output_text = "This is a test"
|
||||
return input_text, output_text
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
tokenizer = XLMRobertaTokenizer(SAMPLE_VOCAB, keep_accents=True)
|
||||
|
||||
@ -121,22 +114,22 @@ class XLMRobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
],
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def big_tokenizer(self):
|
||||
return XLMRobertaTokenizer.from_pretrained("xlm-roberta-base")
|
||||
|
||||
@slow
|
||||
def test_tokenization_base_easy_symbols(self):
|
||||
tokenizer = XLMRobertaTokenizer.from_pretrained("xlm-roberta-base")
|
||||
|
||||
symbols = "Hello World!"
|
||||
original_tokenizer_encodings = [0, 35378, 6661, 38, 2]
|
||||
# xlmr = torch.hub.load('pytorch/fairseq', 'xlmr.base') # xlmr.large has same tokenizer
|
||||
# xlmr.eval()
|
||||
# xlmr.encode(symbols)
|
||||
|
||||
self.assertListEqual(original_tokenizer_encodings, tokenizer.encode(symbols))
|
||||
self.assertListEqual(original_tokenizer_encodings, self.big_tokenizer.encode(symbols))
|
||||
|
||||
@slow
|
||||
def test_tokenization_base_hard_symbols(self):
|
||||
tokenizer = XLMRobertaTokenizer.from_pretrained("xlm-roberta-base")
|
||||
|
||||
symbols = 'This is a very long text with a lot of weird characters, such as: . , ~ ? ( ) " [ ] ! : - . Also we will add words that should not exsist and be tokenized to <unk>, such as saoneuhaoesuth'
|
||||
original_tokenizer_encodings = [
|
||||
0,
|
||||
@ -209,4 +202,4 @@ class XLMRobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
# xlmr.eval()
|
||||
# xlmr.encode(symbols)
|
||||
|
||||
self.assertListEqual(original_tokenizer_encodings, tokenizer.encode(symbols))
|
||||
self.assertListEqual(original_tokenizer_encodings, self.big_tokenizer.encode(symbols))
|
||||
|
@ -37,14 +37,6 @@ class XLNetTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True)
|
||||
tokenizer.save_pretrained(self.tmpdirname)
|
||||
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return XLNetTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_input_output_texts(self):
|
||||
input_text = "This is a test"
|
||||
output_text = "This is a test"
|
||||
return input_text, output_text
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user