mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
[LLaVa] Add past_key_values to _skip_keys_device_placement to fix multi-GPU dispatch (#28051)
Add past_key_values to _skip_keys_device_placement for LLaVa
This commit is contained in:
parent
deb72cb6d9
commit
e2b6df7971
@ -130,6 +130,7 @@ class LlavaPreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["LlavaVisionAttention"]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_supports_flash_attn_2 = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
|
@ -137,6 +137,7 @@ class VipLlavaPreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["VipLlavaVisionAttention"]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_supports_flash_attn_2 = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
|
Loading…
Reference in New Issue
Block a user