No more dtype_byte_size() (#37144)

* No more dtype_byte_size()

* Remove function once again

* Fix rebase cruft

* Trigger tests
This commit is contained in:
Matt 2025-04-02 14:58:38 +01:00 committed by GitHub
parent 7613cf1a45
commit cbfa14823b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 6 additions and 92 deletions

View File

@ -17,7 +17,6 @@
import gc import gc
import json import json
import os import os
import re
import warnings import warnings
from functools import partial from functools import partial
from pickle import UnpicklingError from pickle import UnpicklingError
@ -83,23 +82,6 @@ ACT2FN = {
} }
def dtype_byte_size(dtype):
"""
Returns the size (in bytes) occupied by one parameter of type `dtype`. Example:
```py
>>> dtype_byte_size(np.float32)
4
```
"""
if dtype is bool:
return 1 / 8
bit_search = re.search(r"[^\d](\d+)$", dtype.name)
if bit_search is None:
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
bit_size = int(bit_search.groups()[0])
return bit_size // 8
def flax_shard_checkpoint(params, max_shard_size="10GB"): def flax_shard_checkpoint(params, max_shard_size="10GB"):
""" """
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
@ -131,7 +113,7 @@ def flax_shard_checkpoint(params, max_shard_size="10GB"):
# flatten the weights to chunk # flatten the weights to chunk
weights = flatten_dict(params, sep="/") weights = flatten_dict(params, sep="/")
for item in weights: for item in weights:
weight_size = weights[item].size * dtype_byte_size(weights[item].dtype) weight_size = weights[item].size * weights[item].dtype.itemsize
# If this weight is going to tip up over the maximal size, we split. # If this weight is going to tip up over the maximal size, we split.
if current_block_size + weight_size > max_shard_size: if current_block_size + weight_size > max_shard_size:

View File

@ -617,26 +617,6 @@ def input_processing(func, config, **kwargs):
return output return output
def dtype_byte_size(dtype):
"""
Returns the size (in bytes) occupied by one parameter of type `dtype`.
Example:
```py
>>> dtype_byte_size(tf.float32)
4
```
"""
if dtype == tf.bool:
return 1 / 8
bit_search = re.search(r"[^\d](\d+)$", dtype.name)
if bit_search is None:
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
bit_size = int(bit_search.groups()[0])
return bit_size // 8
def strip_model_name_and_prefix(name, _prefix=None): def strip_model_name_and_prefix(name, _prefix=None):
if _prefix is not None and name.startswith(_prefix): if _prefix is not None and name.startswith(_prefix):
name = name[len(_prefix) :] name = name[len(_prefix) :]
@ -678,7 +658,7 @@ def tf_shard_checkpoint(weights, max_shard_size="10GB", weights_name: str = TF2_
total_size = 0 total_size = 0
for item in weights: for item in weights:
weight_size = item.numpy().size * dtype_byte_size(item.dtype) weight_size = item.numpy().size * item.dtype.size
# If this weight is going to tip up over the maximal size, we split. # If this weight is going to tip up over the maximal size, we split.
if current_block_size + weight_size > max_shard_size: if current_block_size + weight_size > max_shard_size:

View File

@ -385,26 +385,6 @@ def get_state_dict_dtype(state_dict):
return next(state_dict.values()).dtype return next(state_dict.values()).dtype
def dtype_byte_size(dtype):
"""
Returns the size (in bytes) occupied by one parameter of type `dtype`.
Example:
```py
>>> dtype_byte_size(torch.float32)
4
```
"""
if dtype == torch.bool:
return 1 / 8
bit_search = re.search(r"[^\d](\d+)_?", str(dtype))
if bit_search is None:
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
bit_size = int(bit_search.groups()[0])
return bit_size // 8
def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True): def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
""" """
This is the same as This is the same as
@ -5820,7 +5800,7 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict,
for param_name, device in accelerator_device_map.items(): for param_name, device in accelerator_device_map.items():
param = model.get_parameter_or_buffer(param_name) param = model.get_parameter_or_buffer(param_name)
# The dtype of different parameters may be different with composite models or `keep_in_fp32_modules` # The dtype of different parameters may be different with composite models or `keep_in_fp32_modules`
param_byte_count = math.prod(param.shape) * dtype_byte_size(param.dtype) param_byte_count = math.prod(param.shape) * param.element_size()
if tp_plan_regex is not None: if tp_plan_regex is not None:
generic_name = re.sub(r"\.\d+\.", ".*.", param_name) generic_name = re.sub(r"\.\d+\.", ".*.", param_name)

View File

@ -19,7 +19,6 @@ import torch
from torch import nn from torch import nn
from transformers import NllbMoeConfig, NllbMoeModel from transformers import NllbMoeConfig, NllbMoeModel
from transformers.modeling_utils import dtype_byte_size
from transformers.utils import WEIGHTS_INDEX_NAME, WEIGHTS_NAME from transformers.utils import WEIGHTS_INDEX_NAME, WEIGHTS_NAME
@ -86,8 +85,8 @@ def shard_on_the_fly(switch_checkpoint_path, dump_path, num_experts, dtype, weig
) )
torch.save(expert_state, save_path) torch.save(expert_state, save_path)
sharded_state_dicts.append(expert_state.keys()) sharded_state_dicts.append(expert_state.keys())
total_size += sum([value.numel() for key, value in expert_state.items()]) * dtype_byte_size( total_size += sum([value.numel() for key, value in expert_state.items()]) * (
expert_state[list(expert_state)[0]].dtype expert_state[list(expert_state)[0]].element_size()
) )
# Add the last block # Add the last block

View File

@ -8,7 +8,6 @@ from flax import serialization
from flax.traverse_util import flatten_dict, unflatten_dict from flax.traverse_util import flatten_dict, unflatten_dict
from tensorflow.io import gfile from tensorflow.io import gfile
from transformers.modeling_utils import dtype_byte_size
from transformers.models.switch_transformers.convert_switch_transformers_original_flax_checkpoint_to_pytorch import ( from transformers.models.switch_transformers.convert_switch_transformers_original_flax_checkpoint_to_pytorch import (
rename_keys, rename_keys,
) )
@ -94,7 +93,7 @@ def shard_on_the_fly(switch_checkpoint_path, dump_path, max_shard_size, dtype, w
# open tensorstore file # open tensorstore file
raw_weights = ts.open(unflatten_dict(all_layers[key])).result().read().result() raw_weights = ts.open(unflatten_dict(all_layers[key])).result().read().result()
raw_weights = torch.tensor(raw_weights) raw_weights = torch.tensor(raw_weights)
weight_size = raw_weights.numel() * dtype_byte_size(raw_weights.dtype) weight_size = raw_weights.numel() * raw_weights.element_size()
# use the renaming pattern from the small conversion scripts # use the renaming pattern from the small conversion scripts
key, raw_weights = rename_base_flax_keys(tuple(key.split("/")), raw_weights) key, raw_weights = rename_base_flax_keys(tuple(key.split("/")), raw_weights)

View File

@ -116,7 +116,6 @@ if is_torch_available():
from transformers.modeling_utils import ( from transformers.modeling_utils import (
_find_disjoint, _find_disjoint,
_find_identical, _find_identical,
dtype_byte_size,
) )
from transformers.pytorch_utils import isin_mps_friendly from transformers.pytorch_utils import isin_mps_friendly
@ -704,31 +703,6 @@ class ModelUtilsTest(TestCasePlus):
model = AutoModelForCausalLM.from_config(config=config, attn_implementation=requested_attn_implementation) model = AutoModelForCausalLM.from_config(config=config, attn_implementation=requested_attn_implementation)
self.assertEqual(model.config._attn_implementation, requested_attn_implementation) self.assertEqual(model.config._attn_implementation, requested_attn_implementation)
def test_torch_dtype_byte_sizes(self):
torch_dtypes_and_bytes = [
(torch.double, 8),
(torch.float64, 8),
(torch.float, 4),
(torch.float32, 4),
(torch.half, 2),
(torch.float16, 2),
(torch.bfloat16, 2),
(torch.long, 8),
(torch.int64, 8),
(torch.int, 4),
(torch.int32, 4),
(torch.short, 2),
(torch.int16, 2),
(torch.uint8, 1),
(torch.int8, 1),
(torch.float8_e4m3fn, 1),
(torch.float8_e5m2, 1),
(torch.bool, 0.125),
]
for torch_dtype, bytes_per_element in torch_dtypes_and_bytes:
self.assertEqual(dtype_byte_size(torch_dtype), bytes_per_element)
def test_no_super_init_config_and_model(self): def test_no_super_init_config_and_model(self):
config = NoSuperInitConfig(attribute=32) config = NoSuperInitConfig(attribute=32)
model = NoSuperInitModel(config) model = NoSuperInitModel(config)