From cbfa14823b4ef762ebf138822cacc00c733be845 Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 2 Apr 2025 14:58:38 +0100 Subject: [PATCH] No more dtype_byte_size() (#37144) * No more dtype_byte_size() * Remove function once again * Fix rebase cruft * Trigger tests --- src/transformers/modeling_flax_utils.py | 20 +------------- src/transformers/modeling_tf_utils.py | 22 +--------------- src/transformers/modeling_utils.py | 22 +--------------- ..._sharded_original_checkpoint_to_pytorch.py | 5 ++-- .../switch_transformers/convert_big_switch.py | 3 +-- tests/utils/test_modeling_utils.py | 26 ------------------- 6 files changed, 6 insertions(+), 92 deletions(-) diff --git a/src/transformers/modeling_flax_utils.py b/src/transformers/modeling_flax_utils.py index 2f903968320..c775ee85bbe 100644 --- a/src/transformers/modeling_flax_utils.py +++ b/src/transformers/modeling_flax_utils.py @@ -17,7 +17,6 @@ import gc import json import os -import re import warnings from functools import partial from pickle import UnpicklingError @@ -83,23 +82,6 @@ ACT2FN = { } -def dtype_byte_size(dtype): - """ - Returns the size (in bytes) occupied by one parameter of type `dtype`. Example: - ```py - >>> dtype_byte_size(np.float32) - 4 - ``` - """ - if dtype is bool: - return 1 / 8 - bit_search = re.search(r"[^\d](\d+)$", dtype.name) - if bit_search is None: - raise ValueError(f"`dtype` is not a valid dtype: {dtype}.") - bit_size = int(bit_search.groups()[0]) - return bit_size // 8 - - def flax_shard_checkpoint(params, max_shard_size="10GB"): """ Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a @@ -131,7 +113,7 @@ def flax_shard_checkpoint(params, max_shard_size="10GB"): # flatten the weights to chunk weights = flatten_dict(params, sep="/") for item in weights: - weight_size = weights[item].size * dtype_byte_size(weights[item].dtype) + weight_size = weights[item].size * weights[item].dtype.itemsize # If this weight is going to tip up over the maximal size, we split. if current_block_size + weight_size > max_shard_size: diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index d569d97f855..a09bc430a44 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -617,26 +617,6 @@ def input_processing(func, config, **kwargs): return output -def dtype_byte_size(dtype): - """ - Returns the size (in bytes) occupied by one parameter of type `dtype`. - - Example: - - ```py - >>> dtype_byte_size(tf.float32) - 4 - ``` - """ - if dtype == tf.bool: - return 1 / 8 - bit_search = re.search(r"[^\d](\d+)$", dtype.name) - if bit_search is None: - raise ValueError(f"`dtype` is not a valid dtype: {dtype}.") - bit_size = int(bit_search.groups()[0]) - return bit_size // 8 - - def strip_model_name_and_prefix(name, _prefix=None): if _prefix is not None and name.startswith(_prefix): name = name[len(_prefix) :] @@ -678,7 +658,7 @@ def tf_shard_checkpoint(weights, max_shard_size="10GB", weights_name: str = TF2_ total_size = 0 for item in weights: - weight_size = item.numpy().size * dtype_byte_size(item.dtype) + weight_size = item.numpy().size * item.dtype.size # If this weight is going to tip up over the maximal size, we split. if current_block_size + weight_size > max_shard_size: diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 44d83c44e3f..331248fbf99 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -385,26 +385,6 @@ def get_state_dict_dtype(state_dict): return next(state_dict.values()).dtype -def dtype_byte_size(dtype): - """ - Returns the size (in bytes) occupied by one parameter of type `dtype`. - - Example: - - ```py - >>> dtype_byte_size(torch.float32) - 4 - ``` - """ - if dtype == torch.bool: - return 1 / 8 - bit_search = re.search(r"[^\d](\d+)_?", str(dtype)) - if bit_search is None: - raise ValueError(f"`dtype` is not a valid dtype: {dtype}.") - bit_size = int(bit_search.groups()[0]) - return bit_size // 8 - - def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True): """ This is the same as @@ -5820,7 +5800,7 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict, for param_name, device in accelerator_device_map.items(): param = model.get_parameter_or_buffer(param_name) # The dtype of different parameters may be different with composite models or `keep_in_fp32_modules` - param_byte_count = math.prod(param.shape) * dtype_byte_size(param.dtype) + param_byte_count = math.prod(param.shape) * param.element_size() if tp_plan_regex is not None: generic_name = re.sub(r"\.\d+\.", ".*.", param_name) diff --git a/src/transformers/models/nllb_moe/convert_nllb_moe_sharded_original_checkpoint_to_pytorch.py b/src/transformers/models/nllb_moe/convert_nllb_moe_sharded_original_checkpoint_to_pytorch.py index dd995bcbc6b..a84138a6246 100644 --- a/src/transformers/models/nllb_moe/convert_nllb_moe_sharded_original_checkpoint_to_pytorch.py +++ b/src/transformers/models/nllb_moe/convert_nllb_moe_sharded_original_checkpoint_to_pytorch.py @@ -19,7 +19,6 @@ import torch from torch import nn from transformers import NllbMoeConfig, NllbMoeModel -from transformers.modeling_utils import dtype_byte_size from transformers.utils import WEIGHTS_INDEX_NAME, WEIGHTS_NAME @@ -86,8 +85,8 @@ def shard_on_the_fly(switch_checkpoint_path, dump_path, num_experts, dtype, weig ) torch.save(expert_state, save_path) sharded_state_dicts.append(expert_state.keys()) - total_size += sum([value.numel() for key, value in expert_state.items()]) * dtype_byte_size( - expert_state[list(expert_state)[0]].dtype + total_size += sum([value.numel() for key, value in expert_state.items()]) * ( + expert_state[list(expert_state)[0]].element_size() ) # Add the last block diff --git a/src/transformers/models/switch_transformers/convert_big_switch.py b/src/transformers/models/switch_transformers/convert_big_switch.py index 70652c10cf1..6f422439fc7 100644 --- a/src/transformers/models/switch_transformers/convert_big_switch.py +++ b/src/transformers/models/switch_transformers/convert_big_switch.py @@ -8,7 +8,6 @@ from flax import serialization from flax.traverse_util import flatten_dict, unflatten_dict from tensorflow.io import gfile -from transformers.modeling_utils import dtype_byte_size from transformers.models.switch_transformers.convert_switch_transformers_original_flax_checkpoint_to_pytorch import ( rename_keys, ) @@ -94,7 +93,7 @@ def shard_on_the_fly(switch_checkpoint_path, dump_path, max_shard_size, dtype, w # open tensorstore file raw_weights = ts.open(unflatten_dict(all_layers[key])).result().read().result() raw_weights = torch.tensor(raw_weights) - weight_size = raw_weights.numel() * dtype_byte_size(raw_weights.dtype) + weight_size = raw_weights.numel() * raw_weights.element_size() # use the renaming pattern from the small conversion scripts key, raw_weights = rename_base_flax_keys(tuple(key.split("/")), raw_weights) diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 872369c1752..4ec2c497440 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -116,7 +116,6 @@ if is_torch_available(): from transformers.modeling_utils import ( _find_disjoint, _find_identical, - dtype_byte_size, ) from transformers.pytorch_utils import isin_mps_friendly @@ -704,31 +703,6 @@ class ModelUtilsTest(TestCasePlus): model = AutoModelForCausalLM.from_config(config=config, attn_implementation=requested_attn_implementation) self.assertEqual(model.config._attn_implementation, requested_attn_implementation) - def test_torch_dtype_byte_sizes(self): - torch_dtypes_and_bytes = [ - (torch.double, 8), - (torch.float64, 8), - (torch.float, 4), - (torch.float32, 4), - (torch.half, 2), - (torch.float16, 2), - (torch.bfloat16, 2), - (torch.long, 8), - (torch.int64, 8), - (torch.int, 4), - (torch.int32, 4), - (torch.short, 2), - (torch.int16, 2), - (torch.uint8, 1), - (torch.int8, 1), - (torch.float8_e4m3fn, 1), - (torch.float8_e5m2, 1), - (torch.bool, 0.125), - ] - - for torch_dtype, bytes_per_element in torch_dtypes_and_bytes: - self.assertEqual(dtype_byte_size(torch_dtype), bytes_per_element) - def test_no_super_init_config_and_model(self): config = NoSuperInitConfig(attribute=32) model = NoSuperInitModel(config)