max_len_single_sentence & max_len_sentences_pair as attributes so they can be modified

This commit is contained in:
thomwolf 2019-08-23 22:07:26 +02:00
parent ab7bd5ef98
commit 3bcbebd440
8 changed files with 26 additions and 40 deletions

View File

@ -125,6 +125,9 @@ class BertTokenizer(PreTrainedTokenizer):
super(BertTokenizer, self).__init__(unk_token=unk_token, sep_token=sep_token, super(BertTokenizer, self).__init__(unk_token=unk_token, sep_token=sep_token,
pad_token=pad_token, cls_token=cls_token, pad_token=pad_token, cls_token=cls_token,
mask_token=mask_token, **kwargs) mask_token=mask_token, **kwargs)
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens
if not os.path.isfile(vocab_file): if not os.path.isfile(vocab_file):
raise ValueError( raise ValueError(
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
@ -139,14 +142,6 @@ class BertTokenizer(PreTrainedTokenizer):
tokenize_chinese_chars=tokenize_chinese_chars) tokenize_chinese_chars=tokenize_chinese_chars)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token) self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token)
@property
def max_len_single_sentence(self):
return self.max_len - 2 # take into account special tokens
@property
def max_len_sentences_pair(self):
return self.max_len - 3 # take into account special tokens
@property @property
def vocab_size(self): def vocab_size(self):
return len(self.vocab) return len(self.vocab)

View File

@ -108,6 +108,8 @@ class GPT2Tokenizer(PreTrainedTokenizer):
def __init__(self, vocab_file, merges_file, errors='replace', unk_token="<|endoftext|>", def __init__(self, vocab_file, merges_file, errors='replace', unk_token="<|endoftext|>",
bos_token="<|endoftext|>", eos_token="<|endoftext|>", **kwargs): bos_token="<|endoftext|>", eos_token="<|endoftext|>", **kwargs):
super(GPT2Tokenizer, self).__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs) super(GPT2Tokenizer, self).__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs)
self.max_len_single_sentence = self.max_len # no default special tokens - you can update this value if you add special tokens
self.max_len_sentences_pair = self.max_len # no default special tokens - you can update this value if you add special tokens
self.encoder = json.load(open(vocab_file)) self.encoder = json.load(open(vocab_file))
self.decoder = {v:k for k,v in self.encoder.items()} self.decoder = {v:k for k,v in self.encoder.items()}

View File

@ -87,6 +87,9 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
def __init__(self, vocab_file, merges_file, unk_token="<unk>", **kwargs): def __init__(self, vocab_file, merges_file, unk_token="<unk>", **kwargs):
super(OpenAIGPTTokenizer, self).__init__(unk_token=unk_token, **kwargs) super(OpenAIGPTTokenizer, self).__init__(unk_token=unk_token, **kwargs)
self.max_len_single_sentence = self.max_len # no default special tokens - you can update this value if you add special tokens
self.max_len_sentences_pair = self.max_len # no default special tokens - you can update this value if you add special tokens
try: try:
import ftfy import ftfy
from spacy.lang.en import English from spacy.lang.en import English

View File

@ -77,6 +77,9 @@ class RobertaTokenizer(PreTrainedTokenizer):
sep_token=sep_token, cls_token=cls_token, pad_token=pad_token, sep_token=sep_token, cls_token=cls_token, pad_token=pad_token,
mask_token=mask_token, **kwargs) mask_token=mask_token, **kwargs)
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
self.max_len_sentences_pair = self.max_len - 4 # take into account special tokens
self.encoder = json.load(open(vocab_file, encoding="utf-8")) self.encoder = json.load(open(vocab_file, encoding="utf-8"))
self.decoder = {v: k for k, v in self.encoder.items()} self.decoder = {v: k for k, v in self.encoder.items()}
self.errors = errors # how to handle errors in decoding self.errors = errors # how to handle errors in decoding
@ -160,14 +163,6 @@ class RobertaTokenizer(PreTrainedTokenizer):
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
return text return text
@property
def max_len_single_sentence(self):
return self.max_len - 2 # take into account special tokens
@property
def max_len_sentences_pair(self):
return self.max_len - 4 # take into account special tokens
def add_special_tokens_single_sentence(self, token_ids): def add_special_tokens_single_sentence(self, token_ids):
""" """
Adds special tokens to a sequence for sequence classification tasks. Adds special tokens to a sequence for sequence classification tasks.

View File

@ -73,6 +73,10 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
super(TransfoXLTokenizer, self).__init__(unk_token=unk_token, eos_token=eos_token, super(TransfoXLTokenizer, self).__init__(unk_token=unk_token, eos_token=eos_token,
additional_special_tokens=additional_special_tokens, additional_special_tokens=additional_special_tokens,
**kwargs) **kwargs)
self.max_len_single_sentence = self.max_len # no default special tokens - you can update this value if you add special tokens
self.max_len_sentences_pair = self.max_len # no default special tokens - you can update this value if you add special tokens
if never_split is None: if never_split is None:
never_split = self.all_special_tokens never_split = self.all_special_tokens
if special is None: if special is None:

View File

@ -67,14 +67,6 @@ class PreTrainedTokenizer(object):
"pad_token", "cls_token", "mask_token", "pad_token", "cls_token", "mask_token",
"additional_special_tokens"] "additional_special_tokens"]
@property
def max_len_single_sentence(self):
return self.max_len # Default to max_len but can be smaller in specific tokenizers to take into account special tokens
@property
def max_len_sentences_pair(self):
return self.max_len # Default to max_len but can be smaller in specific tokenizers to take into account special tokens
@property @property
def bos_token(self): def bos_token(self):
""" Beginning of sentence token (string). Log an error if used while not having been set. """ """ Beginning of sentence token (string). Log an error if used while not having been set. """
@ -174,6 +166,9 @@ class PreTrainedTokenizer(object):
self._additional_special_tokens = [] self._additional_special_tokens = []
self.max_len = max_len if max_len is not None else int(1e12) self.max_len = max_len if max_len is not None else int(1e12)
self.max_len_single_sentence = self.max_len
self.max_len_sentences_pair = self.max_len
self.added_tokens_encoder = {} self.added_tokens_encoder = {}
self.added_tokens_decoder = {} self.added_tokens_decoder = {}

View File

@ -122,6 +122,10 @@ class XLMTokenizer(PreTrainedTokenizer):
cls_token=cls_token, mask_token=mask_token, cls_token=cls_token, mask_token=mask_token,
additional_special_tokens=additional_special_tokens, additional_special_tokens=additional_special_tokens,
**kwargs) **kwargs)
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens
try: try:
import ftfy import ftfy
from spacy.lang.en import English from spacy.lang.en import English
@ -215,14 +219,6 @@ class XLMTokenizer(PreTrainedTokenizer):
out_string = ''.join(tokens).replace('</w>', ' ').strip() out_string = ''.join(tokens).replace('</w>', ' ').strip()
return out_string return out_string
@property
def max_len_single_sentence(self):
return self.max_len - 2 # take into account special tokens
@property
def max_len_sentences_pair(self):
return self.max_len - 3 # take into account special tokens
def add_special_tokens_single_sentence(self, token_ids): def add_special_tokens_single_sentence(self, token_ids):
""" """
Adds special tokens to a sequence for sequence classification tasks. Adds special tokens to a sequence for sequence classification tasks.

View File

@ -71,6 +71,10 @@ class XLNetTokenizer(PreTrainedTokenizer):
pad_token=pad_token, cls_token=cls_token, pad_token=pad_token, cls_token=cls_token,
mask_token=mask_token, additional_special_tokens= mask_token=mask_token, additional_special_tokens=
additional_special_tokens, **kwargs) additional_special_tokens, **kwargs)
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens
try: try:
import sentencepiece as spm import sentencepiece as spm
except ImportError: except ImportError:
@ -177,14 +181,6 @@ class XLNetTokenizer(PreTrainedTokenizer):
out_string = ''.join(tokens).replace(SPIECE_UNDERLINE, ' ').strip() out_string = ''.join(tokens).replace(SPIECE_UNDERLINE, ' ').strip()
return out_string return out_string
@property
def max_len_single_sentence(self):
return self.max_len - 2 # take into account special tokens
@property
def max_len_sentences_pair(self):
return self.max_len - 3 # take into account special tokens
def add_special_tokens_single_sentence(self, token_ids): def add_special_tokens_single_sentence(self, token_ids):
""" """
Adds special tokens to a sequence pair for sequence classification tasks. Adds special tokens to a sequence pair for sequence classification tasks.