run_classifier WIP + added classifier head and initialization to the model

This commit is contained in:
thomwolf 2018-11-02 00:27:50 +01:00
parent 4a0b59e980
commit f690f0e167
2 changed files with 128 additions and 103 deletions

View File

@ -27,6 +27,7 @@ import six
import tensorflow as tf
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
def gelu(x):
raise NotImplementedError
@ -394,3 +395,30 @@ class BertModel(nn.Module):
sequence_output = all_encoder_layers[-1]
pooled_output = self.pooler(sequence_output)
return all_encoder_layers, pooled_output
class BertForSequenceClassification(nn.Module):
def __init__(self, config, num_labels):
super(BertForSequenceClassification, self).__init__()
self.bert = BertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, num_labels)
def init_weights(m):
if isinstance(m) == nn.Linear or isinstance(m) == nn.Embedding:
print("Initializing {}".format(m))
# Slight difference here with the TF version which uses truncated_normal
# cf https://github.com/pytorch/pytorch/pull/5617
m.weight.normal_(config.initializer_range)
self.apply(init_weights)
def forward(self, input_ids, token_type_ids, attention_mask, labels=None):
_, pooled_output = self.bert(input_ids, token_type_ids, attention_mask)
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits, labels)
return loss, logits
else:
return logits

View File

@ -20,20 +20,23 @@ from __future__ import print_function
import csv
import os
from modeling_pytorch import BertConfig, BertModel
from optimization_pytorch import BERTAdam
# import optimization
import tokenization_pytorch
import torch
import logging
import argparse
import numpy as np
import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
import tokenization_pytorch
from modeling_pytorch import BertConfig, BertForSequenceClassification
from optimization_pytorch import BERTAdam
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S',
level = logging.INFO)
logger = logging.getLogger(__name__)
import argparse
parser = argparse.ArgumentParser()
## Required parameters
@ -127,39 +130,6 @@ parser.add_argument("--local_rank",
default=-1,
help = "local_rank for distributed training on gpus")
### BEGIN - TO DELETE EVENTUALLY --> NO SENSE IN PYTORCH ###
parser.add_argument("--use_tpu",
default = False,
type = bool,
help = "Whether to use TPU or GPU/CPU.")
parser.add_argument("--tpu_name",
default = None,
type = str,
help = "The Cloud TPU to use for training. This should be either the name "
"used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
"url.")
parser.add_argument("--tpu_zone",
default = None,
type = str,
help = "[Optional] GCE zone where the Cloud TPU is located in. If not "
"specified, we will attempt to automatically detect the GCE project from "
"metadata.")
parser.add_argument("--gcp_project",
default = None,
type = str,
help = "[Optional] Project name for the Cloud TPU-enabled project. If not "
"specified, we will attempt to automatically detect the GCE project from "
"metadata.")
parser.add_argument("--master",
default = None,
type = str,
help = "[Optional] TensorFlow master URL.")
parser.add_argument("--num_tpu_cores",
default = 8,
type = int,
help = "Only used if `use_tpu` is True. Total number of TPU cores to use.")
### END - TO DELETE EVENTUALLY --> NO SENSE IN PYTORCH ###
args = parser.parse_args()
class InputExample(object):
@ -429,44 +399,41 @@ def _truncate_seq_pair(tokens_a, tokens_b, max_length):
tokens_b.pop()
def input_fn_builder(features, seq_length, is_training, drop_remainder):
def input_fn_builder(features, seq_length, train_batch_size):
# TODO: delete
"""Creates an `input_fn` closure to be passed to TPUEstimator.""" ### ATTENTION - To rewrite ###
all_input_ids = []
all_input_mask = []
all_segment_ids = []
all_label_ids = []
all_input_ids = [f.input_ids for feature in features]
all_input_mask = [f.input_mask for feature in features]
all_segment_ids = [f.segment_ids for feature in features]
all_label_ids = [f.label_id for feature in features]
for feature in features:
all_input_ids.append(feature.input_ids)
all_input_mask.append(feature.input_mask)
all_segment_ids.append(feature.segment_ids)
all_label_ids.append(feature.label_id)
# for feature in features:
# all_input_ids.append(feature.input_ids)
# all_input_mask.append(feature.input_mask)
# all_segment_ids.append(feature.segment_ids)
# all_label_ids.append(feature.label_id)
def input_fn(params):
"""The actual input function."""
batch_size = params["batch_size"]
input_ids_tensor = torch.tensor(all_input_ids, dtype=torch.Long)
input_mask_tensor = torch.tensor(all_input_mask, dtype=torch.Long)
segment_tensor = torch.tensor(all_segment, dtype=torch.Long)
label_tensor = torch.tensor(all_label, dtype=torch.Long)
num_examples = len(features)
train_data = TensorDataset(input_ids_tensor, input_mask_tensor,
segment_tensor, label_tensor)
if args.local_rank == -1:
train_sampler = RandomSampler(train_data)
else:
train_sampler = DistributedSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=train_batch_size)
device = torch.device("cuda") if args.use_gpu else torch.device("cpu")
d = torch.utils.data.TensorDataset({ ## BUG THIS IS NOT WORKING.... ###
"input_ids": torch.IntTensor(all_input_ids, device=device), #Requires_grad=False by default
"input_mask": torch.IntTensor(all_input_mask, device=device),
"segment_ids": torch.IntTensor(all_segment_ids, device=device),
"label_ids": torch.IntTensor(all_label_ids, device=device)
})
return train_dataloader
shuffle = True if is_training else False
d = torch.utils.data.DataLoader(dataset=d, batch_size=batch_size,
shuffle=shuffle,drop_last=drop_remainder)
# Cf https://pytorch.org/tutorials/beginner/data_loading_tutorial.html
return d
def accuracy(out, labels):
outputs = np.argmax(out, axis=1)
return np.sum(outputs==labels)/float(labels.size)
return input_fn
def main(_):
def main():
processors = {
"cola": ColaProcessor,
"mnli": MnliProcessor,
@ -517,13 +484,13 @@ def main(_):
num_train_steps = int(
len(train_examples) / args.train_batch_size * args.num_train_epochs)
model = BertModel(bert_config)
model = BertForSequenceClassification(bert_config)
if args.init_checkpoint is not None:
model.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
model.to(device)
optimizer = BERTAdam([{'params': [p for n, p in model.named_parameters() if n != 'bias'], 'l2': 0.01},
{'params': [p for n, p in model.named_parameters() if n != 'bias']}
{'params': [p for n, p in model.named_parameters() if n == 'bias'], 'l2': 0.}
],
lr=args.learning_rate, schedule='warmup_linear',
warmup=args.warmup_proportion,
@ -536,18 +503,31 @@ def main(_):
logger.info(" Num examples = %d", len(train_examples))
logger.info(" Batch size = %d", args.train_batch_size)
logger.info(" Num steps = %d", num_train_steps)
train_input = input_fn_builder(
features=train_features,
seq_length=args.max_seq_length,
is_training=True,
drop_remainder=True)
# estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)
for batch_ix, batch in train_input:
output = model_fn(batch)
loss = output["loss"]
all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.Long)
all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.Long)
all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.Long)
all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.Long)
train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
if args.local_rank == -1:
train_sampler = RandomSampler(train_data)
else:
train_sampler = DistributedSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
model.train()
global_step = 0
for input_ids, input_mask, segment_ids, label_ids in train_dataloader:
input_ids.to(device)
input_mask.to(device)
segment_ids.to(device)
label_ids.to(device)
loss = model(input_ids, segment_ids, input_mask, label_ids)
loss.backward()
optimizer.step()
global_step += 1
if args.do_eval:
eval_examples = processor.get_dev_examples(args.data_dir)
@ -558,23 +538,40 @@ def main(_):
logger.info(" Num examples = %d", len(eval_examples))
logger.info(" Batch size = %d", args.eval_batch_size)
# This tells the estimator to run through the entire set.
eval_steps = None
# However, if running eval on the TPU, you will need to specify the
# number of steps.
if args.use_tpu:
# Eval will be slightly WRONG on the TPU because it will truncate
# the last batch.
eval_steps = int(len(eval_examples) / args.eval_batch_size)
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.Long)
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.Long)
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.Long)
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.Long)
eval_drop_remainder = True if args.use_tpu else False
eval_input_fn = input_fn_builder(
features=eval_features,
seq_length=args.max_seq_length,
is_training=False,
drop_remainder=eval_drop_remainder)
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
if args.local_rank == -1:
eval_sampler = SequentialSampler(eval_data)
else:
eval_sampler = DistributedSampler(eval_data)
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)
result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps)
model.eval()
eval_loss = 0
eval_accuracy = 0
for input_ids, input_mask, segment_ids, label_ids in eval_dataloader:
input_ids.to(device)
input_mask.to(device)
segment_ids.to(device)
label_ids.to(device)
tmp_eval_loss, logits = model(input_ids, segment_ids, input_mask, label_ids)
tmp_eval_accuracy = accuracy(logits, label_ids)
eval_loss += tmp_eval_loss.item()
eval_accuracy += tmp_eval_accuracy
eval_loss = eval_loss / len(eval_dataloader)
eval_accuracy = eval_accuracy / len(eval_dataloader)
result = {'eval_loss': eval_loss,
'eval_accuracy': eval_accuracy,
'global_step': global_step,
'loss': loss.item()}
output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
with open(output_eval_file, "w") as writer: