Improve performance of load_state_dict (#37902)

Improve performance of load_state_dict
This commit is contained in:
woctordho 2025-05-01 22:35:17 +08:00 committed by GitHub
parent 410aa01901
commit ee25d57ed1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -507,13 +507,14 @@ def load_state_dict(
)
state_dict = {}
for k in f.keys():
k_dtype = f.get_slice(k).get_dtype()
if k_dtype in str_to_torch_dtype:
dtype = str_to_torch_dtype[k_dtype]
else:
raise ValueError(f"Cannot load safetensors of unknown dtype {k_dtype}")
if map_location == "meta":
state_dict[k] = torch.empty(size=f.get_slice(k).get_shape(), dtype=dtype, device="meta")
_slice = f.get_slice(k)
k_dtype = _slice.get_dtype()
if k_dtype in str_to_torch_dtype:
dtype = str_to_torch_dtype[k_dtype]
else:
raise ValueError(f"Cannot load safetensors of unknown dtype {k_dtype}")
state_dict[k] = torch.empty(size=_slice.get_shape(), dtype=dtype, device="meta")
else:
state_dict[k] = f.get_tensor(k)
return state_dict