Fixes to alternating SWA layers in Gemma2 (#31775)

* HybridCache: Flip order of alternating global-attn/sliding-attn layers

* HybridCache: Read sliding_window argument from cache_kwargs

* Gemma2Model: Flip order of alternating global-attn/sliding-attn layers

* Code formatting
This commit is contained in:
turboderp 2024-07-11 10:03:46 +02:00 committed by GitHub
parent d625294d79
commit a695c18649
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 4 additions and 4 deletions

View File

@ -1148,7 +1148,7 @@ class HybridCache(Cache):
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
)
self.is_sliding = torch.tensor(
[i % 2 for i in range(config.num_hidden_layers)], dtype=torch.bool, device=device
[not bool(i % 2) for i in range(config.num_hidden_layers)], dtype=torch.bool, device=device
)
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
@ -1212,9 +1212,9 @@ class HybridCache(Cache):
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
sliding_window: Optional[int] = None,
) -> Tuple[torch.Tensor]:
cache_position = cache_kwargs.get("cache_position")
sliding_window = cache_kwargs.get("sliding_window")
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device=key_states.device)
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device=value_states.device)
k_out = self.key_cache[layer_idx]

View File

@ -216,7 +216,7 @@ class Gemma2Attention(nn.Module):
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
self.sliding_window = config.sliding_window if layer_idx % 2 else None
self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None
def forward(
self,
@ -616,7 +616,7 @@ class Gemma2DecoderLayer(nn.Module):
self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.is_sliding = bool(layer_idx % 2)
self.is_sliding = not bool(layer_idx % 2)
self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.sliding_window = config.sliding_window