Correct warm-up with fp8 (#37670)

* start clean warmup for quantizers

* style

---------

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
Cyril Vallez 2025-04-22 13:12:49 +02:00 committed by GitHub
parent 6614209b96
commit 9608908639
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 19 additions and 2 deletions

View File

@ -4866,7 +4866,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
# Warmup cuda to load the weights much faster on devices
if device_map is not None and not is_hqq_or_quark:
expanded_device_map = expand_device_map(device_map, expected_keys)
caching_allocator_warmup(model_to_load, expanded_device_map, factor=2 if hf_quantizer is None else 4)
caching_allocator_warmup(model_to_load, expanded_device_map, hf_quantizer)
error_msgs = []
# Iterate on all the shards to load the weights
@ -5871,7 +5871,7 @@ def is_accelerator_device(device: Union[str, int, torch.device]) -> bool:
return torch.device(device).type not in ["meta", "cpu"]
def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict, factor=2):
def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict, hf_quantizer: Optional[HfQuantizer]):
"""This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
device. It allows to have one large call to Malloc, instead of recursively calling it later when loading
the model, which is actually the loading speed botteneck.
@ -5890,6 +5890,8 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict,
- Loading speed bottleneck is now almost only tensor copy (i.e. changing the dtype) and moving the tensors to the devices.
However, we cannot really improve on those aspects obviously, as the data needs to be moved/copied in the end.
"""
factor = 2 if hf_quantizer is None else hf_quantizer.get_cuda_warm_up_factor()
# Remove disk, cpu and meta devices, and cast to proper torch.device
accelerator_device_map = {
param: torch.device(device) for param, device in expanded_device_map.items() if is_accelerator_device(device)

View File

@ -252,6 +252,17 @@ class HfQuantizer(ABC):
return model
def get_cuda_warm_up_factor(self):
"""
The factor to be used in `caching_allocator_warmup` to get the number of bytes to pre-allocate to warm up cuda.
A factor of 2 means we allocate all bytes in the empty model (since we allocate in fp16), a factor of 4 means
we allocate half the memory of the weights residing in the empty model, etc...
"""
# By default we return 4, i.e. half the model size (this corresponds to the case where the model is not
# really pre-processed, i.e. we do not have the info that weights are going to be 8 bits before actual
# weight loading)
return 4
def _dequantize(self, model):
raise NotImplementedError(
f"{self.quantization_config.quant_method} has no implementation of `dequantize`, please raise an issue on GitHub."

View File

@ -200,3 +200,7 @@ class FineGrainedFP8HfQuantizer(HfQuantizer):
@property
def is_trainable(self) -> bool:
return False
def get_cuda_warm_up_factor(self):
# Pre-processing is done cleanly, so we can allocate everything here
return 2