mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
fix GPT2 token's special_tokens_mask
when used with add_bos_token=True
(#19036)
This commit is contained in:
parent
0e24548081
commit
0efbb6e93e
@ -261,6 +261,38 @@ class GPT2Tokenizer(PreTrainedTokenizer):
|
||||
|
||||
return output + bos_token_ids + token_ids_1
|
||||
|
||||
def get_special_tokens_mask(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
||||
) -> List[int]:
|
||||
"""
|
||||
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
||||
special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods.
|
||||
|
||||
Args:
|
||||
token_ids_0 (`List[int]`):
|
||||
List of IDs.
|
||||
token_ids_1 (`List[int]`, *optional*):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not the token list is already formatted with special tokens for the model.
|
||||
|
||||
Returns:
|
||||
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
||||
"""
|
||||
if already_has_special_tokens:
|
||||
return super().get_special_tokens_mask(
|
||||
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
|
||||
)
|
||||
|
||||
if not self.add_bos_token:
|
||||
return super().get_special_tokens_mask(
|
||||
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=False
|
||||
)
|
||||
|
||||
if token_ids_1 is None:
|
||||
return [1] + ([0] * len(token_ids_0))
|
||||
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1))
|
||||
|
||||
def _tokenize(self, text):
|
||||
"""Tokenize a string."""
|
||||
bpe_tokens = []
|
||||
|
@ -250,3 +250,28 @@ class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
# tokenizer has no padding token
|
||||
def test_padding_different_model_input_name(self):
|
||||
pass
|
||||
|
||||
def test_special_tokens_mask_input_pairs_and_bos_token(self):
|
||||
# TODO: change to self.get_tokenizers() when the fast version is implemented
|
||||
tokenizers = [self.get_tokenizer(do_lower_case=False, add_bos_token=True)]
|
||||
for tokenizer in tokenizers:
|
||||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||||
sequence_0 = "Encode this."
|
||||
sequence_1 = "This one too please."
|
||||
encoded_sequence = tokenizer.encode(sequence_0, add_special_tokens=False)
|
||||
encoded_sequence += tokenizer.encode(sequence_1, add_special_tokens=False)
|
||||
encoded_sequence_dict = tokenizer.encode_plus(
|
||||
sequence_0,
|
||||
sequence_1,
|
||||
add_special_tokens=True,
|
||||
return_special_tokens_mask=True,
|
||||
)
|
||||
encoded_sequence_w_special = encoded_sequence_dict["input_ids"]
|
||||
special_tokens_mask = encoded_sequence_dict["special_tokens_mask"]
|
||||
self.assertEqual(len(special_tokens_mask), len(encoded_sequence_w_special))
|
||||
|
||||
filtered_sequence = [
|
||||
(x if not special_tokens_mask[i] else None) for i, x in enumerate(encoded_sequence_w_special)
|
||||
]
|
||||
filtered_sequence = [x for x in filtered_sequence if x is not None]
|
||||
self.assertEqual(encoded_sequence, filtered_sequence)
|
||||
|
Loading…
Reference in New Issue
Block a user