Fix key dtype in GPTJ and CodeGen (#26836)

* fix key dtype in gptj and codegen

* delay the key cast to a later point

* fix
This commit is contained in:
fxmarty 2023-10-24 09:55:14 +02:00 committed by GitHub
parent 32f799db0d
commit ede051f1b8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 7 additions and 2 deletions

View File

@ -224,7 +224,9 @@ class CodeGenAttention(nn.Module):
value = torch.cat((past_value, value), dim=-2)
if use_cache is True:
present = (key, value)
# Note that this cast is quite ugly, but is not implemented before ROPE as k_rot in the original codebase is always in fp32.
# Reference: https://github.com/salesforce/CodeGen/blob/f210c3bb1216c975ad858cd4132c0fdeabf4bfc2/codegen1/jaxformer/hf/codegen/modeling_codegen.py#L38
present = (key.to(hidden_states.dtype), value)
else:
present = None

View File

@ -249,7 +249,9 @@ class GPTJAttention(nn.Module):
value = torch.cat((past_value, value), dim=-2)
if use_cache is True:
present = (key, value)
# Note that this cast is quite ugly, but is not implemented before ROPE as the original codebase keeps the key in float32 all along the computation.
# Reference: https://github.com/kingoflolz/mesh-transformer-jax/blob/f8315e3003033b23f21d78361b288953064e0e76/mesh_transformer/layers.py#L128
present = (key.to(hidden_states.dtype), value)
else:
present = None

View File

@ -306,6 +306,7 @@ class MistralMLP(nn.Module):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,