mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Support shared tensors (#23871)
* Suport shared storage * Really be sure we have the same storage * Make style * - Refactor storage identifier mechanism - Group everything into a single for loop * Make style * PR * make style * Update src/transformers/pytorch_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --------- Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
68d53bc717
commit
d68d6665f9
@ -41,6 +41,7 @@ from .pytorch_utils import ( # noqa: F401
|
||||
Conv1D,
|
||||
apply_chunking_to_forward,
|
||||
find_pruneable_heads_and_indices,
|
||||
id_tensor_storage,
|
||||
prune_conv1d_layer,
|
||||
prune_layer,
|
||||
prune_linear_layer,
|
||||
@ -304,26 +305,31 @@ def shard_checkpoint(
|
||||
"""
|
||||
max_shard_size = convert_file_size_to_int(max_shard_size)
|
||||
|
||||
sharded_state_dicts = []
|
||||
current_block = {}
|
||||
current_block_size = 0
|
||||
sharded_state_dicts = [{}]
|
||||
last_block_size = 0
|
||||
total_size = 0
|
||||
storage_id_to_block = {}
|
||||
|
||||
for key, weight in state_dict.items():
|
||||
storage_id = id_tensor_storage(weight)
|
||||
|
||||
# If a `weight` shares the same underlying storage as another tensor, we put `weight` in the same `block`
|
||||
if storage_id in storage_id_to_block:
|
||||
block_id = storage_id_to_block[storage_id]
|
||||
sharded_state_dicts[block_id][key] = weight
|
||||
continue
|
||||
|
||||
weight_size = weight.numel() * dtype_byte_size(weight.dtype)
|
||||
|
||||
# If this weight is going to tip up over the maximal size, we split.
|
||||
if current_block_size + weight_size > max_shard_size:
|
||||
sharded_state_dicts.append(current_block)
|
||||
current_block = {}
|
||||
current_block_size = 0
|
||||
if last_block_size + weight_size > max_shard_size:
|
||||
sharded_state_dicts.append({})
|
||||
last_block_size = 0
|
||||
|
||||
current_block[key] = weight
|
||||
current_block_size += weight_size
|
||||
sharded_state_dicts[-1][key] = weight
|
||||
last_block_size += weight_size
|
||||
total_size += weight_size
|
||||
|
||||
# Add the last block
|
||||
sharded_state_dicts.append(current_block)
|
||||
storage_id_to_block[storage_id] = len(sharded_state_dicts) - 1
|
||||
|
||||
# If we only have one shard, we return it
|
||||
if len(sharded_state_dicts) == 1:
|
||||
|
@ -16,6 +16,7 @@ from typing import Callable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
from safetensors.torch import storage_ptr, storage_size
|
||||
from torch import nn
|
||||
|
||||
from .utils import logging
|
||||
@ -277,3 +278,13 @@ def meshgrid(
|
||||
if indexing != "ij":
|
||||
raise ValueError('torch.meshgrid only supports `indexing="ij"` for torch<1.10.')
|
||||
return torch.meshgrid(*tensors)
|
||||
|
||||
|
||||
def id_tensor_storage(tensor: torch.Tensor) -> Tuple[torch.device, int, int]:
|
||||
"""
|
||||
Unique identifier to a tensor storage. Multiple different tensors can share the same underlying storage. For
|
||||
example, "meta" tensors all share the same storage, and thus their identifier will all be equal. This identifier is
|
||||
guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with
|
||||
non-overlapping lifetimes may have the same id.
|
||||
"""
|
||||
return tensor.device, storage_ptr(tensor), storage_size(tensor)
|
||||
|
Loading…
Reference in New Issue
Block a user