From b3c6ee0ac1cd95bcd0a54a36a29daf599f389f93 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Mon, 15 Apr 2019 14:24:52 +0200 Subject: [PATCH] tokenization updates --- pytorch_pretrained_bert/tokenization.py | 5 +++-- pytorch_pretrained_bert/tokenization_transfo_xl.py | 4 +++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/pytorch_pretrained_bert/tokenization.py b/pytorch_pretrained_bert/tokenization.py index 8fd65f55f08..3937d6e0118 100644 --- a/pytorch_pretrained_bert/tokenization.py +++ b/pytorch_pretrained_bert/tokenization.py @@ -135,9 +135,10 @@ class BertTokenizer(object): return tokens def save_vocabulary(self, vocab_path): - """Save the tokenizer vocabulary to a path.""" + """Save the tokenizer vocabulary to a directory or file.""" index = 0 - vocab_file = os.path.join(vocab_path, VOCAB_NAME) + if os.path.isdir(vocab_path): + vocab_file = os.path.join(vocab_path, VOCAB_NAME) with open(vocab_file, "w", encoding="utf-8") as writer: for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): if index != token_index: diff --git a/pytorch_pretrained_bert/tokenization_transfo_xl.py b/pytorch_pretrained_bert/tokenization_transfo_xl.py index f704a035dbc..ddebc57c106 100644 --- a/pytorch_pretrained_bert/tokenization_transfo_xl.py +++ b/pytorch_pretrained_bert/tokenization_transfo_xl.py @@ -145,8 +145,10 @@ class TransfoXLTokenizer(object): raise ValueError('No token in vocabulary') def save_vocabulary(self, vocab_path): + """Save the tokenizer vocabulary to a directory or file.""" index = 0 - vocab_file = os.path.join(vocab_path, VOCAB_NAME) + if os.path.isdir(vocab_path): + vocab_file = os.path.join(vocab_path, VOCAB_NAME) torch.save(self.__dict__, vocab_file) return vocab_file