mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-22 14:00:33 +06:00
55 lines
2.0 KiB
Python
55 lines
2.0 KiB
Python
#!/usr/bin/env python3
|
|
|
|
import argparse
|
|
import logging
|
|
from tqdm import trange
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import numpy as np
|
|
|
|
from pytorch_pretrained_bert import BertForSequenceClassification, BertTokenizer
|
|
|
|
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
|
datefmt = '%m/%d/%Y %H:%M:%S',
|
|
level = logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
def run_model():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--model_name_or_path', type=str, default='bert-base-uncased', help='pretrained model name or path to local checkpoint')
|
|
parser.add_argument("--seed", type=int, default=42)
|
|
parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus")
|
|
parser.add_argument("--no_cuda", action='store_true', help="Whether not to use CUDA when available")
|
|
args = parser.parse_args()
|
|
|
|
np.random.seed(args.seed)
|
|
torch.random.manual_seed(args.seed)
|
|
torch.cuda.manual_seed(args.seed)
|
|
|
|
if args.local_rank == -1 or args.no_cuda:
|
|
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
|
n_gpu = torch.cuda.device_count()
|
|
else:
|
|
torch.cuda.set_device(args.local_rank)
|
|
args.device = torch.device("cuda", args.local_rank)
|
|
n_gpu = 1
|
|
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
|
torch.distributed.init_process_group(backend='nccl')
|
|
|
|
logging.basicConfig(level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
|
|
logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
|
|
args.device, n_gpu, bool(args.local_rank != -1), args.fp16))
|
|
|
|
tokenizer = BertTokenizer.from_pretrained(args.model_name_or_path)
|
|
model = BertForSequenceClassification.from_pretrained(args.model_name_or_path)
|
|
model.to(args.device)
|
|
model.eval()
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
run_model()
|
|
|
|
|