Feat: add warnings for unused keys and rules in tensor parallel (#37893)

Feat: tensor parallel plan verification
This commit is contained in:
Matej Sirovatka 2025-05-16 14:52:47 +02:00 committed by GitHub
parent 120935234f
commit 7b5e327c6e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 35 additions and 0 deletions

View File

@ -670,3 +670,34 @@ def shard_and_distribute_module(
setattr(module_to_tp, param_type, param)
# module_to_tp.load_state_dict({param_type: param}, strict=False, assign=True)
return param
def verify_tp_plan(expected_keys: list[str], tp_plan: Optional[dict[str, str]]):
"""
Verify the TP plan of the model, log a warning if the layers that were not sharded and the rules that were not applied.
"""
if tp_plan is None:
return
generic_keys = {re.sub(r"\d+", "*", key) for key in expected_keys}
unsharded_layers = set(generic_keys)
unused_rules = tp_plan
for key in generic_keys:
param_name, _ = key.rsplit(".", 1) if "." in key else key
generic_param_name = re.sub(r"\d+", "*", param_name)
if generic_param_name in tp_plan:
unused_rules.pop(generic_param_name)
unsharded_layers.discard(key)
elif "." in generic_param_name and (parent_param_name := generic_param_name.rsplit(".", 1)[0]) in tp_plan:
unused_rules.pop(parent_param_name)
unsharded_layers.discard(key)
else:
pass # we couldn't find the rule for this parameter, so it's not sharded
if len(unused_rules) > 0:
logger.warning(f"The following TP rules were not applied on any of the layers: {unused_rules}")
if len(unsharded_layers) > 0:
logger.warning(f"The following layers were not sharded: {', '.join(unsharded_layers)}")

View File

@ -64,6 +64,7 @@ from .integrations.sdpa_attention import sdpa_attention_forward
from .integrations.tensor_parallel import (
SUPPORTED_TP_STYLES,
shard_and_distribute_module,
verify_tp_plan,
)
from .loss.loss_utils import LOSS_MAPPING
from .pytorch_utils import ( # noqa: F401
@ -4974,6 +4975,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
if hf_quantizer is not None:
expected_keys = hf_quantizer.update_expected_keys(model_to_load, expected_keys, checkpoint_keys)
if logger.level >= logging.WARNING:
verify_tp_plan(expected_keys, getattr(model_to_load, "_tp_plan", None))
# Warmup cuda to load the weights much faster on devices
if device_map is not None and not is_hqq_or_quark:
expanded_device_map = expand_device_map(device_map, expected_keys)