diff --git a/src/transformers/models/roformer/tokenization_roformer_fast.py b/src/transformers/models/roformer/tokenization_roformer_fast.py index 360b76b843d..bed5935e90f 100644 --- a/src/transformers/models/roformer/tokenization_roformer_fast.py +++ b/src/transformers/models/roformer/tokenization_roformer_fast.py @@ -122,15 +122,19 @@ class RoFormerTokenizerFast(PreTrainedTokenizerFast): **kwargs, ) - pre_tok_state = json.loads(self.backend_tokenizer.normalizer.__getstate__()) + normalizer_state = json.loads(self.backend_tokenizer.normalizer.__getstate__()) if ( - pre_tok_state.get("lowercase", do_lower_case) != do_lower_case - or pre_tok_state.get("strip_accents", strip_accents) != strip_accents + normalizer_state.get("lowercase", do_lower_case) != do_lower_case + or normalizer_state.get("strip_accents", strip_accents) != strip_accents ): - pre_tok_class = getattr(normalizers, pre_tok_state.pop("type")) - pre_tok_state["lowercase"] = do_lower_case - pre_tok_state["strip_accents"] = strip_accents - self.backend_tokenizer.normalizer = pre_tok_class(**pre_tok_state) + normalizer_class = getattr(normalizers, normalizer_state.pop("type")) + normalizer_state["lowercase"] = do_lower_case + normalizer_state["strip_accents"] = strip_accents + self.backend_tokenizer.normalizer = normalizer_class(**normalizer_state) + + # Make sure we correctly set the custom PreTokenizer + vocab = self.backend_tokenizer.get_vocab() + self.backend_tokenizer.pre_tokenizer = PreTokenizer.custom(JiebaPreTokenizer(vocab)) self.do_lower_case = do_lower_case diff --git a/tests/models/roformer/test_tokenization_roformer.py b/tests/models/roformer/test_tokenization_roformer.py index 2d674100f02..3af411b6a80 100644 --- a/tests/models/roformer/test_tokenization_roformer.py +++ b/tests/models/roformer/test_tokenization_roformer.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import tempfile import unittest from transformers import RoFormerTokenizer, RoFormerTokenizerFast @@ -71,6 +72,12 @@ class RoFormerTokenizationTest(TokenizerTesterMixin, unittest.TestCase): def test_training_new_tokenizer_with_special_tokens_change(self): pass - # can't serialise custom PreTokenizer def test_save_slow_from_fast_and_reload_fast(self): - pass + for cls in [RoFormerTokenizer, RoFormerTokenizerFast]: + original = cls.from_pretrained("alchemab/antiberta2") + self.assertEqual(original.encode("生活的真谛是"), [1, 4, 4, 4, 4, 4, 4, 2]) + + with tempfile.TemporaryDirectory() as tmp_dir: + original.save_pretrained(tmp_dir) + new = cls.from_pretrained(tmp_dir) + self.assertEqual(new.encode("生活的真谛是"), [1, 4, 4, 4, 4, 4, 4, 2])