diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 1ae05846922..3766c3479b1 100755 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -131,10 +131,11 @@ def rotate_half(x): def apply_rotary_pos_emb(q, k, cos, sin, position_ids): - gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1] - gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3]) - cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) - sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) + # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. + cos = cos.squeeze((0, 1)) # [seq_len, dim] + sin = sin.squeeze((0, 1)) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed