import itertools import json import linecache import os import pickle import warnings from logging import getLogger from pathlib import Path from typing import Callable, Dict, Iterable, List import git import numpy as np import torch from rouge_score import rouge_scorer, scoring from sacrebleu import corpus_bleu from torch import nn from torch.utils.data import Dataset, Sampler from transformers import BartTokenizer def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100): """From fairseq""" if target.dim() == lprobs.dim() - 1: target = target.unsqueeze(-1) nll_loss = -lprobs.gather(dim=-1, index=target) smooth_loss = -lprobs.sum(dim=-1, keepdim=True) if ignore_index is not None: pad_mask = target.eq(ignore_index) nll_loss.masked_fill_(pad_mask, 0.0) smooth_loss.masked_fill_(pad_mask, 0.0) bs = pad_mask.long().sum() else: nll_loss = nll_loss.squeeze(-1) smooth_loss = smooth_loss.squeeze(-1) bs = lprobs.shape[0] nll_loss = nll_loss.sum() # mean()? Scared to break other math. smooth_loss = smooth_loss.sum() eps_i = epsilon / lprobs.size(-1) loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss return loss / bs, nll_loss / bs def encode_line(tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"): extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {} return tokenizer( [line], max_length=max_length, padding="max_length" if pad_to_max_length else None, truncation=True, return_tensors=return_tensors, **extra_kw, ) def lmap(f: Callable, x: Iterable) -> List: """list(map(f, x))""" return list(map(f, x)) def calculate_bleu_score(output_lns, refs_lns, **kwargs) -> dict: """Uses sacrebleu's corpus_bleu implementation.""" return {"bleu": corpus_bleu(output_lns, [refs_lns], **kwargs).score} def trim_batch( input_ids, pad_token_id, attention_mask=None, ): """Remove columns that are populated exclusively by pad_token_id""" keep_column_mask = input_ids.ne(pad_token_id).any(dim=0) if attention_mask is None: return input_ids[:, keep_column_mask] else: return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask]) class Seq2SeqDataset(Dataset): def __init__( self, tokenizer, data_dir, max_source_length, max_target_length, type_path="train", n_obs=None, src_lang=None, tgt_lang=None, prefix="", ): super().__init__() self.src_file = Path(data_dir).joinpath(type_path + ".source") self.tgt_file = Path(data_dir).joinpath(type_path + ".target") self.src_lens = self.get_char_lens(self.src_file) self.max_source_length = max_source_length self.max_target_length = max_target_length assert min(self.src_lens) > 0, f"found empty line in {self.src_file}" self.tokenizer = tokenizer self.prefix = prefix if n_obs is not None: self.src_lens = self.src_lens[:n_obs] self.pad_token_id = self.tokenizer.pad_token_id self.src_lang = src_lang self.tgt_lang = tgt_lang def __len__(self): return len(self.src_lens) def __getitem__(self, index) -> Dict[str, torch.Tensor]: index = index + 1 # linecache starts at 1 source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n") tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n") assert source_line, f"empty source line for index {index}" assert tgt_line, f"empty tgt line for index {index}" source_inputs = encode_line(self.tokenizer, source_line, self.max_source_length) target_inputs = encode_line(self.tokenizer, tgt_line, self.max_target_length) source_ids = source_inputs["input_ids"].squeeze() target_ids = target_inputs["input_ids"].squeeze() src_mask = source_inputs["attention_mask"].squeeze() return { "input_ids": source_ids, "attention_mask": src_mask, "decoder_input_ids": target_ids, } @staticmethod def get_char_lens(data_file): return [len(x) for x in Path(data_file).open().readlines()] @staticmethod def trim_seq2seq_batch(batch, pad_token_id) -> tuple: y = trim_batch(batch["decoder_input_ids"], pad_token_id) source_ids, source_mask = trim_batch(batch["input_ids"], pad_token_id, attention_mask=batch["attention_mask"]) return source_ids, source_mask, y def collate_fn(self, batch) -> Dict[str, torch.Tensor]: input_ids = torch.stack([x["input_ids"] for x in batch]) masks = torch.stack([x["attention_mask"] for x in batch]) target_ids = torch.stack([x["decoder_input_ids"] for x in batch]) pad_token_id = self.pad_token_id y = trim_batch(target_ids, pad_token_id) source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks) batch = { "input_ids": source_ids, "attention_mask": source_mask, "decoder_input_ids": y, } return batch def make_sortish_sampler(self, batch_size): return SortishSampler(self.src_lens, batch_size) class MBartDataset(Seq2SeqDataset): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if self.max_source_length != self.max_target_length: warnings.warn( f"Mbart will ignore max_target_length = {self.max_target_length} and use {self.max_source_length} for both sides." ) def __getitem__(self, index) -> Dict[str, str]: index = index + 1 # linecache starts at 1 source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n") tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n") assert source_line, f"empty source line for index {index}" assert tgt_line, f"empty tgt line for index {index}" return { "tgt_texts": tgt_line, "src_texts": source_line, } def collate_fn(self, batch) -> Dict[str, torch.Tensor]: batch_encoding = self.tokenizer.prepare_translation_batch( [x["src_texts"] for x in batch], src_lang=self.src_lang, tgt_texts=[x["tgt_texts"] for x in batch], tgt_lang=self.tgt_lang, max_length=self.max_source_length, ) return batch_encoding.data class SortishSampler(Sampler): "Go through the text data by order of src length with a bit of randomness. From fastai repo." def __init__(self, data, batch_size): self.data, self.bs = data, batch_size def key(self, i): return self.data[i] def __len__(self) -> int: return len(self.data) def __iter__(self): idxs = np.random.permutation(len(self.data)) sz = self.bs * 50 ck_idx = [idxs[i : i + sz] for i in range(0, len(idxs), sz)] sort_idx = np.concatenate([sorted(s, key=self.key, reverse=True) for s in ck_idx]) sz = self.bs ck_idx = [sort_idx[i : i + sz] for i in range(0, len(sort_idx), sz)] max_ck = np.argmax([self.key(ck[0]) for ck in ck_idx]) # find the chunk with the largest key, ck_idx[0], ck_idx[max_ck] = ck_idx[max_ck], ck_idx[0] # then make sure it goes first. sort_idx = np.concatenate(np.random.permutation(ck_idx[1:])) if len(ck_idx) > 1 else np.array([], dtype=np.int) sort_idx = np.concatenate((ck_idx[0], sort_idx)) return iter(sort_idx) logger = getLogger(__name__) def use_task_specific_params(model, task): """Update config with summarization specific params.""" task_specific_params = model.config.task_specific_params if task_specific_params is not None: pars = task_specific_params.get(task, {}) logger.info(f"using task specific params for {task}: {pars}") model.config.update(pars) def pickle_load(path): """pickle.load(path)""" with open(path, "rb") as f: return pickle.load(f) def pickle_save(obj, path): """pickle.dump(obj, path)""" with open(path, "wb") as f: return pickle.dump(obj, f) def flatten_list(summary_ids: List[List]): return [x for x in itertools.chain.from_iterable(summary_ids)] def save_git_info(folder_path: str) -> None: """Save git information to output_dir/git_log.json""" repo_infos = get_git_info() save_json(repo_infos, os.path.join(folder_path, "git_log.json")) def save_json(content, path): with open(path, "w") as f: json.dump(content, f, indent=4) def load_json(path): with open(path) as f: return json.load(f) def get_git_info(): repo = git.Repo(search_parent_directories=True) repo_infos = { "repo_id": str(repo), "repo_sha": str(repo.head.object.hexsha), "repo_branch": str(repo.active_branch), } return repo_infos ROUGE_KEYS = ["rouge1", "rouge2", "rougeL"] def calculate_rouge(output_lns: List[str], reference_lns: List[str], use_stemmer=True) -> Dict: scorer = rouge_scorer.RougeScorer(ROUGE_KEYS, use_stemmer=use_stemmer) aggregator = scoring.BootstrapAggregator() for reference_ln, output_ln in zip(reference_lns, output_lns): scores = scorer.score(reference_ln, output_ln) aggregator.add_scores(scores) result = aggregator.aggregate() return {k: v.mid.fmeasure for k, v in result.items()} def freeze_params(model: nn.Module): for par in model.parameters(): par.requires_grad = False def grad_status(model: nn.Module) -> Iterable: return (par.requires_grad for par in model.parameters()) def any_requires_grad(model: nn.Module) -> bool: return any(grad_status(model)) def assert_all_frozen(model): model_grads: List[bool] = list(grad_status(model)) n_require_grad = sum(lmap(int, model_grads)) npars = len(model_grads) assert not any(model_grads), f"{n_require_grad/npars:.1%} of {npars} weights require grad" def assert_not_all_frozen(model): model_grads: List[bool] = list(grad_status(model)) npars = len(model_grads) assert any(model_grads), f"none of {npars} weights require grad"