[flax/mistral] support sliding_window: null in config (#37402)

flax/mistral: Allow sliding_window to be set to none
This commit is contained in:
Yiding Jia 2025-06-02 07:45:02 -07:00 committed by GitHub
parent 1a25fd2f6d
commit cceab972ba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -240,8 +240,8 @@ class FlaxMistralAttention(nn.Module):
self.v_proj = nn.Dense(self.num_key_value_heads * self.head_dim, use_bias=False, dtype=self.dtype)
self.o_proj = nn.Dense(self.hidden_size, use_bias=False, dtype=self.dtype)
casual_mask = make_causal_mask(jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool")
self.causal_mask = jnp.triu(casual_mask, k=-config.sliding_window)
self.rotary_emb = FlaxMistralRotaryEmbedding(config, dtype=self.dtype)
self.causal_mask = jnp.triu(casual_mask, k=-(config.sliding_window or 0))
self.rotary_emb = FlaxMistralRotaryEmbedding(self.config, dtype=self.dtype)
def _split_heads(self, hidden_states, num_heads):
return hidden_states.reshape(hidden_states.shape[:2] + (num_heads, self.head_dim))