mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
feat: add TP plan for granite (#35573)
Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
This commit is contained in:
parent
633da1b10e
commit
320512df46
@ -112,6 +112,16 @@ class GraniteConfig(PretrainedConfig):
|
||||
|
||||
model_type = "granite"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
# Default tensor parallel plan for base model `GraniteModel`
|
||||
base_model_tp_plan = {
|
||||
"layers.*.self_attn.q_proj": "colwise",
|
||||
"layers.*.self_attn.k_proj": "colwise",
|
||||
"layers.*.self_attn.v_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
Loading…
Reference in New Issue
Block a user