mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 17:52:35 +06:00
Update dtype_byte_size
to handle torch.float8_e4m3fn/float8_e5m2 types (#30488)
* Update modeling_utils/dtype_byte_size to handle float8 types * Add a test for dtype_byte_size * Format * Fix bool
This commit is contained in:
parent
59e715f71c
commit
20081c743e
@ -324,7 +324,7 @@ def dtype_byte_size(dtype):
|
||||
"""
|
||||
if dtype == torch.bool:
|
||||
return 1 / 8
|
||||
bit_search = re.search(r"[^\d](\d+)$", str(dtype))
|
||||
bit_search = re.search(r"[^\d](\d+)_?", str(dtype))
|
||||
if bit_search is None:
|
||||
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
|
||||
bit_size = int(bit_search.groups()[0])
|
||||
|
@ -101,7 +101,12 @@ if is_torch_available():
|
||||
_prepare_4d_attention_mask,
|
||||
_prepare_4d_causal_attention_mask,
|
||||
)
|
||||
from transformers.modeling_utils import _find_disjoint, _find_identical, shard_checkpoint
|
||||
from transformers.modeling_utils import (
|
||||
_find_disjoint,
|
||||
_find_identical,
|
||||
dtype_byte_size,
|
||||
shard_checkpoint,
|
||||
)
|
||||
|
||||
# Fake pretrained models for tests
|
||||
class BaseModel(PreTrainedModel):
|
||||
@ -465,6 +470,31 @@ class ModelUtilsTest(TestCasePlus):
|
||||
module.__class__.__name__, mistral_attention_classes[requested_attn_implementation]
|
||||
)
|
||||
|
||||
def test_torch_dtype_byte_sizes(self):
|
||||
torch_dtypes_and_bytes = [
|
||||
(torch.double, 8),
|
||||
(torch.float64, 8),
|
||||
(torch.float, 4),
|
||||
(torch.float32, 4),
|
||||
(torch.half, 2),
|
||||
(torch.float16, 2),
|
||||
(torch.bfloat16, 2),
|
||||
(torch.long, 8),
|
||||
(torch.int64, 8),
|
||||
(torch.int, 4),
|
||||
(torch.int32, 4),
|
||||
(torch.short, 2),
|
||||
(torch.int16, 2),
|
||||
(torch.uint8, 1),
|
||||
(torch.int8, 1),
|
||||
(torch.float8_e4m3fn, 1),
|
||||
(torch.float8_e5m2, 1),
|
||||
(torch.bool, 0.125),
|
||||
]
|
||||
|
||||
for torch_dtype, bytes_per_element in torch_dtypes_and_bytes:
|
||||
self.assertEqual(dtype_byte_size(torch_dtype), bytes_per_element)
|
||||
|
||||
def test_no_super_init_config_and_model(self):
|
||||
config = NoSuperInitConfig(attribute=32)
|
||||
model = NoSuperInitModel(config)
|
||||
|
Loading…
Reference in New Issue
Block a user