import functools import math import os # noqa: F401 from random import choice, randint from time import time import numpy as np import torch import torch.utils.checkpoint as checkpoint from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler from tqdm import tqdm import faiss # noqa: F401 import nlp # noqa: F401 import pandas as pd from elasticsearch import Elasticsearch # noqa: F401 from elasticsearch.helpers import bulk, streaming_bulk # noqa: F401 from transformers import AdamW, AutoModel, AutoModelForSeq2SeqLM, AutoTokenizer, get_linear_schedule_with_warmup pd.set_option("display.max_colwidth", None) ############### # Sparse index ############### def make_es_index_snippets(es_client, passages_dset, index_name="english_wiki_kilt_snippets_100w"): index_config = { "settings": { "number_of_shards": 1, "analysis": {"analyzer": {"stop_standard": {"type": "standard", " stopwords": "_english_"}}}, }, "mappings": { "properties": { "article_title": {"type": "text", "analyzer": "standard", "similarity": "BM25"}, "section_title": {"type": "text", "analyzer": "standard", "similarity": "BM25"}, "passage_text": {"type": "text", "analyzer": "standard", "similarity": "BM25"}, } }, } es_client.indices.create(index=index_name, body=index_config) number_of_docs = passages_dset.num_rows progress = tqdm(unit="docs", total=number_of_docs) successes = 0 def passage_generator(): for passage in passages_dset: yield passage # create the ES index for ok, action in streaming_bulk(client=es_client, index=index_name, actions=passage_generator(),): progress.update(1) successes += ok print("Indexed %d documents" % (successes,)) def query_es_index(question, es_client, index_name="english_wiki_kilt_snippets_100w", n_results=10, min_length=20): q = question.lower() banned = ["how", "why", "what", "where", "which", "do", "does", "is", "?", "eli5", "eli5:"] q = " ".join([w for w in q.split() if w not in banned]) response = es_client.search( index=index_name, body={ "query": { "multi_match": { "query": q, "fields": ["article_title", "section_title", "passage_text^2"], "type": "cross_fields", } }, "size": 2 * n_results, }, ) hits = response["hits"]["hits"] support_doc = "

" + "

".join([hit["_source"]["passage_text"] for hit in hits]) res_list = [dict([(k, hit["_source"][k]) for k in hit["_source"] if k != "passage_text"]) for hit in hits] for r, hit in zip(res_list, hits): r["passage_id"] = hit["_id"] r["score"] = hit["_score"] r["passage_text"] = hit["_source"]["passage_text"] res_list = [res for res in res_list if len(res["passage_text"].split()) > min_length][:n_results] return support_doc, res_list ############### # ELI5 retriever training ############### class ELI5DatasetQARetriver(Dataset): def __init__(self, examples_array, extra_answer_threshold=3, min_answer_length=64, training=True, n_samples=None): self.data = examples_array self.answer_thres = extra_answer_threshold self.min_length = min_answer_length self.training = training self.n_samples = self.data.num_rows if n_samples is None else n_samples def __len__(self): return self.n_samples def make_example(self, idx): example = self.data[idx] question = example["title"] if self.training: answers = [a for i, (a, sc) in enumerate(zip(example["answers"]["text"], example["answers"]["score"]))] answer_tab = choice(answers).split(" ") start_idx = randint(0, max(0, len(answer_tab) - self.min_length)) answer_span = " ".join(answer_tab[start_idx:]) else: answer_span = example["answers"]["text"][0] return (question, answer_span) def __getitem__(self, idx): return self.make_example(idx % self.data.num_rows) class RetrievalQAEmbedder(torch.nn.Module): def __init__(self, sent_encoder, dim): super(RetrievalQAEmbedder, self).__init__() self.sent_encoder = sent_encoder self.output_dim = 128 self.project_q = torch.nn.Linear(dim, self.output_dim, bias=False) self.project_a = torch.nn.Linear(dim, self.output_dim, bias=False) self.ce_loss = torch.nn.CrossEntropyLoss(reduction="mean") def embed_sentences_checkpointed(self, input_ids, attention_mask, checkpoint_batch_size=-1): # reproduces BERT forward pass with checkpointing if checkpoint_batch_size < 0 or input_ids.shape[0] < checkpoint_batch_size: return self.sent_encoder(input_ids, attention_mask=attention_mask)[1] else: # prepare implicit variables device = input_ids.device input_shape = input_ids.size() token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) head_mask = [None] * self.sent_encoder.config.num_hidden_layers extended_attention_mask: torch.Tensor = self.sent_encoder.get_extended_attention_mask( attention_mask, input_shape, device ) # define function for checkpointing def partial_encode(*inputs): encoder_outputs = self.sent_encoder.encoder(inputs[0], attention_mask=inputs[1], head_mask=head_mask,) sequence_output = encoder_outputs[0] pooled_output = self.sent_encoder.pooler(sequence_output) return pooled_output # run embedding layer on everything at once embedding_output = self.sent_encoder.embeddings( input_ids=input_ids, position_ids=None, token_type_ids=token_type_ids, inputs_embeds=None ) # run encoding and pooling on one mini-batch at a time pooled_output_list = [] for b in range(math.ceil(input_ids.shape[0] / checkpoint_batch_size)): b_embedding_output = embedding_output[b * checkpoint_batch_size : (b + 1) * checkpoint_batch_size] b_attention_mask = extended_attention_mask[b * checkpoint_batch_size : (b + 1) * checkpoint_batch_size] pooled_output = checkpoint.checkpoint(partial_encode, b_embedding_output, b_attention_mask) pooled_output_list.append(pooled_output) return torch.cat(pooled_output_list, dim=0) def embed_questions(self, q_ids, q_mask, checkpoint_batch_size=-1): q_reps = self.embed_sentences_checkpointed(q_ids, q_mask, checkpoint_batch_size) return self.project_q(q_reps) def embed_answers(self, a_ids, a_mask, checkpoint_batch_size=-1): a_reps = self.embed_sentences_checkpointed(a_ids, a_mask, checkpoint_batch_size) return self.project_a(a_reps) def forward(self, q_ids, q_mask, a_ids, a_mask, checkpoint_batch_size=-1): device = q_ids.device q_reps = self.embed_questions(q_ids, q_mask, checkpoint_batch_size) a_reps = self.embed_answers(a_ids, a_mask, checkpoint_batch_size) compare_scores = torch.mm(q_reps, a_reps.t()) loss_qa = self.ce_loss(compare_scores, torch.arange(compare_scores.shape[1]).to(device)) loss_aq = self.ce_loss(compare_scores.t(), torch.arange(compare_scores.shape[0]).to(device)) loss = (loss_qa + loss_aq) / 2 return loss def make_qa_retriever_model(model_name="google/bert_uncased_L-8_H-512_A-8", from_file=None, device="cuda:0"): tokenizer = AutoTokenizer.from_pretrained(model_name) bert_model = AutoModel.from_pretrained(model_name).to(device) # run bert_model on a dummy batch to get output dimension d_ids = torch.LongTensor( [[bert_model.config.bos_token_id if bert_model.config.bos_token_id is not None else 1]] ).to(device) d_mask = torch.LongTensor([[1]]).to(device) sent_dim = bert_model(d_ids, attention_mask=d_mask)[1].shape[-1] qa_embedder = RetrievalQAEmbedder(bert_model, sent_dim).to(device) if from_file is not None: param_dict = torch.load(from_file) # has model weights, optimizer, and scheduler states qa_embedder.load_state_dict(param_dict["model"]) return tokenizer, qa_embedder def make_qa_retriever_batch(qa_list, tokenizer, max_len=64, device="cuda:0"): q_ls = [q for q, a in qa_list] a_ls = [a for q, a in qa_list] q_toks = tokenizer.batch_encode_plus(q_ls, max_length=max_len, pad_to_max_length=True) q_ids, q_mask = ( torch.LongTensor(q_toks["input_ids"]).to(device), torch.LongTensor(q_toks["attention_mask"]).to(device), ) a_toks = tokenizer.batch_encode_plus(a_ls, max_length=max_len, pad_to_max_length=True) a_ids, a_mask = ( torch.LongTensor(a_toks["input_ids"]).to(device), torch.LongTensor(a_toks["attention_mask"]).to(device), ) return (q_ids, q_mask, a_ids, a_mask) def train_qa_retriever_epoch(model, dataset, tokenizer, optimizer, scheduler, args, e=0): model.train() # make iterator train_sampler = RandomSampler(dataset) model_collate_fn = functools.partial( make_qa_retriever_batch, tokenizer=tokenizer, max_len=args.max_length, device="cuda:0" ) data_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler, collate_fn=model_collate_fn) epoch_iterator = tqdm(data_loader, desc="Iteration", disable=True) # accumulate loss since last print loc_steps = 0 loc_loss = 0.0 st_time = time() for step, batch in enumerate(epoch_iterator): q_ids, q_mask, a_ids, a_mask = batch pre_loss = model(q_ids, q_mask, a_ids, a_mask, checkpoint_batch_size=args.checkpoint_batch_size) loss = pre_loss.sum() # optimizer loss.backward() optimizer.step() scheduler.step() model.zero_grad() # some printing within the epoch loc_loss += loss.item() loc_steps += 1 if step % args.print_freq == 0 or step == 1: print( "{:2d} {:5d} of {:5d} \t L: {:.3f} \t -- {:.3f}".format( e, step, len(dataset) // args.batch_size, loc_loss / loc_steps, time() - st_time, ) ) loc_loss = 0 loc_steps = 0 def train_qa_retriever_joint_epoch(model, dataset_list, tokenizer, optimizer, scheduler, args, e=0): model.train() model_collate_fn = functools.partial( make_qa_retriever_batch, tokenizer=tokenizer, max_len=args.max_length, device="cuda:0" ) # make iterator train_samplers = [RandomSampler(dataset) for dataset in dataset_list] data_loaders = [ DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler, collate_fn=model_collate_fn) for dataset, train_sampler in zip(dataset_list, train_samplers) ] iterators = [iter(dloader) for dloader in data_loaders] joint_iter = zip(*iterators) # accumulate loss since last print loc_steps = 0 loc_loss = 0.0 st_time = time() for step, (batches,) in enumerate(zip(joint_iter)): for batch in batches: q_ids, q_mask, a_ids, a_mask = batch loss = model(q_ids, q_mask, a_ids, a_mask, checkpoint_batch_size=args.checkpoint_batch_size) # optimizer loss.backward() optimizer.step() scheduler.step() model.zero_grad() # some printing within the epoch loc_loss += loss.item() loc_steps += 1 if step % args.print_freq == 0: print( "{:2d} {:5d} of {:5d} \t L: {:.3f} \t -- {:.3f}".format( e, step, len(dataset_list[0]) // args.batch_size, loc_loss / loc_steps, time() - st_time, ) ) loc_loss = 0 loc_steps = 0 def evaluate_qa_retriever(model, dataset, tokenizer, args): model.eval() # make iterator eval_sampler = SequentialSampler(dataset) model_collate_fn = functools.partial( make_qa_retriever_batch, tokenizer=tokenizer, max_len=args.max_length, device="cuda:0" ) data_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=eval_sampler, collate_fn=model_collate_fn) epoch_iterator = tqdm(data_loader, desc="Iteration", disable=True) tot_loss = 0.0 with torch.no_grad(): for step, batch in enumerate(epoch_iterator): q_ids, q_mask, a_ids, a_mask = batch loss = model(q_ids, q_mask, a_ids, a_mask) tot_loss += loss.item() return tot_loss / (step + 1) def train_qa_retriever(qar_model, qar_tokenizer, qar_train_dset, qar_valid_dset, qar_args): qar_optimizer = AdamW(qar_model.parameters(), lr=qar_args.learning_rate, eps=1e-8) qar_scheduler = get_linear_schedule_with_warmup( qar_optimizer, num_warmup_steps=100, num_training_steps=(qar_args.num_epochs + 1) * math.ceil(len(qar_train_dset) / qar_args.batch_size), ) for e in range(qar_args.num_epochs): train_qa_retriever_epoch(qar_model, qar_train_dset, qar_tokenizer, qar_optimizer, qar_scheduler, qar_args, e) m_save_dict = { "model": qar_model.state_dict(), "optimizer": qar_optimizer.state_dict(), "scheduler": qar_scheduler.state_dict(), } print("Saving model {}".format(qar_args.model_save_name)) torch.save(m_save_dict, "{}_{}.pth".format(qar_args.model_save_name, e)) eval_loss = evaluate_qa_retriever(qar_model, qar_valid_dset, qar_tokenizer, qar_args) print("Evaluation loss epoch {:4d}: {:.3f}".format(e, eval_loss)) ############### # ELI5 seq2seq model training ############### class ELI5DatasetS2S(Dataset): def __init__( self, examples_array, make_doc_fun=None, extra_answer_threshold=3, document_cache=None, training=True ): self.training = training self.data = examples_array self.make_doc_function = make_doc_fun self.document_cache = {} if document_cache is None else document_cache assert not (make_doc_fun is None and document_cache is None) # make index of specific question-answer pairs from multi-answers if self.training: self.qa_id_list = [ (i, j) for i, qa in enumerate(self.data) for j, (a, sc) in enumerate(zip(qa["answers"]["text"], qa["answers"]["score"])) if j == 0 or sc >= extra_answer_threshold ] else: self.qa_id_list = [(i, 0) for i in range(self.data.num_rows)] def __len__(self): return len(self.qa_id_list) def make_example(self, idx): i, j = self.qa_id_list[idx] example = self.data[i] question = example["title"] + " " + example["selftext"] answer = example["answers"]["text"][j] q_id = example["q_id"] if self.make_doc_function is not None: self.document_cache[q_id] = self.document_cache.get(q_id, self.make_doc_function(example["title"])) document = self.document_cache[q_id] in_st = "question: {} context: {}".format( question.lower().replace(" --t--", "").strip(), document.lower().strip(), ) out_st = answer return (in_st, out_st) def __getitem__(self, idx): return self.make_example(idx) def make_qa_s2s_model(model_name="facebook/bart-large", from_file=None, device="cuda:0"): tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device) if from_file is not None: param_dict = torch.load(from_file) # has model weights, optimizer, and scheduler states model.load_state_dict(param_dict["model"]) return tokenizer, model def make_qa_s2s_batch(qa_list, tokenizer, max_len=64, max_a_len=360, device="cuda:0"): q_ls = [q for q, a in qa_list] a_ls = [a for q, a in qa_list] q_toks = tokenizer.batch_encode_plus(q_ls, max_length=max_len, pad_to_max_length=True) q_ids, q_mask = ( torch.LongTensor(q_toks["input_ids"]).to(device), torch.LongTensor(q_toks["attention_mask"]).to(device), ) a_toks = tokenizer.batch_encode_plus(a_ls, max_length=min(max_len, max_a_len), pad_to_max_length=True) a_ids, a_mask = ( torch.LongTensor(a_toks["input_ids"]).to(device), torch.LongTensor(a_toks["attention_mask"]).to(device), ) lm_labels = a_ids[:, 1:].contiguous().clone() lm_labels[a_mask[:, 1:].contiguous() == 0] = -100 model_inputs = { "input_ids": q_ids, "attention_mask": q_mask, "decoder_input_ids": a_ids[:, :-1].contiguous(), "lm_labels": lm_labels, } return model_inputs def train_qa_s2s_epoch(model, dataset, tokenizer, optimizer, scheduler, args, e=0, curriculum=False): model.train() # make iterator if curriculum: train_sampler = SequentialSampler(dataset) else: train_sampler = RandomSampler(dataset) model_collate_fn = functools.partial( make_qa_s2s_batch, tokenizer=tokenizer, max_len=args.max_length, device="cuda:0" ) data_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler, collate_fn=model_collate_fn) epoch_iterator = tqdm(data_loader, desc="Iteration", disable=True) # accumulate loss since last print loc_steps = 0 loc_loss = 0.0 st_time = time() for step, batch_inputs in enumerate(epoch_iterator): pre_loss = model(**batch_inputs)[0] loss = pre_loss.sum() / pre_loss.shape[0] loss.backward() # optimizer if step % args.backward_freq == 0: optimizer.step() scheduler.step() model.zero_grad() # some printing within the epoch loc_loss += loss.item() loc_steps += 1 if step % args.print_freq == 0 or step == 1: print( "{:2d} {:5d} of {:5d} \t L: {:.3f} \t -- {:.3f}".format( e, step, len(dataset) // args.batch_size, loc_loss / loc_steps, time() - st_time, ) ) loc_loss = 0 loc_steps = 0 def eval_qa_s2s_epoch(model, dataset, tokenizer, args): model.eval() # make iterator train_sampler = SequentialSampler(dataset) model_collate_fn = functools.partial( make_qa_s2s_batch, tokenizer=tokenizer, max_len=args.max_length, device="cuda:0" ) data_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler, collate_fn=model_collate_fn) epoch_iterator = tqdm(data_loader, desc="Iteration", disable=True) # accumulate loss since last print loc_steps = 0 loc_loss = 0.0 st_time = time() with torch.no_grad(): for step, batch_inputs in enumerate(epoch_iterator): pre_loss = model(**batch_inputs)[0] loss = pre_loss.sum() / pre_loss.shape[0] loc_loss += loss.item() loc_steps += 1 if step % args.print_freq == 0: print( "{:5d} of {:5d} \t L: {:.3f} \t -- {:.3f}".format( step, len(dataset) // args.batch_size, loc_loss / loc_steps, time() - st_time, ) ) print("Total \t L: {:.3f} \t -- {:.3f}".format(loc_loss / loc_steps, time() - st_time,)) def train_qa_s2s(qa_s2s_model, qa_s2s_tokenizer, s2s_train_dset, s2s_valid_dset, s2s_args): s2s_optimizer = AdamW(qa_s2s_model.parameters(), lr=s2s_args.learning_rate, eps=1e-8) s2s_scheduler = get_linear_schedule_with_warmup( s2s_optimizer, num_warmup_steps=400, num_training_steps=(s2s_args.num_epochs + 1) * math.ceil(len(s2s_train_dset) / s2s_args.batch_size), ) for e in range(s2s_args.num_epochs): train_qa_s2s_epoch( qa_s2s_model, s2s_train_dset, qa_s2s_tokenizer, s2s_optimizer, s2s_scheduler, s2s_args, e, curriculum=(e == 0), ) m_save_dict = { "model": qa_s2s_model.state_dict(), "optimizer": s2s_optimizer.state_dict(), "scheduler": s2s_scheduler.state_dict(), } print("Saving model {}".format(s2s_args.model_save_name)) eval_qa_s2s_epoch(qa_s2s_model, s2s_valid_dset, qa_s2s_tokenizer, s2s_args) torch.save(m_save_dict, "{}_{}.pth".format(s2s_args.model_save_name, e)) # generate answer from input "question: ... context:

..." def qa_s2s_generate( question_doc, qa_s2s_model, qa_s2s_tokenizer, num_answers=1, num_beams=None, min_len=64, max_len=256, do_sample=False, temp=1.0, top_p=None, top_k=None, max_input_length=512, device="cuda:0", ): model_inputs = make_qa_s2s_batch([(question_doc, "A")], qa_s2s_tokenizer, max_input_length, device=device,) n_beams = num_answers if num_beams is None else max(num_beams, num_answers) generated_ids = qa_s2s_model.generate( input_ids=model_inputs["input_ids"], attention_mask=model_inputs["attention_mask"], min_length=min_len, max_length=max_len, do_sample=do_sample, early_stopping=True, num_beams=1 if do_sample else n_beams, temperature=temp, top_k=top_k, top_p=top_p, eos_token_id=qa_s2s_tokenizer.eos_token_id, no_repeat_ngram_size=3, num_return_sequences=num_answers, decoder_start_token_id=qa_s2s_tokenizer.bos_token_id, ) return [qa_s2s_tokenizer.decode(ans_ids, skip_special_tokens=True).strip() for ans_ids in generated_ids] ############### # ELI5-trained retrieval model usage ############### def embed_passages_for_retrieval(passages, tokenizer, qa_embedder, max_length=128, device="cuda:0"): a_toks = tokenizer.batch_encode_plus(passages, max_length=max_length, pad_to_max_length=True) a_ids, a_mask = ( torch.LongTensor(a_toks["input_ids"]).to(device), torch.LongTensor(a_toks["attention_mask"]).to(device), ) with torch.no_grad(): a_reps = qa_embedder.embed_answers(a_ids, a_mask).cpu().type(torch.float) return a_reps.numpy() def embed_questions_for_retrieval(q_ls, tokenizer, qa_embedder, device="cuda:0"): q_toks = tokenizer.batch_encode_plus(q_ls, max_length=128, pad_to_max_length=True) q_ids, q_mask = ( torch.LongTensor(q_toks["input_ids"]).to(device), torch.LongTensor(q_toks["attention_mask"]).to(device), ) with torch.no_grad(): q_reps = qa_embedder.embed_questions(q_ids, q_mask).cpu().type(torch.float) return q_reps.numpy() def make_qa_dense_index( qa_embedder, tokenizer, passages_dset, batch_size=512, max_length=128, index_name="kilt_passages_reps.dat", dtype="float32", device="cuda:0", ): st_time = time() fp = np.memmap(index_name, dtype=dtype, mode="w+", shape=(passages_dset.num_rows, 128)) n_batches = math.ceil(passages_dset.num_rows / batch_size) for i in range(n_batches): passages = [p for p in passages_dset[i * batch_size : (i + 1) * batch_size]["passage_text"]] reps = embed_passages_for_retrieval(passages, tokenizer, qa_embedder, max_length, device) fp[i * batch_size : (i + 1) * batch_size] = reps if i % 50 == 0: print(i, time() - st_time) def evaluate_retriever(qa_list, retriever_func, scoring_func, n_ret=10, verbose=False): total_retriever_time = 0.0 total_retriever_score = 0.0 st_time = time() for i, (question, answer) in enumerate(qa_list): r_time = time() retrieved_passages = retriever_func(question, n_ret) total_retriever_time += time() - r_time total_retriever_score += scoring_func(retrieved_passages, answer) if verbose and ((i + 1) % 500 == 0 or i <= 1): print( "{:03d}: S-{:.4f} T-{:.4f} | {:.2f}".format( i + 1, total_retriever_score / (i + 1), total_retriever_time / (i + 1), time() - st_time ) ) return {"idf_recall": total_retriever_score / (i + 1), "retrieval_time": total_retriever_time / (i + 1)} # build a support document for the question out of Wikipedia snippets def query_qa_dense_index( question, qa_embedder, tokenizer, wiki_passages, wiki_index, n_results=10, min_length=20, device="cuda:0" ): q_rep = embed_questions_for_retrieval([question], tokenizer, qa_embedder, device=device) D, I = wiki_index.search(q_rep, 2 * n_results) res_passages = [wiki_passages[int(i)] for i in I[0]] support_doc = "

" + "

".join([p["passage_text"] for p in res_passages]) res_list = [dict([(k, p[k]) for k in wiki_passages.column_names]) for p in res_passages] res_list = [res for res in res_list if len(res["passage_text"].split()) > min_length][:n_results] for r, sc in zip(res_list, D[0]): r["score"] = float(sc) return support_doc, res_list def batch_query_qa_dense_index(questions, qa_embedder, tokenizer, wiki_passages, wiki_index, n_results=10): q_rep = embed_questions_for_retrieval(questions, tokenizer, qa_embedder) D, I = wiki_index.search(q_rep, n_results) res_passages_lst = [[wiki_passages[int(i)] for i in i_lst] for i_lst in I] support_doc_lst = [ "

" + "

".join([p["passage_text"] for p in res_passages]) for res_passages in res_passages_lst ] all_res_lists = [] for (res_passages, dl) in zip(res_passages_lst, D): res_list = [dict([(k, p[k]) for k in wiki_passages.column_names]) for p in res_passages] for r, sc in zip(res_list, dl): r["score"] = float(sc) all_res_lists += [res_list[:]] return support_doc_lst, all_res_lists # find nearest neighbors of an answer or declarative text in Wikipedia snippets def query_qa_dense_index_nn(passage, qa_embedder, tokenizer, wiki_passages, wiki_index, n_results=10, min_length=20): a_rep = embed_passages_for_retrieval([passage], tokenizer, qa_embedder) D, I = wiki_index.search(a_rep, 2 * n_results) res_passages = [wiki_passages[int(i)] for i in I[0]] support_doc = "

" + "

".join([p["passage_text"] for p in res_passages]) res_list = [dict([(k, p[k]) for k in wiki_passages.column_names]) for p in res_passages] res_list = [res for res in res_list if len(res["passage_text"].split()) > min_length][:n_results] for r, sc, i in zip(res_list, D[0], I[0]): r["passage_id"] = int(i) r["score"] = float(sc) return support_doc, res_list def batch_query_qa_dense_index_nn(passages, qa_embedder, tokenizer, wiki_passages, wiki_index, n_results=10): a_reps = embed_passages_for_retrieval(passages, tokenizer, qa_embedder) D, I = wiki_index.search(a_reps, n_results) res_passages_lst = [[wiki_passages[int(i)] for i in i_lst] for i_lst in I] support_doc_lst = [ "

" + "

".join([p["passage_text"] for p in res_passages]) for res_passages in res_passages_lst ] all_res_lists = [] for (res_passages, dl, il) in zip(res_passages_lst, D, I): res_list = [dict([(k, p[k]) for k in wiki_passages.column_names]) for p in res_passages] for r, sc, i in zip(res_list, dl, il): r["passage_id"] = int(i) r["score"] = float(sc) all_res_lists += [res_list[:]] return support_doc_lst, all_res_lists