[Llava / Vip-Llava] Add SDPA into llava (#28107)

add SDPA into llava
This commit is contained in:
Younes Belkada 2023-12-18 13:46:30 +01:00 committed by GitHub
parent e6dcf8abd6
commit b8378b658e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 0 deletions

View File

@ -155,6 +155,14 @@ class LlavaPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
@property
def _supports_sdpa(self):
"""
Retrieve language_model's attribute to check whether the model supports
SDPA or not.
"""
return self.language_model._supports_sdpa
LLAVA_INPUTS_DOCSTRING = r"""
Args:

View File

@ -162,6 +162,14 @@ class VipLlavaPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
@property
def _supports_sdpa(self):
"""
Retrieve language_model's attribute to check whether the model supports
SDPA or not.
"""
return self.language_model._supports_sdpa
VIPLLAVA_INPUTS_DOCSTRING = r"""
Args: