mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
fix t5 token type ids (#8437)
This commit is contained in:
parent
9fd1f56236
commit
70708cca1a
@ -187,6 +187,28 @@ class T5Tokenizer(PreTrainedTokenizer):
|
||||
else:
|
||||
return token_ids + [self.eos_token_id]
|
||||
|
||||
def create_token_type_ids_from_sequences(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
|
||||
use of token type ids, therefore a list of zeros is returned.
|
||||
|
||||
Args:
|
||||
token_ids_0 (:obj:`List[int]`):
|
||||
List of IDs.
|
||||
token_ids_1 (:obj:`List[int]`, `optional`):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
|
||||
Returns:
|
||||
:obj:`List[int]`: List of zeros.
|
||||
"""
|
||||
eos = [self.eos_token_id]
|
||||
|
||||
if token_ids_1 is None:
|
||||
return len(token_ids_0 + eos) * [0]
|
||||
return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
|
||||
|
||||
def build_inputs_with_special_tokens(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
) -> List[int]:
|
||||
|
@ -191,6 +191,28 @@ class T5TokenizerFast(PreTrainedTokenizerFast):
|
||||
token_ids_1 = token_ids_1 + [self.eos_token_id]
|
||||
return self.prefix_tokens + token_ids_0 + token_ids_1
|
||||
|
||||
def create_token_type_ids_from_sequences(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
|
||||
use of token type ids, therefore a list of zeros is returned.
|
||||
|
||||
Args:
|
||||
token_ids_0 (:obj:`List[int]`):
|
||||
List of IDs.
|
||||
token_ids_1 (:obj:`List[int]`, `optional`):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
|
||||
Returns:
|
||||
:obj:`List[int]`: List of zeros.
|
||||
"""
|
||||
eos = [self.eos_token_id]
|
||||
|
||||
if token_ids_1 is None:
|
||||
return len(token_ids_0 + eos) * [0]
|
||||
return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
|
||||
|
||||
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
|
||||
def prepare_seq2seq_batch(
|
||||
self,
|
||||
|
@ -223,6 +223,20 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
self.assertEqual(expected_src_tokens, src_ids)
|
||||
self.assertEqual(expected_tgt_tokens, tgt_ids)
|
||||
|
||||
def test_token_type_ids(self):
|
||||
src_text_1 = ["A first paragraph for summarization."]
|
||||
src_text_2 = ["A second paragraph for summarization."]
|
||||
|
||||
fast_token_type_ids = self.t5_base_tokenizer_fast(
|
||||
src_text_1, src_text_2, add_special_tokens=True, return_token_type_ids=True
|
||||
).token_type_ids
|
||||
slow_token_type_ids = self.t5_base_tokenizer(
|
||||
src_text_1, src_text_2, add_special_tokens=True, return_token_type_ids=True
|
||||
).token_type_ids
|
||||
|
||||
self.assertEqual(slow_token_type_ids, fast_token_type_ids)
|
||||
self.assertEqual(len(slow_token_type_ids[0]), 18)
|
||||
|
||||
def test_fast_and_slow_same_result(self):
|
||||
src_text = "<pad> Today is <unk> nice day </s>"
|
||||
tgt_ids = [0, 1960, 19, 2, 1245, 239, 1]
|
||||
|
Loading…
Reference in New Issue
Block a user