mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
clean up
This commit is contained in:
parent
834b485b2e
commit
6b0da96b4b
@ -69,7 +69,7 @@ class InputFeatures(object):
|
|||||||
self.input_mask = input_mask
|
self.input_mask = input_mask
|
||||||
self.segment_ids = segment_ids
|
self.segment_ids = segment_ids
|
||||||
self.label_id = label_id
|
self.label_id = label_id
|
||||||
|
|
||||||
|
|
||||||
class DataProcessor(object):
|
class DataProcessor(object):
|
||||||
"""Base class for data converters for sequence classification data sets."""
|
"""Base class for data converters for sequence classification data sets."""
|
||||||
@ -95,8 +95,8 @@ class DataProcessor(object):
|
|||||||
for line in reader:
|
for line in reader:
|
||||||
lines.append(line)
|
lines.append(line)
|
||||||
return lines
|
return lines
|
||||||
|
|
||||||
|
|
||||||
class MrpcProcessor(DataProcessor):
|
class MrpcProcessor(DataProcessor):
|
||||||
"""Processor for the MRPC data set (GLUE version)."""
|
"""Processor for the MRPC data set (GLUE version)."""
|
||||||
|
|
||||||
@ -190,10 +190,9 @@ class ColaProcessor(DataProcessor):
|
|||||||
examples.append(
|
examples.append(
|
||||||
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
|
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
|
||||||
return examples
|
return examples
|
||||||
|
|
||||||
|
|
||||||
def convert_examples_to_features(examples, label_list, max_seq_length,
|
def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer):
|
||||||
tokenizer):
|
|
||||||
"""Loads a data file into a list of `InputBatch`s."""
|
"""Loads a data file into a list of `InputBatch`s."""
|
||||||
|
|
||||||
label_map = {}
|
label_map = {}
|
||||||
@ -380,7 +379,7 @@ def main():
|
|||||||
parser.add_argument("--do_lower_case",
|
parser.add_argument("--do_lower_case",
|
||||||
default=False,
|
default=False,
|
||||||
action='store_true',
|
action='store_true',
|
||||||
help="Whether to lower case the input text. Should be True for uncased models and False for cased models.")
|
help="Whether to lower case the input text. True for uncased models, False for cased models.")
|
||||||
parser.add_argument("--max_seq_length",
|
parser.add_argument("--max_seq_length",
|
||||||
default=128,
|
default=128,
|
||||||
type=int,
|
type=int,
|
||||||
@ -424,6 +423,10 @@ def main():
|
|||||||
default=False,
|
default=False,
|
||||||
action='store_true',
|
action='store_true',
|
||||||
help="Whether not to use CUDA when available")
|
help="Whether not to use CUDA when available")
|
||||||
|
parser.add_argument("--accumulate_gradients",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of steps to accumulate gradient on (divide the single step batch_size)")
|
||||||
parser.add_argument("--local_rank",
|
parser.add_argument("--local_rank",
|
||||||
type=int,
|
type=int,
|
||||||
default=-1,
|
default=-1,
|
||||||
@ -448,12 +451,12 @@ def main():
|
|||||||
n_gpu = 1
|
n_gpu = 1
|
||||||
# print("Initializing the distributed backend: NCCL")
|
# print("Initializing the distributed backend: NCCL")
|
||||||
print("device", device, "n_gpu", n_gpu)
|
print("device", device, "n_gpu", n_gpu)
|
||||||
|
|
||||||
random.seed(args.seed)
|
random.seed(args.seed)
|
||||||
np.random.seed(args.seed)
|
np.random.seed(args.seed)
|
||||||
torch.manual_seed(args.seed)
|
torch.manual_seed(args.seed)
|
||||||
if n_gpu>0: torch.cuda.manual_seed_all(args.seed)
|
if n_gpu>0: torch.cuda.manual_seed_all(args.seed)
|
||||||
|
|
||||||
if not args.do_train and not args.do_eval:
|
if not args.do_train and not args.do_eval:
|
||||||
raise ValueError("At least one of `do_train` or `do_eval` must be True.")
|
raise ValueError("At least one of `do_train` or `do_eval` must be True.")
|
||||||
|
|
||||||
|
@ -18,15 +18,15 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import six
|
|
||||||
import argparse
|
import argparse
|
||||||
import collections
|
import collections
|
||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from tqdm import tqdm, trange
|
import six
|
||||||
import random
|
import random
|
||||||
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
Loading…
Reference in New Issue
Block a user