diff --git a/src/transformers/quantizers/quantizer_torchao.py b/src/transformers/quantizers/quantizer_torchao.py index 21dc2498d9a..5b23a6173d8 100644 --- a/src/transformers/quantizers/quantizer_torchao.py +++ b/src/transformers/quantizers/quantizer_torchao.py @@ -277,6 +277,45 @@ class TorchAoHfQuantizer(HfQuantizer): return False return _is_torchao_serializable + def get_cuda_warm_up_factor(self): + """ + This factor is used in caching_allocator_warmup to determine how many bytes to pre-allocate for CUDA warmup. + - A factor of 2 means we pre-allocate the full memory footprint of the model. + - A factor of 4 means we pre-allocate half of that, and so on + + However, when using TorchAO, calculating memory usage with param.numel() * param.element_size() doesn't give the correct size for quantized weights (like int4 or int8) + That's because TorchAO internally represents quantized tensors using subtensors and metadata, and the reported element_size() still corresponds to the torch_dtype + not the actual bit-width of the quantized data. + + To correct for this: + - Use a division factor of 8 for int4 weights + - Use a division factor of 4 for int8 weights + """ + if self.quantization_config._get_ao_version() > version.Version("0.9.0"): + from torchao.core.config import AOBaseConfig + + quant_type = self.quantization_config.quant_type + # For autoquant case, it will be treated in the string implementation below in map_to_target_dtype + if isinstance(quant_type, AOBaseConfig): + # Extract size digit using fuzzy match on the class name + config_name = quant_type.__class__.__name__ + size_digit = fuzzy_match_size(config_name) + + if size_digit == "4": + return 8 + else: + return 4 + + # Original mapping for non-AOBaseConfig types + map_to_target_dtype = { + "int4_weight_only": 8, + "int8_weight_only": 4, + "int8_dynamic_activation_int8_weight": 4, + "autoquant": 4, + } + + return map_to_target_dtype[self.quantization_config.quant_type] + @property def is_trainable(self) -> bool: supported_quant_types_for_training = [