mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Improve performance of load_state_dict
(#37902)
Improve performance of load_state_dict
This commit is contained in:
parent
410aa01901
commit
ee25d57ed1
@ -507,13 +507,14 @@ def load_state_dict(
|
||||
)
|
||||
state_dict = {}
|
||||
for k in f.keys():
|
||||
k_dtype = f.get_slice(k).get_dtype()
|
||||
if k_dtype in str_to_torch_dtype:
|
||||
dtype = str_to_torch_dtype[k_dtype]
|
||||
else:
|
||||
raise ValueError(f"Cannot load safetensors of unknown dtype {k_dtype}")
|
||||
if map_location == "meta":
|
||||
state_dict[k] = torch.empty(size=f.get_slice(k).get_shape(), dtype=dtype, device="meta")
|
||||
_slice = f.get_slice(k)
|
||||
k_dtype = _slice.get_dtype()
|
||||
if k_dtype in str_to_torch_dtype:
|
||||
dtype = str_to_torch_dtype[k_dtype]
|
||||
else:
|
||||
raise ValueError(f"Cannot load safetensors of unknown dtype {k_dtype}")
|
||||
state_dict[k] = torch.empty(size=_slice.get_shape(), dtype=dtype, device="meta")
|
||||
else:
|
||||
state_dict[k] = f.get_tensor(k)
|
||||
return state_dict
|
||||
|
Loading…
Reference in New Issue
Block a user