mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00

* added rag WIP * path fix * Formatting / renaming prior to actual work * added rag WIP * path fix * Formatting / renaming prior to actual work * added rag WIP * path fix * Formatting / renaming prior to actual work * added rag WIP * Formatting / renaming prior to actual work * First commit * improve comments * Retrieval evaluation scripts * refactor to include modeling outputs + MPI retriever * Fix rag-token model + refactor * Various fixes + finetuning logic * use_bos fix * Retrieval refactor * Finetuning refactoring and cleanup * Add documentation and cleanup * Remove set_up_rag_env.sh file * Fix retrieval wit HF index * Fix import errors * Fix quality errors * Refactor as per suggestions in https://github.com/huggingface/transformers/pull/6813#issuecomment-687208867 * fix quality * Fix RAG Sequence generation * minor cleanup plus initial tests * fix test * fix tests 2 * Comments fix * post-merge fixes * Improve readme + post-rebase refactor * Extra dependencied for tests * Fix tests * Fix tests 2 * Refactor test requirements * Fix tests 3 * Post-rebase refactor * rename nlp->datasets * RAG integration tests * add tokenizer to slow integration test and allow retriever to run on cpu * add tests; fix position ids warning * change structure * change structure * add from encoder generator * save working solution * make all integration tests pass * add RagTokenizer.save/from_pretrained and RagRetriever.save/from_pretrained * don't save paths * delete unnecessary imports * pass config to AutoTokenizer.from_pretrained for Rag tokenizers * init wiki_dpr only once * hardcode legacy index and passages paths (todo: add the right urls) * finalize config * finalize retriver api and config api * LegacyIndex index download refactor * add dpr to autotokenizer * make from pretrained more flexible * fix ragfortokengeneration * small name changes in tokenizer * add labels to models * change default index name * add retrieval tests * finish token generate * align test with previous version and make all tests pass * add tests * finalize tests * implement thoms suggestions * add first version of test * make first tests work * make retriever platform agnostic * naming * style * add legacy index URL * docstrings + simple retrieval test for distributed * clean model api * add doc_ids to retriever's outputs * fix retrieval tests * finish model outputs * finalize model api * fix generate problem for rag * fix generate for other modles * fix some tests * save intermediate * set generate to default * big refactor generate * delete rag_api * correct pip faiss install * fix auto tokenization test * fix faiss install * fix test * move the distributed logic to examples * model page * docs * finish tests * fix dependencies * fix import in __init__ * Refactor eval_rag and finetune scripts * start docstring * add psutil to test * fix tf test * move require torch to top * fix retrieval test * align naming * finish automodel * fix repo consistency * test ragtokenizer save/load * add rag model output docs * fix ragtokenizer save/load from pretrained * fix tokenizer dir * remove torch in retrieval * fix docs * fixe finetune scripts * finish model docs * finish docs * remove auto model for now * add require torch * remove solved todos * integrate sylvains suggestions * sams comments * correct mistake on purpose * improve README * Add generation test cases * fix rag token * clean token generate * fix test * add note to test * fix attention mask * add t5 test for rag * Fix handling prefix in finetune.py * don't overwrite index_name Co-authored-by: Patrick Lewis <plewis@fb.com> Co-authored-by: Aleksandra Piktus <piktus@devfair0141.h2.fair> Co-authored-by: Aleksandra Piktus <piktus@learnfair5102.h2.fair> Co-authored-by: Aleksandra Piktus <piktus@learnfair5067.h2.fair> Co-authored-by: Your Name <you@example.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Quentin Lhoest <lhoest.q@gmail.com>
188 lines
6.6 KiB
Python
188 lines
6.6 KiB
Python
import linecache
|
|
import re
|
|
import string
|
|
from collections import Counter
|
|
from logging import getLogger
|
|
from pathlib import Path
|
|
from typing import Dict, List
|
|
|
|
import torch
|
|
from torch.utils.data import Dataset
|
|
|
|
from examples.seq2seq.utils import SortishSampler, trim_batch
|
|
from transformers import BartTokenizer, RagTokenizer, T5Tokenizer
|
|
|
|
|
|
def encode_line(tokenizer, line, max_length, padding_side, pad_to_max_length=True, return_tensors="pt"):
|
|
extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) and not line.startswith(" ") else {}
|
|
tokenizer.padding_side = padding_side
|
|
return tokenizer(
|
|
[line],
|
|
max_length=max_length,
|
|
padding="max_length" if pad_to_max_length else None,
|
|
truncation=True,
|
|
return_tensors=return_tensors,
|
|
add_special_tokens=True,
|
|
**extra_kw,
|
|
)
|
|
|
|
|
|
class Seq2SeqDataset(Dataset):
|
|
def __init__(
|
|
self,
|
|
tokenizer,
|
|
data_dir,
|
|
max_source_length,
|
|
max_target_length,
|
|
type_path="train",
|
|
n_obs=None,
|
|
src_lang=None,
|
|
tgt_lang=None,
|
|
prefix="",
|
|
):
|
|
super().__init__()
|
|
self.src_file = Path(data_dir).joinpath(type_path + ".source")
|
|
self.tgt_file = Path(data_dir).joinpath(type_path + ".target")
|
|
self.src_lens = self.get_char_lens(self.src_file)
|
|
self.max_source_length = max_source_length
|
|
self.max_target_length = max_target_length
|
|
assert min(self.src_lens) > 0, f"found empty line in {self.src_file}"
|
|
self.tokenizer = tokenizer
|
|
self.prefix = prefix
|
|
if n_obs is not None:
|
|
self.src_lens = self.src_lens[:n_obs]
|
|
self.src_lang = src_lang
|
|
self.tgt_lang = tgt_lang
|
|
|
|
def __len__(self):
|
|
return len(self.src_lens)
|
|
|
|
def __getitem__(self, index) -> Dict[str, torch.Tensor]:
|
|
index = index + 1 # linecache starts at 1
|
|
source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
|
|
tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
|
|
assert source_line, f"empty source line for index {index}"
|
|
assert tgt_line, f"empty tgt line for index {index}"
|
|
|
|
# Need to add eos token manually for T5
|
|
if isinstance(self.tokenizer, T5Tokenizer):
|
|
source_line += self.tokenizer.eos_token
|
|
tgt_line += self.tokenizer.eos_token
|
|
|
|
# Pad source and target to the right
|
|
source_tokenizer = (
|
|
self.tokenizer.question_encoder if isinstance(self.tokenizer, RagTokenizer) else self.tokenizer
|
|
)
|
|
target_tokenizer = self.tokenizer.generator if isinstance(self.tokenizer, RagTokenizer) else self.tokenizer
|
|
|
|
source_inputs = encode_line(source_tokenizer, source_line, self.max_source_length, "right")
|
|
target_inputs = encode_line(target_tokenizer, tgt_line, self.max_target_length, "right")
|
|
|
|
source_ids = source_inputs["input_ids"].squeeze()
|
|
target_ids = target_inputs["input_ids"].squeeze()
|
|
src_mask = source_inputs["attention_mask"].squeeze()
|
|
return {
|
|
"input_ids": source_ids,
|
|
"attention_mask": src_mask,
|
|
"decoder_input_ids": target_ids,
|
|
}
|
|
|
|
@staticmethod
|
|
def get_char_lens(data_file):
|
|
return [len(x) for x in Path(data_file).open().readlines()]
|
|
|
|
def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
|
|
input_ids = torch.stack([x["input_ids"] for x in batch])
|
|
masks = torch.stack([x["attention_mask"] for x in batch])
|
|
target_ids = torch.stack([x["decoder_input_ids"] for x in batch])
|
|
tgt_pad_token_id = (
|
|
self.tokenizer.generator.pad_token_id
|
|
if isinstance(self.tokenizer, RagTokenizer)
|
|
else self.tokenizer.pad_token_id
|
|
)
|
|
src_pad_token_id = (
|
|
self.tokenizer.question_encoder.pad_token_id
|
|
if isinstance(self.tokenizer, RagTokenizer)
|
|
else self.tokenizer.pad_token_id
|
|
)
|
|
y = trim_batch(target_ids, tgt_pad_token_id)
|
|
source_ids, source_mask = trim_batch(input_ids, src_pad_token_id, attention_mask=masks)
|
|
batch = {
|
|
"input_ids": source_ids,
|
|
"attention_mask": source_mask,
|
|
"decoder_input_ids": y,
|
|
}
|
|
return batch
|
|
|
|
def make_sortish_sampler(self, batch_size):
|
|
return SortishSampler(self.src_lens, batch_size)
|
|
|
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
|
def normalize_answer(s):
|
|
"""Lower text and remove punctuation, articles and extra whitespace."""
|
|
|
|
def remove_articles(text):
|
|
return re.sub(r"\b(a|an|the)\b", " ", text)
|
|
|
|
def white_space_fix(text):
|
|
return " ".join(text.split())
|
|
|
|
def remove_punc(text):
|
|
exclude = set(string.punctuation)
|
|
return "".join(ch for ch in text if ch not in exclude)
|
|
|
|
def lower(text):
|
|
return text.lower()
|
|
|
|
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
|
|
|
|
|
def f1_score(prediction, ground_truth):
|
|
prediction_tokens = normalize_answer(prediction).split()
|
|
ground_truth_tokens = normalize_answer(ground_truth).split()
|
|
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
|
|
num_same = sum(common.values())
|
|
if num_same == 0:
|
|
return 0
|
|
precision = 1.0 * num_same / len(prediction_tokens)
|
|
recall = 1.0 * num_same / len(ground_truth_tokens)
|
|
f1 = (2 * precision * recall) / (precision + recall)
|
|
return f1
|
|
|
|
|
|
def exact_match_score(prediction, ground_truth):
|
|
return normalize_answer(prediction) == normalize_answer(ground_truth)
|
|
|
|
|
|
def calculate_exact_match(output_lns: List[str], reference_lns: List[str]) -> Dict:
|
|
assert len(output_lns) == len(reference_lns)
|
|
em = 0
|
|
for hypo, pred in zip(output_lns, reference_lns):
|
|
em += exact_match_score(hypo, pred)
|
|
if len(output_lns) > 0:
|
|
em /= len(output_lns)
|
|
return {"em": em}
|
|
|
|
|
|
def is_rag_model(model_prefix):
|
|
return model_prefix.startswith("rag")
|
|
|
|
|
|
def set_extra_model_params(extra_params, hparams, config):
|
|
equivalent_param = {p: p for p in extra_params}
|
|
# T5 models don't have `dropout` param, they have `dropout_rate` instead
|
|
equivalent_param["dropout"] = "dropout_rate"
|
|
for p in extra_params:
|
|
if getattr(hparams, p, None):
|
|
if not hasattr(config, p) and not hasattr(config, equivalent_param[p]):
|
|
logger.info("config doesn't have a `{}` attribute".format(p))
|
|
delattr(hparams, p)
|
|
continue
|
|
set_p = p if hasattr(config, p) else equivalent_param[p]
|
|
setattr(config, set_p, getattr(hparams, p))
|
|
delattr(hparams, p)
|
|
return hparams, config
|