mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Bug fix and updates
This commit is contained in:
parent
019210c9c7
commit
a031299d13
@ -87,9 +87,10 @@ class AIMv2VisionConfig(PretrainedConfig):
|
||||
projection_dropout: float = 0.0,
|
||||
qkv_bias: bool = False,
|
||||
use_bias: bool = False,
|
||||
hidden_act="silu",
|
||||
initializer_range=0.02,
|
||||
use_head=True,
|
||||
hidden_act: str = "silu",
|
||||
initializer_range: float = 0.02,
|
||||
use_head: bool = True,
|
||||
is_causal: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
@ -110,6 +111,7 @@ class AIMv2VisionConfig(PretrainedConfig):
|
||||
self.qkv_bias = qkv_bias
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.projection_dropout = projection_dropout
|
||||
self.is_causal = is_causal
|
||||
|
||||
|
||||
class AIMv2TextConfig(PretrainedConfig):
|
||||
@ -175,12 +177,13 @@ class AIMv2TextConfig(PretrainedConfig):
|
||||
projection_dropout: float = 0.0,
|
||||
qkv_bias: bool = False,
|
||||
use_bias: bool = False,
|
||||
hidden_act="silu",
|
||||
pad_token_id=None,
|
||||
bos_token_id=None,
|
||||
hidden_act: str = "silu",
|
||||
pad_token_id: int = None,
|
||||
bos_token_id: int = None,
|
||||
eos_token_id: int = 49407,
|
||||
max_position_embeddings: int = 77,
|
||||
initializer_range=0.02,
|
||||
initializer_range: bool = 0.02,
|
||||
is_causal: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||
@ -199,6 +202,7 @@ class AIMv2TextConfig(PretrainedConfig):
|
||||
self.qkv_bias = qkv_bias
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.projection_dropout = projection_dropout
|
||||
self.is_causal = is_causal
|
||||
|
||||
|
||||
class AIMv2Config(PretrainedConfig):
|
||||
|
@ -248,7 +248,7 @@ def eager_attention_forward(
|
||||
class AIMv2Attention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config: AIMv2VisionConfig):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
@ -262,7 +262,7 @@ class AIMv2Attention(nn.Module):
|
||||
)
|
||||
|
||||
self.num_key_value_groups = 1
|
||||
self.scaling = 1.0
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
|
||||
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
|
||||
@ -270,6 +270,8 @@ class AIMv2Attention(nn.Module):
|
||||
self.proj_out = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
|
||||
self.proj_drop = nn.Dropout(config.projection_dropout)
|
||||
|
||||
self.is_causal = config.is_causal
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -307,7 +309,7 @@ class AIMv2Attention(nn.Module):
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
is_causal=False,
|
||||
is_causal=self.is_causal,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
@ -98,9 +98,10 @@ class AIMv2VisionConfig(SiglipVisionConfig):
|
||||
projection_dropout: float = 0.0,
|
||||
qkv_bias: bool = False,
|
||||
use_bias: bool = False,
|
||||
hidden_act="silu",
|
||||
initializer_range=0.02,
|
||||
use_head=True,
|
||||
hidden_act: str ="silu",
|
||||
initializer_range: float =0.02,
|
||||
use_head: bool =True,
|
||||
is_causal: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
@ -123,6 +124,7 @@ class AIMv2VisionConfig(SiglipVisionConfig):
|
||||
self.qkv_bias = qkv_bias
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.projection_dropout = projection_dropout
|
||||
self.is_causal=is_causal
|
||||
|
||||
del self.layer_norm_eps
|
||||
|
||||
@ -186,12 +188,13 @@ class AIMv2TextConfig(SiglipTextConfig):
|
||||
projection_dropout: float = 0.0,
|
||||
qkv_bias: bool = False,
|
||||
use_bias: bool = False,
|
||||
hidden_act="silu",
|
||||
pad_token_id=None,
|
||||
bos_token_id=None,
|
||||
hidden_act: str="silu",
|
||||
pad_token_id: int=None,
|
||||
bos_token_id: int=None,
|
||||
eos_token_id: int = 49407,
|
||||
max_position_embeddings: int = 77,
|
||||
initializer_range=0.02,
|
||||
initializer_range: bool=0.02,
|
||||
is_causal: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
@ -214,6 +217,7 @@ class AIMv2TextConfig(SiglipTextConfig):
|
||||
self.qkv_bias = qkv_bias
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.projection_dropout = projection_dropout
|
||||
self.is_causal=is_causal
|
||||
|
||||
del self.bos_token_id
|
||||
del self.pad_token_id
|
||||
@ -390,7 +394,7 @@ class AIMv2TextEmbeddings(CLIPTextEmbeddings):
|
||||
class AIMv2Attention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config: AIMv2VisionConfig):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
@ -404,7 +408,7 @@ class AIMv2Attention(nn.Module):
|
||||
)
|
||||
|
||||
self.num_key_value_groups = 1
|
||||
self.scaling = 1.0
|
||||
self.scaling = self.head_dim ** -0.5
|
||||
|
||||
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
|
||||
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
|
||||
@ -412,6 +416,8 @@ class AIMv2Attention(nn.Module):
|
||||
self.proj_out = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
|
||||
self.proj_drop = nn.Dropout(config.projection_dropout)
|
||||
|
||||
self.is_causal = config.is_causal
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -449,7 +455,7 @@ class AIMv2Attention(nn.Module):
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
is_causal=False,
|
||||
is_causal=self.is_causal,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user