mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00

* added rag WIP * path fix * Formatting / renaming prior to actual work * added rag WIP * path fix * Formatting / renaming prior to actual work * added rag WIP * path fix * Formatting / renaming prior to actual work * added rag WIP * Formatting / renaming prior to actual work * First commit * improve comments * Retrieval evaluation scripts * refactor to include modeling outputs + MPI retriever * Fix rag-token model + refactor * Various fixes + finetuning logic * use_bos fix * Retrieval refactor * Finetuning refactoring and cleanup * Add documentation and cleanup * Remove set_up_rag_env.sh file * Fix retrieval wit HF index * Fix import errors * Fix quality errors * Refactor as per suggestions in https://github.com/huggingface/transformers/pull/6813#issuecomment-687208867 * fix quality * Fix RAG Sequence generation * minor cleanup plus initial tests * fix test * fix tests 2 * Comments fix * post-merge fixes * Improve readme + post-rebase refactor * Extra dependencied for tests * Fix tests * Fix tests 2 * Refactor test requirements * Fix tests 3 * Post-rebase refactor * rename nlp->datasets * RAG integration tests * add tokenizer to slow integration test and allow retriever to run on cpu * add tests; fix position ids warning * change structure * change structure * add from encoder generator * save working solution * make all integration tests pass * add RagTokenizer.save/from_pretrained and RagRetriever.save/from_pretrained * don't save paths * delete unnecessary imports * pass config to AutoTokenizer.from_pretrained for Rag tokenizers * init wiki_dpr only once * hardcode legacy index and passages paths (todo: add the right urls) * finalize config * finalize retriver api and config api * LegacyIndex index download refactor * add dpr to autotokenizer * make from pretrained more flexible * fix ragfortokengeneration * small name changes in tokenizer * add labels to models * change default index name * add retrieval tests * finish token generate * align test with previous version and make all tests pass * add tests * finalize tests * implement thoms suggestions * add first version of test * make first tests work * make retriever platform agnostic * naming * style * add legacy index URL * docstrings + simple retrieval test for distributed * clean model api * add doc_ids to retriever's outputs * fix retrieval tests * finish model outputs * finalize model api * fix generate problem for rag * fix generate for other modles * fix some tests * save intermediate * set generate to default * big refactor generate * delete rag_api * correct pip faiss install * fix auto tokenization test * fix faiss install * fix test * move the distributed logic to examples * model page * docs * finish tests * fix dependencies * fix import in __init__ * Refactor eval_rag and finetune scripts * start docstring * add psutil to test * fix tf test * move require torch to top * fix retrieval test * align naming * finish automodel * fix repo consistency * test ragtokenizer save/load * add rag model output docs * fix ragtokenizer save/load from pretrained * fix tokenizer dir * remove torch in retrieval * fix docs * fixe finetune scripts * finish model docs * finish docs * remove auto model for now * add require torch * remove solved todos * integrate sylvains suggestions * sams comments * correct mistake on purpose * improve README * Add generation test cases * fix rag token * clean token generate * fix test * add note to test * fix attention mask * add t5 test for rag * Fix handling prefix in finetune.py * don't overwrite index_name Co-authored-by: Patrick Lewis <plewis@fb.com> Co-authored-by: Aleksandra Piktus <piktus@devfair0141.h2.fair> Co-authored-by: Aleksandra Piktus <piktus@learnfair5102.h2.fair> Co-authored-by: Aleksandra Piktus <piktus@learnfair5067.h2.fair> Co-authored-by: Your Name <you@example.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Quentin Lhoest <lhoest.q@gmail.com>
157 lines
6.0 KiB
Python
157 lines
6.0 KiB
Python
import json
|
|
import os
|
|
import shutil
|
|
import sys
|
|
import tempfile
|
|
import unittest
|
|
from unittest import TestCase
|
|
from unittest.mock import patch
|
|
|
|
import numpy as np
|
|
from datasets import Dataset
|
|
|
|
import faiss
|
|
from transformers.configuration_bart import BartConfig
|
|
from transformers.configuration_dpr import DPRConfig
|
|
from transformers.configuration_rag import RagConfig
|
|
from transformers.file_utils import is_datasets_available, is_faiss_available, is_psutil_available, is_torch_available
|
|
from transformers.tokenization_bart import BartTokenizer
|
|
from transformers.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES
|
|
from transformers.tokenization_dpr import DPRQuestionEncoderTokenizer
|
|
from transformers.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES
|
|
|
|
|
|
sys.path.append(os.path.join(os.getcwd())) # noqa: E402 # noqa: E402 # isort:skip
|
|
|
|
from examples.rag.distributed_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip
|
|
|
|
|
|
def require_distributed_retrieval(test_case):
|
|
"""
|
|
Decorator marking a test that requires a set of dependencies necessary for pefrorm retrieval with
|
|
:class:`~transformers.RagRetriever`.
|
|
|
|
These tests are skipped when respective libraries are not installed.
|
|
|
|
"""
|
|
if not (is_torch_available() and is_datasets_available() and is_faiss_available() and is_psutil_available()):
|
|
test_case = unittest.skip("test requires PyTorch, Datasets, Faiss, psutil")(test_case)
|
|
return test_case
|
|
|
|
|
|
@require_distributed_retrieval
|
|
class RagRetrieverTest(TestCase):
|
|
def setUp(self):
|
|
self.tmpdirname = tempfile.mkdtemp()
|
|
self.retrieval_vector_size = 8
|
|
|
|
# DPR tok
|
|
vocab_tokens = [
|
|
"[UNK]",
|
|
"[CLS]",
|
|
"[SEP]",
|
|
"[PAD]",
|
|
"[MASK]",
|
|
"want",
|
|
"##want",
|
|
"##ed",
|
|
"wa",
|
|
"un",
|
|
"runn",
|
|
"##ing",
|
|
",",
|
|
"low",
|
|
"lowest",
|
|
]
|
|
dpr_tokenizer_path = os.path.join(self.tmpdirname, "dpr_tokenizer")
|
|
os.makedirs(dpr_tokenizer_path, exist_ok=True)
|
|
self.vocab_file = os.path.join(dpr_tokenizer_path, DPR_VOCAB_FILES_NAMES["vocab_file"])
|
|
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
|
|
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
|
|
|
# BART tok
|
|
vocab = [
|
|
"l",
|
|
"o",
|
|
"w",
|
|
"e",
|
|
"r",
|
|
"s",
|
|
"t",
|
|
"i",
|
|
"d",
|
|
"n",
|
|
"\u0120",
|
|
"\u0120l",
|
|
"\u0120n",
|
|
"\u0120lo",
|
|
"\u0120low",
|
|
"er",
|
|
"\u0120lowest",
|
|
"\u0120newer",
|
|
"\u0120wider",
|
|
"<unk>",
|
|
]
|
|
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
|
merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""]
|
|
self.special_tokens_map = {"unk_token": "<unk>"}
|
|
|
|
bart_tokenizer_path = os.path.join(self.tmpdirname, "bart_tokenizer")
|
|
os.makedirs(bart_tokenizer_path, exist_ok=True)
|
|
self.vocab_file = os.path.join(bart_tokenizer_path, BART_VOCAB_FILES_NAMES["vocab_file"])
|
|
self.merges_file = os.path.join(bart_tokenizer_path, BART_VOCAB_FILES_NAMES["merges_file"])
|
|
with open(self.vocab_file, "w", encoding="utf-8") as fp:
|
|
fp.write(json.dumps(vocab_tokens) + "\n")
|
|
with open(self.merges_file, "w", encoding="utf-8") as fp:
|
|
fp.write("\n".join(merges))
|
|
|
|
def get_dpr_tokenizer(self) -> DPRQuestionEncoderTokenizer:
|
|
return DPRQuestionEncoderTokenizer.from_pretrained(os.path.join(self.tmpdirname, "dpr_tokenizer"))
|
|
|
|
def get_bart_tokenizer(self) -> BartTokenizer:
|
|
return BartTokenizer.from_pretrained(os.path.join(self.tmpdirname, "bart_tokenizer"))
|
|
|
|
def tearDown(self):
|
|
shutil.rmtree(self.tmpdirname)
|
|
|
|
def get_dummy_pytorch_distributed_retriever(self, init_retrieval, port=12345) -> RagPyTorchDistributedRetriever:
|
|
dataset = Dataset.from_dict(
|
|
{
|
|
"id": ["0", "1"],
|
|
"text": ["foo", "bar"],
|
|
"title": ["Foo", "Bar"],
|
|
"embeddings": [np.ones(self.retrieval_vector_size), 2 * np.ones(self.retrieval_vector_size)],
|
|
}
|
|
)
|
|
dataset.add_faiss_index("embeddings", string_factory="Flat", metric_type=faiss.METRIC_INNER_PRODUCT)
|
|
config = RagConfig(
|
|
retrieval_vector_size=self.retrieval_vector_size,
|
|
question_encoder=DPRConfig().to_dict(),
|
|
generator=BartConfig().to_dict(),
|
|
)
|
|
with patch("transformers.retrieval_rag.load_dataset") as mock_load_dataset:
|
|
mock_load_dataset.return_value = dataset
|
|
retriever = RagPyTorchDistributedRetriever(
|
|
config,
|
|
question_encoder_tokenizer=self.get_dpr_tokenizer(),
|
|
generator_tokenizer=self.get_bart_tokenizer(),
|
|
)
|
|
if init_retrieval:
|
|
retriever.init_retrieval(port)
|
|
return retriever
|
|
|
|
def test_pytorch_distributed_retriever_retrieve(self):
|
|
n_docs = 1
|
|
retriever = self.get_dummy_pytorch_distributed_retriever(init_retrieval=True)
|
|
hidden_states = np.array(
|
|
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
|
)
|
|
retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs)
|
|
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
|
|
self.assertEqual(len(doc_dicts), 2)
|
|
self.assertEqual(sorted(doc_dicts[0]), ["embeddings", "id", "text", "title"])
|
|
self.assertEqual(len(doc_dicts[0]["id"]), n_docs)
|
|
self.assertEqual(doc_dicts[0]["id"][0], "1") # max inner product is reached with second doc
|
|
self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc
|
|
self.assertListEqual(list(doc_ids), [1, 0])
|