RAG-2nd2end-revamp (#11893)

* initial

* code quality test

* code quality

* added test functions in test_modeling_rag.py and test_retrieval_rag.py to test end2end retreiver

* minor change in test_modeling_rag

* fixed tests

* Update examples/research_projects/rag-end2end-retriever/README.md

typo corrected as suggested by lhoestq

Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com>

* Update examples/research_projects/rag-end2end-retriever/finetune_rag.py

type change suggested by lhoestq

Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com>

* Update src/transformers/models/rag/retrieval_rag.py

Adding this change as mentioned by lhoestq.

Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com>

* completed the minor changes suggested by the reviewers

Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com>
This commit is contained in:
Shamane Siri 2021-06-01 18:32:26 +12:00 committed by GitHub
parent ad25fd62bd
commit 9ec0f01b6c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 2808 additions and 25 deletions

View File

@ -0,0 +1,47 @@
# End-to-End finetuning of RAG (including DPR retriever) for Question Answering.
This finetuning script is actively maintained by [Shamane Siri](https://github.com/shamanez). Feel free to ask questions on the [Forum](https://discuss.huggingface.co/) or post an issue on [GitHub](https://github.com/huggingface/transformers/issues/new/choose) and tag @shamanez.
Others that helped out: Patrick von Platen (@patrickvonplaten), Quentin Lhoest (@lhoestq), and Rivindu Weerasekera (@rivinduw)
The original RAG implementation is able to train the question encoder and generator end-to-end.
This extension enables complete end-to-end training of RAG including the context encoder in the retriever component.
Please read the [accompanying blog post](https://shamanesiri.medium.com/how-to-finetune-the-entire-rag-architecture-including-dpr-retriever-4b4385322552) for details on this implementation.
The original RAG code has also been modified to work with the latest versions of pytorch lightning (version 1.2.10) and RAY (version 1.3.0). All other implementation details remain the same as the [original RAG code](https://github.com/huggingface/transformers/tree/master/examples/research_projects/rag).
Read more about RAG at https://arxiv.org/abs/2005.11401.
This code can be modified to experiment with other research on retrival augmented models which include training of the retriever (e.g. [REALM](https://arxiv.org/abs/2002.08909) and [MARGE](https://arxiv.org/abs/2006.15020)).
To start training, use the bash script (finetune_rag_ray_end2end.sh) in this folder. This script also includes descriptions on each command-line argument used.
# Testing
The following two bash scripts can be used to quickly test the implementation.
1. sh ./test_run/test_rag_new_features.sh
- Tests the newly added functions (set_context_encoder and set_context_encoder_tokenizer) related to modeling rag.
- This is sufficient to check the model's ability to use the set functions correctly.
2. sh ./test_run/test_finetune.sh script
- Tests the full end-to-end fine-tuning ability with a dummy knowlendge-base and dummy training dataset (check test_dir directory).
- Users can replace the dummy dataset and knowledge-base with their own to do their own finetuning.
# Comparison of end2end RAG (including DPR finetuning) VS original-RAG
We conducted a simple experiment to investigate the effectiveness of this end2end training extension using the SQuAD dataset. Please execute the following steps to reproduce the results.
- Create a knowledge-base using all the context passages in the SQuAD dataset with their respective titles.
- Use the question-answer pairs as training data.
- Train the system for 10 epochs.
- Test the Exact Match (EM) score with the SQuAD dataset's validation set.
- Training dataset, the knowledge-base, and hyperparameters used in experiments can be accessed from [here](https://drive.google.com/drive/folders/1qyzV-PaEARWvaU_jjpnU_NUS3U_dSjtG?usp=sharing).
# Results
- We train both models for 10 epochs.
| Model Type | EM-Score|
| --------------------| --------|
| RAG-original | 28.12 |
| RAG-end2end with DPR| 40.02 |

View File

@ -0,0 +1,119 @@
import logging
from pathlib import Path
import numpy as np
import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.utilities import rank_zero_only
from utils_rag import save_json
def count_trainable_parameters(model):
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
return params
logger = logging.getLogger(__name__)
def get_checkpoint_callback(output_dir, metric):
"""Saves the best model by validation EM score."""
if metric == "rouge2":
exp = "{val_avg_rouge2:.4f}-{step_count}"
elif metric == "bleu":
exp = "{val_avg_bleu:.4f}-{step_count}"
elif metric == "em":
exp = "{val_avg_em:.4f}-{step_count}"
elif metric == "loss":
exp = "{val_avg_loss:.4f}-{step_count}"
else:
raise NotImplementedError(
f"seq2seq callbacks only support rouge2 and bleu, got {metric}, You can make your own by adding to this function."
)
checkpoint_callback = ModelCheckpoint(
dirpath=output_dir,
filename=exp,
monitor=f"val_{metric}",
mode="max",
save_top_k=1,
every_n_val_epochs=1, # works only with PL > 1.3
)
return checkpoint_callback
def get_early_stopping_callback(metric, patience):
return EarlyStopping(
monitor=f"val_{metric}", # does this need avg?
mode="min" if "loss" in metric else "max",
patience=patience,
verbose=True,
)
class Seq2SeqLoggingCallback(pl.Callback):
def on_batch_end(self, trainer, pl_module):
lrs = {f"lr_group_{i}": param["lr"] for i, param in enumerate(pl_module.trainer.optimizers[0].param_groups)}
pl_module.logger.log_metrics(lrs)
@rank_zero_only
def _write_logs(
self, trainer: pl.Trainer, pl_module: pl.LightningModule, type_path: str, save_generations=True
) -> None:
logger.info(f"***** {type_path} results at step {trainer.global_step:05d} *****")
metrics = trainer.callback_metrics
trainer.logger.log_metrics({k: v for k, v in metrics.items() if k not in ["log", "progress_bar", "preds"]})
# Log results
od = Path(pl_module.hparams.output_dir)
if type_path == "test":
results_file = od / "test_results.txt"
generations_file = od / "test_generations.txt"
else:
# this never gets hit. I prefer not to save intermediate generations, and results are in metrics.json
# If people want this it will be easy enough to add back.
results_file = od / f"{type_path}_results/{trainer.global_step:05d}.txt"
generations_file = od / f"{type_path}_generations/{trainer.global_step:05d}.txt"
results_file.parent.mkdir(exist_ok=True)
generations_file.parent.mkdir(exist_ok=True)
with open(results_file, "a+") as writer:
for key in sorted(metrics):
if key in ["log", "progress_bar", "preds"]:
continue
val = metrics[key]
if isinstance(val, torch.Tensor):
val = val.item()
msg = f"{key}: {val:.6f}\n"
writer.write(msg)
if not save_generations:
return
if "preds" in metrics:
content = "\n".join(metrics["preds"])
generations_file.open("w+").write(content)
@rank_zero_only
def on_train_start(self, trainer, pl_module):
try:
npars = pl_module.model.model.num_parameters()
except AttributeError:
npars = pl_module.model.num_parameters()
n_trainable_pars = count_trainable_parameters(pl_module)
# mp stands for million parameters
trainer.logger.log_metrics({"n_params": npars, "mp": npars / 1e6, "grad_mp": n_trainable_pars / 1e6})
@rank_zero_only
def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
save_json(pl_module.metrics, pl_module.metrics_save_path)
return self._write_logs(trainer, pl_module, "test")
@rank_zero_only
def on_validation_end(self, trainer: pl.Trainer, pl_module):
save_json(pl_module.metrics, pl_module.metrics_save_path)
# Uncommenting this will save val generations
# return self._write_logs(trainer, pl_module, "valid")

View File

@ -0,0 +1,185 @@
import logging
import random
import ray
from transformers import RagConfig, RagRetriever, RagTokenizer
from transformers.models.rag.retrieval_rag import CustomHFIndex
logger = logging.getLogger(__name__)
class RayRetriever:
def __init__(self):
self.initialized = False
def create_rag_retriever(self, config, question_encoder_tokenizer, generator_tokenizer, index):
if not self.initialized:
self.retriever = RagRetriever(
config,
question_encoder_tokenizer=question_encoder_tokenizer,
generator_tokenizer=generator_tokenizer,
index=index,
init_retrieval=False,
)
self.initialized = True
def init_retrieval(self):
self.retriever.index.init_index()
def clear_object(self):
# delete the old self.retriever object before assigning the new index
del self.retriever
self.initialized = False
def retrieve(self, question_hidden_states, n_docs):
doc_ids, retrieved_doc_embeds = self.retriever._main_retrieve(question_hidden_states, n_docs)
doc_dicts = self.retriever.index.get_doc_dicts(doc_ids)
return doc_ids, retrieved_doc_embeds, doc_dicts
class RagRayDistributedRetriever(RagRetriever):
"""
A distributed retriever built on top of the ``Ray`` API, a library
for building distributed applications (https://docs.ray.io/en/master/).
package. During training, all training workers initialize their own
instance of a `RagRayDistributedRetriever`, and each instance of
this distributed retriever shares a common set of Retrieval Ray
Actors (https://docs.ray.io/en/master/walkthrough.html#remote
-classes-actors) that load the index on separate processes. Ray
handles the communication between the `RagRayDistributedRetriever`
instances and the remote Ray actors. If training is done in a
non-distributed setup, the index will simply be loaded in the same
process as the training worker and Ray will not be used.
Args:
config (:class:`~transformers.RagConfig`):
The configuration of the RAG model this Retriever is used with. Contains parameters indicating which ``Index`` to build.
question_encoder_tokenizer (:class:`~transformers.PreTrainedTokenizer`):
The tokenizer that was used to tokenize the question.
It is used to decode the question and then use the generator_tokenizer.
generator_tokenizer (:class:`~transformers.PreTrainedTokenizer`):
The tokenizer used for the generator part of the RagModel.
retrieval_workers (:obj:`List[ray.ActorClass(RayRetriever)]`): A list of already initialized `RayRetriever` actors.
These actor classes run on remote processes and are responsible for performing the index lookup.
index (:class:`~transformers.retrieval_rag.Index`, optional, defaults to the one defined by the configuration):
If specified, use this index instead of the one built using the configuration
"""
def __init__(self, config, question_encoder_tokenizer, generator_tokenizer, retrieval_workers, index=None):
if index is not None and index.is_initialized() and len(retrieval_workers) > 0:
raise ValueError(
"When using Ray for distributed fine-tuning, "
"you'll need to provide the paths instead, "
"as the dataset and the index are loaded "
"separately. More info in examples/rag/use_own_knowledge_dataset.py "
)
super().__init__(
config,
question_encoder_tokenizer=question_encoder_tokenizer,
generator_tokenizer=generator_tokenizer,
index=index,
init_retrieval=False,
)
self.retrieval_workers = retrieval_workers
self.question_encoder_tokenizer = question_encoder_tokenizer
self.generator_tokenizer = generator_tokenizer
if len(self.retrieval_workers) > 0:
ray.get(
[
worker.create_rag_retriever.remote(config, question_encoder_tokenizer, generator_tokenizer, index)
for worker in self.retrieval_workers
]
)
def init_retrieval(self):
"""
Retriever initialization function, needs to be called from the
training process. This function triggers retrieval initialization
for all retrieval actors if using distributed setting, or loads
index into current process if training is not distributed.
"""
logger.info("initializing retrieval")
if len(self.retrieval_workers) > 0:
ray.get([worker.init_retrieval.remote() for worker in self.retrieval_workers])
else:
# Non-distributed training. Load index into this same process.
self.index.init_index()
def retrieve(self, question_hidden_states, n_docs):
"""
Retrieves documents for specified ``question_hidden_states``. If
running training with multiple workers, a random retrieval actor is
selected to perform the index lookup and return the result.
Args:
question_hidden_states (:obj:`np.ndarray` of shape :obj:`(batch_size, vector_size)`):
A batch of query vectors to retrieve with.
n_docs (:obj:`int`):
The number of docs retrieved per query.
Output:
retrieved_doc_embeds (:obj:`np.ndarray` of shape :obj:`(batch_size, n_docs, dim)`
The retrieval embeddings of the retrieved docs per query.
doc_ids (:obj:`np.ndarray` of shape :obj:`batch_size, n_docs`)
The ids of the documents in the index
doc_dicts (:obj:`List[dict]`):
The retrieved_doc_embeds examples per query.
"""
if len(self.retrieval_workers) > 0:
# Select a random retrieval actor.
random_worker = self.retrieval_workers[random.randint(0, len(self.retrieval_workers) - 1)]
doc_ids, retrieved_doc_embeds, doc_dicts = ray.get(
random_worker.retrieve.remote(question_hidden_states, n_docs)
)
else:
doc_ids, retrieved_doc_embeds = self._main_retrieve(question_hidden_states, n_docs)
doc_dicts = self.index.get_doc_dicts(doc_ids)
return retrieved_doc_embeds, doc_ids, doc_dicts
@classmethod
def get_tokenizers(cls, retriever_name_or_path, indexed_dataset=None, **kwargs):
return super(RagRayDistributedRetriever, cls).get_tokenizers(retriever_name_or_path, indexed_dataset, **kwargs)
@classmethod
def from_pretrained(cls, retriever_name_or_path, actor_handles, indexed_dataset=None, **kwargs):
config = kwargs.pop("config", None) or RagConfig.from_pretrained(retriever_name_or_path, **kwargs)
rag_tokenizer = RagTokenizer.from_pretrained(retriever_name_or_path, config=config)
question_encoder_tokenizer = rag_tokenizer.question_encoder
generator_tokenizer = rag_tokenizer.generator
if indexed_dataset is not None:
config.index_name = "custom"
index = CustomHFIndex(config.retrieval_vector_size, indexed_dataset)
else:
index = cls._build_index(config)
return cls(
config,
question_encoder_tokenizer=question_encoder_tokenizer,
generator_tokenizer=generator_tokenizer,
retrieval_workers=actor_handles,
index=index,
)
def re_load(self):
logger.info("re-loading the new dataset with embeddings")
# access from the training loop
ray.get([worker.clear_object.remote() for worker in self.retrieval_workers])
# build the index object again
index = self._build_index(self.config)
ray.get(
[
worker.create_rag_retriever.remote(
self.config, self.question_encoder_tokenizer, self.generator_tokenizer, index
)
for worker in self.retrieval_workers
]
)

View File

@ -0,0 +1,312 @@
""" Evaluation script for RAG models."""
import argparse
import ast
import logging
import os
import sys
import pandas as pd
import torch
from tqdm import tqdm
from transformers import BartForConditionalGeneration, RagRetriever, RagSequenceForGeneration, RagTokenForGeneration
from transformers import logging as transformers_logging
sys.path.append(os.path.join(os.getcwd())) # noqa: E402 # isort:skip
from utils_rag import exact_match_score, f1_score # noqa: E402 # isort:skip
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
transformers_logging.set_verbosity_info()
def infer_model_type(model_name_or_path):
if "token" in model_name_or_path:
return "rag_token"
if "sequence" in model_name_or_path:
return "rag_sequence"
if "bart" in model_name_or_path:
return "bart"
return None
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
return max(metric_fn(prediction, gt) for gt in ground_truths)
def get_scores(args, preds_path, gold_data_path):
hypos = [line.strip() for line in open(preds_path, "r").readlines()]
answers = []
if args.gold_data_mode == "qa":
data = pd.read_csv(gold_data_path, sep="\t", header=None)
for answer_list in data[1]:
ground_truths = ast.literal_eval(answer_list)
answers.append(ground_truths)
else:
references = [line.strip() for line in open(gold_data_path, "r").readlines()]
answers = [[reference] for reference in references]
f1 = em = total = 0
for prediction, ground_truths in zip(hypos, answers):
total += 1
em += metric_max_over_ground_truths(exact_match_score, prediction, ground_truths)
f1 += metric_max_over_ground_truths(f1_score, prediction, ground_truths)
em = 100.0 * em / total
f1 = 100.0 * f1 / total
logger.info(f"F1: {f1:.2f}")
logger.info(f"EM: {em:.2f}")
def get_precision_at_k(args, preds_path, gold_data_path):
k = args.k
hypos = [line.strip() for line in open(preds_path, "r").readlines()]
references = [line.strip() for line in open(gold_data_path, "r").readlines()]
em = total = 0
for hypo, reference in zip(hypos, references):
hypo_provenance = set(hypo.split("\t")[:k])
ref_provenance = set(reference.split("\t"))
total += 1
em += len(hypo_provenance & ref_provenance) / k
em = 100.0 * em / total
logger.info(f"Precision@{k}: {em: .2f}")
def evaluate_batch_retrieval(args, rag_model, questions):
def strip_title(title):
if title.startswith('"'):
title = title[1:]
if title.endswith('"'):
title = title[:-1]
return title
retriever_input_ids = rag_model.retriever.question_encoder_tokenizer.batch_encode_plus(
questions,
return_tensors="pt",
padding=True,
truncation=True,
)["input_ids"].to(args.device)
question_enc_outputs = rag_model.rag.question_encoder(retriever_input_ids)
question_enc_pool_output = question_enc_outputs[0]
result = rag_model.retriever(
retriever_input_ids,
question_enc_pool_output.cpu().detach().to(torch.float32).numpy(),
prefix=rag_model.rag.generator.config.prefix,
n_docs=rag_model.config.n_docs,
return_tensors="pt",
)
all_docs = rag_model.retriever.index.get_doc_dicts(result.doc_ids)
provenance_strings = []
for docs in all_docs:
provenance = [strip_title(title) for title in docs["title"]]
provenance_strings.append("\t".join(provenance))
return provenance_strings
def evaluate_batch_e2e(args, rag_model, questions):
with torch.no_grad():
inputs_dict = rag_model.retriever.question_encoder_tokenizer.batch_encode_plus(
questions, return_tensors="pt", padding=True, truncation=True
)
input_ids = inputs_dict.input_ids.to(args.device)
attention_mask = inputs_dict.attention_mask.to(args.device)
outputs = rag_model.generate( # rag_model overwrites generate
input_ids,
attention_mask=attention_mask,
num_beams=args.num_beams,
min_length=args.min_length,
max_length=args.max_length,
early_stopping=False,
num_return_sequences=1,
bad_words_ids=[[0, 0]], # BART likes to repeat BOS tokens, dont allow it to generate more than one
)
answers = rag_model.retriever.generator_tokenizer.batch_decode(outputs, skip_special_tokens=True)
if args.print_predictions:
for q, a in zip(questions, answers):
logger.info("Q: {} - A: {}".format(q, a))
return answers
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_type",
choices=["rag_sequence", "rag_token", "bart"],
type=str,
help="RAG model type: rag_sequence, rag_token or bart, if none specified, the type is inferred from the model_name_or_path",
)
parser.add_argument(
"--index_name",
default=None,
choices=["exact", "compressed", "legacy"],
type=str,
help="RAG model retriever type",
)
parser.add_argument(
"--index_path",
default=None,
type=str,
help="Path to the retrieval index",
)
parser.add_argument("--n_docs", default=5, type=int, help="Number of retrieved docs")
parser.add_argument(
"--model_name_or_path",
default=None,
type=str,
required=True,
help="Path to pretrained checkpoints or model identifier from huggingface.co/models",
)
parser.add_argument(
"--eval_mode",
choices=["e2e", "retrieval"],
default="e2e",
type=str,
help="Evaluation mode, e2e calculates exact match and F1 of the downstream task, retrieval calculates precision@k.",
)
parser.add_argument("--k", default=1, type=int, help="k for the precision@k calculation")
parser.add_argument(
"--evaluation_set",
default=None,
type=str,
required=True,
help="Path to a file containing evaluation samples",
)
parser.add_argument(
"--gold_data_path",
default=None,
type=str,
required=True,
help="Path to a tab-separated file with gold samples",
)
parser.add_argument(
"--gold_data_mode",
default="qa",
type=str,
choices=["qa", "ans"],
help="Format of the gold data file"
"qa - a single line in the following format: question [tab] answer_list"
"ans - a single line of the gold file contains the expected answer string",
)
parser.add_argument(
"--predictions_path",
type=str,
default="predictions.txt",
help="Name of the predictions file, to be stored in the checkpoints directory",
)
parser.add_argument(
"--eval_all_checkpoints",
action="store_true",
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
)
parser.add_argument(
"--eval_batch_size",
default=8,
type=int,
help="Batch size per GPU/CPU for evaluation.",
)
parser.add_argument(
"--recalculate",
help="Recalculate predictions even if the prediction file exists",
action="store_true",
)
parser.add_argument(
"--num_beams",
default=4,
type=int,
help="Number of beams to be used when generating answers",
)
parser.add_argument("--min_length", default=1, type=int, help="Min length of the generated answers")
parser.add_argument("--max_length", default=50, type=int, help="Max length of the generated answers")
parser.add_argument(
"--print_predictions",
action="store_true",
help="If True, prints predictions while evaluating.",
)
parser.add_argument(
"--print_docs",
action="store_true",
help="If True, prints docs retried while generating.",
)
args = parser.parse_args()
args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
return args
def main(args):
model_kwargs = {}
if args.model_type is None:
args.model_type = infer_model_type(args.model_name_or_path)
assert args.model_type is not None
if args.model_type.startswith("rag"):
model_class = RagTokenForGeneration if args.model_type == "rag_token" else RagSequenceForGeneration
model_kwargs["n_docs"] = args.n_docs
if args.index_name is not None:
model_kwargs["index_name"] = args.index_name
if args.index_path is not None:
model_kwargs["index_path"] = args.index_path
else:
model_class = BartForConditionalGeneration
checkpoints = (
[f.path for f in os.scandir(args.model_name_or_path) if f.is_dir()]
if args.eval_all_checkpoints
else [args.model_name_or_path]
)
logger.info("Evaluate the following checkpoints: %s", checkpoints)
score_fn = get_scores if args.eval_mode == "e2e" else get_precision_at_k
evaluate_batch_fn = evaluate_batch_e2e if args.eval_mode == "e2e" else evaluate_batch_retrieval
for checkpoint in checkpoints:
if os.path.exists(args.predictions_path) and (not args.recalculate):
logger.info("Calculating metrics based on an existing predictions file: {}".format(args.predictions_path))
score_fn(args, args.predictions_path, args.gold_data_path)
continue
logger.info("***** Running evaluation for {} *****".format(checkpoint))
logger.info(" Batch size = %d", args.eval_batch_size)
logger.info(" Predictions will be stored under {}".format(args.predictions_path))
if args.model_type.startswith("rag"):
retriever = RagRetriever.from_pretrained(checkpoint, **model_kwargs)
model = model_class.from_pretrained(checkpoint, retriever=retriever, **model_kwargs)
model.retriever.init_retrieval()
else:
model = model_class.from_pretrained(checkpoint, **model_kwargs)
model.to(args.device)
with open(args.evaluation_set, "r") as eval_file, open(args.predictions_path, "w") as preds_file:
questions = []
for line in tqdm(eval_file):
questions.append(line.strip())
if len(questions) == args.eval_batch_size:
answers = evaluate_batch_fn(args, model, questions)
preds_file.write("\n".join(answers) + "\n")
preds_file.flush()
questions = []
if len(questions) > 0:
answers = evaluate_batch_fn(args, model, questions)
preds_file.write("\n".join(answers))
preds_file.flush()
score_fn(args, args.predictions_path, args.gold_data_path)
if __name__ == "__main__":
args = get_args()
main(args)

View File

@ -0,0 +1,789 @@
"""Finetuning script for RAG models. Adapted from examples.seq2seq.finetune.py"""
import argparse
import copy
import json
import logging
import multiprocessing
import os
import random
import shutil
import sys
import time
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, List, Tuple
import numpy as np
import pytorch_lightning as pl
import torch
import torch.distributed as dist
from datasets import concatenate_datasets, load_from_disk
from torch.utils.data import DataLoader
from transformers import (
AutoConfig,
AutoTokenizer,
BartForConditionalGeneration,
BatchEncoding,
DPRConfig,
DPRContextEncoder,
DPRContextEncoderTokenizerFast,
RagConfig,
RagSequenceForGeneration,
RagTokenForGeneration,
RagTokenizer,
T5ForConditionalGeneration,
)
from transformers import logging as transformers_logging
from transformers.integrations import is_ray_available
if is_ray_available():
import ray
from distributed_ray_retriever import RagRayDistributedRetriever, RayRetriever
from glob import glob
from callbacks_rag import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
from kb_encode_utils import add_index, embed_update
from lightning_base import BaseTransformer, add_generic_args, generic_train
from pynvml import nvmlDeviceGetCount, nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo, nvmlInit
from utils_rag import (
Seq2SeqDataset,
calculate_exact_match,
get_git_info,
is_rag_model,
lmap,
pickle_save,
save_git_info,
save_json,
set_extra_model_params,
)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
transformers_logging.set_verbosity_info()
sys.path.insert(2, str(Path(__file__).resolve().parents[1]))
isEmUpdateBusy = False
isAddIndexBusy = False
processes = []
threadHandle_index = None
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self
class GenerativeQAModule(BaseTransformer):
mode = "generative_qa"
loss_names = ["loss"]
metric_names = ["em"]
val_metric = "em"
def __init__(self, hparams, **kwargs):
# when loading from a pytorch lightning checkpoint, hparams are passed as dict
if isinstance(hparams, dict):
hparams = AttrDict(hparams)
if hparams.model_type == "rag_sequence":
self.model_class = RagSequenceForGeneration
elif hparams.model_type == "rag_token":
self.model_class = RagTokenForGeneration
elif hparams.model_type == "bart":
self.model_class = BartForConditionalGeneration
else:
self.model_class = T5ForConditionalGeneration
self.is_rag_model = is_rag_model(hparams.model_type)
config_class = RagConfig if self.is_rag_model else AutoConfig
config = config_class.from_pretrained(hparams.model_name_or_path)
# set retriever parameters
config.index_name = hparams.index_name or config.index_name
config.passages_path = hparams.passages_path or config.passages_path
config.index_path = hparams.index_path or config.index_path
config.use_dummy_dataset = hparams.use_dummy_dataset
# set extra_model_params for generator configs and load_model
extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "attention_dropout", "dropout")
if self.is_rag_model:
if hparams.prefix is not None:
config.generator.prefix = hparams.prefix
config.label_smoothing = hparams.label_smoothing
hparams, config.generator = set_extra_model_params(extra_model_params, hparams, config.generator)
if hparams.distributed_retriever == "ray":
# The Ray retriever needs the handles to the retriever actors.
retriever = RagRayDistributedRetriever.from_pretrained(
hparams.model_name_or_path, hparams.actor_handles, config=config
)
if hparams.end2end:
ctx_encoder_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained(
"facebook/dpr-ctx_encoder-multiset-base"
)
retriever.set_ctx_encoder_tokenizer(ctx_encoder_tokenizer)
else:
logger.info("please use RAY as the distributed retrieval method")
model = self.model_class.from_pretrained(hparams.model_name_or_path, config=config, retriever=retriever)
if hparams.end2end:
ctx_encoder = DPRContextEncoder.from_pretrained(hparams.context_encoder_name)
model.set_context_encoder_for_training(ctx_encoder)
prefix = config.question_encoder.prefix
else:
if hparams.prefix is not None:
config.prefix = hparams.prefix
hparams, config = set_extra_model_params(extra_model_params, hparams, config)
model = self.model_class.from_pretrained(hparams.model_name_or_path, config=config)
prefix = config.prefix
tokenizer = (
RagTokenizer.from_pretrained(hparams.model_name_or_path)
if self.is_rag_model
else AutoTokenizer.from_pretrained(hparams.model_name_or_path)
)
self.config_dpr = DPRConfig.from_pretrained(hparams.context_encoder_name)
self.custom_config = hparams
self.context_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained(hparams.context_encoder_name)
super().__init__(hparams, config=config, tokenizer=tokenizer, model=model)
save_git_info(self.hparams.output_dir)
self.output_dir = Path(self.hparams.output_dir)
self.dpr_ctx_check_dir = str(Path(self.hparams.output_dir)) + "/dpr_ctx_checkpoint"
self.metrics_save_path = Path(self.output_dir) / "metrics.json"
self.hparams_save_path = Path(self.output_dir) / "hparams.pkl"
pickle_save(self.hparams, self.hparams_save_path)
self.step_count = 0
self.metrics = defaultdict(list)
self.dataset_kwargs: dict = dict(
data_dir=self.hparams.data_dir,
max_source_length=self.hparams.max_source_length,
prefix=prefix or "",
)
n_observations_per_split = {
"train": self.hparams.n_train,
"val": self.hparams.n_val,
"test": self.hparams.n_test,
}
self.n_obs = {k: v if v >= 0 else None for k, v in n_observations_per_split.items()}
self.target_lens = {
"train": self.hparams.max_target_length,
"val": self.hparams.val_max_target_length,
"test": self.hparams.test_max_target_length,
}
assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}"
assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}"
self.hparams.git_sha = get_git_info()["repo_sha"]
self.num_workers = hparams.num_workers
self.distributed_port = self.hparams.distributed_port
# For single GPU training, init_ddp_connection is not called.
# So we need to initialize the retrievers here.
if hparams.gpus <= 1:
if hparams.distributed_retriever == "ray":
self.model.retriever.init_retrieval()
else:
logger.info("please use RAY as the distributed retrieval method")
self.distributed_retriever = hparams.distributed_retriever
def forward(self, input_ids, **kwargs):
return self.model(input_ids, **kwargs)
def ids_to_clean_text(self, generated_ids: List[int]):
gen_text = self.tokenizer.batch_decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
return lmap(str.strip, gen_text)
def _step(self, batch: dict) -> Tuple:
source_ids, source_mask, target_ids = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"]
rag_kwargs = {}
if isinstance(self.model, T5ForConditionalGeneration):
decoder_input_ids = self.model._shift_right(target_ids)
lm_labels = target_ids
elif isinstance(self.model, BartForConditionalGeneration):
decoder_input_ids = target_ids[:, :-1].contiguous()
lm_labels = target_ids[:, 1:].clone()
else:
assert self.is_rag_model
generator = self.model.rag.generator
if isinstance(generator, T5ForConditionalGeneration):
decoder_start_token_id = generator.config.decoder_start_token_id
decoder_input_ids = (
torch.cat(
[torch.Tensor([[decoder_start_token_id]] * target_ids.shape[0]).to(target_ids), target_ids],
dim=1,
)
if target_ids.shape[0] < self.target_lens["train"]
else generator._shift_right(target_ids)
)
elif isinstance(generator, BartForConditionalGeneration):
decoder_input_ids = target_ids
lm_labels = decoder_input_ids
rag_kwargs["reduce_loss"] = True
assert decoder_input_ids is not None
outputs = self(
source_ids,
attention_mask=source_mask,
decoder_input_ids=decoder_input_ids,
use_cache=False,
labels=lm_labels,
**rag_kwargs,
)
loss = outputs["loss"]
return (loss,)
@property
def pad(self) -> int:
raise NotImplementedError("pad not implemented")
def training_step(self, batch, batch_idx) -> Dict:
global isEmUpdateBusy # use to check whether the entire embedding update process is finished or not
global isAddIndexBusy # use to check whether the entire indexing process is finished or not
global processes # use to keep threads embedding update processes
global threadHandle_index # use to keep thread in embedding indexing processes
if (self.trainer.global_rank == 0) and (self.custom_config.end2end):
if (not batch_idx == 0) and (batch_idx % self.custom_config.indexing_freq == 0):
free_gpu_list = []
nvmlInit()
deviceCount = nvmlDeviceGetCount()
my_list = json.loads(self.custom_config.gpu_order)
for i in range(deviceCount):
handle = nvmlDeviceGetHandleByIndex(i)
info = nvmlDeviceGetMemoryInfo(handle)
if info.used / 1e6 < 15:
position = my_list.index(i)
free_gpu_list.append("cuda:" + str(position))
if len(free_gpu_list) >= self.custom_config.index_gpus:
has_free_gpus = True
else:
has_free_gpus = False
if (not isEmUpdateBusy) and has_free_gpus:
model_copy = type(self.model.rag.ctx_encoder)(
self.config_dpr
) # get a new instance #this will be load in the CPU
model_copy.load_state_dict(self.model.rag.ctx_encoder.state_dict()) # copy weights
processes = []
if len(free_gpu_list) > self.custom_config.index_gpus:
cuda_devices = random.sample(free_gpu_list, self.custom_config.index_gpus)
else:
cuda_devices = free_gpu_list
num_processes = len(cuda_devices)
for rank in range(num_processes):
logger.info("Iniitializing embedding calculation process rank{}".format(rank))
device = cuda_devices[rank]
p = multiprocessing.Process(
target=embed_update,
args=(
copy.deepcopy(model_copy),
num_processes,
device,
rank,
self.custom_config.shard_dir,
self.custom_config.csv_path,
),
)
processes.append(p)
for p in processes:
p.start()
isEmUpdateBusy = True
if isEmUpdateBusy and (not isAddIndexBusy):
index_process_list = [processes[k].is_alive() for k in range(self.custom_config.index_gpus)]
if (
sum(index_process_list) == 0
): # If entire list is false, we can say all embedding calculation process has finished
logger.info("Start adding the index")
threadHandle_index = multiprocessing.Process(
target=add_index,
args=(
self.custom_config.shard_dir,
self.config.index_path,
),
)
threadHandle_index.start()
isAddIndexBusy = True
# check when index building has started
if isAddIndexBusy:
# check still the index_building process is happening
if not threadHandle_index.is_alive():
logger.info("Merging the dataset shards")
saved_dataset_shards = []
for address in glob(str(self.custom_config.shard_dir) + "/*/"):
saved_dataset_shards.append(load_from_disk(address))
concat = concatenate_datasets(saved_dataset_shards)
concat.save_to_disk(self.config.passages_path) # here we update the main passage file on the disk
logger.info("done updating the dataset")
# if you load the index from the disk make sure to update the index file here, otherwise it is ok to update the index file from the worker.
# logger.info("then updating the index")
# shutil.copy(self.custom_config.temp_index, self.config.idex_path)
logger.info("Loading new passages and iniitalzing new index")
self.trainer.model.module.module.model.rag.retriever.re_load()
self.trainer.model.module.module.model.rag.retriever.init_retrieval()
isEmUpdateBusy = False
isAddIndexBusy = False
self.trainer.accelerator_connector.accelerator.barrier(
"barrier"
) # waint untill the index and kb get re-initialized.
loss_tensors = self._step(batch)
logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
# tokens per batch
tgt_pad_token_id = (
self.tokenizer.generator.pad_token_id
if isinstance(self.tokenizer, RagTokenizer)
else self.tokenizer.pad_token_id
)
src_pad_token_id = (
self.tokenizer.question_encoder.pad_token_id
if isinstance(self.tokenizer, RagTokenizer)
else self.tokenizer.pad_token_id
)
logs["tpb"] = (
batch["input_ids"].ne(src_pad_token_id).sum() + batch["decoder_input_ids"].ne(tgt_pad_token_id).sum()
)
self.log("loss", loss_tensors[0])
return loss_tensors[0]
def validation_step(self, batch, batch_idx) -> Dict:
return self._generative_step(batch)
def validation_epoch_end(self, outputs, prefix="val") -> Dict:
self.step_count += 1
losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
loss = losses["loss"]
gen_metrics = {
k: np.array([x[k] for x in outputs]).mean() for k in self.metric_names + ["gen_time", "gen_len"]
}
metrics_tensor: torch.FloatTensor = torch.tensor(gen_metrics[self.val_metric]).type_as(loss)
gen_metrics.update({k: v.item() for k, v in losses.items()})
# fix for https://github.com/PyTorchLightning/pytorch-lightning/issues/2424
if dist.is_initialized():
dist.all_reduce(metrics_tensor, op=dist.ReduceOp.SUM)
metrics_tensor = metrics_tensor / dist.get_world_size()
gen_metrics.update({self.val_metric: metrics_tensor.item()})
losses.update(gen_metrics)
metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()}
metrics["step_count"] = self.step_count
self.save_metrics(metrics, prefix) # writes to self.metrics_save_path
log_dict = {
"val_avg_em": metrics["val_avg_em"],
"step_count": metrics["step_count"],
"val_avg_loss": metrics["val_avg_loss"],
"val_loss": loss,
"val_em": metrics_tensor,
}
self.log_dict(log_dict)
def save_metrics(self, latest_metrics, type_path) -> None:
self.metrics[type_path].append(latest_metrics)
save_json(self.metrics, self.metrics_save_path)
def calc_generative_metrics(self, preds, target) -> Dict:
return calculate_exact_match(preds, target)
def _generative_step(self, batch: dict) -> dict:
start_time = time.time()
batch = BatchEncoding(batch).to(device=self.model.device)
generated_ids = self.model.generate(
batch["input_ids"],
attention_mask=batch["attention_mask"],
do_deduplication=False, # rag specific parameter
use_cache=True,
min_length=1,
max_length=self.target_lens["val"],
)
gen_time = (time.time() - start_time) / batch["input_ids"].shape[0]
preds: List[str] = self.ids_to_clean_text(generated_ids)
target: List[str] = self.ids_to_clean_text(batch["decoder_input_ids"])
# print(preds,target)
loss_tensors = self._step(batch)
base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
gen_metrics: Dict = self.calc_generative_metrics(preds, target)
summ_len = np.mean(lmap(len, generated_ids))
base_metrics.update(gen_time=gen_time, gen_len=summ_len, preds=preds, target=target, **gen_metrics)
return base_metrics
def test_step(self, batch, batch_idx):
return self._generative_step(batch)
def test_epoch_end(self, outputs):
return self.validation_epoch_end(outputs, prefix="test")
def get_dataset(self, type_path) -> Seq2SeqDataset:
n_obs = self.n_obs[type_path]
max_target_length = self.target_lens[type_path]
dataset = Seq2SeqDataset(
self.tokenizer,
type_path=type_path,
n_obs=n_obs,
max_target_length=max_target_length,
**self.dataset_kwargs,
)
return dataset
def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader:
dataset = self.get_dataset(type_path)
dataloader = DataLoader(
dataset,
batch_size=batch_size,
collate_fn=dataset.collate_fn,
shuffle=shuffle,
num_workers=self.num_workers,
)
return dataloader
def train_dataloader(self) -> DataLoader:
dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
return dataloader
def val_dataloader(self) -> DataLoader:
return self.get_dataloader("val", batch_size=self.hparams.eval_batch_size)
def test_dataloader(self) -> DataLoader:
return self.get_dataloader("test", batch_size=self.hparams.eval_batch_size)
@pl.utilities.rank_zero_only
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
save_path = self.output_dir.joinpath("checkpoint{}".format(self.step_count))
self.model.config.save_step = self.step_count
# self.model.save_pretrained(save_path)
self.tokenizer.save_pretrained(save_path)
if self.custom_config.end2end:
modified_state_dict = self.model.state_dict()
for key in self.model.state_dict().keys():
if key.split(".")[1] == "ctx_encoder":
del modified_state_dict[key]
self.model.save_pretrained(save_directory=save_path, state_dict=modified_state_dict)
save_path_dpr = os.path.join(self.dpr_ctx_check_dir, "checkpoint{}".format(self.step_count))
self.model.rag.ctx_encoder.save_pretrained(save_path_dpr)
self.context_tokenizer.save_pretrained(save_path_dpr)
@staticmethod
def add_model_specific_args(parser, root_dir):
BaseTransformer.add_model_specific_args(parser, root_dir)
add_generic_args(parser, root_dir)
parser.add_argument(
"--max_source_length",
default=128,
type=int,
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.",
)
parser.add_argument(
"--max_target_length",
default=25,
type=int,
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.",
)
parser.add_argument(
"--val_max_target_length",
default=25,
type=int,
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.",
)
parser.add_argument(
"--test_max_target_length",
default=25,
type=int,
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.",
)
parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
parser.add_argument("--n_val", type=int, default=-1, required=False, help="# examples. -1 means use all.")
parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.")
parser.add_argument("--label_smoothing", type=float, default=0.0, required=False)
parser.add_argument(
"--prefix",
type=str,
default=None,
help="Prefix added at the beginning of each text, typically used with T5-based models.",
)
parser.add_argument(
"--early_stopping_patience",
type=int,
default=-1,
required=False,
help="-1 means never early stop. early_stopping_patience is measured in validation checks, not epochs. So val_check_interval will effect it.",
)
parser.add_argument(
"--distributed-port", type=int, default=-1, required=False, help="Port number for distributed training."
)
parser.add_argument(
"--model_type",
choices=["rag_sequence", "rag_token", "bart", "t5"],
type=str,
help="RAG model type: sequence or token, if none specified, the type is inferred from the model_name_or_path",
)
parser.add_argument(
"--context_encoder_name",
default="facebook/dpr-ctx_encoder-multiset-base",
type=str,
help="Name of the pre-trained context encoder checkpoint from the DPR",
)
parser.add_argument(
"--csv_path",
default=str(Path(__file__).parent / "test_run" / "dummy-kb" / "my_knowledge_dataset.csv"),
type=str,
help="path of the raw KB csv",
)
parser.add_argument("--end2end", action="store_true", help="whether to train the system end2end or not")
parser.add_argument("--index_gpus", type=int, help="how many GPUs used in re-encoding process")
parser.add_argument(
"--shard_dir",
type=str,
default=str(Path(__file__).parent / "test_run" / "kb-shards"),
help="directory used to keep temporary shards during the re-encode process",
)
parser.add_argument(
"--gpu_order",
type=str,
help="order of the GPU used during the fine-tuning. Used to finding free GPUs during the re-encode process. I do not have many GPUs :)",
)
parser.add_argument("--indexing_freq", type=int, help="frequency of re-encode process")
return parser
@staticmethod
def add_retriever_specific_args(parser):
parser.add_argument(
"--index_name",
type=str,
default=None,
help="Name of the index to use: 'hf' for a canonical dataset from the datasets library (default), 'custom' for a local index, or 'legacy' for the orignal one)",
)
parser.add_argument(
"--passages_path",
type=str,
default=str(Path(__file__).parent / "test_run" / "dummy-kb" / "my_knowledge_dataset"),
help="Path to the dataset of passages for custom index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`",
)
parser.add_argument(
"--index_path",
type=str,
default=str(Path(__file__).parent / "test_run" / "dummy-kb" / "my_knowledge_dataset_hnsw_index.faiss"),
help="Path to the faiss index for custom index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`",
)
parser.add_argument(
"--distributed_retriever",
choices=["ray", "pytorch"],
type=str,
default="ray",
help="What implementation to use for distributed retriever? If "
"pytorch is selected, the index is loaded on training "
"worker 0, and torch.distributed is used to handle "
"communication between training worker 0, and the other "
"training workers. If ray is selected, the Ray library is "
"used to create load the index on separate processes, "
"and Ray handles the communication between the training "
"workers and the retrieval actors.",
)
parser.add_argument(
"--use_dummy_dataset",
type=bool,
default=False,
help="Whether to use the dummy version of the dataset index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`",
)
return parser
@staticmethod
def add_ray_specific_args(parser):
# Ray cluster address.
parser.add_argument(
"--ray-address",
default="auto",
type=str,
help="The address of the Ray cluster to connect to. If not "
"specified, Ray will attempt to automatically detect the "
"cluster. Has no effect if pytorch is used as the distributed "
"retriever.",
)
parser.add_argument(
"--num_retrieval_workers",
type=int,
default=1,
help="The number of retrieval actors to use when Ray is selected"
"for the distributed retriever. Has no effect when "
"distributed_retriever is set to pytorch.",
)
return parser
def main(args=None, model=None) -> GenerativeQAModule:
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser = GenerativeQAModule.add_model_specific_args(parser, os.getcwd())
parser = GenerativeQAModule.add_retriever_specific_args(parser)
args = args or parser.parse_args()
Path(args.output_dir).mkdir(exist_ok=True)
Path(args.output_dir + "/dpr_ctx_checkpoint").mkdir(
exist_ok=True
) # save dpr_context encoder seprately for the future use
print(args.shard_dir)
if os.path.exists(args.shard_dir): # we do not need previous kb shards used in dataset re-conding and re-indexing
shutil.rmtree(args.shard_dir)
Path(args.shard_dir).mkdir(exist_ok=True)
if os.path.exists(
args.cache_dir
): # we do not need previous cache files used in dataset re-conding and re-indexing
shutil.rmtree(args.cache_dir)
Path(args.cache_dir).mkdir(exist_ok=True)
named_actors = []
if args.distributed_retriever == "ray" and args.gpus > 1:
if not is_ray_available():
raise RuntimeError("Please install Ray to use the Ray " "distributed retriever.")
# Connect to an existing Ray cluster.
try:
ray.init(address=args.ray_address)
except (ConnectionError, ValueError):
logger.warning(
"Connection to Ray cluster failed. Make sure a Ray"
"cluster is running by either using Ray's cluster "
"launcher (`ray up`) or by manually starting Ray on "
"each node via `ray start --head` for the head node "
"and `ray start --address='<ip address>:6379'` for "
"additional nodes. See "
"https://docs.ray.io/en/master/cluster/index.html "
"for more info."
)
raise
# Create Ray actors only for rank 0.
if ("LOCAL_RANK" not in os.environ or os.environ["LOCAL_RANK"] == 0) and (
"NODE_RANK" not in os.environ or os.environ["NODE_RANK"] == 0
):
remote_cls = ray.remote(RayRetriever)
named_actors = [
remote_cls.options(name="retrieval_worker_{}".format(i)).remote()
for i in range(args.num_retrieval_workers)
]
else:
logger.info(
"Getting named actors for NODE_RANK {}, LOCAL_RANK {}".format(
os.environ["NODE_RANK"], os.environ["LOCAL_RANK"]
)
)
named_actors = [ray.get_actor("retrieval_worker_{}".format(i)) for i in range(args.num_retrieval_workers)]
args.actor_handles = named_actors
assert args.actor_handles == named_actors
if model is None:
model: GenerativeQAModule = GenerativeQAModule(args)
dataset = Path(args.data_dir).name
if (
args.logger_name == "default"
or args.fast_dev_run
or str(args.output_dir).startswith("/tmp")
or str(args.output_dir).startswith("/var")
):
training_logger = True # don't pollute wandb logs unnecessarily
elif args.logger_name == "wandb":
from pytorch_lightning.loggers import WandbLogger
project = os.environ.get("WANDB_PROJECT", dataset)
training_logger = WandbLogger(name=model.output_dir.name, project=project)
elif args.logger_name == "wandb_shared":
from pytorch_lightning.loggers import WandbLogger
training_logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}")
es_callback = (
get_early_stopping_callback(model.val_metric, args.early_stopping_patience)
if args.early_stopping_patience >= 0
else False
)
trainer: pl.Trainer = generic_train(
model,
args,
logging_callback=Seq2SeqLoggingCallback(),
checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric),
early_stopping_callback=es_callback,
logger=training_logger,
profiler=pl.profiler.AdvancedProfiler() if args.profile else None,
)
pickle_save(model.hparams, model.output_dir / "hparams.pkl")
if not args.do_predict:
return model
# test() without a model tests using the best checkpoint automatically
trainer.test()
return model
if __name__ == "__main__":
multiprocessing.set_start_method("spawn")
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser = GenerativeQAModule.add_model_specific_args(parser, os.getcwd())
parser = GenerativeQAModule.add_retriever_specific_args(parser)
parser = GenerativeQAModule.add_ray_specific_args(parser)
# Pytorch Lightning Profiler
parser.add_argument(
"--profile",
action="store_true",
help="If True, use pytorch_lightning.profiler.AdvancedProfiler to profile the Trainer.",
)
args = parser.parse_args()
main(args)

View File

@ -0,0 +1,68 @@
# Sample script to finetune RAG using Ray for distributed retrieval.
# Add parent directory to python path to access lightning_base.py
export PYTHONPATH="../":"${PYTHONPATH}"
#creates the custom knowlegebase
python use_own_knowledge_dataset.py \
--csv_path /DIR/SQUAD-KB/squad-kb.csv \
--output_dir /DIR/SQUAD-KB
# Start a single-node Ray cluster.
ray start --head
# A sample finetuning run, you need to specify data_dir, output_dir and model_name_or_path
# run ./examples/rag/finetune_rag_ray.sh --help to see all the possible options
python finetune_rag.py \
--data_dir /DIR/squad-training-data \
--output_dir /DIR/model_checkpoints \
--model_name_or_path facebook/rag-token-base \
--model_type rag_token \
--fp16 \
--gpus 2 \
--profile \
--do_train \
--end2end \
--do_predict \
--n_val -1 \
--train_batch_size 4 \
--eval_batch_size 1 \
--max_source_length 128 \
--max_target_length 25 \
--val_max_target_length 25 \
--test_max_target_length 25 \
--label_smoothing 0.1 \
--dropout 0.1 \
--attention_dropout 0.1 \
--weight_decay 0.001 \
--adam_epsilon 1e-08 \
--max_grad_norm 0.1 \
--lr_scheduler polynomial \
--learning_rate 3e-05 \
--num_train_epochs 10 \
--warmup_steps 500 \
--gradient_accumulation_steps 8 \
--distributed_retriever ray \
--num_retrieval_workers 4 \
--passages_path /DIR/SQUAD-KB/my_knowledge_dataset \
--index_path /DIR/SQUAD-KB/my_knowledge_dataset_hnsw_index.faiss \
--index_name custom \
--context_encoder_name facebook/dpr-ctx_encoder-multiset-base \
--csv_path /DIR/SQUAD-KB/squad-kb.csv \
--index_gpus 1 \
--gpu_order [5,6,7,8,9,0,1,2,3,4] \
--shard_dir ./test_dir/kb-shards \
--indexing_freq 500
# Stop the Ray cluster.
ray stop
#this script was used to test the SQuAD data.
#change the dir paramater acording to your prefernece.
#please use the same device ordere when running CUDA_VISIBLE_DEVICES=5,6,7,8,9,0,1,2,3,4 sh finetune_rag_ray_end2end.sh

View File

@ -0,0 +1,81 @@
import os
from functools import partial
from glob import glob
from datasets import Features, Sequence, Value, concatenate_datasets, load_dataset, load_from_disk
import faiss
from transformers import DPRContextEncoder, DPRContextEncoderTokenizerFast
def split_text(text, n=100, character=" "):
"""Split the text every ``n``-th occurrence of ``character``"""
text = text.split(character)
return [character.join(text[i : i + n]).strip() for i in range(0, len(text), n)]
def split_documents(documents):
"""Split documents into passages"""
titles, texts = [], []
for title, text in zip(documents["title"], documents["text"]):
if text is not None:
for passage in split_text(text):
titles.append(title if title is not None else "")
texts.append(passage)
return {"title": titles, "text": texts}
def embed_update(ctx_encoder, total_processes, device, process_num, shard_dir, csv_path):
kb_dataset = load_dataset(
"csv", data_files=[csv_path], split="train", delimiter="\t", column_names=["title", "text"]
)
kb_dataset = kb_dataset.map(
split_documents, batched=True, num_proc=1
) # if you want you can load already splitted csv.
kb_list = [kb_dataset.shard(total_processes, i, contiguous=True) for i in range(total_processes)]
data_shrad = kb_list[process_num]
arrow_folder = "data_" + str(process_num)
passages_path = os.path.join(shard_dir, arrow_folder)
context_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained("facebook/dpr-ctx_encoder-multiset-base")
ctx_encoder = ctx_encoder.to(device=device)
def embed(
documents: dict, ctx_encoder: DPRContextEncoder, ctx_tokenizer: DPRContextEncoderTokenizerFast, device
) -> dict:
"""Compute the DPR embeddings of document passages"""
input_ids = ctx_tokenizer(
documents["title"], documents["text"], truncation=True, padding="longest", return_tensors="pt"
)["input_ids"]
embeddings = ctx_encoder(input_ids.to(device=device), return_dict=True).pooler_output
return {"embeddings": embeddings.detach().cpu().numpy()}
new_features = Features(
{"text": Value("string"), "title": Value("string"), "embeddings": Sequence(Value("float32"))}
) # optional, save as float32 instead of float64 to save space
dataset = data_shrad.map(
partial(embed, ctx_encoder=ctx_encoder, ctx_tokenizer=context_tokenizer, device=device),
batched=True,
batch_size=16,
features=new_features,
)
dataset.save_to_disk(passages_path)
def add_index(shard_dir, index_path):
data_shard_list = []
for shard_address in glob(str(shard_dir) + "/*/"):
data_shard_list.append(load_from_disk(shard_address))
concat = concatenate_datasets(data_shard_list)
faiss.omp_set_num_threads(96)
index = faiss.IndexHNSWFlat(768, 128, faiss.METRIC_INNER_PRODUCT)
concat.add_faiss_index("embeddings", custom_index=index)
concat.get_index("embeddings").save(
index_path
) # since we load the index in to memory,we can directly update the index in the disk

View File

@ -0,0 +1,415 @@
import argparse
import logging
import os
from pathlib import Path
from typing import Any, Dict
import pytorch_lightning as pl
from pytorch_lightning.plugins.training_type import DDPPlugin
from pytorch_lightning.utilities import rank_zero_info
from transformers import (
AdamW,
AutoConfig,
AutoModel,
AutoModelForPreTraining,
AutoModelForQuestionAnswering,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
AutoModelWithLMHead,
AutoTokenizer,
PretrainedConfig,
PreTrainedTokenizer,
)
from transformers.optimization import (
Adafactor,
get_cosine_schedule_with_warmup,
get_cosine_with_hard_restarts_schedule_with_warmup,
get_linear_schedule_with_warmup,
get_polynomial_decay_schedule_with_warmup,
)
from transformers.utils.versions import require_version_examples
logger = logging.getLogger(__name__)
require_version_examples("pytorch_lightning>=1.0.4")
MODEL_MODES = {
"base": AutoModel,
"sequence-classification": AutoModelForSequenceClassification,
"question-answering": AutoModelForQuestionAnswering,
"pretraining": AutoModelForPreTraining,
"token-classification": AutoModelForTokenClassification,
"language-modeling": AutoModelWithLMHead,
"summarization": AutoModelForSeq2SeqLM,
"translation": AutoModelForSeq2SeqLM,
}
# update this and the import above to support new schedulers from transformers.optimization
arg_to_scheduler = {
"linear": get_linear_schedule_with_warmup,
"cosine": get_cosine_schedule_with_warmup,
"cosine_w_restarts": get_cosine_with_hard_restarts_schedule_with_warmup,
"polynomial": get_polynomial_decay_schedule_with_warmup,
# '': get_constant_schedule, # not supported for now
# '': get_constant_schedule_with_warmup, # not supported for now
}
arg_to_scheduler_choices = sorted(arg_to_scheduler.keys())
arg_to_scheduler_metavar = "{" + ", ".join(arg_to_scheduler_choices) + "}"
class BaseTransformer(pl.LightningModule):
def __init__(
self,
hparams: argparse.Namespace,
num_labels=None,
mode="base",
config=None,
tokenizer=None,
model=None,
**config_kwargs
):
"""Initialize a model, tokenizer and config."""
super().__init__()
# TODO: move to self.save_hyperparameters()
# self.save_hyperparameters()
# can also expand arguments into trainer signature for easier reading
self.save_hyperparameters(hparams)
self.step_count = 0
self.output_dir = Path(self.hparams.output_dir)
cache_dir = self.hparams.cache_dir if self.hparams.cache_dir else None
if config is None:
self.config = AutoConfig.from_pretrained(
self.hparams.config_name if self.hparams.config_name else self.hparams.model_name_or_path,
**({"num_labels": num_labels} if num_labels is not None else {}),
cache_dir=cache_dir,
**config_kwargs,
)
else:
self.config: PretrainedConfig = config
extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "dropout", "attention_dropout")
for p in extra_model_params:
if getattr(self.hparams, p, None):
assert hasattr(self.config, p), f"model config doesn't have a `{p}` attribute"
setattr(self.config, p, getattr(self.hparams, p))
if tokenizer is None:
self.tokenizer = AutoTokenizer.from_pretrained(
self.hparams.tokenizer_name if self.hparams.tokenizer_name else self.hparams.model_name_or_path,
cache_dir=cache_dir,
)
else:
self.tokenizer: PreTrainedTokenizer = tokenizer
self.model_type = MODEL_MODES[mode]
if model is None:
self.model = self.model_type.from_pretrained(
self.hparams.model_name_or_path,
from_tf=bool(".ckpt" in self.hparams.model_name_or_path),
config=self.config,
cache_dir=cache_dir,
)
else:
self.model = model
def load_hf_checkpoint(self, *args, **kwargs):
self.model = self.model_type.from_pretrained(*args, **kwargs)
def get_lr_scheduler(self):
get_schedule_func = arg_to_scheduler[self.hparams.lr_scheduler]
scheduler = get_schedule_func(
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.total_steps()
)
scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
return scheduler
def configure_optimizers(self):
"""Prepare optimizer and schedule (linear warmup and decay)"""
model = self.model
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [
p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)
], # check this named paramters
"weight_decay": self.hparams.weight_decay,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
if self.hparams.adafactor:
optimizer = Adafactor(
optimizer_grouped_parameters, lr=self.hparams.learning_rate, scale_parameter=False, relative_step=False
)
else:
optimizer = AdamW(
optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon
)
self.opt = optimizer
scheduler = self.get_lr_scheduler()
return [optimizer], [scheduler]
def test_step(self, batch, batch_nb):
return self.validation_step(batch, batch_nb)
def test_epoch_end(self, outputs):
return self.validation_end(outputs)
def total_steps(self) -> int:
"""The number of total training steps that will be run. Used for lr scheduler purposes."""
num_devices = max(1, self.hparams.gpus) # TODO: consider num_tpu_cores
effective_batch_size = self.hparams.train_batch_size * self.hparams.accumulate_grad_batches * num_devices
return (self.dataset_size / effective_batch_size) * self.hparams.max_epochs
def setup(self, stage):
if stage == "test":
self.dataset_size = len(self.test_dataloader().dataset)
else:
self.train_loader = self.get_dataloader("train", self.hparams.train_batch_size, shuffle=True)
self.dataset_size = len(self.train_dataloader().dataset)
def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False):
raise NotImplementedError("You must implement this for your task")
def train_dataloader(self):
return self.train_loader
def val_dataloader(self):
return self.get_dataloader("dev", self.hparams.eval_batch_size, shuffle=False)
def test_dataloader(self):
return self.get_dataloader("test", self.hparams.eval_batch_size, shuffle=False)
def _feature_file(self, mode):
return os.path.join(
self.hparams.data_dir,
"cached_{}_{}_{}".format(
mode,
list(filter(None, self.hparams.model_name_or_path.split("/"))).pop(),
str(self.hparams.max_seq_length),
),
)
@pl.utilities.rank_zero_only
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
save_path = self.output_dir.joinpath("best_tfmr")
self.model.config.save_step = self.step_count
self.model.save_pretrained(save_path)
self.tokenizer.save_pretrained(save_path)
@staticmethod
def add_model_specific_args(parser, root_dir):
parser.add_argument(
"--model_name_or_path",
default=None,
type=str,
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models",
)
parser.add_argument(
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
)
parser.add_argument(
"--tokenizer_name",
default=None,
type=str,
help="Pretrained tokenizer name or path if not the same as model_name",
)
parser.add_argument(
"--cache_dir",
default=str(Path(__file__).parent / "test_run" / "cache"),
type=str,
help="Where do you want to store the pre-trained models downloaded from huggingface.co",
)
parser.add_argument(
"--encoder_layerdrop",
type=float,
help="Encoder layer dropout probability (Optional). Goes into model.config",
)
parser.add_argument(
"--decoder_layerdrop",
type=float,
help="Decoder layer dropout probability (Optional). Goes into model.config",
)
parser.add_argument(
"--dropout",
type=float,
help="Dropout probability (Optional). Goes into model.config",
)
parser.add_argument(
"--attention_dropout",
type=float,
help="Attention dropout probability (Optional). Goes into model.config",
)
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
parser.add_argument(
"--lr_scheduler",
default="linear",
choices=arg_to_scheduler_choices,
metavar=arg_to_scheduler_metavar,
type=str,
help="Learning rate scheduler",
)
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
parser.add_argument("--num_workers", default=4, type=int, help="kwarg passed to DataLoader")
parser.add_argument("--num_train_epochs", dest="max_epochs", default=3, type=int)
parser.add_argument("--train_batch_size", default=32, type=int)
parser.add_argument("--eval_batch_size", default=32, type=int)
parser.add_argument("--adafactor", action="store_true")
class InitCallback(pl.Callback):
# this process can also be done with PL ddp plugging.
# But still it is experimental (check original RAG, I updated that with pluggin (shamanez))
def on_sanity_check_start(self, trainer, pl_module):
if (
trainer.is_global_zero and trainer.global_rank == 0
): # we initialize the retriever only on master worker with RAY. In new pytorch-lightning accelorators are removed.
pl_module.model.rag.retriever.init_retrieval() # better to use hook functions.
class CheckParamCallback(pl.Callback):
# check whether new added model paramters are differentiable
def on_after_backward(self, trainer, pl_module):
# print(pl_module.model.rag)
for name, param in pl_module.model.rag.named_parameters():
if param.grad is None:
print(name)
class LoggingCallback(pl.Callback):
def on_batch_end(self, trainer, pl_module):
lr_scheduler = trainer.lr_schedulers[0]["scheduler"]
lrs = {f"lr_group_{i}": lr for i, lr in enumerate(lr_scheduler.get_lr())}
pl_module.logger.log_metrics(lrs)
def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
rank_zero_info("***** Validation results *****")
metrics = trainer.callback_metrics
# Log results
for key in sorted(metrics):
if key not in ["log", "progress_bar"]:
rank_zero_info("{} = {}\n".format(key, str(metrics[key])))
def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
rank_zero_info("***** Test results *****")
metrics = trainer.callback_metrics
# Log and save results to file
output_test_results_file = os.path.join(pl_module.hparams.output_dir, "test_results.txt")
with open(output_test_results_file, "w") as writer:
for key in sorted(metrics):
if key not in ["log", "progress_bar"]:
rank_zero_info("{} = {}\n".format(key, str(metrics[key])))
writer.write("{} = {}\n".format(key, str(metrics[key])))
def add_generic_args(parser, root_dir) -> None:
# To allow all pl args uncomment the following line
# parser = pl.Trainer.add_argparse_args(parser)
parser.add_argument(
"--output_dir",
default=str(Path(__file__).parent / "test_run" / "model_checkpoints"),
type=str,
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument(
"--fp16",
action="store_true",
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
)
parser.add_argument(
"--fp16_opt_level",
type=str,
default="O2",
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html",
)
parser.add_argument("--n_tpu_cores", dest="tpu_cores", type=int)
parser.add_argument("--max_grad_norm", dest="gradient_clip_val", default=1.0, type=float, help="Max gradient norm")
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
parser.add_argument("--do_predict", action="store_true", help="Whether to run predictions on the test set.")
parser.add_argument(
"--gradient_accumulation_steps",
dest="accumulate_grad_batches",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
parser.add_argument(
"--data_dir",
default=str(Path(__file__).parent / "test_run" / "dummy-train-data"),
type=str,
help="The input data dir. Should contain the training files for the CoNLL-2003 NER task.",
)
def generic_train(
model: BaseTransformer,
args: argparse.Namespace,
early_stopping_callback=None,
logger=True, # can pass WandbLogger() here
extra_callbacks=[],
checkpoint_callback=None,
logging_callback=None,
**extra_train_kwargs
):
pl.seed_everything(args.seed)
# init model
odir = Path(model.hparams.output_dir)
odir.mkdir(exist_ok=True)
# add custom checkpoints
if checkpoint_callback is None:
checkpoint_callback = pl.callbacks.ModelCheckpoint(
filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=1
)
if early_stopping_callback:
extra_callbacks.append(early_stopping_callback)
if logging_callback is None:
logging_callback = LoggingCallback()
train_params = {}
# TODO: remove with PyTorch 1.6 since pl uses native amp
if args.fp16:
train_params["precision"] = 16
train_params["amp_level"] = args.fp16_opt_level
if args.gpus > 1:
train_params["accelerator"] = "ddp"
train_params["accumulate_grad_batches"] = args.accumulate_grad_batches
# train_params["accelerator"] = extra_train_kwargs.get("accelerator", None)
train_params["profiler"] = None # extra_train_kwargs.get("profiler", None)
trainer = pl.Trainer.from_argparse_args(
args,
weights_summary=None,
callbacks=[logging_callback] + extra_callbacks + [InitCallback()] + [checkpoint_callback],
logger=logger,
plugins=[DDPPlugin(find_unused_parameters=True)], # this is needed in new pytorch-lightning new version
val_check_interval=1,
num_sanity_val_steps=2,
**train_params,
)
if args.do_train:
trainer.fit(model)
# else:
# print("RAG modeling tests with new set functions successfuly executed!")
return trainer

View File

@ -0,0 +1,7 @@
faiss-cpu >= 1.7.0
datasets >= 1.6.2
psutil >= 5.7.0
torch >= 1.4.0
pytorch-lightning == 1.3.1
nvidia-ml-py3 == 7.352.0
ray >= 1.3.0

View File

@ -0,0 +1,2 @@
Aaron Aaron Aaron ( or ; "Ahärôn") is a prophet, high priest, and the brother of Moses in the Abrahamic religions. Knowledge of Aaron, along with his brother Moses, comes exclusively from religious texts, such as the Bible and Quran. The Hebrew Bible relates that, unlike Moses, who grew up in the Egyptian royal court, Aaron and his elder sister Miriam remained with their kinsmen in the eastern border-land of Egypt (Goshen). When Moses first confronted the Egyptian king about the Israelites, Aaron served as his brother's spokesman ("prophet") to the Pharaoh. Part of the Law (Torah) that Moses received from God at Sinai granted Aaron the priesthood for himself and his male descendants, and he became the first High Priest of the Israelites. Aaron died before the Israelites crossed the North Jordan river and he was buried on Mount Hor (Numbers 33:39; Deuteronomy 10:6 says he died and was buried at Moserah). Aaron is also mentioned in the New Testament of the Bible. According to the Book of Exodus, Aaron first functioned as Moses' assistant. Because Moses complained that he could not speak well, God appointed Aaron as Moses' "prophet" (Exodus 4:10-17; 7:1). At the command of Moses, he let his rod turn into a snake. Then he stretched out his rod in order to bring on the first three plagues. After that, Moses tended to act and speak for himself. During the journey in the wilderness, Aaron was not always prominent or active. At the battle with Amalek, he was chosen with Hur to support the hand of Moses that held the "rod of God". When the revelation was given to Moses at biblical Mount Sinai, he headed the elders of Israel who accompanied Moses on the way to the summit.
"Pokémon" Pokémon , also known as in Japan, is a media franchise managed by The Pokémon Company, a Japanese consortium between Nintendo, Game Freak, and Creatures. The franchise copyright is shared by all three companies, but Nintendo is the sole owner of the trademark. The franchise was created by Satoshi Tajiri in 1995, and is centered on fictional creatures called "Pokémon", which humans, known as Pokémon Trainers, catch and train to battle each other for sport. The English slogan for the franchise is "Gotta Catch 'Em All". Works within the franchise are set in the Pokémon universe. The franchise began as "Pokémon Red" and "Green" (released outside of Japan as "Pokémon Red" and "Blue"), a pair of video games for the original Game Boy that were developed by Game Freak and published by Nintendo in February 1996. "Pokémon" has since gone on to become the highest-grossing media franchise of all time, with over in revenue up until March 2017. The original video game series is the second best-selling video game franchise (behind Nintendo's "Mario" franchise) with more than 300million copies sold and over 800million mobile downloads. In addition, the "Pokémon" franchise includes the world's top-selling toy brand, the top-selling trading card game with over 25.7billion cards sold, an anime television series that has become the most successful video game adaptation with over 20 seasons and 1,000 episodes in 124 countries, as well as an anime film series, a , books, manga comics, music, and merchandise. The franchise is also represented in other Nintendo media, such as the "Super Smash Bros." series. In November 2005, 4Kids Entertainment, which had managed the non-game related licensing of "Pokémon", announced that it had agreed not to renew the "Pokémon" representation agreement. The Pokémon Company International oversees all "Pokémon" licensing outside Asia.
Can't render this file because it contains an unexpected character in line 1 and column 35.

View File

@ -0,0 +1,48 @@
What does Moses' rod turn into ?
Who is Aron?
Where did Moses grow up ?
What happens at the command of the Moses ?
Who manages the Pokémon ?
Who owned the Pokémon trademark ?
What else include in Pokémon franchise ?
How many seasons in Pokémon animme series ?
What does Moses' rod turn into ?
Who is Aron?
Where did Moses grow up ?
What happens at the command of the Moses ?
Who manages the Pokémon ?
Who owned the Pokémon trademark ?
What else include in Pokémon franchise ?
How many seasons in Pokémon animme series ?
What does Moses' rod turn into ?
Who is Aron?
Where did Moses grow up ?
What happens at the command of the Moses ?
Who manages the Pokémon ?
Who owned the Pokémon trademark ?
What else include in Pokémon franchise ?
How many seasons in Pokémon animme series ?
What does Moses' rod turn into ?
Who is Aron?
Where did Moses grow up ?
What happens at the command of the Moses ?
Who manages the Pokémon ?
Who owned the Pokémon trademark ?
What else include in Pokémon franchise ?
How many seasons in Pokémon animme series ?
What does Moses' rod turn into ?
Who is Aron?
Where did Moses grow up ?
What happens at the command of the Moses ?
Who manages the Pokémon ?
Who owned the Pokémon trademark ?
What else include in Pokémon franchise ?
How many seasons in Pokémon animme series ?
What does Moses' rod turn into ?
Who is Aron?
Where did Moses grow up ?
What happens at the command of the Moses ?
Who manages the Pokémon ?
Who owned the Pokémon trademark ?
What else include in Pokémon franchise ?
How many seasons in Pokémon animme series ?

View File

@ -0,0 +1,48 @@
to a snake
Moses' assistant
Egyptian royal court
let his rod turn in to a snake
The Pokémon Company
Nintendo
world's top-selling toy brand, the top-selling trading card game
over 20 seasons
to a snake
Moses' assistant
Egyptian royal court
let his rod turn in to a snake
The Pokémon Company
Nintendo
world's top-selling toy brand, the top-selling trading card game
over 20 seasons
to a snake
Moses' assistant
Egyptian royal court
let his rod turn in to a snake
The Pokémon Company
Nintendo
world's top-selling toy brand, the top-selling trading card game
over 20 seasons
to a snake
Moses' assistant
Egyptian royal court
let his rod turn in to a snake
The Pokémon Company
Nintendo
world's top-selling toy brand, the top-selling trading card game
over 20 seasons
to a snake
Moses' assistant
Egyptian royal court
let his rod turn in to a snake
The Pokémon Company
Nintendo
world's top-selling toy brand, the top-selling trading card game
over 20 seasons
to a snake
Moses' assistant
Egyptian royal court
let his rod turn in to a snake
The Pokémon Company
Nintendo
world's top-selling toy brand, the top-selling trading card game
over 20 seasons

View File

@ -0,0 +1,8 @@
What does Moses' rod turn into ?
Who is Aron?
Where did Moses grow up ?
What happens at the command of the Moses ?
Who manages the Pokémon ?
Who owned the Pokémon trademark ?
What else include in Pokémon franchise ?
How many seasons in Pokémon animme series ?

View File

@ -0,0 +1,8 @@
to a snake
Moses' assistant
Egyptian royal court
let his rod turn in to a snake
The Pokémon Company
Nintendo
world's top-selling toy brand, the top-selling trading card game
over 20 seasons

View File

@ -0,0 +1,54 @@
# Add parent directory to python path to access lightning_base.py
export PYTHONPATH="../":"${PYTHONPATH}"
#creates the custom knowlegebase
python use_own_knowledge_dataset.py
# Start a single-node Ray cluster.
ray start --head
# A sample finetuning run, you need to specify data_dir, output_dir and model_name_or_path
# run ./examples/rag/finetune_rag_ray.sh --help to see all the possible options
python finetune_rag.py \
--model_name_or_path facebook/rag-token-base \
--model_type rag_token \
--fp16 \
--gpus 2 \
--profile \
--do_train \
--end2end \
--do_predict \
--n_val -1 \
--train_batch_size 1 \
--eval_batch_size 1 \
--max_source_length 128 \
--max_target_length 25 \
--val_max_target_length 25 \
--test_max_target_length 25 \
--label_smoothing 0.1 \
--dropout 0.1 \
--attention_dropout 0.1 \
--weight_decay 0.001 \
--adam_epsilon 1e-08 \
--max_grad_norm 0.1 \
--lr_scheduler polynomial \
--learning_rate 3e-05 \
--num_train_epochs 10 \
--warmup_steps 500 \
--gradient_accumulation_steps 1 \
--distributed_retriever ray \
--num_retrieval_workers 4 \
--index_name custom \
--context_encoder_name facebook/dpr-ctx_encoder-multiset-base \
--index_gpus 1 \
--gpu_order [6,7,8,9,0,1,2,3,5,4] \
--indexing_freq 5
# Stop the Ray cluster.
ray stop

View File

@ -0,0 +1,16 @@
export PYTHONPATH="../":"${PYTHONPATH}"
python use_own_knowledge_dataset.py
ray start --head
python finetune_rag.py \
--model_name_or_path facebook/rag-token-base \
--model_type rag_token \
--context_encoder_name facebook/dpr-ctx_encoder-multiset-base \
--fp16 \
--gpus 1 \
--profile \
--end2end \
--index_name custom
ray stop

View File

@ -0,0 +1,171 @@
import logging
import os
from dataclasses import dataclass, field
from functools import partial
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import List, Optional
import torch
from datasets import Features, Sequence, Value, load_dataset
import faiss
from transformers import DPRContextEncoder, DPRContextEncoderTokenizerFast, HfArgumentParser
logger = logging.getLogger(__name__)
torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"
def split_text(text: str, n=100, character=" ") -> List[str]:
"""Split the text every ``n``-th occurrence of ``character``"""
text = text.split(character)
return [character.join(text[i : i + n]).strip() for i in range(0, len(text), n)]
def split_documents(documents: dict) -> dict:
"""Split documents into passages"""
titles, texts = [], []
for title, text in zip(documents["title"], documents["text"]):
if text is not None:
for passage in split_text(text):
titles.append(title if title is not None else "")
texts.append(passage)
return {"title": titles, "text": texts}
def embed(documents: dict, ctx_encoder: DPRContextEncoder, ctx_tokenizer: DPRContextEncoderTokenizerFast) -> dict:
"""Compute the DPR embeddings of document passages"""
input_ids = ctx_tokenizer(
documents["title"], documents["text"], truncation=True, padding="longest", return_tensors="pt"
)["input_ids"]
embeddings = ctx_encoder(input_ids.to(device=device), return_dict=True).pooler_output
return {"embeddings": embeddings.detach().cpu().numpy()}
def main(
rag_example_args: "RagExampleArguments",
processing_args: "ProcessingArguments",
index_hnsw_args: "IndexHnswArguments",
):
######################################
logger.info("Step 1 - Create the dataset")
######################################
# The dataset needed for RAG must have three columns:
# - title (string): title of the document
# - text (string): text of a passage of the document
# - embeddings (array of dimension d): DPR representation of the passage
# Let's say you have documents in tab-separated csv files with columns "title" and "text"
assert os.path.isfile(rag_example_args.csv_path), "Please provide a valid path to a csv file"
# You can load a Dataset object this way
dataset = load_dataset(
"csv", data_files=[rag_example_args.csv_path], split="train", delimiter="\t", column_names=["title", "text"]
)
# More info about loading csv files in the documentation: https://huggingface.co/docs/datasets/loading_datasets.html?highlight=csv#csv-files
# Then split the documents into passages of 100 words
dataset = dataset.map(split_documents, batched=True, num_proc=processing_args.num_proc)
# And compute the embeddings
ctx_encoder = DPRContextEncoder.from_pretrained(rag_example_args.dpr_ctx_encoder_model_name).to(device=device)
ctx_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained(rag_example_args.dpr_ctx_encoder_model_name)
new_features = Features(
{"text": Value("string"), "title": Value("string"), "embeddings": Sequence(Value("float32"))}
) # optional, save as float32 instead of float64 to save space
dataset = dataset.map(
partial(embed, ctx_encoder=ctx_encoder, ctx_tokenizer=ctx_tokenizer),
batched=True,
batch_size=processing_args.batch_size,
features=new_features,
)
# And finally save your dataset
passages_path = os.path.join(rag_example_args.output_dir, "my_knowledge_dataset")
dataset.save_to_disk(passages_path)
# from datasets import load_from_disk
# dataset = load_from_disk(passages_path) # to reload the dataset
######################################
logger.info("Step 2 - Index the dataset")
######################################
# Let's use the Faiss implementation of HNSW for fast approximate nearest neighbor search
index = faiss.IndexHNSWFlat(index_hnsw_args.d, index_hnsw_args.m, faiss.METRIC_INNER_PRODUCT)
dataset.add_faiss_index("embeddings", custom_index=index)
# And save the index
index_path = os.path.join(rag_example_args.output_dir, "my_knowledge_dataset_hnsw_index.faiss")
dataset.get_index("embeddings").save(index_path)
# dataset.load_faiss_index("embeddings", index_path) # to reload the index
@dataclass
class RagExampleArguments:
csv_path: str = field(
default=str(Path(__file__).parent / "test_run" / "dummy-kb" / "my_knowledge_dataset.csv"),
metadata={"help": "Path to a tab-separated csv file with columns 'title' and 'text'"},
)
question: Optional[str] = field(
default=None,
metadata={"help": "Question that is passed as input to RAG. Default is 'What does Moses' rod turn into ?'."},
)
rag_model_name: str = field(
default="facebook/rag-sequence-nq",
metadata={"help": "The RAG model to use. Either 'facebook/rag-sequence-nq' or 'facebook/rag-token-nq'"},
)
dpr_ctx_encoder_model_name: str = field(
default="facebook/dpr-ctx_encoder-multiset-base",
metadata={
"help": "The DPR context encoder model to use. Either 'facebook/dpr-ctx_encoder-single-nq-base' or 'facebook/dpr-ctx_encoder-multiset-base'"
},
)
output_dir: Optional[str] = field(
default=str(Path(__file__).parent / "test_run" / "dummy-kb"),
metadata={"help": "Path to a directory where the dataset passages and the index will be saved"},
)
@dataclass
class ProcessingArguments:
num_proc: Optional[int] = field(
default=None,
metadata={
"help": "The number of processes to use to split the documents into passages. Default is single process."
},
)
batch_size: int = field(
default=16,
metadata={
"help": "The batch size to use when computing the passages embeddings using the DPR context encoder."
},
)
@dataclass
class IndexHnswArguments:
d: int = field(
default=768,
metadata={"help": "The dimension of the embeddings to pass to the HNSW Faiss index."},
)
m: int = field(
default=128,
metadata={
"help": "The number of bi-directional links created for every new element during the HNSW index construction."
},
)
if __name__ == "__main__":
logging.basicConfig(level=logging.WARNING)
logger.setLevel(logging.INFO)
parser = HfArgumentParser((RagExampleArguments, ProcessingArguments, IndexHnswArguments))
rag_example_args, processing_args, index_hnsw_args = parser.parse_args_into_dataclasses()
with TemporaryDirectory() as tmp_dir:
rag_example_args.output_dir = rag_example_args.output_dir or tmp_dir
main(rag_example_args, processing_args, index_hnsw_args)

View File

@ -0,0 +1,244 @@
import itertools
import json
import linecache
import os
import pickle
import re
import socket
import string
from collections import Counter
from logging import getLogger
from pathlib import Path
from typing import Callable, Dict, Iterable, List
import git
import torch
from torch.utils.data import Dataset
from transformers import BartTokenizer, RagTokenizer, T5Tokenizer
def encode_line(tokenizer, line, max_length, padding_side, pad_to_max_length=True, return_tensors="pt"):
extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) and not line.startswith(" ") else {}
tokenizer.padding_side = padding_side
return tokenizer(
[line],
max_length=max_length,
padding="max_length" if pad_to_max_length else None,
truncation=True,
return_tensors=return_tensors,
add_special_tokens=True,
**extra_kw,
)
def trim_batch(
input_ids,
pad_token_id,
attention_mask=None,
):
"""Remove columns that are populated exclusively by pad_token_id"""
keep_column_mask = input_ids.ne(pad_token_id).any(dim=0)
if attention_mask is None:
return input_ids[:, keep_column_mask]
else:
return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])
class Seq2SeqDataset(Dataset):
def __init__(
self,
tokenizer,
data_dir,
max_source_length,
max_target_length,
type_path="train",
n_obs=None,
src_lang=None,
tgt_lang=None,
prefix="",
):
super().__init__()
self.src_file = Path(data_dir).joinpath(type_path + ".source")
self.tgt_file = Path(data_dir).joinpath(type_path + ".target")
self.src_lens = self.get_char_lens(self.src_file)
self.max_source_length = max_source_length
self.max_target_length = max_target_length
assert min(self.src_lens) > 0, f"found empty line in {self.src_file}"
self.tokenizer = tokenizer
self.prefix = prefix
if n_obs is not None:
self.src_lens = self.src_lens[:n_obs]
self.src_lang = src_lang
self.tgt_lang = tgt_lang
def __len__(self):
return len(self.src_lens)
def __getitem__(self, index) -> Dict[str, torch.Tensor]:
index = index + 1 # linecache starts at 1
source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
assert source_line, f"empty source line for index {index}"
assert tgt_line, f"empty tgt line for index {index}"
# Need to add eos token manually for T5
if isinstance(self.tokenizer, T5Tokenizer):
source_line += self.tokenizer.eos_token
tgt_line += self.tokenizer.eos_token
# Pad source and target to the right
source_tokenizer = (
self.tokenizer.question_encoder if isinstance(self.tokenizer, RagTokenizer) else self.tokenizer
)
target_tokenizer = self.tokenizer.generator if isinstance(self.tokenizer, RagTokenizer) else self.tokenizer
source_inputs = encode_line(source_tokenizer, source_line, self.max_source_length, "right")
target_inputs = encode_line(target_tokenizer, tgt_line, self.max_target_length, "right")
source_ids = source_inputs["input_ids"].squeeze()
target_ids = target_inputs["input_ids"].squeeze()
src_mask = source_inputs["attention_mask"].squeeze()
return {
"input_ids": source_ids,
"attention_mask": src_mask,
"decoder_input_ids": target_ids,
}
@staticmethod
def get_char_lens(data_file):
return [len(x) for x in Path(data_file).open().readlines()]
def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
input_ids = torch.stack([x["input_ids"] for x in batch])
masks = torch.stack([x["attention_mask"] for x in batch])
target_ids = torch.stack([x["decoder_input_ids"] for x in batch])
tgt_pad_token_id = (
self.tokenizer.generator.pad_token_id
if isinstance(self.tokenizer, RagTokenizer)
else self.tokenizer.pad_token_id
)
src_pad_token_id = (
self.tokenizer.question_encoder.pad_token_id
if isinstance(self.tokenizer, RagTokenizer)
else self.tokenizer.pad_token_id
)
y = trim_batch(target_ids, tgt_pad_token_id)
source_ids, source_mask = trim_batch(input_ids, src_pad_token_id, attention_mask=masks)
batch = {
"input_ids": source_ids,
"attention_mask": source_mask,
"decoder_input_ids": y,
}
return batch
logger = getLogger(__name__)
def flatten_list(summary_ids: List[List]):
return [x for x in itertools.chain.from_iterable(summary_ids)]
def save_git_info(folder_path: str) -> None:
"""Save git information to output_dir/git_log.json"""
repo_infos = get_git_info()
save_json(repo_infos, os.path.join(folder_path, "git_log.json"))
def save_json(content, path, indent=4, **json_dump_kwargs):
with open(path, "w") as f:
json.dump(content, f, indent=indent, **json_dump_kwargs)
def load_json(path):
with open(path) as f:
return json.load(f)
def get_git_info():
repo = git.Repo(search_parent_directories=True)
repo_infos = {
"repo_id": str(repo),
"repo_sha": str(repo.head.object.hexsha),
"repo_branch": str(repo.active_branch),
"hostname": str(socket.gethostname()),
}
return repo_infos
def lmap(f: Callable, x: Iterable) -> List:
"""list(map(f, x))"""
return list(map(f, x))
def pickle_save(obj, path):
"""pickle.dump(obj, path)"""
with open(path, "wb") as f:
return pickle.dump(obj, f)
def normalize_answer(s):
"""Lower text and remove punctuation, articles and extra whitespace."""
def remove_articles(text):
return re.sub(r"\b(a|an|the)\b", " ", text)
def white_space_fix(text):
return " ".join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return "".join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def f1_score(prediction, ground_truth):
prediction_tokens = normalize_answer(prediction).split()
ground_truth_tokens = normalize_answer(ground_truth).split()
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
num_same = sum(common.values())
if num_same == 0:
return 0
precision = 1.0 * num_same / len(prediction_tokens)
recall = 1.0 * num_same / len(ground_truth_tokens)
f1 = (2 * precision * recall) / (precision + recall)
return f1
def exact_match_score(prediction, ground_truth):
return normalize_answer(prediction) == normalize_answer(ground_truth)
def calculate_exact_match(output_lns: List[str], reference_lns: List[str]) -> Dict:
assert len(output_lns) == len(reference_lns)
em = 0
for hypo, pred in zip(output_lns, reference_lns):
em += exact_match_score(hypo, pred)
if len(output_lns) > 0:
em /= len(output_lns)
return {"em": em}
def is_rag_model(model_prefix):
return model_prefix.startswith("rag")
def set_extra_model_params(extra_params, hparams, config):
equivalent_param = {p: p for p in extra_params}
# T5 models don't have `dropout` param, they have `dropout_rate` instead
equivalent_param["dropout"] = "dropout_rate"
for p in extra_params:
if getattr(hparams, p, None):
if not hasattr(config, p) and not hasattr(config, equivalent_param[p]):
logger.info("config doesn't have a `{}` attribute".format(p))
delattr(hparams, p)
continue
set_p = p if hasattr(config, p) else equivalent_param[p]
setattr(config, set_p, getattr(hparams, p))
delattr(hparams, p)
return hparams, config

View File

@ -523,6 +523,9 @@ class RagModel(RagPreTrainedModel):
self.question_encoder = question_encoder
self.generator = generator
self.ctx_encoder = None
self.context_encoder_training = False
@add_start_docstrings_to_model_forward(RAG_FORWARD_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=RetrievAugLMOutput, config_class=_CONFIG_FOR_DOC)
def forward(
@ -588,22 +591,58 @@ class RagModel(RagPreTrainedModel):
n_docs=n_docs,
return_tensors="pt",
)
context_input_ids, context_attention_mask, retrieved_doc_embeds, retrieved_doc_ids = (
retriever_outputs["context_input_ids"],
retriever_outputs["context_attention_mask"],
retriever_outputs["retrieved_doc_embeds"],
retriever_outputs["doc_ids"],
)
if self.context_encoder_training:
# set to correct device
retrieved_doc_embeds = retrieved_doc_embeds.to(question_encoder_last_hidden_state)
context_input_ids = context_input_ids.to(input_ids)
context_attention_mask = context_attention_mask.to(input_ids)
(
context_input_ids,
context_attention_mask,
retrieved_doc_embeds,
retrived_doc_input_ids,
retrived_doc_attention_mask,
retrieved_doc_ids,
) = (
retriever_outputs["context_input_ids"],
retriever_outputs["context_attention_mask"],
retriever_outputs["retrieved_doc_embeds"],
retriever_outputs["tokenized_doc_ids"],
retriever_outputs["tokenized_doc_attention_mask"],
retriever_outputs["doc_ids"],
)
# compute doc_scores
doc_scores = torch.bmm(
question_encoder_last_hidden_state.unsqueeze(1), retrieved_doc_embeds.transpose(1, 2)
).squeeze(1)
context_input_ids = context_input_ids.to(input_ids)
context_attention_mask = context_attention_mask.to(input_ids)
retrived_doc_input_ids = retrived_doc_input_ids.to(input_ids)
retrived_doc_attention_mask = retrived_doc_attention_mask.to(input_ids)
retrieved_doc_embeds = self.ctx_encoder(
retrived_doc_input_ids, attention_mask=retrived_doc_attention_mask, return_dict=True
).pooler_output
retrieved_doc_embeds = retrieved_doc_embeds.view(
-1, n_docs, question_encoder_last_hidden_state.shape[1]
) # reshaping
# compute doc_scores involving ctx_encoder
doc_scores = torch.bmm(
question_encoder_last_hidden_state.unsqueeze(1), retrieved_doc_embeds.transpose(1, 2)
).squeeze(1)
else:
context_input_ids, context_attention_mask, retrieved_doc_embeds, retrieved_doc_ids = (
retriever_outputs["context_input_ids"],
retriever_outputs["context_attention_mask"],
retriever_outputs["retrieved_doc_embeds"],
retriever_outputs["doc_ids"],
)
# set to correct device
retrieved_doc_embeds = retrieved_doc_embeds.to(question_encoder_last_hidden_state)
context_input_ids = context_input_ids.to(input_ids)
context_attention_mask = context_attention_mask.to(input_ids)
# compute doc_scores
doc_scores = torch.bmm(
question_encoder_last_hidden_state.unsqueeze(1), retrieved_doc_embeds.transpose(1, 2)
).squeeze(1)
else:
assert (
context_input_ids is not None
@ -710,6 +749,10 @@ class RagSequenceForGeneration(RagPreTrainedModel):
def set_retriever(self, retriever: RagRetriever):
self.rag.retriever = retriever
def set_context_encoder_for_training(self, ctx_encoder: PreTrainedModel):
self.rag.context_encoder_training = True
self.rag.ctx_encoder = ctx_encoder
@add_start_docstrings_to_model_forward(RAG_FORWARD_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=RetrievAugLMMarginOutput, config_class=_CONFIG_FOR_DOC)
def forward(
@ -1095,6 +1138,10 @@ class RagTokenForGeneration(RagPreTrainedModel):
def set_retriever(self, retriever: RagRetriever):
self.rag.retriever = retriever
def set_context_encoder_for_training(self, ctx_encoder: PreTrainedModel):
self.rag.context_encoder_training = True
self.rag.ctx_encoder = ctx_encoder
def prepare_inputs_for_generation(
self,
decoder_input_ids,

View File

@ -22,6 +22,7 @@ from typing import Iterable, List, Optional, Tuple
import numpy as np
from ...file_utils import cached_path, is_datasets_available, is_faiss_available, is_remote_url, requires_backends
from ...tokenization_utils import PreTrainedTokenizer
from ...tokenization_utils_base import BatchEncoding
from ...utils import logging
from .configuration_rag import RagConfig
@ -378,6 +379,9 @@ class RagRetriever:
if self._init_retrieval:
self.init_retrieval()
self.ctx_encoder_tokenizer = None
self.return_tokenized_docs = False
@staticmethod
def _build_index(config):
if config.index_name == "legacy":
@ -543,6 +547,11 @@ class RagRetriever:
doc_ids, retrieved_doc_embeds = self._main_retrieve(question_hidden_states, n_docs)
return retrieved_doc_embeds, doc_ids, self.index.get_doc_dicts(doc_ids)
def set_ctx_encoder_tokenizer(self, ctx_encoder_tokenizer: PreTrainedTokenizer):
# used in end2end retriever training
self.ctx_encoder_tokenizer = ctx_encoder_tokenizer
self.return_tokenized_docs = True
def __call__(
self,
question_input_ids: List[List[int]],
@ -594,12 +603,42 @@ class RagRetriever:
docs, input_strings, prefix, n_docs, return_tensors=return_tensors
)
return BatchEncoding(
{
"context_input_ids": context_input_ids,
"context_attention_mask": context_attention_mask,
"retrieved_doc_embeds": retrieved_doc_embeds,
"doc_ids": doc_ids,
},
tensor_type=return_tensors,
)
if self.return_tokenized_docs:
retrived_doc_text = []
retrived_doc_title = []
for b_idx in range(len(docs)):
for doc_idx in range(n_docs):
retrived_doc_text.append(docs[b_idx]["text"][doc_idx])
retrived_doc_title.append(docs[b_idx]["title"][doc_idx])
tokenized_docs = self.ctx_encoder_tokenizer(
retrived_doc_title,
retrived_doc_text,
truncation=True,
padding="longest",
return_tensors=return_tensors,
)
return BatchEncoding(
{
"context_input_ids": context_input_ids,
"context_attention_mask": context_attention_mask,
"retrieved_doc_embeds": retrieved_doc_embeds,
"doc_ids": doc_ids,
"tokenized_doc_ids": tokenized_docs["input_ids"],
"tokenized_doc_attention_mask": tokenized_docs["attention_mask"],
},
tensor_type=return_tensors,
)
else:
return BatchEncoding(
{
"context_input_ids": context_input_ids,
"context_attention_mask": context_attention_mask,
"retrieved_doc_embeds": retrieved_doc_embeds,
"doc_ids": doc_ids,
},
tensor_type=return_tensors,
)

View File

@ -26,7 +26,7 @@ import numpy as np
from transformers import BartTokenizer, T5Tokenizer
from transformers.file_utils import cached_property, is_datasets_available, is_faiss_available, is_torch_available
from transformers.models.bert.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES
from transformers.models.dpr.tokenization_dpr import DPRQuestionEncoderTokenizer
from transformers.models.dpr.tokenization_dpr import DPRContextEncoderTokenizer, DPRQuestionEncoderTokenizer
from transformers.models.roberta.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES
from transformers.testing_utils import (
require_sentencepiece,
@ -55,6 +55,7 @@ if is_torch_available() and is_datasets_available() and is_faiss_available():
AutoConfig,
AutoModel,
AutoModelForSeq2SeqLM,
DPRContextEncoder,
RagConfig,
RagModel,
RagRetriever,
@ -179,6 +180,10 @@ class RagTestMixin:
def dpr_tokenizer(self) -> DPRQuestionEncoderTokenizer:
return DPRQuestionEncoderTokenizer.from_pretrained(os.path.join(self.tmpdirname, "dpr_tokenizer"))
@cached_property
def dpr_ctx_encoder_tokenizer(self) -> DPRContextEncoderTokenizer:
return DPRContextEncoderTokenizer.from_pretrained(os.path.join(self.tmpdirname, "dpr_tokenizer"))
@cached_property
def bart_tokenizer(self) -> BartTokenizer:
return BartTokenizer.from_pretrained(os.path.join(self.tmpdirname, "bart_tokenizer"))
@ -246,6 +251,46 @@ class RagTestMixin:
# doc scores
self.assertEqual(outputs.doc_scores.shape, (input_ids.shape[0], self.n_docs))
def check_model_with_end2end_retriever(
self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, **kwargs
):
self.assertIsNotNone(config.question_encoder)
self.assertIsNotNone(config.generator)
context_encoder_tokenizer = self.dpr_ctx_encoder_tokenizer
dpr_context_encoder = DPRContextEncoder(config.question_encoder) # dpr is a twin tower
retriever = self.get_retriever(config)
retriever.set_ctx_encoder_tokenizer(context_encoder_tokenizer) # setting the ctx_encoder_tokenizer.
for model_class in [RagTokenForGeneration, RagSequenceForGeneration]:
model = model_class(config, retriever=retriever)
model.set_context_encoder_for_training(dpr_context_encoder) # set the context_encoder for training
model.to(torch_device)
model.eval()
self.assertTrue(model.config.is_encoder_decoder)
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
)
# logits
self.assertEqual(
outputs.logits.shape,
(self.n_docs * decoder_input_ids.shape[0], decoder_input_ids.shape[1], config.generator.vocab_size),
)
# generator encoder last hidden states
self.assertEqual(
outputs.generator_enc_last_hidden_state.shape,
(self.n_docs * decoder_input_ids.shape[0], self.max_combined_length, config.generator.hidden_size),
)
# doc scores
self.assertEqual(outputs.doc_scores.shape, (input_ids.shape[0], self.n_docs))
def check_model_generate_from_context_input_ids(
self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, **kwargs
):
@ -538,6 +583,10 @@ class RagTestMixin:
inputs_dict = self.config_and_inputs
self.check_model_with_retriever(**inputs_dict)
def test_model_with_end2end_retriever(self):
inputs_dict = self.config_and_inputs
self.check_model_with_end2end_retriever(**inputs_dict)
def test_model_without_retriever(self):
inputs_dict = self.config_and_inputs
self.check_model_without_retriever(**inputs_dict)

View File

@ -28,7 +28,7 @@ from transformers.models.bart.configuration_bart import BartConfig
from transformers.models.bart.tokenization_bart import BartTokenizer
from transformers.models.bert.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES
from transformers.models.dpr.configuration_dpr import DPRConfig
from transformers.models.dpr.tokenization_dpr import DPRQuestionEncoderTokenizer
from transformers.models.dpr.tokenization_dpr import DPRContextEncoderTokenizer, DPRQuestionEncoderTokenizer
from transformers.models.rag.configuration_rag import RagConfig
from transformers.models.rag.retrieval_rag import CustomHFIndex, RagRetriever
from transformers.models.roberta.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES
@ -115,6 +115,9 @@ class RagRetrieverTest(TestCase):
def get_dpr_tokenizer(self) -> DPRQuestionEncoderTokenizer:
return DPRQuestionEncoderTokenizer.from_pretrained(os.path.join(self.tmpdirname, "dpr_tokenizer"))
def get_dpr_ctx_encoder_tokenizer(self) -> DPRContextEncoderTokenizer:
return DPRContextEncoderTokenizer.from_pretrained(os.path.join(self.tmpdirname, "dpr_tokenizer"))
def get_bart_tokenizer(self) -> BartTokenizer:
return BartTokenizer.from_pretrained(os.path.join(self.tmpdirname, "bart_tokenizer"))
@ -359,3 +362,26 @@ class RagRetrieverTest(TestCase):
self.assertIsInstance(context_input_ids, torch.Tensor)
self.assertIsInstance(context_attention_mask, torch.Tensor)
self.assertIsInstance(retrieved_doc_embeds, torch.Tensor)
@require_torch
@require_tokenizers
@require_sentencepiece
def test_custom_hf_index_end2end_retriever_call(self):
context_encoder_tokenizer = self.get_dpr_ctx_encoder_tokenizer()
n_docs = 1
retriever = self.get_dummy_custom_hf_index_retriever(from_disk=False)
retriever.set_ctx_encoder_tokenizer(context_encoder_tokenizer)
question_input_ids = [[5, 7], [10, 11]]
hidden_states = np.array(
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
)
out = retriever(question_input_ids, hidden_states, prefix=retriever.config.generator.prefix, n_docs=n_docs)
self.assertEqual(
len(out), 6
) # check whether the retriever output consist of 6 attributes including tokenized docs
self.assertEqual(
all(k in out for k in ("tokenized_doc_ids", "tokenized_doc_attention_mask")), True
) # check for doc token related keys in dictionary.