diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 8923b9d5c9a..bd09c1ae57d 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -5577,8 +5577,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi def get_parameter_or_buffer(self, target: str): """ Return the parameter or buffer given by `target` if it exists, otherwise throw an error. This combines - `get_parameter()` and `get_buffer()` in a single handy function. Note that it only work if `target` is a - leaf of the model. + `get_parameter()` and `get_buffer()` in a single handy function. If the target is an `_extra_state` attribute, + it will return the extra state provided by the module. Note that it only work if `target` is a leaf of the model. """ try: return self.get_parameter(target) @@ -5588,7 +5588,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi return self.get_buffer(target) except AttributeError: pass - raise AttributeError(f"`{target}` is neither a parameter nor a buffer.") + module, param_name = get_module_from_name(self, target) + if ( + param_name == "_extra_state" + and getattr(module.__class__, "get_extra_state", torch.nn.Module.get_extra_state) + is not torch.nn.Module.get_extra_state + ): + return module.get_extra_state() + + raise AttributeError(f"`{target}` is neither a parameter, buffer, nor extra state.") PreTrainedModel.push_to_hub = copy_func(PreTrainedModel.push_to_hub) diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index ca4e1cc3d42..cd0edd94571 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -2815,3 +2815,86 @@ class TestTensorSharing(TestCasePlus): shared_names, identical_names = _find_identical([{"a", "b"}], state_dict) self.assertEqual(shared_names, [{"a", "b"}]) self.assertEqual(identical_names, []) + + +@require_torch +class TestSaveAndLoadModelWithExtraState(TestCasePlus): + """ + This test checks that a model can be saved and loaded that uses the torch extra state API. + https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.get_extra_state. + + Currently, only tensor-valued extra_states are supported. + """ + + def test_save_and_load_model_with_tensor_extra_state(self): + class MyConfig(PretrainedConfig): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.some_counter = 0 + self.linear = torch.nn.Linear(320, 320) + + def get_extra_state(self): + return torch.tensor(self.some_counter) + + def set_extra_state(self, state): + self.some_counter = state.item() + + class MyModel(PreTrainedModel): + config_class = MyConfig + + def __init__(self, config: MyConfig): + super().__init__(config) + self.my_layer = MyModule() + + def forward(self, hidden_states, attention_mask): + return self.my_layer(hidden_states, attention_mask) + + config = MyConfig() + model = MyModel(config) + model.my_layer.some_counter = 42 + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model = MyModel.from_pretrained(tmpdirname) + self.assertEqual(model.my_layer.some_counter, 42) + + @mark.xfail(reason="save and from_pretrained currently only supports tensor extra_state") + def test_save_and_load_model_with_dict_extra_state(self): + class MyConfig(PretrainedConfig): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.some_counter = 0 + self.linear = torch.nn.Linear(320, 320) + + def get_extra_state(self): + return {"some_counter": self.some_counter} + + def set_extra_state(self, state): + self.some_counter = state["some_counter"] + + class MyModel(PreTrainedModel): + config_class = MyConfig + + def __init__(self, config: MyConfig): + super().__init__(config) + self.my_layer = MyModule() + + def forward(self, hidden_states, attention_mask): + return self.my_layer(hidden_states, attention_mask) + + config = MyConfig() + model = MyModel(config) + model.my_layer.some_counter = 42 + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model = MyModel.from_pretrained(tmpdirname) + self.assertEqual(model.my_layer.some_counter, 42)