From bab40c6838c97f56022c0f3340b27aff89692b4d Mon Sep 17 00:00:00 2001 From: "Peter St. John" Date: Wed, 28 May 2025 07:38:42 -0600 Subject: [PATCH] [core] support tensor-valued _extra_state values in `from_pretrained` (#38155) Support tensor-valued _extra_state values TransformerEngine uses the pytorch get/set_extra_state API to store FP8 layer config information as bytes Tensor in the _extra_state entry in the state dict. With recent changes to from_pretrained, this functionality has broken and loading a model that uses this API doesn't appear to work. This PR fixes the save/load pretrained functions for extra state entries that use a pytorch tensor, and adds a (currently x-failing) test for a dictionary extra state. Signed-off-by: Peter St. John --- src/transformers/modeling_utils.py | 14 +++-- tests/utils/test_modeling_utils.py | 83 ++++++++++++++++++++++++++++++ 2 files changed, 94 insertions(+), 3 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 8923b9d5c9a..bd09c1ae57d 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -5577,8 +5577,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi def get_parameter_or_buffer(self, target: str): """ Return the parameter or buffer given by `target` if it exists, otherwise throw an error. This combines - `get_parameter()` and `get_buffer()` in a single handy function. Note that it only work if `target` is a - leaf of the model. + `get_parameter()` and `get_buffer()` in a single handy function. If the target is an `_extra_state` attribute, + it will return the extra state provided by the module. Note that it only work if `target` is a leaf of the model. """ try: return self.get_parameter(target) @@ -5588,7 +5588,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi return self.get_buffer(target) except AttributeError: pass - raise AttributeError(f"`{target}` is neither a parameter nor a buffer.") + module, param_name = get_module_from_name(self, target) + if ( + param_name == "_extra_state" + and getattr(module.__class__, "get_extra_state", torch.nn.Module.get_extra_state) + is not torch.nn.Module.get_extra_state + ): + return module.get_extra_state() + + raise AttributeError(f"`{target}` is neither a parameter, buffer, nor extra state.") PreTrainedModel.push_to_hub = copy_func(PreTrainedModel.push_to_hub) diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index ca4e1cc3d42..cd0edd94571 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -2815,3 +2815,86 @@ class TestTensorSharing(TestCasePlus): shared_names, identical_names = _find_identical([{"a", "b"}], state_dict) self.assertEqual(shared_names, [{"a", "b"}]) self.assertEqual(identical_names, []) + + +@require_torch +class TestSaveAndLoadModelWithExtraState(TestCasePlus): + """ + This test checks that a model can be saved and loaded that uses the torch extra state API. + https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.get_extra_state. + + Currently, only tensor-valued extra_states are supported. + """ + + def test_save_and_load_model_with_tensor_extra_state(self): + class MyConfig(PretrainedConfig): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.some_counter = 0 + self.linear = torch.nn.Linear(320, 320) + + def get_extra_state(self): + return torch.tensor(self.some_counter) + + def set_extra_state(self, state): + self.some_counter = state.item() + + class MyModel(PreTrainedModel): + config_class = MyConfig + + def __init__(self, config: MyConfig): + super().__init__(config) + self.my_layer = MyModule() + + def forward(self, hidden_states, attention_mask): + return self.my_layer(hidden_states, attention_mask) + + config = MyConfig() + model = MyModel(config) + model.my_layer.some_counter = 42 + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model = MyModel.from_pretrained(tmpdirname) + self.assertEqual(model.my_layer.some_counter, 42) + + @mark.xfail(reason="save and from_pretrained currently only supports tensor extra_state") + def test_save_and_load_model_with_dict_extra_state(self): + class MyConfig(PretrainedConfig): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.some_counter = 0 + self.linear = torch.nn.Linear(320, 320) + + def get_extra_state(self): + return {"some_counter": self.some_counter} + + def set_extra_state(self, state): + self.some_counter = state["some_counter"] + + class MyModel(PreTrainedModel): + config_class = MyConfig + + def __init__(self, config: MyConfig): + super().__init__(config) + self.my_layer = MyModule() + + def forward(self, hidden_states, attention_mask): + return self.my_layer(hidden_states, attention_mask) + + config = MyConfig() + model = MyModel(config) + model.my_layer.some_counter = 42 + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model = MyModel.from_pretrained(tmpdirname) + self.assertEqual(model.my_layer.some_counter, 42)