mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 04:40:06 +06:00
[flax/mistral] support sliding_window: null in config (#37402)
flax/mistral: Allow sliding_window to be set to none
This commit is contained in:
parent
1a25fd2f6d
commit
cceab972ba
@ -240,8 +240,8 @@ class FlaxMistralAttention(nn.Module):
|
||||
self.v_proj = nn.Dense(self.num_key_value_heads * self.head_dim, use_bias=False, dtype=self.dtype)
|
||||
self.o_proj = nn.Dense(self.hidden_size, use_bias=False, dtype=self.dtype)
|
||||
casual_mask = make_causal_mask(jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool")
|
||||
self.causal_mask = jnp.triu(casual_mask, k=-config.sliding_window)
|
||||
self.rotary_emb = FlaxMistralRotaryEmbedding(config, dtype=self.dtype)
|
||||
self.causal_mask = jnp.triu(casual_mask, k=-(config.sliding_window or 0))
|
||||
self.rotary_emb = FlaxMistralRotaryEmbedding(self.config, dtype=self.dtype)
|
||||
|
||||
def _split_heads(self, hidden_states, num_heads):
|
||||
return hidden_states.reshape(hidden_states.shape[:2] + (num_heads, self.head_dim))
|
||||
|
Loading…
Reference in New Issue
Block a user