Fix ESM models buffers (#24576)

* Fix ESM models buffers

* Remove modifs

* Tied weights keys are needed silly

* quality
This commit is contained in:
Sylvain Gugger 2023-06-29 10:55:21 -04:00 committed by GitHub
parent b324557aac
commit 8c4471d1fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -96,7 +96,7 @@ class RotaryEmbedding(torch.nn.Module):
# Generate and save the inverse frequency buffer (non trainable)
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
inv_freq = inv_freq
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.register_buffer("inv_freq", inv_freq)
self._seq_len_cached = None
self._cos_cached = None