Support shared tensors (#23871)

* Suport shared storage

* Really be sure we have the same storage

* Make style

* - Refactor storage identifier mechanism
 - Group everything into a single for loop

* Make style

* PR

* make style

* Update src/transformers/pytorch_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

---------

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Thomas Wang 2023-05-31 15:42:30 +02:00 committed by GitHub
parent 68d53bc717
commit d68d6665f9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 29 additions and 12 deletions

View File

@ -41,6 +41,7 @@ from .pytorch_utils import ( # noqa: F401
Conv1D,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
id_tensor_storage,
prune_conv1d_layer,
prune_layer,
prune_linear_layer,
@ -304,26 +305,31 @@ def shard_checkpoint(
"""
max_shard_size = convert_file_size_to_int(max_shard_size)
sharded_state_dicts = []
current_block = {}
current_block_size = 0
sharded_state_dicts = [{}]
last_block_size = 0
total_size = 0
storage_id_to_block = {}
for key, weight in state_dict.items():
storage_id = id_tensor_storage(weight)
# If a `weight` shares the same underlying storage as another tensor, we put `weight` in the same `block`
if storage_id in storage_id_to_block:
block_id = storage_id_to_block[storage_id]
sharded_state_dicts[block_id][key] = weight
continue
weight_size = weight.numel() * dtype_byte_size(weight.dtype)
# If this weight is going to tip up over the maximal size, we split.
if current_block_size + weight_size > max_shard_size:
sharded_state_dicts.append(current_block)
current_block = {}
current_block_size = 0
if last_block_size + weight_size > max_shard_size:
sharded_state_dicts.append({})
last_block_size = 0
current_block[key] = weight
current_block_size += weight_size
sharded_state_dicts[-1][key] = weight
last_block_size += weight_size
total_size += weight_size
# Add the last block
sharded_state_dicts.append(current_block)
storage_id_to_block[storage_id] = len(sharded_state_dicts) - 1
# If we only have one shard, we return it
if len(sharded_state_dicts) == 1:

View File

@ -16,6 +16,7 @@ from typing import Callable, List, Optional, Set, Tuple, Union
import torch
from packaging import version
from safetensors.torch import storage_ptr, storage_size
from torch import nn
from .utils import logging
@ -277,3 +278,13 @@ def meshgrid(
if indexing != "ij":
raise ValueError('torch.meshgrid only supports `indexing="ij"` for torch<1.10.')
return torch.meshgrid(*tensors)
def id_tensor_storage(tensor: torch.Tensor) -> Tuple[torch.device, int, int]:
"""
Unique identifier to a tensor storage. Multiple different tensors can share the same underlying storage. For
example, "meta" tensors all share the same storage, and thus their identifier will all be equal. This identifier is
guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with
non-overlapping lifetimes may have the same id.
"""
return tensor.device, storage_ptr(tensor), storage_size(tensor)