mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[TokenizationRoformerFast
] Fix the save and loading (#28527)
* cleanup * add a test * update the test * style * revert part that allows to pickle the tokenizer
This commit is contained in:
parent
716df5fb7e
commit
96d0883103
@ -122,15 +122,19 @@ class RoFormerTokenizerFast(PreTrainedTokenizerFast):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
pre_tok_state = json.loads(self.backend_tokenizer.normalizer.__getstate__())
|
||||
normalizer_state = json.loads(self.backend_tokenizer.normalizer.__getstate__())
|
||||
if (
|
||||
pre_tok_state.get("lowercase", do_lower_case) != do_lower_case
|
||||
or pre_tok_state.get("strip_accents", strip_accents) != strip_accents
|
||||
normalizer_state.get("lowercase", do_lower_case) != do_lower_case
|
||||
or normalizer_state.get("strip_accents", strip_accents) != strip_accents
|
||||
):
|
||||
pre_tok_class = getattr(normalizers, pre_tok_state.pop("type"))
|
||||
pre_tok_state["lowercase"] = do_lower_case
|
||||
pre_tok_state["strip_accents"] = strip_accents
|
||||
self.backend_tokenizer.normalizer = pre_tok_class(**pre_tok_state)
|
||||
normalizer_class = getattr(normalizers, normalizer_state.pop("type"))
|
||||
normalizer_state["lowercase"] = do_lower_case
|
||||
normalizer_state["strip_accents"] = strip_accents
|
||||
self.backend_tokenizer.normalizer = normalizer_class(**normalizer_state)
|
||||
|
||||
# Make sure we correctly set the custom PreTokenizer
|
||||
vocab = self.backend_tokenizer.get_vocab()
|
||||
self.backend_tokenizer.pre_tokenizer = PreTokenizer.custom(JiebaPreTokenizer(vocab))
|
||||
|
||||
self.do_lower_case = do_lower_case
|
||||
|
||||
|
@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import RoFormerTokenizer, RoFormerTokenizerFast
|
||||
@ -71,6 +72,12 @@ class RoFormerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
def test_training_new_tokenizer_with_special_tokens_change(self):
|
||||
pass
|
||||
|
||||
# can't serialise custom PreTokenizer
|
||||
def test_save_slow_from_fast_and_reload_fast(self):
|
||||
pass
|
||||
for cls in [RoFormerTokenizer, RoFormerTokenizerFast]:
|
||||
original = cls.from_pretrained("alchemab/antiberta2")
|
||||
self.assertEqual(original.encode("生活的真谛是"), [1, 4, 4, 4, 4, 4, 4, 2])
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
original.save_pretrained(tmp_dir)
|
||||
new = cls.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new.encode("生活的真谛是"), [1, 4, 4, 4, 4, 4, 4, 2])
|
||||
|
Loading…
Reference in New Issue
Block a user