import argparse import warnings from logging import getLogger from pathlib import Path from typing import Dict import torch from torch.utils.data import DataLoader from tqdm import tqdm from transformers import AutoModelForSeq2SeqLM, AutoTokenizer logger = getLogger(__name__) try: from .utils import Seq2SeqDataset, parse_numeric_cl_kwargs, save_json, use_task_specific_params except ImportError: from utils import Seq2SeqDataset, parse_numeric_cl_kwargs, save_json, use_task_specific_params DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu" def eval_data_dir( data_dir, save_dir: str, model_name: str, bs: int = 8, max_source_length: int = 1024, type_path="val", n_obs=None, fp16=False, save_source=False, num_beams: int = 4, task="summarization", local_rank=None, **generate_kwargs, ) -> Dict: """Run evaluation on part of the data for one gpu and save to {save_dir}/rank_{rank}_output.json""" model_name = str(model_name) assert local_rank is not None torch.distributed.init_process_group(backend="nccl", rank=local_rank) save_dir = Path(save_dir) save_path = save_dir.joinpath(f"rank_{local_rank}_output.json") torch.cuda.set_device(local_rank) model = AutoModelForSeq2SeqLM.from_pretrained(model_name).cuda() if fp16: model = model.half() tokenizer = AutoTokenizer.from_pretrained(model_name) logger.info(f"Inferred tokenizer type: {tokenizer.__class__}") # if this is wrong, check config.model_type. use_task_specific_params(model, task) # update config with task specific params ds = Seq2SeqDataset( tokenizer, data_dir, max_source_length, max_target_length=1024, type_path=type_path, n_obs=n_obs, prefix=model.config.prefix, ) sampler = ds.make_sortish_sampler(bs, distributed=True) data_loader = DataLoader(ds, sampler=sampler, batch_size=bs, collate_fn=ds.collate_fn) dec_kwargs = dict(skip_special_tokens=True, clean_up_tokenization_spaces=False) # tokenizer.decode results = [] for batch in tqdm(data_loader): summaries = model.generate( input_ids=batch["input_ids"].to(model.device), attention_mask=batch["attention_mask"].to(model.device), num_beams=num_beams, **generate_kwargs, ) preds = tokenizer.batch_decode(summaries, **dec_kwargs) labels = tokenizer.batch_decode(batch["labels"], **dec_kwargs) if save_source: docs = tokenizer.batch_decode(batch["input_ids"], **dec_kwargs) for i in range(len(labels)): label, pred = labels[i], preds[i] if save_source: results.append(dict(pred=pred, label=label, source=docs[i])) else: results.append(dict(pred=pred, label=label)) save_json(results, save_path) return results def run_generate(): parser = argparse.ArgumentParser( epilog="Unspecified args like --num_beams=2 --decoder_start_token_id=4 are passed to model.generate" ) parser.add_argument("--input_path", type=str, help="like cnn_dm/test.source") parser.add_argument( "--model_name", type=str, help="like facebook/bart-large-cnn,t5-base, etc.", default="sshleifer/distilbart-xsum-12-3", ) parser.add_argument("--save_dir", type=str, help="where to save", default="tmp_gen") parser.add_argument("--prefix", type=str, default="test", help="which subset to evaluate typically train/val/test") parser.add_argument("--reference_path", type=str, required=False, help="like cnn_dm/test.target") parser.add_argument("--score_path", type=str, required=False, default="metrics.json", help="where to save metrics") parser.add_argument("--task", type=str, default="summarization", help="used for task_specific_params + metrics") parser.add_argument("--bs", type=int, default=8, required=False, help="batch size") parser.add_argument( "--local_rank", type=int, default=-1, required=False, help="should be passed by distributed.launch" ) parser.add_argument( "--n_obs", type=int, default=None, required=False, help="How many observations. Defaults to all." ) parser.add_argument("--fp16", action="store_true") parser.add_argument("--save_source", action="store_true") args, rest = parser.parse_known_args() parsed = parse_numeric_cl_kwargs(rest) if parsed: print(f"parsed the following generate kwargs: {parsed}") Path(args.save_dir).mkdir(exist_ok=True) if args.reference_path is None and Path(args.score_path).exists(): warnings.warn(f"score_path {args.score_path} will be overwritten unless you type ctrl-c.") eval_data_dir( args.input_path, args.save_dir, args.model_name, prefix=args.prefix, batch_size=args.bs, fp16=args.fp16, task=args.task, local_rank=args.local_rank, n_obs=args.n_obs, save_source=args.save_source, **parsed, ) if __name__ == "__main__": # Usage for MT: run_generate()