diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 552d182190b..588f7f6134b 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -41,6 +41,7 @@ from .pytorch_utils import ( # noqa: F401 Conv1D, apply_chunking_to_forward, find_pruneable_heads_and_indices, + id_tensor_storage, prune_conv1d_layer, prune_layer, prune_linear_layer, @@ -304,26 +305,31 @@ def shard_checkpoint( """ max_shard_size = convert_file_size_to_int(max_shard_size) - sharded_state_dicts = [] - current_block = {} - current_block_size = 0 + sharded_state_dicts = [{}] + last_block_size = 0 total_size = 0 + storage_id_to_block = {} for key, weight in state_dict.items(): + storage_id = id_tensor_storage(weight) + + # If a `weight` shares the same underlying storage as another tensor, we put `weight` in the same `block` + if storage_id in storage_id_to_block: + block_id = storage_id_to_block[storage_id] + sharded_state_dicts[block_id][key] = weight + continue + weight_size = weight.numel() * dtype_byte_size(weight.dtype) # If this weight is going to tip up over the maximal size, we split. - if current_block_size + weight_size > max_shard_size: - sharded_state_dicts.append(current_block) - current_block = {} - current_block_size = 0 + if last_block_size + weight_size > max_shard_size: + sharded_state_dicts.append({}) + last_block_size = 0 - current_block[key] = weight - current_block_size += weight_size + sharded_state_dicts[-1][key] = weight + last_block_size += weight_size total_size += weight_size - - # Add the last block - sharded_state_dicts.append(current_block) + storage_id_to_block[storage_id] = len(sharded_state_dicts) - 1 # If we only have one shard, we return it if len(sharded_state_dicts) == 1: diff --git a/src/transformers/pytorch_utils.py b/src/transformers/pytorch_utils.py index 87c3e986d99..6c5fc5f4a41 100644 --- a/src/transformers/pytorch_utils.py +++ b/src/transformers/pytorch_utils.py @@ -16,6 +16,7 @@ from typing import Callable, List, Optional, Set, Tuple, Union import torch from packaging import version +from safetensors.torch import storage_ptr, storage_size from torch import nn from .utils import logging @@ -277,3 +278,13 @@ def meshgrid( if indexing != "ij": raise ValueError('torch.meshgrid only supports `indexing="ij"` for torch<1.10.') return torch.meshgrid(*tensors) + + +def id_tensor_storage(tensor: torch.Tensor) -> Tuple[torch.device, int, int]: + """ + Unique identifier to a tensor storage. Multiple different tensors can share the same underlying storage. For + example, "meta" tensors all share the same storage, and thus their identifier will all be equal. This identifier is + guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with + non-overlapping lifetimes may have the same id. + """ + return tensor.device, storage_ptr(tensor), storage_size(tensor)