No more dtype_byte_size() (#37144)

* No more dtype_byte_size()

* Remove function once again

* Fix rebase cruft

* Trigger tests
This commit is contained in:
Matt 2025-04-02 14:58:38 +01:00 committed by GitHub
parent 7613cf1a45
commit cbfa14823b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 6 additions and 92 deletions

View File

@ -17,7 +17,6 @@
import gc
import json
import os
import re
import warnings
from functools import partial
from pickle import UnpicklingError
@ -83,23 +82,6 @@ ACT2FN = {
}
def dtype_byte_size(dtype):
"""
Returns the size (in bytes) occupied by one parameter of type `dtype`. Example:
```py
>>> dtype_byte_size(np.float32)
4
```
"""
if dtype is bool:
return 1 / 8
bit_search = re.search(r"[^\d](\d+)$", dtype.name)
if bit_search is None:
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
bit_size = int(bit_search.groups()[0])
return bit_size // 8
def flax_shard_checkpoint(params, max_shard_size="10GB"):
"""
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
@ -131,7 +113,7 @@ def flax_shard_checkpoint(params, max_shard_size="10GB"):
# flatten the weights to chunk
weights = flatten_dict(params, sep="/")
for item in weights:
weight_size = weights[item].size * dtype_byte_size(weights[item].dtype)
weight_size = weights[item].size * weights[item].dtype.itemsize
# If this weight is going to tip up over the maximal size, we split.
if current_block_size + weight_size > max_shard_size:

View File

@ -617,26 +617,6 @@ def input_processing(func, config, **kwargs):
return output
def dtype_byte_size(dtype):
"""
Returns the size (in bytes) occupied by one parameter of type `dtype`.
Example:
```py
>>> dtype_byte_size(tf.float32)
4
```
"""
if dtype == tf.bool:
return 1 / 8
bit_search = re.search(r"[^\d](\d+)$", dtype.name)
if bit_search is None:
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
bit_size = int(bit_search.groups()[0])
return bit_size // 8
def strip_model_name_and_prefix(name, _prefix=None):
if _prefix is not None and name.startswith(_prefix):
name = name[len(_prefix) :]
@ -678,7 +658,7 @@ def tf_shard_checkpoint(weights, max_shard_size="10GB", weights_name: str = TF2_
total_size = 0
for item in weights:
weight_size = item.numpy().size * dtype_byte_size(item.dtype)
weight_size = item.numpy().size * item.dtype.size
# If this weight is going to tip up over the maximal size, we split.
if current_block_size + weight_size > max_shard_size:

View File

@ -385,26 +385,6 @@ def get_state_dict_dtype(state_dict):
return next(state_dict.values()).dtype
def dtype_byte_size(dtype):
"""
Returns the size (in bytes) occupied by one parameter of type `dtype`.
Example:
```py
>>> dtype_byte_size(torch.float32)
4
```
"""
if dtype == torch.bool:
return 1 / 8
bit_search = re.search(r"[^\d](\d+)_?", str(dtype))
if bit_search is None:
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
bit_size = int(bit_search.groups()[0])
return bit_size // 8
def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
"""
This is the same as
@ -5820,7 +5800,7 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict,
for param_name, device in accelerator_device_map.items():
param = model.get_parameter_or_buffer(param_name)
# The dtype of different parameters may be different with composite models or `keep_in_fp32_modules`
param_byte_count = math.prod(param.shape) * dtype_byte_size(param.dtype)
param_byte_count = math.prod(param.shape) * param.element_size()
if tp_plan_regex is not None:
generic_name = re.sub(r"\.\d+\.", ".*.", param_name)

View File

@ -19,7 +19,6 @@ import torch
from torch import nn
from transformers import NllbMoeConfig, NllbMoeModel
from transformers.modeling_utils import dtype_byte_size
from transformers.utils import WEIGHTS_INDEX_NAME, WEIGHTS_NAME
@ -86,8 +85,8 @@ def shard_on_the_fly(switch_checkpoint_path, dump_path, num_experts, dtype, weig
)
torch.save(expert_state, save_path)
sharded_state_dicts.append(expert_state.keys())
total_size += sum([value.numel() for key, value in expert_state.items()]) * dtype_byte_size(
expert_state[list(expert_state)[0]].dtype
total_size += sum([value.numel() for key, value in expert_state.items()]) * (
expert_state[list(expert_state)[0]].element_size()
)
# Add the last block

View File

@ -8,7 +8,6 @@ from flax import serialization
from flax.traverse_util import flatten_dict, unflatten_dict
from tensorflow.io import gfile
from transformers.modeling_utils import dtype_byte_size
from transformers.models.switch_transformers.convert_switch_transformers_original_flax_checkpoint_to_pytorch import (
rename_keys,
)
@ -94,7 +93,7 @@ def shard_on_the_fly(switch_checkpoint_path, dump_path, max_shard_size, dtype, w
# open tensorstore file
raw_weights = ts.open(unflatten_dict(all_layers[key])).result().read().result()
raw_weights = torch.tensor(raw_weights)
weight_size = raw_weights.numel() * dtype_byte_size(raw_weights.dtype)
weight_size = raw_weights.numel() * raw_weights.element_size()
# use the renaming pattern from the small conversion scripts
key, raw_weights = rename_base_flax_keys(tuple(key.split("/")), raw_weights)

View File

@ -116,7 +116,6 @@ if is_torch_available():
from transformers.modeling_utils import (
_find_disjoint,
_find_identical,
dtype_byte_size,
)
from transformers.pytorch_utils import isin_mps_friendly
@ -704,31 +703,6 @@ class ModelUtilsTest(TestCasePlus):
model = AutoModelForCausalLM.from_config(config=config, attn_implementation=requested_attn_implementation)
self.assertEqual(model.config._attn_implementation, requested_attn_implementation)
def test_torch_dtype_byte_sizes(self):
torch_dtypes_and_bytes = [
(torch.double, 8),
(torch.float64, 8),
(torch.float, 4),
(torch.float32, 4),
(torch.half, 2),
(torch.float16, 2),
(torch.bfloat16, 2),
(torch.long, 8),
(torch.int64, 8),
(torch.int, 4),
(torch.int32, 4),
(torch.short, 2),
(torch.int16, 2),
(torch.uint8, 1),
(torch.int8, 1),
(torch.float8_e4m3fn, 1),
(torch.float8_e5m2, 1),
(torch.bool, 0.125),
]
for torch_dtype, bytes_per_element in torch_dtypes_and_bytes:
self.assertEqual(dtype_byte_size(torch_dtype), bytes_per_element)
def test_no_super_init_config_and_model(self):
config = NoSuperInitConfig(attribute=32)
model = NoSuperInitModel(config)