Fix cos_sin device issue in Falcon model (#26448)

* fix

* fix

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2023-09-28 10:00:15 +02:00 committed by GitHub
parent a7e0ed829c
commit 375b4e0935
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -129,6 +129,11 @@ class FalconRotaryEmbedding(nn.Module):
total_length = seq_len + past_key_values_length
if total_length > self.seq_len_cached:
self._set_cos_sin_cache(total_length, device, dtype)
# the cached tensors need to update their devices (for example, after we change the model's device)
self.cos_cached = self.cos_cached.to(device)
self.sin_cached = self.sin_cached.to(device)
# Gather cos, sin at the designated position ids
cos = self.cos_cached.squeeze(0)[position_ids] # [bs, seq_len, dim]
sin = self.sin_cached.squeeze(0)[position_ids] # [bs, seq_len, dim]