diff --git a/pytorch_transformers/tests/tokenization_bert_test.py b/pytorch_transformers/tests/tokenization_bert_test.py index 220bf453467..dbbe9ac5ea2 100644 --- a/pytorch_transformers/tests/tokenization_bert_test.py +++ b/pytorch_transformers/tests/tokenization_bert_test.py @@ -24,7 +24,7 @@ from pytorch_transformers.tokenization_bert import (BasicTokenizer, _is_control, _is_punctuation, _is_whitespace, VOCAB_FILES_NAMES) -from .tokenization_tests_commons import create_and_check_tokenizer_commons +from .tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory class TokenizationTest(unittest.TestCase): @@ -33,21 +33,18 @@ class TokenizationTest(unittest.TestCase): "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", "##ing", ",", "low", "lowest", ] - vocab_directory = "/tmp/" - vocab_file = os.path.join(vocab_directory, VOCAB_FILES_NAMES['vocab_file']) - with open(vocab_file, "w", encoding='utf-8') as vocab_writer: - vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) - vocab_file = vocab_writer.name + with TemporaryDirectory() as tmpdirname: + vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file']) + with open(vocab_file, "w", encoding='utf-8') as vocab_writer: + vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) - create_and_check_tokenizer_commons(self, BertTokenizer, pretrained_model_name_or_path=vocab_directory) + create_and_check_tokenizer_commons(self, BertTokenizer, tmpdirname) - tokenizer = BertTokenizer(vocab_file) + tokenizer = BertTokenizer(vocab_file) - tokens = tokenizer.tokenize(u"UNwant\u00E9d,running") - self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"]) - self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) - - os.remove(vocab_file) + tokens = tokenizer.tokenize(u"UNwant\u00E9d,running") + self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"]) + self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) def test_chinese(self): tokenizer = BasicTokenizer() diff --git a/pytorch_transformers/tests/tokenization_gpt2_test.py b/pytorch_transformers/tests/tokenization_gpt2_test.py index 30959ceed1d..8ae8896187e 100644 --- a/pytorch_transformers/tests/tokenization_gpt2_test.py +++ b/pytorch_transformers/tests/tokenization_gpt2_test.py @@ -17,11 +17,10 @@ from __future__ import absolute_import, division, print_function, unicode_litera import os import unittest import json -import tempfile from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer, VOCAB_FILES_NAMES -from .tokenization_tests_commons import create_and_check_tokenizer_commons +from .tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory class GPT2TokenizationTest(unittest.TestCase): @@ -34,7 +33,7 @@ class GPT2TokenizationTest(unittest.TestCase): merges = ["#version: 0.2", "l o", "lo w", "e r", ""] special_tokens_map = {"unk_token": ""} - with tempfile.TemporaryDirectory() as tmpdirname: + with TemporaryDirectory() as tmpdirname: vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file']) merges_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['merges_file']) with open(vocab_file, "w") as fp: diff --git a/pytorch_transformers/tests/tokenization_openai_test.py b/pytorch_transformers/tests/tokenization_openai_test.py index 22f7d700176..f5c99877d72 100644 --- a/pytorch_transformers/tests/tokenization_openai_test.py +++ b/pytorch_transformers/tests/tokenization_openai_test.py @@ -17,11 +17,10 @@ from __future__ import absolute_import, division, print_function, unicode_litera import os import unittest import json -import tempfile from pytorch_transformers.tokenization_openai import OpenAIGPTTokenizer, VOCAB_FILES_NAMES -from.tokenization_tests_commons import create_and_check_tokenizer_commons +from .tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory class OpenAIGPTTokenizationTest(unittest.TestCase): @@ -35,7 +34,7 @@ class OpenAIGPTTokenizationTest(unittest.TestCase): vocab_tokens = dict(zip(vocab, range(len(vocab)))) merges = ["#version: 0.2", "l o", "lo w", "e r", ""] - with tempfile.TemporaryDirectory() as tmpdirname: + with TemporaryDirectory() as tmpdirname: vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file']) merges_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['merges_file']) with open(vocab_file, "w") as fp: diff --git a/pytorch_transformers/tests/tokenization_tests_commons.py b/pytorch_transformers/tests/tokenization_tests_commons.py index 07f962bcab5..4e5fe837069 100644 --- a/pytorch_transformers/tests/tokenization_tests_commons.py +++ b/pytorch_transformers/tests/tokenization_tests_commons.py @@ -14,18 +14,25 @@ # limitations under the License. from __future__ import absolute_import, division, print_function, unicode_literals -import os import sys from io import open import tempfile - -if sys.version_info[0] == 3: - unicode = str +import shutil if sys.version_info[0] == 2: import cPickle as pickle + + class TemporaryDirectory(object): + """Context manager for tempfile.mkdtemp() so it's usable with "with" statement.""" + def __enter__(self): + self.name = tempfile.mkdtemp() + return self.name + def __exit__(self, exc_type, exc_value, traceback): + shutil.rmtree(self.name) else: import pickle + TemporaryDirectory = tempfile.TemporaryDirectory + unicode = str def create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, **kwargs): @@ -33,7 +40,7 @@ def create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, * before_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running") - with tempfile.TemporaryDirectory() as tmpdirname: + with TemporaryDirectory() as tmpdirname: tokenizer.save_pretrained(tmpdirname) tokenizer = tokenizer.from_pretrained(tmpdirname) diff --git a/pytorch_transformers/tests/tokenization_transfo_xl_test.py b/pytorch_transformers/tests/tokenization_transfo_xl_test.py index a4ddd357b9d..135f48b0ef4 100644 --- a/pytorch_transformers/tests/tokenization_transfo_xl_test.py +++ b/pytorch_transformers/tests/tokenization_transfo_xl_test.py @@ -17,11 +17,10 @@ from __future__ import absolute_import, division, print_function, unicode_litera import os import unittest from io import open -import tempfile from pytorch_transformers.tokenization_transfo_xl import TransfoXLTokenizer, VOCAB_FILES_NAMES -from.tokenization_tests_commons import create_and_check_tokenizer_commons +from.tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory class TransfoXLTokenizationTest(unittest.TestCase): @@ -30,7 +29,7 @@ class TransfoXLTokenizationTest(unittest.TestCase): "", "[CLS]", "[SEP]", "want", "unwanted", "wa", "un", "running", ",", "low", "l", ] - with tempfile.TemporaryDirectory() as tmpdirname: + with TemporaryDirectory() as tmpdirname: vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file']) with open(vocab_file, "w", encoding='utf-8') as vocab_writer: vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) diff --git a/pytorch_transformers/tests/tokenization_xlm_test.py b/pytorch_transformers/tests/tokenization_xlm_test.py index b543ed23f87..827ec1606e1 100644 --- a/pytorch_transformers/tests/tokenization_xlm_test.py +++ b/pytorch_transformers/tests/tokenization_xlm_test.py @@ -17,11 +17,10 @@ from __future__ import absolute_import, division, print_function, unicode_litera import os import unittest import json -import tempfile from pytorch_transformers.tokenization_xlm import XLMTokenizer, VOCAB_FILES_NAMES -from .tokenization_tests_commons import create_and_check_tokenizer_commons +from .tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory class XLMTokenizationTest(unittest.TestCase): @@ -34,7 +33,7 @@ class XLMTokenizationTest(unittest.TestCase): vocab_tokens = dict(zip(vocab, range(len(vocab)))) merges = ["l o 123", "lo w 1456", "e r 1789", ""] - with tempfile.TemporaryDirectory() as tmpdirname: + with TemporaryDirectory() as tmpdirname: vocab_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['vocab_file']) merges_file = os.path.join(tmpdirname, VOCAB_FILES_NAMES['merges_file']) with open(vocab_file, "w") as fp: diff --git a/pytorch_transformers/tests/tokenization_xlnet_test.py b/pytorch_transformers/tests/tokenization_xlnet_test.py index 8fc98209ba4..e50fe9243d3 100644 --- a/pytorch_transformers/tests/tokenization_xlnet_test.py +++ b/pytorch_transformers/tests/tokenization_xlnet_test.py @@ -16,11 +16,10 @@ from __future__ import absolute_import, division, print_function, unicode_litera import os import unittest -import tempfile -from pytorch_transformers.tokenization_xlnet import (XLNetTokenizer, SPIECE_UNDERLINE, VOCAB_FILES_NAMES) +from pytorch_transformers.tokenization_xlnet import (XLNetTokenizer, SPIECE_UNDERLINE) -from.tokenization_tests_commons import create_and_check_tokenizer_commons +from .tokenization_tests_commons import create_and_check_tokenizer_commons, TemporaryDirectory SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'fixtures/test_sentencepiece.model') @@ -30,7 +29,7 @@ class XLNetTokenizationTest(unittest.TestCase): def test_full_tokenizer(self): tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True) - with tempfile.TemporaryDirectory() as tmpdirname: + with TemporaryDirectory() as tmpdirname: tokenizer.save_pretrained(tmpdirname) create_and_check_tokenizer_commons(self, XLNetTokenizer, tmpdirname) diff --git a/pytorch_transformers/tokenization_utils.py b/pytorch_transformers/tokenization_utils.py index b191dd22e6e..60081893c8f 100644 --- a/pytorch_transformers/tokenization_utils.py +++ b/pytorch_transformers/tokenization_utils.py @@ -231,8 +231,7 @@ class PreTrainedTokenizer(object): # Add supplementary tokens. if added_tokens_file is not None: - added_tokens = json.load(open(added_tokens_file, encoding="utf-8")) - added_tok_encoder = dict((tok, len(tokenizer) + i) for i, tok in enumerate(added_tokens)) + added_tok_encoder = json.load(open(added_tokens_file, encoding="utf-8")) added_tok_decoder = {v:k for k, v in added_tok_encoder.items()} tokenizer.added_tokens_encoder.update(added_tok_encoder) tokenizer.added_tokens_decoder.update(added_tok_decoder) @@ -256,7 +255,11 @@ class PreTrainedTokenizer(object): f.write(json.dumps(self.special_tokens_map, ensure_ascii=False)) with open(added_tokens_file, 'w', encoding='utf-8') as f: - f.write(json.dumps(self.added_tokens_decoder, ensure_ascii=False)) + if self.added_tokens_encoder: + out_str = json.dumps(self.added_tokens_decoder, ensure_ascii=False) + else: + out_str = u"{}" + f.write(out_str) vocab_files = self.save_vocabulary(save_directory)