diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 199adf825b7..fe257b56944 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -507,13 +507,14 @@ def load_state_dict( ) state_dict = {} for k in f.keys(): - k_dtype = f.get_slice(k).get_dtype() - if k_dtype in str_to_torch_dtype: - dtype = str_to_torch_dtype[k_dtype] - else: - raise ValueError(f"Cannot load safetensors of unknown dtype {k_dtype}") if map_location == "meta": - state_dict[k] = torch.empty(size=f.get_slice(k).get_shape(), dtype=dtype, device="meta") + _slice = f.get_slice(k) + k_dtype = _slice.get_dtype() + if k_dtype in str_to_torch_dtype: + dtype = str_to_torch_dtype[k_dtype] + else: + raise ValueError(f"Cannot load safetensors of unknown dtype {k_dtype}") + state_dict[k] = torch.empty(size=_slice.get_shape(), dtype=dtype, device="meta") else: state_dict[k] = f.get_tensor(k) return state_dict