Bug fix and updates

This commit is contained in:
yaswant19 2025-03-29 11:00:57 +05:30
parent 019210c9c7
commit a031299d13
3 changed files with 32 additions and 20 deletions

View File

@ -87,9 +87,10 @@ class AIMv2VisionConfig(PretrainedConfig):
projection_dropout: float = 0.0,
qkv_bias: bool = False,
use_bias: bool = False,
hidden_act="silu",
initializer_range=0.02,
use_head=True,
hidden_act: str = "silu",
initializer_range: float = 0.02,
use_head: bool = True,
is_causal: bool = False,
**kwargs,
):
super().__init__(**kwargs)
@ -110,6 +111,7 @@ class AIMv2VisionConfig(PretrainedConfig):
self.qkv_bias = qkv_bias
self.rms_norm_eps = rms_norm_eps
self.projection_dropout = projection_dropout
self.is_causal = is_causal
class AIMv2TextConfig(PretrainedConfig):
@ -175,12 +177,13 @@ class AIMv2TextConfig(PretrainedConfig):
projection_dropout: float = 0.0,
qkv_bias: bool = False,
use_bias: bool = False,
hidden_act="silu",
pad_token_id=None,
bos_token_id=None,
hidden_act: str = "silu",
pad_token_id: int = None,
bos_token_id: int = None,
eos_token_id: int = 49407,
max_position_embeddings: int = 77,
initializer_range=0.02,
initializer_range: bool = 0.02,
is_causal: bool = True,
**kwargs,
):
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
@ -199,6 +202,7 @@ class AIMv2TextConfig(PretrainedConfig):
self.qkv_bias = qkv_bias
self.rms_norm_eps = rms_norm_eps
self.projection_dropout = projection_dropout
self.is_causal = is_causal
class AIMv2Config(PretrainedConfig):

View File

@ -248,7 +248,7 @@ def eager_attention_forward(
class AIMv2Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: AIMv2VisionConfig):
def __init__(self, config):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
@ -262,7 +262,7 @@ class AIMv2Attention(nn.Module):
)
self.num_key_value_groups = 1
self.scaling = 1.0
self.scaling = self.head_dim**-0.5
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
@ -270,6 +270,8 @@ class AIMv2Attention(nn.Module):
self.proj_out = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
self.proj_drop = nn.Dropout(config.projection_dropout)
self.is_causal = config.is_causal
def forward(
self,
hidden_states: torch.Tensor,
@ -307,7 +309,7 @@ class AIMv2Attention(nn.Module):
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
is_causal=False,
is_causal=self.is_causal,
**kwargs,
)

View File

@ -98,9 +98,10 @@ class AIMv2VisionConfig(SiglipVisionConfig):
projection_dropout: float = 0.0,
qkv_bias: bool = False,
use_bias: bool = False,
hidden_act="silu",
initializer_range=0.02,
use_head=True,
hidden_act: str ="silu",
initializer_range: float =0.02,
use_head: bool =True,
is_causal: bool = False,
**kwargs,
):
super().__init__(
@ -123,6 +124,7 @@ class AIMv2VisionConfig(SiglipVisionConfig):
self.qkv_bias = qkv_bias
self.rms_norm_eps = rms_norm_eps
self.projection_dropout = projection_dropout
self.is_causal=is_causal
del self.layer_norm_eps
@ -186,12 +188,13 @@ class AIMv2TextConfig(SiglipTextConfig):
projection_dropout: float = 0.0,
qkv_bias: bool = False,
use_bias: bool = False,
hidden_act="silu",
pad_token_id=None,
bos_token_id=None,
hidden_act: str="silu",
pad_token_id: int=None,
bos_token_id: int=None,
eos_token_id: int = 49407,
max_position_embeddings: int = 77,
initializer_range=0.02,
initializer_range: bool=0.02,
is_causal: bool = True,
**kwargs,
):
super().__init__(
@ -214,6 +217,7 @@ class AIMv2TextConfig(SiglipTextConfig):
self.qkv_bias = qkv_bias
self.rms_norm_eps = rms_norm_eps
self.projection_dropout = projection_dropout
self.is_causal=is_causal
del self.bos_token_id
del self.pad_token_id
@ -390,7 +394,7 @@ class AIMv2TextEmbeddings(CLIPTextEmbeddings):
class AIMv2Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: AIMv2VisionConfig):
def __init__(self, config):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
@ -404,7 +408,7 @@ class AIMv2Attention(nn.Module):
)
self.num_key_value_groups = 1
self.scaling = 1.0
self.scaling = self.head_dim ** -0.5
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
@ -412,6 +416,8 @@ class AIMv2Attention(nn.Module):
self.proj_out = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
self.proj_drop = nn.Dropout(config.projection_dropout)
self.is_causal = config.is_causal
def forward(
self,
hidden_states: torch.Tensor,
@ -449,7 +455,7 @@ class AIMv2Attention(nn.Module):
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
is_causal=False,
is_causal=self.is_causal,
**kwargs,
)