Merge pull request #2065 from huggingface/fixing-camembert

Fixing camembert tokenization
This commit is contained in:
Thomas Wolf 2019-12-05 13:45:44 +01:00 committed by GitHub
commit af077b15e2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 3 deletions

View File

@ -167,7 +167,7 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
def create_and_check_xlnet_base_model_with_att_output(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, token_labels):
model = XLNetModel(config)
model.eval()

View File

@ -51,7 +51,7 @@ class CamembertTokenizer(PreTrainedTokenizer):
def __init__(self, vocab_file, bos_token="<s>", eos_token="</s>", sep_token="</s>",
cls_token="<s>", unk_token="<unk>", pad_token='<pad>', mask_token='<mask>',
additional_special_tokens=['<s>NOTUSED', '<s>NOTUSED'], **kwargs):
additional_special_tokens=['<s>NOTUSED', '</s>NOTUSED'], **kwargs):
super(CamembertTokenizer, self).__init__(max_len=512, bos_token=bos_token, eos_token=eos_token, unk_token=unk_token,
sep_token=sep_token, cls_token=cls_token, pad_token=pad_token,
mask_token=mask_token, additional_special_tokens=additional_special_tokens,
@ -125,7 +125,7 @@ class CamembertTokenizer(PreTrainedTokenizer):
@property
def vocab_size(self):
return self.fairseq_offset + len(self.sp_model)
return len(self.fairseq_tokens_to_ids) + len(self.sp_model)
def _tokenize(self, text):
return self.sp_model.EncodeAsPieces(text)
@ -134,6 +134,9 @@ class CamembertTokenizer(PreTrainedTokenizer):
""" Converts a token (str/unicode) in an id using the vocab. """
if token in self.fairseq_tokens_to_ids:
return self.fairseq_tokens_to_ids[token]
elif self.sp_model.PieceToId(token) == 0:
# Convert sentence piece unk token to fairseq unk token index
return self.unk_token_id
return self.fairseq_offset + self.sp_model.PieceToId(token)
def _convert_id_to_token(self, index):