mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-20 21:18:21 +06:00
Make sure custom configs work with Transformers (#15569)
* Make sure custom configs work with Transformers * Apply code review suggestions
This commit is contained in:
parent
7732d0fe7a
commit
1f60bc46f3
@ -368,7 +368,7 @@ class PretrainedConfig(PushToHubMixin):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def name_or_path(self) -> str:
|
def name_or_path(self) -> str:
|
||||||
return self._name_or_path
|
return getattr(self, "_name_or_path", None)
|
||||||
|
|
||||||
@name_or_path.setter
|
@name_or_path.setter
|
||||||
def name_or_path(self, value):
|
def name_or_path(self, value):
|
||||||
|
@ -621,10 +621,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
weights instead.
|
weights instead.
|
||||||
"""
|
"""
|
||||||
output_embeddings = self.get_output_embeddings()
|
output_embeddings = self.get_output_embeddings()
|
||||||
if output_embeddings is not None and self.config.tie_word_embeddings:
|
if output_embeddings is not None and getattr(self.config, "tie_word_embeddings", True):
|
||||||
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
|
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
|
||||||
|
|
||||||
if self.config.is_encoder_decoder and self.config.tie_encoder_decoder:
|
if getattr(self.config, "is_encoder_decoder", False) and getattr(self.config, "tie_encoder_decoder", False):
|
||||||
if hasattr(self, self.base_model_prefix):
|
if hasattr(self, self.base_model_prefix):
|
||||||
self = getattr(self, self.base_model_prefix)
|
self = getattr(self, self.base_model_prefix)
|
||||||
self._tie_encoder_decoder_weights(self.encoder, self.decoder, self.base_model_prefix)
|
self._tie_encoder_decoder_weights(self.encoder, self.decoder, self.base_model_prefix)
|
||||||
|
@ -59,14 +59,14 @@ from transformers.testing_utils import (
|
|||||||
|
|
||||||
sys.path.append(str(Path(__file__).parent.parent / "utils"))
|
sys.path.append(str(Path(__file__).parent.parent / "utils"))
|
||||||
|
|
||||||
from test_module.custom_configuration import CustomConfig # noqa E402
|
from test_module.custom_configuration import CustomConfig, NoSuperInitConfig # noqa E402
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from test_module.custom_modeling import CustomModel
|
from test_module.custom_modeling import CustomModel, NoSuperInitModel
|
||||||
from transformers import (
|
from transformers import (
|
||||||
BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING,
|
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING,
|
||||||
@ -2091,6 +2091,15 @@ class ModelUtilsTest(TestCasePlus):
|
|||||||
model = AutoModel.from_pretrained(TINY_T5, torch_dtype=torch.float16)
|
model = AutoModel.from_pretrained(TINY_T5, torch_dtype=torch.float16)
|
||||||
self.assertEqual(model.dtype, torch.float16)
|
self.assertEqual(model.dtype, torch.float16)
|
||||||
|
|
||||||
|
def test_no_super_init_config_and_model(self):
|
||||||
|
config = NoSuperInitConfig(attribute=32)
|
||||||
|
model = NoSuperInitModel(config)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.save_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
model = NoSuperInitModel.from_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@is_staging_test
|
@is_staging_test
|
||||||
|
@ -7,3 +7,10 @@ class CustomConfig(PretrainedConfig):
|
|||||||
def __init__(self, attribute=1, **kwargs):
|
def __init__(self, attribute=1, **kwargs):
|
||||||
self.attribute = attribute
|
self.attribute = attribute
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class NoSuperInitConfig(PretrainedConfig):
|
||||||
|
model_type = "custom"
|
||||||
|
|
||||||
|
def __init__(self, attribute=1, **kwargs):
|
||||||
|
self.attribute = attribute
|
||||||
|
@ -2,7 +2,7 @@ import torch
|
|||||||
|
|
||||||
from transformers import PreTrainedModel
|
from transformers import PreTrainedModel
|
||||||
|
|
||||||
from .custom_configuration import CustomConfig
|
from .custom_configuration import CustomConfig, NoSuperInitConfig
|
||||||
|
|
||||||
|
|
||||||
class CustomModel(PreTrainedModel):
|
class CustomModel(PreTrainedModel):
|
||||||
@ -18,3 +18,18 @@ class CustomModel(PreTrainedModel):
|
|||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class NoSuperInitModel(PreTrainedModel):
|
||||||
|
config_class = NoSuperInitConfig
|
||||||
|
base_model_prefix = "custom"
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.linear = torch.nn.Linear(config.attribute, config.attribute)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.linear(x)
|
||||||
|
|
||||||
|
def _init_weights(self, module):
|
||||||
|
pass
|
||||||
|
Loading…
Reference in New Issue
Block a user