mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
improve(llama): Faster apply_rotary_pos_emb (#22785)
This commit is contained in:
parent
abbc96a214
commit
626c1b8af1
@ -131,10 +131,11 @@ def rotate_half(x):
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
||||
gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
|
||||
gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
|
||||
cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
|
||||
sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
|
||||
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
||||
cos = cos.squeeze((0, 1)) # [seq_len, dim]
|
||||
sin = sin.squeeze((0, 1)) # [seq_len, dim]
|
||||
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
||||
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
Loading…
Reference in New Issue
Block a user