Update deprecated torch.ger (#24387)

This commit is contained in:
Sergii Dymchenko 2023-06-20 17:21:13 -07:00 committed by GitHub
parent eb849f6604
commit cb8f675510
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -185,7 +185,7 @@ class PositionalEmbedding(nn.Module):
self.register_buffer("inv_freq", inv_freq)
def forward(self, pos_seq, bsz=None):
sinusoid_inp = torch.ger(pos_seq, self.inv_freq)
sinusoid_inp = torch.outer(pos_seq, self.inv_freq)
pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)
if bsz is not None: