[TokenizationRoformerFast] Fix the save and loading (#28527)

* cleanup

* add a test

* update the test

* style

* revert part that allows to pickle the tokenizer
This commit is contained in:
Arthur 2024-01-16 16:37:15 +01:00 committed by GitHub
parent 716df5fb7e
commit 96d0883103
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 20 additions and 9 deletions

View File

@ -122,15 +122,19 @@ class RoFormerTokenizerFast(PreTrainedTokenizerFast):
**kwargs,
)
pre_tok_state = json.loads(self.backend_tokenizer.normalizer.__getstate__())
normalizer_state = json.loads(self.backend_tokenizer.normalizer.__getstate__())
if (
pre_tok_state.get("lowercase", do_lower_case) != do_lower_case
or pre_tok_state.get("strip_accents", strip_accents) != strip_accents
normalizer_state.get("lowercase", do_lower_case) != do_lower_case
or normalizer_state.get("strip_accents", strip_accents) != strip_accents
):
pre_tok_class = getattr(normalizers, pre_tok_state.pop("type"))
pre_tok_state["lowercase"] = do_lower_case
pre_tok_state["strip_accents"] = strip_accents
self.backend_tokenizer.normalizer = pre_tok_class(**pre_tok_state)
normalizer_class = getattr(normalizers, normalizer_state.pop("type"))
normalizer_state["lowercase"] = do_lower_case
normalizer_state["strip_accents"] = strip_accents
self.backend_tokenizer.normalizer = normalizer_class(**normalizer_state)
# Make sure we correctly set the custom PreTokenizer
vocab = self.backend_tokenizer.get_vocab()
self.backend_tokenizer.pre_tokenizer = PreTokenizer.custom(JiebaPreTokenizer(vocab))
self.do_lower_case = do_lower_case

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
from transformers import RoFormerTokenizer, RoFormerTokenizerFast
@ -71,6 +72,12 @@ class RoFormerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
def test_training_new_tokenizer_with_special_tokens_change(self):
pass
# can't serialise custom PreTokenizer
def test_save_slow_from_fast_and_reload_fast(self):
pass
for cls in [RoFormerTokenizer, RoFormerTokenizerFast]:
original = cls.from_pretrained("alchemab/antiberta2")
self.assertEqual(original.encode("生活的真谛是"), [1, 4, 4, 4, 4, 4, 4, 2])
with tempfile.TemporaryDirectory() as tmp_dir:
original.save_pretrained(tmp_dir)
new = cls.from_pretrained(tmp_dir)
self.assertEqual(new.encode("生活的真谛是"), [1, 4, 4, 4, 4, 4, 4, 2])