Fix autocast incompatibility in RecurrentGemma (#30832)

This commit is contained in:
Phillip Rust 2024-06-19 09:59:34 +02:00 committed by GitHub
parent b275a41005
commit 7c71b61dae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -254,8 +254,8 @@ class RecurrentGemmaSdpaAttention(nn.Module):
k_out = k_out[:, :, indices]
v_out = v_out[:, :, indices]
k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states
k_out[:, :, cache_position] = key_states.to(k_out.dtype)
v_out[:, :, cache_position] = value_states.to(v_out.dtype)
self.key_states, self.value_states = k_out, v_out
return k_out, v_out