diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index 00414afadb9..5b8f8acb5a6 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -670,3 +670,34 @@ def shard_and_distribute_module( setattr(module_to_tp, param_type, param) # module_to_tp.load_state_dict({param_type: param}, strict=False, assign=True) return param + + +def verify_tp_plan(expected_keys: list[str], tp_plan: Optional[dict[str, str]]): + """ + Verify the TP plan of the model, log a warning if the layers that were not sharded and the rules that were not applied. + """ + + if tp_plan is None: + return + + generic_keys = {re.sub(r"\d+", "*", key) for key in expected_keys} + unsharded_layers = set(generic_keys) + unused_rules = tp_plan + + for key in generic_keys: + param_name, _ = key.rsplit(".", 1) if "." in key else key + generic_param_name = re.sub(r"\d+", "*", param_name) + + if generic_param_name in tp_plan: + unused_rules.pop(generic_param_name) + unsharded_layers.discard(key) + elif "." in generic_param_name and (parent_param_name := generic_param_name.rsplit(".", 1)[0]) in tp_plan: + unused_rules.pop(parent_param_name) + unsharded_layers.discard(key) + else: + pass # we couldn't find the rule for this parameter, so it's not sharded + + if len(unused_rules) > 0: + logger.warning(f"The following TP rules were not applied on any of the layers: {unused_rules}") + if len(unsharded_layers) > 0: + logger.warning(f"The following layers were not sharded: {', '.join(unsharded_layers)}") diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 7f9751ab57b..459cd7aca55 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -64,6 +64,7 @@ from .integrations.sdpa_attention import sdpa_attention_forward from .integrations.tensor_parallel import ( SUPPORTED_TP_STYLES, shard_and_distribute_module, + verify_tp_plan, ) from .loss.loss_utils import LOSS_MAPPING from .pytorch_utils import ( # noqa: F401 @@ -4974,6 +4975,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi if hf_quantizer is not None: expected_keys = hf_quantizer.update_expected_keys(model_to_load, expected_keys, checkpoint_keys) + if logger.level >= logging.WARNING: + verify_tp_plan(expected_keys, getattr(model_to_load, "_tp_plan", None)) + # Warmup cuda to load the weights much faster on devices if device_map is not None and not is_hqq_or_quark: expanded_device_map = expand_device_map(device_map, expected_keys)