only load state dict when the checkpoint is not None (#16673)

This commit is contained in:
Laura Hanu 2022-04-08 18:42:04 +01:00 committed by GitHub
parent d57da99237
commit f4d4f0a1ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1792,7 +1792,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# load pt weights early so that we know which dtype to init the model under
if from_pt:
if not is_sharded:
if not is_sharded and state_dict is None:
# Time to load the checkpoint
state_dict = load_state_dict(resolved_archive_file)
# set dtype to instantiate the model under: