Define warmup allocator for torchao quantization (#37764)

* torchao allocator

* add comment

---------

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
Mohamed Mekkouri 2025-04-28 10:45:55 +02:00 committed by GitHub
parent a41b6d9b5c
commit f466603963
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -277,6 +277,45 @@ class TorchAoHfQuantizer(HfQuantizer):
return False
return _is_torchao_serializable
def get_cuda_warm_up_factor(self):
"""
This factor is used in caching_allocator_warmup to determine how many bytes to pre-allocate for CUDA warmup.
- A factor of 2 means we pre-allocate the full memory footprint of the model.
- A factor of 4 means we pre-allocate half of that, and so on
However, when using TorchAO, calculating memory usage with param.numel() * param.element_size() doesn't give the correct size for quantized weights (like int4 or int8)
That's because TorchAO internally represents quantized tensors using subtensors and metadata, and the reported element_size() still corresponds to the torch_dtype
not the actual bit-width of the quantized data.
To correct for this:
- Use a division factor of 8 for int4 weights
- Use a division factor of 4 for int8 weights
"""
if self.quantization_config._get_ao_version() > version.Version("0.9.0"):
from torchao.core.config import AOBaseConfig
quant_type = self.quantization_config.quant_type
# For autoquant case, it will be treated in the string implementation below in map_to_target_dtype
if isinstance(quant_type, AOBaseConfig):
# Extract size digit using fuzzy match on the class name
config_name = quant_type.__class__.__name__
size_digit = fuzzy_match_size(config_name)
if size_digit == "4":
return 8
else:
return 4
# Original mapping for non-AOBaseConfig types
map_to_target_dtype = {
"int4_weight_only": 8,
"int8_weight_only": 4,
"int8_dynamic_activation_int8_weight": 4,
"autoquant": 4,
}
return map_to_target_dtype[self.quantization_config.quant_type]
@property
def is_trainable(self) -> bool:
supported_quant_types_for_training = [