Fix bnb regression due to empty state dict (#36663)

fix
This commit is contained in:
Marc Sun 2025-03-12 11:40:46 +01:00 committed by GitHub
parent 994cad2790
commit 7652804d23
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -833,7 +833,7 @@ def _load_state_dict_into_meta_model(
bin_state_dict = None
if shard_file.endswith(".safetensors"):
file_pointer = safe_open(shard_file, framework="pt", device=tensor_device)
else:
elif shard_file.endswith(".bin"):
map_location = "cpu"
if (
device_map is not None
@ -848,6 +848,13 @@ def _load_state_dict_into_meta_model(
is_quantized = hf_quantizer is not None
# get full state dict
if is_quantized:
if shard_file.endswith(".safetensors"):
full_state_dict = load_state_dict(shard_file, map_location="cpu")
elif shard_file.endswith(".bin"):
full_state_dict = bin_state_dict
for serialized_param_name, empty_param in state_dict.items():
# serialized_param_name is the raw, serialized name
# fixed_param_name is the model's equivalent
@ -912,7 +919,7 @@ def _load_state_dict_into_meta_model(
model,
param,
fixed_param_name,
state_dict,
full_state_dict,
param_device=param_device,
device_map=device_map,
)
@ -928,7 +935,7 @@ def _load_state_dict_into_meta_model(
)
else:
hf_quantizer.create_quantized_param(
model, param, fixed_param_name, param_device, state_dict, unexpected_keys
model, param, fixed_param_name, param_device, full_state_dict, unexpected_keys
)
# For quantized modules with FSDP/DeepSpeed Stage 3, we need to quantize the parameter on the GPU
# and then cast it to CPU to avoid excessive memory usage on each GPU