mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
LlamaTokenizer should be picklable (#24681)
* LlamaTokenizer should be picklable * make fixup
This commit is contained in:
parent
9a5d468ba0
commit
fb3b22c3b9
@ -98,12 +98,13 @@ class LlamaTokenizer(PreTrainedTokenizer):
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
state["sp_model"] = None
|
||||
state["sp_model_proto"] = self.sp_model.serialized_model_proto()
|
||||
return state
|
||||
|
||||
def __setstate__(self, d):
|
||||
self.__dict__ = d
|
||||
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
||||
self.sp_model.Load(self.vocab_file)
|
||||
self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
|
@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import pickle
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
@ -285,6 +286,13 @@ class LlamaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
padding=False,
|
||||
)
|
||||
|
||||
def test_picklable(self):
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
shutil.copyfile(SAMPLE_VOCAB, f.name)
|
||||
tokenizer = LlamaTokenizer(f.name, keep_accents=True)
|
||||
pickled_tokenizer = pickle.dumps(tokenizer)
|
||||
pickle.loads(pickled_tokenizer)
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_sentencepiece
|
||||
|
Loading…
Reference in New Issue
Block a user