mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 21:30:07 +06:00
fix(modeling_utils): Correctly call _init_weights in smart_apply
This commit is contained in:
parent
8e87adc45f
commit
f3d88d70f3
@ -2763,10 +2763,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
def smart_apply(self, fn):
|
def smart_apply(self, fn):
|
||||||
for module in self.children():
|
for module in self.children():
|
||||||
# We found a sub-model: recursively dispatch its own init function now!
|
# We found a sub-model: recursively dispatch its own init function now!
|
||||||
if isinstance(module, PreTrainedModel):
|
# If a module has its own _init_weights function, that one takes priority.
|
||||||
module.smart_apply(module._initialize_weights)
|
# The logic is:
|
||||||
|
# 1. Recursively apply the top-level initialization function (`fn`) to all children.
|
||||||
|
# 2. Apply the submodule's own custom initializer (`_init_weights`) on top of that.
|
||||||
|
if hasattr(module, "_init_weights"):
|
||||||
|
module.init_weights()
|
||||||
else:
|
else:
|
||||||
module.smart_apply(fn)
|
module.apply(fn)
|
||||||
fn(self)
|
fn(self)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user