mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-06 06:10:04 +06:00

improve unittests for finetuning, especially w.r.t testing frozen parameters fix freeze_embeds for T5 add streamlit setup.cfg
272 lines
8.8 KiB
Python
272 lines
8.8 KiB
Python
import itertools
|
|
import json
|
|
import os
|
|
import pickle
|
|
from pathlib import Path
|
|
from typing import Callable, Dict, Iterable, List
|
|
|
|
import git
|
|
import numpy as np
|
|
import torch
|
|
from rouge_score import rouge_scorer, scoring
|
|
from sacrebleu import corpus_bleu
|
|
from torch import nn
|
|
from torch.utils.data import Dataset, Sampler
|
|
from tqdm import tqdm
|
|
|
|
from transformers import BartTokenizer
|
|
|
|
|
|
def encode_file(
|
|
tokenizer,
|
|
data_path,
|
|
max_length,
|
|
pad_to_max_length=True,
|
|
return_tensors="pt",
|
|
overwrite_cache=False,
|
|
prefix="",
|
|
tok_name="",
|
|
):
|
|
extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {}
|
|
cache_path = Path(f"{data_path}_{tok_name}{max_length}.pt")
|
|
if not overwrite_cache and cache_path.exists():
|
|
try:
|
|
examples = torch.load(cache_path)
|
|
assert isinstance(examples, list)
|
|
return examples
|
|
|
|
except Exception:
|
|
print(f"failed to load from {cache_path}, retokenizing {data_path}")
|
|
data_path = Path(data_path)
|
|
|
|
lns = lmap(str.strip, data_path.open().readlines())
|
|
lns = [prefix + text for text in lns]
|
|
assert lns, f"found empty file at {data_path}"
|
|
examples = []
|
|
for text in tqdm(lns, desc=f"Tokenizing {data_path.name}"):
|
|
tokenized = tokenizer(
|
|
[text],
|
|
max_length=max_length,
|
|
padding="max_length" if pad_to_max_length else None,
|
|
truncation=True,
|
|
return_tensors=return_tensors,
|
|
**extra_kw,
|
|
)
|
|
assert tokenized.input_ids.shape[1] == max_length
|
|
examples.append(tokenized)
|
|
torch.save(lmap(dict, examples), cache_path.open("wb"))
|
|
return examples
|
|
|
|
|
|
def lmap(f: Callable, x: Iterable) -> List:
|
|
"""list(map(f, x))"""
|
|
return list(map(f, x))
|
|
|
|
|
|
def calculate_bleu_score(output_lns, refs_lns, **kwargs) -> dict:
|
|
"""Uses sacrebleu's corpus_bleu implementation."""
|
|
return {"bleu": corpus_bleu(output_lns, [refs_lns], **kwargs).score}
|
|
|
|
|
|
def trim_batch(
|
|
input_ids, pad_token_id, attention_mask=None,
|
|
):
|
|
"""Remove columns that are populated exclusively by pad_token_id"""
|
|
keep_column_mask = input_ids.ne(pad_token_id).any(dim=0)
|
|
if attention_mask is None:
|
|
return input_ids[:, keep_column_mask]
|
|
else:
|
|
return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])
|
|
|
|
|
|
class SummarizationDataset(Dataset):
|
|
def __init__(
|
|
self,
|
|
tokenizer,
|
|
data_dir,
|
|
type_path="train",
|
|
max_source_length=1024,
|
|
max_target_length=56,
|
|
n_obs=None,
|
|
overwrite_cache=False,
|
|
prefix="",
|
|
src_lang=None,
|
|
tgt_lang=None,
|
|
):
|
|
super().__init__()
|
|
# FIXME: the rstrip logic strips all the chars, it seems.
|
|
tok_name = tokenizer.__class__.__name__.lower().rstrip("tokenizer")
|
|
if hasattr(tokenizer, "set_lang") and src_lang is not None:
|
|
tokenizer.set_lang(src_lang) # HACK: only applies to mbart
|
|
self.source = encode_file(
|
|
tokenizer,
|
|
os.path.join(data_dir, type_path + ".source"),
|
|
max_source_length,
|
|
overwrite_cache=overwrite_cache,
|
|
prefix=prefix,
|
|
tok_name=tok_name,
|
|
)
|
|
tgt_path = os.path.join(data_dir, type_path + ".target")
|
|
if hasattr(tokenizer, "set_lang"):
|
|
assert tgt_lang is not None, "--tgt_lang must be passed to build a translation"
|
|
tokenizer.set_lang(tgt_lang) # HACK: only applies to mbart
|
|
self.target = encode_file(
|
|
tokenizer, tgt_path, max_target_length, overwrite_cache=overwrite_cache, tok_name=tok_name
|
|
)
|
|
if n_obs is not None:
|
|
self.source = self.source[:n_obs]
|
|
self.target = self.target[:n_obs]
|
|
self.pad_token_id = tokenizer.pad_token_id
|
|
|
|
def __len__(self):
|
|
return len(self.source)
|
|
|
|
def __getitem__(self, index):
|
|
source_ids = self.source[index]["input_ids"].squeeze()
|
|
target_ids = self.target[index]["input_ids"].squeeze()
|
|
src_mask = self.source[index]["attention_mask"].squeeze()
|
|
return {"input_ids": source_ids, "attention_mask": src_mask, "decoder_input_ids": target_ids}
|
|
|
|
@staticmethod
|
|
def trim_seq2seq_batch(batch, pad_token_id):
|
|
y = trim_batch(batch["decoder_input_ids"], pad_token_id)
|
|
source_ids, source_mask = trim_batch(batch["input_ids"], pad_token_id, attention_mask=batch["attention_mask"])
|
|
return source_ids, source_mask, y
|
|
|
|
def collate_fn(self, batch) -> dict:
|
|
input_ids = torch.stack([x["input_ids"] for x in batch])
|
|
masks = torch.stack([x["attention_mask"] for x in batch])
|
|
target_ids = torch.stack([x["decoder_input_ids"] for x in batch])
|
|
pad_token_id = self.pad_token_id
|
|
y = trim_batch(target_ids, pad_token_id)
|
|
source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks)
|
|
batch = {"input_ids": source_ids, "attention_mask": source_mask, "decoder_input_ids": y}
|
|
return batch
|
|
|
|
@property
|
|
def src_lens(self): # Can delete?
|
|
return lmap(len, self.source)
|
|
|
|
@property
|
|
def tgt_lens(self):
|
|
return lmap(len, self.target)
|
|
|
|
def make_sortish_sampler(self, batch_size):
|
|
return SortishSampler(self.source, batch_size)
|
|
|
|
|
|
class SortishSampler(Sampler):
|
|
"Go through the text data by order of src length with a bit of randomness. From fastai repo."
|
|
|
|
def __init__(self, data, batch_size):
|
|
self.data, self.bs = data, batch_size
|
|
|
|
def key(self, i):
|
|
return len(self.data[i])
|
|
|
|
def __len__(self) -> int:
|
|
return len(self.data)
|
|
|
|
def __iter__(self):
|
|
idxs = np.random.permutation(len(self.data))
|
|
sz = self.bs * 50
|
|
ck_idx = [idxs[i : i + sz] for i in range(0, len(idxs), sz)]
|
|
sort_idx = np.concatenate([sorted(s, key=self.key, reverse=True) for s in ck_idx])
|
|
sz = self.bs
|
|
ck_idx = [sort_idx[i : i + sz] for i in range(0, len(sort_idx), sz)]
|
|
max_ck = np.argmax([self.key(ck[0]) for ck in ck_idx]) # find the chunk with the largest key,
|
|
ck_idx[0], ck_idx[max_ck] = ck_idx[max_ck], ck_idx[0] # then make sure it goes first.
|
|
sort_idx = np.concatenate(np.random.permutation(ck_idx[1:])) if len(ck_idx) > 1 else np.array([], dtype=np.int)
|
|
sort_idx = np.concatenate((ck_idx[0], sort_idx))
|
|
return iter(sort_idx)
|
|
|
|
|
|
def use_task_specific_params(model, task):
|
|
# update config with summarization specific params
|
|
task_specific_params = model.config.task_specific_params
|
|
if task_specific_params is not None:
|
|
model.config.update(task_specific_params.get(task, {}))
|
|
|
|
|
|
def pickle_load(path):
|
|
"""pickle.load(path)"""
|
|
with open(path, "rb") as f:
|
|
return pickle.load(f)
|
|
|
|
|
|
def pickle_save(obj, path):
|
|
"""pickle.dump(obj, path)"""
|
|
with open(path, "wb") as f:
|
|
return pickle.dump(obj, f)
|
|
|
|
|
|
def flatten_list(summary_ids: List[List]):
|
|
return [x for x in itertools.chain.from_iterable(summary_ids)]
|
|
|
|
|
|
def save_git_info(folder_path: str) -> None:
|
|
"""Save git information to output_dir/git_log.json"""
|
|
repo_infos = get_git_info()
|
|
save_json(repo_infos, os.path.join(folder_path, "git_log.json"))
|
|
|
|
|
|
def save_json(content, path):
|
|
with open(path, "w") as f:
|
|
json.dump(content, f, indent=4)
|
|
|
|
|
|
def load_json(path):
|
|
with open(path) as f:
|
|
return json.load(f)
|
|
|
|
|
|
def get_git_info():
|
|
repo = git.Repo(search_parent_directories=True)
|
|
repo_infos = {
|
|
"repo_id": str(repo),
|
|
"repo_sha": str(repo.head.object.hexsha),
|
|
"repo_branch": str(repo.active_branch),
|
|
}
|
|
return repo_infos
|
|
|
|
|
|
ROUGE_KEYS = ["rouge1", "rouge2", "rougeL"]
|
|
|
|
|
|
def calculate_rouge(output_lns: List[str], reference_lns: List[str], use_stemmer=True) -> Dict:
|
|
scorer = rouge_scorer.RougeScorer(ROUGE_KEYS, use_stemmer=use_stemmer)
|
|
aggregator = scoring.BootstrapAggregator()
|
|
|
|
for reference_ln, output_ln in zip(reference_lns, output_lns):
|
|
scores = scorer.score(reference_ln, output_ln)
|
|
aggregator.add_scores(scores)
|
|
|
|
result = aggregator.aggregate()
|
|
return {k: v.mid.fmeasure for k, v in result.items()}
|
|
|
|
|
|
def freeze_params(model: nn.Module):
|
|
for par in model.parameters():
|
|
par.requires_grad = False
|
|
|
|
|
|
def grad_status(model: nn.Module) -> Iterable:
|
|
return (par.requires_grad for par in model.parameters())
|
|
|
|
|
|
def any_requires_grad(model: nn.Module) -> bool:
|
|
return any(grad_status(model))
|
|
|
|
|
|
def assert_all_frozen(model):
|
|
model_grads: List[bool] = list(grad_status(model))
|
|
n_require_grad = sum(lmap(int, model_grads))
|
|
npars = len(model_grads)
|
|
assert not any(model_grads), f"{n_require_grad/npars:.1%} of {npars} weights require grad"
|
|
|
|
|
|
def assert_not_all_frozen(model):
|
|
model_grads: List[bool] = list(grad_status(model))
|
|
npars = len(model_grads)
|
|
assert any(model_grads), f"none of {npars} weights require grad"
|