mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 13:20:12 +06:00
Define warmup allocator for torchao quantization (#37764)
* torchao allocator * add comment --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
parent
a41b6d9b5c
commit
f466603963
@ -277,6 +277,45 @@ class TorchAoHfQuantizer(HfQuantizer):
|
|||||||
return False
|
return False
|
||||||
return _is_torchao_serializable
|
return _is_torchao_serializable
|
||||||
|
|
||||||
|
def get_cuda_warm_up_factor(self):
|
||||||
|
"""
|
||||||
|
This factor is used in caching_allocator_warmup to determine how many bytes to pre-allocate for CUDA warmup.
|
||||||
|
- A factor of 2 means we pre-allocate the full memory footprint of the model.
|
||||||
|
- A factor of 4 means we pre-allocate half of that, and so on
|
||||||
|
|
||||||
|
However, when using TorchAO, calculating memory usage with param.numel() * param.element_size() doesn't give the correct size for quantized weights (like int4 or int8)
|
||||||
|
That's because TorchAO internally represents quantized tensors using subtensors and metadata, and the reported element_size() still corresponds to the torch_dtype
|
||||||
|
not the actual bit-width of the quantized data.
|
||||||
|
|
||||||
|
To correct for this:
|
||||||
|
- Use a division factor of 8 for int4 weights
|
||||||
|
- Use a division factor of 4 for int8 weights
|
||||||
|
"""
|
||||||
|
if self.quantization_config._get_ao_version() > version.Version("0.9.0"):
|
||||||
|
from torchao.core.config import AOBaseConfig
|
||||||
|
|
||||||
|
quant_type = self.quantization_config.quant_type
|
||||||
|
# For autoquant case, it will be treated in the string implementation below in map_to_target_dtype
|
||||||
|
if isinstance(quant_type, AOBaseConfig):
|
||||||
|
# Extract size digit using fuzzy match on the class name
|
||||||
|
config_name = quant_type.__class__.__name__
|
||||||
|
size_digit = fuzzy_match_size(config_name)
|
||||||
|
|
||||||
|
if size_digit == "4":
|
||||||
|
return 8
|
||||||
|
else:
|
||||||
|
return 4
|
||||||
|
|
||||||
|
# Original mapping for non-AOBaseConfig types
|
||||||
|
map_to_target_dtype = {
|
||||||
|
"int4_weight_only": 8,
|
||||||
|
"int8_weight_only": 4,
|
||||||
|
"int8_dynamic_activation_int8_weight": 4,
|
||||||
|
"autoquant": 4,
|
||||||
|
}
|
||||||
|
|
||||||
|
return map_to_target_dtype[self.quantization_config.quant_type]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_trainable(self) -> bool:
|
def is_trainable(self) -> bool:
|
||||||
supported_quant_types_for_training = [
|
supported_quant_types_for_training = [
|
||||||
|
Loading…
Reference in New Issue
Block a user