mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
fix caching_allocator_warmup with tie weights (#39070)
* fix caching_allocator_warmup with tie weights Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix comment Signed-off-by: jiqing-feng <jiqing.feng@intel.com> --------- Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
This commit is contained in:
parent
e435574721
commit
06c4a4d499
@ -5843,7 +5843,12 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: dict,
|
||||
else None
|
||||
)
|
||||
total_byte_count = defaultdict(lambda: 0)
|
||||
tied_param_names = _get_tied_weight_keys(model)
|
||||
for param_name, device in accelerator_device_map.items():
|
||||
# Skip if the parameter has already been accounted for (tied weights)
|
||||
if param_name in tied_param_names:
|
||||
continue
|
||||
|
||||
param = model.get_parameter_or_buffer(param_name)
|
||||
# The dtype of different parameters may be different with composite models or `keep_in_fp32_modules`
|
||||
param_byte_count = param.numel() * param.element_size()
|
||||
|
Loading…
Reference in New Issue
Block a user