mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Replace false parameter by a buffer (#18259)
This commit is contained in:
parent
2844c5de10
commit
c8ed1b8b59
@ -131,9 +131,7 @@ class M2M100SinusoidalPositionalEmbedding(nn.Module):
|
||||
# in forward put the weights on the correct dtype and device of the param
|
||||
emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)
|
||||
|
||||
self.weights = nn.Parameter(emb_weights)
|
||||
self.weights.requires_grad = False
|
||||
self.weights.detach_()
|
||||
self.register_buffer("weights", emb_weights)
|
||||
|
||||
@staticmethod
|
||||
def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
|
||||
|
@ -173,9 +173,7 @@ class XGLMSinusoidalPositionalEmbedding(nn.Module):
|
||||
# in forward put the weights on the correct dtype and device of the param
|
||||
emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)
|
||||
|
||||
self.weights = nn.Parameter(emb_weights)
|
||||
self.weights.requires_grad = False
|
||||
self.weights.detach_()
|
||||
self.register_buffer("weights", emb_weights)
|
||||
|
||||
@staticmethod
|
||||
def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
|
||||
|
Loading…
Reference in New Issue
Block a user