mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
more options on special tokens
This commit is contained in:
parent
05f961840b
commit
01a3966bc6
@ -131,6 +131,10 @@ class OpenAIGPTTokenizer(object):
|
||||
return len(self.encoder) + len(self.special_tokens)
|
||||
|
||||
def set_special_tokens(self, special_tokens):
|
||||
""" Add a list of additional tokens to the encoder.
|
||||
The additional tokens are indexed starting from the last index of the
|
||||
current vocabulary in the order of the `special_tokens` list.
|
||||
"""
|
||||
if not special_tokens:
|
||||
self.special_tokens = {}
|
||||
self.special_tokens_decoder = {}
|
||||
@ -210,18 +214,19 @@ class OpenAIGPTTokenizer(object):
|
||||
)
|
||||
return ids
|
||||
|
||||
def convert_ids_to_tokens(self, ids):
|
||||
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
|
||||
"""Converts a sequence of ids in BPE tokens using the vocab."""
|
||||
tokens = []
|
||||
for i in ids:
|
||||
if i in self.special_tokens_decoder:
|
||||
tokens.append(self.special_tokens_decoder[i])
|
||||
if not skip_special_tokens:
|
||||
tokens.append(self.special_tokens_decoder[i])
|
||||
else:
|
||||
tokens.append(self.decoder[i])
|
||||
return tokens
|
||||
|
||||
def decode(self, ids):
|
||||
def decode(self, ids, skip_special_tokens=False):
|
||||
"""Converts a sequence of ids in a string."""
|
||||
tokens = self.convert_ids_to_tokens(ids)
|
||||
tokens = self.convert_ids_to_tokens(ids, skip_special_tokens=skip_special_tokens)
|
||||
out_string = ''.join(tokens).replace('</w>', ' ')
|
||||
return out_string
|
||||
|
Loading…
Reference in New Issue
Block a user