[qwen-omni] fix sliding window (#38525)

fix
This commit is contained in:
Raushan Turganbay 2025-06-05 10:11:58 +02:00 committed by GitHub
parent 1fed6166c0
commit 0d69fa6dcd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 28 additions and 2 deletions

View File

@ -658,6 +658,8 @@ class Qwen2_5OmniTalkerConfig(PretrainedConfig):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
spatial_merge_size (`int`, *optional*, defaults to 2):
The size used for merging spatial dimensions.
layer_types (`list`, *optional*):
Attention pattern for each layer.
Example:
@ -726,6 +728,7 @@ class Qwen2_5OmniTalkerConfig(PretrainedConfig):
audio_end_token_id=151648,
initializer_range=0.02,
spatial_merge_size=2,
layer_types=None,
**kwargs,
):
self.audio_token_index = audio_token_index
@ -753,7 +756,7 @@ class Qwen2_5OmniTalkerConfig(PretrainedConfig):
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.use_sliding_window = use_sliding_window
self.sliding_window = sliding_window
self.sliding_window = sliding_window if self.use_sliding_window else None
self.max_window_layers = max_window_layers
# for backward compatibility
@ -775,6 +778,16 @@ class Qwen2_5OmniTalkerConfig(PretrainedConfig):
self.initializer_range = initializer_range
self.spatial_merge_size = spatial_merge_size
self.layer_types = layer_types
if self.layer_types is None:
self.layer_types = [
"sliding_attention"
if self.sliding_window is not None and i >= self.max_window_layers
else "full_attention"
for i in range(self.num_hidden_layers)
]
layer_type_validation(self.layer_types)
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)

View File

@ -697,6 +697,8 @@ class Qwen2_5OmniTalkerConfig(PretrainedConfig):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
spatial_merge_size (`int`, *optional*, defaults to 2):
The size used for merging spatial dimensions.
layer_types (`list`, *optional*):
Attention pattern for each layer.
Example:
@ -765,6 +767,7 @@ class Qwen2_5OmniTalkerConfig(PretrainedConfig):
audio_end_token_id=151648,
initializer_range=0.02,
spatial_merge_size=2,
layer_types=None,
**kwargs,
):
self.audio_token_index = audio_token_index
@ -792,7 +795,7 @@ class Qwen2_5OmniTalkerConfig(PretrainedConfig):
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.use_sliding_window = use_sliding_window
self.sliding_window = sliding_window
self.sliding_window = sliding_window if self.use_sliding_window else None
self.max_window_layers = max_window_layers
# for backward compatibility
@ -814,6 +817,16 @@ class Qwen2_5OmniTalkerConfig(PretrainedConfig):
self.initializer_range = initializer_range
self.spatial_merge_size = spatial_merge_size
self.layer_types = layer_types
if self.layer_types is None:
self.layer_types = [
"sliding_attention"
if self.sliding_window is not None and i >= self.max_window_layers
else "full_attention"
for i in range(self.num_hidden_layers)
]
layer_type_validation(self.layer_types)
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)