Replace false parameter by a buffer (#18259)

This commit is contained in:
Sylvain Gugger 2022-07-26 13:02:58 +02:00 committed by GitHub
parent 2844c5de10
commit c8ed1b8b59
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 2 additions and 6 deletions

View File

@ -131,9 +131,7 @@ class M2M100SinusoidalPositionalEmbedding(nn.Module):
# in forward put the weights on the correct dtype and device of the param
emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)
self.weights = nn.Parameter(emb_weights)
self.weights.requires_grad = False
self.weights.detach_()
self.register_buffer("weights", emb_weights)
@staticmethod
def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):

View File

@ -173,9 +173,7 @@ class XGLMSinusoidalPositionalEmbedding(nn.Module):
# in forward put the weights on the correct dtype and device of the param
emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)
self.weights = nn.Parameter(emb_weights)
self.weights.requires_grad = False
self.weights.detach_()
self.register_buffer("weights", emb_weights)
@staticmethod
def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):