tokenization updates

This commit is contained in:
thomwolf 2019-04-15 14:24:52 +02:00
parent 20577d8a7c
commit b3c6ee0ac1
2 changed files with 6 additions and 3 deletions

View File

@ -135,9 +135,10 @@ class BertTokenizer(object):
return tokens
def save_vocabulary(self, vocab_path):
"""Save the tokenizer vocabulary to a path."""
"""Save the tokenizer vocabulary to a directory or file."""
index = 0
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
if os.path.isdir(vocab_path):
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
with open(vocab_file, "w", encoding="utf-8") as writer:
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
if index != token_index:

View File

@ -145,8 +145,10 @@ class TransfoXLTokenizer(object):
raise ValueError('No <unkown> token in vocabulary')
def save_vocabulary(self, vocab_path):
"""Save the tokenizer vocabulary to a directory or file."""
index = 0
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
if os.path.isdir(vocab_path):
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
torch.save(self.__dict__, vocab_file)
return vocab_file