Update dtype_byte_size to handle torch.float8_e4m3fn/float8_e5m2 types (#30488)

* Update modeling_utils/dtype_byte_size to handle float8 types

* Add a test for dtype_byte_size

* Format

* Fix bool
This commit is contained in:
Michael Goin 2024-04-26 06:26:43 -04:00 committed by GitHub
parent 59e715f71c
commit 20081c743e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 32 additions and 2 deletions

View File

@ -324,7 +324,7 @@ def dtype_byte_size(dtype):
"""
if dtype == torch.bool:
return 1 / 8
bit_search = re.search(r"[^\d](\d+)$", str(dtype))
bit_search = re.search(r"[^\d](\d+)_?", str(dtype))
if bit_search is None:
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
bit_size = int(bit_search.groups()[0])

View File

@ -101,7 +101,12 @@ if is_torch_available():
_prepare_4d_attention_mask,
_prepare_4d_causal_attention_mask,
)
from transformers.modeling_utils import _find_disjoint, _find_identical, shard_checkpoint
from transformers.modeling_utils import (
_find_disjoint,
_find_identical,
dtype_byte_size,
shard_checkpoint,
)
# Fake pretrained models for tests
class BaseModel(PreTrainedModel):
@ -465,6 +470,31 @@ class ModelUtilsTest(TestCasePlus):
module.__class__.__name__, mistral_attention_classes[requested_attn_implementation]
)
def test_torch_dtype_byte_sizes(self):
torch_dtypes_and_bytes = [
(torch.double, 8),
(torch.float64, 8),
(torch.float, 4),
(torch.float32, 4),
(torch.half, 2),
(torch.float16, 2),
(torch.bfloat16, 2),
(torch.long, 8),
(torch.int64, 8),
(torch.int, 4),
(torch.int32, 4),
(torch.short, 2),
(torch.int16, 2),
(torch.uint8, 1),
(torch.int8, 1),
(torch.float8_e4m3fn, 1),
(torch.float8_e5m2, 1),
(torch.bool, 0.125),
]
for torch_dtype, bytes_per_element in torch_dtypes_and_bytes:
self.assertEqual(dtype_byte_size(torch_dtype), bytes_per_element)
def test_no_super_init_config_and_model(self):
config = NoSuperInitConfig(attribute=32)
model = NoSuperInitModel(config)