Fix MT5 init (#12591)

This commit is contained in:
Sylvain Gugger 2021-07-08 11:12:18 -04:00 committed by GitHub
parent 4da568c152
commit 75e63dbf70
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -29,24 +29,22 @@ from ...file_utils import (
if is_sentencepiece_available():
from ..t5.tokenization_t5 import T5Tokenizer
else:
from ...utils.dummy_sentencepiece_objects import T5Tokenizer
MT5Tokenizer = T5Tokenizer
MT5Tokenizer = T5Tokenizer
if is_tokenizers_available():
from ..t5.tokenization_t5_fast import T5TokenizerFast
else:
from ...utils.dummy_tokenizers_objects import T5TokenizerFast
MT5TokenizerFast = T5TokenizerFast
MT5TokenizerFast = T5TokenizerFast
_import_structure = {
"configuration_mt5": ["MT5Config"],
}
if is_sentencepiece_available():
_import_structure["."] = ["T5Tokenizer"] # Fake to get the same objects in both side.
if is_tokenizers_available():
_import_structure["."] = ["T5TokenizerFast"] # Fake to get the same objects in both side.
if is_torch_available():
_import_structure["modeling_mt5"] = ["MT5EncoderModel", "MT5ForConditionalGeneration", "MT5Model"]
@ -57,16 +55,6 @@ if is_tf_available():
if TYPE_CHECKING:
from .configuration_mt5 import MT5Config
if is_sentencepiece_available():
from ..t5.tokenization_t5 import T5Tokenizer
MT5Tokenizer = T5Tokenizer
if is_tokenizers_available():
from ..t5.tokenization_t5_fast import T5TokenizerFast
MT5TokenizerFast = T5TokenizerFast
if is_torch_available():
from .modeling_mt5 import MT5EncoderModel, MT5ForConditionalGeneration, MT5Model
@ -76,20 +64,7 @@ if TYPE_CHECKING:
else:
import sys
class _MT5LazyModule(_LazyModule):
"""
Module class that surfaces all objects but only performs associated imports when the objects are requested.
"""
def __getattr__(self, name):
if name == "MT5Tokenizer":
return MT5Tokenizer
elif name == "MT5TokenizerFast":
return MT5TokenizerFast
else:
return super().__getattr__(name)
sys.modules[__name__] = _MT5LazyModule(
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,