From 32e0db8a693d32963e4a0da83bc3ad87bc820835 Mon Sep 17 00:00:00 2001 From: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com> Date: Thu, 9 Jan 2025 17:46:50 +0100 Subject: [PATCH] [`tokenizers`] Ensure that add_prefix_space is propagated to backend_tokenizer.pre_tokenizer (#35593) * Ensure that add_prefix_space is propagated to backend_tokenizer.pre_tokenizer in PreTrainedTokenizerFast, rather than relying on subclasses to take care of this. * Simplify setting self.add_prefix_space, ensure pre_tok exists * Wrap in try-except to catch 'Custom PreTokenizer cannot be serialized' https://github.com/huggingface/tokenizers/blob/862d1a346a99183017b1eb5ad1aa3133b466784f/bindings/python/src/pre_tokenizers.rs#L672 produces the Exception. They're triggered by the roformer tests, as the RoFormerTokenizerFast uses a custom PreTokenizer. * Propagate add_prefix_space in T5TokenizerFast to superclass --- .../models/bart/tokenization_bart_fast.py | 10 +--------- .../blenderbot/tokenization_blenderbot_fast.py | 10 +--------- .../models/codegen/tokenization_codegen_fast.py | 10 ---------- .../models/deberta/tokenization_deberta_fast.py | 11 ----------- .../models/gpt2/tokenization_gpt2_fast.py | 11 ----------- .../models/gpt_neox/tokenization_gpt_neox_fast.py | 11 +---------- .../layoutlmv3/tokenization_layoutlmv3_fast.py | 10 +--------- .../models/led/tokenization_led_fast.py | 10 +--------- .../longformer/tokenization_longformer_fast.py | 10 +--------- .../models/markuplm/tokenization_markuplm_fast.py | 10 +--------- .../models/mvp/tokenization_mvp_fast.py | 10 +--------- .../models/roberta/tokenization_roberta_fast.py | 10 +--------- src/transformers/models/t5/tokenization_t5_fast.py | 1 + .../models/whisper/tokenization_whisper_fast.py | 9 +-------- src/transformers/tokenization_utils_fast.py | 13 +++++++++++++ tests/test_tokenization_common.py | 12 ++++++++++++ 16 files changed, 36 insertions(+), 122 deletions(-) diff --git a/src/transformers/models/bart/tokenization_bart_fast.py b/src/transformers/models/bart/tokenization_bart_fast.py index 4586ab4797e..f7c92b9d22c 100644 --- a/src/transformers/models/bart/tokenization_bart_fast.py +++ b/src/transformers/models/bart/tokenization_bart_fast.py @@ -16,7 +16,7 @@ import json from typing import List, Optional, Tuple -from tokenizers import pre_tokenizers, processors +from tokenizers import processors from ...tokenization_utils_base import AddedToken, BatchEncoding from ...tokenization_utils_fast import PreTrainedTokenizerFast @@ -157,14 +157,6 @@ class BartTokenizerFast(PreTrainedTokenizerFast): **kwargs, ) - pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__()) - if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space: - pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type")) - pre_tok_state["add_prefix_space"] = add_prefix_space - self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state) - - self.add_prefix_space = add_prefix_space - # the pre_tokenizer is already updated in the GPT2TokenizerFast `__init__` tokenizer_component = "post_processor" tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None) diff --git a/src/transformers/models/blenderbot/tokenization_blenderbot_fast.py b/src/transformers/models/blenderbot/tokenization_blenderbot_fast.py index f649246517d..8667fe76349 100644 --- a/src/transformers/models/blenderbot/tokenization_blenderbot_fast.py +++ b/src/transformers/models/blenderbot/tokenization_blenderbot_fast.py @@ -17,7 +17,7 @@ import json from typing import List, Optional, Tuple -from tokenizers import pre_tokenizers, processors +from tokenizers import processors from ...tokenization_utils_base import AddedToken, BatchEncoding from ...tokenization_utils_fast import PreTrainedTokenizerFast @@ -160,14 +160,6 @@ class BlenderbotTokenizerFast(PreTrainedTokenizerFast): **kwargs, ) - pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__()) - if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space: - pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type")) - pre_tok_state["add_prefix_space"] = add_prefix_space - self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state) - - self.add_prefix_space = add_prefix_space - tokenizer_component = "post_processor" tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None) if tokenizer_component_instance: diff --git a/src/transformers/models/codegen/tokenization_codegen_fast.py b/src/transformers/models/codegen/tokenization_codegen_fast.py index fcfe1d2795b..86782cf8070 100644 --- a/src/transformers/models/codegen/tokenization_codegen_fast.py +++ b/src/transformers/models/codegen/tokenization_codegen_fast.py @@ -14,7 +14,6 @@ # limitations under the License. """Tokenization classes for OpenAI GPT.""" -import json import re from typing import TYPE_CHECKING, List, Optional, Tuple, Union @@ -29,7 +28,6 @@ if TYPE_CHECKING: if is_tf_available(): import tensorflow as tf -from tokenizers import pre_tokenizers from ...tokenization_utils_base import BatchEncoding from ...tokenization_utils_fast import PreTrainedTokenizerFast @@ -137,14 +135,6 @@ class CodeGenTokenizerFast(PreTrainedTokenizerFast): " so that the fast tokenizer works correctly." ) - pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__()) - if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space: - pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type")) - pre_tok_state["add_prefix_space"] = add_prefix_space - self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state) - - self.add_prefix_space = add_prefix_space - def _batch_encode_plus(self, *args, **kwargs) -> BatchEncoding: is_split_into_words = kwargs.get("is_split_into_words", False) assert self.add_prefix_space or not is_split_into_words, ( diff --git a/src/transformers/models/deberta/tokenization_deberta_fast.py b/src/transformers/models/deberta/tokenization_deberta_fast.py index 39c64d90e53..368f29b522b 100644 --- a/src/transformers/models/deberta/tokenization_deberta_fast.py +++ b/src/transformers/models/deberta/tokenization_deberta_fast.py @@ -14,11 +14,8 @@ # limitations under the License. """Fast Tokenization class for model DeBERTa.""" -import json from typing import List, Optional, Tuple -from tokenizers import pre_tokenizers - from ...tokenization_utils_base import AddedToken, BatchEncoding from ...tokenization_utils_fast import PreTrainedTokenizerFast from ...utils import logging @@ -132,14 +129,6 @@ class DebertaTokenizerFast(PreTrainedTokenizerFast): ) self.add_bos_token = kwargs.pop("add_bos_token", False) - pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__()) - if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space: - pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type")) - pre_tok_state["add_prefix_space"] = add_prefix_space - self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state) - - self.add_prefix_space = add_prefix_space - @property def mask_token(self) -> str: """ diff --git a/src/transformers/models/gpt2/tokenization_gpt2_fast.py b/src/transformers/models/gpt2/tokenization_gpt2_fast.py index 07b48faad4e..81e67a818de 100644 --- a/src/transformers/models/gpt2/tokenization_gpt2_fast.py +++ b/src/transformers/models/gpt2/tokenization_gpt2_fast.py @@ -14,11 +14,8 @@ # limitations under the License. """Tokenization classes for OpenAI GPT.""" -import json from typing import Optional, Tuple -from tokenizers import pre_tokenizers - from ...tokenization_utils_base import BatchEncoding from ...tokenization_utils_fast import PreTrainedTokenizerFast from ...utils import logging @@ -109,14 +106,6 @@ class GPT2TokenizerFast(PreTrainedTokenizerFast): self.add_bos_token = kwargs.pop("add_bos_token", False) - pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__()) - if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space: - pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type")) - pre_tok_state["add_prefix_space"] = add_prefix_space - self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state) - - self.add_prefix_space = add_prefix_space - def _batch_encode_plus(self, *args, **kwargs) -> BatchEncoding: is_split_into_words = kwargs.get("is_split_into_words", False) assert self.add_prefix_space or not is_split_into_words, ( diff --git a/src/transformers/models/gpt_neox/tokenization_gpt_neox_fast.py b/src/transformers/models/gpt_neox/tokenization_gpt_neox_fast.py index 1df53f3776d..d2ea1c3984f 100644 --- a/src/transformers/models/gpt_neox/tokenization_gpt_neox_fast.py +++ b/src/transformers/models/gpt_neox/tokenization_gpt_neox_fast.py @@ -14,10 +14,9 @@ # limitations under the License. """Tokenization classes for GPTNeoX.""" -import json from typing import List, Optional, Tuple -from tokenizers import pre_tokenizers, processors +from tokenizers import processors from ...tokenization_utils_fast import PreTrainedTokenizerFast from ...utils import logging @@ -122,14 +121,6 @@ class GPTNeoXTokenizerFast(PreTrainedTokenizerFast): self._add_eos_token = add_eos_token self.update_post_processor() - pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__()) - if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space: - pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type")) - pre_tok_state["add_prefix_space"] = add_prefix_space - self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state) - - self.add_prefix_space = add_prefix_space - @property def add_eos_token(self): return self._add_eos_token diff --git a/src/transformers/models/layoutlmv3/tokenization_layoutlmv3_fast.py b/src/transformers/models/layoutlmv3/tokenization_layoutlmv3_fast.py index 934f04937b0..ff67d233ffe 100644 --- a/src/transformers/models/layoutlmv3/tokenization_layoutlmv3_fast.py +++ b/src/transformers/models/layoutlmv3/tokenization_layoutlmv3_fast.py @@ -20,7 +20,7 @@ and _encode_plus, in which the Rust tokenizer is used. import json from typing import Dict, List, Optional, Tuple, Union -from tokenizers import pre_tokenizers, processors +from tokenizers import processors from ...tokenization_utils_base import ( BatchEncoding, @@ -162,14 +162,6 @@ class LayoutLMv3TokenizerFast(PreTrainedTokenizerFast): **kwargs, ) - pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__()) - if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space: - pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type")) - pre_tok_state["add_prefix_space"] = add_prefix_space - self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state) - - self.add_prefix_space = add_prefix_space - tokenizer_component = "post_processor" tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None) if tokenizer_component_instance: diff --git a/src/transformers/models/led/tokenization_led_fast.py b/src/transformers/models/led/tokenization_led_fast.py index 5b36f513f5a..06e959e8754 100644 --- a/src/transformers/models/led/tokenization_led_fast.py +++ b/src/transformers/models/led/tokenization_led_fast.py @@ -17,7 +17,7 @@ import json from typing import Dict, List, Optional, Tuple, Union -from tokenizers import pre_tokenizers, processors +from tokenizers import processors from ...tokenization_utils_base import AddedToken, BatchEncoding, EncodedInput from ...tokenization_utils_fast import PreTrainedTokenizerFast @@ -157,14 +157,6 @@ class LEDTokenizerFast(PreTrainedTokenizerFast): **kwargs, ) - pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__()) - if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space: - pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type")) - pre_tok_state["add_prefix_space"] = add_prefix_space - self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state) - - self.add_prefix_space = add_prefix_space - # the pre_tokenizer is already updated in the GPT2TokenizerFast `__init__` tokenizer_component = "post_processor" tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None) diff --git a/src/transformers/models/longformer/tokenization_longformer_fast.py b/src/transformers/models/longformer/tokenization_longformer_fast.py index 3d3ca97a6f6..b8111b3d8a2 100644 --- a/src/transformers/models/longformer/tokenization_longformer_fast.py +++ b/src/transformers/models/longformer/tokenization_longformer_fast.py @@ -17,7 +17,7 @@ import json from typing import List, Optional, Tuple -from tokenizers import pre_tokenizers, processors +from tokenizers import processors from ...tokenization_utils_base import AddedToken, BatchEncoding from ...tokenization_utils_fast import PreTrainedTokenizerFast @@ -155,14 +155,6 @@ class LongformerTokenizerFast(PreTrainedTokenizerFast): **kwargs, ) - pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__()) - if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space: - pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type")) - pre_tok_state["add_prefix_space"] = add_prefix_space - self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state) - - self.add_prefix_space = add_prefix_space - tokenizer_component = "post_processor" tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None) if tokenizer_component_instance: diff --git a/src/transformers/models/markuplm/tokenization_markuplm_fast.py b/src/transformers/models/markuplm/tokenization_markuplm_fast.py index a7ef344f4e3..ec6808348ab 100644 --- a/src/transformers/models/markuplm/tokenization_markuplm_fast.py +++ b/src/transformers/models/markuplm/tokenization_markuplm_fast.py @@ -21,7 +21,7 @@ import json from functools import lru_cache from typing import Dict, List, Optional, Tuple, Union -from tokenizers import pre_tokenizers, processors +from tokenizers import processors from ...file_utils import PaddingStrategy, TensorType, add_end_docstrings from ...tokenization_utils_base import ( @@ -207,14 +207,6 @@ class MarkupLMTokenizerFast(PreTrainedTokenizerFast): self.tags_dict = tags_dict - pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__()) - if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space: - pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type")) - pre_tok_state["add_prefix_space"] = add_prefix_space - self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state) - - self.add_prefix_space = add_prefix_space - tokenizer_component = "post_processor" tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None) if tokenizer_component_instance: diff --git a/src/transformers/models/mvp/tokenization_mvp_fast.py b/src/transformers/models/mvp/tokenization_mvp_fast.py index a66b4e178e8..ae226812c83 100644 --- a/src/transformers/models/mvp/tokenization_mvp_fast.py +++ b/src/transformers/models/mvp/tokenization_mvp_fast.py @@ -16,7 +16,7 @@ import json from typing import List, Optional, Tuple -from tokenizers import pre_tokenizers, processors +from tokenizers import processors from ...tokenization_utils_base import AddedToken, BatchEncoding from ...tokenization_utils_fast import PreTrainedTokenizerFast @@ -160,14 +160,6 @@ class MvpTokenizerFast(PreTrainedTokenizerFast): **kwargs, ) - pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__()) - if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space: - pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type")) - pre_tok_state["add_prefix_space"] = add_prefix_space - self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state) - - self.add_prefix_space = add_prefix_space - # the pre_tokenizer is already updated in the GPT2TokenizerFast `__init__` tokenizer_component = "post_processor" tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None) diff --git a/src/transformers/models/roberta/tokenization_roberta_fast.py b/src/transformers/models/roberta/tokenization_roberta_fast.py index 336148f4138..cf288f4d8c7 100644 --- a/src/transformers/models/roberta/tokenization_roberta_fast.py +++ b/src/transformers/models/roberta/tokenization_roberta_fast.py @@ -17,7 +17,7 @@ import json from typing import List, Optional, Tuple -from tokenizers import pre_tokenizers, processors +from tokenizers import processors from ...tokenization_utils_base import AddedToken, BatchEncoding from ...tokenization_utils_fast import PreTrainedTokenizerFast @@ -154,14 +154,6 @@ class RobertaTokenizerFast(PreTrainedTokenizerFast): **kwargs, ) - pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__()) - if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space: - pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type")) - pre_tok_state["add_prefix_space"] = add_prefix_space - self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state) - - self.add_prefix_space = add_prefix_space - tokenizer_component = "post_processor" tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None) if tokenizer_component_instance: diff --git a/src/transformers/models/t5/tokenization_t5_fast.py b/src/transformers/models/t5/tokenization_t5_fast.py index 68d750aca56..8eb652728bf 100644 --- a/src/transformers/models/t5/tokenization_t5_fast.py +++ b/src/transformers/models/t5/tokenization_t5_fast.py @@ -124,6 +124,7 @@ class T5TokenizerFast(PreTrainedTokenizerFast): pad_token=pad_token, extra_ids=extra_ids, additional_special_tokens=additional_special_tokens, + add_prefix_space=add_prefix_space, **kwargs, ) diff --git a/src/transformers/models/whisper/tokenization_whisper_fast.py b/src/transformers/models/whisper/tokenization_whisper_fast.py index 9cc0b7c530f..9a2c6525440 100644 --- a/src/transformers/models/whisper/tokenization_whisper_fast.py +++ b/src/transformers/models/whisper/tokenization_whisper_fast.py @@ -22,7 +22,7 @@ from functools import lru_cache from typing import List, Optional, Tuple import numpy as np -from tokenizers import AddedToken, pre_tokenizers, processors +from tokenizers import AddedToken, processors from ...tokenization_utils_base import BatchEncoding from ...tokenization_utils_fast import PreTrainedTokenizerFast @@ -128,19 +128,12 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast): self.add_bos_token = kwargs.pop("add_bos_token", False) - pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__()) - if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space: - pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type")) - pre_tok_state["add_prefix_space"] = add_prefix_space - self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state) - if normalizer_file is not None: with open(normalizer_file, encoding="utf-8") as vocab_handle: self.english_spelling_normalizer = json.load(vocab_handle) else: self.english_spelling_normalizer = None - self.add_prefix_space = add_prefix_space self.timestamp_pat = re.compile(r"<\|(\d+\.\d+)\|>") self.language = language diff --git a/src/transformers/tokenization_utils_fast.py b/src/transformers/tokenization_utils_fast.py index cc7edbd5328..925069f2c2f 100644 --- a/src/transformers/tokenization_utils_fast.py +++ b/src/transformers/tokenization_utils_fast.py @@ -102,6 +102,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase): fast_tokenizer_file = kwargs.pop("tokenizer_file", None) from_slow = kwargs.pop("from_slow", False) added_tokens_decoder = kwargs.pop("added_tokens_decoder", {}) + self.add_prefix_space = kwargs.get("add_prefix_space", False) if from_slow and slow_tokenizer is None and self.slow_tokenizer_class is None: raise ValueError( @@ -206,6 +207,18 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase): if tokens: self.add_tokens(tokens) + try: + pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__()) + if pre_tok_state.get("add_prefix_space", self.add_prefix_space) != self.add_prefix_space: + pre_tok_class = getattr(pre_tokenizers_fast, pre_tok_state.pop("type")) + pre_tok_state["add_prefix_space"] = self.add_prefix_space + self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state) + except Exception: + # We'll get an error if there is no pre_tokenizer, or if it's a custom pre_tokenizer that can + # not be serialized. In those cases, we just ignore the error as there's no pre_tokenizer + # for which we need to update the `add_prefix_space` attribute. + pass + @property def is_fast(self) -> bool: return True diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index d6957757dc5..9bf90efd4b5 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -4684,3 +4684,15 @@ class TokenizerTesterMixin: with self.assertRaises(AttributeError, msg="conflicts with the method"): get_tokenizer_func(get_vocab=True) + + @parameterized.expand([(True,), (False,)]) + def test_rust_tokenizer_add_prefix_space(self, add_prefix_space): + if not self.test_rust_tokenizer: + self.skipTest(reason="test_rust_tokenizer is set to False") + + for tokenizer, pretrained_name, _ in self.tokenizers_list: + fast_tokenizer = tokenizer.from_pretrained(pretrained_name, add_prefix_space=add_prefix_space) + self.assertEqual(fast_tokenizer.add_prefix_space, add_prefix_space) + # Only the ByteLevel pre-tokenizer has the `add_prefix_space` attribute, we have to ensure that it's set correctly + if hasattr(fast_tokenizer.backend_tokenizer.pre_tokenizer, "add_prefix_space"): + self.assertEqual(fast_tokenizer.backend_tokenizer.pre_tokenizer.add_prefix_space, add_prefix_space)