mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-01 20:00:09 +06:00
Fix rag
(#38585)
Some checks are pending
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Waiting to run
Build documentation / build (push) Waiting to run
New model PR merged notification / Notify new model (push) Waiting to run
Slow tests on important models (on Push - A10) / Get all modified files (push) Waiting to run
Slow tests on important models (on Push - A10) / Slow & FA2 tests (push) Blocked by required conditions
Self-hosted runner (push-caller) / Check if setup was changed (push) Waiting to run
Self-hosted runner (push-caller) / build-docker-containers (push) Blocked by required conditions
Self-hosted runner (push-caller) / Trigger Push CI (push) Blocked by required conditions
Secret Leaks / trufflehog (push) Waiting to run
Update Transformers metadata / build_and_package (push) Waiting to run
Some checks are pending
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Waiting to run
Build documentation / build (push) Waiting to run
New model PR merged notification / Notify new model (push) Waiting to run
Slow tests on important models (on Push - A10) / Get all modified files (push) Waiting to run
Slow tests on important models (on Push - A10) / Slow & FA2 tests (push) Blocked by required conditions
Self-hosted runner (push-caller) / Check if setup was changed (push) Waiting to run
Self-hosted runner (push-caller) / build-docker-containers (push) Blocked by required conditions
Self-hosted runner (push-caller) / Trigger Push CI (push) Blocked by required conditions
Secret Leaks / trufflehog (push) Waiting to run
Update Transformers metadata / build_and_package (push) Waiting to run
* fix * fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
9eac19eb59
commit
f9be71b34d
@ -21,6 +21,7 @@ import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
from transformers import BartTokenizer, T5Tokenizer
|
||||
from transformers.models.bert.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES
|
||||
@ -49,7 +50,7 @@ T5_SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
|
||||
if is_torch_available() and is_datasets_available() and is_faiss_available():
|
||||
import faiss
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from datasets import Dataset, load_dataset
|
||||
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
@ -679,6 +680,24 @@ class RagDPRT5Test(RagTestMixin, unittest.TestCase):
|
||||
@require_tokenizers
|
||||
@require_torch_non_multi_accelerator
|
||||
class RagModelIntegrationTests(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.temp_dir = tempfile.TemporaryDirectory()
|
||||
cls.dataset_path = cls.temp_dir.name
|
||||
cls.index_path = os.path.join(cls.temp_dir.name, "index.faiss")
|
||||
|
||||
ds = load_dataset("hf-internal-testing/wiki_dpr_dummy")["train"]
|
||||
ds.save_to_disk(cls.dataset_path)
|
||||
|
||||
url = "https://huggingface.co/datasets/hf-internal-testing/wiki_dpr_dummy/resolve/main/index.faiss"
|
||||
response = requests.get(url, stream=True)
|
||||
with open(cls.index_path, "wb") as fp:
|
||||
fp.write(response.content)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls.temp_dir.cleanup()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
# clean-up as much as possible GPU memory occupied by PyTorch
|
||||
@ -722,8 +741,9 @@ class RagModelIntegrationTests(unittest.TestCase):
|
||||
max_combined_length=300,
|
||||
dataset="wiki_dpr",
|
||||
dataset_split="train",
|
||||
index_name="exact",
|
||||
index_path=None,
|
||||
index_name="custom",
|
||||
passages_path=self.dataset_path,
|
||||
index_path=self.index_path,
|
||||
use_dummy_dataset=True,
|
||||
retrieval_vector_size=768,
|
||||
retrieval_batch_size=8,
|
||||
@ -841,8 +861,8 @@ class RagModelIntegrationTests(unittest.TestCase):
|
||||
output_text_2 = rag_decoder_tokenizer.decode(output_ids[1], skip_special_tokens=True)
|
||||
|
||||
# Expected outputs as given by model at integration time.
|
||||
EXPECTED_OUTPUT_TEXT_1 = "\"She's My Kind of Girl"
|
||||
EXPECTED_OUTPUT_TEXT_2 = "\"She's My Kind of Love"
|
||||
EXPECTED_OUTPUT_TEXT_1 = '"She\'s My Kind of Girl" was released through Epic Records in Japan in March 1972. The song was a Top 10 hit in the country. It was the first single to be released by ABBA in the UK. The single was followed by "En Carousel" and "Love Has Its Uses"'
|
||||
EXPECTED_OUTPUT_TEXT_2 = '"She\'s My Kind of Girl" was released through Epic Records in Japan in March 1972. The song was a Top 10 hit in the country. It was the first single to be released by ABBA in the UK. The single was followed by "En Carousel" and "Love Has Its Ways"'
|
||||
|
||||
self.assertEqual(output_text_1, EXPECTED_OUTPUT_TEXT_1)
|
||||
self.assertEqual(output_text_2, EXPECTED_OUTPUT_TEXT_2)
|
||||
@ -903,7 +923,10 @@ class RagModelIntegrationTests(unittest.TestCase):
|
||||
def test_rag_sequence_generate_batch(self):
|
||||
tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
|
||||
retriever = RagRetriever.from_pretrained(
|
||||
"facebook/rag-sequence-nq", index_name="exact", use_dummy_dataset=True, dataset_revision="b24a417"
|
||||
"facebook/rag-sequence-nq",
|
||||
index_name="custom",
|
||||
passages_path=self.dataset_path,
|
||||
index_path=self.index_path,
|
||||
)
|
||||
rag_sequence = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever).to(
|
||||
torch_device
|
||||
@ -926,12 +949,13 @@ class RagModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
||||
|
||||
# PR #31938 cause the output being changed from `june 22, 2018` to `june 22 , 2018`.
|
||||
EXPECTED_OUTPUTS = [
|
||||
" albert einstein",
|
||||
" june 22, 2018",
|
||||
" june 22 , 2018",
|
||||
" amplitude modulation",
|
||||
" tim besley ( chairman )",
|
||||
" june 20, 2018",
|
||||
" june 20 , 2018",
|
||||
" 1980",
|
||||
" 7.0",
|
||||
" 8",
|
||||
@ -943,9 +967,9 @@ class RagModelIntegrationTests(unittest.TestCase):
|
||||
tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
|
||||
retriever = RagRetriever.from_pretrained(
|
||||
"facebook/rag-sequence-nq",
|
||||
index_name="exact",
|
||||
use_dummy_dataset=True,
|
||||
dataset_revision="b24a417",
|
||||
index_name="custom",
|
||||
passages_path=self.dataset_path,
|
||||
index_path=self.index_path,
|
||||
)
|
||||
rag_sequence = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever).to(
|
||||
torch_device
|
||||
@ -981,10 +1005,10 @@ class RagModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
EXPECTED_OUTPUTS = [
|
||||
" albert einstein",
|
||||
" june 22, 2018",
|
||||
" june 22 , 2018",
|
||||
" amplitude modulation",
|
||||
" tim besley ( chairman )",
|
||||
" june 20, 2018",
|
||||
" june 20 , 2018",
|
||||
" 1980",
|
||||
" 7.0",
|
||||
" 8",
|
||||
@ -995,7 +1019,7 @@ class RagModelIntegrationTests(unittest.TestCase):
|
||||
def test_rag_token_generate_batch(self):
|
||||
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
|
||||
retriever = RagRetriever.from_pretrained(
|
||||
"facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True, dataset_revision="b24a417"
|
||||
"facebook/rag-token-nq", index_name="custom", passages_path=self.dataset_path, index_path=self.index_path
|
||||
)
|
||||
rag_token = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever).to(
|
||||
torch_device
|
||||
@ -1023,10 +1047,10 @@ class RagModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
EXPECTED_OUTPUTS = [
|
||||
" albert einstein",
|
||||
" september 22, 2017",
|
||||
" september 22 , 2017",
|
||||
" amplitude modulation",
|
||||
" stefan persson",
|
||||
" april 20, 2018",
|
||||
" april 20 , 2018",
|
||||
" the 1970s",
|
||||
" 7.1. 2",
|
||||
" 13",
|
||||
@ -1037,6 +1061,24 @@ class RagModelIntegrationTests(unittest.TestCase):
|
||||
@require_torch
|
||||
@require_retrieval
|
||||
class RagModelSaveLoadTests(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.temp_dir = tempfile.TemporaryDirectory()
|
||||
cls.dataset_path = cls.temp_dir.name
|
||||
cls.index_path = os.path.join(cls.temp_dir.name, "index.faiss")
|
||||
|
||||
ds = load_dataset("hf-internal-testing/wiki_dpr_dummy")["train"]
|
||||
ds.save_to_disk(cls.dataset_path)
|
||||
|
||||
url = "https://huggingface.co/datasets/hf-internal-testing/wiki_dpr_dummy/resolve/main/index.faiss"
|
||||
response = requests.get(url, stream=True)
|
||||
with open(cls.index_path, "wb") as fp:
|
||||
fp.write(response.content)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls.temp_dir.cleanup()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
# clean-up as much as possible GPU memory occupied by PyTorch
|
||||
@ -1060,8 +1102,9 @@ class RagModelSaveLoadTests(unittest.TestCase):
|
||||
max_combined_length=300,
|
||||
dataset="wiki_dpr",
|
||||
dataset_split="train",
|
||||
index_name="exact",
|
||||
index_path=None,
|
||||
index_name="custom",
|
||||
passages_path=self.dataset_path,
|
||||
index_path=self.index_path,
|
||||
use_dummy_dataset=True,
|
||||
retrieval_vector_size=768,
|
||||
retrieval_batch_size=8,
|
||||
|
Loading…
Reference in New Issue
Block a user