mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
cleaning up tokenizer tests structure (at last) - last remaining ppb refs
This commit is contained in:
parent
00132b7a7a
commit
328afb7097
11
README.md
11
README.md
@ -345,8 +345,13 @@ tokenizer = BertTokenizer.from_pretrained('./my_saved_model_directory/')
|
||||
|
||||
### Optimizers: BertAdam & OpenAIAdam are now AdamW, schedules are standard PyTorch schedules
|
||||
|
||||
The two optimizers previously included, `BertAdam` and `OpenAIAdam`, have been replaced by a single `AdamW` optimizer.
|
||||
The new optimizer `AdamW` matches PyTorch `Adam` optimizer API.
|
||||
The two optimizers previously included, `BertAdam` and `OpenAIAdam`, have been replaced by a single `AdamW` optimizer which has a few differences:
|
||||
|
||||
- it only implements weights decay correction,
|
||||
- schedules are now externals (see below),
|
||||
- gradient clipping is now also external (see below).
|
||||
|
||||
The new optimizer `AdamW` matches PyTorch `Adam` optimizer API and let you use standard PyTorch or apex methods for the schedule and clipping.
|
||||
|
||||
The schedules are now standard [PyTorch learning rate schedulers](https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate) and not part of the optimizer anymore.
|
||||
|
||||
@ -355,6 +360,7 @@ Here is a conversion examples from `BertAdam` with a linear warmup and decay sch
|
||||
```python
|
||||
# Parameters:
|
||||
lr = 1e-3
|
||||
max_grad_norm = 1.0
|
||||
num_total_steps = 1000
|
||||
num_warmup_steps = 100
|
||||
warmup_proportion = float(num_warmup_steps) / float(num_total_steps) # 0.1
|
||||
@ -374,6 +380,7 @@ scheduler = WarmupLinearSchedule(optimizer, warmup_steps=num_warmup_steps, t_tot
|
||||
for batch in train_data:
|
||||
loss = model(batch)
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) # Gradient clipping is not in AdamW anymore (so you can use amp without issue)
|
||||
scheduler.step()
|
||||
optimizer.step()
|
||||
```
|
||||
|
@ -68,8 +68,13 @@ tokenizer = BertTokenizer.from_pretrained('./my_saved_model_directory/')
|
||||
|
||||
### Optimizers: BertAdam & OpenAIAdam are now AdamW, schedules are standard PyTorch schedules
|
||||
|
||||
The two optimizers previously included, `BertAdam` and `OpenAIAdam`, have been replaced by a single `AdamW` optimizer.
|
||||
The new optimizer `AdamW` matches PyTorch `Adam` optimizer API.
|
||||
The two optimizers previously included, `BertAdam` and `OpenAIAdam`, have been replaced by a single `AdamW` optimizer which has a few differences:
|
||||
|
||||
- it only implements weights decay correction,
|
||||
- schedules are now externals (see below),
|
||||
- gradient clipping is now also external (see below).
|
||||
|
||||
The new optimizer `AdamW` matches PyTorch `Adam` optimizer API and let you use standard PyTorch or apex methods for the schedule and clipping.
|
||||
|
||||
The schedules are now standard [PyTorch learning rate schedulers](https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate) and not part of the optimizer anymore.
|
||||
|
||||
@ -78,6 +83,7 @@ Here is a conversion examples from `BertAdam` with a linear warmup and decay sch
|
||||
```python
|
||||
# Parameters:
|
||||
lr = 1e-3
|
||||
max_grad_norm = 1.0
|
||||
num_total_steps = 1000
|
||||
num_warmup_steps = 100
|
||||
warmup_proportion = float(num_warmup_steps) / float(num_total_steps) # 0.1
|
||||
@ -97,6 +103,7 @@ scheduler = WarmupLinearSchedule(optimizer, warmup_steps=num_warmup_steps, t_tot
|
||||
for batch in train_data:
|
||||
loss = model(batch)
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) # Gradient clipping is not in AdamW anymore (so you can use amp without issue)
|
||||
scheduler.step()
|
||||
optimizer.step()
|
||||
```
|
||||
|
@ -122,7 +122,7 @@ Here is the recommended way of saving the model, configuration and vocabulary to
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from pytorch_pretrained_bert import WEIGHTS_NAME, CONFIG_NAME
|
||||
from pytorch_transformers import WEIGHTS_NAME, CONFIG_NAME
|
||||
|
||||
output_dir = "./models/"
|
||||
|
||||
|
@ -74,7 +74,7 @@ according to a ``BertConfig`` class and then saved to disk under the filename ``
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from pytorch_pretrained_bert import BertModel, BertTokenizer, BertConfig
|
||||
from pytorch_transformers import BertModel, BertTokenizer, BertConfig
|
||||
import torch
|
||||
|
||||
enc = BertTokenizer.from_pretrained("bert-base-uncased")
|
||||
@ -105,6 +105,9 @@ according to a ``BertConfig`` class and then saved to disk under the filename ``
|
||||
# The model needs to be in evaluation mode
|
||||
model.eval()
|
||||
|
||||
# If you are instantiating the model with `from_pretrained` you can also easily set the TorchScript flag
|
||||
model = BertModel.from_pretrained("bert-base-uncased", torchscript=True)
|
||||
|
||||
# Creating the trace
|
||||
traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors])
|
||||
torch.jit.save(traced_model, "traced_bert.pt")
|
||||
|
@ -39,4 +39,4 @@ from .modeling_utils import (WEIGHTS_NAME, CONFIG_NAME, TF_WEIGHTS_NAME,
|
||||
from .optimization import (AdamW, ConstantLRSchedule, WarmupConstantSchedule, WarmupCosineSchedule,
|
||||
WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule)
|
||||
|
||||
from .file_utils import (PYTORCH_PRETRAINED_BERT_CACHE, cached_path)
|
||||
from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE, cached_path)
|
||||
|
@ -20,7 +20,7 @@ import argparse
|
||||
import torch
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from pytorch_pretrained_bert.modeling import BertModel
|
||||
from pytorch_transformers.modeling import BertModel
|
||||
|
||||
|
||||
def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:str):
|
||||
|
@ -38,10 +38,13 @@ except ImportError:
|
||||
try:
|
||||
from pathlib import Path
|
||||
PYTORCH_PRETRAINED_BERT_CACHE = Path(
|
||||
os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', default_cache_path))
|
||||
os.getenv('PYTORCH_TRANSFORMERS_CACHE', os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', default_cache_path)))
|
||||
except (AttributeError, ImportError):
|
||||
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
|
||||
default_cache_path)
|
||||
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_TRANSFORMERS_CACHE',
|
||||
os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
|
||||
default_cache_path))
|
||||
|
||||
PYTORCH_TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility
|
||||
|
||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@ -70,7 +73,7 @@ def filename_to_url(filename, cache_dir=None):
|
||||
Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
|
||||
"""
|
||||
if cache_dir is None:
|
||||
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
|
||||
cache_dir = PYTORCH_TRANSFORMERS_CACHE
|
||||
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
|
||||
cache_dir = str(cache_dir)
|
||||
|
||||
@ -98,7 +101,7 @@ def cached_path(url_or_filename, cache_dir=None):
|
||||
make sure the file exists and then return the path.
|
||||
"""
|
||||
if cache_dir is None:
|
||||
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
|
||||
cache_dir = PYTORCH_TRANSFORMERS_CACHE
|
||||
if sys.version_info[0] == 3 and isinstance(url_or_filename, Path):
|
||||
url_or_filename = str(url_or_filename)
|
||||
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
|
||||
@ -187,7 +190,7 @@ def get_from_cache(url, cache_dir=None):
|
||||
If it's not there, download it. Then return the path to the cached file.
|
||||
"""
|
||||
if cache_dir is None:
|
||||
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
|
||||
cache_dir = PYTORCH_TRANSFORMERS_CACHE
|
||||
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
|
||||
cache_dir = str(cache_dir)
|
||||
if sys.version_info[0] == 2 and not isinstance(cache_dir, str):
|
||||
|
@ -24,30 +24,37 @@ from pytorch_transformers.tokenization_bert import (BasicTokenizer,
|
||||
_is_control, _is_punctuation,
|
||||
_is_whitespace, VOCAB_FILES_NAMES)
|
||||
|
||||
from .tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory
|
||||
from .tokenization_tests_commons import CommonTestCases
|
||||
|
||||
class TokenizationTest(unittest.TestCase):
|
||||
class BertTokenizationTest(CommonTestCases.CommonTokenizerTester):
|
||||
|
||||
tokenizer_class = BertTokenizer
|
||||
|
||||
def setUp(self):
|
||||
super(BertTokenizationTest, self).setUp()
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
vocab_tokens = [
|
||||
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
|
||||
"##ing", ",", "low", "lowest",
|
||||
]
|
||||
with TemporaryDirectory() as tmpdirname:
|
||||
vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
|
||||
with open(vocab_file, "w", encoding='utf-8') as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
|
||||
with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||
|
||||
input_text = u"UNwant\u00E9d,running"
|
||||
output_text = u"unwanted, running"
|
||||
def get_tokenizer(self):
|
||||
return BertTokenizer.from_pretrained(self.tmpdirname)
|
||||
|
||||
create_and_check_tokenizer_commons(self, input_text, output_text, BertTokenizer, tmpdirname)
|
||||
def get_input_output_texts(self):
|
||||
input_text = u"UNwant\u00E9d,running"
|
||||
output_text = u"unwanted, running"
|
||||
return input_text, output_text
|
||||
|
||||
tokenizer = BertTokenizer(vocab_file)
|
||||
def test_full_tokenizer(self):
|
||||
tokenizer = BertTokenizer(self.vocab_file)
|
||||
|
||||
tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
|
||||
self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
|
||||
self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
|
||||
tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
|
||||
self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
|
||||
self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
|
||||
|
||||
def test_chinese(self):
|
||||
tokenizer = BasicTokenizer()
|
||||
|
@ -20,42 +20,49 @@ import json
|
||||
|
||||
from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer, VOCAB_FILES_NAMES
|
||||
|
||||
from .tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory
|
||||
from .tokenization_tests_commons import CommonTestCases
|
||||
|
||||
class GPT2TokenizationTest(unittest.TestCase):
|
||||
class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester):
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
""" Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """
|
||||
tokenizer_class = GPT2Tokenizer
|
||||
|
||||
def setUp(self):
|
||||
super(GPT2TokenizationTest, self).setUp()
|
||||
|
||||
# Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
|
||||
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
|
||||
"lo", "low", "er",
|
||||
"low", "lowest", "newer", "wider", "<unk>"]
|
||||
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
||||
merges = ["#version: 0.2", "l o", "lo w", "e r", ""]
|
||||
special_tokens_map = {"unk_token": "<unk>"}
|
||||
self.special_tokens_map = {"unk_token": "<unk>"}
|
||||
|
||||
with TemporaryDirectory() as tmpdirname:
|
||||
vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
|
||||
merges_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['merges_file'])
|
||||
with open(vocab_file, "w") as fp:
|
||||
fp.write(json.dumps(vocab_tokens))
|
||||
with open(merges_file, "w") as fp:
|
||||
fp.write("\n".join(merges))
|
||||
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
|
||||
self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file'])
|
||||
with open(self.vocab_file, "w") as fp:
|
||||
fp.write(json.dumps(vocab_tokens))
|
||||
with open(self.merges_file, "w") as fp:
|
||||
fp.write("\n".join(merges))
|
||||
|
||||
input_text = u"lower newer"
|
||||
output_text = u"lower<unk>newer"
|
||||
def get_tokenizer(self):
|
||||
return GPT2Tokenizer.from_pretrained(self.tmpdirname, **self.special_tokens_map)
|
||||
|
||||
create_and_check_tokenizer_commons(self, input_text, output_text, GPT2Tokenizer, tmpdirname, **special_tokens_map)
|
||||
def get_input_output_texts(self):
|
||||
input_text = u"lower newer"
|
||||
output_text = u"lower<unk>newer"
|
||||
return input_text, output_text
|
||||
|
||||
tokenizer = GPT2Tokenizer(vocab_file, merges_file, **special_tokens_map)
|
||||
text = "lower"
|
||||
bpe_tokens = ["low", "er"]
|
||||
tokens = tokenizer.tokenize(text)
|
||||
self.assertListEqual(tokens, bpe_tokens)
|
||||
def test_full_tokenizer(self):
|
||||
tokenizer = GPT2Tokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map)
|
||||
text = "lower"
|
||||
bpe_tokens = ["low", "er"]
|
||||
tokens = tokenizer.tokenize(text)
|
||||
self.assertListEqual(tokens, bpe_tokens)
|
||||
|
||||
input_tokens = tokens + [tokenizer.unk_token]
|
||||
input_bpe_tokens = [13, 12, 17]
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||
input_tokens = tokens + [tokenizer.unk_token]
|
||||
input_bpe_tokens = [13, 12, 17]
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -20,13 +20,17 @@ import json
|
||||
|
||||
from pytorch_transformers.tokenization_openai import OpenAIGPTTokenizer, VOCAB_FILES_NAMES
|
||||
|
||||
from .tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory
|
||||
from .tokenization_tests_commons import CommonTestCases
|
||||
|
||||
|
||||
class OpenAIGPTTokenizationTest(unittest.TestCase):
|
||||
class OpenAIGPTTokenizationTest(CommonTestCases.CommonTokenizerTester):
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
""" Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """
|
||||
tokenizer_class = OpenAIGPTTokenizer
|
||||
|
||||
def setUp(self):
|
||||
super(OpenAIGPTTokenizationTest, self).setUp()
|
||||
|
||||
# Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
|
||||
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
|
||||
"w</w>", "r</w>", "t</w>",
|
||||
"lo", "low", "er</w>",
|
||||
@ -34,30 +38,34 @@ class OpenAIGPTTokenizationTest(unittest.TestCase):
|
||||
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
||||
merges = ["#version: 0.2", "l o", "lo w", "e r</w>", ""]
|
||||
|
||||
with TemporaryDirectory() as tmpdirname:
|
||||
vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
|
||||
merges_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['merges_file'])
|
||||
with open(vocab_file, "w") as fp:
|
||||
fp.write(json.dumps(vocab_tokens))
|
||||
with open(merges_file, "w") as fp:
|
||||
fp.write("\n".join(merges))
|
||||
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
|
||||
self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file'])
|
||||
with open(self.vocab_file, "w") as fp:
|
||||
fp.write(json.dumps(vocab_tokens))
|
||||
with open(self.merges_file, "w") as fp:
|
||||
fp.write("\n".join(merges))
|
||||
|
||||
input_text = u"lower newer"
|
||||
output_text = u"lower newer"
|
||||
def get_tokenizer(self):
|
||||
return OpenAIGPTTokenizer.from_pretrained(self.tmpdirname)
|
||||
|
||||
create_and_check_tokenizer_commons(self, input_text, output_text, OpenAIGPTTokenizer, tmpdirname)
|
||||
def get_input_output_texts(self):
|
||||
input_text = u"lower newer"
|
||||
output_text = u"lower newer"
|
||||
return input_text, output_text
|
||||
|
||||
tokenizer = OpenAIGPTTokenizer(vocab_file, merges_file)
|
||||
|
||||
text = "lower"
|
||||
bpe_tokens = ["low", "er</w>"]
|
||||
tokens = tokenizer.tokenize(text)
|
||||
self.assertListEqual(tokens, bpe_tokens)
|
||||
def test_full_tokenizer(self):
|
||||
tokenizer = OpenAIGPTTokenizer(self.vocab_file, self.merges_file)
|
||||
|
||||
input_tokens = tokens + ["<unk>"]
|
||||
input_bpe_tokens = [14, 15, 20]
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||
text = "lower"
|
||||
bpe_tokens = ["low", "er</w>"]
|
||||
tokens = tokenizer.tokenize(text)
|
||||
self.assertListEqual(tokens, bpe_tokens)
|
||||
|
||||
input_tokens = tokens + ["<unk>"]
|
||||
input_bpe_tokens = [14, 15, 20]
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -19,6 +19,7 @@ import sys
|
||||
from io import open
|
||||
import tempfile
|
||||
import shutil
|
||||
import unittest
|
||||
|
||||
if sys.version_info[0] == 2:
|
||||
import cPickle as pickle
|
||||
@ -36,113 +37,124 @@ else:
|
||||
unicode = str
|
||||
|
||||
|
||||
def create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, **kwargs):
|
||||
tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs)
|
||||
class CommonTestCases:
|
||||
|
||||
before_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running")
|
||||
class CommonTokenizerTester(unittest.TestCase):
|
||||
|
||||
with TemporaryDirectory() as tmpdirname:
|
||||
tokenizer.save_pretrained(tmpdirname)
|
||||
tokenizer = tokenizer.from_pretrained(tmpdirname)
|
||||
tokenizer_class = None
|
||||
|
||||
after_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running")
|
||||
tester.assertListEqual(before_tokens, after_tokens)
|
||||
def setUp(self):
|
||||
self.tmpdirname = tempfile.mkdtemp()
|
||||
|
||||
def create_and_check_pickle_tokenizer(tester, tokenizer_class, *inputs, **kwargs):
|
||||
tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs)
|
||||
tester.assertIsNotNone(tokenizer)
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmpdirname)
|
||||
|
||||
text = u"Munich and Berlin are nice cities"
|
||||
subwords = tokenizer.tokenize(text)
|
||||
def get_tokenizer(self):
|
||||
raise NotImplementedError
|
||||
|
||||
with TemporaryDirectory() as tmpdirname:
|
||||
def get_input_output_texts(self):
|
||||
raise NotImplementedError
|
||||
|
||||
filename = os.path.join(tmpdirname, u"tokenizer.bin")
|
||||
pickle.dump(tokenizer, open(filename, "wb"))
|
||||
def test_save_and_load_tokenizer(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
tokenizer_new = pickle.load(open(filename, "rb"))
|
||||
before_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running")
|
||||
|
||||
subwords_loaded = tokenizer_new.tokenize(text)
|
||||
with TemporaryDirectory() as tmpdirname:
|
||||
tokenizer.save_pretrained(tmpdirname)
|
||||
tokenizer = tokenizer.from_pretrained(tmpdirname)
|
||||
|
||||
tester.assertListEqual(subwords, subwords_loaded)
|
||||
after_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running")
|
||||
self.assertListEqual(before_tokens, after_tokens)
|
||||
|
||||
def test_pickle_tokenizer(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
self.assertIsNotNone(tokenizer)
|
||||
|
||||
text = u"Munich and Berlin are nice cities"
|
||||
subwords = tokenizer.tokenize(text)
|
||||
|
||||
with TemporaryDirectory() as tmpdirname:
|
||||
|
||||
filename = os.path.join(tmpdirname, u"tokenizer.bin")
|
||||
pickle.dump(tokenizer, open(filename, "wb"))
|
||||
|
||||
tokenizer_new = pickle.load(open(filename, "rb"))
|
||||
|
||||
subwords_loaded = tokenizer_new.tokenize(text)
|
||||
|
||||
self.assertListEqual(subwords, subwords_loaded)
|
||||
|
||||
|
||||
def create_and_check_add_tokens_tokenizer(tester, tokenizer_class, *inputs, **kwargs):
|
||||
tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs)
|
||||
def test_add_tokens_tokenizer(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
vocab_size = tokenizer.vocab_size
|
||||
all_size = len(tokenizer)
|
||||
vocab_size = tokenizer.vocab_size
|
||||
all_size = len(tokenizer)
|
||||
|
||||
tester.assertNotEqual(vocab_size, 0)
|
||||
tester.assertEqual(vocab_size, all_size)
|
||||
self.assertNotEqual(vocab_size, 0)
|
||||
self.assertEqual(vocab_size, all_size)
|
||||
|
||||
new_toks = ["aaaaabbbbbb", "cccccccccdddddddd"]
|
||||
added_toks = tokenizer.add_tokens(new_toks)
|
||||
vocab_size_2 = tokenizer.vocab_size
|
||||
all_size_2 = len(tokenizer)
|
||||
new_toks = ["aaaaabbbbbb", "cccccccccdddddddd"]
|
||||
added_toks = tokenizer.add_tokens(new_toks)
|
||||
vocab_size_2 = tokenizer.vocab_size
|
||||
all_size_2 = len(tokenizer)
|
||||
|
||||
tester.assertNotEqual(vocab_size_2, 0)
|
||||
tester.assertEqual(vocab_size, vocab_size_2)
|
||||
tester.assertEqual(added_toks, len(new_toks))
|
||||
tester.assertEqual(all_size_2, all_size + len(new_toks))
|
||||
self.assertNotEqual(vocab_size_2, 0)
|
||||
self.assertEqual(vocab_size, vocab_size_2)
|
||||
self.assertEqual(added_toks, len(new_toks))
|
||||
self.assertEqual(all_size_2, all_size + len(new_toks))
|
||||
|
||||
tokens = tokenizer.encode("aaaaabbbbbb low cccccccccdddddddd l")
|
||||
tester.assertGreaterEqual(len(tokens), 4)
|
||||
tester.assertGreater(tokens[0], tokenizer.vocab_size - 1)
|
||||
tester.assertGreater(tokens[-2], tokenizer.vocab_size - 1)
|
||||
tokens = tokenizer.encode("aaaaabbbbbb low cccccccccdddddddd l")
|
||||
self.assertGreaterEqual(len(tokens), 4)
|
||||
self.assertGreater(tokens[0], tokenizer.vocab_size - 1)
|
||||
self.assertGreater(tokens[-2], tokenizer.vocab_size - 1)
|
||||
|
||||
new_toks_2 = {'eos_token': ">>>>|||<||<<|<<",
|
||||
'pad_token': "<<<<<|||>|>>>>|>"}
|
||||
added_toks_2 = tokenizer.add_special_tokens(new_toks_2)
|
||||
vocab_size_3 = tokenizer.vocab_size
|
||||
all_size_3 = len(tokenizer)
|
||||
new_toks_2 = {'eos_token': ">>>>|||<||<<|<<",
|
||||
'pad_token': "<<<<<|||>|>>>>|>"}
|
||||
added_toks_2 = tokenizer.add_special_tokens(new_toks_2)
|
||||
vocab_size_3 = tokenizer.vocab_size
|
||||
all_size_3 = len(tokenizer)
|
||||
|
||||
tester.assertNotEqual(vocab_size_3, 0)
|
||||
tester.assertEqual(vocab_size, vocab_size_3)
|
||||
tester.assertEqual(added_toks_2, len(new_toks_2))
|
||||
tester.assertEqual(all_size_3, all_size_2 + len(new_toks_2))
|
||||
self.assertNotEqual(vocab_size_3, 0)
|
||||
self.assertEqual(vocab_size, vocab_size_3)
|
||||
self.assertEqual(added_toks_2, len(new_toks_2))
|
||||
self.assertEqual(all_size_3, all_size_2 + len(new_toks_2))
|
||||
|
||||
tokens = tokenizer.encode(">>>>|||<||<<|<< aaaaabbbbbb low cccccccccdddddddd <<<<<|||>|>>>>|> l")
|
||||
tokens = tokenizer.encode(">>>>|||<||<<|<< aaaaabbbbbb low cccccccccdddddddd <<<<<|||>|>>>>|> l")
|
||||
|
||||
tester.assertGreaterEqual(len(tokens), 6)
|
||||
tester.assertGreater(tokens[0], tokenizer.vocab_size - 1)
|
||||
tester.assertGreater(tokens[0], tokens[1])
|
||||
tester.assertGreater(tokens[-2], tokenizer.vocab_size - 1)
|
||||
tester.assertGreater(tokens[-2], tokens[-3])
|
||||
tester.assertEqual(tokens[0], tokenizer.convert_tokens_to_ids(tokenizer.eos_token))
|
||||
tester.assertEqual(tokens[-2], tokenizer.convert_tokens_to_ids(tokenizer.pad_token))
|
||||
self.assertGreaterEqual(len(tokens), 6)
|
||||
self.assertGreater(tokens[0], tokenizer.vocab_size - 1)
|
||||
self.assertGreater(tokens[0], tokens[1])
|
||||
self.assertGreater(tokens[-2], tokenizer.vocab_size - 1)
|
||||
self.assertGreater(tokens[-2], tokens[-3])
|
||||
self.assertEqual(tokens[0], tokenizer.convert_tokens_to_ids(tokenizer.eos_token))
|
||||
self.assertEqual(tokens[-2], tokenizer.convert_tokens_to_ids(tokenizer.pad_token))
|
||||
|
||||
|
||||
def create_and_check_required_methods_tokenizer(tester, input_text, output_text, tokenizer_class, *inputs, **kwargs):
|
||||
tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs)
|
||||
def test_required_methods_tokenizer(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
input_text, output_text = self.get_input_output_texts()
|
||||
|
||||
tokens = tokenizer.tokenize(input_text)
|
||||
ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
ids_2 = tokenizer.encode(input_text)
|
||||
tester.assertListEqual(ids, ids_2)
|
||||
tokens = tokenizer.tokenize(input_text)
|
||||
ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
ids_2 = tokenizer.encode(input_text)
|
||||
self.assertListEqual(ids, ids_2)
|
||||
|
||||
tokens_2 = tokenizer.convert_ids_to_tokens(ids)
|
||||
text_2 = tokenizer.decode(ids)
|
||||
tokens_2 = tokenizer.convert_ids_to_tokens(ids)
|
||||
text_2 = tokenizer.decode(ids)
|
||||
|
||||
tester.assertEqual(text_2, output_text)
|
||||
self.assertEqual(text_2, output_text)
|
||||
|
||||
tester.assertNotEqual(len(tokens_2), 0)
|
||||
tester.assertIsInstance(text_2, (str, unicode))
|
||||
self.assertNotEqual(len(tokens_2), 0)
|
||||
self.assertIsInstance(text_2, (str, unicode))
|
||||
|
||||
|
||||
def create_and_check_pretrained_model_lists(tester, input_text, output_text, tokenizer_class, *inputs, **kwargs):
|
||||
weights_list = list(tokenizer_class.max_model_input_sizes.keys())
|
||||
weights_lists_2 = []
|
||||
for file_id, map_list in tokenizer_class.pretrained_vocab_files_map.items():
|
||||
weights_lists_2.append(list(map_list.keys()))
|
||||
def test_pretrained_model_lists(self):
|
||||
weights_list = list(self.tokenizer_class.max_model_input_sizes.keys())
|
||||
weights_lists_2 = []
|
||||
for file_id, map_list in self.tokenizer_class.pretrained_vocab_files_map.items():
|
||||
weights_lists_2.append(list(map_list.keys()))
|
||||
|
||||
for weights_list_2 in weights_lists_2:
|
||||
tester.assertListEqual(weights_list, weights_list_2)
|
||||
|
||||
|
||||
def create_and_check_tokenizer_commons(tester, input_text, output_text, tokenizer_class, *inputs, **kwargs):
|
||||
create_and_check_pretrained_model_lists(tester, input_text, output_text, tokenizer_class, *inputs, **kwargs)
|
||||
create_and_check_required_methods_tokenizer(tester, input_text, output_text, tokenizer_class, *inputs, **kwargs)
|
||||
create_and_check_add_tokens_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
|
||||
create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
|
||||
create_and_check_pickle_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
|
||||
for weights_list_2 in weights_lists_2:
|
||||
self.assertListEqual(weights_list, weights_list_2)
|
||||
|
@ -20,32 +20,39 @@ from io import open
|
||||
|
||||
from pytorch_transformers.tokenization_transfo_xl import TransfoXLTokenizer, VOCAB_FILES_NAMES
|
||||
|
||||
from.tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory
|
||||
from.tokenization_tests_commons import CommonTestCases
|
||||
|
||||
class TransfoXLTokenizationTest(unittest.TestCase):
|
||||
class TransfoXLTokenizationTest(CommonTestCases.CommonTokenizerTester):
|
||||
|
||||
tokenizer_class = TransfoXLTokenizer
|
||||
|
||||
def setUp(self):
|
||||
super(TransfoXLTokenizationTest, self).setUp()
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
vocab_tokens = [
|
||||
"<unk>", "[CLS]", "[SEP]", "want", "unwanted", "wa", "un",
|
||||
"running", ",", "low", "l",
|
||||
]
|
||||
with TemporaryDirectory() as tmpdirname:
|
||||
vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
|
||||
with open(vocab_file, "w", encoding='utf-8') as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
|
||||
with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||
|
||||
input_text = u"<unk> UNwanted , running"
|
||||
output_text = u"<unk> unwanted, running"
|
||||
def get_tokenizer(self):
|
||||
return TransfoXLTokenizer.from_pretrained(self.tmpdirname, lower_case=True)
|
||||
|
||||
create_and_check_tokenizer_commons(self, input_text, output_text, TransfoXLTokenizer, tmpdirname, lower_case=True)
|
||||
def get_input_output_texts(self):
|
||||
input_text = u"<unk> UNwanted , running"
|
||||
output_text = u"<unk> unwanted, running"
|
||||
return input_text, output_text
|
||||
|
||||
tokenizer = TransfoXLTokenizer(vocab_file=vocab_file, lower_case=True)
|
||||
def test_full_tokenizer(self):
|
||||
tokenizer = TransfoXLTokenizer(vocab_file=self.vocab_file, lower_case=True)
|
||||
|
||||
tokens = tokenizer.tokenize(u"<unk> UNwanted , running")
|
||||
self.assertListEqual(tokens, ["<unk>", "unwanted", ",", "running"])
|
||||
tokens = tokenizer.tokenize(u"<unk> UNwanted , running")
|
||||
self.assertListEqual(tokens, ["<unk>", "unwanted", ",", "running"])
|
||||
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7])
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7])
|
||||
|
||||
def test_full_tokenizer_lower(self):
|
||||
tokenizer = TransfoXLTokenizer(lower_case=True)
|
||||
|
@ -20,12 +20,16 @@ import json
|
||||
|
||||
from pytorch_transformers.tokenization_xlm import XLMTokenizer, VOCAB_FILES_NAMES
|
||||
|
||||
from .tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory
|
||||
from .tokenization_tests_commons import CommonTestCases
|
||||
|
||||
class XLMTokenizationTest(unittest.TestCase):
|
||||
class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester):
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
""" Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """
|
||||
tokenizer_class = XLMTokenizer
|
||||
|
||||
def setUp(self):
|
||||
super(XLMTokenizationTest, self).setUp()
|
||||
|
||||
# Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
|
||||
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n",
|
||||
"w</w>", "r</w>", "t</w>",
|
||||
"lo", "low", "er</w>",
|
||||
@ -33,30 +37,34 @@ class XLMTokenizationTest(unittest.TestCase):
|
||||
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
||||
merges = ["l o 123", "lo w 1456", "e r</w> 1789", ""]
|
||||
|
||||
with TemporaryDirectory() as tmpdirname:
|
||||
vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
|
||||
merges_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['merges_file'])
|
||||
with open(vocab_file, "w") as fp:
|
||||
fp.write(json.dumps(vocab_tokens))
|
||||
with open(merges_file, "w") as fp:
|
||||
fp.write("\n".join(merges))
|
||||
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file'])
|
||||
self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['merges_file'])
|
||||
with open(self.vocab_file, "w") as fp:
|
||||
fp.write(json.dumps(vocab_tokens))
|
||||
with open(self.merges_file, "w") as fp:
|
||||
fp.write("\n".join(merges))
|
||||
|
||||
input_text = u"lower newer"
|
||||
output_text = u"lower newer"
|
||||
def get_tokenizer(self):
|
||||
return XLMTokenizer.from_pretrained(self.tmpdirname)
|
||||
|
||||
create_and_check_tokenizer_commons(self, input_text, output_text, XLMTokenizer, tmpdirname)
|
||||
def get_input_output_texts(self):
|
||||
input_text = u"lower newer"
|
||||
output_text = u"lower newer"
|
||||
return input_text, output_text
|
||||
|
||||
tokenizer = XLMTokenizer(vocab_file, merges_file)
|
||||
def test_full_tokenizer(self):
|
||||
""" Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """
|
||||
tokenizer = XLMTokenizer(self.vocab_file, self.merges_file)
|
||||
|
||||
text = "lower"
|
||||
bpe_tokens = ["low", "er</w>"]
|
||||
tokens = tokenizer.tokenize(text)
|
||||
self.assertListEqual(tokens, bpe_tokens)
|
||||
text = "lower"
|
||||
bpe_tokens = ["low", "er</w>"]
|
||||
tokens = tokenizer.tokenize(text)
|
||||
self.assertListEqual(tokens, bpe_tokens)
|
||||
|
||||
input_tokens = tokens + ["<unk>"]
|
||||
input_bpe_tokens = [14, 15, 20]
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||
input_tokens = tokens + ["<unk>"]
|
||||
input_bpe_tokens = [14, 15, 20]
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -19,48 +19,58 @@ import unittest
|
||||
|
||||
from pytorch_transformers.tokenization_xlnet import (XLNetTokenizer, SPIECE_UNDERLINE)
|
||||
|
||||
from .tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory
|
||||
from .tokenization_tests_commons import CommonTestCases
|
||||
|
||||
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)),
|
||||
'fixtures/test_sentencepiece.model')
|
||||
|
||||
class XLNetTokenizationTest(unittest.TestCase):
|
||||
class XLNetTokenizationTest(CommonTestCases.CommonTokenizerTester):
|
||||
|
||||
tokenizer_class = XLNetTokenizer
|
||||
|
||||
def setUp(self):
|
||||
super(XLNetTokenizationTest, self).setUp()
|
||||
|
||||
# We have a SentencePiece fixture for testing
|
||||
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True)
|
||||
tokenizer.save_pretrained(self.tmpdirname)
|
||||
|
||||
def get_tokenizer(self):
|
||||
return XLNetTokenizer.from_pretrained(self.tmpdirname)
|
||||
|
||||
def get_input_output_texts(self):
|
||||
input_text = u"This is a test"
|
||||
output_text = u"This is a test"
|
||||
return input_text, output_text
|
||||
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True)
|
||||
|
||||
with TemporaryDirectory() as tmpdirname:
|
||||
tokenizer.save_pretrained(tmpdirname)
|
||||
tokens = tokenizer.tokenize(u'This is a test')
|
||||
self.assertListEqual(tokens, [u'▁This', u'▁is', u'▁a', u'▁t', u'est'])
|
||||
|
||||
input_text = u"This is a test"
|
||||
output_text = u"This is a test"
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(tokens), [285, 46, 10, 170, 382])
|
||||
|
||||
create_and_check_tokenizer_commons(self, input_text, output_text, XLNetTokenizer, tmpdirname)
|
||||
tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.")
|
||||
self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b',
|
||||
u'or', u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'',
|
||||
u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
|
||||
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's', u'é', u'.'])
|
||||
ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
self.assertListEqual(
|
||||
ids, [8, 21, 84, 55, 24, 19, 7, 0,
|
||||
602, 347, 347, 347, 3, 12, 66,
|
||||
46, 72, 80, 6, 0, 4])
|
||||
|
||||
tokens = tokenizer.tokenize(u'This is a test')
|
||||
self.assertListEqual(tokens, [u'▁This', u'▁is', u'▁a', u'▁t', u'est'])
|
||||
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(tokens), [285, 46, 10, 170, 382])
|
||||
|
||||
tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.")
|
||||
self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b',
|
||||
u'or', u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'',
|
||||
u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
|
||||
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's', u'é', u'.'])
|
||||
ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
self.assertListEqual(
|
||||
ids, [8, 21, 84, 55, 24, 19, 7, 0,
|
||||
602, 347, 347, 347, 3, 12, 66,
|
||||
46, 72, 80, 6, 0, 4])
|
||||
|
||||
back_tokens = tokenizer.convert_ids_to_tokens(ids)
|
||||
self.assertListEqual(back_tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b',
|
||||
u'or', u'n', SPIECE_UNDERLINE + u'in',
|
||||
SPIECE_UNDERLINE + u'', u'<unk>', u'2', u'0', u'0', u'0', u',',
|
||||
SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
|
||||
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's',
|
||||
u'<unk>', u'.'])
|
||||
back_tokens = tokenizer.convert_ids_to_tokens(ids)
|
||||
self.assertListEqual(back_tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b',
|
||||
u'or', u'n', SPIECE_UNDERLINE + u'in',
|
||||
SPIECE_UNDERLINE + u'', u'<unk>', u'2', u'0', u'0', u'0', u',',
|
||||
SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
|
||||
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's',
|
||||
u'<unk>', u'.'])
|
||||
|
||||
def test_tokenizer_lower(self):
|
||||
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=True)
|
||||
|
@ -86,7 +86,7 @@ def whitespace_tokenize(text):
|
||||
class BertTokenizer(PreTrainedTokenizer):
|
||||
r"""
|
||||
Constructs a BertTokenizer.
|
||||
:class:`~pytorch_pretrained_bert.BertTokenizer` runs end-to-end tokenization: punctuation splitting + wordpiece
|
||||
:class:`~pytorch_transformers.BertTokenizer` runs end-to-end tokenization: punctuation splitting + wordpiece
|
||||
|
||||
Args:
|
||||
vocab_file: Path to a one-wordpiece-per-line vocabulary file
|
||||
|
@ -125,42 +125,34 @@ class PreTrainedTokenizer(object):
|
||||
|
||||
@bos_token.setter
|
||||
def bos_token(self, value):
|
||||
self.add_tokens([value])
|
||||
self._bos_token = value
|
||||
|
||||
@eos_token.setter
|
||||
def eos_token(self, value):
|
||||
self.add_tokens([value])
|
||||
self._eos_token = value
|
||||
|
||||
@unk_token.setter
|
||||
def unk_token(self, value):
|
||||
self.add_tokens([value])
|
||||
self._unk_token = value
|
||||
|
||||
@sep_token.setter
|
||||
def sep_token(self, value):
|
||||
self.add_tokens([value])
|
||||
self._sep_token = value
|
||||
|
||||
@pad_token.setter
|
||||
def pad_token(self, value):
|
||||
self.add_tokens([value])
|
||||
self._pad_token = value
|
||||
|
||||
@cls_token.setter
|
||||
def cls_token(self, value):
|
||||
self.add_tokens([value])
|
||||
self._cls_token = value
|
||||
|
||||
@mask_token.setter
|
||||
def mask_token(self, value):
|
||||
self.add_tokens([value])
|
||||
self._mask_token = value
|
||||
|
||||
@additional_special_tokens.setter
|
||||
def additional_special_tokens(self, value):
|
||||
self.add_tokens(value)
|
||||
self._additional_special_tokens = value
|
||||
|
||||
def __init__(self, max_len=None, **kwargs):
|
||||
@ -179,6 +171,10 @@ class PreTrainedTokenizer(object):
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if key in self.SPECIAL_TOKENS_ATTRIBUTES:
|
||||
if key == 'additional_special_tokens':
|
||||
assert isinstance(value, (list, tuple)) and all(isinstance(t, str) or (six.PY2 and isinstance(t, unicode)) for t in value)
|
||||
else:
|
||||
assert isinstance(value, str) or (six.PY2 and isinstance(value, unicode))
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
@ -415,15 +411,39 @@ class PreTrainedTokenizer(object):
|
||||
|
||||
Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them).
|
||||
|
||||
Returns:
|
||||
Number of tokens added to the vocabulary.
|
||||
|
||||
Examples::
|
||||
|
||||
# Let's see how to add a new classification token to GPT-2
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
model = GPT2Model.from_pretrained('gpt2')
|
||||
|
||||
special_tokens_dict = {'cls_token': '<CLS>'}
|
||||
|
||||
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
|
||||
print('We have added', num_added_toks, 'tokens')
|
||||
model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
|
||||
|
||||
assert tokenizer.cls_token == '<CLS>'
|
||||
"""
|
||||
if not special_tokens_dict:
|
||||
return 0
|
||||
|
||||
added_tokens = 0
|
||||
for key, value in special_tokens_dict.items():
|
||||
assert key in self.SPECIAL_TOKENS_ATTRIBUTES
|
||||
if key == 'additional_special_tokens':
|
||||
assert isinstance(value, (list, tuple)) and all(isinstance(t, str) or (six.PY2 and isinstance(t, unicode)) for t in value)
|
||||
added_tokens += self.add_tokens(value)
|
||||
else:
|
||||
assert isinstance(value, str) or (six.PY2 and isinstance(value, unicode))
|
||||
added_tokens += self.add_tokens([value])
|
||||
logger.info("Assigning %s to the %s key of the tokenizer", value, key)
|
||||
setattr(self, key, value)
|
||||
|
||||
return added_tokens
|
||||
|
||||
def tokenize(self, text, **kwargs):
|
||||
""" Converts a string in a sequence of tokens (string), using the tokenizer.
|
||||
|
Loading…
Reference in New Issue
Block a user