[Sentencepiece] make sure legacy do not require protobuf (#25684)

make sure legacy does not require `protobuf`
This commit is contained in:
Arthur 2023-08-25 14:41:04 +02:00 committed by GitHub
parent 0770ce6cfb
commit dd8b7d28ae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 21 additions and 12 deletions

View File

@ -27,9 +27,10 @@ from tokenizers import AddedToken, Regex, Tokenizer, decoders, normalizers, pre_
from tokenizers.models import BPE, Unigram, WordPiece
from .utils import is_protobuf_available, requires_backends
from .utils.import_utils import PROTOBUF_IMPORT_ERROR
def import_protobuf():
def import_protobuf(error_message=""):
if is_protobuf_available():
import google.protobuf
@ -37,7 +38,9 @@ def import_protobuf():
from transformers.utils import sentencepiece_model_pb2
else:
from transformers.utils import sentencepiece_model_pb2_new as sentencepiece_model_pb2
return sentencepiece_model_pb2
return sentencepiece_model_pb2
else:
raise ImportError(PROTOBUF_IMPORT_ERROR.format(error_message))
class SentencePieceExtractor:

View File

@ -162,14 +162,17 @@ class LlamaTokenizer(PreTrainedTokenizer):
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_spm_processor
def get_spm_processor(self):
tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs)
if self.legacy: # no dependency on protobuf
tokenizer.Load(self.vocab_file)
return tokenizer
with open(self.vocab_file, "rb") as f:
sp_model = f.read()
model_pb2 = import_protobuf()
model_pb2 = import_protobuf(f"The new behaviour of {self.__class__.__name__} (with `self.legacy = False`)")
model = model_pb2.ModelProto.FromString(sp_model)
if not self.legacy:
normalizer_spec = model_pb2.NormalizerSpec()
normalizer_spec.add_dummy_prefix = False
model.normalizer_spec.MergeFrom(normalizer_spec)
normalizer_spec = model_pb2.NormalizerSpec()
normalizer_spec.add_dummy_prefix = False
model.normalizer_spec.MergeFrom(normalizer_spec)
sp_model = model.SerializeToString()
tokenizer.LoadFromSerializedProto(sp_model)
return tokenizer

View File

@ -195,14 +195,17 @@ class T5Tokenizer(PreTrainedTokenizer):
def get_spm_processor(self):
tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs)
if self.legacy: # no dependency on protobuf
tokenizer.Load(self.vocab_file)
return tokenizer
with open(self.vocab_file, "rb") as f:
sp_model = f.read()
model_pb2 = import_protobuf()
model_pb2 = import_protobuf(f"The new behaviour of {self.__class__.__name__} (with `self.legacy = False`)")
model = model_pb2.ModelProto.FromString(sp_model)
if not self.legacy:
normalizer_spec = model_pb2.NormalizerSpec()
normalizer_spec.add_dummy_prefix = False
model.normalizer_spec.MergeFrom(normalizer_spec)
normalizer_spec = model_pb2.NormalizerSpec()
normalizer_spec.add_dummy_prefix = False
model.normalizer_spec.MergeFrom(normalizer_spec)
sp_model = model.SerializeToString()
tokenizer.LoadFromSerializedProto(sp_model)
return tokenizer