add comment on recursive weights loading

This commit is contained in:
Rémi Louf 2019-10-10 10:02:03 +02:00
parent 770b15b58c
commit 851ef592c5

View File

@ -383,6 +383,8 @@ class PreTrainedModel(nn.Module):
if metadata is not None:
state_dict._metadata = metadata
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
# so we need to apply the function recursively.
def load(module, prefix=''):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(