mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix dtype in radnomly initialized head (#19690)
This commit is contained in:
parent
07f6690206
commit
344e2664d4
@ -2446,9 +2446,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
param = model_state_dict[key]
|
||||
if param.device == torch.device("meta"):
|
||||
if not load_in_8bit:
|
||||
set_module_tensor_to_device(model, key, "cpu", torch.empty(*param.size()))
|
||||
set_module_tensor_to_device(model, key, "cpu", torch.empty(*param.size(), dtype=dtype))
|
||||
else:
|
||||
set_module_8bit_tensor_to_device(model, key, "cpu", torch.empty(*param.size()))
|
||||
set_module_8bit_tensor_to_device(model, key, "cpu", torch.empty(*param.size(), dtype=dtype))
|
||||
|
||||
# retrieve unintialized modules and initialize before maybe overriding that with the pretrained weights.
|
||||
if _fast_init:
|
||||
|
Loading…
Reference in New Issue
Block a user