diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index e6b7031ab37..ead4ec80757 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2763,10 +2763,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi def smart_apply(self, fn): for module in self.children(): # We found a sub-model: recursively dispatch its own init function now! - if isinstance(module, PreTrainedModel): - module.smart_apply(module._initialize_weights) + # If a module has its own _init_weights function, that one takes priority. + # The logic is: + # 1. Recursively apply the top-level initialization function (`fn`) to all children. + # 2. Apply the submodule's own custom initializer (`_init_weights`) on top of that. + if hasattr(module, "_init_weights"): + module.init_weights() else: - module.smart_apply(fn) + module.apply(fn) fn(self) return self