[gptj] support older pytorch version (#22325)

* [gptj] support older pytorch version

* contributor

* contributor

* make copies

---------

Co-authored-by: Michael Wyatt <michaelwyatt@microsoft.com>
Co-authored-by: Nick Hill <nickhill@us.ibm.com>
This commit is contained in:
Stas Bekman 2023-03-22 18:35:04 -07:00 committed by GitHub
parent 80e3b36361
commit 61f79b2986
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 3 additions and 2 deletions

View File

@ -55,7 +55,7 @@ CODEGEN_PRETRAINED_MODEL_ARCHIVE_LIST = [
def create_sinusoidal_positions(num_pos: int, dim: int) -> torch.Tensor: def create_sinusoidal_positions(num_pos: int, dim: int) -> torch.Tensor:
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2) / dim)) inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2) / dim))
sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(num_pos, dtype=torch.float), inv_freq).float() sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(num_pos, dtype=torch.float), inv_freq).float()
return torch.concat((torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)), dim=1) return torch.cat((torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)), dim=1)
# Copied from transformers.models.gptj.modeling_gptj.rotate_every_two # Copied from transformers.models.gptj.modeling_gptj.rotate_every_two

View File

@ -18,6 +18,7 @@ import warnings
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import torch import torch
import torch.fx
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
@ -57,7 +58,7 @@ GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST = [
def create_sinusoidal_positions(num_pos: int, dim: int) -> torch.Tensor: def create_sinusoidal_positions(num_pos: int, dim: int) -> torch.Tensor:
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2) / dim)) inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2) / dim))
sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(num_pos, dtype=torch.float), inv_freq).float() sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(num_pos, dtype=torch.float), inv_freq).float()
return torch.concat((torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)), dim=1) return torch.cat((torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)), dim=1)
@torch.fx.wrap @torch.fx.wrap