Signed-off-by: Roger Wang <hey@rogerw.me>
This commit is contained in:
Roger Wang 2025-06-28 13:49:51 -07:00
parent ccf2ca162e
commit 1d85c39140
3 changed files with 10 additions and 10 deletions

View File

@ -575,7 +575,7 @@ class Gemma3nConfig(PretrainedConfig):
Custom vision config or dict.
audio_config (`Union[AutoConfig, dict]`, *optional*):
Custom audio config or dict.
audio_soft_tokens_per_image (`int`, *optional*, defaults to 188):
audio_soft_tokens_per_audio (`int`, *optional*, defaults to 188):
The number of soft tokens per audio clip.
vision_soft_tokens_per_image (`int`, *optional*, defaults to 256):
The number of soft tokens per image.
@ -631,7 +631,7 @@ class Gemma3nConfig(PretrainedConfig):
text_config: Optional[Union[Gemma3nTextConfig, dict[str, Any]]] = None,
vision_config: Optional[Union[Gemma3nVisionConfig, dict[str, Any]]] = None,
audio_config: Optional[Union[Gemma3nAudioConfig, dict[str, Any]]] = None,
audio_soft_tokens_per_image: int = 188,
audio_soft_tokens_per_audio: int = 188,
vision_soft_tokens_per_image: int = 256,
boi_token_id: int = 255_999,
eoi_token_id: int = 262_144,
@ -666,7 +666,7 @@ class Gemma3nConfig(PretrainedConfig):
self.vision_config = vision_config
self.audio_config = audio_config
self.audio_soft_tokens_per_image = audio_soft_tokens_per_image
self.audio_soft_tokens_per_audio = audio_soft_tokens_per_audio
self.vision_soft_tokens_per_image = vision_soft_tokens_per_image
self.boi_token_id = boi_token_id
self.eoi_token_id = eoi_token_id

View File

@ -937,7 +937,7 @@ class Gemma3nAudioEncoder(PreTrainedModel):
Returns:
audio_encodings: a torch.Tensor of shape
`[batch_size, self.config.audio_soft_tokens_per_image,
`[batch_size, self.config.audio_soft_tokens_per_audio,
self.config.audio_config.hidden_size]`
audio_mel_mask: a torch.BoolTensor of shape [batch, num_frames].
"""
@ -2114,7 +2114,7 @@ class Gemma3nModel(Gemma3nPreTrainedModel):
audio_features = torch.where(audio_mask.unsqueeze(-1), audio_padding_embs, audio_features)
audio_batch_size, audio_seq_len, audio_embed_dim = audio_features.shape
extra_padding_tokens = self.config.audio_soft_tokens_per_image - audio_seq_len
extra_padding_tokens = self.config.audio_soft_tokens_per_audio - audio_seq_len
extra_padding_features = audio_padding_embs.expand(audio_batch_size, extra_padding_tokens, audio_embed_dim)
audio_features = torch.cat((audio_features, extra_padding_features), dim=1)

View File

@ -538,7 +538,7 @@ class Gemma3nConfig(PretrainedConfig):
Custom vision config or dict.
audio_config (`Union[AutoConfig, dict]`, *optional*):
Custom audio config or dict.
audio_soft_tokens_per_image (`int`, *optional*, defaults to 188):
audio_soft_tokens_per_audio (`int`, *optional*, defaults to 188):
The number of soft tokens per audio clip.
vision_soft_tokens_per_image (`int`, *optional*, defaults to 256):
The number of soft tokens per image.
@ -594,7 +594,7 @@ class Gemma3nConfig(PretrainedConfig):
text_config: Optional[Union[Gemma3nTextConfig, dict[str, Any]]] = None,
vision_config: Optional[Union[Gemma3nVisionConfig, dict[str, Any]]] = None,
audio_config: Optional[Union[Gemma3nAudioConfig, dict[str, Any]]] = None,
audio_soft_tokens_per_image: int = 188,
audio_soft_tokens_per_audio: int = 188,
vision_soft_tokens_per_image: int = 256,
boi_token_id: int = 255_999,
eoi_token_id: int = 262_144,
@ -629,7 +629,7 @@ class Gemma3nConfig(PretrainedConfig):
self.vision_config = vision_config
self.audio_config = audio_config
self.audio_soft_tokens_per_image = audio_soft_tokens_per_image
self.audio_soft_tokens_per_audio = audio_soft_tokens_per_audio
self.vision_soft_tokens_per_image = vision_soft_tokens_per_image
self.boi_token_id = boi_token_id
self.eoi_token_id = eoi_token_id
@ -1499,7 +1499,7 @@ class Gemma3nAudioEncoder(PreTrainedModel):
Returns:
audio_encodings: a torch.Tensor of shape
`[batch_size, self.config.audio_soft_tokens_per_image,
`[batch_size, self.config.audio_soft_tokens_per_audio,
self.config.audio_config.hidden_size]`
audio_mel_mask: a torch.BoolTensor of shape [batch, num_frames].
"""
@ -2383,7 +2383,7 @@ class Gemma3nModel(PaliGemmaModel):
audio_features = torch.where(audio_mask.unsqueeze(-1), audio_padding_embs, audio_features)
audio_batch_size, audio_seq_len, audio_embed_dim = audio_features.shape
extra_padding_tokens = self.config.audio_soft_tokens_per_image - audio_seq_len
extra_padding_tokens = self.config.audio_soft_tokens_per_audio - audio_seq_len
extra_padding_features = audio_padding_embs.expand(audio_batch_size, extra_padding_tokens, audio_embed_dim)
audio_features = torch.cat((audio_features, extra_padding_features), dim=1)