Fix disk offload for full safetensors checkpoints (#20497)

This commit is contained in:
Sylvain Gugger 2022-11-29 14:58:30 -05:00 committed by GitHub
parent 4aa630eeab
commit ab9fe45236
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -597,6 +597,9 @@ def _load_state_dict_into_meta_model(
# in int/uint/bool and not cast them.
if dtype is not None and torch.is_floating_point(param):
param = param.to(dtype)
# For compatibility with PyTorch which loads float16/bfloat16 weights in fp32
if is_safetensors and dtype is None and torch.is_floating_point(param):
param = param.to(torch.float32)
if device_map is None:
param_device = "cpu"
@ -2452,6 +2455,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if offload_state_dict is None:
offload_state_dict = True
is_sharded_safetensors = is_safetensors and sharded_metadata is not None
# Retrieve missing & unexpected_keys
model_state_dict = model.state_dict()
expected_keys = list(model_state_dict.keys())
@ -2567,12 +2571,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
folder = os.path.sep.join(resolved_archive_file[0].split(os.path.sep)[:-1])
if device_map is not None and is_safetensors:
param_device_map = expand_device_map(device_map, sharded_metadata["all_checkpoint_keys"])
param_device_map = expand_device_map(device_map, original_loaded_keys)
str_dtype = str(dtype).replace("torch.", "")
str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32"
if sharded_metadata is None:
archive_file = (
resolved_archive_file[0]
if isinstance(resolved_archive_file, (list, tuple))
else resolved_archive_file
)
weight_map = {p: archive_file for p in original_loaded_keys}
else:
weight_map = {p: os.path.join(folder, f) for p, f in sharded_metadata["weight_map"].items()}
offload_index = {
p: {"safetensors_file": os.path.join(folder, f), "weight_name": p, "dtype": str_dtype}
for p, f in sharded_metadata["weight_map"].items()
p: {"safetensors_file": f, "weight_name": p, "dtype": str_dtype}
for p, f in weight_map.items()
if param_device_map[p] == "disk"
}
@ -2606,7 +2619,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
state_dict_folder = None
state_dict_index = None
if is_safetensors:
if is_sharded_safetensors:
disk_only_shard_files = get_disk_only_shard_files(device_map, sharded_metadata=sharded_metadata)
disk_only_shard_files = [os.path.join(folder, f) for f in disk_only_shard_files]
else: