LLaVa: add cache class attribute (#32278)

cache class flag
This commit is contained in:
Raushan Turganbay 2024-08-01 09:48:03 +05:00 committed by GitHub
parent 14ee2326e5
commit 453e74884f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 6 additions and 0 deletions

View File

@ -126,6 +126,7 @@ class LlavaPreTrainedModel(PreTrainedModel):
_no_split_modules = ["LlavaVisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_cache_class = True
def _init_weights(self, module):
# important: this ported version of Llava isn't meant for training from scratch - only

View File

@ -232,6 +232,7 @@ class LlavaNextPreTrainedModel(PreTrainedModel):
_no_split_modules = ["LlavaNextVisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_cache_class = True
def _init_weights(self, module):
# important: this ported version of LlavaNext isn't meant for training from scratch - only

View File

@ -272,6 +272,7 @@ class LlavaNextVideoPreTrainedModel(PreTrainedModel):
_no_split_modules = ["LlavaNextVideoVisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_cache_class = True
def _init_weights(self, module):
# important: this ported version of LlavaNextVideo isn't meant for training from scratch - only

View File

@ -127,6 +127,7 @@ class PaliGemmaPreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = False
_supports_sdpa = True
_supports_cache_class = True
def _init_weights(self, module):
# important: this ported version of PaliGemmaisn't meant for training from scratch - only

View File

@ -126,6 +126,7 @@ class VideoLlavaPreTrainedModel(PreTrainedModel):
_no_split_modules = ["VideoLlavaVisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_cache_class = True
def _init_weights(self, module):
std = (

View File

@ -135,6 +135,7 @@ class VipLlavaPreTrainedModel(PreTrainedModel):
_no_split_modules = ["VipLlavaVisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_cache_class = True
def _init_weights(self, module):
# important: this ported version of VipLlava isn't meant for training from scratch - only