Enable device map (#30870)

* added_no_split_modules

* added LlavaNextVisionAttention to _no_split_modules
This commit is contained in:
Darshana S 2024-05-17 17:20:24 +05:30 committed by GitHub
parent 57c965a8f1
commit 3802e786ef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -124,6 +124,7 @@ class VideoLlavaPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_no_split_modules = ["VideoLlavaVisionAttention"]
def _init_weights(self, module):
# important: this ported version of VideoLlava isn't meant for training from scratch - only