mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix key dtype in GPTJ and CodeGen (#26836)
* fix key dtype in gptj and codegen * delay the key cast to a later point * fix
This commit is contained in:
parent
32f799db0d
commit
ede051f1b8
@ -224,7 +224,9 @@ class CodeGenAttention(nn.Module):
|
||||
value = torch.cat((past_value, value), dim=-2)
|
||||
|
||||
if use_cache is True:
|
||||
present = (key, value)
|
||||
# Note that this cast is quite ugly, but is not implemented before ROPE as k_rot in the original codebase is always in fp32.
|
||||
# Reference: https://github.com/salesforce/CodeGen/blob/f210c3bb1216c975ad858cd4132c0fdeabf4bfc2/codegen1/jaxformer/hf/codegen/modeling_codegen.py#L38
|
||||
present = (key.to(hidden_states.dtype), value)
|
||||
else:
|
||||
present = None
|
||||
|
||||
|
@ -249,7 +249,9 @@ class GPTJAttention(nn.Module):
|
||||
value = torch.cat((past_value, value), dim=-2)
|
||||
|
||||
if use_cache is True:
|
||||
present = (key, value)
|
||||
# Note that this cast is quite ugly, but is not implemented before ROPE as the original codebase keeps the key in float32 all along the computation.
|
||||
# Reference: https://github.com/kingoflolz/mesh-transformer-jax/blob/f8315e3003033b23f21d78361b288953064e0e76/mesh_transformer/layers.py#L128
|
||||
present = (key.to(hidden_states.dtype), value)
|
||||
else:
|
||||
present = None
|
||||
|
||||
|
@ -306,6 +306,7 @@ class MistralMLP(nn.Module):
|
||||
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.repeat_kv
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
|
Loading…
Reference in New Issue
Block a user