Merge pull request #1205 from maru0kun/patch-2

Fix typo
This commit is contained in:
Thomas Wolf 2019-09-05 21:44:16 +02:00 committed by GitHub
commit 5ac8b62265
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 31 additions and 25 deletions

View File

@ -55,6 +55,22 @@ class CommonTestCases:
def get_input_output_texts(self):
raise NotImplementedError
def test_tokenizers_common_properties(self):
tokenizer = self.get_tokenizer()
attributes_list = ["bos_token", "eos_token", "unk_token", "sep_token",
"pad_token", "cls_token", "mask_token"]
for attr in attributes_list:
self.assertTrue(hasattr(tokenizer, attr))
self.assertTrue(hasattr(tokenizer, attr + "_id"))
self.assertTrue(hasattr(tokenizer, "additional_special_tokens"))
self.assertTrue(hasattr(tokenizer, 'additional_special_tokens_ids'))
attributes_list = ["max_len", "init_inputs", "init_kwargs", "added_tokens_encoder",
"added_tokens_decoder"]
for attr in attributes_list:
self.assertTrue(hasattr(tokenizer, attr))
def test_save_and_load_tokenizer(self):
# safety check on max_len default value so we are sure the test works
tokenizer = self.get_tokenizer()

View File

@ -162,58 +162,42 @@ class PreTrainedTokenizer(object):
@property
def bos_token_id(self):
""" Id of the beginning of sentence token in the vocabulary. Log an error if used while not having been set. """
if self._bos_token is None:
logger.error("Using bos_token, but it is not set yet.")
return self.convert_tokens_to_ids(self._bos_token)
return self.convert_tokens_to_ids(self.bos_token)
@property
def eos_token_id(self):
""" Id of the end of sentence token in the vocabulary. Log an error if used while not having been set. """
if self._eos_token is None:
logger.error("Using eos_token, but it is not set yet.")
return self.convert_tokens_to_ids(self._eos_token)
return self.convert_tokens_to_ids(self.eos_token)
@property
def unk_token_is(self):
def unk_token_id(self):
""" Id of the unknown token in the vocabulary. Log an error if used while not having been set. """
if self._unk_token is None:
logger.error("Using unk_token, but it is not set yet.")
return self.convert_tokens_to_ids(self._unk_token)
return self.convert_tokens_to_ids(self.unk_token)
@property
def sep_token_id(self):
""" Id of the separation token in the vocabulary. E.g. separate context and query in an input sequence. Log an error if used while not having been set. """
if self._sep_token is None:
logger.error("Using sep_token, but it is not set yet.")
return self.convert_tokens_to_ids(self._sep_token)
return self.convert_tokens_to_ids(self.sep_token)
@property
def pad_token_id(self):
""" Id of the padding token in the vocabulary. Log an error if used while not having been set. """
if self._pad_token is None:
logger.error("Using pad_token, but it is not set yet.")
return self.convert_tokens_to_ids(self._pad_token)
return self.convert_tokens_to_ids(self.pad_token)
@property
def cls_token_id(self):
""" Id of the classification token in the vocabulary. E.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model. Log an error if used while not having been set. """
if self._cls_token is None:
logger.error("Using cls_token, but it is not set yet.")
return self.convert_tokens_to_ids(self._cls_token)
return self.convert_tokens_to_ids(self.cls_token)
@property
def mask_token_id(self):
""" Id of the mask token in the vocabulary. E.g. when training a model with masked-language modeling. Log an error if used while not having been set. """
if self._mask_token is None:
logger.error("Using mask_token, but it is not set yet.")
return self.convert_tokens_to_ids(self._mask_token)
return self.convert_tokens_to_ids(self.mask_token)
@property
def additional_special_tokens_ids(self):
""" Ids of all the additional special tokens in the vocabulary (list of integers). Log an error if used while not having been set. """
if self._additional_special_tokens is None:
logger.error("Using additional_special_tokens, but it is not set yet.")
return self.convert_tokens_to_ids(self._additional_special_tokens)
return self.convert_tokens_to_ids(self.additional_special_tokens)
def __init__(self, max_len=None, **kwargs):
self._bos_token = None
@ -653,6 +637,9 @@ class PreTrainedTokenizer(object):
""" Converts a single token, or a sequence of tokens, (str/unicode) in a single integer id
(resp. a sequence of ids), using the vocabulary.
"""
if tokens is None:
return None
if isinstance(tokens, str) or (six.PY2 and isinstance(tokens, unicode)):
return self._convert_token_to_id_with_added_voc(tokens)
@ -666,6 +653,9 @@ class PreTrainedTokenizer(object):
return ids
def _convert_token_to_id_with_added_voc(self, token):
if token is None:
return None
if token in self.added_tokens_encoder:
return self.added_tokens_encoder[token]
return self._convert_token_to_id(token)