Fix dtype in radnomly initialized head (#19690)

This commit is contained in:
Sylvain Gugger 2022-10-17 15:54:23 -04:00 committed by GitHub
parent 07f6690206
commit 344e2664d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2446,9 +2446,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
param = model_state_dict[key]
if param.device == torch.device("meta"):
if not load_in_8bit:
set_module_tensor_to_device(model, key, "cpu", torch.empty(*param.size()))
set_module_tensor_to_device(model, key, "cpu", torch.empty(*param.size(), dtype=dtype))
else:
set_module_8bit_tensor_to_device(model, key, "cpu", torch.empty(*param.size()))
set_module_8bit_tensor_to_device(model, key, "cpu", torch.empty(*param.size(), dtype=dtype))
# retrieve unintialized modules and initialize before maybe overriding that with the pretrained weights.
if _fast_init: