mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
[core] support tensor-valued _extra_state values in from_pretrained
(#38155)
Support tensor-valued _extra_state values TransformerEngine uses the pytorch get/set_extra_state API to store FP8 layer config information as bytes Tensor in the _extra_state entry in the state dict. With recent changes to from_pretrained, this functionality has broken and loading a model that uses this API doesn't appear to work. This PR fixes the save/load pretrained functions for extra state entries that use a pytorch tensor, and adds a (currently x-failing) test for a dictionary extra state. Signed-off-by: Peter St. John <pstjohn@nvidia.com>
This commit is contained in:
parent
badc71b9f6
commit
bab40c6838
@ -5577,8 +5577,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
def get_parameter_or_buffer(self, target: str):
|
def get_parameter_or_buffer(self, target: str):
|
||||||
"""
|
"""
|
||||||
Return the parameter or buffer given by `target` if it exists, otherwise throw an error. This combines
|
Return the parameter or buffer given by `target` if it exists, otherwise throw an error. This combines
|
||||||
`get_parameter()` and `get_buffer()` in a single handy function. Note that it only work if `target` is a
|
`get_parameter()` and `get_buffer()` in a single handy function. If the target is an `_extra_state` attribute,
|
||||||
leaf of the model.
|
it will return the extra state provided by the module. Note that it only work if `target` is a leaf of the model.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
return self.get_parameter(target)
|
return self.get_parameter(target)
|
||||||
@ -5588,7 +5588,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
return self.get_buffer(target)
|
return self.get_buffer(target)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
pass
|
pass
|
||||||
raise AttributeError(f"`{target}` is neither a parameter nor a buffer.")
|
module, param_name = get_module_from_name(self, target)
|
||||||
|
if (
|
||||||
|
param_name == "_extra_state"
|
||||||
|
and getattr(module.__class__, "get_extra_state", torch.nn.Module.get_extra_state)
|
||||||
|
is not torch.nn.Module.get_extra_state
|
||||||
|
):
|
||||||
|
return module.get_extra_state()
|
||||||
|
|
||||||
|
raise AttributeError(f"`{target}` is neither a parameter, buffer, nor extra state.")
|
||||||
|
|
||||||
|
|
||||||
PreTrainedModel.push_to_hub = copy_func(PreTrainedModel.push_to_hub)
|
PreTrainedModel.push_to_hub = copy_func(PreTrainedModel.push_to_hub)
|
||||||
|
@ -2815,3 +2815,86 @@ class TestTensorSharing(TestCasePlus):
|
|||||||
shared_names, identical_names = _find_identical([{"a", "b"}], state_dict)
|
shared_names, identical_names = _find_identical([{"a", "b"}], state_dict)
|
||||||
self.assertEqual(shared_names, [{"a", "b"}])
|
self.assertEqual(shared_names, [{"a", "b"}])
|
||||||
self.assertEqual(identical_names, [])
|
self.assertEqual(identical_names, [])
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class TestSaveAndLoadModelWithExtraState(TestCasePlus):
|
||||||
|
"""
|
||||||
|
This test checks that a model can be saved and loaded that uses the torch extra state API.
|
||||||
|
https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.get_extra_state.
|
||||||
|
|
||||||
|
Currently, only tensor-valued extra_states are supported.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_save_and_load_model_with_tensor_extra_state(self):
|
||||||
|
class MyConfig(PretrainedConfig):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
class MyModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.some_counter = 0
|
||||||
|
self.linear = torch.nn.Linear(320, 320)
|
||||||
|
|
||||||
|
def get_extra_state(self):
|
||||||
|
return torch.tensor(self.some_counter)
|
||||||
|
|
||||||
|
def set_extra_state(self, state):
|
||||||
|
self.some_counter = state.item()
|
||||||
|
|
||||||
|
class MyModel(PreTrainedModel):
|
||||||
|
config_class = MyConfig
|
||||||
|
|
||||||
|
def __init__(self, config: MyConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
self.my_layer = MyModule()
|
||||||
|
|
||||||
|
def forward(self, hidden_states, attention_mask):
|
||||||
|
return self.my_layer(hidden_states, attention_mask)
|
||||||
|
|
||||||
|
config = MyConfig()
|
||||||
|
model = MyModel(config)
|
||||||
|
model.my_layer.some_counter = 42
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
model.save_pretrained(tmpdirname)
|
||||||
|
model = MyModel.from_pretrained(tmpdirname)
|
||||||
|
self.assertEqual(model.my_layer.some_counter, 42)
|
||||||
|
|
||||||
|
@mark.xfail(reason="save and from_pretrained currently only supports tensor extra_state")
|
||||||
|
def test_save_and_load_model_with_dict_extra_state(self):
|
||||||
|
class MyConfig(PretrainedConfig):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
class MyModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.some_counter = 0
|
||||||
|
self.linear = torch.nn.Linear(320, 320)
|
||||||
|
|
||||||
|
def get_extra_state(self):
|
||||||
|
return {"some_counter": self.some_counter}
|
||||||
|
|
||||||
|
def set_extra_state(self, state):
|
||||||
|
self.some_counter = state["some_counter"]
|
||||||
|
|
||||||
|
class MyModel(PreTrainedModel):
|
||||||
|
config_class = MyConfig
|
||||||
|
|
||||||
|
def __init__(self, config: MyConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
self.my_layer = MyModule()
|
||||||
|
|
||||||
|
def forward(self, hidden_states, attention_mask):
|
||||||
|
return self.my_layer(hidden_states, attention_mask)
|
||||||
|
|
||||||
|
config = MyConfig()
|
||||||
|
model = MyModel(config)
|
||||||
|
model.my_layer.some_counter = 42
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
model.save_pretrained(tmpdirname)
|
||||||
|
model = MyModel.from_pretrained(tmpdirname)
|
||||||
|
self.assertEqual(model.my_layer.some_counter, 42)
|
||||||
|
Loading…
Reference in New Issue
Block a user