Remove T5 dependency from mT5 model (#20949)

make mt5 independent from t5
This commit is contained in:
Sujay 2023-01-05 00:21:54 +05:30 committed by GitHub
parent 9dcc881fa6
commit 15e17c99f9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 1860 additions and 26 deletions

View File

@ -1806,7 +1806,9 @@ else:
"MPNetPreTrainedModel",
]
)
_import_structure["models.mt5"].extend(["MT5EncoderModel", "MT5ForConditionalGeneration", "MT5Model"])
_import_structure["models.mt5"].extend(
["MT5EncoderModel", "MT5ForConditionalGeneration", "MT5Model", "MT5PreTrainedModel"]
)
_import_structure["models.mvp"].extend(
[
"MVP_PRETRAINED_MODEL_ARCHIVE_LIST",
@ -4922,7 +4924,7 @@ if TYPE_CHECKING:
MPNetModel,
MPNetPreTrainedModel,
)
from .models.mt5 import MT5EncoderModel, MT5ForConditionalGeneration, MT5Model
from .models.mt5 import MT5EncoderModel, MT5ForConditionalGeneration, MT5Model, MT5PreTrainedModel
from .models.mvp import (
MVP_PRETRAINED_MODEL_ARCHIVE_LIST,
MvpForCausalLM,

View File

@ -51,7 +51,13 @@ try:
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_mt5"] = ["MT5EncoderModel", "MT5ForConditionalGeneration", "MT5Model"]
_import_structure["modeling_mt5"] = [
"MT5EncoderModel",
"MT5ForConditionalGeneration",
"MT5Model",
"MT5PreTrainedModel",
"MT5Stack",
]
try:
if not is_tf_available():
@ -79,7 +85,7 @@ if TYPE_CHECKING:
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_mt5 import MT5EncoderModel, MT5ForConditionalGeneration, MT5Model
from .modeling_mt5 import MT5EncoderModel, MT5ForConditionalGeneration, MT5Model, MT5PreTrainedModel, MT5Stack
try:
if not is_tf_available():

File diff suppressed because it is too large Load Diff

View File

@ -4003,6 +4003,13 @@ class MT5Model(metaclass=DummyObject):
requires_backends(self, ["torch"])
class MT5PreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
MVP_PRETRAINED_MODEL_ARCHIVE_LIST = None

View File

@ -40,6 +40,7 @@ PRIVATE_MODELS = [
"LongT5Stack",
"RealmBertModel",
"T5Stack",
"MT5Stack",
"SwitchTransformersStack",
"TFDPRSpanPredictor",
"MaskFormerSwinModel",