diff --git a/src/transformers/models/transfo_xl/modeling_transfo_xl.py b/src/transformers/models/transfo_xl/modeling_transfo_xl.py index 493ab08d9ce..d0f6cc029fb 100644 --- a/src/transformers/models/transfo_xl/modeling_transfo_xl.py +++ b/src/transformers/models/transfo_xl/modeling_transfo_xl.py @@ -185,7 +185,7 @@ class PositionalEmbedding(nn.Module): self.register_buffer("inv_freq", inv_freq) def forward(self, pos_seq, bsz=None): - sinusoid_inp = torch.ger(pos_seq, self.inv_freq) + sinusoid_inp = torch.outer(pos_seq, self.inv_freq) pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1) if bsz is not None: