mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-25 23:38:59 +06:00
Improve truncation_side (#14947)
* Enabling `truncation_side` for Slow and Fast tokenizer. Co-Authored-by: Niels Rogge <48327001+NielsRogge@users.noreply.github.com> * Disable failing tests. * Layout xlm. * assert -> assertEqual. Co-authored-by: Niels Rogge <48327001+NielsRogge@users.noreply.github.com>
This commit is contained in:
parent
8c2618e6aa
commit
d33dc7966a
@ -1437,6 +1437,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
# to make sure `tokenizer.pad(...)` works correctly
|
# to make sure `tokenizer.pad(...)` works correctly
|
||||||
model_input_names: List[str] = ["input_ids", "token_type_ids", "attention_mask"]
|
model_input_names: List[str] = ["input_ids", "token_type_ids", "attention_mask"]
|
||||||
padding_side: str = "right"
|
padding_side: str = "right"
|
||||||
|
truncation_side: str = "right"
|
||||||
slow_tokenizer_class = None
|
slow_tokenizer_class = None
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
@ -1514,7 +1515,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
return (
|
return (
|
||||||
f"{'PreTrainedTokenizerFast' if self.is_fast else 'PreTrainedTokenizer'}(name_or_path='{self.name_or_path}', "
|
f"{'PreTrainedTokenizerFast' if self.is_fast else 'PreTrainedTokenizer'}(name_or_path='{self.name_or_path}', "
|
||||||
f"vocab_size={self.vocab_size}, model_max_len={self.model_max_length}, is_fast={self.is_fast}, "
|
f"vocab_size={self.vocab_size}, model_max_len={self.model_max_length}, is_fast={self.is_fast}, "
|
||||||
f"padding_side='{self.padding_side}', special_tokens={self.special_tokens_map_extended})"
|
f"padding_side='{self.padding_side}', truncation_side='{self.truncation_side}', special_tokens={self.special_tokens_map_extended})"
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_vocab(self) -> Dict[str, int]:
|
def get_vocab(self) -> Dict[str, int]:
|
||||||
@ -3041,8 +3042,15 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
):
|
):
|
||||||
if len(ids) > num_tokens_to_remove:
|
if len(ids) > num_tokens_to_remove:
|
||||||
window_len = min(len(ids), stride + num_tokens_to_remove)
|
window_len = min(len(ids), stride + num_tokens_to_remove)
|
||||||
overflowing_tokens = ids[-window_len:]
|
if self.truncation_side == "left":
|
||||||
ids = ids[:-num_tokens_to_remove]
|
overflowing_tokens = ids[:window_len]
|
||||||
|
ids = ids[num_tokens_to_remove:]
|
||||||
|
elif self.truncation_side == "right":
|
||||||
|
overflowing_tokens = ids[-window_len:]
|
||||||
|
ids = ids[:-num_tokens_to_remove]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"invalid truncation strategy: {self.truncation_side}, use 'left' or 'right'.")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
error_msg = (
|
error_msg = (
|
||||||
f"We need to remove {num_tokens_to_remove} to truncate the input "
|
f"We need to remove {num_tokens_to_remove} to truncate the input "
|
||||||
@ -3063,14 +3071,30 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
)
|
)
|
||||||
for _ in range(num_tokens_to_remove):
|
for _ in range(num_tokens_to_remove):
|
||||||
if pair_ids is None or len(ids) > len(pair_ids):
|
if pair_ids is None or len(ids) > len(pair_ids):
|
||||||
ids = ids[:-1]
|
if self.truncation_side == "right":
|
||||||
|
ids = ids[:-1]
|
||||||
|
elif self.truncation_side == "left":
|
||||||
|
ids = ids[1:]
|
||||||
|
else:
|
||||||
|
raise ValueError("invalid truncation strategy:" + str(self.truncation_side))
|
||||||
else:
|
else:
|
||||||
pair_ids = pair_ids[:-1]
|
if self.truncation_side == "right":
|
||||||
|
pair_ids = pair_ids[:-1]
|
||||||
|
elif self.truncation_side == "left":
|
||||||
|
pair_ids = pair_ids[1:]
|
||||||
|
else:
|
||||||
|
raise ValueError("invalid truncation strategy:" + str(self.truncation_side))
|
||||||
elif truncation_strategy == TruncationStrategy.ONLY_SECOND and pair_ids is not None:
|
elif truncation_strategy == TruncationStrategy.ONLY_SECOND and pair_ids is not None:
|
||||||
if len(pair_ids) > num_tokens_to_remove:
|
if len(pair_ids) > num_tokens_to_remove:
|
||||||
window_len = min(len(pair_ids), stride + num_tokens_to_remove)
|
window_len = min(len(pair_ids), stride + num_tokens_to_remove)
|
||||||
overflowing_tokens = pair_ids[-window_len:]
|
if self.truncation_side == "right":
|
||||||
pair_ids = pair_ids[:-num_tokens_to_remove]
|
overflowing_tokens = pair_ids[-window_len:]
|
||||||
|
pair_ids = pair_ids[:-num_tokens_to_remove]
|
||||||
|
elif self.truncation_side == "left":
|
||||||
|
overflowing_tokens = pair_ids[:window_len]
|
||||||
|
pair_ids = pair_ids[num_tokens_to_remove:]
|
||||||
|
else:
|
||||||
|
raise ValueError("invalid truncation strategy:" + str(self.truncation_side))
|
||||||
else:
|
else:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"We need to remove {num_tokens_to_remove} to truncate the input "
|
f"We need to remove {num_tokens_to_remove} to truncate the input "
|
||||||
|
@ -356,6 +356,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
|
|||||||
"max_length": max_length,
|
"max_length": max_length,
|
||||||
"stride": stride,
|
"stride": stride,
|
||||||
"strategy": truncation_strategy.value,
|
"strategy": truncation_strategy.value,
|
||||||
|
"direction": self.truncation_side,
|
||||||
}
|
}
|
||||||
|
|
||||||
# _truncation might contain more keys that the target `transformers`
|
# _truncation might contain more keys that the target `transformers`
|
||||||
|
@ -209,6 +209,7 @@ class AutoTokenizerTest(unittest.TestCase):
|
|||||||
self.assertEqual(tokenizer.vocab_size, 30000)
|
self.assertEqual(tokenizer.vocab_size, 30000)
|
||||||
self.assertEqual(tokenizer.unk_token, "[UNK]")
|
self.assertEqual(tokenizer.unk_token, "[UNK]")
|
||||||
self.assertEqual(tokenizer.padding_side, "right")
|
self.assertEqual(tokenizer.padding_side, "right")
|
||||||
|
self.assertEqual(tokenizer.truncation_side, "right")
|
||||||
|
|
||||||
def test_auto_tokenizer_from_local_folder(self):
|
def test_auto_tokenizer_from_local_folder(self):
|
||||||
tokenizer = AutoTokenizer.from_pretrained(SMALL_MODEL_IDENTIFIER)
|
tokenizer = AutoTokenizer.from_pretrained(SMALL_MODEL_IDENTIFIER)
|
||||||
|
@ -1398,6 +1398,64 @@ class TokenizerTesterMixin:
|
|||||||
assert sequence_length == padded_sequence_left_length
|
assert sequence_length == padded_sequence_left_length
|
||||||
assert encoded_sequence == padded_sequence_left
|
assert encoded_sequence == padded_sequence_left
|
||||||
|
|
||||||
|
def test_right_and_left_truncation(self):
|
||||||
|
tokenizers = self.get_tokenizers(do_lower_case=False)
|
||||||
|
for tokenizer in tokenizers:
|
||||||
|
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||||||
|
sequence = "This is a test sequence"
|
||||||
|
|
||||||
|
# RIGHT PADDING - Check that it correctly pads when a maximum length is specified along with the padding flag set to True
|
||||||
|
truncation_size = 3
|
||||||
|
tokenizer.truncation_side = "right"
|
||||||
|
encoded_sequence = tokenizer.encode(sequence, add_special_tokens=False)
|
||||||
|
sequence_length = len(encoded_sequence)
|
||||||
|
# Remove EOS/BOS tokens
|
||||||
|
truncated_sequence = tokenizer.encode(
|
||||||
|
sequence, max_length=sequence_length - truncation_size, truncation=True, add_special_tokens=False
|
||||||
|
)
|
||||||
|
truncated_sequence_length = len(truncated_sequence)
|
||||||
|
self.assertEqual(sequence_length, truncated_sequence_length + truncation_size)
|
||||||
|
self.assertEqual(encoded_sequence[:-truncation_size], truncated_sequence)
|
||||||
|
|
||||||
|
# LEFT PADDING - Check that it correctly pads when a maximum length is specified along with the truncation flag set to True
|
||||||
|
tokenizer.truncation_side = "left"
|
||||||
|
sequence_length = len(encoded_sequence)
|
||||||
|
truncated_sequence = tokenizer.encode(
|
||||||
|
sequence, max_length=sequence_length - truncation_size, truncation=True, add_special_tokens=False
|
||||||
|
)
|
||||||
|
truncated_sequence_length = len(truncated_sequence)
|
||||||
|
self.assertEqual(sequence_length, truncated_sequence_length + truncation_size)
|
||||||
|
self.assertEqual(encoded_sequence[truncation_size:], truncated_sequence)
|
||||||
|
|
||||||
|
# RIGHT & LEFT PADDING - Check that nothing is done for 'longest' and 'no_truncation'
|
||||||
|
sequence_length = len(encoded_sequence)
|
||||||
|
|
||||||
|
tokenizer.truncation_side = "right"
|
||||||
|
truncated_sequence_right = tokenizer.encode(sequence, truncation=True, add_special_tokens=False)
|
||||||
|
truncated_sequence_right_length = len(truncated_sequence_right)
|
||||||
|
self.assertEqual(sequence_length, truncated_sequence_right_length)
|
||||||
|
self.assertEqual(encoded_sequence, truncated_sequence_right)
|
||||||
|
|
||||||
|
tokenizer.truncation_side = "left"
|
||||||
|
truncated_sequence_left = tokenizer.encode(
|
||||||
|
sequence, truncation="longest_first", add_special_tokens=False
|
||||||
|
)
|
||||||
|
truncated_sequence_left_length = len(truncated_sequence_left)
|
||||||
|
self.assertEqual(sequence_length, truncated_sequence_left_length)
|
||||||
|
self.assertEqual(encoded_sequence, truncated_sequence_left)
|
||||||
|
|
||||||
|
tokenizer.truncation_side = "right"
|
||||||
|
truncated_sequence_right = tokenizer.encode(sequence, add_special_tokens=False)
|
||||||
|
truncated_sequence_right_length = len(truncated_sequence_right)
|
||||||
|
self.assertEqual(sequence_length, truncated_sequence_right_length)
|
||||||
|
self.assertEqual(encoded_sequence, truncated_sequence_right)
|
||||||
|
|
||||||
|
tokenizer.truncation_side = "left"
|
||||||
|
truncated_sequence_left = tokenizer.encode(sequence, truncation=False, add_special_tokens=False)
|
||||||
|
truncated_sequence_left_length = len(truncated_sequence_left)
|
||||||
|
self.assertEqual(sequence_length, truncated_sequence_left_length)
|
||||||
|
self.assertEqual(encoded_sequence, truncated_sequence_left)
|
||||||
|
|
||||||
def test_padding_to_max_length(self):
|
def test_padding_to_max_length(self):
|
||||||
"""We keep this test for backward compatibility but it should be remove when `pad_to_max_length` is deprecated."""
|
"""We keep this test for backward compatibility but it should be remove when `pad_to_max_length` is deprecated."""
|
||||||
tokenizers = self.get_tokenizers(do_lower_case=False)
|
tokenizers = self.get_tokenizers(do_lower_case=False)
|
||||||
|
@ -371,6 +371,10 @@ class LayoutLMv2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
decoded = tokenizer.decode(encoded, spaces_between_special_tokens=self.space_between_special_tokens)
|
decoded = tokenizer.decode(encoded, spaces_between_special_tokens=self.space_between_special_tokens)
|
||||||
self.assertIn(decoded, [output, output.lower()])
|
self.assertIn(decoded, [output, output.lower()])
|
||||||
|
|
||||||
|
@unittest.skip("Not implemented")
|
||||||
|
def test_right_and_left_truncation(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def test_encode_plus_with_padding(self):
|
def test_encode_plus_with_padding(self):
|
||||||
tokenizers = self.get_tokenizers(do_lower_case=False)
|
tokenizers = self.get_tokenizers(do_lower_case=False)
|
||||||
for tokenizer in tokenizers:
|
for tokenizer in tokenizers:
|
||||||
|
@ -939,6 +939,10 @@ class LayoutXLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
shutil.rmtree(tmpdirname)
|
shutil.rmtree(tmpdirname)
|
||||||
|
|
||||||
|
@unittest.skip("Not implemented")
|
||||||
|
def test_right_and_left_truncation(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def test_right_and_left_padding(self):
|
def test_right_and_left_padding(self):
|
||||||
tokenizers = self.get_tokenizers(do_lower_case=False)
|
tokenizers = self.get_tokenizers(do_lower_case=False)
|
||||||
for tokenizer in tokenizers:
|
for tokenizer in tokenizers:
|
||||||
|
@ -904,6 +904,10 @@ class TapasTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
shutil.rmtree(tmpdirname)
|
shutil.rmtree(tmpdirname)
|
||||||
|
|
||||||
|
@unittest.skip("Not implemented")
|
||||||
|
def test_right_and_left_truncation(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def test_right_and_left_padding(self):
|
def test_right_and_left_padding(self):
|
||||||
tokenizers = self.get_tokenizers(do_lower_case=False)
|
tokenizers = self.get_tokenizers(do_lower_case=False)
|
||||||
for tokenizer in tokenizers:
|
for tokenizer in tokenizers:
|
||||||
|
Loading…
Reference in New Issue
Block a user