From 7edb51f3a516ca533797fb2bb2f2b7ce86e0df70 Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Tue, 3 Dec 2019 22:07:25 +0000 Subject: [PATCH] [pplm] split classif head into its own file --- examples/pplm/pplm_classification_head.py | 18 ++++++++++++++++++ examples/pplm/run_pplm.py | 2 +- examples/pplm/run_pplm_discrim_train.py | 17 +---------------- 3 files changed, 20 insertions(+), 17 deletions(-) create mode 100644 examples/pplm/pplm_classification_head.py diff --git a/examples/pplm/pplm_classification_head.py b/examples/pplm/pplm_classification_head.py new file mode 100644 index 00000000000..9aae0f17e9c --- /dev/null +++ b/examples/pplm/pplm_classification_head.py @@ -0,0 +1,18 @@ +import torch + +class ClassificationHead(torch.nn.Module): + """Classification Head for transformer encoders""" + + def __init__(self, class_size, embed_size): + super(ClassificationHead, self).__init__() + self.class_size = class_size + self.embed_size = embed_size + # self.mlp1 = torch.nn.Linear(embed_size, embed_size) + # self.mlp2 = (torch.nn.Linear(embed_size, class_size)) + self.mlp = torch.nn.Linear(embed_size, class_size) + + def forward(self, hidden_state): + # hidden_state = F.relu(self.mlp1(hidden_state)) + # hidden_state = self.mlp2(hidden_state) + logits = self.mlp(hidden_state) + return logits diff --git a/examples/pplm/run_pplm.py b/examples/pplm/run_pplm.py index f626a43f4fd..dda5d85ae72 100644 --- a/examples/pplm/run_pplm.py +++ b/examples/pplm/run_pplm.py @@ -33,10 +33,10 @@ import torch.nn.functional as F from torch.autograd import Variable from tqdm import trange -from examples.run_pplm_discrim_train import ClassificationHead from transformers import GPT2Tokenizer from transformers.file_utils import cached_path from transformers.modeling_gpt2 import GPT2LMHeadModel +from pplm_classification_head import ClassificationHead PPLM_BOW = 1 PPLM_DISCRIM = 2 diff --git a/examples/pplm/run_pplm_discrim_train.py b/examples/pplm/run_pplm_discrim_train.py index db081e1a17c..9d36b79bc44 100644 --- a/examples/pplm/run_pplm_discrim_train.py +++ b/examples/pplm/run_pplm_discrim_train.py @@ -21,6 +21,7 @@ from torchtext import datasets from tqdm import tqdm, trange from transformers import GPT2Tokenizer, GPT2LMHeadModel +from pplm_classification_head import ClassificationHead torch.manual_seed(0) np.random.seed(0) @@ -29,22 +30,6 @@ example_sentence = "This is incredible! I love it, this is the best chicken I ha max_length_seq = 100 -class ClassificationHead(torch.nn.Module): - """Classification Head for transformer encoders""" - - def __init__(self, class_size, embed_size): - super(ClassificationHead, self).__init__() - self.class_size = class_size - self.embed_size = embed_size - # self.mlp1 = torch.nn.Linear(embed_size, embed_size) - # self.mlp2 = (torch.nn.Linear(embed_size, class_size)) - self.mlp = torch.nn.Linear(embed_size, class_size) - - def forward(self, hidden_state): - # hidden_state = F.relu(self.mlp1(hidden_state)) - # hidden_state = self.mlp2(hidden_state) - logits = self.mlp(hidden_state) - return logits class Discriminator(torch.nn.Module):