[MptConfig] support from pretrained args (#25116)

* support from pretrained args

* draft addition of tests

* update test

* use parrent assert true

* Update src/transformers/models/mpt/configuration_mpt.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
This commit is contained in:
Arthur 2023-07-27 16:24:52 +02:00 committed by GitHub
parent a1c4954d25
commit 9cea3e7b80
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 46 additions and 7 deletions

View File

@ -101,6 +101,23 @@ class MptAttentionConfig(PretrainedConfig):
f"`attn_type` has to be either `multihead_attention` or `multiquery_attention`. Received: {attn_type}"
)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
if config_dict.get("model_type") == "mpt":
config_dict = config_dict["attn_config"]
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
return cls.from_dict(config_dict, **kwargs)
class MptConfig(PretrainedConfig):
"""
@ -180,6 +197,7 @@ class MptConfig(PretrainedConfig):
"hidden_size": "d_model",
"num_hidden_layers": "n_layers",
}
is_composition = True
def __init__(
self,
@ -204,6 +222,7 @@ class MptConfig(PretrainedConfig):
initializer_range=0.02,
**kwargs,
):
self.attn_config = attn_config
self.d_model = d_model
self.n_heads = n_heads
self.n_layers = n_layers
@ -222,20 +241,25 @@ class MptConfig(PretrainedConfig):
self.layer_norm_epsilon = layer_norm_epsilon
self.use_cache = use_cache
self.initializer_range = initializer_range
super().__init__(**kwargs)
@property
def attn_config(self):
return self._attn_config
@attn_config.setter
def attn_config(self, attn_config):
if attn_config is None:
self.attn_config = MptAttentionConfig()
self._attn_config = MptAttentionConfig()
elif isinstance(attn_config, dict):
self.attn_config = MptAttentionConfig(**attn_config)
self._attn_config = MptAttentionConfig(**attn_config)
elif isinstance(attn_config, MptAttentionConfig):
self.attn_config = attn_config
self._attn_config = attn_config
else:
raise ValueError(
f"`attn_config` has to be either a `MptAttentionConfig` or a dictionary. Received: {type(attn_config)}"
)
super().__init__(**kwargs)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
@ -245,7 +269,8 @@ class MptConfig(PretrainedConfig):
"""
output = copy.deepcopy(self.__dict__)
output["attn_config"] = (
self.attn_config.to_dict() if not isinstance(self.attn_config, dict) else self.attn_config
self._attn_config.to_dict() if not isinstance(self.attn_config, dict) else self.attn_config
)
del output["_attn_config"]
output["model_type"] = self.__class__.model_type
return output

View File

@ -327,6 +327,20 @@ class MptModelTester:
return config, inputs_dict
class MptConfigTester(ConfigTester):
def __init__(self, parent, config_class=None, has_text_modality=True, common_properties=None, **kwargs):
super().__init__(parent, config_class, has_text_modality, common_properties, **kwargs)
def test_attn_config_as_dict(self):
config = self.config_class(**self.inputs_dict, attn_config={"attn_impl": "flash", "softmax_scale": None})
self.parent.assertTrue(config.attn_config.attn_impl == "flash")
self.parent.assertTrue(config.attn_config.softmax_scale is None)
def run_common_tests(self):
self.test_attn_config_as_dict()
return super().run_common_tests()
@require_torch
class MptModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (
@ -353,7 +367,7 @@ class MptModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
def setUp(self):
self.model_tester = MptModelTester(self)
self.config_tester = ConfigTester(self, config_class=MptConfig, n_embd=37)
self.config_tester = MptConfigTester(self, config_class=MptConfig, n_embd=37)
def test_config(self):
self.config_tester.run_common_tests()