mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Feat: add warnings for unused keys and rules in tensor parallel (#37893)
Feat: tensor parallel plan verification
This commit is contained in:
parent
120935234f
commit
7b5e327c6e
@ -670,3 +670,34 @@ def shard_and_distribute_module(
|
||||
setattr(module_to_tp, param_type, param)
|
||||
# module_to_tp.load_state_dict({param_type: param}, strict=False, assign=True)
|
||||
return param
|
||||
|
||||
|
||||
def verify_tp_plan(expected_keys: list[str], tp_plan: Optional[dict[str, str]]):
|
||||
"""
|
||||
Verify the TP plan of the model, log a warning if the layers that were not sharded and the rules that were not applied.
|
||||
"""
|
||||
|
||||
if tp_plan is None:
|
||||
return
|
||||
|
||||
generic_keys = {re.sub(r"\d+", "*", key) for key in expected_keys}
|
||||
unsharded_layers = set(generic_keys)
|
||||
unused_rules = tp_plan
|
||||
|
||||
for key in generic_keys:
|
||||
param_name, _ = key.rsplit(".", 1) if "." in key else key
|
||||
generic_param_name = re.sub(r"\d+", "*", param_name)
|
||||
|
||||
if generic_param_name in tp_plan:
|
||||
unused_rules.pop(generic_param_name)
|
||||
unsharded_layers.discard(key)
|
||||
elif "." in generic_param_name and (parent_param_name := generic_param_name.rsplit(".", 1)[0]) in tp_plan:
|
||||
unused_rules.pop(parent_param_name)
|
||||
unsharded_layers.discard(key)
|
||||
else:
|
||||
pass # we couldn't find the rule for this parameter, so it's not sharded
|
||||
|
||||
if len(unused_rules) > 0:
|
||||
logger.warning(f"The following TP rules were not applied on any of the layers: {unused_rules}")
|
||||
if len(unsharded_layers) > 0:
|
||||
logger.warning(f"The following layers were not sharded: {', '.join(unsharded_layers)}")
|
||||
|
@ -64,6 +64,7 @@ from .integrations.sdpa_attention import sdpa_attention_forward
|
||||
from .integrations.tensor_parallel import (
|
||||
SUPPORTED_TP_STYLES,
|
||||
shard_and_distribute_module,
|
||||
verify_tp_plan,
|
||||
)
|
||||
from .loss.loss_utils import LOSS_MAPPING
|
||||
from .pytorch_utils import ( # noqa: F401
|
||||
@ -4974,6 +4975,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
if hf_quantizer is not None:
|
||||
expected_keys = hf_quantizer.update_expected_keys(model_to_load, expected_keys, checkpoint_keys)
|
||||
|
||||
if logger.level >= logging.WARNING:
|
||||
verify_tp_plan(expected_keys, getattr(model_to_load, "_tp_plan", None))
|
||||
|
||||
# Warmup cuda to load the weights much faster on devices
|
||||
if device_map is not None and not is_hqq_or_quark:
|
||||
expanded_device_map = expand_device_map(device_map, expected_keys)
|
||||
|
Loading…
Reference in New Issue
Block a user