Fix module initialization for root module under Zero3 (#33632)

* Use all state dict keys when checking if root module is initialized.

* Apply style corrections

* Add comment explaining change.

* Change comment phrasing.
This commit is contained in:
Ben Schneider 2024-10-03 08:41:50 -04:00 committed by GitHub
parent 4df3ccddb7
commit 95a2f5f6c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -617,6 +617,9 @@ def set_initialized_submodules(model, state_dict_keys):
not_initialized_submodules = {}
for module_name, module in model.named_modules():
loaded_keys = {k.replace(f"{module_name}.", "") for k in state_dict_keys if k.startswith(f"{module_name}.")}
# When checking if the root module is loaded all state_dict_keys must be used.
if module_name == "":
loaded_keys = set(state_dict_keys)
if loaded_keys.issuperset(module.state_dict()):
module._is_hf_initialized = True
else: