This commit is contained in:
thomwolf 2018-11-04 15:17:55 +01:00
parent 834b485b2e
commit 6b0da96b4b
2 changed files with 15 additions and 12 deletions

View File

@ -69,7 +69,7 @@ class InputFeatures(object):
self.input_mask = input_mask self.input_mask = input_mask
self.segment_ids = segment_ids self.segment_ids = segment_ids
self.label_id = label_id self.label_id = label_id
class DataProcessor(object): class DataProcessor(object):
"""Base class for data converters for sequence classification data sets.""" """Base class for data converters for sequence classification data sets."""
@ -95,8 +95,8 @@ class DataProcessor(object):
for line in reader: for line in reader:
lines.append(line) lines.append(line)
return lines return lines
class MrpcProcessor(DataProcessor): class MrpcProcessor(DataProcessor):
"""Processor for the MRPC data set (GLUE version).""" """Processor for the MRPC data set (GLUE version)."""
@ -190,10 +190,9 @@ class ColaProcessor(DataProcessor):
examples.append( examples.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples return examples
def convert_examples_to_features(examples, label_list, max_seq_length, def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer):
tokenizer):
"""Loads a data file into a list of `InputBatch`s.""" """Loads a data file into a list of `InputBatch`s."""
label_map = {} label_map = {}
@ -380,7 +379,7 @@ def main():
parser.add_argument("--do_lower_case", parser.add_argument("--do_lower_case",
default=False, default=False,
action='store_true', action='store_true',
help="Whether to lower case the input text. Should be True for uncased models and False for cased models.") help="Whether to lower case the input text. True for uncased models, False for cased models.")
parser.add_argument("--max_seq_length", parser.add_argument("--max_seq_length",
default=128, default=128,
type=int, type=int,
@ -424,6 +423,10 @@ def main():
default=False, default=False,
action='store_true', action='store_true',
help="Whether not to use CUDA when available") help="Whether not to use CUDA when available")
parser.add_argument("--accumulate_gradients",
type=int,
default=1,
help="Number of steps to accumulate gradient on (divide the single step batch_size)")
parser.add_argument("--local_rank", parser.add_argument("--local_rank",
type=int, type=int,
default=-1, default=-1,
@ -448,12 +451,12 @@ def main():
n_gpu = 1 n_gpu = 1
# print("Initializing the distributed backend: NCCL") # print("Initializing the distributed backend: NCCL")
print("device", device, "n_gpu", n_gpu) print("device", device, "n_gpu", n_gpu)
random.seed(args.seed) random.seed(args.seed)
np.random.seed(args.seed) np.random.seed(args.seed)
torch.manual_seed(args.seed) torch.manual_seed(args.seed)
if n_gpu>0: torch.cuda.manual_seed_all(args.seed) if n_gpu>0: torch.cuda.manual_seed_all(args.seed)
if not args.do_train and not args.do_eval: if not args.do_train and not args.do_eval:
raise ValueError("At least one of `do_train` or `do_eval` must be True.") raise ValueError("At least one of `do_train` or `do_eval` must be True.")

View File

@ -18,15 +18,15 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import six
import argparse import argparse
import collections import collections
import logging import logging
import json import json
import math import math
import os import os
from tqdm import tqdm, trange import six
import random import random
from tqdm import tqdm, trange
import numpy as np import numpy as np
import torch import torch