mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Fix disk offload for full safetensors checkpoints (#20497)
This commit is contained in:
parent
4aa630eeab
commit
ab9fe45236
@ -597,6 +597,9 @@ def _load_state_dict_into_meta_model(
|
||||
# in int/uint/bool and not cast them.
|
||||
if dtype is not None and torch.is_floating_point(param):
|
||||
param = param.to(dtype)
|
||||
# For compatibility with PyTorch which loads float16/bfloat16 weights in fp32
|
||||
if is_safetensors and dtype is None and torch.is_floating_point(param):
|
||||
param = param.to(torch.float32)
|
||||
|
||||
if device_map is None:
|
||||
param_device = "cpu"
|
||||
@ -2452,6 +2455,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
if offload_state_dict is None:
|
||||
offload_state_dict = True
|
||||
|
||||
is_sharded_safetensors = is_safetensors and sharded_metadata is not None
|
||||
# Retrieve missing & unexpected_keys
|
||||
model_state_dict = model.state_dict()
|
||||
expected_keys = list(model_state_dict.keys())
|
||||
@ -2567,12 +2571,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
|
||||
folder = os.path.sep.join(resolved_archive_file[0].split(os.path.sep)[:-1])
|
||||
if device_map is not None and is_safetensors:
|
||||
param_device_map = expand_device_map(device_map, sharded_metadata["all_checkpoint_keys"])
|
||||
param_device_map = expand_device_map(device_map, original_loaded_keys)
|
||||
|
||||
str_dtype = str(dtype).replace("torch.", "")
|
||||
str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32"
|
||||
if sharded_metadata is None:
|
||||
archive_file = (
|
||||
resolved_archive_file[0]
|
||||
if isinstance(resolved_archive_file, (list, tuple))
|
||||
else resolved_archive_file
|
||||
)
|
||||
weight_map = {p: archive_file for p in original_loaded_keys}
|
||||
else:
|
||||
weight_map = {p: os.path.join(folder, f) for p, f in sharded_metadata["weight_map"].items()}
|
||||
offload_index = {
|
||||
p: {"safetensors_file": os.path.join(folder, f), "weight_name": p, "dtype": str_dtype}
|
||||
for p, f in sharded_metadata["weight_map"].items()
|
||||
p: {"safetensors_file": f, "weight_name": p, "dtype": str_dtype}
|
||||
for p, f in weight_map.items()
|
||||
if param_device_map[p] == "disk"
|
||||
}
|
||||
|
||||
@ -2606,7 +2619,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
state_dict_folder = None
|
||||
state_dict_index = None
|
||||
|
||||
if is_safetensors:
|
||||
if is_sharded_safetensors:
|
||||
disk_only_shard_files = get_disk_only_shard_files(device_map, sharded_metadata=sharded_metadata)
|
||||
disk_only_shard_files = [os.path.join(folder, f) for f in disk_only_shard_files]
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user