Merge pull request #388 from ananyahjha93/master

Added remaining GLUE tasks to 'run_classifier.py'
This commit is contained in:
Thomas Wolf 2019-03-28 09:06:53 +01:00 committed by GitHub
commit 694e2117f3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 430 additions and 34 deletions

View File

@ -927,11 +927,60 @@ Where `$THIS_MACHINE_INDEX` is an sequential index assigned to each of your mach
We showcase several fine-tuning examples based on (and extended from) [the original implementation](https://github.com/google-research/bert/): We showcase several fine-tuning examples based on (and extended from) [the original implementation](https://github.com/google-research/bert/):
- a *sequence-level classifier* on the MRPC classification corpus, - a *sequence-level classifier* on nine different GLUE tasks,
- a *token-level classifier* on the question answering dataset SQuAD, and - a *token-level classifier* on the question answering dataset SQuAD, and
- a *sequence-level multiple-choice classifier* on the SWAG classification corpus. - a *sequence-level multiple-choice classifier* on the SWAG classification corpus.
- a *BERT language model* on another target corpus - a *BERT language model* on another target corpus
#### GLUE results on dev set
We get the following results on the dev set of GLUE benchmark with an uncased BERT base
model. All experiments were run on a P100 GPU with a batch size of 32.
| Task | Metric | Result |
|-|-|-|
| CoLA | Matthew's corr. | 57.29 |
| SST-2 | accuracy | 93.00 |
| MRPC | F1/accuracy | 88.85/83.82 |
| STS-B | Pearson/Spearman corr. | 89.70/89.37 |
| QQP | accuracy/F1 | 90.72/87.41 |
| MNLI | matched acc./mismatched acc.| 83.95/84.39 |
| QNLI | accuracy | 89.04 |
| RTE | accuracy | 61.01 |
| WNLI | accuracy | 53.52 |
Some of these results are significantly different from the ones reported on the test set
of GLUE benchmark on the website. For QQP and WNLI, please refer to [FAQ #12](https://gluebenchmark.com/faq) on the webite.
Before running anyone of these GLUE tasks you should download the
[GLUE data](https://gluebenchmark.com/tasks) by running
[this script](https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e)
and unpack it to some directory `$GLUE_DIR`.
```shell
export GLUE_DIR=/path/to/glue
export TASK_NAME=MRPC
python run_classifier.py \
--task_name $TASK_NAME \
--do_train \
--do_eval \
--do_lower_case \
--data_dir $GLUE_DIR/$TASK_NAME \
--bert_model bert-base-uncased \
--max_seq_length 128 \
--train_batch_size 32 \
--learning_rate 2e-5 \
--num_train_epochs 3.0 \
--output_dir /tmp/$TASK_NAME/
```
where task name can be one of CoLA, SST-2, MRPC, STS-B, QQP, MNLI, QNLI, RTE, WNLI.
The dev set results will be present within the text file 'eval_results.txt' in the specified output_dir. In case of MNLI, since there are two separate dev sets, matched and mismatched, there will be a separate output folder called '/tmp/MNLI-MM/' in addition to '/tmp/MNLI/'.
The code has not been tested with half-precision training with apex on any GLUE task apart from MRPC, MNLI, CoLA, SST-2. The following section provides details on how to run half-precision training with MRPC. With that being said, there shouldn't be any issues in running half-precision training with the remaining GLUE tasks as well, since the data processor for each task inherits from the base class DataProcessor.
#### MRPC #### MRPC
This example code fine-tunes BERT on the Microsoft Research Paraphrase This example code fine-tunes BERT on the Microsoft Research Paraphrase

View File

@ -31,6 +31,10 @@ from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange from tqdm import tqdm, trange
from torch.nn import CrossEntropyLoss, MSELoss
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import matthews_corrcoef, f1_score
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from pytorch_pretrained_bert.modeling import BertForSequenceClassification, BertConfig, WEIGHTS_NAME, CONFIG_NAME from pytorch_pretrained_bert.modeling import BertForSequenceClassification, BertConfig, WEIGHTS_NAME, CONFIG_NAME
from pytorch_pretrained_bert.tokenization import BertTokenizer from pytorch_pretrained_bert.tokenization import BertTokenizer
@ -167,6 +171,16 @@ class MnliProcessor(DataProcessor):
return examples return examples
class MnliMismatchedProcessor(MnliProcessor):
"""Processor for the MultiNLI Mismatched data set (GLUE version)."""
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")),
"dev_matched")
class ColaProcessor(DataProcessor): class ColaProcessor(DataProcessor):
"""Processor for the CoLA data set (GLUE version).""" """Processor for the CoLA data set (GLUE version)."""
@ -227,13 +241,181 @@ class Sst2Processor(DataProcessor):
return examples return examples
def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer): class StsbProcessor(DataProcessor):
"""Processor for the STS-B data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_labels(self):
"""See base class."""
return [None]
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, line[0])
text_a = line[7]
text_b = line[8]
label = line[-1]
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
class QqpProcessor(DataProcessor):
"""Processor for the STS-B data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_labels(self):
"""See base class."""
return ["0", "1"]
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, line[0])
try:
text_a = line[3]
text_b = line[4]
label = line[5]
except IndexError:
continue
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
class QnliProcessor(DataProcessor):
"""Processor for the STS-B data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")),
"dev_matched")
def get_labels(self):
"""See base class."""
return ["entailment", "not_entailment"]
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, line[0])
text_a = line[1]
text_b = line[2]
label = line[-1]
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
class RteProcessor(DataProcessor):
"""Processor for the RTE data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_labels(self):
"""See base class."""
return ["entailment", "not_entailment"]
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, line[0])
text_a = line[1]
text_b = line[2]
label = line[-1]
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
class WnliProcessor(DataProcessor):
"""Processor for the WNLI data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_labels(self):
"""See base class."""
return ["0", "1"]
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, line[0])
text_a = line[1]
text_b = line[2]
label = line[-1]
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def convert_examples_to_features(examples, label_list, max_seq_length,
tokenizer, output_mode):
"""Loads a data file into a list of `InputBatch`s.""" """Loads a data file into a list of `InputBatch`s."""
label_map = {label : i for i, label in enumerate(label_list)} label_map = {label : i for i, label in enumerate(label_list)}
features = [] features = []
for (ex_index, example) in enumerate(examples): for (ex_index, example) in enumerate(examples):
if ex_index % 10000 == 0:
logger.info("Writing example %d of %d" % (ex_index, len(examples)))
tokens_a = tokenizer.tokenize(example.text_a) tokens_a = tokenizer.tokenize(example.text_a)
tokens_b = None tokens_b = None
@ -289,7 +471,13 @@ def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer
assert len(input_mask) == max_seq_length assert len(input_mask) == max_seq_length
assert len(segment_ids) == max_seq_length assert len(segment_ids) == max_seq_length
if output_mode == "classification":
label_id = label_map[example.label] label_id = label_map[example.label]
elif output_mode == "regression":
label_id = float(example.label)
else:
raise KeyError(output_mode)
if ex_index < 5: if ex_index < 5:
logger.info("*** Example ***") logger.info("*** Example ***")
logger.info("guid: %s" % (example.guid)) logger.info("guid: %s" % (example.guid))
@ -325,9 +513,56 @@ def _truncate_seq_pair(tokens_a, tokens_b, max_length):
else: else:
tokens_b.pop() tokens_b.pop()
def accuracy(out, labels):
outputs = np.argmax(out, axis=1) def simple_accuracy(preds, labels):
return np.sum(outputs == labels) return (preds == labels).mean()
def acc_and_f1(preds, labels):
acc = simple_accuracy(preds, labels)
f1 = f1_score(y_true=labels, y_pred=preds)
return {
"acc": acc,
"f1": f1,
"acc_and_f1": (acc + f1) / 2,
}
def pearson_and_spearman(preds, labels):
pearson_corr = pearsonr(preds, labels)[0]
spearman_corr = spearmanr(preds, labels)[0]
return {
"pearson": pearson_corr,
"spearmanr": spearman_corr,
"corr": (pearson_corr + spearman_corr) / 2,
}
def compute_metrics(task_name, preds, labels):
assert len(preds) == len(labels)
if task_name == "cola":
return {"mcc": matthews_corrcoef(labels, preds)}
elif task_name == "sst-2":
return {"acc": simple_accuracy(preds, labels)}
elif task_name == "mrpc":
return acc_and_f1(preds, labels)
elif task_name == "sts-b":
return pearson_and_spearman(preds, labels)
elif task_name == "qqp":
return acc_and_f1(preds, labels)
elif task_name == "mnli":
return {"acc": simple_accuracy(preds, labels)}
elif task_name == "mnli-mm":
return {"acc": simple_accuracy(preds, labels)}
elif task_name == "qnli":
return {"acc": simple_accuracy(preds, labels)}
elif task_name == "rte":
return {"acc": simple_accuracy(preds, labels)}
elif task_name == "wnli":
return {"acc": simple_accuracy(preds, labels)}
else:
raise KeyError(task_name)
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -431,15 +666,26 @@ def main():
processors = { processors = {
"cola": ColaProcessor, "cola": ColaProcessor,
"mnli": MnliProcessor, "mnli": MnliProcessor,
"mnli-mm": MnliMismatchedProcessor,
"mrpc": MrpcProcessor, "mrpc": MrpcProcessor,
"sst-2": Sst2Processor, "sst-2": Sst2Processor,
"sts-b": StsbProcessor,
"qqp": QqpProcessor,
"qnli": QnliProcessor,
"rte": RteProcessor,
"wnli": WnliProcessor,
} }
num_labels_task = { output_modes = {
"cola": 2, "cola": "classification",
"sst-2": 2, "mnli": "classification",
"mnli": 3, "mrpc": "classification",
"mrpc": 2, "sst-2": "classification",
"sts-b": "regression",
"qqp": "classification",
"qnli": "classification",
"rte": "classification",
"wnli": "classification",
} }
if args.local_rank == -1 or args.no_cuda: if args.local_rank == -1 or args.no_cuda:
@ -480,8 +726,10 @@ def main():
raise ValueError("Task not found: %s" % (task_name)) raise ValueError("Task not found: %s" % (task_name))
processor = processors[task_name]() processor = processors[task_name]()
num_labels = num_labels_task[task_name] output_mode = output_modes[task_name]
label_list = processor.get_labels() label_list = processor.get_labels()
num_labels = len(label_list)
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
@ -546,7 +794,7 @@ def main():
tr_loss = 0 tr_loss = 0
if args.do_train: if args.do_train:
train_features = convert_examples_to_features( train_features = convert_examples_to_features(
train_examples, label_list, args.max_seq_length, tokenizer) train_examples, label_list, args.max_seq_length, tokenizer, output_mode)
logger.info("***** Running training *****") logger.info("***** Running training *****")
logger.info(" Num examples = %d", len(train_examples)) logger.info(" Num examples = %d", len(train_examples))
logger.info(" Batch size = %d", args.train_batch_size) logger.info(" Batch size = %d", args.train_batch_size)
@ -554,7 +802,12 @@ def main():
all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long) all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long) all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long) all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
if output_mode == "classification":
all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long) all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long)
elif output_mode == "regression":
all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.float)
train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
if args.local_rank == -1: if args.local_rank == -1:
train_sampler = RandomSampler(train_data) train_sampler = RandomSampler(train_data)
@ -569,7 +822,17 @@ def main():
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")): for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
batch = tuple(t.to(device) for t in batch) batch = tuple(t.to(device) for t in batch)
input_ids, input_mask, segment_ids, label_ids = batch input_ids, input_mask, segment_ids, label_ids = batch
loss = model(input_ids, segment_ids, input_mask, label_ids)
# define a new function to compute loss values for both output_modes
logits = model(input_ids, segment_ids, input_mask, labels=None)
if output_mode == "classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1))
elif output_mode == "regression":
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), label_ids.view(-1))
if n_gpu > 1: if n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu. loss = loss.mean() # mean() to average on multi-gpu.
if args.gradient_accumulation_steps > 1: if args.gradient_accumulation_steps > 1:
@ -613,22 +876,28 @@ def main():
if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0): if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
eval_examples = processor.get_dev_examples(args.data_dir) eval_examples = processor.get_dev_examples(args.data_dir)
eval_features = convert_examples_to_features( eval_features = convert_examples_to_features(
eval_examples, label_list, args.max_seq_length, tokenizer) eval_examples, label_list, args.max_seq_length, tokenizer, output_mode)
logger.info("***** Running evaluation *****") logger.info("***** Running evaluation *****")
logger.info(" Num examples = %d", len(eval_examples)) logger.info(" Num examples = %d", len(eval_examples))
logger.info(" Batch size = %d", args.eval_batch_size) logger.info(" Batch size = %d", args.eval_batch_size)
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long) all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long) all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long) all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
if output_mode == "classification":
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long) all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
elif output_mode == "regression":
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.float)
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
# Run prediction for full data # Run prediction for full data
eval_sampler = SequentialSampler(eval_data) eval_sampler = SequentialSampler(eval_data)
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size) eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)
model.eval() model.eval()
eval_loss, eval_accuracy = 0, 0 eval_loss = 0
nb_eval_steps, nb_eval_examples = 0, 0 nb_eval_steps = 0
preds = []
for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"): for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"):
input_ids = input_ids.to(device) input_ids = input_ids.to(device)
@ -637,26 +906,36 @@ def main():
label_ids = label_ids.to(device) label_ids = label_ids.to(device)
with torch.no_grad(): with torch.no_grad():
tmp_eval_loss = model(input_ids, segment_ids, input_mask, label_ids) logits = model(input_ids, segment_ids, input_mask, labels=None)
logits = model(input_ids, segment_ids, input_mask)
logits = logits.detach().cpu().numpy() # create eval loss and other metric required by the task
label_ids = label_ids.to('cpu').numpy() if output_mode == "classification":
tmp_eval_accuracy = accuracy(logits, label_ids) loss_fct = CrossEntropyLoss()
tmp_eval_loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1))
elif output_mode == "regression":
loss_fct = MSELoss()
tmp_eval_loss = loss_fct(logits.view(-1), label_ids.view(-1))
eval_loss += tmp_eval_loss.mean().item() eval_loss += tmp_eval_loss.mean().item()
eval_accuracy += tmp_eval_accuracy
nb_eval_examples += input_ids.size(0)
nb_eval_steps += 1 nb_eval_steps += 1
if len(preds) == 0:
preds.append(logits.detach().cpu().numpy())
else:
preds[0] = np.append(
preds[0], logits.detach().cpu().numpy(), axis=0)
eval_loss = eval_loss / nb_eval_steps eval_loss = eval_loss / nb_eval_steps
eval_accuracy = eval_accuracy / nb_eval_examples preds = preds[0]
if output_mode == "classification":
preds = np.argmax(preds, axis=1)
elif output_mode == "regression":
preds = np.squeeze(preds)
result = compute_metrics(task_name, preds, all_label_ids.numpy())
loss = tr_loss/nb_tr_steps if args.do_train else None loss = tr_loss/nb_tr_steps if args.do_train else None
result = {'eval_loss': eval_loss,
'eval_accuracy': eval_accuracy, result['eval_loss'] = eval_loss
'global_step': global_step, result['global_step'] = global_step
'loss': loss} result['loss'] = loss
output_eval_file = os.path.join(args.output_dir, "eval_results.txt") output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
with open(output_eval_file, "w") as writer: with open(output_eval_file, "w") as writer:
@ -665,5 +944,73 @@ def main():
logger.info(" %s = %s", key, str(result[key])) logger.info(" %s = %s", key, str(result[key]))
writer.write("%s = %s\n" % (key, str(result[key]))) writer.write("%s = %s\n" % (key, str(result[key])))
# hack for MNLI-MM
if task_name == "mnli":
task_name = "mnli-mm"
processor = processors[task_name]()
if os.path.exists(args.output_dir + '-MM') and os.listdir(args.output_dir + '-MM') and args.do_train:
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
if not os.path.exists(args.output_dir + '-MM'):
os.makedirs(args.output_dir + '-MM')
eval_examples = processor.get_dev_examples(args.data_dir)
eval_features = convert_examples_to_features(
eval_examples, label_list, args.max_seq_length, tokenizer, output_mode)
logger.info("***** Running evaluation *****")
logger.info(" Num examples = %d", len(eval_examples))
logger.info(" Batch size = %d", args.eval_batch_size)
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
# Run prediction for full data
eval_sampler = SequentialSampler(eval_data)
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)
model.eval()
eval_loss = 0
nb_eval_steps = 0
preds = []
for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"):
input_ids = input_ids.to(device)
input_mask = input_mask.to(device)
segment_ids = segment_ids.to(device)
label_ids = label_ids.to(device)
with torch.no_grad():
logits = model(input_ids, segment_ids, input_mask, labels=None)
loss_fct = CrossEntropyLoss()
tmp_eval_loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1))
eval_loss += tmp_eval_loss.mean().item()
nb_eval_steps += 1
if len(preds) == 0:
preds.append(logits.detach().cpu().numpy())
else:
preds[0] = np.append(
preds[0], logits.detach().cpu().numpy(), axis=0)
eval_loss = eval_loss / nb_eval_steps
preds = preds[0]
preds = np.argmax(preds, axis=1)
result = compute_metrics(task_name, preds, all_label_ids.numpy())
loss = tr_loss/nb_tr_steps if args.do_train else None
result['eval_loss'] = eval_loss
result['global_step'] = global_step
result['loss'] = loss
output_eval_file = os.path.join(args.output_dir + '-MM', "eval_results.txt")
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results *****")
for key in sorted(result.keys()):
logger.info(" %s = %s", key, str(result[key]))
writer.write("%s = %s\n" % (key, str(result[key])))
if __name__ == "__main__": if __name__ == "__main__":
main() main()