mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
Fix cos_sin
device issue in Falcon model (#26448)
* fix * fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
a7e0ed829c
commit
375b4e0935
@ -129,6 +129,11 @@ class FalconRotaryEmbedding(nn.Module):
|
|||||||
total_length = seq_len + past_key_values_length
|
total_length = seq_len + past_key_values_length
|
||||||
if total_length > self.seq_len_cached:
|
if total_length > self.seq_len_cached:
|
||||||
self._set_cos_sin_cache(total_length, device, dtype)
|
self._set_cos_sin_cache(total_length, device, dtype)
|
||||||
|
|
||||||
|
# the cached tensors need to update their devices (for example, after we change the model's device)
|
||||||
|
self.cos_cached = self.cos_cached.to(device)
|
||||||
|
self.sin_cached = self.sin_cached.to(device)
|
||||||
|
|
||||||
# Gather cos, sin at the designated position ids
|
# Gather cos, sin at the designated position ids
|
||||||
cos = self.cos_cached.squeeze(0)[position_ids] # [bs, seq_len, dim]
|
cos = self.cos_cached.squeeze(0)[position_ids] # [bs, seq_len, dim]
|
||||||
sin = self.sin_cached.squeeze(0)[position_ids] # [bs, seq_len, dim]
|
sin = self.sin_cached.squeeze(0)[position_ids] # [bs, seq_len, dim]
|
||||||
|
Loading…
Reference in New Issue
Block a user