mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
fix t5 token type ids (#8437)
This commit is contained in:
parent
9fd1f56236
commit
70708cca1a
@ -187,6 +187,28 @@ class T5Tokenizer(PreTrainedTokenizer):
|
|||||||
else:
|
else:
|
||||||
return token_ids + [self.eos_token_id]
|
return token_ids + [self.eos_token_id]
|
||||||
|
|
||||||
|
def create_token_type_ids_from_sequences(
|
||||||
|
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||||
|
) -> List[int]:
|
||||||
|
"""
|
||||||
|
Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
|
||||||
|
use of token type ids, therefore a list of zeros is returned.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token_ids_0 (:obj:`List[int]`):
|
||||||
|
List of IDs.
|
||||||
|
token_ids_1 (:obj:`List[int]`, `optional`):
|
||||||
|
Optional second list of IDs for sequence pairs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:obj:`List[int]`: List of zeros.
|
||||||
|
"""
|
||||||
|
eos = [self.eos_token_id]
|
||||||
|
|
||||||
|
if token_ids_1 is None:
|
||||||
|
return len(token_ids_0 + eos) * [0]
|
||||||
|
return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
|
||||||
|
|
||||||
def build_inputs_with_special_tokens(
|
def build_inputs_with_special_tokens(
|
||||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
|
@ -191,6 +191,28 @@ class T5TokenizerFast(PreTrainedTokenizerFast):
|
|||||||
token_ids_1 = token_ids_1 + [self.eos_token_id]
|
token_ids_1 = token_ids_1 + [self.eos_token_id]
|
||||||
return self.prefix_tokens + token_ids_0 + token_ids_1
|
return self.prefix_tokens + token_ids_0 + token_ids_1
|
||||||
|
|
||||||
|
def create_token_type_ids_from_sequences(
|
||||||
|
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||||
|
) -> List[int]:
|
||||||
|
"""
|
||||||
|
Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
|
||||||
|
use of token type ids, therefore a list of zeros is returned.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token_ids_0 (:obj:`List[int]`):
|
||||||
|
List of IDs.
|
||||||
|
token_ids_1 (:obj:`List[int]`, `optional`):
|
||||||
|
Optional second list of IDs for sequence pairs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:obj:`List[int]`: List of zeros.
|
||||||
|
"""
|
||||||
|
eos = [self.eos_token_id]
|
||||||
|
|
||||||
|
if token_ids_1 is None:
|
||||||
|
return len(token_ids_0 + eos) * [0]
|
||||||
|
return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
|
||||||
|
|
||||||
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
|
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
|
||||||
def prepare_seq2seq_batch(
|
def prepare_seq2seq_batch(
|
||||||
self,
|
self,
|
||||||
|
@ -223,6 +223,20 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
self.assertEqual(expected_src_tokens, src_ids)
|
self.assertEqual(expected_src_tokens, src_ids)
|
||||||
self.assertEqual(expected_tgt_tokens, tgt_ids)
|
self.assertEqual(expected_tgt_tokens, tgt_ids)
|
||||||
|
|
||||||
|
def test_token_type_ids(self):
|
||||||
|
src_text_1 = ["A first paragraph for summarization."]
|
||||||
|
src_text_2 = ["A second paragraph for summarization."]
|
||||||
|
|
||||||
|
fast_token_type_ids = self.t5_base_tokenizer_fast(
|
||||||
|
src_text_1, src_text_2, add_special_tokens=True, return_token_type_ids=True
|
||||||
|
).token_type_ids
|
||||||
|
slow_token_type_ids = self.t5_base_tokenizer(
|
||||||
|
src_text_1, src_text_2, add_special_tokens=True, return_token_type_ids=True
|
||||||
|
).token_type_ids
|
||||||
|
|
||||||
|
self.assertEqual(slow_token_type_ids, fast_token_type_ids)
|
||||||
|
self.assertEqual(len(slow_token_type_ids[0]), 18)
|
||||||
|
|
||||||
def test_fast_and_slow_same_result(self):
|
def test_fast_and_slow_same_result(self):
|
||||||
src_text = "<pad> Today is <unk> nice day </s>"
|
src_text = "<pad> Today is <unk> nice day </s>"
|
||||||
tgt_ids = [0, 1960, 19, 2, 1245, 239, 1]
|
tgt_ids = [0, 1960, 19, 2, 1245, 239, 1]
|
||||||
|
Loading…
Reference in New Issue
Block a user