feat: add TP plan for granite (#35573)

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
This commit is contained in:
Mehant Kammakomati 2025-01-09 19:55:55 +05:30 committed by GitHub
parent 633da1b10e
commit 320512df46
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -112,6 +112,16 @@ class GraniteConfig(PretrainedConfig):
model_type = "granite"
keys_to_ignore_at_inference = ["past_key_values"]
# Default tensor parallel plan for base model `GraniteModel`
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
def __init__(
self,