fix: (llama4) fix no_split_modules to be picked up for fsdpv1 and v2 sharding (#37462)

fix: fix no_split_modules to be picked up for fsdpv1 and v2 sharding

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
This commit is contained in:
Mehant Kammakomati 2025-04-14 14:14:32 +05:30 committed by GitHub
parent 953196a43d
commit 78cea3e22c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -476,6 +476,7 @@ class Llama4PreTrainedModel(PreTrainedModel):
_supports_quantized_cache = True
_supports_static_cache = True
_supports_attention_backend = True
_no_split_modules = ["Llama4TextDecoderLayer", "Llama4VisionEncoderLayer"]
def _init_weights(self, module):
std = (