# coding=utf-8 # Copyright 2019 The HuggingFace Inc. team. # Copyright (c) 2019 The HuggingFace Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Finetuning seq2seq models for sequence generation.""" import argparse import functools import logging import os import random import sys import numpy as np from tqdm import tqdm, trange import torch from torch.optim import Adam from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from transformers import ( AutoTokenizer, BertForMaskedLM, BertConfig, PreTrainedEncoderDecoder, Model2Model, ) from utils_summarization import ( CNNDailyMailDataset, encode_for_summarization, fit_to_block_size, build_lm_labels, build_mask, compute_token_type_ids, ) logger = logging.getLogger(__name__) logging.basicConfig(stream=sys.stdout, level=logging.INFO) def set_seed(args): random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) # ------------ # Load dataset # ------------ def load_and_cache_examples(args, tokenizer): dataset = CNNDailyMailDataset(tokenizer, data_dir=args.data_dir) return dataset def collate(data, tokenizer, block_size): """ List of tuple as an input. """ # remove the files with empty an story/summary, encode and fit to block data = filter(lambda x: not (len(x[0]) == 0 or len(x[1]) == 0), data) data = [ encode_for_summarization(story, summary, tokenizer) for story, summary in data ] data = [ ( fit_to_block_size(story, block_size, tokenizer.pad_token_id), fit_to_block_size(summary, block_size, tokenizer.pad_token_id), ) for story, summary in data ] stories = torch.tensor([story for story, summary in data]) summaries = torch.tensor([summary for story, summary in data]) encoder_token_type_ids = compute_token_type_ids(stories, tokenizer.cls_token_id) encoder_mask = build_mask(stories, tokenizer.pad_token_id) decoder_mask = build_mask(summaries, tokenizer.pad_token_id) lm_labels = build_lm_labels(summaries, tokenizer.pad_token_id) return ( stories, summaries, encoder_token_type_ids, encoder_mask, decoder_mask, lm_labels, ) # ---------- # Optimizers # ---------- class BertSumOptimizer(object): """ Specific optimizer for BertSum. As described in [1], the authors fine-tune BertSum for abstractive summarization using two Adam Optimizers with different warm-up steps and learning rate. They also use a custom learning rate scheduler. [1] Liu, Yang, and Mirella Lapata. "Text summarization with pretrained encoders." arXiv preprint arXiv:1908.08345 (2019). """ def __init__(self, model, lr, warmup_steps, beta_1=0.99, beta_2=0.999, eps=1e-8): self.encoder = model.encoder self.decoder = model.decoder self.lr = lr self.warmup_steps = warmup_steps self.optimizers = { "encoder": Adam( model.encoder.parameters(), lr=lr["encoder"], betas=(beta_1, beta_2), eps=eps, ), "decoder": Adam( model.decoder.parameters(), lr=lr["decoder"], betas=(beta_1, beta_2), eps=eps, ), } self._step = 0 def _update_rate(self, stack): return self.lr[stack] * min( self._step ** (-0.5), self._step * self.warmup_steps[stack] ** (-0.5) ) def zero_grad(self): self.optimizer_decoder.zero_grad() self.optimizer_encoder.zero_grad() def step(self): self._step += 1 for stack, optimizer in self.optimizers.items(): new_rate = self._update_rate(stack) for param_group in optimizer.param_groups: param_group["lr"] = new_rate optimizer.step() # ------------ # Train # ------------ def train(args, model, tokenizer): """ Fine-tune the pretrained model on the corpus. """ set_seed(args) # Load the data args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) train_dataset = load_and_cache_examples(args, tokenizer) train_sampler = RandomSampler(train_dataset) model_collate_fn = functools.partial(collate, tokenizer=tokenizer, block_size=512) train_dataloader = DataLoader( train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, collate_fn=model_collate_fn, ) # Training schedule if args.max_steps > 0: t_total = args.max_steps args.num_train_epochs = t_total // ( len(train_dataloader) // args.gradient_accumulation_steps + 1 ) else: t_total = ( len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs ) # Prepare the optimizer lr = {"encoder": 0.002, "decoder": 0.2} warmup_steps = {"encoder": 20000, "decoder": 10000} optimizer = BertSumOptimizer(model, lr, warmup_steps) # Train logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Num Epochs = %d", args.num_train_epochs) logger.info( " Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size ) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", args.train_batch_size * args.gradient_accumulation_steps # * (torch.distributed.get_world_size() if args.local_rank != -1 else 1), ) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", t_total) model.zero_grad() train_iterator = trange(args.num_train_epochs, desc="Epoch", disable=True) global_step = 0 tr_loss = 0.0 for _ in train_iterator: epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=True) for step, batch in enumerate(epoch_iterator): source, target, encoder_token_type_ids, encoder_mask, decoder_mask, lm_labels = batch source = source.to(args.device) target = target.to(args.device) encoder_token_type_ids = encoder_token_type_ids.to(args.device) encoder_mask = encoder_mask.to(args.device) decoder_mask = decoder_mask.to(args.device) lm_labels = lm_labels.to(args.device) model.train() outputs = model( source, target, encoder_token_type_ids=encoder_token_type_ids, encoder_attention_mask=encoder_mask, decoder_attention_mask=decoder_mask, decoder_lm_labels=lm_labels, ) loss = outputs[0] print(loss) if args.gradient_accumulation_steps > 1: loss /= args.gradient_accumulation_steps loss.backward() tr_loss += loss.item() if (step + 1) % args.gradient_accumulation_steps == 0: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() model.zero_grad() global_step += 1 if args.max_steps > 0 and global_step > args.max_steps: epoch_iterator.close() break if args.max_steps > 0 and global_step > args.max_steps: train_iterator.close() break return global_step, tr_loss / global_step # ------------ # Train # ------------ def evaluate(args, model, tokenizer, prefix=""): set_seed(args) args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) eval_dataset = load_and_cache_examples(args, tokenizer, evaluate=True) eval_sampler = SequentialSampler(eval_dataset) eval_dataloader = DataLoader( eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size ) # multi-gpu evaluate if args.n_gpu > 1: model = torch.nn.DataParallel(model) logger.info("***** Running evaluation {} *****".format(prefix)) logger.info(" Num examples = %d", len(eval_dataset)) logger.info(" Batch size = %d", args.eval_batch_size) eval_loss = 0.0 nb_eval_steps = 0 model.eval() for batch in tqdm(eval_dataloader, desc="Evaluating"): source, target, encoder_token_type_ids, encoder_mask, decoder_mask, lm_labels = batch source = source.to(args.device) target = target.to(args.device) encoder_token_type_ids = encoder_token_type_ids.to(args.device) encoder_mask = encoder_mask.to(args.device) decoder_mask = decoder_mask.to(args.device) lm_labels = lm_labels.to(args.device) with torch.no_grad(): outputs = model( source, target, encoder_token_type_ids=encoder_token_type_ids, encoder_attention_mask=encoder_mask, decoder_attention_mask=decoder_mask, decoder_lm_labels=lm_labels, ) lm_loss = outputs[0] eval_loss += lm_loss.mean().item() nb_eval_steps += 1 eval_loss = eval_loss / nb_eval_steps perplexity = torch.exp(torch.tensor(eval_loss)) result = {"perplexity": perplexity} # Save the evaluation's results output_eval_file = os.path.join(args.output_dir, "eval_results.txt") if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) with open(output_eval_file, "w") as writer: logger.info("***** Eval results {} *****".format(prefix)) for key in sorted(result.keys()): logger.info(" %s = %s", key, str(result[key])) writer.write("%s = %s\n" % (key, str(result[key]))) return result def main(): parser = argparse.ArgumentParser() # Required parameters parser.add_argument( "--data_dir", default=None, type=str, required=True, help="The input training data file (a text file).", ) parser.add_argument( "--output_dir", default=None, type=str, required=True, help="The output directory where the model predictions and checkpoints will be written.", ) # Optional parameters parser.add_argument( "--gradient_accumulation_steps", type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.", ) parser.add_argument( "--do_evaluate", type=bool, default=False, help="Run model evaluation on out-of-sample data.", ) parser.add_argument("--do_train", type=bool, default=False, help="Run training.") parser.add_argument( "--do_overwrite_output_dir", type=bool, default=False, help="Whether to overwrite the output dir.", ) parser.add_argument( "--model_name_or_path", default="bert-base-cased", type=str, help="The model checkpoint to initialize the encoder and decoder's weights with.", ) parser.add_argument( "--model_type", default="bert", type=str, help="The decoder architecture to be fine-tuned.", ) parser.add_argument( "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." ) parser.add_argument( "--max_steps", default=-1, type=int, help="If > 0: set total number of training steps to perform. Override num_train_epochs.", ) parser.add_argument( "--to_cpu", default=False, type=bool, help="Whether to force training on CPU." ) parser.add_argument( "--num_train_epochs", default=10, type=int, help="Total number of training epochs to perform.", ) parser.add_argument( "--per_gpu_train_batch_size", default=4, type=int, help="Batch size per GPU/CPU for training.", ) parser.add_argument("--seed", default=42, type=int) args = parser.parse_args() if ( os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.do_overwrite_output_dir ): raise ValueError( "Output directory ({}) already exists and is not empty. Use --do_overwrite_output_dir to overwrite.".format( args.output_dir ) ) # Set up training device if args.to_cpu or not torch.cuda.is_available(): args.device = torch.device("cpu") args.n_gpu = 0 else: args.device = torch.device("cuda") args.n_gpu = torch.cuda.device_count() # Load pretrained model and tokenizer. The decoder's weights are randomly initialized. tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) config = BertConfig.from_pretrained(args.model_name_or_path) decoder_model = BertForMaskedLM(config) model = Model2Model.from_pretrained( args.model_name_or_path, decoder_model=decoder_model ) # Setup logging logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) logger.warning( "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", 0, args.device, args.n_gpu, False, False, ) logger.info("Training/evaluation parameters %s", args) # Train the model model.to(args.device) if args.do_train: global_step, tr_loss = train(args, model, tokenizer) logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) logger.info("Saving model checkpoint to %s", args.output_dir) # Save a trained model, configuration and tokenizer using `save_pretrained()`. # They can then be reloaded using `from_pretrained()` model_to_save = ( model.module if hasattr(model, "module") else model ) # Take care of distributed/parallel training model_to_save.save_pretrained(args.output_dir) tokenizer.save_pretrained(args.output_dir) torch.save(args, os.path.join(args.output_dir, "training_arguments.bin")) # Evaluate the model results = {} if args.do_evaluate: checkpoints = [] logger.info("Evaluate the following checkpoints: %s", checkpoints) for checkpoint in checkpoints: encoder_checkpoint = os.path.join(checkpoint, "encoder") decoder_checkpoint = os.path.join(checkpoint, "decoder") model = PreTrainedEncoderDecoder.from_pretrained( encoder_checkpoint, decoder_checkpoint ) model.to(args.device) results = "placeholder" return results if __name__ == "__main__": main()