mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Make sure custom configs work with Transformers (#15569)
* Make sure custom configs work with Transformers * Apply code review suggestions
This commit is contained in:
parent
7732d0fe7a
commit
1f60bc46f3
@ -368,7 +368,7 @@ class PretrainedConfig(PushToHubMixin):
|
||||
|
||||
@property
|
||||
def name_or_path(self) -> str:
|
||||
return self._name_or_path
|
||||
return getattr(self, "_name_or_path", None)
|
||||
|
||||
@name_or_path.setter
|
||||
def name_or_path(self, value):
|
||||
|
@ -621,10 +621,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
weights instead.
|
||||
"""
|
||||
output_embeddings = self.get_output_embeddings()
|
||||
if output_embeddings is not None and self.config.tie_word_embeddings:
|
||||
if output_embeddings is not None and getattr(self.config, "tie_word_embeddings", True):
|
||||
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
|
||||
|
||||
if self.config.is_encoder_decoder and self.config.tie_encoder_decoder:
|
||||
if getattr(self.config, "is_encoder_decoder", False) and getattr(self.config, "tie_encoder_decoder", False):
|
||||
if hasattr(self, self.base_model_prefix):
|
||||
self = getattr(self, self.base_model_prefix)
|
||||
self._tie_encoder_decoder_weights(self.encoder, self.decoder, self.base_model_prefix)
|
||||
|
@ -59,14 +59,14 @@ from transformers.testing_utils import (
|
||||
|
||||
sys.path.append(str(Path(__file__).parent.parent / "utils"))
|
||||
|
||||
from test_module.custom_configuration import CustomConfig # noqa E402
|
||||
from test_module.custom_configuration import CustomConfig, NoSuperInitConfig # noqa E402
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from test_module.custom_modeling import CustomModel
|
||||
from test_module.custom_modeling import CustomModel, NoSuperInitModel
|
||||
from transformers import (
|
||||
BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING,
|
||||
@ -2091,6 +2091,15 @@ class ModelUtilsTest(TestCasePlus):
|
||||
model = AutoModel.from_pretrained(TINY_T5, torch_dtype=torch.float16)
|
||||
self.assertEqual(model.dtype, torch.float16)
|
||||
|
||||
def test_no_super_init_config_and_model(self):
|
||||
config = NoSuperInitConfig(attribute=32)
|
||||
model = NoSuperInitModel(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir)
|
||||
|
||||
model = NoSuperInitModel.from_pretrained(tmp_dir)
|
||||
|
||||
|
||||
@require_torch
|
||||
@is_staging_test
|
||||
|
@ -7,3 +7,10 @@ class CustomConfig(PretrainedConfig):
|
||||
def __init__(self, attribute=1, **kwargs):
|
||||
self.attribute = attribute
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class NoSuperInitConfig(PretrainedConfig):
|
||||
model_type = "custom"
|
||||
|
||||
def __init__(self, attribute=1, **kwargs):
|
||||
self.attribute = attribute
|
||||
|
@ -2,7 +2,7 @@ import torch
|
||||
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
from .custom_configuration import CustomConfig
|
||||
from .custom_configuration import CustomConfig, NoSuperInitConfig
|
||||
|
||||
|
||||
class CustomModel(PreTrainedModel):
|
||||
@ -18,3 +18,18 @@ class CustomModel(PreTrainedModel):
|
||||
|
||||
def _init_weights(self, module):
|
||||
pass
|
||||
|
||||
|
||||
class NoSuperInitModel(PreTrainedModel):
|
||||
config_class = NoSuperInitConfig
|
||||
base_model_prefix = "custom"
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.linear = torch.nn.Linear(config.attribute, config.attribute)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
def _init_weights(self, module):
|
||||
pass
|
||||
|
Loading…
Reference in New Issue
Block a user