mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
parent
994cad2790
commit
7652804d23
@ -833,7 +833,7 @@ def _load_state_dict_into_meta_model(
|
||||
bin_state_dict = None
|
||||
if shard_file.endswith(".safetensors"):
|
||||
file_pointer = safe_open(shard_file, framework="pt", device=tensor_device)
|
||||
else:
|
||||
elif shard_file.endswith(".bin"):
|
||||
map_location = "cpu"
|
||||
if (
|
||||
device_map is not None
|
||||
@ -848,6 +848,13 @@ def _load_state_dict_into_meta_model(
|
||||
|
||||
is_quantized = hf_quantizer is not None
|
||||
|
||||
# get full state dict
|
||||
if is_quantized:
|
||||
if shard_file.endswith(".safetensors"):
|
||||
full_state_dict = load_state_dict(shard_file, map_location="cpu")
|
||||
elif shard_file.endswith(".bin"):
|
||||
full_state_dict = bin_state_dict
|
||||
|
||||
for serialized_param_name, empty_param in state_dict.items():
|
||||
# serialized_param_name is the raw, serialized name
|
||||
# fixed_param_name is the model's equivalent
|
||||
@ -912,7 +919,7 @@ def _load_state_dict_into_meta_model(
|
||||
model,
|
||||
param,
|
||||
fixed_param_name,
|
||||
state_dict,
|
||||
full_state_dict,
|
||||
param_device=param_device,
|
||||
device_map=device_map,
|
||||
)
|
||||
@ -928,7 +935,7 @@ def _load_state_dict_into_meta_model(
|
||||
)
|
||||
else:
|
||||
hf_quantizer.create_quantized_param(
|
||||
model, param, fixed_param_name, param_device, state_dict, unexpected_keys
|
||||
model, param, fixed_param_name, param_device, full_state_dict, unexpected_keys
|
||||
)
|
||||
# For quantized modules with FSDP/DeepSpeed Stage 3, we need to quantize the parameter on the GPU
|
||||
# and then cast it to CPU to avoid excessive memory usage on each GPU
|
||||
|
Loading…
Reference in New Issue
Block a user