mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00

* added rag WIP * path fix * Formatting / renaming prior to actual work * added rag WIP * path fix * Formatting / renaming prior to actual work * added rag WIP * path fix * Formatting / renaming prior to actual work * added rag WIP * Formatting / renaming prior to actual work * First commit * improve comments * Retrieval evaluation scripts * refactor to include modeling outputs + MPI retriever * Fix rag-token model + refactor * Various fixes + finetuning logic * use_bos fix * Retrieval refactor * Finetuning refactoring and cleanup * Add documentation and cleanup * Remove set_up_rag_env.sh file * Fix retrieval wit HF index * Fix import errors * Fix quality errors * Refactor as per suggestions in https://github.com/huggingface/transformers/pull/6813#issuecomment-687208867 * fix quality * Fix RAG Sequence generation * minor cleanup plus initial tests * fix test * fix tests 2 * Comments fix * post-merge fixes * Improve readme + post-rebase refactor * Extra dependencied for tests * Fix tests * Fix tests 2 * Refactor test requirements * Fix tests 3 * Post-rebase refactor * rename nlp->datasets * RAG integration tests * add tokenizer to slow integration test and allow retriever to run on cpu * add tests; fix position ids warning * change structure * change structure * add from encoder generator * save working solution * make all integration tests pass * add RagTokenizer.save/from_pretrained and RagRetriever.save/from_pretrained * don't save paths * delete unnecessary imports * pass config to AutoTokenizer.from_pretrained for Rag tokenizers * init wiki_dpr only once * hardcode legacy index and passages paths (todo: add the right urls) * finalize config * finalize retriver api and config api * LegacyIndex index download refactor * add dpr to autotokenizer * make from pretrained more flexible * fix ragfortokengeneration * small name changes in tokenizer * add labels to models * change default index name * add retrieval tests * finish token generate * align test with previous version and make all tests pass * add tests * finalize tests * implement thoms suggestions * add first version of test * make first tests work * make retriever platform agnostic * naming * style * add legacy index URL * docstrings + simple retrieval test for distributed * clean model api * add doc_ids to retriever's outputs * fix retrieval tests * finish model outputs * finalize model api * fix generate problem for rag * fix generate for other modles * fix some tests * save intermediate * set generate to default * big refactor generate * delete rag_api * correct pip faiss install * fix auto tokenization test * fix faiss install * fix test * move the distributed logic to examples * model page * docs * finish tests * fix dependencies * fix import in __init__ * Refactor eval_rag and finetune scripts * start docstring * add psutil to test * fix tf test * move require torch to top * fix retrieval test * align naming * finish automodel * fix repo consistency * test ragtokenizer save/load * add rag model output docs * fix ragtokenizer save/load from pretrained * fix tokenizer dir * remove torch in retrieval * fix docs * fixe finetune scripts * finish model docs * finish docs * remove auto model for now * add require torch * remove solved todos * integrate sylvains suggestions * sams comments * correct mistake on purpose * improve README * Add generation test cases * fix rag token * clean token generate * fix test * add note to test * fix attention mask * add t5 test for rag * Fix handling prefix in finetune.py * don't overwrite index_name Co-authored-by: Patrick Lewis <plewis@fb.com> Co-authored-by: Aleksandra Piktus <piktus@devfair0141.h2.fair> Co-authored-by: Aleksandra Piktus <piktus@learnfair5102.h2.fair> Co-authored-by: Aleksandra Piktus <piktus@learnfair5067.h2.fair> Co-authored-by: Your Name <you@example.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Quentin Lhoest <lhoest.q@gmail.com>
136 lines
6.1 KiB
Python
136 lines
6.1 KiB
Python
import logging
|
|
import os
|
|
from typing import List, Tuple
|
|
|
|
import numpy as np
|
|
import psutil
|
|
import torch
|
|
import torch.distributed as dist
|
|
|
|
from transformers import RagRetriever
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class RagPyTorchDistributedRetriever(RagRetriever):
|
|
"""
|
|
A distributed retriever built on top of the ``torch.distributed`` communication package. During training all workers
|
|
initalize their own instance of the retriever, however, only the main worker loads the index into memory. The index is stored
|
|
in cpu memory. The index will also work well in a non-distributed setup.
|
|
|
|
Args:
|
|
config (:class:`~transformers.RagConfig`):
|
|
The configuration of the RAG model this Retriever is used with. Contains parameters indicating which ``Index`` to build.
|
|
question_encoder_tokenizer (:class:`~transformers.PretrainedTokenizer`):
|
|
The tokenizer that was used to tokenize the question.
|
|
It is used to decode the question and then use the generator_tokenizer.
|
|
generator_tokenizer (:class:`~transformers.PretrainedTokenizer`):
|
|
The tokenizer used for the generator part of the RagModel.
|
|
"""
|
|
|
|
_init_retrieval = False
|
|
|
|
def __init__(self, config, question_encoder_tokenizer, generator_tokenizer):
|
|
super().__init__(
|
|
config, question_encoder_tokenizer=question_encoder_tokenizer, generator_tokenizer=generator_tokenizer
|
|
)
|
|
|
|
self.process_group = None
|
|
|
|
def init_retrieval(self, distributed_port: int):
|
|
"""
|
|
Retriever initalization function, needs to be called from the training process. The function sets some common parameters
|
|
and environment variables. On top of that, (only) the main process in the process group loads the index into memory.
|
|
|
|
Args:
|
|
distributed_port (:obj:`int`):
|
|
The port on which the main communication of the training run is carried out. We set the port for retrieval-related
|
|
communication as ``distributed_port + 1``.
|
|
"""
|
|
|
|
logger.info("initializing retrieval")
|
|
|
|
# initializing a separate process group for retrievel as the default
|
|
# nccl backend doesn't support gather/scatter operations while gloo
|
|
# is too slow to replace nccl for the core gpu communication
|
|
if dist.is_initialized():
|
|
logger.info("dist initialized")
|
|
# needs to be set manually
|
|
os.environ["GLOO_SOCKET_IFNAME"] = self._infer_socket_ifname()
|
|
# avoid clash with the NCCL port
|
|
os.environ["MASTER_PORT"] = str(distributed_port + 1)
|
|
self.process_group = dist.new_group(ranks=None, backend="gloo")
|
|
|
|
# initialize retriever only on the main worker
|
|
if not dist.is_initialized() or self._is_main():
|
|
logger.info("dist not initialized / main")
|
|
self.index.init_index()
|
|
|
|
# all processes wait untill the retriever is initialized by the main process
|
|
if dist.is_initialized():
|
|
torch.distributed.barrier(group=self.process_group)
|
|
|
|
def _is_main(self):
|
|
return dist.get_rank(group=self.process_group) == 0
|
|
|
|
def _scattered(self, scatter_list, target_shape, target_type=torch.float32):
|
|
target_tensor = torch.empty(target_shape, dtype=target_type)
|
|
dist.scatter(target_tensor, src=0, scatter_list=scatter_list, group=self.process_group)
|
|
return target_tensor
|
|
|
|
def _infer_socket_ifname(self):
|
|
addrs = psutil.net_if_addrs()
|
|
# a hacky way to deal with varying network interface names
|
|
ifname = next((addr for addr in addrs if addr.startswith("e")), None)
|
|
return ifname
|
|
|
|
def retrieve(self, question_hidden_states: np.ndarray, n_docs: int) -> Tuple[np.ndarray, List[dict]]:
|
|
"""
|
|
Retrieves documents for specified ``question_hidden_states``. The main process, which has the access to the index stored in memory, gathers queries
|
|
from all the processes in the main training process group, performs the retrieval and scatters back the results.
|
|
|
|
Args:
|
|
question_hidden_states (:obj:`np.ndarray` of shape :obj:`(batch_size, vector_size)`):
|
|
A batch of query vectors to retrieve with.
|
|
n_docs (:obj:`int`):
|
|
The number of docs retrieved per query.
|
|
|
|
Ouput:
|
|
retrieved_doc_embeds (:obj:`np.ndarray` of shape :obj:`(batch_size, n_docs, dim)`
|
|
The retrieval embeddings of the retrieved docs per query.
|
|
doc_ids (:obj:`np.ndarray` of shape :obj:`batch_size, n_docs`)
|
|
The ids of the documents in the index
|
|
doc_dicts (:obj:`List[dict]`):
|
|
The retrieved_doc_embeds examples per query.
|
|
"""
|
|
|
|
# single GPU training
|
|
if not dist.is_initialized():
|
|
doc_ids, retrieved_doc_embeds = self._main_retrieve(question_hidden_states, n_docs)
|
|
return retrieved_doc_embeds, doc_ids, self.index.get_doc_dicts(doc_ids)
|
|
|
|
# distributed training
|
|
world_size = dist.get_world_size(group=self.process_group)
|
|
|
|
# gather logic
|
|
gather_list = None
|
|
if self._is_main():
|
|
gather_list = [torch.empty(question_hidden_states.shape, dtype=torch.float32) for _ in range(world_size)]
|
|
dist.gather(torch.tensor(question_hidden_states), dst=0, gather_list=gather_list, group=self.process_group)
|
|
|
|
# scatter logic
|
|
n_queries = question_hidden_states.shape[0]
|
|
scatter_ids = []
|
|
scatter_vectors = []
|
|
if self._is_main():
|
|
assert len(gather_list) == world_size
|
|
ids, vectors = self._main_retrieve(torch.cat(gather_list).numpy(), n_docs)
|
|
ids, vectors = torch.tensor(ids), torch.tensor(vectors)
|
|
scatter_ids = self._chunk_tensor(ids, n_queries)
|
|
scatter_vectors = self._chunk_tensor(vectors, n_queries)
|
|
doc_ids = self._scattered(scatter_ids, [n_queries, n_docs], target_type=torch.int64)
|
|
retrieved_doc_embeds = self._scattered(scatter_vectors, [n_queries, n_docs, question_hidden_states.shape[1]])
|
|
|
|
return retrieved_doc_embeds.numpy(), doc_ids.numpy(), self.index.get_doc_dicts(doc_ids)
|