Make sure custom configs work with Transformers (#15569)

* Make sure custom configs work with Transformers

* Apply code review suggestions
This commit is contained in:
Sylvain Gugger 2022-02-09 10:04:44 -05:00 committed by GitHub
parent 7732d0fe7a
commit 1f60bc46f3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 37 additions and 6 deletions

View File

@ -368,7 +368,7 @@ class PretrainedConfig(PushToHubMixin):
@property
def name_or_path(self) -> str:
return self._name_or_path
return getattr(self, "_name_or_path", None)
@name_or_path.setter
def name_or_path(self, value):

View File

@ -621,10 +621,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
weights instead.
"""
output_embeddings = self.get_output_embeddings()
if output_embeddings is not None and self.config.tie_word_embeddings:
if output_embeddings is not None and getattr(self.config, "tie_word_embeddings", True):
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
if self.config.is_encoder_decoder and self.config.tie_encoder_decoder:
if getattr(self.config, "is_encoder_decoder", False) and getattr(self.config, "tie_encoder_decoder", False):
if hasattr(self, self.base_model_prefix):
self = getattr(self, self.base_model_prefix)
self._tie_encoder_decoder_weights(self.encoder, self.decoder, self.base_model_prefix)

View File

@ -59,14 +59,14 @@ from transformers.testing_utils import (
sys.path.append(str(Path(__file__).parent.parent / "utils"))
from test_module.custom_configuration import CustomConfig # noqa E402
from test_module.custom_configuration import CustomConfig, NoSuperInitConfig # noqa E402
if is_torch_available():
import torch
from torch import nn
from test_module.custom_modeling import CustomModel
from test_module.custom_modeling import CustomModel, NoSuperInitModel
from transformers import (
BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING,
@ -2091,6 +2091,15 @@ class ModelUtilsTest(TestCasePlus):
model = AutoModel.from_pretrained(TINY_T5, torch_dtype=torch.float16)
self.assertEqual(model.dtype, torch.float16)
def test_no_super_init_config_and_model(self):
config = NoSuperInitConfig(attribute=32)
model = NoSuperInitModel(config)
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
model = NoSuperInitModel.from_pretrained(tmp_dir)
@require_torch
@is_staging_test

View File

@ -7,3 +7,10 @@ class CustomConfig(PretrainedConfig):
def __init__(self, attribute=1, **kwargs):
self.attribute = attribute
super().__init__(**kwargs)
class NoSuperInitConfig(PretrainedConfig):
model_type = "custom"
def __init__(self, attribute=1, **kwargs):
self.attribute = attribute

View File

@ -2,7 +2,7 @@ import torch
from transformers import PreTrainedModel
from .custom_configuration import CustomConfig
from .custom_configuration import CustomConfig, NoSuperInitConfig
class CustomModel(PreTrainedModel):
@ -18,3 +18,18 @@ class CustomModel(PreTrainedModel):
def _init_weights(self, module):
pass
class NoSuperInitModel(PreTrainedModel):
config_class = NoSuperInitConfig
base_model_prefix = "custom"
def __init__(self, config):
super().__init__(config)
self.linear = torch.nn.Linear(config.attribute, config.attribute)
def forward(self, x):
return self.linear(x)
def _init_weights(self, module):
pass