mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-17 03:28:22 +06:00
run_classifier WIP + added classifier head and initialization to the model
This commit is contained in:
parent
4a0b59e980
commit
f690f0e167
@ -27,6 +27,7 @@ import six
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from torch.nn import CrossEntropyLoss
|
||||||
|
|
||||||
def gelu(x):
|
def gelu(x):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@ -394,3 +395,30 @@ class BertModel(nn.Module):
|
|||||||
sequence_output = all_encoder_layers[-1]
|
sequence_output = all_encoder_layers[-1]
|
||||||
pooled_output = self.pooler(sequence_output)
|
pooled_output = self.pooler(sequence_output)
|
||||||
return all_encoder_layers, pooled_output
|
return all_encoder_layers, pooled_output
|
||||||
|
|
||||||
|
class BertForSequenceClassification(nn.Module):
|
||||||
|
def __init__(self, config, num_labels):
|
||||||
|
super(BertForSequenceClassification, self).__init__()
|
||||||
|
self.bert = BertModel(config)
|
||||||
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
self.classifier = nn.Linear(config.hidden_size, num_labels)
|
||||||
|
|
||||||
|
def init_weights(m):
|
||||||
|
if isinstance(m) == nn.Linear or isinstance(m) == nn.Embedding:
|
||||||
|
print("Initializing {}".format(m))
|
||||||
|
# Slight difference here with the TF version which uses truncated_normal
|
||||||
|
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||||
|
m.weight.normal_(config.initializer_range)
|
||||||
|
self.apply(init_weights)
|
||||||
|
|
||||||
|
def forward(self, input_ids, token_type_ids, attention_mask, labels=None):
|
||||||
|
_, pooled_output = self.bert(input_ids, token_type_ids, attention_mask)
|
||||||
|
pooled_output = self.dropout(pooled_output)
|
||||||
|
logits = self.classifier(pooled_output)
|
||||||
|
|
||||||
|
if labels is not None:
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
loss = loss_fct(logits, labels)
|
||||||
|
return loss, logits
|
||||||
|
else:
|
||||||
|
return logits
|
||||||
|
@ -20,20 +20,23 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import csv
|
import csv
|
||||||
import os
|
import os
|
||||||
from modeling_pytorch import BertConfig, BertModel
|
|
||||||
from optimization_pytorch import BERTAdam
|
|
||||||
# import optimization
|
|
||||||
import tokenization_pytorch
|
|
||||||
import torch
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
|
||||||
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
|
||||||
|
import tokenization_pytorch
|
||||||
|
from modeling_pytorch import BertConfig, BertForSequenceClassification
|
||||||
|
from optimization_pytorch import BERTAdam
|
||||||
|
|
||||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||||
level = logging.INFO)
|
level = logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
## Required parameters
|
## Required parameters
|
||||||
@ -127,39 +130,6 @@ parser.add_argument("--local_rank",
|
|||||||
default=-1,
|
default=-1,
|
||||||
help = "local_rank for distributed training on gpus")
|
help = "local_rank for distributed training on gpus")
|
||||||
|
|
||||||
### BEGIN - TO DELETE EVENTUALLY --> NO SENSE IN PYTORCH ###
|
|
||||||
parser.add_argument("--use_tpu",
|
|
||||||
default = False,
|
|
||||||
type = bool,
|
|
||||||
help = "Whether to use TPU or GPU/CPU.")
|
|
||||||
parser.add_argument("--tpu_name",
|
|
||||||
default = None,
|
|
||||||
type = str,
|
|
||||||
help = "The Cloud TPU to use for training. This should be either the name "
|
|
||||||
"used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
|
|
||||||
"url.")
|
|
||||||
parser.add_argument("--tpu_zone",
|
|
||||||
default = None,
|
|
||||||
type = str,
|
|
||||||
help = "[Optional] GCE zone where the Cloud TPU is located in. If not "
|
|
||||||
"specified, we will attempt to automatically detect the GCE project from "
|
|
||||||
"metadata.")
|
|
||||||
parser.add_argument("--gcp_project",
|
|
||||||
default = None,
|
|
||||||
type = str,
|
|
||||||
help = "[Optional] Project name for the Cloud TPU-enabled project. If not "
|
|
||||||
"specified, we will attempt to automatically detect the GCE project from "
|
|
||||||
"metadata.")
|
|
||||||
parser.add_argument("--master",
|
|
||||||
default = None,
|
|
||||||
type = str,
|
|
||||||
help = "[Optional] TensorFlow master URL.")
|
|
||||||
parser.add_argument("--num_tpu_cores",
|
|
||||||
default = 8,
|
|
||||||
type = int,
|
|
||||||
help = "Only used if `use_tpu` is True. Total number of TPU cores to use.")
|
|
||||||
### END - TO DELETE EVENTUALLY --> NO SENSE IN PYTORCH ###
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
class InputExample(object):
|
class InputExample(object):
|
||||||
@ -429,44 +399,41 @@ def _truncate_seq_pair(tokens_a, tokens_b, max_length):
|
|||||||
tokens_b.pop()
|
tokens_b.pop()
|
||||||
|
|
||||||
|
|
||||||
def input_fn_builder(features, seq_length, is_training, drop_remainder):
|
def input_fn_builder(features, seq_length, train_batch_size):
|
||||||
|
# TODO: delete
|
||||||
"""Creates an `input_fn` closure to be passed to TPUEstimator.""" ### ATTENTION - To rewrite ###
|
"""Creates an `input_fn` closure to be passed to TPUEstimator.""" ### ATTENTION - To rewrite ###
|
||||||
|
|
||||||
all_input_ids = []
|
all_input_ids = [f.input_ids for feature in features]
|
||||||
all_input_mask = []
|
all_input_mask = [f.input_mask for feature in features]
|
||||||
all_segment_ids = []
|
all_segment_ids = [f.segment_ids for feature in features]
|
||||||
all_label_ids = []
|
all_label_ids = [f.label_id for feature in features]
|
||||||
|
|
||||||
for feature in features:
|
# for feature in features:
|
||||||
all_input_ids.append(feature.input_ids)
|
# all_input_ids.append(feature.input_ids)
|
||||||
all_input_mask.append(feature.input_mask)
|
# all_input_mask.append(feature.input_mask)
|
||||||
all_segment_ids.append(feature.segment_ids)
|
# all_segment_ids.append(feature.segment_ids)
|
||||||
all_label_ids.append(feature.label_id)
|
# all_label_ids.append(feature.label_id)
|
||||||
|
|
||||||
def input_fn(params):
|
input_ids_tensor = torch.tensor(all_input_ids, dtype=torch.Long)
|
||||||
"""The actual input function."""
|
input_mask_tensor = torch.tensor(all_input_mask, dtype=torch.Long)
|
||||||
batch_size = params["batch_size"]
|
segment_tensor = torch.tensor(all_segment, dtype=torch.Long)
|
||||||
|
label_tensor = torch.tensor(all_label, dtype=torch.Long)
|
||||||
|
|
||||||
num_examples = len(features)
|
train_data = TensorDataset(input_ids_tensor, input_mask_tensor,
|
||||||
|
segment_tensor, label_tensor)
|
||||||
|
if args.local_rank == -1:
|
||||||
|
train_sampler = RandomSampler(train_data)
|
||||||
|
else:
|
||||||
|
train_sampler = DistributedSampler(train_data)
|
||||||
|
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=train_batch_size)
|
||||||
|
|
||||||
device = torch.device("cuda") if args.use_gpu else torch.device("cpu")
|
return train_dataloader
|
||||||
d = torch.utils.data.TensorDataset({ ## BUG THIS IS NOT WORKING.... ###
|
|
||||||
"input_ids": torch.IntTensor(all_input_ids, device=device), #Requires_grad=False by default
|
|
||||||
"input_mask": torch.IntTensor(all_input_mask, device=device),
|
|
||||||
"segment_ids": torch.IntTensor(all_segment_ids, device=device),
|
|
||||||
"label_ids": torch.IntTensor(all_label_ids, device=device)
|
|
||||||
})
|
|
||||||
|
|
||||||
shuffle = True if is_training else False
|
def accuracy(out, labels):
|
||||||
d = torch.utils.data.DataLoader(dataset=d, batch_size=batch_size,
|
outputs = np.argmax(out, axis=1)
|
||||||
shuffle=shuffle,drop_last=drop_remainder)
|
return np.sum(outputs==labels)/float(labels.size)
|
||||||
# Cf https://pytorch.org/tutorials/beginner/data_loading_tutorial.html
|
|
||||||
return d
|
|
||||||
|
|
||||||
return input_fn
|
def main():
|
||||||
|
|
||||||
|
|
||||||
def main(_):
|
|
||||||
processors = {
|
processors = {
|
||||||
"cola": ColaProcessor,
|
"cola": ColaProcessor,
|
||||||
"mnli": MnliProcessor,
|
"mnli": MnliProcessor,
|
||||||
@ -517,13 +484,13 @@ def main(_):
|
|||||||
num_train_steps = int(
|
num_train_steps = int(
|
||||||
len(train_examples) / args.train_batch_size * args.num_train_epochs)
|
len(train_examples) / args.train_batch_size * args.num_train_epochs)
|
||||||
|
|
||||||
model = BertModel(bert_config)
|
model = BertForSequenceClassification(bert_config)
|
||||||
if args.init_checkpoint is not None:
|
if args.init_checkpoint is not None:
|
||||||
model.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
|
model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
|
||||||
optimizer = BERTAdam([{'params': [p for n, p in model.named_parameters() if n != 'bias'], 'l2': 0.01},
|
optimizer = BERTAdam([{'params': [p for n, p in model.named_parameters() if n != 'bias'], 'l2': 0.01},
|
||||||
{'params': [p for n, p in model.named_parameters() if n != 'bias']}
|
{'params': [p for n, p in model.named_parameters() if n == 'bias'], 'l2': 0.}
|
||||||
],
|
],
|
||||||
lr=args.learning_rate, schedule='warmup_linear',
|
lr=args.learning_rate, schedule='warmup_linear',
|
||||||
warmup=args.warmup_proportion,
|
warmup=args.warmup_proportion,
|
||||||
@ -536,18 +503,31 @@ def main(_):
|
|||||||
logger.info(" Num examples = %d", len(train_examples))
|
logger.info(" Num examples = %d", len(train_examples))
|
||||||
logger.info(" Batch size = %d", args.train_batch_size)
|
logger.info(" Batch size = %d", args.train_batch_size)
|
||||||
logger.info(" Num steps = %d", num_train_steps)
|
logger.info(" Num steps = %d", num_train_steps)
|
||||||
train_input = input_fn_builder(
|
|
||||||
features=train_features,
|
all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.Long)
|
||||||
seq_length=args.max_seq_length,
|
all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.Long)
|
||||||
is_training=True,
|
all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.Long)
|
||||||
drop_remainder=True)
|
all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.Long)
|
||||||
# estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)
|
|
||||||
for batch_ix, batch in train_input:
|
train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
|
||||||
output = model_fn(batch)
|
if args.local_rank == -1:
|
||||||
loss = output["loss"]
|
train_sampler = RandomSampler(train_data)
|
||||||
|
else:
|
||||||
|
train_sampler = DistributedSampler(train_data)
|
||||||
|
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
|
||||||
|
|
||||||
|
model.train()
|
||||||
|
global_step = 0
|
||||||
|
for input_ids, input_mask, segment_ids, label_ids in train_dataloader:
|
||||||
|
input_ids.to(device)
|
||||||
|
input_mask.to(device)
|
||||||
|
segment_ids.to(device)
|
||||||
|
label_ids.to(device)
|
||||||
|
|
||||||
|
loss = model(input_ids, segment_ids, input_mask, label_ids)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
global_step += 1
|
||||||
|
|
||||||
if args.do_eval:
|
if args.do_eval:
|
||||||
eval_examples = processor.get_dev_examples(args.data_dir)
|
eval_examples = processor.get_dev_examples(args.data_dir)
|
||||||
@ -558,23 +538,40 @@ def main(_):
|
|||||||
logger.info(" Num examples = %d", len(eval_examples))
|
logger.info(" Num examples = %d", len(eval_examples))
|
||||||
logger.info(" Batch size = %d", args.eval_batch_size)
|
logger.info(" Batch size = %d", args.eval_batch_size)
|
||||||
|
|
||||||
# This tells the estimator to run through the entire set.
|
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.Long)
|
||||||
eval_steps = None
|
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.Long)
|
||||||
# However, if running eval on the TPU, you will need to specify the
|
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.Long)
|
||||||
# number of steps.
|
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.Long)
|
||||||
if args.use_tpu:
|
|
||||||
# Eval will be slightly WRONG on the TPU because it will truncate
|
|
||||||
# the last batch.
|
|
||||||
eval_steps = int(len(eval_examples) / args.eval_batch_size)
|
|
||||||
|
|
||||||
eval_drop_remainder = True if args.use_tpu else False
|
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
|
||||||
eval_input_fn = input_fn_builder(
|
if args.local_rank == -1:
|
||||||
features=eval_features,
|
eval_sampler = SequentialSampler(eval_data)
|
||||||
seq_length=args.max_seq_length,
|
else:
|
||||||
is_training=False,
|
eval_sampler = DistributedSampler(eval_data)
|
||||||
drop_remainder=eval_drop_remainder)
|
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
||||||
|
|
||||||
result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps)
|
model.eval()
|
||||||
|
eval_loss = 0
|
||||||
|
eval_accuracy = 0
|
||||||
|
for input_ids, input_mask, segment_ids, label_ids in eval_dataloader:
|
||||||
|
input_ids.to(device)
|
||||||
|
input_mask.to(device)
|
||||||
|
segment_ids.to(device)
|
||||||
|
label_ids.to(device)
|
||||||
|
|
||||||
|
tmp_eval_loss, logits = model(input_ids, segment_ids, input_mask, label_ids)
|
||||||
|
tmp_eval_accuracy = accuracy(logits, label_ids)
|
||||||
|
|
||||||
|
eval_loss += tmp_eval_loss.item()
|
||||||
|
eval_accuracy += tmp_eval_accuracy
|
||||||
|
|
||||||
|
eval_loss = eval_loss / len(eval_dataloader)
|
||||||
|
eval_accuracy = eval_accuracy / len(eval_dataloader)
|
||||||
|
|
||||||
|
result = {'eval_loss': eval_loss,
|
||||||
|
'eval_accuracy': eval_accuracy,
|
||||||
|
'global_step': global_step,
|
||||||
|
'loss': loss.item()}
|
||||||
|
|
||||||
output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
|
output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
|
||||||
with open(output_eval_file, "w") as writer:
|
with open(output_eval_file, "w") as writer:
|
||||||
|
Loading…
Reference in New Issue
Block a user