From 8b63a01d9511e68fca16b60d496a617a270d6c43 Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Fri, 17 Apr 2020 11:28:55 -0400 Subject: [PATCH] XLM tokenizer should encode with bos token (#3791) * XLM tokenizer should encode with bos token * Update tests --- src/transformers/tokenization_xlm.py | 9 +++++---- tests/test_tokenization_xlm.py | 4 ++-- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/transformers/tokenization_xlm.py b/src/transformers/tokenization_xlm.py index 5afe1d29f0b..a9b79cec828 100644 --- a/src/transformers/tokenization_xlm.py +++ b/src/transformers/tokenization_xlm.py @@ -873,11 +873,12 @@ class XLMTokenizer(PreTrainedTokenizer): :obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens. """ - if token_ids_1 is None: - return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + bos = [self.bos_token_id] sep = [self.sep_token_id] - cls = [self.cls_token_id] - return cls + token_ids_0 + sep + token_ids_1 + sep + + if token_ids_1 is None: + return bos + token_ids_0 + sep + return bos + token_ids_0 + sep + token_ids_1 + sep def get_special_tokens_mask( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False diff --git a/tests/test_tokenization_xlm.py b/tests/test_tokenization_xlm.py index 5fd7379388b..43123554273 100644 --- a/tests/test_tokenization_xlm.py +++ b/tests/test_tokenization_xlm.py @@ -96,5 +96,5 @@ class XLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase): encoded_sentence = tokenizer.build_inputs_with_special_tokens(text) encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2) - assert encoded_sentence == [1] + text + [1] - assert encoded_pair == [1] + text + [1] + text_2 + [1] + assert encoded_sentence == [0] + text + [1] + assert encoded_pair == [0] + text + [1] + text_2 + [1]