Changed vocabulary save function. Variable name was inconsistent, causing an error to be thrown when passing a file name instead of a directory.

This commit is contained in:
dchurchwell 2020-02-06 02:02:48 -07:00 committed by Lysandre Debut
parent 6fc3d34abd
commit 2c12464a20

View File

@ -159,6 +159,8 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
"""Save the tokenizer vocabulary to a directory or file."""
if os.path.isdir(vocab_path):
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["pretrained_vocab_file"])
else:
vocab_file = vocab_path
torch.save(self.__dict__, vocab_file)
return (vocab_file,)