mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 13:20:12 +06:00
Define warmup allocator for torchao quantization (#37764)
* torchao allocator * add comment --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
parent
a41b6d9b5c
commit
f466603963
@ -277,6 +277,45 @@ class TorchAoHfQuantizer(HfQuantizer):
|
||||
return False
|
||||
return _is_torchao_serializable
|
||||
|
||||
def get_cuda_warm_up_factor(self):
|
||||
"""
|
||||
This factor is used in caching_allocator_warmup to determine how many bytes to pre-allocate for CUDA warmup.
|
||||
- A factor of 2 means we pre-allocate the full memory footprint of the model.
|
||||
- A factor of 4 means we pre-allocate half of that, and so on
|
||||
|
||||
However, when using TorchAO, calculating memory usage with param.numel() * param.element_size() doesn't give the correct size for quantized weights (like int4 or int8)
|
||||
That's because TorchAO internally represents quantized tensors using subtensors and metadata, and the reported element_size() still corresponds to the torch_dtype
|
||||
not the actual bit-width of the quantized data.
|
||||
|
||||
To correct for this:
|
||||
- Use a division factor of 8 for int4 weights
|
||||
- Use a division factor of 4 for int8 weights
|
||||
"""
|
||||
if self.quantization_config._get_ao_version() > version.Version("0.9.0"):
|
||||
from torchao.core.config import AOBaseConfig
|
||||
|
||||
quant_type = self.quantization_config.quant_type
|
||||
# For autoquant case, it will be treated in the string implementation below in map_to_target_dtype
|
||||
if isinstance(quant_type, AOBaseConfig):
|
||||
# Extract size digit using fuzzy match on the class name
|
||||
config_name = quant_type.__class__.__name__
|
||||
size_digit = fuzzy_match_size(config_name)
|
||||
|
||||
if size_digit == "4":
|
||||
return 8
|
||||
else:
|
||||
return 4
|
||||
|
||||
# Original mapping for non-AOBaseConfig types
|
||||
map_to_target_dtype = {
|
||||
"int4_weight_only": 8,
|
||||
"int8_weight_only": 4,
|
||||
"int8_dynamic_activation_int8_weight": 4,
|
||||
"autoquant": 4,
|
||||
}
|
||||
|
||||
return map_to_target_dtype[self.quantization_config.quant_type]
|
||||
|
||||
@property
|
||||
def is_trainable(self) -> bool:
|
||||
supported_quant_types_for_training = [
|
||||
|
Loading…
Reference in New Issue
Block a user