mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
parent
30580f035b
commit
14b597f518
@ -732,6 +732,8 @@ def _infer_parameter_dtype(
|
||||
if keep_in_fp32_modules is not None and keep_in_fp32_modules.search(param_name):
|
||||
casting_dtype = torch.float32
|
||||
# Then dtype that was instantiated in the meta model -- note that this respects subconfigs dtypes
|
||||
elif hf_quantizer is not None:
|
||||
casting_dtype = model.config._pre_quantization_dtype
|
||||
else:
|
||||
casting_dtype = old_param.dtype
|
||||
return old_param is not None and old_param.is_contiguous(), casting_dtype
|
||||
@ -754,7 +756,6 @@ def _load_state_dict_into_meta_model(
|
||||
keep_in_fp32_modules: Optional[List[str]] = None,
|
||||
unexpected_keys: Optional[List[str]] = None, # passing `unexpected` for cleanup from quantization items
|
||||
device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
|
||||
weights_only=True,
|
||||
) -> Tuple[Optional[Dict], Optional[Dict]]:
|
||||
"""Load parameters from `meta_state_dict` into the model. The parameters of the `meta_state_dict` are on the meta
|
||||
device in order to easily infer the shapes and dtypes that they will have. Then proper parameters are then loaded
|
||||
@ -4883,7 +4884,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||
unexpected_keys=unexpected_keys,
|
||||
device_mesh=device_mesh,
|
||||
weights_only=weights_only,
|
||||
)
|
||||
else:
|
||||
assign_params = check_support_param_buffer_assignment(model_to_load, state_dict)
|
||||
|
Loading…
Reference in New Issue
Block a user