mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[MptConfig
] support from pretrained args (#25116)
* support from pretrained args * draft addition of tests * update test * use parrent assert true * Update src/transformers/models/mpt/configuration_mpt.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
This commit is contained in:
parent
a1c4954d25
commit
9cea3e7b80
@ -101,6 +101,23 @@ class MptAttentionConfig(PretrainedConfig):
|
||||
f"`attn_type` has to be either `multihead_attention` or `multiquery_attention`. Received: {attn_type}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs) -> "PretrainedConfig":
|
||||
cls._set_token_in_kwargs(kwargs)
|
||||
|
||||
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
if config_dict.get("model_type") == "mpt":
|
||||
config_dict = config_dict["attn_config"]
|
||||
|
||||
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
||||
logger.warning(
|
||||
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
||||
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
||||
)
|
||||
|
||||
return cls.from_dict(config_dict, **kwargs)
|
||||
|
||||
|
||||
class MptConfig(PretrainedConfig):
|
||||
"""
|
||||
@ -180,6 +197,7 @@ class MptConfig(PretrainedConfig):
|
||||
"hidden_size": "d_model",
|
||||
"num_hidden_layers": "n_layers",
|
||||
}
|
||||
is_composition = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -204,6 +222,7 @@ class MptConfig(PretrainedConfig):
|
||||
initializer_range=0.02,
|
||||
**kwargs,
|
||||
):
|
||||
self.attn_config = attn_config
|
||||
self.d_model = d_model
|
||||
self.n_heads = n_heads
|
||||
self.n_layers = n_layers
|
||||
@ -222,20 +241,25 @@ class MptConfig(PretrainedConfig):
|
||||
self.layer_norm_epsilon = layer_norm_epsilon
|
||||
self.use_cache = use_cache
|
||||
self.initializer_range = initializer_range
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
def attn_config(self):
|
||||
return self._attn_config
|
||||
|
||||
@attn_config.setter
|
||||
def attn_config(self, attn_config):
|
||||
if attn_config is None:
|
||||
self.attn_config = MptAttentionConfig()
|
||||
self._attn_config = MptAttentionConfig()
|
||||
elif isinstance(attn_config, dict):
|
||||
self.attn_config = MptAttentionConfig(**attn_config)
|
||||
self._attn_config = MptAttentionConfig(**attn_config)
|
||||
elif isinstance(attn_config, MptAttentionConfig):
|
||||
self.attn_config = attn_config
|
||||
self._attn_config = attn_config
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`attn_config` has to be either a `MptAttentionConfig` or a dictionary. Received: {type(attn_config)}"
|
||||
)
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
|
||||
@ -245,7 +269,8 @@ class MptConfig(PretrainedConfig):
|
||||
"""
|
||||
output = copy.deepcopy(self.__dict__)
|
||||
output["attn_config"] = (
|
||||
self.attn_config.to_dict() if not isinstance(self.attn_config, dict) else self.attn_config
|
||||
self._attn_config.to_dict() if not isinstance(self.attn_config, dict) else self.attn_config
|
||||
)
|
||||
del output["_attn_config"]
|
||||
output["model_type"] = self.__class__.model_type
|
||||
return output
|
||||
|
@ -327,6 +327,20 @@ class MptModelTester:
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
class MptConfigTester(ConfigTester):
|
||||
def __init__(self, parent, config_class=None, has_text_modality=True, common_properties=None, **kwargs):
|
||||
super().__init__(parent, config_class, has_text_modality, common_properties, **kwargs)
|
||||
|
||||
def test_attn_config_as_dict(self):
|
||||
config = self.config_class(**self.inputs_dict, attn_config={"attn_impl": "flash", "softmax_scale": None})
|
||||
self.parent.assertTrue(config.attn_config.attn_impl == "flash")
|
||||
self.parent.assertTrue(config.attn_config.softmax_scale is None)
|
||||
|
||||
def run_common_tests(self):
|
||||
self.test_attn_config_as_dict()
|
||||
return super().run_common_tests()
|
||||
|
||||
|
||||
@require_torch
|
||||
class MptModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
@ -353,7 +367,7 @@ class MptModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = MptModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=MptConfig, n_embd=37)
|
||||
self.config_tester = MptConfigTester(self, config_class=MptConfig, n_embd=37)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
Loading…
Reference in New Issue
Block a user