transformers/examples/rag/distributed_retriever.py
Ola Piktus c754c41c61
RAG (#6813)
* added rag WIP

* path fix

* Formatting / renaming prior to actual work

* added rag WIP

* path fix

* Formatting / renaming prior to actual work

* added rag WIP

* path fix

* Formatting / renaming prior to actual work

* added rag WIP

* Formatting / renaming prior to actual work

* First commit

* improve comments

* Retrieval evaluation scripts

* refactor to include modeling outputs + MPI retriever

* Fix rag-token model + refactor

* Various fixes + finetuning logic

* use_bos fix

* Retrieval refactor

* Finetuning refactoring and cleanup

* Add documentation and cleanup

* Remove set_up_rag_env.sh file

* Fix retrieval wit HF index

* Fix import errors

* Fix quality errors

* Refactor as per suggestions in https://github.com/huggingface/transformers/pull/6813#issuecomment-687208867

* fix quality

* Fix RAG Sequence generation

* minor cleanup plus initial tests

* fix test

* fix tests 2

* Comments fix

* post-merge fixes

* Improve readme + post-rebase refactor

* Extra dependencied for tests

* Fix tests

* Fix tests 2

* Refactor test requirements

* Fix tests 3

* Post-rebase refactor

* rename nlp->datasets

* RAG integration tests

* add tokenizer to slow integration test and allow retriever to run on cpu

* add tests; fix position ids warning

* change structure

* change structure

* add from encoder generator

* save working solution

* make all integration tests pass

* add RagTokenizer.save/from_pretrained and RagRetriever.save/from_pretrained

* don't save paths

* delete unnecessary imports

* pass config to AutoTokenizer.from_pretrained for Rag tokenizers

* init wiki_dpr only once

* hardcode legacy index and passages paths (todo: add the right urls)

* finalize config

* finalize retriver api and config api

* LegacyIndex index download refactor

* add dpr to autotokenizer

* make from pretrained more flexible

* fix ragfortokengeneration

* small name changes in tokenizer

* add labels to models

* change default index name

* add retrieval tests

* finish token generate

* align test with previous version and make all tests pass

* add tests

* finalize tests

* implement thoms suggestions

* add first version of test

* make first tests work

* make retriever platform agnostic

* naming

* style

* add legacy index URL

* docstrings + simple retrieval test for distributed

* clean model api

* add doc_ids to retriever's outputs

* fix retrieval tests

* finish model outputs

* finalize model api

* fix generate problem for rag

* fix generate for other modles

* fix some tests

* save intermediate

* set generate to default

* big refactor generate

* delete rag_api

* correct pip faiss install

* fix auto tokenization test

* fix faiss install

* fix test

* move the distributed logic to examples

* model page

* docs

* finish tests

* fix dependencies

* fix import in __init__

* Refactor eval_rag and finetune scripts

* start docstring

* add psutil to test

* fix tf test

* move require torch to top

* fix retrieval test

* align naming

* finish automodel

* fix repo consistency

* test ragtokenizer save/load

* add rag model output docs

* fix ragtokenizer save/load from pretrained

* fix tokenizer dir

* remove torch in retrieval

* fix docs

* fixe finetune scripts

* finish model docs

* finish docs

* remove auto model for now

* add require torch

* remove solved todos

* integrate sylvains suggestions

* sams comments

* correct mistake on purpose

* improve README

* Add generation test cases

* fix rag token

* clean token generate

* fix test

* add note to test

* fix attention mask

* add t5 test for rag

* Fix handling prefix in finetune.py

* don't overwrite index_name

Co-authored-by: Patrick Lewis <plewis@fb.com>
Co-authored-by: Aleksandra Piktus <piktus@devfair0141.h2.fair>
Co-authored-by: Aleksandra Piktus <piktus@learnfair5102.h2.fair>
Co-authored-by: Aleksandra Piktus <piktus@learnfair5067.h2.fair>
Co-authored-by: Your Name <you@example.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Quentin Lhoest <lhoest.q@gmail.com>
2020-09-22 18:29:58 +02:00

136 lines
6.1 KiB
Python

import logging
import os
from typing import List, Tuple
import numpy as np
import psutil
import torch
import torch.distributed as dist
from transformers import RagRetriever
logger = logging.getLogger(__name__)
class RagPyTorchDistributedRetriever(RagRetriever):
"""
A distributed retriever built on top of the ``torch.distributed`` communication package. During training all workers
initalize their own instance of the retriever, however, only the main worker loads the index into memory. The index is stored
in cpu memory. The index will also work well in a non-distributed setup.
Args:
config (:class:`~transformers.RagConfig`):
The configuration of the RAG model this Retriever is used with. Contains parameters indicating which ``Index`` to build.
question_encoder_tokenizer (:class:`~transformers.PretrainedTokenizer`):
The tokenizer that was used to tokenize the question.
It is used to decode the question and then use the generator_tokenizer.
generator_tokenizer (:class:`~transformers.PretrainedTokenizer`):
The tokenizer used for the generator part of the RagModel.
"""
_init_retrieval = False
def __init__(self, config, question_encoder_tokenizer, generator_tokenizer):
super().__init__(
config, question_encoder_tokenizer=question_encoder_tokenizer, generator_tokenizer=generator_tokenizer
)
self.process_group = None
def init_retrieval(self, distributed_port: int):
"""
Retriever initalization function, needs to be called from the training process. The function sets some common parameters
and environment variables. On top of that, (only) the main process in the process group loads the index into memory.
Args:
distributed_port (:obj:`int`):
The port on which the main communication of the training run is carried out. We set the port for retrieval-related
communication as ``distributed_port + 1``.
"""
logger.info("initializing retrieval")
# initializing a separate process group for retrievel as the default
# nccl backend doesn't support gather/scatter operations while gloo
# is too slow to replace nccl for the core gpu communication
if dist.is_initialized():
logger.info("dist initialized")
# needs to be set manually
os.environ["GLOO_SOCKET_IFNAME"] = self._infer_socket_ifname()
# avoid clash with the NCCL port
os.environ["MASTER_PORT"] = str(distributed_port + 1)
self.process_group = dist.new_group(ranks=None, backend="gloo")
# initialize retriever only on the main worker
if not dist.is_initialized() or self._is_main():
logger.info("dist not initialized / main")
self.index.init_index()
# all processes wait untill the retriever is initialized by the main process
if dist.is_initialized():
torch.distributed.barrier(group=self.process_group)
def _is_main(self):
return dist.get_rank(group=self.process_group) == 0
def _scattered(self, scatter_list, target_shape, target_type=torch.float32):
target_tensor = torch.empty(target_shape, dtype=target_type)
dist.scatter(target_tensor, src=0, scatter_list=scatter_list, group=self.process_group)
return target_tensor
def _infer_socket_ifname(self):
addrs = psutil.net_if_addrs()
# a hacky way to deal with varying network interface names
ifname = next((addr for addr in addrs if addr.startswith("e")), None)
return ifname
def retrieve(self, question_hidden_states: np.ndarray, n_docs: int) -> Tuple[np.ndarray, List[dict]]:
"""
Retrieves documents for specified ``question_hidden_states``. The main process, which has the access to the index stored in memory, gathers queries
from all the processes in the main training process group, performs the retrieval and scatters back the results.
Args:
question_hidden_states (:obj:`np.ndarray` of shape :obj:`(batch_size, vector_size)`):
A batch of query vectors to retrieve with.
n_docs (:obj:`int`):
The number of docs retrieved per query.
Ouput:
retrieved_doc_embeds (:obj:`np.ndarray` of shape :obj:`(batch_size, n_docs, dim)`
The retrieval embeddings of the retrieved docs per query.
doc_ids (:obj:`np.ndarray` of shape :obj:`batch_size, n_docs`)
The ids of the documents in the index
doc_dicts (:obj:`List[dict]`):
The retrieved_doc_embeds examples per query.
"""
# single GPU training
if not dist.is_initialized():
doc_ids, retrieved_doc_embeds = self._main_retrieve(question_hidden_states, n_docs)
return retrieved_doc_embeds, doc_ids, self.index.get_doc_dicts(doc_ids)
# distributed training
world_size = dist.get_world_size(group=self.process_group)
# gather logic
gather_list = None
if self._is_main():
gather_list = [torch.empty(question_hidden_states.shape, dtype=torch.float32) for _ in range(world_size)]
dist.gather(torch.tensor(question_hidden_states), dst=0, gather_list=gather_list, group=self.process_group)
# scatter logic
n_queries = question_hidden_states.shape[0]
scatter_ids = []
scatter_vectors = []
if self._is_main():
assert len(gather_list) == world_size
ids, vectors = self._main_retrieve(torch.cat(gather_list).numpy(), n_docs)
ids, vectors = torch.tensor(ids), torch.tensor(vectors)
scatter_ids = self._chunk_tensor(ids, n_queries)
scatter_vectors = self._chunk_tensor(vectors, n_queries)
doc_ids = self._scattered(scatter_ids, [n_queries, n_docs], target_type=torch.int64)
retrieved_doc_embeds = self._scattered(scatter_vectors, [n_queries, n_docs, question_hidden_states.shape[1]])
return retrieved_doc_embeds.numpy(), doc_ids.numpy(), self.index.get_doc_dicts(doc_ids)