mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix cos_sin
device issue in Falcon model (#26448)
* fix * fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
a7e0ed829c
commit
375b4e0935
@ -129,6 +129,11 @@ class FalconRotaryEmbedding(nn.Module):
|
||||
total_length = seq_len + past_key_values_length
|
||||
if total_length > self.seq_len_cached:
|
||||
self._set_cos_sin_cache(total_length, device, dtype)
|
||||
|
||||
# the cached tensors need to update their devices (for example, after we change the model's device)
|
||||
self.cos_cached = self.cos_cached.to(device)
|
||||
self.sin_cached = self.sin_cached.to(device)
|
||||
|
||||
# Gather cos, sin at the designated position ids
|
||||
cos = self.cos_cached.squeeze(0)[position_ids] # [bs, seq_len, dim]
|
||||
sin = self.sin_cached.squeeze(0)[position_ids] # [bs, seq_len, dim]
|
||||
|
Loading…
Reference in New Issue
Block a user