more options on special tokens

This commit is contained in:
thomwolf 2019-02-04 17:26:25 +01:00
parent 05f961840b
commit 01a3966bc6

View File

@ -131,6 +131,10 @@ class OpenAIGPTTokenizer(object):
return len(self.encoder) + len(self.special_tokens)
def set_special_tokens(self, special_tokens):
""" Add a list of additional tokens to the encoder.
The additional tokens are indexed starting from the last index of the
current vocabulary in the order of the `special_tokens` list.
"""
if not special_tokens:
self.special_tokens = {}
self.special_tokens_decoder = {}
@ -210,18 +214,19 @@ class OpenAIGPTTokenizer(object):
)
return ids
def convert_ids_to_tokens(self, ids):
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
"""Converts a sequence of ids in BPE tokens using the vocab."""
tokens = []
for i in ids:
if i in self.special_tokens_decoder:
tokens.append(self.special_tokens_decoder[i])
if not skip_special_tokens:
tokens.append(self.special_tokens_decoder[i])
else:
tokens.append(self.decoder[i])
return tokens
def decode(self, ids):
def decode(self, ids, skip_special_tokens=False):
"""Converts a sequence of ids in a string."""
tokens = self.convert_ids_to_tokens(ids)
tokens = self.convert_ids_to_tokens(ids, skip_special_tokens=skip_special_tokens)
out_string = ''.join(tokens).replace('</w>', ' ')
return out_string