mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
RAG-2nd2end-revamp (#11893)
* initial * code quality test * code quality * added test functions in test_modeling_rag.py and test_retrieval_rag.py to test end2end retreiver * minor change in test_modeling_rag * fixed tests * Update examples/research_projects/rag-end2end-retriever/README.md typo corrected as suggested by lhoestq Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> * Update examples/research_projects/rag-end2end-retriever/finetune_rag.py type change suggested by lhoestq Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> * Update src/transformers/models/rag/retrieval_rag.py Adding this change as mentioned by lhoestq. Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> * completed the minor changes suggested by the reviewers Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com>
This commit is contained in:
parent
ad25fd62bd
commit
9ec0f01b6c
47
examples/research_projects/rag-end2end-retriever/README.md
Normal file
47
examples/research_projects/rag-end2end-retriever/README.md
Normal file
@ -0,0 +1,47 @@
|
||||
# End-to-End finetuning of RAG (including DPR retriever) for Question Answering.
|
||||
|
||||
This finetuning script is actively maintained by [Shamane Siri](https://github.com/shamanez). Feel free to ask questions on the [Forum](https://discuss.huggingface.co/) or post an issue on [GitHub](https://github.com/huggingface/transformers/issues/new/choose) and tag @shamanez.
|
||||
|
||||
Others that helped out: Patrick von Platen (@patrickvonplaten), Quentin Lhoest (@lhoestq), and Rivindu Weerasekera (@rivinduw)
|
||||
|
||||
The original RAG implementation is able to train the question encoder and generator end-to-end.
|
||||
This extension enables complete end-to-end training of RAG including the context encoder in the retriever component.
|
||||
Please read the [accompanying blog post](https://shamanesiri.medium.com/how-to-finetune-the-entire-rag-architecture-including-dpr-retriever-4b4385322552) for details on this implementation.
|
||||
|
||||
The original RAG code has also been modified to work with the latest versions of pytorch lightning (version 1.2.10) and RAY (version 1.3.0). All other implementation details remain the same as the [original RAG code](https://github.com/huggingface/transformers/tree/master/examples/research_projects/rag).
|
||||
Read more about RAG at https://arxiv.org/abs/2005.11401.
|
||||
|
||||
This code can be modified to experiment with other research on retrival augmented models which include training of the retriever (e.g. [REALM](https://arxiv.org/abs/2002.08909) and [MARGE](https://arxiv.org/abs/2006.15020)).
|
||||
|
||||
To start training, use the bash script (finetune_rag_ray_end2end.sh) in this folder. This script also includes descriptions on each command-line argument used.
|
||||
|
||||
|
||||
# Testing
|
||||
|
||||
The following two bash scripts can be used to quickly test the implementation.
|
||||
1. sh ./test_run/test_rag_new_features.sh
|
||||
- Tests the newly added functions (set_context_encoder and set_context_encoder_tokenizer) related to modeling rag.
|
||||
- This is sufficient to check the model's ability to use the set functions correctly.
|
||||
2. sh ./test_run/test_finetune.sh script
|
||||
- Tests the full end-to-end fine-tuning ability with a dummy knowlendge-base and dummy training dataset (check test_dir directory).
|
||||
- Users can replace the dummy dataset and knowledge-base with their own to do their own finetuning.
|
||||
|
||||
|
||||
# Comparison of end2end RAG (including DPR finetuning) VS original-RAG
|
||||
|
||||
We conducted a simple experiment to investigate the effectiveness of this end2end training extension using the SQuAD dataset. Please execute the following steps to reproduce the results.
|
||||
|
||||
- Create a knowledge-base using all the context passages in the SQuAD dataset with their respective titles.
|
||||
- Use the question-answer pairs as training data.
|
||||
- Train the system for 10 epochs.
|
||||
- Test the Exact Match (EM) score with the SQuAD dataset's validation set.
|
||||
- Training dataset, the knowledge-base, and hyperparameters used in experiments can be accessed from [here](https://drive.google.com/drive/folders/1qyzV-PaEARWvaU_jjpnU_NUS3U_dSjtG?usp=sharing).
|
||||
|
||||
# Results
|
||||
|
||||
- We train both models for 10 epochs.
|
||||
|
||||
| Model Type | EM-Score|
|
||||
| --------------------| --------|
|
||||
| RAG-original | 28.12 |
|
||||
| RAG-end2end with DPR| 40.02 |
|
@ -0,0 +1,119 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
||||
from pytorch_lightning.utilities import rank_zero_only
|
||||
|
||||
from utils_rag import save_json
|
||||
|
||||
|
||||
def count_trainable_parameters(model):
|
||||
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
|
||||
params = sum([np.prod(p.size()) for p in model_parameters])
|
||||
return params
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_checkpoint_callback(output_dir, metric):
|
||||
"""Saves the best model by validation EM score."""
|
||||
if metric == "rouge2":
|
||||
exp = "{val_avg_rouge2:.4f}-{step_count}"
|
||||
elif metric == "bleu":
|
||||
exp = "{val_avg_bleu:.4f}-{step_count}"
|
||||
elif metric == "em":
|
||||
exp = "{val_avg_em:.4f}-{step_count}"
|
||||
elif metric == "loss":
|
||||
exp = "{val_avg_loss:.4f}-{step_count}"
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"seq2seq callbacks only support rouge2 and bleu, got {metric}, You can make your own by adding to this function."
|
||||
)
|
||||
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
dirpath=output_dir,
|
||||
filename=exp,
|
||||
monitor=f"val_{metric}",
|
||||
mode="max",
|
||||
save_top_k=1,
|
||||
every_n_val_epochs=1, # works only with PL > 1.3
|
||||
)
|
||||
|
||||
return checkpoint_callback
|
||||
|
||||
|
||||
def get_early_stopping_callback(metric, patience):
|
||||
return EarlyStopping(
|
||||
monitor=f"val_{metric}", # does this need avg?
|
||||
mode="min" if "loss" in metric else "max",
|
||||
patience=patience,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
|
||||
class Seq2SeqLoggingCallback(pl.Callback):
|
||||
def on_batch_end(self, trainer, pl_module):
|
||||
lrs = {f"lr_group_{i}": param["lr"] for i, param in enumerate(pl_module.trainer.optimizers[0].param_groups)}
|
||||
pl_module.logger.log_metrics(lrs)
|
||||
|
||||
@rank_zero_only
|
||||
def _write_logs(
|
||||
self, trainer: pl.Trainer, pl_module: pl.LightningModule, type_path: str, save_generations=True
|
||||
) -> None:
|
||||
logger.info(f"***** {type_path} results at step {trainer.global_step:05d} *****")
|
||||
metrics = trainer.callback_metrics
|
||||
trainer.logger.log_metrics({k: v for k, v in metrics.items() if k not in ["log", "progress_bar", "preds"]})
|
||||
# Log results
|
||||
od = Path(pl_module.hparams.output_dir)
|
||||
if type_path == "test":
|
||||
results_file = od / "test_results.txt"
|
||||
generations_file = od / "test_generations.txt"
|
||||
else:
|
||||
# this never gets hit. I prefer not to save intermediate generations, and results are in metrics.json
|
||||
# If people want this it will be easy enough to add back.
|
||||
results_file = od / f"{type_path}_results/{trainer.global_step:05d}.txt"
|
||||
generations_file = od / f"{type_path}_generations/{trainer.global_step:05d}.txt"
|
||||
results_file.parent.mkdir(exist_ok=True)
|
||||
generations_file.parent.mkdir(exist_ok=True)
|
||||
with open(results_file, "a+") as writer:
|
||||
for key in sorted(metrics):
|
||||
if key in ["log", "progress_bar", "preds"]:
|
||||
continue
|
||||
val = metrics[key]
|
||||
if isinstance(val, torch.Tensor):
|
||||
val = val.item()
|
||||
msg = f"{key}: {val:.6f}\n"
|
||||
writer.write(msg)
|
||||
|
||||
if not save_generations:
|
||||
return
|
||||
|
||||
if "preds" in metrics:
|
||||
content = "\n".join(metrics["preds"])
|
||||
generations_file.open("w+").write(content)
|
||||
|
||||
@rank_zero_only
|
||||
def on_train_start(self, trainer, pl_module):
|
||||
try:
|
||||
npars = pl_module.model.model.num_parameters()
|
||||
except AttributeError:
|
||||
npars = pl_module.model.num_parameters()
|
||||
|
||||
n_trainable_pars = count_trainable_parameters(pl_module)
|
||||
# mp stands for million parameters
|
||||
trainer.logger.log_metrics({"n_params": npars, "mp": npars / 1e6, "grad_mp": n_trainable_pars / 1e6})
|
||||
|
||||
@rank_zero_only
|
||||
def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
||||
save_json(pl_module.metrics, pl_module.metrics_save_path)
|
||||
return self._write_logs(trainer, pl_module, "test")
|
||||
|
||||
@rank_zero_only
|
||||
def on_validation_end(self, trainer: pl.Trainer, pl_module):
|
||||
save_json(pl_module.metrics, pl_module.metrics_save_path)
|
||||
# Uncommenting this will save val generations
|
||||
# return self._write_logs(trainer, pl_module, "valid")
|
@ -0,0 +1,185 @@
|
||||
import logging
|
||||
import random
|
||||
|
||||
import ray
|
||||
from transformers import RagConfig, RagRetriever, RagTokenizer
|
||||
from transformers.models.rag.retrieval_rag import CustomHFIndex
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RayRetriever:
|
||||
def __init__(self):
|
||||
self.initialized = False
|
||||
|
||||
def create_rag_retriever(self, config, question_encoder_tokenizer, generator_tokenizer, index):
|
||||
if not self.initialized:
|
||||
self.retriever = RagRetriever(
|
||||
config,
|
||||
question_encoder_tokenizer=question_encoder_tokenizer,
|
||||
generator_tokenizer=generator_tokenizer,
|
||||
index=index,
|
||||
init_retrieval=False,
|
||||
)
|
||||
self.initialized = True
|
||||
|
||||
def init_retrieval(self):
|
||||
self.retriever.index.init_index()
|
||||
|
||||
def clear_object(self):
|
||||
# delete the old self.retriever object before assigning the new index
|
||||
del self.retriever
|
||||
self.initialized = False
|
||||
|
||||
def retrieve(self, question_hidden_states, n_docs):
|
||||
doc_ids, retrieved_doc_embeds = self.retriever._main_retrieve(question_hidden_states, n_docs)
|
||||
doc_dicts = self.retriever.index.get_doc_dicts(doc_ids)
|
||||
return doc_ids, retrieved_doc_embeds, doc_dicts
|
||||
|
||||
|
||||
class RagRayDistributedRetriever(RagRetriever):
|
||||
"""
|
||||
A distributed retriever built on top of the ``Ray`` API, a library
|
||||
for building distributed applications (https://docs.ray.io/en/master/).
|
||||
package. During training, all training workers initialize their own
|
||||
instance of a `RagRayDistributedRetriever`, and each instance of
|
||||
this distributed retriever shares a common set of Retrieval Ray
|
||||
Actors (https://docs.ray.io/en/master/walkthrough.html#remote
|
||||
-classes-actors) that load the index on separate processes. Ray
|
||||
handles the communication between the `RagRayDistributedRetriever`
|
||||
instances and the remote Ray actors. If training is done in a
|
||||
non-distributed setup, the index will simply be loaded in the same
|
||||
process as the training worker and Ray will not be used.
|
||||
|
||||
Args:
|
||||
config (:class:`~transformers.RagConfig`):
|
||||
The configuration of the RAG model this Retriever is used with. Contains parameters indicating which ``Index`` to build.
|
||||
question_encoder_tokenizer (:class:`~transformers.PreTrainedTokenizer`):
|
||||
The tokenizer that was used to tokenize the question.
|
||||
It is used to decode the question and then use the generator_tokenizer.
|
||||
generator_tokenizer (:class:`~transformers.PreTrainedTokenizer`):
|
||||
The tokenizer used for the generator part of the RagModel.
|
||||
retrieval_workers (:obj:`List[ray.ActorClass(RayRetriever)]`): A list of already initialized `RayRetriever` actors.
|
||||
These actor classes run on remote processes and are responsible for performing the index lookup.
|
||||
index (:class:`~transformers.retrieval_rag.Index`, optional, defaults to the one defined by the configuration):
|
||||
If specified, use this index instead of the one built using the configuration
|
||||
"""
|
||||
|
||||
def __init__(self, config, question_encoder_tokenizer, generator_tokenizer, retrieval_workers, index=None):
|
||||
if index is not None and index.is_initialized() and len(retrieval_workers) > 0:
|
||||
raise ValueError(
|
||||
"When using Ray for distributed fine-tuning, "
|
||||
"you'll need to provide the paths instead, "
|
||||
"as the dataset and the index are loaded "
|
||||
"separately. More info in examples/rag/use_own_knowledge_dataset.py "
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
config,
|
||||
question_encoder_tokenizer=question_encoder_tokenizer,
|
||||
generator_tokenizer=generator_tokenizer,
|
||||
index=index,
|
||||
init_retrieval=False,
|
||||
)
|
||||
|
||||
self.retrieval_workers = retrieval_workers
|
||||
self.question_encoder_tokenizer = question_encoder_tokenizer
|
||||
self.generator_tokenizer = generator_tokenizer
|
||||
if len(self.retrieval_workers) > 0:
|
||||
ray.get(
|
||||
[
|
||||
worker.create_rag_retriever.remote(config, question_encoder_tokenizer, generator_tokenizer, index)
|
||||
for worker in self.retrieval_workers
|
||||
]
|
||||
)
|
||||
|
||||
def init_retrieval(self):
|
||||
"""
|
||||
Retriever initialization function, needs to be called from the
|
||||
training process. This function triggers retrieval initialization
|
||||
for all retrieval actors if using distributed setting, or loads
|
||||
index into current process if training is not distributed.
|
||||
"""
|
||||
logger.info("initializing retrieval")
|
||||
|
||||
if len(self.retrieval_workers) > 0:
|
||||
ray.get([worker.init_retrieval.remote() for worker in self.retrieval_workers])
|
||||
else:
|
||||
# Non-distributed training. Load index into this same process.
|
||||
self.index.init_index()
|
||||
|
||||
def retrieve(self, question_hidden_states, n_docs):
|
||||
"""
|
||||
Retrieves documents for specified ``question_hidden_states``. If
|
||||
running training with multiple workers, a random retrieval actor is
|
||||
selected to perform the index lookup and return the result.
|
||||
|
||||
Args:
|
||||
question_hidden_states (:obj:`np.ndarray` of shape :obj:`(batch_size, vector_size)`):
|
||||
A batch of query vectors to retrieve with.
|
||||
n_docs (:obj:`int`):
|
||||
The number of docs retrieved per query.
|
||||
|
||||
Output:
|
||||
retrieved_doc_embeds (:obj:`np.ndarray` of shape :obj:`(batch_size, n_docs, dim)`
|
||||
The retrieval embeddings of the retrieved docs per query.
|
||||
doc_ids (:obj:`np.ndarray` of shape :obj:`batch_size, n_docs`)
|
||||
The ids of the documents in the index
|
||||
doc_dicts (:obj:`List[dict]`):
|
||||
The retrieved_doc_embeds examples per query.
|
||||
"""
|
||||
if len(self.retrieval_workers) > 0:
|
||||
# Select a random retrieval actor.
|
||||
random_worker = self.retrieval_workers[random.randint(0, len(self.retrieval_workers) - 1)]
|
||||
doc_ids, retrieved_doc_embeds, doc_dicts = ray.get(
|
||||
random_worker.retrieve.remote(question_hidden_states, n_docs)
|
||||
)
|
||||
else:
|
||||
doc_ids, retrieved_doc_embeds = self._main_retrieve(question_hidden_states, n_docs)
|
||||
doc_dicts = self.index.get_doc_dicts(doc_ids)
|
||||
return retrieved_doc_embeds, doc_ids, doc_dicts
|
||||
|
||||
@classmethod
|
||||
def get_tokenizers(cls, retriever_name_or_path, indexed_dataset=None, **kwargs):
|
||||
return super(RagRayDistributedRetriever, cls).get_tokenizers(retriever_name_or_path, indexed_dataset, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, retriever_name_or_path, actor_handles, indexed_dataset=None, **kwargs):
|
||||
config = kwargs.pop("config", None) or RagConfig.from_pretrained(retriever_name_or_path, **kwargs)
|
||||
rag_tokenizer = RagTokenizer.from_pretrained(retriever_name_or_path, config=config)
|
||||
question_encoder_tokenizer = rag_tokenizer.question_encoder
|
||||
generator_tokenizer = rag_tokenizer.generator
|
||||
|
||||
if indexed_dataset is not None:
|
||||
config.index_name = "custom"
|
||||
index = CustomHFIndex(config.retrieval_vector_size, indexed_dataset)
|
||||
else:
|
||||
index = cls._build_index(config)
|
||||
|
||||
return cls(
|
||||
config,
|
||||
question_encoder_tokenizer=question_encoder_tokenizer,
|
||||
generator_tokenizer=generator_tokenizer,
|
||||
retrieval_workers=actor_handles,
|
||||
index=index,
|
||||
)
|
||||
|
||||
def re_load(self):
|
||||
|
||||
logger.info("re-loading the new dataset with embeddings")
|
||||
# access from the training loop
|
||||
|
||||
ray.get([worker.clear_object.remote() for worker in self.retrieval_workers])
|
||||
|
||||
# build the index object again
|
||||
index = self._build_index(self.config)
|
||||
|
||||
ray.get(
|
||||
[
|
||||
worker.create_rag_retriever.remote(
|
||||
self.config, self.question_encoder_tokenizer, self.generator_tokenizer, index
|
||||
)
|
||||
for worker in self.retrieval_workers
|
||||
]
|
||||
)
|
312
examples/research_projects/rag-end2end-retriever/eval_rag.py
Normal file
312
examples/research_projects/rag-end2end-retriever/eval_rag.py
Normal file
@ -0,0 +1,312 @@
|
||||
""" Evaluation script for RAG models."""
|
||||
|
||||
import argparse
|
||||
import ast
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers import BartForConditionalGeneration, RagRetriever, RagSequenceForGeneration, RagTokenForGeneration
|
||||
from transformers import logging as transformers_logging
|
||||
|
||||
|
||||
sys.path.append(os.path.join(os.getcwd())) # noqa: E402 # isort:skip
|
||||
from utils_rag import exact_match_score, f1_score # noqa: E402 # isort:skip
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
transformers_logging.set_verbosity_info()
|
||||
|
||||
|
||||
def infer_model_type(model_name_or_path):
|
||||
if "token" in model_name_or_path:
|
||||
return "rag_token"
|
||||
if "sequence" in model_name_or_path:
|
||||
return "rag_sequence"
|
||||
if "bart" in model_name_or_path:
|
||||
return "bart"
|
||||
return None
|
||||
|
||||
|
||||
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
|
||||
return max(metric_fn(prediction, gt) for gt in ground_truths)
|
||||
|
||||
|
||||
def get_scores(args, preds_path, gold_data_path):
|
||||
hypos = [line.strip() for line in open(preds_path, "r").readlines()]
|
||||
answers = []
|
||||
|
||||
if args.gold_data_mode == "qa":
|
||||
data = pd.read_csv(gold_data_path, sep="\t", header=None)
|
||||
for answer_list in data[1]:
|
||||
ground_truths = ast.literal_eval(answer_list)
|
||||
answers.append(ground_truths)
|
||||
else:
|
||||
references = [line.strip() for line in open(gold_data_path, "r").readlines()]
|
||||
answers = [[reference] for reference in references]
|
||||
|
||||
f1 = em = total = 0
|
||||
for prediction, ground_truths in zip(hypos, answers):
|
||||
total += 1
|
||||
em += metric_max_over_ground_truths(exact_match_score, prediction, ground_truths)
|
||||
f1 += metric_max_over_ground_truths(f1_score, prediction, ground_truths)
|
||||
|
||||
em = 100.0 * em / total
|
||||
f1 = 100.0 * f1 / total
|
||||
|
||||
logger.info(f"F1: {f1:.2f}")
|
||||
logger.info(f"EM: {em:.2f}")
|
||||
|
||||
|
||||
def get_precision_at_k(args, preds_path, gold_data_path):
|
||||
k = args.k
|
||||
hypos = [line.strip() for line in open(preds_path, "r").readlines()]
|
||||
references = [line.strip() for line in open(gold_data_path, "r").readlines()]
|
||||
|
||||
em = total = 0
|
||||
for hypo, reference in zip(hypos, references):
|
||||
hypo_provenance = set(hypo.split("\t")[:k])
|
||||
ref_provenance = set(reference.split("\t"))
|
||||
total += 1
|
||||
em += len(hypo_provenance & ref_provenance) / k
|
||||
|
||||
em = 100.0 * em / total
|
||||
logger.info(f"Precision@{k}: {em: .2f}")
|
||||
|
||||
|
||||
def evaluate_batch_retrieval(args, rag_model, questions):
|
||||
def strip_title(title):
|
||||
if title.startswith('"'):
|
||||
title = title[1:]
|
||||
if title.endswith('"'):
|
||||
title = title[:-1]
|
||||
return title
|
||||
|
||||
retriever_input_ids = rag_model.retriever.question_encoder_tokenizer.batch_encode_plus(
|
||||
questions,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
)["input_ids"].to(args.device)
|
||||
|
||||
question_enc_outputs = rag_model.rag.question_encoder(retriever_input_ids)
|
||||
question_enc_pool_output = question_enc_outputs[0]
|
||||
|
||||
result = rag_model.retriever(
|
||||
retriever_input_ids,
|
||||
question_enc_pool_output.cpu().detach().to(torch.float32).numpy(),
|
||||
prefix=rag_model.rag.generator.config.prefix,
|
||||
n_docs=rag_model.config.n_docs,
|
||||
return_tensors="pt",
|
||||
)
|
||||
all_docs = rag_model.retriever.index.get_doc_dicts(result.doc_ids)
|
||||
provenance_strings = []
|
||||
for docs in all_docs:
|
||||
provenance = [strip_title(title) for title in docs["title"]]
|
||||
provenance_strings.append("\t".join(provenance))
|
||||
return provenance_strings
|
||||
|
||||
|
||||
def evaluate_batch_e2e(args, rag_model, questions):
|
||||
with torch.no_grad():
|
||||
inputs_dict = rag_model.retriever.question_encoder_tokenizer.batch_encode_plus(
|
||||
questions, return_tensors="pt", padding=True, truncation=True
|
||||
)
|
||||
|
||||
input_ids = inputs_dict.input_ids.to(args.device)
|
||||
attention_mask = inputs_dict.attention_mask.to(args.device)
|
||||
outputs = rag_model.generate( # rag_model overwrites generate
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
num_beams=args.num_beams,
|
||||
min_length=args.min_length,
|
||||
max_length=args.max_length,
|
||||
early_stopping=False,
|
||||
num_return_sequences=1,
|
||||
bad_words_ids=[[0, 0]], # BART likes to repeat BOS tokens, dont allow it to generate more than one
|
||||
)
|
||||
answers = rag_model.retriever.generator_tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
|
||||
if args.print_predictions:
|
||||
for q, a in zip(questions, answers):
|
||||
logger.info("Q: {} - A: {}".format(q, a))
|
||||
|
||||
return answers
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
choices=["rag_sequence", "rag_token", "bart"],
|
||||
type=str,
|
||||
help="RAG model type: rag_sequence, rag_token or bart, if none specified, the type is inferred from the model_name_or_path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--index_name",
|
||||
default=None,
|
||||
choices=["exact", "compressed", "legacy"],
|
||||
type=str,
|
||||
help="RAG model retriever type",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--index_path",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Path to the retrieval index",
|
||||
)
|
||||
parser.add_argument("--n_docs", default=5, type=int, help="Number of retrieved docs")
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pretrained checkpoints or model identifier from huggingface.co/models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval_mode",
|
||||
choices=["e2e", "retrieval"],
|
||||
default="e2e",
|
||||
type=str,
|
||||
help="Evaluation mode, e2e calculates exact match and F1 of the downstream task, retrieval calculates precision@k.",
|
||||
)
|
||||
parser.add_argument("--k", default=1, type=int, help="k for the precision@k calculation")
|
||||
parser.add_argument(
|
||||
"--evaluation_set",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to a file containing evaluation samples",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gold_data_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to a tab-separated file with gold samples",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gold_data_mode",
|
||||
default="qa",
|
||||
type=str,
|
||||
choices=["qa", "ans"],
|
||||
help="Format of the gold data file"
|
||||
"qa - a single line in the following format: question [tab] answer_list"
|
||||
"ans - a single line of the gold file contains the expected answer string",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--predictions_path",
|
||||
type=str,
|
||||
default="predictions.txt",
|
||||
help="Name of the predictions file, to be stored in the checkpoints directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval_all_checkpoints",
|
||||
action="store_true",
|
||||
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval_batch_size",
|
||||
default=8,
|
||||
type=int,
|
||||
help="Batch size per GPU/CPU for evaluation.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--recalculate",
|
||||
help="Recalculate predictions even if the prediction file exists",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_beams",
|
||||
default=4,
|
||||
type=int,
|
||||
help="Number of beams to be used when generating answers",
|
||||
)
|
||||
parser.add_argument("--min_length", default=1, type=int, help="Min length of the generated answers")
|
||||
parser.add_argument("--max_length", default=50, type=int, help="Max length of the generated answers")
|
||||
|
||||
parser.add_argument(
|
||||
"--print_predictions",
|
||||
action="store_true",
|
||||
help="If True, prints predictions while evaluating.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--print_docs",
|
||||
action="store_true",
|
||||
help="If True, prints docs retried while generating.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
return args
|
||||
|
||||
|
||||
def main(args):
|
||||
model_kwargs = {}
|
||||
if args.model_type is None:
|
||||
args.model_type = infer_model_type(args.model_name_or_path)
|
||||
assert args.model_type is not None
|
||||
if args.model_type.startswith("rag"):
|
||||
model_class = RagTokenForGeneration if args.model_type == "rag_token" else RagSequenceForGeneration
|
||||
model_kwargs["n_docs"] = args.n_docs
|
||||
if args.index_name is not None:
|
||||
model_kwargs["index_name"] = args.index_name
|
||||
if args.index_path is not None:
|
||||
model_kwargs["index_path"] = args.index_path
|
||||
else:
|
||||
model_class = BartForConditionalGeneration
|
||||
|
||||
checkpoints = (
|
||||
[f.path for f in os.scandir(args.model_name_or_path) if f.is_dir()]
|
||||
if args.eval_all_checkpoints
|
||||
else [args.model_name_or_path]
|
||||
)
|
||||
|
||||
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||
|
||||
score_fn = get_scores if args.eval_mode == "e2e" else get_precision_at_k
|
||||
evaluate_batch_fn = evaluate_batch_e2e if args.eval_mode == "e2e" else evaluate_batch_retrieval
|
||||
|
||||
for checkpoint in checkpoints:
|
||||
if os.path.exists(args.predictions_path) and (not args.recalculate):
|
||||
logger.info("Calculating metrics based on an existing predictions file: {}".format(args.predictions_path))
|
||||
score_fn(args, args.predictions_path, args.gold_data_path)
|
||||
continue
|
||||
|
||||
logger.info("***** Running evaluation for {} *****".format(checkpoint))
|
||||
logger.info(" Batch size = %d", args.eval_batch_size)
|
||||
logger.info(" Predictions will be stored under {}".format(args.predictions_path))
|
||||
|
||||
if args.model_type.startswith("rag"):
|
||||
retriever = RagRetriever.from_pretrained(checkpoint, **model_kwargs)
|
||||
model = model_class.from_pretrained(checkpoint, retriever=retriever, **model_kwargs)
|
||||
model.retriever.init_retrieval()
|
||||
else:
|
||||
model = model_class.from_pretrained(checkpoint, **model_kwargs)
|
||||
model.to(args.device)
|
||||
|
||||
with open(args.evaluation_set, "r") as eval_file, open(args.predictions_path, "w") as preds_file:
|
||||
questions = []
|
||||
for line in tqdm(eval_file):
|
||||
questions.append(line.strip())
|
||||
if len(questions) == args.eval_batch_size:
|
||||
answers = evaluate_batch_fn(args, model, questions)
|
||||
preds_file.write("\n".join(answers) + "\n")
|
||||
preds_file.flush()
|
||||
questions = []
|
||||
if len(questions) > 0:
|
||||
answers = evaluate_batch_fn(args, model, questions)
|
||||
preds_file.write("\n".join(answers))
|
||||
preds_file.flush()
|
||||
|
||||
score_fn(args, args.predictions_path, args.gold_data_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
main(args)
|
789
examples/research_projects/rag-end2end-retriever/finetune_rag.py
Normal file
789
examples/research_projects/rag-end2end-retriever/finetune_rag.py
Normal file
@ -0,0 +1,789 @@
|
||||
"""Finetuning script for RAG models. Adapted from examples.seq2seq.finetune.py"""
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
import sys
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from datasets import concatenate_datasets, load_from_disk
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoTokenizer,
|
||||
BartForConditionalGeneration,
|
||||
BatchEncoding,
|
||||
DPRConfig,
|
||||
DPRContextEncoder,
|
||||
DPRContextEncoderTokenizerFast,
|
||||
RagConfig,
|
||||
RagSequenceForGeneration,
|
||||
RagTokenForGeneration,
|
||||
RagTokenizer,
|
||||
T5ForConditionalGeneration,
|
||||
)
|
||||
from transformers import logging as transformers_logging
|
||||
from transformers.integrations import is_ray_available
|
||||
|
||||
|
||||
if is_ray_available():
|
||||
import ray
|
||||
from distributed_ray_retriever import RagRayDistributedRetriever, RayRetriever
|
||||
|
||||
from glob import glob
|
||||
|
||||
from callbacks_rag import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
|
||||
from kb_encode_utils import add_index, embed_update
|
||||
from lightning_base import BaseTransformer, add_generic_args, generic_train
|
||||
from pynvml import nvmlDeviceGetCount, nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo, nvmlInit
|
||||
from utils_rag import (
|
||||
Seq2SeqDataset,
|
||||
calculate_exact_match,
|
||||
get_git_info,
|
||||
is_rag_model,
|
||||
lmap,
|
||||
pickle_save,
|
||||
save_git_info,
|
||||
save_json,
|
||||
set_extra_model_params,
|
||||
)
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
transformers_logging.set_verbosity_info()
|
||||
|
||||
|
||||
sys.path.insert(2, str(Path(__file__).resolve().parents[1]))
|
||||
isEmUpdateBusy = False
|
||||
isAddIndexBusy = False
|
||||
processes = []
|
||||
threadHandle_index = None
|
||||
|
||||
|
||||
class AttrDict(dict):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(AttrDict, self).__init__(*args, **kwargs)
|
||||
self.__dict__ = self
|
||||
|
||||
|
||||
class GenerativeQAModule(BaseTransformer):
|
||||
mode = "generative_qa"
|
||||
loss_names = ["loss"]
|
||||
metric_names = ["em"]
|
||||
val_metric = "em"
|
||||
|
||||
def __init__(self, hparams, **kwargs):
|
||||
# when loading from a pytorch lightning checkpoint, hparams are passed as dict
|
||||
if isinstance(hparams, dict):
|
||||
hparams = AttrDict(hparams)
|
||||
if hparams.model_type == "rag_sequence":
|
||||
self.model_class = RagSequenceForGeneration
|
||||
elif hparams.model_type == "rag_token":
|
||||
self.model_class = RagTokenForGeneration
|
||||
elif hparams.model_type == "bart":
|
||||
self.model_class = BartForConditionalGeneration
|
||||
else:
|
||||
self.model_class = T5ForConditionalGeneration
|
||||
self.is_rag_model = is_rag_model(hparams.model_type)
|
||||
|
||||
config_class = RagConfig if self.is_rag_model else AutoConfig
|
||||
config = config_class.from_pretrained(hparams.model_name_or_path)
|
||||
|
||||
# set retriever parameters
|
||||
config.index_name = hparams.index_name or config.index_name
|
||||
config.passages_path = hparams.passages_path or config.passages_path
|
||||
config.index_path = hparams.index_path or config.index_path
|
||||
config.use_dummy_dataset = hparams.use_dummy_dataset
|
||||
|
||||
# set extra_model_params for generator configs and load_model
|
||||
extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "attention_dropout", "dropout")
|
||||
if self.is_rag_model:
|
||||
if hparams.prefix is not None:
|
||||
config.generator.prefix = hparams.prefix
|
||||
config.label_smoothing = hparams.label_smoothing
|
||||
hparams, config.generator = set_extra_model_params(extra_model_params, hparams, config.generator)
|
||||
if hparams.distributed_retriever == "ray":
|
||||
# The Ray retriever needs the handles to the retriever actors.
|
||||
retriever = RagRayDistributedRetriever.from_pretrained(
|
||||
hparams.model_name_or_path, hparams.actor_handles, config=config
|
||||
)
|
||||
|
||||
if hparams.end2end:
|
||||
ctx_encoder_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained(
|
||||
"facebook/dpr-ctx_encoder-multiset-base"
|
||||
)
|
||||
retriever.set_ctx_encoder_tokenizer(ctx_encoder_tokenizer)
|
||||
else:
|
||||
logger.info("please use RAY as the distributed retrieval method")
|
||||
|
||||
model = self.model_class.from_pretrained(hparams.model_name_or_path, config=config, retriever=retriever)
|
||||
if hparams.end2end:
|
||||
ctx_encoder = DPRContextEncoder.from_pretrained(hparams.context_encoder_name)
|
||||
model.set_context_encoder_for_training(ctx_encoder)
|
||||
prefix = config.question_encoder.prefix
|
||||
else:
|
||||
if hparams.prefix is not None:
|
||||
config.prefix = hparams.prefix
|
||||
hparams, config = set_extra_model_params(extra_model_params, hparams, config)
|
||||
model = self.model_class.from_pretrained(hparams.model_name_or_path, config=config)
|
||||
prefix = config.prefix
|
||||
|
||||
tokenizer = (
|
||||
RagTokenizer.from_pretrained(hparams.model_name_or_path)
|
||||
if self.is_rag_model
|
||||
else AutoTokenizer.from_pretrained(hparams.model_name_or_path)
|
||||
)
|
||||
|
||||
self.config_dpr = DPRConfig.from_pretrained(hparams.context_encoder_name)
|
||||
self.custom_config = hparams
|
||||
self.context_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained(hparams.context_encoder_name)
|
||||
|
||||
super().__init__(hparams, config=config, tokenizer=tokenizer, model=model)
|
||||
|
||||
save_git_info(self.hparams.output_dir)
|
||||
self.output_dir = Path(self.hparams.output_dir)
|
||||
self.dpr_ctx_check_dir = str(Path(self.hparams.output_dir)) + "/dpr_ctx_checkpoint"
|
||||
self.metrics_save_path = Path(self.output_dir) / "metrics.json"
|
||||
self.hparams_save_path = Path(self.output_dir) / "hparams.pkl"
|
||||
pickle_save(self.hparams, self.hparams_save_path)
|
||||
self.step_count = 0
|
||||
self.metrics = defaultdict(list)
|
||||
|
||||
self.dataset_kwargs: dict = dict(
|
||||
data_dir=self.hparams.data_dir,
|
||||
max_source_length=self.hparams.max_source_length,
|
||||
prefix=prefix or "",
|
||||
)
|
||||
n_observations_per_split = {
|
||||
"train": self.hparams.n_train,
|
||||
"val": self.hparams.n_val,
|
||||
"test": self.hparams.n_test,
|
||||
}
|
||||
self.n_obs = {k: v if v >= 0 else None for k, v in n_observations_per_split.items()}
|
||||
self.target_lens = {
|
||||
"train": self.hparams.max_target_length,
|
||||
"val": self.hparams.val_max_target_length,
|
||||
"test": self.hparams.test_max_target_length,
|
||||
}
|
||||
assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}"
|
||||
assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}"
|
||||
|
||||
self.hparams.git_sha = get_git_info()["repo_sha"]
|
||||
self.num_workers = hparams.num_workers
|
||||
self.distributed_port = self.hparams.distributed_port
|
||||
|
||||
# For single GPU training, init_ddp_connection is not called.
|
||||
# So we need to initialize the retrievers here.
|
||||
if hparams.gpus <= 1:
|
||||
if hparams.distributed_retriever == "ray":
|
||||
self.model.retriever.init_retrieval()
|
||||
else:
|
||||
logger.info("please use RAY as the distributed retrieval method")
|
||||
|
||||
self.distributed_retriever = hparams.distributed_retriever
|
||||
|
||||
def forward(self, input_ids, **kwargs):
|
||||
return self.model(input_ids, **kwargs)
|
||||
|
||||
def ids_to_clean_text(self, generated_ids: List[int]):
|
||||
gen_text = self.tokenizer.batch_decode(
|
||||
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
||||
)
|
||||
return lmap(str.strip, gen_text)
|
||||
|
||||
def _step(self, batch: dict) -> Tuple:
|
||||
source_ids, source_mask, target_ids = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"]
|
||||
|
||||
rag_kwargs = {}
|
||||
if isinstance(self.model, T5ForConditionalGeneration):
|
||||
decoder_input_ids = self.model._shift_right(target_ids)
|
||||
lm_labels = target_ids
|
||||
elif isinstance(self.model, BartForConditionalGeneration):
|
||||
decoder_input_ids = target_ids[:, :-1].contiguous()
|
||||
lm_labels = target_ids[:, 1:].clone()
|
||||
else:
|
||||
assert self.is_rag_model
|
||||
generator = self.model.rag.generator
|
||||
if isinstance(generator, T5ForConditionalGeneration):
|
||||
decoder_start_token_id = generator.config.decoder_start_token_id
|
||||
decoder_input_ids = (
|
||||
torch.cat(
|
||||
[torch.Tensor([[decoder_start_token_id]] * target_ids.shape[0]).to(target_ids), target_ids],
|
||||
dim=1,
|
||||
)
|
||||
if target_ids.shape[0] < self.target_lens["train"]
|
||||
else generator._shift_right(target_ids)
|
||||
)
|
||||
elif isinstance(generator, BartForConditionalGeneration):
|
||||
decoder_input_ids = target_ids
|
||||
lm_labels = decoder_input_ids
|
||||
rag_kwargs["reduce_loss"] = True
|
||||
|
||||
assert decoder_input_ids is not None
|
||||
|
||||
outputs = self(
|
||||
source_ids,
|
||||
attention_mask=source_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
use_cache=False,
|
||||
labels=lm_labels,
|
||||
**rag_kwargs,
|
||||
)
|
||||
loss = outputs["loss"]
|
||||
return (loss,)
|
||||
|
||||
@property
|
||||
def pad(self) -> int:
|
||||
raise NotImplementedError("pad not implemented")
|
||||
|
||||
def training_step(self, batch, batch_idx) -> Dict:
|
||||
|
||||
global isEmUpdateBusy # use to check whether the entire embedding update process is finished or not
|
||||
global isAddIndexBusy # use to check whether the entire indexing process is finished or not
|
||||
global processes # use to keep threads embedding update processes
|
||||
global threadHandle_index # use to keep thread in embedding indexing processes
|
||||
|
||||
if (self.trainer.global_rank == 0) and (self.custom_config.end2end):
|
||||
|
||||
if (not batch_idx == 0) and (batch_idx % self.custom_config.indexing_freq == 0):
|
||||
free_gpu_list = []
|
||||
nvmlInit()
|
||||
deviceCount = nvmlDeviceGetCount()
|
||||
|
||||
my_list = json.loads(self.custom_config.gpu_order)
|
||||
|
||||
for i in range(deviceCount):
|
||||
handle = nvmlDeviceGetHandleByIndex(i)
|
||||
info = nvmlDeviceGetMemoryInfo(handle)
|
||||
|
||||
if info.used / 1e6 < 15:
|
||||
position = my_list.index(i)
|
||||
free_gpu_list.append("cuda:" + str(position))
|
||||
|
||||
if len(free_gpu_list) >= self.custom_config.index_gpus:
|
||||
has_free_gpus = True
|
||||
|
||||
else:
|
||||
has_free_gpus = False
|
||||
|
||||
if (not isEmUpdateBusy) and has_free_gpus:
|
||||
|
||||
model_copy = type(self.model.rag.ctx_encoder)(
|
||||
self.config_dpr
|
||||
) # get a new instance #this will be load in the CPU
|
||||
model_copy.load_state_dict(self.model.rag.ctx_encoder.state_dict()) # copy weights
|
||||
|
||||
processes = []
|
||||
|
||||
if len(free_gpu_list) > self.custom_config.index_gpus:
|
||||
cuda_devices = random.sample(free_gpu_list, self.custom_config.index_gpus)
|
||||
else:
|
||||
cuda_devices = free_gpu_list
|
||||
|
||||
num_processes = len(cuda_devices)
|
||||
|
||||
for rank in range(num_processes):
|
||||
logger.info("Iniitializing embedding calculation process rank{}".format(rank))
|
||||
device = cuda_devices[rank]
|
||||
p = multiprocessing.Process(
|
||||
target=embed_update,
|
||||
args=(
|
||||
copy.deepcopy(model_copy),
|
||||
num_processes,
|
||||
device,
|
||||
rank,
|
||||
self.custom_config.shard_dir,
|
||||
self.custom_config.csv_path,
|
||||
),
|
||||
)
|
||||
processes.append(p)
|
||||
|
||||
for p in processes:
|
||||
p.start()
|
||||
|
||||
isEmUpdateBusy = True
|
||||
|
||||
if isEmUpdateBusy and (not isAddIndexBusy):
|
||||
index_process_list = [processes[k].is_alive() for k in range(self.custom_config.index_gpus)]
|
||||
if (
|
||||
sum(index_process_list) == 0
|
||||
): # If entire list is false, we can say all embedding calculation process has finished
|
||||
logger.info("Start adding the index")
|
||||
threadHandle_index = multiprocessing.Process(
|
||||
target=add_index,
|
||||
args=(
|
||||
self.custom_config.shard_dir,
|
||||
self.config.index_path,
|
||||
),
|
||||
)
|
||||
threadHandle_index.start()
|
||||
isAddIndexBusy = True
|
||||
|
||||
# check when index building has started
|
||||
if isAddIndexBusy:
|
||||
|
||||
# check still the index_building process is happening
|
||||
if not threadHandle_index.is_alive():
|
||||
|
||||
logger.info("Merging the dataset shards")
|
||||
saved_dataset_shards = []
|
||||
|
||||
for address in glob(str(self.custom_config.shard_dir) + "/*/"):
|
||||
saved_dataset_shards.append(load_from_disk(address))
|
||||
|
||||
concat = concatenate_datasets(saved_dataset_shards)
|
||||
concat.save_to_disk(self.config.passages_path) # here we update the main passage file on the disk
|
||||
logger.info("done updating the dataset")
|
||||
|
||||
# if you load the index from the disk make sure to update the index file here, otherwise it is ok to update the index file from the worker.
|
||||
# logger.info("then updating the index")
|
||||
# shutil.copy(self.custom_config.temp_index, self.config.idex_path)
|
||||
|
||||
logger.info("Loading new passages and iniitalzing new index")
|
||||
self.trainer.model.module.module.model.rag.retriever.re_load()
|
||||
self.trainer.model.module.module.model.rag.retriever.init_retrieval()
|
||||
|
||||
isEmUpdateBusy = False
|
||||
isAddIndexBusy = False
|
||||
|
||||
self.trainer.accelerator_connector.accelerator.barrier(
|
||||
"barrier"
|
||||
) # waint untill the index and kb get re-initialized.
|
||||
|
||||
loss_tensors = self._step(batch)
|
||||
|
||||
logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
|
||||
# tokens per batch
|
||||
tgt_pad_token_id = (
|
||||
self.tokenizer.generator.pad_token_id
|
||||
if isinstance(self.tokenizer, RagTokenizer)
|
||||
else self.tokenizer.pad_token_id
|
||||
)
|
||||
src_pad_token_id = (
|
||||
self.tokenizer.question_encoder.pad_token_id
|
||||
if isinstance(self.tokenizer, RagTokenizer)
|
||||
else self.tokenizer.pad_token_id
|
||||
)
|
||||
logs["tpb"] = (
|
||||
batch["input_ids"].ne(src_pad_token_id).sum() + batch["decoder_input_ids"].ne(tgt_pad_token_id).sum()
|
||||
)
|
||||
self.log("loss", loss_tensors[0])
|
||||
return loss_tensors[0]
|
||||
|
||||
def validation_step(self, batch, batch_idx) -> Dict:
|
||||
return self._generative_step(batch)
|
||||
|
||||
def validation_epoch_end(self, outputs, prefix="val") -> Dict:
|
||||
self.step_count += 1
|
||||
losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
|
||||
loss = losses["loss"]
|
||||
gen_metrics = {
|
||||
k: np.array([x[k] for x in outputs]).mean() for k in self.metric_names + ["gen_time", "gen_len"]
|
||||
}
|
||||
metrics_tensor: torch.FloatTensor = torch.tensor(gen_metrics[self.val_metric]).type_as(loss)
|
||||
gen_metrics.update({k: v.item() for k, v in losses.items()})
|
||||
|
||||
# fix for https://github.com/PyTorchLightning/pytorch-lightning/issues/2424
|
||||
if dist.is_initialized():
|
||||
dist.all_reduce(metrics_tensor, op=dist.ReduceOp.SUM)
|
||||
metrics_tensor = metrics_tensor / dist.get_world_size()
|
||||
gen_metrics.update({self.val_metric: metrics_tensor.item()})
|
||||
|
||||
losses.update(gen_metrics)
|
||||
metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()}
|
||||
metrics["step_count"] = self.step_count
|
||||
self.save_metrics(metrics, prefix) # writes to self.metrics_save_path
|
||||
|
||||
log_dict = {
|
||||
"val_avg_em": metrics["val_avg_em"],
|
||||
"step_count": metrics["step_count"],
|
||||
"val_avg_loss": metrics["val_avg_loss"],
|
||||
"val_loss": loss,
|
||||
"val_em": metrics_tensor,
|
||||
}
|
||||
self.log_dict(log_dict)
|
||||
|
||||
def save_metrics(self, latest_metrics, type_path) -> None:
|
||||
self.metrics[type_path].append(latest_metrics)
|
||||
save_json(self.metrics, self.metrics_save_path)
|
||||
|
||||
def calc_generative_metrics(self, preds, target) -> Dict:
|
||||
return calculate_exact_match(preds, target)
|
||||
|
||||
def _generative_step(self, batch: dict) -> dict:
|
||||
start_time = time.time()
|
||||
batch = BatchEncoding(batch).to(device=self.model.device)
|
||||
generated_ids = self.model.generate(
|
||||
batch["input_ids"],
|
||||
attention_mask=batch["attention_mask"],
|
||||
do_deduplication=False, # rag specific parameter
|
||||
use_cache=True,
|
||||
min_length=1,
|
||||
max_length=self.target_lens["val"],
|
||||
)
|
||||
gen_time = (time.time() - start_time) / batch["input_ids"].shape[0]
|
||||
preds: List[str] = self.ids_to_clean_text(generated_ids)
|
||||
target: List[str] = self.ids_to_clean_text(batch["decoder_input_ids"])
|
||||
# print(preds,target)
|
||||
loss_tensors = self._step(batch)
|
||||
base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
|
||||
gen_metrics: Dict = self.calc_generative_metrics(preds, target)
|
||||
|
||||
summ_len = np.mean(lmap(len, generated_ids))
|
||||
base_metrics.update(gen_time=gen_time, gen_len=summ_len, preds=preds, target=target, **gen_metrics)
|
||||
return base_metrics
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
return self._generative_step(batch)
|
||||
|
||||
def test_epoch_end(self, outputs):
|
||||
return self.validation_epoch_end(outputs, prefix="test")
|
||||
|
||||
def get_dataset(self, type_path) -> Seq2SeqDataset:
|
||||
n_obs = self.n_obs[type_path]
|
||||
max_target_length = self.target_lens[type_path]
|
||||
dataset = Seq2SeqDataset(
|
||||
self.tokenizer,
|
||||
type_path=type_path,
|
||||
n_obs=n_obs,
|
||||
max_target_length=max_target_length,
|
||||
**self.dataset_kwargs,
|
||||
)
|
||||
return dataset
|
||||
|
||||
def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader:
|
||||
dataset = self.get_dataset(type_path)
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
collate_fn=dataset.collate_fn,
|
||||
shuffle=shuffle,
|
||||
num_workers=self.num_workers,
|
||||
)
|
||||
return dataloader
|
||||
|
||||
def train_dataloader(self) -> DataLoader:
|
||||
dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
|
||||
return dataloader
|
||||
|
||||
def val_dataloader(self) -> DataLoader:
|
||||
return self.get_dataloader("val", batch_size=self.hparams.eval_batch_size)
|
||||
|
||||
def test_dataloader(self) -> DataLoader:
|
||||
return self.get_dataloader("test", batch_size=self.hparams.eval_batch_size)
|
||||
|
||||
@pl.utilities.rank_zero_only
|
||||
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
|
||||
save_path = self.output_dir.joinpath("checkpoint{}".format(self.step_count))
|
||||
self.model.config.save_step = self.step_count
|
||||
# self.model.save_pretrained(save_path)
|
||||
self.tokenizer.save_pretrained(save_path)
|
||||
|
||||
if self.custom_config.end2end:
|
||||
|
||||
modified_state_dict = self.model.state_dict()
|
||||
for key in self.model.state_dict().keys():
|
||||
if key.split(".")[1] == "ctx_encoder":
|
||||
del modified_state_dict[key]
|
||||
self.model.save_pretrained(save_directory=save_path, state_dict=modified_state_dict)
|
||||
|
||||
save_path_dpr = os.path.join(self.dpr_ctx_check_dir, "checkpoint{}".format(self.step_count))
|
||||
self.model.rag.ctx_encoder.save_pretrained(save_path_dpr)
|
||||
self.context_tokenizer.save_pretrained(save_path_dpr)
|
||||
|
||||
@staticmethod
|
||||
def add_model_specific_args(parser, root_dir):
|
||||
BaseTransformer.add_model_specific_args(parser, root_dir)
|
||||
add_generic_args(parser, root_dir)
|
||||
parser.add_argument(
|
||||
"--max_source_length",
|
||||
default=128,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_target_length",
|
||||
default=25,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--val_max_target_length",
|
||||
default=25,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test_max_target_length",
|
||||
default=25,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.",
|
||||
)
|
||||
parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
|
||||
parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
|
||||
parser.add_argument("--n_val", type=int, default=-1, required=False, help="# examples. -1 means use all.")
|
||||
parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.")
|
||||
parser.add_argument("--label_smoothing", type=float, default=0.0, required=False)
|
||||
parser.add_argument(
|
||||
"--prefix",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Prefix added at the beginning of each text, typically used with T5-based models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--early_stopping_patience",
|
||||
type=int,
|
||||
default=-1,
|
||||
required=False,
|
||||
help="-1 means never early stop. early_stopping_patience is measured in validation checks, not epochs. So val_check_interval will effect it.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--distributed-port", type=int, default=-1, required=False, help="Port number for distributed training."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
choices=["rag_sequence", "rag_token", "bart", "t5"],
|
||||
type=str,
|
||||
help="RAG model type: sequence or token, if none specified, the type is inferred from the model_name_or_path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--context_encoder_name",
|
||||
default="facebook/dpr-ctx_encoder-multiset-base",
|
||||
type=str,
|
||||
help="Name of the pre-trained context encoder checkpoint from the DPR",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--csv_path",
|
||||
default=str(Path(__file__).parent / "test_run" / "dummy-kb" / "my_knowledge_dataset.csv"),
|
||||
type=str,
|
||||
help="path of the raw KB csv",
|
||||
)
|
||||
parser.add_argument("--end2end", action="store_true", help="whether to train the system end2end or not")
|
||||
parser.add_argument("--index_gpus", type=int, help="how many GPUs used in re-encoding process")
|
||||
parser.add_argument(
|
||||
"--shard_dir",
|
||||
type=str,
|
||||
default=str(Path(__file__).parent / "test_run" / "kb-shards"),
|
||||
help="directory used to keep temporary shards during the re-encode process",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--gpu_order",
|
||||
type=str,
|
||||
help="order of the GPU used during the fine-tuning. Used to finding free GPUs during the re-encode process. I do not have many GPUs :)",
|
||||
)
|
||||
|
||||
parser.add_argument("--indexing_freq", type=int, help="frequency of re-encode process")
|
||||
return parser
|
||||
|
||||
@staticmethod
|
||||
def add_retriever_specific_args(parser):
|
||||
parser.add_argument(
|
||||
"--index_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Name of the index to use: 'hf' for a canonical dataset from the datasets library (default), 'custom' for a local index, or 'legacy' for the orignal one)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--passages_path",
|
||||
type=str,
|
||||
default=str(Path(__file__).parent / "test_run" / "dummy-kb" / "my_knowledge_dataset"),
|
||||
help="Path to the dataset of passages for custom index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--index_path",
|
||||
type=str,
|
||||
default=str(Path(__file__).parent / "test_run" / "dummy-kb" / "my_knowledge_dataset_hnsw_index.faiss"),
|
||||
help="Path to the faiss index for custom index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--distributed_retriever",
|
||||
choices=["ray", "pytorch"],
|
||||
type=str,
|
||||
default="ray",
|
||||
help="What implementation to use for distributed retriever? If "
|
||||
"pytorch is selected, the index is loaded on training "
|
||||
"worker 0, and torch.distributed is used to handle "
|
||||
"communication between training worker 0, and the other "
|
||||
"training workers. If ray is selected, the Ray library is "
|
||||
"used to create load the index on separate processes, "
|
||||
"and Ray handles the communication between the training "
|
||||
"workers and the retrieval actors.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_dummy_dataset",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Whether to use the dummy version of the dataset index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`",
|
||||
)
|
||||
return parser
|
||||
|
||||
@staticmethod
|
||||
def add_ray_specific_args(parser):
|
||||
# Ray cluster address.
|
||||
parser.add_argument(
|
||||
"--ray-address",
|
||||
default="auto",
|
||||
type=str,
|
||||
help="The address of the Ray cluster to connect to. If not "
|
||||
"specified, Ray will attempt to automatically detect the "
|
||||
"cluster. Has no effect if pytorch is used as the distributed "
|
||||
"retriever.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_retrieval_workers",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The number of retrieval actors to use when Ray is selected"
|
||||
"for the distributed retriever. Has no effect when "
|
||||
"distributed_retriever is set to pytorch.",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def main(args=None, model=None) -> GenerativeQAModule:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = pl.Trainer.add_argparse_args(parser)
|
||||
parser = GenerativeQAModule.add_model_specific_args(parser, os.getcwd())
|
||||
parser = GenerativeQAModule.add_retriever_specific_args(parser)
|
||||
args = args or parser.parse_args()
|
||||
|
||||
Path(args.output_dir).mkdir(exist_ok=True)
|
||||
Path(args.output_dir + "/dpr_ctx_checkpoint").mkdir(
|
||||
exist_ok=True
|
||||
) # save dpr_context encoder seprately for the future use
|
||||
print(args.shard_dir)
|
||||
if os.path.exists(args.shard_dir): # we do not need previous kb shards used in dataset re-conding and re-indexing
|
||||
shutil.rmtree(args.shard_dir)
|
||||
Path(args.shard_dir).mkdir(exist_ok=True)
|
||||
|
||||
if os.path.exists(
|
||||
args.cache_dir
|
||||
): # we do not need previous cache files used in dataset re-conding and re-indexing
|
||||
shutil.rmtree(args.cache_dir)
|
||||
Path(args.cache_dir).mkdir(exist_ok=True)
|
||||
|
||||
named_actors = []
|
||||
if args.distributed_retriever == "ray" and args.gpus > 1:
|
||||
if not is_ray_available():
|
||||
raise RuntimeError("Please install Ray to use the Ray " "distributed retriever.")
|
||||
# Connect to an existing Ray cluster.
|
||||
try:
|
||||
ray.init(address=args.ray_address)
|
||||
except (ConnectionError, ValueError):
|
||||
logger.warning(
|
||||
"Connection to Ray cluster failed. Make sure a Ray"
|
||||
"cluster is running by either using Ray's cluster "
|
||||
"launcher (`ray up`) or by manually starting Ray on "
|
||||
"each node via `ray start --head` for the head node "
|
||||
"and `ray start --address='<ip address>:6379'` for "
|
||||
"additional nodes. See "
|
||||
"https://docs.ray.io/en/master/cluster/index.html "
|
||||
"for more info."
|
||||
)
|
||||
raise
|
||||
|
||||
# Create Ray actors only for rank 0.
|
||||
if ("LOCAL_RANK" not in os.environ or os.environ["LOCAL_RANK"] == 0) and (
|
||||
"NODE_RANK" not in os.environ or os.environ["NODE_RANK"] == 0
|
||||
):
|
||||
remote_cls = ray.remote(RayRetriever)
|
||||
named_actors = [
|
||||
remote_cls.options(name="retrieval_worker_{}".format(i)).remote()
|
||||
for i in range(args.num_retrieval_workers)
|
||||
]
|
||||
else:
|
||||
logger.info(
|
||||
"Getting named actors for NODE_RANK {}, LOCAL_RANK {}".format(
|
||||
os.environ["NODE_RANK"], os.environ["LOCAL_RANK"]
|
||||
)
|
||||
)
|
||||
named_actors = [ray.get_actor("retrieval_worker_{}".format(i)) for i in range(args.num_retrieval_workers)]
|
||||
args.actor_handles = named_actors
|
||||
assert args.actor_handles == named_actors
|
||||
|
||||
if model is None:
|
||||
model: GenerativeQAModule = GenerativeQAModule(args)
|
||||
|
||||
dataset = Path(args.data_dir).name
|
||||
if (
|
||||
args.logger_name == "default"
|
||||
or args.fast_dev_run
|
||||
or str(args.output_dir).startswith("/tmp")
|
||||
or str(args.output_dir).startswith("/var")
|
||||
):
|
||||
training_logger = True # don't pollute wandb logs unnecessarily
|
||||
elif args.logger_name == "wandb":
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
|
||||
project = os.environ.get("WANDB_PROJECT", dataset)
|
||||
training_logger = WandbLogger(name=model.output_dir.name, project=project)
|
||||
|
||||
elif args.logger_name == "wandb_shared":
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
|
||||
training_logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}")
|
||||
|
||||
es_callback = (
|
||||
get_early_stopping_callback(model.val_metric, args.early_stopping_patience)
|
||||
if args.early_stopping_patience >= 0
|
||||
else False
|
||||
)
|
||||
|
||||
trainer: pl.Trainer = generic_train(
|
||||
model,
|
||||
args,
|
||||
logging_callback=Seq2SeqLoggingCallback(),
|
||||
checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric),
|
||||
early_stopping_callback=es_callback,
|
||||
logger=training_logger,
|
||||
profiler=pl.profiler.AdvancedProfiler() if args.profile else None,
|
||||
)
|
||||
|
||||
pickle_save(model.hparams, model.output_dir / "hparams.pkl")
|
||||
if not args.do_predict:
|
||||
return model
|
||||
|
||||
# test() without a model tests using the best checkpoint automatically
|
||||
trainer.test()
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
multiprocessing.set_start_method("spawn")
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = pl.Trainer.add_argparse_args(parser)
|
||||
parser = GenerativeQAModule.add_model_specific_args(parser, os.getcwd())
|
||||
parser = GenerativeQAModule.add_retriever_specific_args(parser)
|
||||
parser = GenerativeQAModule.add_ray_specific_args(parser)
|
||||
|
||||
# Pytorch Lightning Profiler
|
||||
parser.add_argument(
|
||||
"--profile",
|
||||
action="store_true",
|
||||
help="If True, use pytorch_lightning.profiler.AdvancedProfiler to profile the Trainer.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
68
examples/research_projects/rag-end2end-retriever/finetune_rag_ray_end2end.sh
Executable file
68
examples/research_projects/rag-end2end-retriever/finetune_rag_ray_end2end.sh
Executable file
@ -0,0 +1,68 @@
|
||||
# Sample script to finetune RAG using Ray for distributed retrieval.
|
||||
|
||||
# Add parent directory to python path to access lightning_base.py
|
||||
export PYTHONPATH="../":"${PYTHONPATH}"
|
||||
|
||||
#creates the custom knowlegebase
|
||||
python use_own_knowledge_dataset.py \
|
||||
--csv_path /DIR/SQUAD-KB/squad-kb.csv \
|
||||
--output_dir /DIR/SQUAD-KB
|
||||
|
||||
# Start a single-node Ray cluster.
|
||||
ray start --head
|
||||
|
||||
# A sample finetuning run, you need to specify data_dir, output_dir and model_name_or_path
|
||||
# run ./examples/rag/finetune_rag_ray.sh --help to see all the possible options
|
||||
|
||||
|
||||
|
||||
python finetune_rag.py \
|
||||
--data_dir /DIR/squad-training-data \
|
||||
--output_dir /DIR/model_checkpoints \
|
||||
--model_name_or_path facebook/rag-token-base \
|
||||
--model_type rag_token \
|
||||
--fp16 \
|
||||
--gpus 2 \
|
||||
--profile \
|
||||
--do_train \
|
||||
--end2end \
|
||||
--do_predict \
|
||||
--n_val -1 \
|
||||
--train_batch_size 4 \
|
||||
--eval_batch_size 1 \
|
||||
--max_source_length 128 \
|
||||
--max_target_length 25 \
|
||||
--val_max_target_length 25 \
|
||||
--test_max_target_length 25 \
|
||||
--label_smoothing 0.1 \
|
||||
--dropout 0.1 \
|
||||
--attention_dropout 0.1 \
|
||||
--weight_decay 0.001 \
|
||||
--adam_epsilon 1e-08 \
|
||||
--max_grad_norm 0.1 \
|
||||
--lr_scheduler polynomial \
|
||||
--learning_rate 3e-05 \
|
||||
--num_train_epochs 10 \
|
||||
--warmup_steps 500 \
|
||||
--gradient_accumulation_steps 8 \
|
||||
--distributed_retriever ray \
|
||||
--num_retrieval_workers 4 \
|
||||
--passages_path /DIR/SQUAD-KB/my_knowledge_dataset \
|
||||
--index_path /DIR/SQUAD-KB/my_knowledge_dataset_hnsw_index.faiss \
|
||||
--index_name custom \
|
||||
--context_encoder_name facebook/dpr-ctx_encoder-multiset-base \
|
||||
--csv_path /DIR/SQUAD-KB/squad-kb.csv \
|
||||
--index_gpus 1 \
|
||||
--gpu_order [5,6,7,8,9,0,1,2,3,4] \
|
||||
--shard_dir ./test_dir/kb-shards \
|
||||
--indexing_freq 500
|
||||
|
||||
|
||||
|
||||
# Stop the Ray cluster.
|
||||
ray stop
|
||||
|
||||
|
||||
#this script was used to test the SQuAD data.
|
||||
#change the dir paramater acording to your prefernece.
|
||||
#please use the same device ordere when running CUDA_VISIBLE_DEVICES=5,6,7,8,9,0,1,2,3,4 sh finetune_rag_ray_end2end.sh
|
@ -0,0 +1,81 @@
|
||||
import os
|
||||
from functools import partial
|
||||
from glob import glob
|
||||
|
||||
from datasets import Features, Sequence, Value, concatenate_datasets, load_dataset, load_from_disk
|
||||
|
||||
import faiss
|
||||
from transformers import DPRContextEncoder, DPRContextEncoderTokenizerFast
|
||||
|
||||
|
||||
def split_text(text, n=100, character=" "):
|
||||
"""Split the text every ``n``-th occurrence of ``character``"""
|
||||
text = text.split(character)
|
||||
return [character.join(text[i : i + n]).strip() for i in range(0, len(text), n)]
|
||||
|
||||
|
||||
def split_documents(documents):
|
||||
"""Split documents into passages"""
|
||||
titles, texts = [], []
|
||||
for title, text in zip(documents["title"], documents["text"]):
|
||||
if text is not None:
|
||||
for passage in split_text(text):
|
||||
titles.append(title if title is not None else "")
|
||||
texts.append(passage)
|
||||
return {"title": titles, "text": texts}
|
||||
|
||||
|
||||
def embed_update(ctx_encoder, total_processes, device, process_num, shard_dir, csv_path):
|
||||
|
||||
kb_dataset = load_dataset(
|
||||
"csv", data_files=[csv_path], split="train", delimiter="\t", column_names=["title", "text"]
|
||||
)
|
||||
kb_dataset = kb_dataset.map(
|
||||
split_documents, batched=True, num_proc=1
|
||||
) # if you want you can load already splitted csv.
|
||||
kb_list = [kb_dataset.shard(total_processes, i, contiguous=True) for i in range(total_processes)]
|
||||
data_shrad = kb_list[process_num]
|
||||
|
||||
arrow_folder = "data_" + str(process_num)
|
||||
passages_path = os.path.join(shard_dir, arrow_folder)
|
||||
|
||||
context_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained("facebook/dpr-ctx_encoder-multiset-base")
|
||||
ctx_encoder = ctx_encoder.to(device=device)
|
||||
|
||||
def embed(
|
||||
documents: dict, ctx_encoder: DPRContextEncoder, ctx_tokenizer: DPRContextEncoderTokenizerFast, device
|
||||
) -> dict:
|
||||
"""Compute the DPR embeddings of document passages"""
|
||||
input_ids = ctx_tokenizer(
|
||||
documents["title"], documents["text"], truncation=True, padding="longest", return_tensors="pt"
|
||||
)["input_ids"]
|
||||
embeddings = ctx_encoder(input_ids.to(device=device), return_dict=True).pooler_output
|
||||
return {"embeddings": embeddings.detach().cpu().numpy()}
|
||||
|
||||
new_features = Features(
|
||||
{"text": Value("string"), "title": Value("string"), "embeddings": Sequence(Value("float32"))}
|
||||
) # optional, save as float32 instead of float64 to save space
|
||||
|
||||
dataset = data_shrad.map(
|
||||
partial(embed, ctx_encoder=ctx_encoder, ctx_tokenizer=context_tokenizer, device=device),
|
||||
batched=True,
|
||||
batch_size=16,
|
||||
features=new_features,
|
||||
)
|
||||
dataset.save_to_disk(passages_path)
|
||||
|
||||
|
||||
def add_index(shard_dir, index_path):
|
||||
data_shard_list = []
|
||||
|
||||
for shard_address in glob(str(shard_dir) + "/*/"):
|
||||
data_shard_list.append(load_from_disk(shard_address))
|
||||
|
||||
concat = concatenate_datasets(data_shard_list)
|
||||
faiss.omp_set_num_threads(96)
|
||||
|
||||
index = faiss.IndexHNSWFlat(768, 128, faiss.METRIC_INNER_PRODUCT)
|
||||
concat.add_faiss_index("embeddings", custom_index=index)
|
||||
concat.get_index("embeddings").save(
|
||||
index_path
|
||||
) # since we load the index in to memory,we can directly update the index in the disk
|
@ -0,0 +1,415 @@
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.plugins.training_type import DDPPlugin
|
||||
from pytorch_lightning.utilities import rank_zero_info
|
||||
|
||||
from transformers import (
|
||||
AdamW,
|
||||
AutoConfig,
|
||||
AutoModel,
|
||||
AutoModelForPreTraining,
|
||||
AutoModelForQuestionAnswering,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoModelForTokenClassification,
|
||||
AutoModelWithLMHead,
|
||||
AutoTokenizer,
|
||||
PretrainedConfig,
|
||||
PreTrainedTokenizer,
|
||||
)
|
||||
from transformers.optimization import (
|
||||
Adafactor,
|
||||
get_cosine_schedule_with_warmup,
|
||||
get_cosine_with_hard_restarts_schedule_with_warmup,
|
||||
get_linear_schedule_with_warmup,
|
||||
get_polynomial_decay_schedule_with_warmup,
|
||||
)
|
||||
from transformers.utils.versions import require_version_examples
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
require_version_examples("pytorch_lightning>=1.0.4")
|
||||
|
||||
MODEL_MODES = {
|
||||
"base": AutoModel,
|
||||
"sequence-classification": AutoModelForSequenceClassification,
|
||||
"question-answering": AutoModelForQuestionAnswering,
|
||||
"pretraining": AutoModelForPreTraining,
|
||||
"token-classification": AutoModelForTokenClassification,
|
||||
"language-modeling": AutoModelWithLMHead,
|
||||
"summarization": AutoModelForSeq2SeqLM,
|
||||
"translation": AutoModelForSeq2SeqLM,
|
||||
}
|
||||
|
||||
|
||||
# update this and the import above to support new schedulers from transformers.optimization
|
||||
arg_to_scheduler = {
|
||||
"linear": get_linear_schedule_with_warmup,
|
||||
"cosine": get_cosine_schedule_with_warmup,
|
||||
"cosine_w_restarts": get_cosine_with_hard_restarts_schedule_with_warmup,
|
||||
"polynomial": get_polynomial_decay_schedule_with_warmup,
|
||||
# '': get_constant_schedule, # not supported for now
|
||||
# '': get_constant_schedule_with_warmup, # not supported for now
|
||||
}
|
||||
arg_to_scheduler_choices = sorted(arg_to_scheduler.keys())
|
||||
arg_to_scheduler_metavar = "{" + ", ".join(arg_to_scheduler_choices) + "}"
|
||||
|
||||
|
||||
class BaseTransformer(pl.LightningModule):
|
||||
def __init__(
|
||||
self,
|
||||
hparams: argparse.Namespace,
|
||||
num_labels=None,
|
||||
mode="base",
|
||||
config=None,
|
||||
tokenizer=None,
|
||||
model=None,
|
||||
**config_kwargs
|
||||
):
|
||||
"""Initialize a model, tokenizer and config."""
|
||||
super().__init__()
|
||||
# TODO: move to self.save_hyperparameters()
|
||||
# self.save_hyperparameters()
|
||||
# can also expand arguments into trainer signature for easier reading
|
||||
|
||||
self.save_hyperparameters(hparams)
|
||||
self.step_count = 0
|
||||
self.output_dir = Path(self.hparams.output_dir)
|
||||
cache_dir = self.hparams.cache_dir if self.hparams.cache_dir else None
|
||||
if config is None:
|
||||
self.config = AutoConfig.from_pretrained(
|
||||
self.hparams.config_name if self.hparams.config_name else self.hparams.model_name_or_path,
|
||||
**({"num_labels": num_labels} if num_labels is not None else {}),
|
||||
cache_dir=cache_dir,
|
||||
**config_kwargs,
|
||||
)
|
||||
else:
|
||||
self.config: PretrainedConfig = config
|
||||
|
||||
extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "dropout", "attention_dropout")
|
||||
for p in extra_model_params:
|
||||
if getattr(self.hparams, p, None):
|
||||
assert hasattr(self.config, p), f"model config doesn't have a `{p}` attribute"
|
||||
setattr(self.config, p, getattr(self.hparams, p))
|
||||
|
||||
if tokenizer is None:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.hparams.tokenizer_name if self.hparams.tokenizer_name else self.hparams.model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
else:
|
||||
self.tokenizer: PreTrainedTokenizer = tokenizer
|
||||
self.model_type = MODEL_MODES[mode]
|
||||
if model is None:
|
||||
self.model = self.model_type.from_pretrained(
|
||||
self.hparams.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in self.hparams.model_name_or_path),
|
||||
config=self.config,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
else:
|
||||
self.model = model
|
||||
|
||||
def load_hf_checkpoint(self, *args, **kwargs):
|
||||
self.model = self.model_type.from_pretrained(*args, **kwargs)
|
||||
|
||||
def get_lr_scheduler(self):
|
||||
get_schedule_func = arg_to_scheduler[self.hparams.lr_scheduler]
|
||||
scheduler = get_schedule_func(
|
||||
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.total_steps()
|
||||
)
|
||||
scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
|
||||
return scheduler
|
||||
|
||||
def configure_optimizers(self):
|
||||
"""Prepare optimizer and schedule (linear warmup and decay)"""
|
||||
model = self.model
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [
|
||||
p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)
|
||||
], # check this named paramters
|
||||
"weight_decay": self.hparams.weight_decay,
|
||||
},
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
if self.hparams.adafactor:
|
||||
optimizer = Adafactor(
|
||||
optimizer_grouped_parameters, lr=self.hparams.learning_rate, scale_parameter=False, relative_step=False
|
||||
)
|
||||
|
||||
else:
|
||||
optimizer = AdamW(
|
||||
optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon
|
||||
)
|
||||
self.opt = optimizer
|
||||
|
||||
scheduler = self.get_lr_scheduler()
|
||||
|
||||
return [optimizer], [scheduler]
|
||||
|
||||
def test_step(self, batch, batch_nb):
|
||||
return self.validation_step(batch, batch_nb)
|
||||
|
||||
def test_epoch_end(self, outputs):
|
||||
return self.validation_end(outputs)
|
||||
|
||||
def total_steps(self) -> int:
|
||||
"""The number of total training steps that will be run. Used for lr scheduler purposes."""
|
||||
num_devices = max(1, self.hparams.gpus) # TODO: consider num_tpu_cores
|
||||
effective_batch_size = self.hparams.train_batch_size * self.hparams.accumulate_grad_batches * num_devices
|
||||
return (self.dataset_size / effective_batch_size) * self.hparams.max_epochs
|
||||
|
||||
def setup(self, stage):
|
||||
if stage == "test":
|
||||
self.dataset_size = len(self.test_dataloader().dataset)
|
||||
else:
|
||||
self.train_loader = self.get_dataloader("train", self.hparams.train_batch_size, shuffle=True)
|
||||
self.dataset_size = len(self.train_dataloader().dataset)
|
||||
|
||||
def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False):
|
||||
raise NotImplementedError("You must implement this for your task")
|
||||
|
||||
def train_dataloader(self):
|
||||
return self.train_loader
|
||||
|
||||
def val_dataloader(self):
|
||||
return self.get_dataloader("dev", self.hparams.eval_batch_size, shuffle=False)
|
||||
|
||||
def test_dataloader(self):
|
||||
return self.get_dataloader("test", self.hparams.eval_batch_size, shuffle=False)
|
||||
|
||||
def _feature_file(self, mode):
|
||||
return os.path.join(
|
||||
self.hparams.data_dir,
|
||||
"cached_{}_{}_{}".format(
|
||||
mode,
|
||||
list(filter(None, self.hparams.model_name_or_path.split("/"))).pop(),
|
||||
str(self.hparams.max_seq_length),
|
||||
),
|
||||
)
|
||||
|
||||
@pl.utilities.rank_zero_only
|
||||
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
|
||||
save_path = self.output_dir.joinpath("best_tfmr")
|
||||
self.model.config.save_step = self.step_count
|
||||
self.model.save_pretrained(save_path)
|
||||
self.tokenizer.save_pretrained(save_path)
|
||||
|
||||
@staticmethod
|
||||
def add_model_specific_args(parser, root_dir):
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_name",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_dir",
|
||||
default=str(Path(__file__).parent / "test_run" / "cache"),
|
||||
type=str,
|
||||
help="Where do you want to store the pre-trained models downloaded from huggingface.co",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--encoder_layerdrop",
|
||||
type=float,
|
||||
help="Encoder layer dropout probability (Optional). Goes into model.config",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decoder_layerdrop",
|
||||
type=float,
|
||||
help="Decoder layer dropout probability (Optional). Goes into model.config",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dropout",
|
||||
type=float,
|
||||
help="Dropout probability (Optional). Goes into model.config",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--attention_dropout",
|
||||
type=float,
|
||||
help="Attention dropout probability (Optional). Goes into model.config",
|
||||
)
|
||||
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
|
||||
parser.add_argument(
|
||||
"--lr_scheduler",
|
||||
default="linear",
|
||||
choices=arg_to_scheduler_choices,
|
||||
metavar=arg_to_scheduler_metavar,
|
||||
type=str,
|
||||
help="Learning rate scheduler",
|
||||
)
|
||||
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
|
||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
||||
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
|
||||
parser.add_argument("--num_workers", default=4, type=int, help="kwarg passed to DataLoader")
|
||||
parser.add_argument("--num_train_epochs", dest="max_epochs", default=3, type=int)
|
||||
parser.add_argument("--train_batch_size", default=32, type=int)
|
||||
parser.add_argument("--eval_batch_size", default=32, type=int)
|
||||
parser.add_argument("--adafactor", action="store_true")
|
||||
|
||||
|
||||
class InitCallback(pl.Callback):
|
||||
# this process can also be done with PL ddp plugging.
|
||||
# But still it is experimental (check original RAG, I updated that with pluggin (shamanez))
|
||||
def on_sanity_check_start(self, trainer, pl_module):
|
||||
if (
|
||||
trainer.is_global_zero and trainer.global_rank == 0
|
||||
): # we initialize the retriever only on master worker with RAY. In new pytorch-lightning accelorators are removed.
|
||||
pl_module.model.rag.retriever.init_retrieval() # better to use hook functions.
|
||||
|
||||
|
||||
class CheckParamCallback(pl.Callback):
|
||||
# check whether new added model paramters are differentiable
|
||||
def on_after_backward(self, trainer, pl_module):
|
||||
# print(pl_module.model.rag)
|
||||
for name, param in pl_module.model.rag.named_parameters():
|
||||
if param.grad is None:
|
||||
print(name)
|
||||
|
||||
|
||||
class LoggingCallback(pl.Callback):
|
||||
def on_batch_end(self, trainer, pl_module):
|
||||
lr_scheduler = trainer.lr_schedulers[0]["scheduler"]
|
||||
lrs = {f"lr_group_{i}": lr for i, lr in enumerate(lr_scheduler.get_lr())}
|
||||
pl_module.logger.log_metrics(lrs)
|
||||
|
||||
def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
||||
rank_zero_info("***** Validation results *****")
|
||||
metrics = trainer.callback_metrics
|
||||
# Log results
|
||||
for key in sorted(metrics):
|
||||
if key not in ["log", "progress_bar"]:
|
||||
rank_zero_info("{} = {}\n".format(key, str(metrics[key])))
|
||||
|
||||
def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
||||
rank_zero_info("***** Test results *****")
|
||||
metrics = trainer.callback_metrics
|
||||
# Log and save results to file
|
||||
output_test_results_file = os.path.join(pl_module.hparams.output_dir, "test_results.txt")
|
||||
with open(output_test_results_file, "w") as writer:
|
||||
for key in sorted(metrics):
|
||||
if key not in ["log", "progress_bar"]:
|
||||
rank_zero_info("{} = {}\n".format(key, str(metrics[key])))
|
||||
writer.write("{} = {}\n".format(key, str(metrics[key])))
|
||||
|
||||
|
||||
def add_generic_args(parser, root_dir) -> None:
|
||||
# To allow all pl args uncomment the following line
|
||||
# parser = pl.Trainer.add_argparse_args(parser)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
default=str(Path(__file__).parent / "test_run" / "model_checkpoints"),
|
||||
type=str,
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fp16",
|
||||
action="store_true",
|
||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--fp16_opt_level",
|
||||
type=str,
|
||||
default="O2",
|
||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||
"See details at https://nvidia.github.io/apex/amp.html",
|
||||
)
|
||||
parser.add_argument("--n_tpu_cores", dest="tpu_cores", type=int)
|
||||
parser.add_argument("--max_grad_norm", dest="gradient_clip_val", default=1.0, type=float, help="Max gradient norm")
|
||||
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
||||
parser.add_argument("--do_predict", action="store_true", help="Whether to run predictions on the test set.")
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
dest="accumulate_grad_batches",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
default=str(Path(__file__).parent / "test_run" / "dummy-train-data"),
|
||||
type=str,
|
||||
help="The input data dir. Should contain the training files for the CoNLL-2003 NER task.",
|
||||
)
|
||||
|
||||
|
||||
def generic_train(
|
||||
model: BaseTransformer,
|
||||
args: argparse.Namespace,
|
||||
early_stopping_callback=None,
|
||||
logger=True, # can pass WandbLogger() here
|
||||
extra_callbacks=[],
|
||||
checkpoint_callback=None,
|
||||
logging_callback=None,
|
||||
**extra_train_kwargs
|
||||
):
|
||||
pl.seed_everything(args.seed)
|
||||
|
||||
# init model
|
||||
odir = Path(model.hparams.output_dir)
|
||||
odir.mkdir(exist_ok=True)
|
||||
|
||||
# add custom checkpoints
|
||||
if checkpoint_callback is None:
|
||||
checkpoint_callback = pl.callbacks.ModelCheckpoint(
|
||||
filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=1
|
||||
)
|
||||
if early_stopping_callback:
|
||||
extra_callbacks.append(early_stopping_callback)
|
||||
if logging_callback is None:
|
||||
logging_callback = LoggingCallback()
|
||||
|
||||
train_params = {}
|
||||
|
||||
# TODO: remove with PyTorch 1.6 since pl uses native amp
|
||||
if args.fp16:
|
||||
train_params["precision"] = 16
|
||||
train_params["amp_level"] = args.fp16_opt_level
|
||||
|
||||
if args.gpus > 1:
|
||||
train_params["accelerator"] = "ddp"
|
||||
|
||||
train_params["accumulate_grad_batches"] = args.accumulate_grad_batches
|
||||
# train_params["accelerator"] = extra_train_kwargs.get("accelerator", None)
|
||||
train_params["profiler"] = None # extra_train_kwargs.get("profiler", None)
|
||||
|
||||
trainer = pl.Trainer.from_argparse_args(
|
||||
args,
|
||||
weights_summary=None,
|
||||
callbacks=[logging_callback] + extra_callbacks + [InitCallback()] + [checkpoint_callback],
|
||||
logger=logger,
|
||||
plugins=[DDPPlugin(find_unused_parameters=True)], # this is needed in new pytorch-lightning new version
|
||||
val_check_interval=1,
|
||||
num_sanity_val_steps=2,
|
||||
**train_params,
|
||||
)
|
||||
|
||||
if args.do_train:
|
||||
trainer.fit(model)
|
||||
|
||||
# else:
|
||||
# print("RAG modeling tests with new set functions successfuly executed!")
|
||||
return trainer
|
@ -0,0 +1,7 @@
|
||||
faiss-cpu >= 1.7.0
|
||||
datasets >= 1.6.2
|
||||
psutil >= 5.7.0
|
||||
torch >= 1.4.0
|
||||
pytorch-lightning == 1.3.1
|
||||
nvidia-ml-py3 == 7.352.0
|
||||
ray >= 1.3.0
|
@ -0,0 +1,2 @@
|
||||
Aaron Aaron Aaron ( or ; "Ahärôn") is a prophet, high priest, and the brother of Moses in the Abrahamic religions. Knowledge of Aaron, along with his brother Moses, comes exclusively from religious texts, such as the Bible and Quran. The Hebrew Bible relates that, unlike Moses, who grew up in the Egyptian royal court, Aaron and his elder sister Miriam remained with their kinsmen in the eastern border-land of Egypt (Goshen). When Moses first confronted the Egyptian king about the Israelites, Aaron served as his brother's spokesman ("prophet") to the Pharaoh. Part of the Law (Torah) that Moses received from God at Sinai granted Aaron the priesthood for himself and his male descendants, and he became the first High Priest of the Israelites. Aaron died before the Israelites crossed the North Jordan river and he was buried on Mount Hor (Numbers 33:39; Deuteronomy 10:6 says he died and was buried at Moserah). Aaron is also mentioned in the New Testament of the Bible. According to the Book of Exodus, Aaron first functioned as Moses' assistant. Because Moses complained that he could not speak well, God appointed Aaron as Moses' "prophet" (Exodus 4:10-17; 7:1). At the command of Moses, he let his rod turn into a snake. Then he stretched out his rod in order to bring on the first three plagues. After that, Moses tended to act and speak for himself. During the journey in the wilderness, Aaron was not always prominent or active. At the battle with Amalek, he was chosen with Hur to support the hand of Moses that held the "rod of God". When the revelation was given to Moses at biblical Mount Sinai, he headed the elders of Israel who accompanied Moses on the way to the summit.
|
||||
"Pokémon" Pokémon , also known as in Japan, is a media franchise managed by The Pokémon Company, a Japanese consortium between Nintendo, Game Freak, and Creatures. The franchise copyright is shared by all three companies, but Nintendo is the sole owner of the trademark. The franchise was created by Satoshi Tajiri in 1995, and is centered on fictional creatures called "Pokémon", which humans, known as Pokémon Trainers, catch and train to battle each other for sport. The English slogan for the franchise is "Gotta Catch 'Em All". Works within the franchise are set in the Pokémon universe. The franchise began as "Pokémon Red" and "Green" (released outside of Japan as "Pokémon Red" and "Blue"), a pair of video games for the original Game Boy that were developed by Game Freak and published by Nintendo in February 1996. "Pokémon" has since gone on to become the highest-grossing media franchise of all time, with over in revenue up until March 2017. The original video game series is the second best-selling video game franchise (behind Nintendo's "Mario" franchise) with more than 300million copies sold and over 800million mobile downloads. In addition, the "Pokémon" franchise includes the world's top-selling toy brand, the top-selling trading card game with over 25.7billion cards sold, an anime television series that has become the most successful video game adaptation with over 20 seasons and 1,000 episodes in 124 countries, as well as an anime film series, a , books, manga comics, music, and merchandise. The franchise is also represented in other Nintendo media, such as the "Super Smash Bros." series. In November 2005, 4Kids Entertainment, which had managed the non-game related licensing of "Pokémon", announced that it had agreed not to renew the "Pokémon" representation agreement. The Pokémon Company International oversees all "Pokémon" licensing outside Asia.
|
Can't render this file because it contains an unexpected character in line 1 and column 35.
|
@ -0,0 +1,48 @@
|
||||
What does Moses' rod turn into ?
|
||||
Who is Aron?
|
||||
Where did Moses grow up ?
|
||||
What happens at the command of the Moses ?
|
||||
Who manages the Pokémon ?
|
||||
Who owned the Pokémon trademark ?
|
||||
What else include in Pokémon franchise ?
|
||||
How many seasons in Pokémon animme series ?
|
||||
What does Moses' rod turn into ?
|
||||
Who is Aron?
|
||||
Where did Moses grow up ?
|
||||
What happens at the command of the Moses ?
|
||||
Who manages the Pokémon ?
|
||||
Who owned the Pokémon trademark ?
|
||||
What else include in Pokémon franchise ?
|
||||
How many seasons in Pokémon animme series ?
|
||||
What does Moses' rod turn into ?
|
||||
Who is Aron?
|
||||
Where did Moses grow up ?
|
||||
What happens at the command of the Moses ?
|
||||
Who manages the Pokémon ?
|
||||
Who owned the Pokémon trademark ?
|
||||
What else include in Pokémon franchise ?
|
||||
How many seasons in Pokémon animme series ?
|
||||
What does Moses' rod turn into ?
|
||||
Who is Aron?
|
||||
Where did Moses grow up ?
|
||||
What happens at the command of the Moses ?
|
||||
Who manages the Pokémon ?
|
||||
Who owned the Pokémon trademark ?
|
||||
What else include in Pokémon franchise ?
|
||||
How many seasons in Pokémon animme series ?
|
||||
What does Moses' rod turn into ?
|
||||
Who is Aron?
|
||||
Where did Moses grow up ?
|
||||
What happens at the command of the Moses ?
|
||||
Who manages the Pokémon ?
|
||||
Who owned the Pokémon trademark ?
|
||||
What else include in Pokémon franchise ?
|
||||
How many seasons in Pokémon animme series ?
|
||||
What does Moses' rod turn into ?
|
||||
Who is Aron?
|
||||
Where did Moses grow up ?
|
||||
What happens at the command of the Moses ?
|
||||
Who manages the Pokémon ?
|
||||
Who owned the Pokémon trademark ?
|
||||
What else include in Pokémon franchise ?
|
||||
How many seasons in Pokémon animme series ?
|
@ -0,0 +1,48 @@
|
||||
to a snake
|
||||
Moses' assistant
|
||||
Egyptian royal court
|
||||
let his rod turn in to a snake
|
||||
The Pokémon Company
|
||||
Nintendo
|
||||
world's top-selling toy brand, the top-selling trading card game
|
||||
over 20 seasons
|
||||
to a snake
|
||||
Moses' assistant
|
||||
Egyptian royal court
|
||||
let his rod turn in to a snake
|
||||
The Pokémon Company
|
||||
Nintendo
|
||||
world's top-selling toy brand, the top-selling trading card game
|
||||
over 20 seasons
|
||||
to a snake
|
||||
Moses' assistant
|
||||
Egyptian royal court
|
||||
let his rod turn in to a snake
|
||||
The Pokémon Company
|
||||
Nintendo
|
||||
world's top-selling toy brand, the top-selling trading card game
|
||||
over 20 seasons
|
||||
to a snake
|
||||
Moses' assistant
|
||||
Egyptian royal court
|
||||
let his rod turn in to a snake
|
||||
The Pokémon Company
|
||||
Nintendo
|
||||
world's top-selling toy brand, the top-selling trading card game
|
||||
over 20 seasons
|
||||
to a snake
|
||||
Moses' assistant
|
||||
Egyptian royal court
|
||||
let his rod turn in to a snake
|
||||
The Pokémon Company
|
||||
Nintendo
|
||||
world's top-selling toy brand, the top-selling trading card game
|
||||
over 20 seasons
|
||||
to a snake
|
||||
Moses' assistant
|
||||
Egyptian royal court
|
||||
let his rod turn in to a snake
|
||||
The Pokémon Company
|
||||
Nintendo
|
||||
world's top-selling toy brand, the top-selling trading card game
|
||||
over 20 seasons
|
@ -0,0 +1,8 @@
|
||||
What does Moses' rod turn into ?
|
||||
Who is Aron?
|
||||
Where did Moses grow up ?
|
||||
What happens at the command of the Moses ?
|
||||
Who manages the Pokémon ?
|
||||
Who owned the Pokémon trademark ?
|
||||
What else include in Pokémon franchise ?
|
||||
How many seasons in Pokémon animme series ?
|
@ -0,0 +1,8 @@
|
||||
to a snake
|
||||
Moses' assistant
|
||||
Egyptian royal court
|
||||
let his rod turn in to a snake
|
||||
The Pokémon Company
|
||||
Nintendo
|
||||
world's top-selling toy brand, the top-selling trading card game
|
||||
over 20 seasons
|
54
examples/research_projects/rag-end2end-retriever/test_run/test_finetune.sh
Executable file
54
examples/research_projects/rag-end2end-retriever/test_run/test_finetune.sh
Executable file
@ -0,0 +1,54 @@
|
||||
# Add parent directory to python path to access lightning_base.py
|
||||
export PYTHONPATH="../":"${PYTHONPATH}"
|
||||
|
||||
#creates the custom knowlegebase
|
||||
python use_own_knowledge_dataset.py
|
||||
|
||||
|
||||
# Start a single-node Ray cluster.
|
||||
ray start --head
|
||||
|
||||
# A sample finetuning run, you need to specify data_dir, output_dir and model_name_or_path
|
||||
# run ./examples/rag/finetune_rag_ray.sh --help to see all the possible options
|
||||
|
||||
|
||||
|
||||
python finetune_rag.py \
|
||||
--model_name_or_path facebook/rag-token-base \
|
||||
--model_type rag_token \
|
||||
--fp16 \
|
||||
--gpus 2 \
|
||||
--profile \
|
||||
--do_train \
|
||||
--end2end \
|
||||
--do_predict \
|
||||
--n_val -1 \
|
||||
--train_batch_size 1 \
|
||||
--eval_batch_size 1 \
|
||||
--max_source_length 128 \
|
||||
--max_target_length 25 \
|
||||
--val_max_target_length 25 \
|
||||
--test_max_target_length 25 \
|
||||
--label_smoothing 0.1 \
|
||||
--dropout 0.1 \
|
||||
--attention_dropout 0.1 \
|
||||
--weight_decay 0.001 \
|
||||
--adam_epsilon 1e-08 \
|
||||
--max_grad_norm 0.1 \
|
||||
--lr_scheduler polynomial \
|
||||
--learning_rate 3e-05 \
|
||||
--num_train_epochs 10 \
|
||||
--warmup_steps 500 \
|
||||
--gradient_accumulation_steps 1 \
|
||||
--distributed_retriever ray \
|
||||
--num_retrieval_workers 4 \
|
||||
--index_name custom \
|
||||
--context_encoder_name facebook/dpr-ctx_encoder-multiset-base \
|
||||
--index_gpus 1 \
|
||||
--gpu_order [6,7,8,9,0,1,2,3,5,4] \
|
||||
--indexing_freq 5
|
||||
|
||||
|
||||
|
||||
# Stop the Ray cluster.
|
||||
ray stop
|
@ -0,0 +1,16 @@
|
||||
export PYTHONPATH="../":"${PYTHONPATH}"
|
||||
|
||||
python use_own_knowledge_dataset.py
|
||||
|
||||
ray start --head
|
||||
python finetune_rag.py \
|
||||
--model_name_or_path facebook/rag-token-base \
|
||||
--model_type rag_token \
|
||||
--context_encoder_name facebook/dpr-ctx_encoder-multiset-base \
|
||||
--fp16 \
|
||||
--gpus 1 \
|
||||
--profile \
|
||||
--end2end \
|
||||
--index_name custom
|
||||
|
||||
ray stop
|
@ -0,0 +1,171 @@
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from datasets import Features, Sequence, Value, load_dataset
|
||||
|
||||
import faiss
|
||||
from transformers import DPRContextEncoder, DPRContextEncoderTokenizerFast, HfArgumentParser
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
torch.set_grad_enabled(False)
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
def split_text(text: str, n=100, character=" ") -> List[str]:
|
||||
"""Split the text every ``n``-th occurrence of ``character``"""
|
||||
text = text.split(character)
|
||||
return [character.join(text[i : i + n]).strip() for i in range(0, len(text), n)]
|
||||
|
||||
|
||||
def split_documents(documents: dict) -> dict:
|
||||
"""Split documents into passages"""
|
||||
titles, texts = [], []
|
||||
for title, text in zip(documents["title"], documents["text"]):
|
||||
if text is not None:
|
||||
for passage in split_text(text):
|
||||
titles.append(title if title is not None else "")
|
||||
texts.append(passage)
|
||||
return {"title": titles, "text": texts}
|
||||
|
||||
|
||||
def embed(documents: dict, ctx_encoder: DPRContextEncoder, ctx_tokenizer: DPRContextEncoderTokenizerFast) -> dict:
|
||||
"""Compute the DPR embeddings of document passages"""
|
||||
input_ids = ctx_tokenizer(
|
||||
documents["title"], documents["text"], truncation=True, padding="longest", return_tensors="pt"
|
||||
)["input_ids"]
|
||||
embeddings = ctx_encoder(input_ids.to(device=device), return_dict=True).pooler_output
|
||||
return {"embeddings": embeddings.detach().cpu().numpy()}
|
||||
|
||||
|
||||
def main(
|
||||
rag_example_args: "RagExampleArguments",
|
||||
processing_args: "ProcessingArguments",
|
||||
index_hnsw_args: "IndexHnswArguments",
|
||||
):
|
||||
|
||||
######################################
|
||||
logger.info("Step 1 - Create the dataset")
|
||||
######################################
|
||||
|
||||
# The dataset needed for RAG must have three columns:
|
||||
# - title (string): title of the document
|
||||
# - text (string): text of a passage of the document
|
||||
# - embeddings (array of dimension d): DPR representation of the passage
|
||||
# Let's say you have documents in tab-separated csv files with columns "title" and "text"
|
||||
assert os.path.isfile(rag_example_args.csv_path), "Please provide a valid path to a csv file"
|
||||
|
||||
# You can load a Dataset object this way
|
||||
dataset = load_dataset(
|
||||
"csv", data_files=[rag_example_args.csv_path], split="train", delimiter="\t", column_names=["title", "text"]
|
||||
)
|
||||
|
||||
# More info about loading csv files in the documentation: https://huggingface.co/docs/datasets/loading_datasets.html?highlight=csv#csv-files
|
||||
|
||||
# Then split the documents into passages of 100 words
|
||||
dataset = dataset.map(split_documents, batched=True, num_proc=processing_args.num_proc)
|
||||
|
||||
# And compute the embeddings
|
||||
ctx_encoder = DPRContextEncoder.from_pretrained(rag_example_args.dpr_ctx_encoder_model_name).to(device=device)
|
||||
ctx_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained(rag_example_args.dpr_ctx_encoder_model_name)
|
||||
new_features = Features(
|
||||
{"text": Value("string"), "title": Value("string"), "embeddings": Sequence(Value("float32"))}
|
||||
) # optional, save as float32 instead of float64 to save space
|
||||
dataset = dataset.map(
|
||||
partial(embed, ctx_encoder=ctx_encoder, ctx_tokenizer=ctx_tokenizer),
|
||||
batched=True,
|
||||
batch_size=processing_args.batch_size,
|
||||
features=new_features,
|
||||
)
|
||||
|
||||
# And finally save your dataset
|
||||
passages_path = os.path.join(rag_example_args.output_dir, "my_knowledge_dataset")
|
||||
dataset.save_to_disk(passages_path)
|
||||
# from datasets import load_from_disk
|
||||
# dataset = load_from_disk(passages_path) # to reload the dataset
|
||||
|
||||
######################################
|
||||
logger.info("Step 2 - Index the dataset")
|
||||
######################################
|
||||
|
||||
# Let's use the Faiss implementation of HNSW for fast approximate nearest neighbor search
|
||||
index = faiss.IndexHNSWFlat(index_hnsw_args.d, index_hnsw_args.m, faiss.METRIC_INNER_PRODUCT)
|
||||
dataset.add_faiss_index("embeddings", custom_index=index)
|
||||
|
||||
# And save the index
|
||||
index_path = os.path.join(rag_example_args.output_dir, "my_knowledge_dataset_hnsw_index.faiss")
|
||||
dataset.get_index("embeddings").save(index_path)
|
||||
# dataset.load_faiss_index("embeddings", index_path) # to reload the index
|
||||
|
||||
|
||||
@dataclass
|
||||
class RagExampleArguments:
|
||||
csv_path: str = field(
|
||||
default=str(Path(__file__).parent / "test_run" / "dummy-kb" / "my_knowledge_dataset.csv"),
|
||||
metadata={"help": "Path to a tab-separated csv file with columns 'title' and 'text'"},
|
||||
)
|
||||
question: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Question that is passed as input to RAG. Default is 'What does Moses' rod turn into ?'."},
|
||||
)
|
||||
rag_model_name: str = field(
|
||||
default="facebook/rag-sequence-nq",
|
||||
metadata={"help": "The RAG model to use. Either 'facebook/rag-sequence-nq' or 'facebook/rag-token-nq'"},
|
||||
)
|
||||
dpr_ctx_encoder_model_name: str = field(
|
||||
default="facebook/dpr-ctx_encoder-multiset-base",
|
||||
metadata={
|
||||
"help": "The DPR context encoder model to use. Either 'facebook/dpr-ctx_encoder-single-nq-base' or 'facebook/dpr-ctx_encoder-multiset-base'"
|
||||
},
|
||||
)
|
||||
output_dir: Optional[str] = field(
|
||||
default=str(Path(__file__).parent / "test_run" / "dummy-kb"),
|
||||
metadata={"help": "Path to a directory where the dataset passages and the index will be saved"},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessingArguments:
|
||||
num_proc: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The number of processes to use to split the documents into passages. Default is single process."
|
||||
},
|
||||
)
|
||||
batch_size: int = field(
|
||||
default=16,
|
||||
metadata={
|
||||
"help": "The batch size to use when computing the passages embeddings using the DPR context encoder."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexHnswArguments:
|
||||
d: int = field(
|
||||
default=768,
|
||||
metadata={"help": "The dimension of the embeddings to pass to the HNSW Faiss index."},
|
||||
)
|
||||
m: int = field(
|
||||
default=128,
|
||||
metadata={
|
||||
"help": "The number of bi-directional links created for every new element during the HNSW index construction."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
parser = HfArgumentParser((RagExampleArguments, ProcessingArguments, IndexHnswArguments))
|
||||
rag_example_args, processing_args, index_hnsw_args = parser.parse_args_into_dataclasses()
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
rag_example_args.output_dir = rag_example_args.output_dir or tmp_dir
|
||||
main(rag_example_args, processing_args, index_hnsw_args)
|
244
examples/research_projects/rag-end2end-retriever/utils_rag.py
Normal file
244
examples/research_projects/rag-end2end-retriever/utils_rag.py
Normal file
@ -0,0 +1,244 @@
|
||||
import itertools
|
||||
import json
|
||||
import linecache
|
||||
import os
|
||||
import pickle
|
||||
import re
|
||||
import socket
|
||||
import string
|
||||
from collections import Counter
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Iterable, List
|
||||
|
||||
import git
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from transformers import BartTokenizer, RagTokenizer, T5Tokenizer
|
||||
|
||||
|
||||
def encode_line(tokenizer, line, max_length, padding_side, pad_to_max_length=True, return_tensors="pt"):
|
||||
extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) and not line.startswith(" ") else {}
|
||||
tokenizer.padding_side = padding_side
|
||||
return tokenizer(
|
||||
[line],
|
||||
max_length=max_length,
|
||||
padding="max_length" if pad_to_max_length else None,
|
||||
truncation=True,
|
||||
return_tensors=return_tensors,
|
||||
add_special_tokens=True,
|
||||
**extra_kw,
|
||||
)
|
||||
|
||||
|
||||
def trim_batch(
|
||||
input_ids,
|
||||
pad_token_id,
|
||||
attention_mask=None,
|
||||
):
|
||||
"""Remove columns that are populated exclusively by pad_token_id"""
|
||||
keep_column_mask = input_ids.ne(pad_token_id).any(dim=0)
|
||||
if attention_mask is None:
|
||||
return input_ids[:, keep_column_mask]
|
||||
else:
|
||||
return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])
|
||||
|
||||
|
||||
class Seq2SeqDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer,
|
||||
data_dir,
|
||||
max_source_length,
|
||||
max_target_length,
|
||||
type_path="train",
|
||||
n_obs=None,
|
||||
src_lang=None,
|
||||
tgt_lang=None,
|
||||
prefix="",
|
||||
):
|
||||
super().__init__()
|
||||
self.src_file = Path(data_dir).joinpath(type_path + ".source")
|
||||
self.tgt_file = Path(data_dir).joinpath(type_path + ".target")
|
||||
self.src_lens = self.get_char_lens(self.src_file)
|
||||
self.max_source_length = max_source_length
|
||||
self.max_target_length = max_target_length
|
||||
assert min(self.src_lens) > 0, f"found empty line in {self.src_file}"
|
||||
self.tokenizer = tokenizer
|
||||
self.prefix = prefix
|
||||
if n_obs is not None:
|
||||
self.src_lens = self.src_lens[:n_obs]
|
||||
self.src_lang = src_lang
|
||||
self.tgt_lang = tgt_lang
|
||||
|
||||
def __len__(self):
|
||||
return len(self.src_lens)
|
||||
|
||||
def __getitem__(self, index) -> Dict[str, torch.Tensor]:
|
||||
index = index + 1 # linecache starts at 1
|
||||
source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
|
||||
tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
|
||||
assert source_line, f"empty source line for index {index}"
|
||||
assert tgt_line, f"empty tgt line for index {index}"
|
||||
|
||||
# Need to add eos token manually for T5
|
||||
if isinstance(self.tokenizer, T5Tokenizer):
|
||||
source_line += self.tokenizer.eos_token
|
||||
tgt_line += self.tokenizer.eos_token
|
||||
|
||||
# Pad source and target to the right
|
||||
source_tokenizer = (
|
||||
self.tokenizer.question_encoder if isinstance(self.tokenizer, RagTokenizer) else self.tokenizer
|
||||
)
|
||||
target_tokenizer = self.tokenizer.generator if isinstance(self.tokenizer, RagTokenizer) else self.tokenizer
|
||||
|
||||
source_inputs = encode_line(source_tokenizer, source_line, self.max_source_length, "right")
|
||||
target_inputs = encode_line(target_tokenizer, tgt_line, self.max_target_length, "right")
|
||||
|
||||
source_ids = source_inputs["input_ids"].squeeze()
|
||||
target_ids = target_inputs["input_ids"].squeeze()
|
||||
src_mask = source_inputs["attention_mask"].squeeze()
|
||||
return {
|
||||
"input_ids": source_ids,
|
||||
"attention_mask": src_mask,
|
||||
"decoder_input_ids": target_ids,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_char_lens(data_file):
|
||||
return [len(x) for x in Path(data_file).open().readlines()]
|
||||
|
||||
def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
|
||||
input_ids = torch.stack([x["input_ids"] for x in batch])
|
||||
masks = torch.stack([x["attention_mask"] for x in batch])
|
||||
target_ids = torch.stack([x["decoder_input_ids"] for x in batch])
|
||||
tgt_pad_token_id = (
|
||||
self.tokenizer.generator.pad_token_id
|
||||
if isinstance(self.tokenizer, RagTokenizer)
|
||||
else self.tokenizer.pad_token_id
|
||||
)
|
||||
src_pad_token_id = (
|
||||
self.tokenizer.question_encoder.pad_token_id
|
||||
if isinstance(self.tokenizer, RagTokenizer)
|
||||
else self.tokenizer.pad_token_id
|
||||
)
|
||||
y = trim_batch(target_ids, tgt_pad_token_id)
|
||||
source_ids, source_mask = trim_batch(input_ids, src_pad_token_id, attention_mask=masks)
|
||||
batch = {
|
||||
"input_ids": source_ids,
|
||||
"attention_mask": source_mask,
|
||||
"decoder_input_ids": y,
|
||||
}
|
||||
return batch
|
||||
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def flatten_list(summary_ids: List[List]):
|
||||
return [x for x in itertools.chain.from_iterable(summary_ids)]
|
||||
|
||||
|
||||
def save_git_info(folder_path: str) -> None:
|
||||
"""Save git information to output_dir/git_log.json"""
|
||||
repo_infos = get_git_info()
|
||||
save_json(repo_infos, os.path.join(folder_path, "git_log.json"))
|
||||
|
||||
|
||||
def save_json(content, path, indent=4, **json_dump_kwargs):
|
||||
with open(path, "w") as f:
|
||||
json.dump(content, f, indent=indent, **json_dump_kwargs)
|
||||
|
||||
|
||||
def load_json(path):
|
||||
with open(path) as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def get_git_info():
|
||||
repo = git.Repo(search_parent_directories=True)
|
||||
repo_infos = {
|
||||
"repo_id": str(repo),
|
||||
"repo_sha": str(repo.head.object.hexsha),
|
||||
"repo_branch": str(repo.active_branch),
|
||||
"hostname": str(socket.gethostname()),
|
||||
}
|
||||
return repo_infos
|
||||
|
||||
|
||||
def lmap(f: Callable, x: Iterable) -> List:
|
||||
"""list(map(f, x))"""
|
||||
return list(map(f, x))
|
||||
|
||||
|
||||
def pickle_save(obj, path):
|
||||
"""pickle.dump(obj, path)"""
|
||||
with open(path, "wb") as f:
|
||||
return pickle.dump(obj, f)
|
||||
|
||||
|
||||
def normalize_answer(s):
|
||||
"""Lower text and remove punctuation, articles and extra whitespace."""
|
||||
|
||||
def remove_articles(text):
|
||||
return re.sub(r"\b(a|an|the)\b", " ", text)
|
||||
|
||||
def white_space_fix(text):
|
||||
return " ".join(text.split())
|
||||
|
||||
def remove_punc(text):
|
||||
exclude = set(string.punctuation)
|
||||
return "".join(ch for ch in text if ch not in exclude)
|
||||
|
||||
def lower(text):
|
||||
return text.lower()
|
||||
|
||||
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
||||
|
||||
|
||||
def f1_score(prediction, ground_truth):
|
||||
prediction_tokens = normalize_answer(prediction).split()
|
||||
ground_truth_tokens = normalize_answer(ground_truth).split()
|
||||
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
|
||||
num_same = sum(common.values())
|
||||
if num_same == 0:
|
||||
return 0
|
||||
precision = 1.0 * num_same / len(prediction_tokens)
|
||||
recall = 1.0 * num_same / len(ground_truth_tokens)
|
||||
f1 = (2 * precision * recall) / (precision + recall)
|
||||
return f1
|
||||
|
||||
|
||||
def exact_match_score(prediction, ground_truth):
|
||||
return normalize_answer(prediction) == normalize_answer(ground_truth)
|
||||
|
||||
|
||||
def calculate_exact_match(output_lns: List[str], reference_lns: List[str]) -> Dict:
|
||||
assert len(output_lns) == len(reference_lns)
|
||||
em = 0
|
||||
for hypo, pred in zip(output_lns, reference_lns):
|
||||
em += exact_match_score(hypo, pred)
|
||||
if len(output_lns) > 0:
|
||||
em /= len(output_lns)
|
||||
return {"em": em}
|
||||
|
||||
|
||||
def is_rag_model(model_prefix):
|
||||
return model_prefix.startswith("rag")
|
||||
|
||||
|
||||
def set_extra_model_params(extra_params, hparams, config):
|
||||
equivalent_param = {p: p for p in extra_params}
|
||||
# T5 models don't have `dropout` param, they have `dropout_rate` instead
|
||||
equivalent_param["dropout"] = "dropout_rate"
|
||||
for p in extra_params:
|
||||
if getattr(hparams, p, None):
|
||||
if not hasattr(config, p) and not hasattr(config, equivalent_param[p]):
|
||||
logger.info("config doesn't have a `{}` attribute".format(p))
|
||||
delattr(hparams, p)
|
||||
continue
|
||||
set_p = p if hasattr(config, p) else equivalent_param[p]
|
||||
setattr(config, set_p, getattr(hparams, p))
|
||||
delattr(hparams, p)
|
||||
return hparams, config
|
@ -523,6 +523,9 @@ class RagModel(RagPreTrainedModel):
|
||||
self.question_encoder = question_encoder
|
||||
self.generator = generator
|
||||
|
||||
self.ctx_encoder = None
|
||||
self.context_encoder_training = False
|
||||
|
||||
@add_start_docstrings_to_model_forward(RAG_FORWARD_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=RetrievAugLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
@ -588,22 +591,58 @@ class RagModel(RagPreTrainedModel):
|
||||
n_docs=n_docs,
|
||||
return_tensors="pt",
|
||||
)
|
||||
context_input_ids, context_attention_mask, retrieved_doc_embeds, retrieved_doc_ids = (
|
||||
retriever_outputs["context_input_ids"],
|
||||
retriever_outputs["context_attention_mask"],
|
||||
retriever_outputs["retrieved_doc_embeds"],
|
||||
retriever_outputs["doc_ids"],
|
||||
)
|
||||
if self.context_encoder_training:
|
||||
|
||||
# set to correct device
|
||||
retrieved_doc_embeds = retrieved_doc_embeds.to(question_encoder_last_hidden_state)
|
||||
context_input_ids = context_input_ids.to(input_ids)
|
||||
context_attention_mask = context_attention_mask.to(input_ids)
|
||||
(
|
||||
context_input_ids,
|
||||
context_attention_mask,
|
||||
retrieved_doc_embeds,
|
||||
retrived_doc_input_ids,
|
||||
retrived_doc_attention_mask,
|
||||
retrieved_doc_ids,
|
||||
) = (
|
||||
retriever_outputs["context_input_ids"],
|
||||
retriever_outputs["context_attention_mask"],
|
||||
retriever_outputs["retrieved_doc_embeds"],
|
||||
retriever_outputs["tokenized_doc_ids"],
|
||||
retriever_outputs["tokenized_doc_attention_mask"],
|
||||
retriever_outputs["doc_ids"],
|
||||
)
|
||||
|
||||
# compute doc_scores
|
||||
doc_scores = torch.bmm(
|
||||
question_encoder_last_hidden_state.unsqueeze(1), retrieved_doc_embeds.transpose(1, 2)
|
||||
).squeeze(1)
|
||||
context_input_ids = context_input_ids.to(input_ids)
|
||||
context_attention_mask = context_attention_mask.to(input_ids)
|
||||
|
||||
retrived_doc_input_ids = retrived_doc_input_ids.to(input_ids)
|
||||
retrived_doc_attention_mask = retrived_doc_attention_mask.to(input_ids)
|
||||
retrieved_doc_embeds = self.ctx_encoder(
|
||||
retrived_doc_input_ids, attention_mask=retrived_doc_attention_mask, return_dict=True
|
||||
).pooler_output
|
||||
retrieved_doc_embeds = retrieved_doc_embeds.view(
|
||||
-1, n_docs, question_encoder_last_hidden_state.shape[1]
|
||||
) # reshaping
|
||||
|
||||
# compute doc_scores involving ctx_encoder
|
||||
doc_scores = torch.bmm(
|
||||
question_encoder_last_hidden_state.unsqueeze(1), retrieved_doc_embeds.transpose(1, 2)
|
||||
).squeeze(1)
|
||||
|
||||
else:
|
||||
context_input_ids, context_attention_mask, retrieved_doc_embeds, retrieved_doc_ids = (
|
||||
retriever_outputs["context_input_ids"],
|
||||
retriever_outputs["context_attention_mask"],
|
||||
retriever_outputs["retrieved_doc_embeds"],
|
||||
retriever_outputs["doc_ids"],
|
||||
)
|
||||
|
||||
# set to correct device
|
||||
retrieved_doc_embeds = retrieved_doc_embeds.to(question_encoder_last_hidden_state)
|
||||
context_input_ids = context_input_ids.to(input_ids)
|
||||
context_attention_mask = context_attention_mask.to(input_ids)
|
||||
|
||||
# compute doc_scores
|
||||
doc_scores = torch.bmm(
|
||||
question_encoder_last_hidden_state.unsqueeze(1), retrieved_doc_embeds.transpose(1, 2)
|
||||
).squeeze(1)
|
||||
else:
|
||||
assert (
|
||||
context_input_ids is not None
|
||||
@ -710,6 +749,10 @@ class RagSequenceForGeneration(RagPreTrainedModel):
|
||||
def set_retriever(self, retriever: RagRetriever):
|
||||
self.rag.retriever = retriever
|
||||
|
||||
def set_context_encoder_for_training(self, ctx_encoder: PreTrainedModel):
|
||||
self.rag.context_encoder_training = True
|
||||
self.rag.ctx_encoder = ctx_encoder
|
||||
|
||||
@add_start_docstrings_to_model_forward(RAG_FORWARD_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=RetrievAugLMMarginOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
@ -1095,6 +1138,10 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
def set_retriever(self, retriever: RagRetriever):
|
||||
self.rag.retriever = retriever
|
||||
|
||||
def set_context_encoder_for_training(self, ctx_encoder: PreTrainedModel):
|
||||
self.rag.context_encoder_training = True
|
||||
self.rag.ctx_encoder = ctx_encoder
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
|
@ -22,6 +22,7 @@ from typing import Iterable, List, Optional, Tuple
|
||||
import numpy as np
|
||||
|
||||
from ...file_utils import cached_path, is_datasets_available, is_faiss_available, is_remote_url, requires_backends
|
||||
from ...tokenization_utils import PreTrainedTokenizer
|
||||
from ...tokenization_utils_base import BatchEncoding
|
||||
from ...utils import logging
|
||||
from .configuration_rag import RagConfig
|
||||
@ -378,6 +379,9 @@ class RagRetriever:
|
||||
if self._init_retrieval:
|
||||
self.init_retrieval()
|
||||
|
||||
self.ctx_encoder_tokenizer = None
|
||||
self.return_tokenized_docs = False
|
||||
|
||||
@staticmethod
|
||||
def _build_index(config):
|
||||
if config.index_name == "legacy":
|
||||
@ -543,6 +547,11 @@ class RagRetriever:
|
||||
doc_ids, retrieved_doc_embeds = self._main_retrieve(question_hidden_states, n_docs)
|
||||
return retrieved_doc_embeds, doc_ids, self.index.get_doc_dicts(doc_ids)
|
||||
|
||||
def set_ctx_encoder_tokenizer(self, ctx_encoder_tokenizer: PreTrainedTokenizer):
|
||||
# used in end2end retriever training
|
||||
self.ctx_encoder_tokenizer = ctx_encoder_tokenizer
|
||||
self.return_tokenized_docs = True
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
question_input_ids: List[List[int]],
|
||||
@ -594,12 +603,42 @@ class RagRetriever:
|
||||
docs, input_strings, prefix, n_docs, return_tensors=return_tensors
|
||||
)
|
||||
|
||||
return BatchEncoding(
|
||||
{
|
||||
"context_input_ids": context_input_ids,
|
||||
"context_attention_mask": context_attention_mask,
|
||||
"retrieved_doc_embeds": retrieved_doc_embeds,
|
||||
"doc_ids": doc_ids,
|
||||
},
|
||||
tensor_type=return_tensors,
|
||||
)
|
||||
if self.return_tokenized_docs:
|
||||
retrived_doc_text = []
|
||||
retrived_doc_title = []
|
||||
|
||||
for b_idx in range(len(docs)):
|
||||
for doc_idx in range(n_docs):
|
||||
retrived_doc_text.append(docs[b_idx]["text"][doc_idx])
|
||||
retrived_doc_title.append(docs[b_idx]["title"][doc_idx])
|
||||
|
||||
tokenized_docs = self.ctx_encoder_tokenizer(
|
||||
retrived_doc_title,
|
||||
retrived_doc_text,
|
||||
truncation=True,
|
||||
padding="longest",
|
||||
return_tensors=return_tensors,
|
||||
)
|
||||
|
||||
return BatchEncoding(
|
||||
{
|
||||
"context_input_ids": context_input_ids,
|
||||
"context_attention_mask": context_attention_mask,
|
||||
"retrieved_doc_embeds": retrieved_doc_embeds,
|
||||
"doc_ids": doc_ids,
|
||||
"tokenized_doc_ids": tokenized_docs["input_ids"],
|
||||
"tokenized_doc_attention_mask": tokenized_docs["attention_mask"],
|
||||
},
|
||||
tensor_type=return_tensors,
|
||||
)
|
||||
|
||||
else:
|
||||
return BatchEncoding(
|
||||
{
|
||||
"context_input_ids": context_input_ids,
|
||||
"context_attention_mask": context_attention_mask,
|
||||
"retrieved_doc_embeds": retrieved_doc_embeds,
|
||||
"doc_ids": doc_ids,
|
||||
},
|
||||
tensor_type=return_tensors,
|
||||
)
|
||||
|
@ -26,7 +26,7 @@ import numpy as np
|
||||
from transformers import BartTokenizer, T5Tokenizer
|
||||
from transformers.file_utils import cached_property, is_datasets_available, is_faiss_available, is_torch_available
|
||||
from transformers.models.bert.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES
|
||||
from transformers.models.dpr.tokenization_dpr import DPRQuestionEncoderTokenizer
|
||||
from transformers.models.dpr.tokenization_dpr import DPRContextEncoderTokenizer, DPRQuestionEncoderTokenizer
|
||||
from transformers.models.roberta.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES
|
||||
from transformers.testing_utils import (
|
||||
require_sentencepiece,
|
||||
@ -55,6 +55,7 @@ if is_torch_available() and is_datasets_available() and is_faiss_available():
|
||||
AutoConfig,
|
||||
AutoModel,
|
||||
AutoModelForSeq2SeqLM,
|
||||
DPRContextEncoder,
|
||||
RagConfig,
|
||||
RagModel,
|
||||
RagRetriever,
|
||||
@ -179,6 +180,10 @@ class RagTestMixin:
|
||||
def dpr_tokenizer(self) -> DPRQuestionEncoderTokenizer:
|
||||
return DPRQuestionEncoderTokenizer.from_pretrained(os.path.join(self.tmpdirname, "dpr_tokenizer"))
|
||||
|
||||
@cached_property
|
||||
def dpr_ctx_encoder_tokenizer(self) -> DPRContextEncoderTokenizer:
|
||||
return DPRContextEncoderTokenizer.from_pretrained(os.path.join(self.tmpdirname, "dpr_tokenizer"))
|
||||
|
||||
@cached_property
|
||||
def bart_tokenizer(self) -> BartTokenizer:
|
||||
return BartTokenizer.from_pretrained(os.path.join(self.tmpdirname, "bart_tokenizer"))
|
||||
@ -246,6 +251,46 @@ class RagTestMixin:
|
||||
# doc scores
|
||||
self.assertEqual(outputs.doc_scores.shape, (input_ids.shape[0], self.n_docs))
|
||||
|
||||
def check_model_with_end2end_retriever(
|
||||
self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, **kwargs
|
||||
):
|
||||
self.assertIsNotNone(config.question_encoder)
|
||||
self.assertIsNotNone(config.generator)
|
||||
|
||||
context_encoder_tokenizer = self.dpr_ctx_encoder_tokenizer
|
||||
dpr_context_encoder = DPRContextEncoder(config.question_encoder) # dpr is a twin tower
|
||||
|
||||
retriever = self.get_retriever(config)
|
||||
retriever.set_ctx_encoder_tokenizer(context_encoder_tokenizer) # setting the ctx_encoder_tokenizer.
|
||||
|
||||
for model_class in [RagTokenForGeneration, RagSequenceForGeneration]:
|
||||
model = model_class(config, retriever=retriever)
|
||||
model.set_context_encoder_for_training(dpr_context_encoder) # set the context_encoder for training
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
self.assertTrue(model.config.is_encoder_decoder)
|
||||
|
||||
outputs = model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
)
|
||||
|
||||
# logits
|
||||
self.assertEqual(
|
||||
outputs.logits.shape,
|
||||
(self.n_docs * decoder_input_ids.shape[0], decoder_input_ids.shape[1], config.generator.vocab_size),
|
||||
)
|
||||
# generator encoder last hidden states
|
||||
self.assertEqual(
|
||||
outputs.generator_enc_last_hidden_state.shape,
|
||||
(self.n_docs * decoder_input_ids.shape[0], self.max_combined_length, config.generator.hidden_size),
|
||||
)
|
||||
# doc scores
|
||||
self.assertEqual(outputs.doc_scores.shape, (input_ids.shape[0], self.n_docs))
|
||||
|
||||
def check_model_generate_from_context_input_ids(
|
||||
self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, **kwargs
|
||||
):
|
||||
@ -538,6 +583,10 @@ class RagTestMixin:
|
||||
inputs_dict = self.config_and_inputs
|
||||
self.check_model_with_retriever(**inputs_dict)
|
||||
|
||||
def test_model_with_end2end_retriever(self):
|
||||
inputs_dict = self.config_and_inputs
|
||||
self.check_model_with_end2end_retriever(**inputs_dict)
|
||||
|
||||
def test_model_without_retriever(self):
|
||||
inputs_dict = self.config_and_inputs
|
||||
self.check_model_without_retriever(**inputs_dict)
|
||||
|
@ -28,7 +28,7 @@ from transformers.models.bart.configuration_bart import BartConfig
|
||||
from transformers.models.bart.tokenization_bart import BartTokenizer
|
||||
from transformers.models.bert.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES
|
||||
from transformers.models.dpr.configuration_dpr import DPRConfig
|
||||
from transformers.models.dpr.tokenization_dpr import DPRQuestionEncoderTokenizer
|
||||
from transformers.models.dpr.tokenization_dpr import DPRContextEncoderTokenizer, DPRQuestionEncoderTokenizer
|
||||
from transformers.models.rag.configuration_rag import RagConfig
|
||||
from transformers.models.rag.retrieval_rag import CustomHFIndex, RagRetriever
|
||||
from transformers.models.roberta.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES
|
||||
@ -115,6 +115,9 @@ class RagRetrieverTest(TestCase):
|
||||
def get_dpr_tokenizer(self) -> DPRQuestionEncoderTokenizer:
|
||||
return DPRQuestionEncoderTokenizer.from_pretrained(os.path.join(self.tmpdirname, "dpr_tokenizer"))
|
||||
|
||||
def get_dpr_ctx_encoder_tokenizer(self) -> DPRContextEncoderTokenizer:
|
||||
return DPRContextEncoderTokenizer.from_pretrained(os.path.join(self.tmpdirname, "dpr_tokenizer"))
|
||||
|
||||
def get_bart_tokenizer(self) -> BartTokenizer:
|
||||
return BartTokenizer.from_pretrained(os.path.join(self.tmpdirname, "bart_tokenizer"))
|
||||
|
||||
@ -359,3 +362,26 @@ class RagRetrieverTest(TestCase):
|
||||
self.assertIsInstance(context_input_ids, torch.Tensor)
|
||||
self.assertIsInstance(context_attention_mask, torch.Tensor)
|
||||
self.assertIsInstance(retrieved_doc_embeds, torch.Tensor)
|
||||
|
||||
@require_torch
|
||||
@require_tokenizers
|
||||
@require_sentencepiece
|
||||
def test_custom_hf_index_end2end_retriever_call(self):
|
||||
|
||||
context_encoder_tokenizer = self.get_dpr_ctx_encoder_tokenizer()
|
||||
n_docs = 1
|
||||
retriever = self.get_dummy_custom_hf_index_retriever(from_disk=False)
|
||||
retriever.set_ctx_encoder_tokenizer(context_encoder_tokenizer)
|
||||
|
||||
question_input_ids = [[5, 7], [10, 11]]
|
||||
hidden_states = np.array(
|
||||
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||
)
|
||||
out = retriever(question_input_ids, hidden_states, prefix=retriever.config.generator.prefix, n_docs=n_docs)
|
||||
|
||||
self.assertEqual(
|
||||
len(out), 6
|
||||
) # check whether the retriever output consist of 6 attributes including tokenized docs
|
||||
self.assertEqual(
|
||||
all(k in out for k in ("tokenized_doc_ids", "tokenized_doc_attention_mask")), True
|
||||
) # check for doc token related keys in dictionary.
|
||||
|
Loading…
Reference in New Issue
Block a user