From cb8f675510f35b34c935a9ad04f2fe7fe6a8a9e2 Mon Sep 17 00:00:00 2001 From: Sergii Dymchenko Date: Tue, 20 Jun 2023 17:21:13 -0700 Subject: [PATCH] Update deprecated torch.ger (#24387) --- src/transformers/models/transfo_xl/modeling_transfo_xl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/transfo_xl/modeling_transfo_xl.py b/src/transformers/models/transfo_xl/modeling_transfo_xl.py index 493ab08d9ce..d0f6cc029fb 100644 --- a/src/transformers/models/transfo_xl/modeling_transfo_xl.py +++ b/src/transformers/models/transfo_xl/modeling_transfo_xl.py @@ -185,7 +185,7 @@ class PositionalEmbedding(nn.Module): self.register_buffer("inv_freq", inv_freq) def forward(self, pos_seq, bsz=None): - sinusoid_inp = torch.ger(pos_seq, self.inv_freq) + sinusoid_inp = torch.outer(pos_seq, self.inv_freq) pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1) if bsz is not None: