diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index 172a45544ba..05699ef15c7 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -224,7 +224,9 @@ class CodeGenAttention(nn.Module): value = torch.cat((past_value, value), dim=-2) if use_cache is True: - present = (key, value) + # Note that this cast is quite ugly, but is not implemented before ROPE as k_rot in the original codebase is always in fp32. + # Reference: https://github.com/salesforce/CodeGen/blob/f210c3bb1216c975ad858cd4132c0fdeabf4bfc2/codegen1/jaxformer/hf/codegen/modeling_codegen.py#L38 + present = (key.to(hidden_states.dtype), value) else: present = None diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 6b5607f235b..acdbb8c4921 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -249,7 +249,9 @@ class GPTJAttention(nn.Module): value = torch.cat((past_value, value), dim=-2) if use_cache is True: - present = (key, value) + # Note that this cast is quite ugly, but is not implemented before ROPE as the original codebase keeps the key in float32 all along the computation. + # Reference: https://github.com/kingoflolz/mesh-transformer-jax/blob/f8315e3003033b23f21d78361b288953064e0e76/mesh_transformer/layers.py#L128 + present = (key.to(hidden_states.dtype), value) else: present = None diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index f735c471268..53667b6a82c 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -306,6 +306,7 @@ class MistralMLP(nn.Module): return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) +# Copied from transformers.models.llama.modeling_llama.repeat_kv def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,