mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
more options on special tokens
This commit is contained in:
parent
05f961840b
commit
01a3966bc6
@ -131,6 +131,10 @@ class OpenAIGPTTokenizer(object):
|
|||||||
return len(self.encoder) + len(self.special_tokens)
|
return len(self.encoder) + len(self.special_tokens)
|
||||||
|
|
||||||
def set_special_tokens(self, special_tokens):
|
def set_special_tokens(self, special_tokens):
|
||||||
|
""" Add a list of additional tokens to the encoder.
|
||||||
|
The additional tokens are indexed starting from the last index of the
|
||||||
|
current vocabulary in the order of the `special_tokens` list.
|
||||||
|
"""
|
||||||
if not special_tokens:
|
if not special_tokens:
|
||||||
self.special_tokens = {}
|
self.special_tokens = {}
|
||||||
self.special_tokens_decoder = {}
|
self.special_tokens_decoder = {}
|
||||||
@ -210,18 +214,19 @@ class OpenAIGPTTokenizer(object):
|
|||||||
)
|
)
|
||||||
return ids
|
return ids
|
||||||
|
|
||||||
def convert_ids_to_tokens(self, ids):
|
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
|
||||||
"""Converts a sequence of ids in BPE tokens using the vocab."""
|
"""Converts a sequence of ids in BPE tokens using the vocab."""
|
||||||
tokens = []
|
tokens = []
|
||||||
for i in ids:
|
for i in ids:
|
||||||
if i in self.special_tokens_decoder:
|
if i in self.special_tokens_decoder:
|
||||||
tokens.append(self.special_tokens_decoder[i])
|
if not skip_special_tokens:
|
||||||
|
tokens.append(self.special_tokens_decoder[i])
|
||||||
else:
|
else:
|
||||||
tokens.append(self.decoder[i])
|
tokens.append(self.decoder[i])
|
||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
def decode(self, ids):
|
def decode(self, ids, skip_special_tokens=False):
|
||||||
"""Converts a sequence of ids in a string."""
|
"""Converts a sequence of ids in a string."""
|
||||||
tokens = self.convert_ids_to_tokens(ids)
|
tokens = self.convert_ids_to_tokens(ids, skip_special_tokens=skip_special_tokens)
|
||||||
out_string = ''.join(tokens).replace('</w>', ' ')
|
out_string = ''.join(tokens).replace('</w>', ' ')
|
||||||
return out_string
|
return out_string
|
||||||
|
Loading…
Reference in New Issue
Block a user