[LLaVa] Add past_key_values to _skip_keys_device_placement to fix multi-GPU dispatch (#28051)

Add past_key_values to _skip_keys_device_placement  for LLaVa
This commit is contained in:
Adilzhan Ismailov 2023-12-15 14:05:20 +00:00 committed by GitHub
parent deb72cb6d9
commit e2b6df7971
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 2 additions and 0 deletions

View File

@ -130,6 +130,7 @@ class LlavaPreTrainedModel(PreTrainedModel):
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["LlavaVisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
def _init_weights(self, module):

View File

@ -137,6 +137,7 @@ class VipLlavaPreTrainedModel(PreTrainedModel):
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["VipLlavaVisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
def _init_weights(self, module):