fix caching_allocator_warmup with tie weights (#39070)

* fix caching_allocator_warmup with tie weights

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix comment

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

---------

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
This commit is contained in:
jiqing-feng 2025-07-01 17:32:20 +08:00 committed by GitHub
parent e435574721
commit 06c4a4d499
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -5843,7 +5843,12 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: dict,
else None
)
total_byte_count = defaultdict(lambda: 0)
tied_param_names = _get_tied_weight_keys(model)
for param_name, device in accelerator_device_map.items():
# Skip if the parameter has already been accounted for (tied weights)
if param_name in tied_param_names:
continue
param = model.get_parameter_or_buffer(param_name)
# The dtype of different parameters may be different with composite models or `keep_in_fp32_modules`
param_byte_count = param.numel() * param.element_size()