mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-20 21:18:21 +06:00
152 lines
5.6 KiB
Python
152 lines
5.6 KiB
Python
# coding=utf-8
|
|
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HugginFace Inc. team.
|
|
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
""" PyTorch Transformer XL model evaluation script.
|
|
Adapted from https://github.com/kimiyoung/transformer-xl.
|
|
In particular https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/eval.py
|
|
"""
|
|
import os
|
|
import sys
|
|
import functools
|
|
import argparse
|
|
import time
|
|
import math
|
|
|
|
import torch
|
|
|
|
from pytorch_pretrained_bert import TransfoXLModel, TransfoXLCorpus
|
|
|
|
def logging(s, log_path, print_=True, log_=True):
|
|
if print_:
|
|
print(s)
|
|
if log_:
|
|
with open(log_path, 'a+') as f_log:
|
|
f_log.write(s + '\n')
|
|
|
|
def get_logger(log_path, **kwargs):
|
|
return functools.partial(logging, log_path=log_path, **kwargs)
|
|
|
|
parser = argparse.ArgumentParser(description='PyTorch Transformer Language Model')
|
|
# parser.add_argument('--data', type=str, default='../data/wikitext-103',
|
|
# help='location of the data corpus')
|
|
parser.add_argument('--model_name', type=str, default='transfo-xl-wt103',
|
|
choices=['transfo-xl-wt103'], #, 'lm1b', 'enwik8', 'text8'],
|
|
help='pretrained model name')
|
|
parser.add_argument('--split', type=str, default='all',
|
|
choices=['all', 'valid', 'test'],
|
|
help='which split to evaluate')
|
|
parser.add_argument('--batch_size', type=int, default=10,
|
|
help='batch size')
|
|
parser.add_argument('--tgt_len', type=int, default=5,
|
|
help='number of tokens to predict')
|
|
parser.add_argument('--ext_len', type=int, default=0,
|
|
help='length of the extended context')
|
|
parser.add_argument('--mem_len', type=int, default=0,
|
|
help='length of the retained previous heads')
|
|
parser.add_argument('--clamp_len', type=int, default=-1,
|
|
help='max positional embedding index')
|
|
parser.add_argument('--cuda', action='store_true',
|
|
help='use CUDA')
|
|
parser.add_argument('--work_dir', type=str, required=True,
|
|
help='path to the work_dir')
|
|
parser.add_argument('--no_log', action='store_true',
|
|
help='do not log the eval result')
|
|
parser.add_argument('--same_length', action='store_true',
|
|
help='set same length attention with masking')
|
|
args = parser.parse_args()
|
|
assert args.ext_len >= 0, 'extended context length must be non-negative'
|
|
|
|
device = torch.device("cuda" if args.cuda else "cpu")
|
|
|
|
# Get logger
|
|
logging = get_logger(os.path.join(args.work_dir, 'log.txt'),
|
|
log_=not args.no_log)
|
|
|
|
# Load dataset
|
|
corpus = TransfoXLCorpus.from_pretrained(args.model_name)
|
|
ntokens = len(corpus.vocab)
|
|
|
|
va_iter = corpus.get_iterator('valid', args.batch_size, args.tgt_len,
|
|
device=device, ext_len=args.ext_len)
|
|
te_iter = corpus.get_iterator('test', args.batch_size, args.tgt_len,
|
|
device=device, ext_len=args.ext_len)
|
|
|
|
# Load the best saved model.
|
|
# with open(os.path.join(args.work_dir, 'model.pt'), 'rb') as f:
|
|
# model = torch.load(f)
|
|
# model.backward_compatible()
|
|
model = TransfoXLModel.from_pretrained(args.model_name)
|
|
model = model.to(device)
|
|
|
|
logging('Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}'.format(
|
|
args.batch_size, args.tgt_len, args.ext_len, args.mem_len, args.clamp_len))
|
|
|
|
model.reset_length(args.tgt_len, args.ext_len, args.mem_len)
|
|
if args.clamp_len > 0:
|
|
model.clamp_len = args.clamp_len
|
|
if args.same_length:
|
|
model.same_length = True
|
|
|
|
###############################################################################
|
|
# Evaluation code
|
|
###############################################################################
|
|
def evaluate(eval_iter):
|
|
# Turn on evaluation mode which disables dropout.
|
|
model.eval()
|
|
total_len, total_loss = 0, 0.
|
|
start_time = time.time()
|
|
with torch.no_grad():
|
|
mems = tuple()
|
|
for idx, (data, target, seq_len) in enumerate(eval_iter):
|
|
ret = model(data, target, *mems)
|
|
loss, mems = ret[0], ret[1:]
|
|
loss = loss.mean()
|
|
total_loss += seq_len * loss.item()
|
|
total_len += seq_len
|
|
total_time = time.time() - start_time
|
|
logging('Time : {:.2f}s, {:.2f}ms/segment'.format(
|
|
total_time, 1000 * total_time / (idx+1)))
|
|
return total_loss / total_len
|
|
|
|
# Run on test data.
|
|
if args.split == 'all':
|
|
test_loss = evaluate(te_iter)
|
|
valid_loss = evaluate(va_iter)
|
|
elif args.split == 'valid':
|
|
valid_loss = evaluate(va_iter)
|
|
test_loss = None
|
|
elif args.split == 'test':
|
|
test_loss = evaluate(te_iter)
|
|
valid_loss = None
|
|
|
|
def format_log(loss, split):
|
|
if args.dataset in ['enwik8', 'text8']:
|
|
log_str = '| {0} loss {1:5.2f} | {0} bpc {2:9.5f} '.format(
|
|
split, loss, loss / math.log(2))
|
|
else:
|
|
log_str = '| {0} loss {1:5.2f} | {0} ppl {2:9.3f} '.format(
|
|
split, loss, math.exp(loss))
|
|
return log_str
|
|
|
|
log_str = ''
|
|
if valid_loss is not None:
|
|
log_str += format_log(valid_loss, 'valid')
|
|
if test_loss is not None:
|
|
log_str += format_log(test_loss, 'test')
|
|
|
|
logging('=' * 100)
|
|
logging(log_str)
|
|
logging('=' * 100)
|