[pplm] split classif head into its own file

This commit is contained in:
Julien Chaumond 2019-12-03 22:07:25 +00:00
parent 8101924a68
commit 7edb51f3a5
3 changed files with 20 additions and 17 deletions

View File

@ -0,0 +1,18 @@
import torch
class ClassificationHead(torch.nn.Module):
"""Classification Head for transformer encoders"""
def __init__(self, class_size, embed_size):
super(ClassificationHead, self).__init__()
self.class_size = class_size
self.embed_size = embed_size
# self.mlp1 = torch.nn.Linear(embed_size, embed_size)
# self.mlp2 = (torch.nn.Linear(embed_size, class_size))
self.mlp = torch.nn.Linear(embed_size, class_size)
def forward(self, hidden_state):
# hidden_state = F.relu(self.mlp1(hidden_state))
# hidden_state = self.mlp2(hidden_state)
logits = self.mlp(hidden_state)
return logits

View File

@ -33,10 +33,10 @@ import torch.nn.functional as F
from torch.autograd import Variable
from tqdm import trange
from examples.run_pplm_discrim_train import ClassificationHead
from transformers import GPT2Tokenizer
from transformers.file_utils import cached_path
from transformers.modeling_gpt2 import GPT2LMHeadModel
from pplm_classification_head import ClassificationHead
PPLM_BOW = 1
PPLM_DISCRIM = 2

View File

@ -21,6 +21,7 @@ from torchtext import datasets
from tqdm import tqdm, trange
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from pplm_classification_head import ClassificationHead
torch.manual_seed(0)
np.random.seed(0)
@ -29,22 +30,6 @@ example_sentence = "This is incredible! I love it, this is the best chicken I ha
max_length_seq = 100
class ClassificationHead(torch.nn.Module):
"""Classification Head for transformer encoders"""
def __init__(self, class_size, embed_size):
super(ClassificationHead, self).__init__()
self.class_size = class_size
self.embed_size = embed_size
# self.mlp1 = torch.nn.Linear(embed_size, embed_size)
# self.mlp2 = (torch.nn.Linear(embed_size, class_size))
self.mlp = torch.nn.Linear(embed_size, class_size)
def forward(self, hidden_state):
# hidden_state = F.relu(self.mlp1(hidden_state))
# hidden_state = self.mlp2(hidden_state)
logits = self.mlp(hidden_state)
return logits
class Discriminator(torch.nn.Module):