mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
No more dtype_byte_size() (#37144)
* No more dtype_byte_size() * Remove function once again * Fix rebase cruft * Trigger tests
This commit is contained in:
parent
7613cf1a45
commit
cbfa14823b
@ -17,7 +17,6 @@
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import warnings
|
||||
from functools import partial
|
||||
from pickle import UnpicklingError
|
||||
@ -83,23 +82,6 @@ ACT2FN = {
|
||||
}
|
||||
|
||||
|
||||
def dtype_byte_size(dtype):
|
||||
"""
|
||||
Returns the size (in bytes) occupied by one parameter of type `dtype`. Example:
|
||||
```py
|
||||
>>> dtype_byte_size(np.float32)
|
||||
4
|
||||
```
|
||||
"""
|
||||
if dtype is bool:
|
||||
return 1 / 8
|
||||
bit_search = re.search(r"[^\d](\d+)$", dtype.name)
|
||||
if bit_search is None:
|
||||
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
|
||||
bit_size = int(bit_search.groups()[0])
|
||||
return bit_size // 8
|
||||
|
||||
|
||||
def flax_shard_checkpoint(params, max_shard_size="10GB"):
|
||||
"""
|
||||
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
|
||||
@ -131,7 +113,7 @@ def flax_shard_checkpoint(params, max_shard_size="10GB"):
|
||||
# flatten the weights to chunk
|
||||
weights = flatten_dict(params, sep="/")
|
||||
for item in weights:
|
||||
weight_size = weights[item].size * dtype_byte_size(weights[item].dtype)
|
||||
weight_size = weights[item].size * weights[item].dtype.itemsize
|
||||
|
||||
# If this weight is going to tip up over the maximal size, we split.
|
||||
if current_block_size + weight_size > max_shard_size:
|
||||
|
@ -617,26 +617,6 @@ def input_processing(func, config, **kwargs):
|
||||
return output
|
||||
|
||||
|
||||
def dtype_byte_size(dtype):
|
||||
"""
|
||||
Returns the size (in bytes) occupied by one parameter of type `dtype`.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
>>> dtype_byte_size(tf.float32)
|
||||
4
|
||||
```
|
||||
"""
|
||||
if dtype == tf.bool:
|
||||
return 1 / 8
|
||||
bit_search = re.search(r"[^\d](\d+)$", dtype.name)
|
||||
if bit_search is None:
|
||||
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
|
||||
bit_size = int(bit_search.groups()[0])
|
||||
return bit_size // 8
|
||||
|
||||
|
||||
def strip_model_name_and_prefix(name, _prefix=None):
|
||||
if _prefix is not None and name.startswith(_prefix):
|
||||
name = name[len(_prefix) :]
|
||||
@ -678,7 +658,7 @@ def tf_shard_checkpoint(weights, max_shard_size="10GB", weights_name: str = TF2_
|
||||
total_size = 0
|
||||
|
||||
for item in weights:
|
||||
weight_size = item.numpy().size * dtype_byte_size(item.dtype)
|
||||
weight_size = item.numpy().size * item.dtype.size
|
||||
|
||||
# If this weight is going to tip up over the maximal size, we split.
|
||||
if current_block_size + weight_size > max_shard_size:
|
||||
|
@ -385,26 +385,6 @@ def get_state_dict_dtype(state_dict):
|
||||
return next(state_dict.values()).dtype
|
||||
|
||||
|
||||
def dtype_byte_size(dtype):
|
||||
"""
|
||||
Returns the size (in bytes) occupied by one parameter of type `dtype`.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
>>> dtype_byte_size(torch.float32)
|
||||
4
|
||||
```
|
||||
"""
|
||||
if dtype == torch.bool:
|
||||
return 1 / 8
|
||||
bit_search = re.search(r"[^\d](\d+)_?", str(dtype))
|
||||
if bit_search is None:
|
||||
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
|
||||
bit_size = int(bit_search.groups()[0])
|
||||
return bit_size // 8
|
||||
|
||||
|
||||
def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
|
||||
"""
|
||||
This is the same as
|
||||
@ -5820,7 +5800,7 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict,
|
||||
for param_name, device in accelerator_device_map.items():
|
||||
param = model.get_parameter_or_buffer(param_name)
|
||||
# The dtype of different parameters may be different with composite models or `keep_in_fp32_modules`
|
||||
param_byte_count = math.prod(param.shape) * dtype_byte_size(param.dtype)
|
||||
param_byte_count = math.prod(param.shape) * param.element_size()
|
||||
|
||||
if tp_plan_regex is not None:
|
||||
generic_name = re.sub(r"\.\d+\.", ".*.", param_name)
|
||||
|
@ -19,7 +19,6 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
from transformers import NllbMoeConfig, NllbMoeModel
|
||||
from transformers.modeling_utils import dtype_byte_size
|
||||
from transformers.utils import WEIGHTS_INDEX_NAME, WEIGHTS_NAME
|
||||
|
||||
|
||||
@ -86,8 +85,8 @@ def shard_on_the_fly(switch_checkpoint_path, dump_path, num_experts, dtype, weig
|
||||
)
|
||||
torch.save(expert_state, save_path)
|
||||
sharded_state_dicts.append(expert_state.keys())
|
||||
total_size += sum([value.numel() for key, value in expert_state.items()]) * dtype_byte_size(
|
||||
expert_state[list(expert_state)[0]].dtype
|
||||
total_size += sum([value.numel() for key, value in expert_state.items()]) * (
|
||||
expert_state[list(expert_state)[0]].element_size()
|
||||
)
|
||||
|
||||
# Add the last block
|
||||
|
@ -8,7 +8,6 @@ from flax import serialization
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
from tensorflow.io import gfile
|
||||
|
||||
from transformers.modeling_utils import dtype_byte_size
|
||||
from transformers.models.switch_transformers.convert_switch_transformers_original_flax_checkpoint_to_pytorch import (
|
||||
rename_keys,
|
||||
)
|
||||
@ -94,7 +93,7 @@ def shard_on_the_fly(switch_checkpoint_path, dump_path, max_shard_size, dtype, w
|
||||
# open tensorstore file
|
||||
raw_weights = ts.open(unflatten_dict(all_layers[key])).result().read().result()
|
||||
raw_weights = torch.tensor(raw_weights)
|
||||
weight_size = raw_weights.numel() * dtype_byte_size(raw_weights.dtype)
|
||||
weight_size = raw_weights.numel() * raw_weights.element_size()
|
||||
|
||||
# use the renaming pattern from the small conversion scripts
|
||||
key, raw_weights = rename_base_flax_keys(tuple(key.split("/")), raw_weights)
|
||||
|
@ -116,7 +116,6 @@ if is_torch_available():
|
||||
from transformers.modeling_utils import (
|
||||
_find_disjoint,
|
||||
_find_identical,
|
||||
dtype_byte_size,
|
||||
)
|
||||
from transformers.pytorch_utils import isin_mps_friendly
|
||||
|
||||
@ -704,31 +703,6 @@ class ModelUtilsTest(TestCasePlus):
|
||||
model = AutoModelForCausalLM.from_config(config=config, attn_implementation=requested_attn_implementation)
|
||||
self.assertEqual(model.config._attn_implementation, requested_attn_implementation)
|
||||
|
||||
def test_torch_dtype_byte_sizes(self):
|
||||
torch_dtypes_and_bytes = [
|
||||
(torch.double, 8),
|
||||
(torch.float64, 8),
|
||||
(torch.float, 4),
|
||||
(torch.float32, 4),
|
||||
(torch.half, 2),
|
||||
(torch.float16, 2),
|
||||
(torch.bfloat16, 2),
|
||||
(torch.long, 8),
|
||||
(torch.int64, 8),
|
||||
(torch.int, 4),
|
||||
(torch.int32, 4),
|
||||
(torch.short, 2),
|
||||
(torch.int16, 2),
|
||||
(torch.uint8, 1),
|
||||
(torch.int8, 1),
|
||||
(torch.float8_e4m3fn, 1),
|
||||
(torch.float8_e5m2, 1),
|
||||
(torch.bool, 0.125),
|
||||
]
|
||||
|
||||
for torch_dtype, bytes_per_element in torch_dtypes_and_bytes:
|
||||
self.assertEqual(dtype_byte_size(torch_dtype), bytes_per_element)
|
||||
|
||||
def test_no_super_init_config_and_model(self):
|
||||
config = NoSuperInitConfig(attribute=32)
|
||||
model = NoSuperInitModel(config)
|
||||
|
Loading…
Reference in New Issue
Block a user