transformers/examples/longform-qa/eli5_app.py
Ola Piktus c754c41c61
RAG (#6813)
* added rag WIP

* path fix

* Formatting / renaming prior to actual work

* added rag WIP

* path fix

* Formatting / renaming prior to actual work

* added rag WIP

* path fix

* Formatting / renaming prior to actual work

* added rag WIP

* Formatting / renaming prior to actual work

* First commit

* improve comments

* Retrieval evaluation scripts

* refactor to include modeling outputs + MPI retriever

* Fix rag-token model + refactor

* Various fixes + finetuning logic

* use_bos fix

* Retrieval refactor

* Finetuning refactoring and cleanup

* Add documentation and cleanup

* Remove set_up_rag_env.sh file

* Fix retrieval wit HF index

* Fix import errors

* Fix quality errors

* Refactor as per suggestions in https://github.com/huggingface/transformers/pull/6813#issuecomment-687208867

* fix quality

* Fix RAG Sequence generation

* minor cleanup plus initial tests

* fix test

* fix tests 2

* Comments fix

* post-merge fixes

* Improve readme + post-rebase refactor

* Extra dependencied for tests

* Fix tests

* Fix tests 2

* Refactor test requirements

* Fix tests 3

* Post-rebase refactor

* rename nlp->datasets

* RAG integration tests

* add tokenizer to slow integration test and allow retriever to run on cpu

* add tests; fix position ids warning

* change structure

* change structure

* add from encoder generator

* save working solution

* make all integration tests pass

* add RagTokenizer.save/from_pretrained and RagRetriever.save/from_pretrained

* don't save paths

* delete unnecessary imports

* pass config to AutoTokenizer.from_pretrained for Rag tokenizers

* init wiki_dpr only once

* hardcode legacy index and passages paths (todo: add the right urls)

* finalize config

* finalize retriver api and config api

* LegacyIndex index download refactor

* add dpr to autotokenizer

* make from pretrained more flexible

* fix ragfortokengeneration

* small name changes in tokenizer

* add labels to models

* change default index name

* add retrieval tests

* finish token generate

* align test with previous version and make all tests pass

* add tests

* finalize tests

* implement thoms suggestions

* add first version of test

* make first tests work

* make retriever platform agnostic

* naming

* style

* add legacy index URL

* docstrings + simple retrieval test for distributed

* clean model api

* add doc_ids to retriever's outputs

* fix retrieval tests

* finish model outputs

* finalize model api

* fix generate problem for rag

* fix generate for other modles

* fix some tests

* save intermediate

* set generate to default

* big refactor generate

* delete rag_api

* correct pip faiss install

* fix auto tokenization test

* fix faiss install

* fix test

* move the distributed logic to examples

* model page

* docs

* finish tests

* fix dependencies

* fix import in __init__

* Refactor eval_rag and finetune scripts

* start docstring

* add psutil to test

* fix tf test

* move require torch to top

* fix retrieval test

* align naming

* finish automodel

* fix repo consistency

* test ragtokenizer save/load

* add rag model output docs

* fix ragtokenizer save/load from pretrained

* fix tokenizer dir

* remove torch in retrieval

* fix docs

* fixe finetune scripts

* finish model docs

* finish docs

* remove auto model for now

* add require torch

* remove solved todos

* integrate sylvains suggestions

* sams comments

* correct mistake on purpose

* improve README

* Add generation test cases

* fix rag token

* clean token generate

* fix test

* add note to test

* fix attention mask

* add t5 test for rag

* Fix handling prefix in finetune.py

* don't overwrite index_name

Co-authored-by: Patrick Lewis <plewis@fb.com>
Co-authored-by: Aleksandra Piktus <piktus@devfair0141.h2.fair>
Co-authored-by: Aleksandra Piktus <piktus@learnfair5102.h2.fair>
Co-authored-by: Aleksandra Piktus <piktus@learnfair5067.h2.fair>
Co-authored-by: Your Name <you@example.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Quentin Lhoest <lhoest.q@gmail.com>
2020-09-22 18:29:58 +02:00

347 lines
13 KiB
Python

import datasets
import numpy as np
import streamlit as st
import torch
from elasticsearch import Elasticsearch
import faiss
import transformers
from eli5_utils import (
embed_questions_for_retrieval,
make_qa_s2s_model,
qa_s2s_generate,
query_es_index,
query_qa_dense_index,
)
from transformers import AutoModel, AutoModelForSeq2SeqLM, AutoTokenizer
MODEL_TYPE = "bart"
LOAD_DENSE_INDEX = True
@st.cache(allow_output_mutation=True)
def load_models():
if LOAD_DENSE_INDEX:
qar_tokenizer = AutoTokenizer.from_pretrained("yjernite/retribert-base-uncased")
qar_model = AutoModel.from_pretrained("yjernite/retribert-base-uncased").to("cuda:0")
_ = qar_model.eval()
else:
qar_tokenizer, qar_model = (None, None)
if MODEL_TYPE == "bart":
s2s_tokenizer = AutoTokenizer.from_pretrained("yjernite/bart_eli5")
s2s_model = AutoModelForSeq2SeqLM.from_pretrained("yjernite/bart_eli5").to("cuda:0")
save_dict = torch.load("seq2seq_models/eli5_bart_model_blm_2.pth")
s2s_model.load_state_dict(save_dict["model"])
_ = s2s_model.eval()
else:
s2s_tokenizer, s2s_model = make_qa_s2s_model(
model_name="t5-small", from_file="seq2seq_models/eli5_t5_model_1024_4.pth", device="cuda:0"
)
return (qar_tokenizer, qar_model, s2s_tokenizer, s2s_model)
@st.cache(allow_output_mutation=True)
def load_indexes():
if LOAD_DENSE_INDEX:
faiss_res = faiss.StandardGpuResources()
wiki40b_passages = datasets.load_dataset(path="wiki_snippets", name="wiki40b_en_100_0")["train"]
wiki40b_passage_reps = np.memmap(
"wiki40b_passages_reps_32_l-8_h-768_b-512-512.dat",
dtype="float32",
mode="r",
shape=(wiki40b_passages.num_rows, 128),
)
wiki40b_index_flat = faiss.IndexFlatIP(128)
wiki40b_gpu_index_flat = faiss.index_cpu_to_gpu(faiss_res, 1, wiki40b_index_flat)
wiki40b_gpu_index_flat.add(wiki40b_passage_reps) # TODO fix for larger GPU
else:
wiki40b_passages, wiki40b_gpu_index_flat = (None, None)
es_client = Elasticsearch([{"host": "localhost", "port": "9200"}])
return (wiki40b_passages, wiki40b_gpu_index_flat, es_client)
@st.cache(allow_output_mutation=True)
def load_train_data():
eli5 = datasets.load_dataset("eli5", name="LFQA_reddit")
eli5_train = eli5["train_eli5"]
eli5_train_q_reps = np.memmap(
"eli5_questions_reps.dat", dtype="float32", mode="r", shape=(eli5_train.num_rows, 128)
)
eli5_train_q_index = faiss.IndexFlatIP(128)
eli5_train_q_index.add(eli5_train_q_reps)
return (eli5_train, eli5_train_q_index)
passages, gpu_dense_index, es_client = load_indexes()
qar_tokenizer, qar_model, s2s_tokenizer, s2s_model = load_models()
eli5_train, eli5_train_q_index = load_train_data()
def find_nearest_training(question, n_results=10):
q_rep = embed_questions_for_retrieval([question], qar_tokenizer, qar_model)
D, I = eli5_train_q_index.search(q_rep, n_results)
nn_examples = [eli5_train[int(i)] for i in I[0]]
return nn_examples
def make_support(question, source="wiki40b", method="dense", n_results=10):
if source == "none":
support_doc, hit_lst = (" <P> ".join(["" for _ in range(11)]).strip(), [])
else:
if method == "dense":
support_doc, hit_lst = query_qa_dense_index(
question, qar_model, qar_tokenizer, passages, gpu_dense_index, n_results
)
else:
support_doc, hit_lst = query_es_index(
question,
es_client,
index_name="english_wiki40b_snippets_100w",
n_results=n_results,
)
support_list = [
(res["article_title"], res["section_title"].strip(), res["score"], res["passage_text"]) for res in hit_lst
]
question_doc = "question: {} context: {}".format(question, support_doc)
return question_doc, support_list
@st.cache(hash_funcs={torch.Tensor: (lambda _: None), transformers.tokenization_bart.BartTokenizer: (lambda _: None)})
def answer_question(
question_doc, s2s_model, s2s_tokenizer, min_len=64, max_len=256, sampling=False, n_beams=2, top_p=0.95, temp=0.8
):
with torch.no_grad():
answer = qa_s2s_generate(
question_doc,
s2s_model,
s2s_tokenizer,
num_answers=1,
num_beams=n_beams,
min_len=min_len,
max_len=max_len,
do_sample=sampling,
temp=temp,
top_p=top_p,
top_k=None,
max_input_length=1024,
device="cuda:0",
)[0]
return (answer, support_list)
st.title("Long Form Question Answering with ELI5")
# Start sidebar
header_html = "<img src='https://huggingface.co/front/assets/huggingface_logo.svg'>"
header_full = """
<html>
<head>
<style>
.img-container {
padding-left: 90px;
padding-right: 90px;
padding-top: 50px;
padding-bottom: 50px;
background-color: #f0f3f9;
}
</style>
</head>
<body>
<span class="img-container"> <!-- Inline parent element -->
%s
</span>
</body>
</html>
""" % (
header_html,
)
st.sidebar.markdown(
header_full,
unsafe_allow_html=True,
)
# Long Form QA with ELI5 and Wikipedia
description = """
This demo presents a model trained to [provide long-form answers to open-domain questions](https://yjernite.github.io/lfqa.html).
First, a document retriever fetches a set of relevant Wikipedia passages given the question from the [Wiki40b](https://research.google/pubs/pub49029/) dataset,
a pre-processed fixed snapshot of Wikipedia.
"""
st.sidebar.markdown(description, unsafe_allow_html=True)
action_list = [
"Answer the question",
"View the retrieved document only",
"View the most similar ELI5 question and answer",
"Show me everything, please!",
]
demo_options = st.sidebar.checkbox("Demo options")
if demo_options:
action_st = st.sidebar.selectbox(
"",
action_list,
index=3,
)
action = action_list.index(action_st)
show_type = st.sidebar.selectbox(
"",
["Show full text of passages", "Show passage section titles"],
index=0,
)
show_passages = show_type == "Show full text of passages"
else:
action = 3
show_passages = True
retrieval_options = st.sidebar.checkbox("Retrieval options")
if retrieval_options:
retriever_info = """
### Information retriever options
The **sparse** retriever uses ElasticSearch, while the **dense** retriever uses max-inner-product search between a question and passage embedding
trained using the [ELI5](https://arxiv.org/abs/1907.09190) questions-answer pairs.
The answer is then generated by sequence to sequence model which takes the question and retrieved document as input.
"""
st.sidebar.markdown(retriever_info)
wiki_source = st.sidebar.selectbox("Which Wikipedia format should the model use?", ["wiki40b", "none"])
index_type = st.sidebar.selectbox("Which Wikipedia indexer should the model use?", ["dense", "sparse", "mixed"])
else:
wiki_source = "wiki40b"
index_type = "dense"
sampled = "beam"
n_beams = 2
min_len = 64
max_len = 256
top_p = None
temp = None
generate_options = st.sidebar.checkbox("Generation options")
if generate_options:
generate_info = """
### Answer generation options
The sequence-to-sequence model was initialized with [BART](https://huggingface.co/facebook/bart-large)
weights and fine-tuned on the ELI5 QA pairs and retrieved documents. You can use the model for greedy decoding with
**beam** search, or **sample** from the decoder's output probabilities.
"""
st.sidebar.markdown(generate_info)
sampled = st.sidebar.selectbox("Would you like to use beam search or sample an answer?", ["beam", "sampled"])
min_len = st.sidebar.slider(
"Minimum generation length", min_value=8, max_value=256, value=64, step=8, format=None, key=None
)
max_len = st.sidebar.slider(
"Maximum generation length", min_value=64, max_value=512, value=256, step=16, format=None, key=None
)
if sampled == "beam":
n_beams = st.sidebar.slider("Beam size", min_value=1, max_value=8, value=2, step=None, format=None, key=None)
else:
top_p = st.sidebar.slider(
"Nucleus sampling p", min_value=0.1, max_value=1.0, value=0.95, step=0.01, format=None, key=None
)
temp = st.sidebar.slider(
"Temperature", min_value=0.1, max_value=1.0, value=0.7, step=0.01, format=None, key=None
)
n_beams = None
# start main text
questions_list = [
"<MY QUESTION>",
"How do people make chocolate?",
"Why do we get a fever when we are sick?",
"How can different animals perceive different colors?",
"What is natural language processing?",
"What's the best way to treat a sunburn?",
"What exactly are vitamins ?",
"How does nuclear energy provide electricity?",
"What's the difference between viruses and bacteria?",
"Why are flutes classified as woodwinds when most of them are made out of metal ?",
"Why do people like drinking coffee even though it tastes so bad?",
"What happens when wine ages? How does it make the wine taste better?",
"If an animal is an herbivore, where does it get the protein that it needs to survive if it only eats grass?",
"How can we set a date to the beginning or end of an artistic period? Doesn't the change happen gradually?",
"How does New Zealand have so many large bird predators?",
]
question_s = st.selectbox(
"What would you like to ask? ---- select <MY QUESTION> to enter a new query",
questions_list,
index=1,
)
if question_s == "<MY QUESTION>":
question = st.text_input("Enter your question here:", "")
else:
question = question_s
if st.button("Show me!"):
if action in [0, 1, 3]:
if index_type == "mixed":
_, support_list_dense = make_support(question, source=wiki_source, method="dense", n_results=10)
_, support_list_sparse = make_support(question, source=wiki_source, method="sparse", n_results=10)
support_list = []
for res_d, res_s in zip(support_list_dense, support_list_sparse):
if tuple(res_d) not in support_list:
support_list += [tuple(res_d)]
if tuple(res_s) not in support_list:
support_list += [tuple(res_s)]
support_list = support_list[:10]
question_doc = "<P> " + " <P> ".join([res[-1] for res in support_list])
else:
question_doc, support_list = make_support(question, source=wiki_source, method=index_type, n_results=10)
if action in [0, 3]:
answer, support_list = answer_question(
question_doc,
s2s_model,
s2s_tokenizer,
min_len=min_len,
max_len=int(max_len),
sampling=(sampled == "sampled"),
n_beams=n_beams,
top_p=top_p,
temp=temp,
)
st.markdown("### The model generated answer is:")
st.write(answer)
if action in [0, 1, 3] and wiki_source != "none":
st.markdown("--- \n ### The model is drawing information from the following Wikipedia passages:")
for i, res in enumerate(support_list):
wiki_url = "https://en.wikipedia.org/wiki/{}".format(res[0].replace(" ", "_"))
sec_titles = res[1].strip()
if sec_titles == "":
sections = "[{}]({})".format(res[0], wiki_url)
else:
sec_list = sec_titles.split(" & ")
sections = " & ".join(
["[{}]({}#{})".format(sec.strip(), wiki_url, sec.strip().replace(" ", "_")) for sec in sec_list]
)
st.markdown(
"{0:02d} - **Article**: {1:<18} <br> _Section_: {2}".format(i + 1, res[0], sections),
unsafe_allow_html=True,
)
if show_passages:
st.write(
'> <span style="font-family:arial; font-size:10pt;">' + res[-1] + "</span>", unsafe_allow_html=True
)
if action in [2, 3]:
nn_train_list = find_nearest_training(question)
train_exple = nn_train_list[0]
st.markdown(
"--- \n ### The most similar question in the ELI5 training set was: \n\n {}".format(train_exple["title"])
)
answers_st = [
"{}. {}".format(i + 1, " \n".join([line.strip() for line in ans.split("\n") if line.strip() != ""]))
for i, (ans, sc) in enumerate(zip(train_exple["answers"]["text"], train_exple["answers"]["score"]))
if i == 0 or sc > 2
]
st.markdown("##### Its answers were: \n\n {}".format("\n".join(answers_st)))
disclaimer = """
---
**Disclaimer**
*The intent of this app is to provide some (hopefully entertaining) insights into the behavior of a current LFQA system.
Evaluating biases of such a model and ensuring factual generations are still very much open research problems.
Therefore, until some significant progress is achieved, we caution against using the generated answers for practical purposes.*
"""
st.sidebar.markdown(disclaimer, unsafe_allow_html=True)