Add standardized get_vocab method to tokenizers

This commit is contained in:
Joe Davison 2020-02-22 12:09:01 -05:00 committed by GitHub
commit c36416e53c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 62 additions and 0 deletions

View File

@ -114,6 +114,11 @@ class AlbertTokenizer(PreTrainedTokenizer):
def vocab_size(self): def vocab_size(self):
return len(self.sp_model) return len(self.sp_model)
def get_vocab(self):
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
return vocab
def __getstate__(self): def __getstate__(self):
state = self.__dict__.copy() state = self.__dict__.copy()
state["sp_model"] = None state["sp_model"] = None

View File

@ -195,6 +195,9 @@ class BertTokenizer(PreTrainedTokenizer):
def vocab_size(self): def vocab_size(self):
return len(self.vocab) return len(self.vocab)
def get_vocab(self):
return dict(self.vocab, **self.added_tokens_encoder)
def _tokenize(self, text): def _tokenize(self, text):
split_tokens = [] split_tokens = []
if self.do_basic_tokenize: if self.do_basic_tokenize:

View File

@ -147,6 +147,9 @@ class CTRLTokenizer(PreTrainedTokenizer):
def vocab_size(self): def vocab_size(self):
return len(self.encoder) return len(self.encoder)
def get_vocab(self):
return dict(self.encoder, **self.added_tokens_encoder)
def bpe(self, token): def bpe(self, token):
if token in self.cache: if token in self.cache:
return self.cache[token] return self.cache[token]

View File

@ -149,6 +149,9 @@ class GPT2Tokenizer(PreTrainedTokenizer):
def vocab_size(self): def vocab_size(self):
return len(self.encoder) return len(self.encoder)
def get_vocab(self):
return dict(self.encoder, **self.added_tokens_encoder)
def bpe(self, token): def bpe(self, token):
if token in self.cache: if token in self.cache:
return self.cache[token] return self.cache[token]

View File

@ -125,6 +125,9 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
def vocab_size(self): def vocab_size(self):
return len(self.encoder) return len(self.encoder)
def get_vocab(self):
return dict(self.encoder, **self.added_tokens_encoder)
def bpe(self, token): def bpe(self, token):
word = tuple(token[:-1]) + (token[-1] + "</w>",) word = tuple(token[:-1]) + (token[-1] + "</w>",)
if token in self.cache: if token in self.cache:

View File

@ -119,6 +119,11 @@ class T5Tokenizer(PreTrainedTokenizer):
def vocab_size(self): def vocab_size(self):
return self.sp_model.get_piece_size() + self._extra_ids return self.sp_model.get_piece_size() + self._extra_ids
def get_vocab(self):
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
return vocab
def __getstate__(self): def __getstate__(self):
state = self.__dict__.copy() state = self.__dict__.copy()
state["sp_model"] = None state["sp_model"] = None

View File

@ -273,6 +273,9 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
def vocab_size(self): def vocab_size(self):
return len(self.idx2sym) return len(self.idx2sym)
def get_vocab(self):
return dict(self.sym2idx, **self.added_tokens_encoder)
def _tokenize(self, line, add_eos=False, add_double_eos=False): def _tokenize(self, line, add_eos=False, add_double_eos=False):
line = line.strip() line = line.strip()
# convert to lower case # convert to lower case

View File

@ -286,6 +286,10 @@ class PreTrainedTokenizer(object):
""" Ids of all the additional special tokens in the vocabulary (list of integers). Log an error if used while not having been set. """ """ Ids of all the additional special tokens in the vocabulary (list of integers). Log an error if used while not having been set. """
return self.convert_tokens_to_ids(self.additional_special_tokens) return self.convert_tokens_to_ids(self.additional_special_tokens)
def get_vocab(self):
""" Returns the vocabulary as a dict of {token: index} pairs. `tokenizer.get_vocab()[token]` is equivalent to `tokenizer.convert_tokens_to_ids(token)` when `token` is in the vocab. """
raise NotImplementedError()
def __init__(self, max_len=None, **kwargs): def __init__(self, max_len=None, **kwargs):
self._bos_token = None self._bos_token = None
self._eos_token = None self._eos_token = None

View File

@ -662,6 +662,9 @@ class XLMTokenizer(PreTrainedTokenizer):
def vocab_size(self): def vocab_size(self):
return len(self.encoder) return len(self.encoder)
def get_vocab(self):
return dict(self.encoder, **self.added_tokens_encoder)
def bpe(self, token): def bpe(self, token):
word = tuple(token[:-1]) + (token[-1] + "</w>",) word = tuple(token[:-1]) + (token[-1] + "</w>",)
if token in self.cache: if token in self.cache:

View File

@ -190,6 +190,11 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
def vocab_size(self): def vocab_size(self):
return len(self.sp_model) + len(self.fairseq_tokens_to_ids) return len(self.sp_model) + len(self.fairseq_tokens_to_ids)
def get_vocab(self):
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
return vocab
def _tokenize(self, text): def _tokenize(self, text):
return self.sp_model.EncodeAsPieces(text) return self.sp_model.EncodeAsPieces(text)

View File

@ -114,6 +114,11 @@ class XLNetTokenizer(PreTrainedTokenizer):
def vocab_size(self): def vocab_size(self):
return len(self.sp_model) return len(self.sp_model)
def get_vocab(self):
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
return vocab
def __getstate__(self): def __getstate__(self):
state = self.__dict__.copy() state = self.__dict__.copy()
state["sp_model"] = None state["sp_model"] = None

View File

@ -542,3 +542,23 @@ class TokenizerTesterMixin:
print(new_tokenizer.init_kwargs) print(new_tokenizer.init_kwargs)
assert tokenizer.init_kwargs["random_argument"] is True assert tokenizer.init_kwargs["random_argument"] is True
assert new_tokenizer.init_kwargs["random_argument"] is False assert new_tokenizer.init_kwargs["random_argument"] is False
def test_get_vocab(self):
tokenizer = self.get_tokenizer()
vocab = tokenizer.get_vocab()
self.assertIsInstance(vocab, dict)
self.assertEqual(len(vocab), len(tokenizer))
for word, ind in vocab.items():
self.assertEqual(tokenizer.convert_tokens_to_ids(word), ind)
self.assertEqual(tokenizer.convert_ids_to_tokens(ind), word)
tokenizer.add_tokens(["asdfasdfasdfasdf"])
vocab = tokenizer.get_vocab()
self.assertIsInstance(vocab, dict)
self.assertEqual(len(vocab), len(tokenizer))
for word, ind in vocab.items():
self.assertEqual(tokenizer.convert_tokens_to_ids(word), ind)
self.assertEqual(tokenizer.convert_ids_to_tokens(ind), word)