mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
Fix convert_token_type_ids_from_sequences for fast tokenizers (#4503)
This commit is contained in:
parent
f7677e1623
commit
35df911485
@ -672,3 +672,33 @@ class BertTokenizerFast(PreTrainedTokenizerFast):
|
||||
output += token_ids_1 + [self.sep_token_id]
|
||||
|
||||
return output
|
||||
|
||||
def create_token_type_ids_from_sequences(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
|
||||
A BERT sequence pair mask has the following format:
|
||||
|
||||
::
|
||||
|
||||
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
|
||||
| first sequence | second sequence |
|
||||
|
||||
if token_ids_1 is None, only returns the first portion of the mask (0's).
|
||||
|
||||
Args:
|
||||
token_ids_0 (:obj:`List[int]`):
|
||||
List of ids.
|
||||
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
|
||||
Returns:
|
||||
:obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
|
||||
sequence(s).
|
||||
"""
|
||||
sep = [self.sep_token_id]
|
||||
cls = [self.cls_token_id]
|
||||
if token_ids_1 is None:
|
||||
return len(cls + token_ids_0 + sep) * [0]
|
||||
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
|
||||
|
@ -343,3 +343,27 @@ class RobertaTokenizerFast(GPT2TokenizerFast):
|
||||
return output
|
||||
|
||||
return output + [self.eos_token_id] + token_ids_1 + [self.eos_token_id]
|
||||
|
||||
def create_token_type_ids_from_sequences(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
|
||||
RoBERTa does not make use of token type ids, therefore a list of zeros is returned.
|
||||
|
||||
Args:
|
||||
token_ids_0 (:obj:`List[int]`):
|
||||
List of ids.
|
||||
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
|
||||
Returns:
|
||||
:obj:`List[int]`: List of zeros.
|
||||
|
||||
"""
|
||||
sep = [self.sep_token_id]
|
||||
cls = [self.cls_token_id]
|
||||
|
||||
if token_ids_1 is None:
|
||||
return len(cls + token_ids_0 + sep) * [0]
|
||||
return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
|
||||
|
@ -75,6 +75,7 @@ class CommonFastTokenizerTest(unittest.TestCase):
|
||||
self.assert_special_tokens_map_equal(tokenizer_r, tokenizer_p)
|
||||
self.assert_embeded_special_tokens(tokenizer_r, tokenizer_p)
|
||||
self.assert_padding(tokenizer_r, tokenizer_p)
|
||||
self.assert_create_token_type_ids(tokenizer_r, tokenizer_p)
|
||||
# TODO: enable for v3.0.0
|
||||
# self.assert_empty_output_no_special_tokens(tokenizer_r, tokenizer_p)
|
||||
|
||||
@ -308,6 +309,20 @@ class CommonFastTokenizerTest(unittest.TestCase):
|
||||
self.assertEqual(len(tokens[key].shape), 2)
|
||||
self.assertEqual(tokens[key].shape[-1], 6)
|
||||
|
||||
def assert_create_token_type_ids(self, tokenizer_r, tokenizer_p):
|
||||
input_simple = [1, 2, 3]
|
||||
input_pair = [1, 2, 3]
|
||||
|
||||
# Generate output
|
||||
output_r = tokenizer_r.create_token_type_ids_from_sequences(input_simple)
|
||||
output_p = tokenizer_p.create_token_type_ids_from_sequences(input_simple)
|
||||
self.assertEqual(output_p, output_r)
|
||||
|
||||
# Generate pair output
|
||||
output_r = tokenizer_r.create_token_type_ids_from_sequences(input_simple, input_pair)
|
||||
output_p = tokenizer_p.create_token_type_ids_from_sequences(input_simple, input_pair)
|
||||
self.assertEqual(output_p, output_r)
|
||||
|
||||
def assert_build_inputs_with_special_tokens(self, tokenizer_r, tokenizer_p):
|
||||
# Input string
|
||||
input_simple = tokenizer_p.tokenize("This is a sample input")
|
||||
|
Loading…
Reference in New Issue
Block a user