fix t5 token type ids (#8437)

This commit is contained in:
Patrick von Platen 2020-11-10 20:21:54 +01:00 committed by GitHub
parent 9fd1f56236
commit 70708cca1a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 58 additions and 0 deletions

View File

@ -187,6 +187,28 @@ class T5Tokenizer(PreTrainedTokenizer):
else:
return token_ids + [self.eos_token_id]
def create_token_type_ids_from_sequences(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
use of token type ids, therefore a list of zeros is returned.
Args:
token_ids_0 (:obj:`List[int]`):
List of IDs.
token_ids_1 (:obj:`List[int]`, `optional`):
Optional second list of IDs for sequence pairs.
Returns:
:obj:`List[int]`: List of zeros.
"""
eos = [self.eos_token_id]
if token_ids_1 is None:
return len(token_ids_0 + eos) * [0]
return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:

View File

@ -191,6 +191,28 @@ class T5TokenizerFast(PreTrainedTokenizerFast):
token_ids_1 = token_ids_1 + [self.eos_token_id]
return self.prefix_tokens + token_ids_0 + token_ids_1
def create_token_type_ids_from_sequences(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
use of token type ids, therefore a list of zeros is returned.
Args:
token_ids_0 (:obj:`List[int]`):
List of IDs.
token_ids_1 (:obj:`List[int]`, `optional`):
Optional second list of IDs for sequence pairs.
Returns:
:obj:`List[int]`: List of zeros.
"""
eos = [self.eos_token_id]
if token_ids_1 is None:
return len(token_ids_0 + eos) * [0]
return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
def prepare_seq2seq_batch(
self,

View File

@ -223,6 +223,20 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
self.assertEqual(expected_src_tokens, src_ids)
self.assertEqual(expected_tgt_tokens, tgt_ids)
def test_token_type_ids(self):
src_text_1 = ["A first paragraph for summarization."]
src_text_2 = ["A second paragraph for summarization."]
fast_token_type_ids = self.t5_base_tokenizer_fast(
src_text_1, src_text_2, add_special_tokens=True, return_token_type_ids=True
).token_type_ids
slow_token_type_ids = self.t5_base_tokenizer(
src_text_1, src_text_2, add_special_tokens=True, return_token_type_ids=True
).token_type_ids
self.assertEqual(slow_token_type_ids, fast_token_type_ids)
self.assertEqual(len(slow_token_type_ids[0]), 18)
def test_fast_and_slow_same_result(self):
src_text = "<pad> Today is <unk> nice day </s>"
tgt_ids = [0, 1960, 19, 2, 1245, 239, 1]