mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[RAG] Fix rag from pretrained question encoder generator behavior (#11962)
* fix_torch_device_generate_test * remove @ * fix rag from pretrained loading * add test * uplaod * finish
This commit is contained in:
parent
6db3a87de2
commit
43f46aa7fd
@ -245,7 +245,6 @@ class RagPreTrainedModel(PreTrainedModel):
|
||||
question_encoder_pretrained_model_name_or_path: str = None,
|
||||
generator_pretrained_model_name_or_path: str = None,
|
||||
retriever: RagRetriever = None,
|
||||
*model_args,
|
||||
**kwargs
|
||||
) -> PreTrainedModel:
|
||||
r"""
|
||||
@ -310,7 +309,7 @@ class RagPreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
|
||||
kwargs_question_encoder = {
|
||||
argument[len("question_question_encoder_") :]: value
|
||||
argument[len("question_encoder_") :]: value
|
||||
for argument, value in kwargs.items()
|
||||
if argument.startswith("question_encoder_")
|
||||
}
|
||||
@ -340,11 +339,15 @@ class RagPreTrainedModel(PreTrainedModel):
|
||||
if "config" not in kwargs_question_encoder:
|
||||
from ..auto.configuration_auto import AutoConfig
|
||||
|
||||
question_encoder_config = AutoConfig.from_pretrained(question_encoder_pretrained_model_name_or_path)
|
||||
question_encoder_config, kwargs_question_encoder = AutoConfig.from_pretrained(
|
||||
question_encoder_pretrained_model_name_or_path,
|
||||
**kwargs_question_encoder,
|
||||
return_unused_kwargs=True,
|
||||
)
|
||||
kwargs_question_encoder["config"] = question_encoder_config
|
||||
|
||||
question_encoder = AutoModel.from_pretrained(
|
||||
question_encoder_pretrained_model_name_or_path, *model_args, **kwargs_question_encoder
|
||||
question_encoder_pretrained_model_name_or_path, **kwargs_question_encoder
|
||||
)
|
||||
|
||||
generator = kwargs_generator.pop("model", None)
|
||||
@ -357,7 +360,10 @@ class RagPreTrainedModel(PreTrainedModel):
|
||||
if "config" not in kwargs_generator:
|
||||
from ..auto.configuration_auto import AutoConfig
|
||||
|
||||
generator_config = AutoConfig.from_pretrained(generator_pretrained_model_name_or_path)
|
||||
generator_config, kwargs_generator = AutoConfig.from_pretrained(
|
||||
generator_pretrained_model_name_or_path, **kwargs_generator, return_unused_kwargs=True
|
||||
)
|
||||
|
||||
kwargs_generator["config"] = generator_config
|
||||
|
||||
generator = AutoModelForSeq2SeqLM.from_pretrained(
|
||||
|
@ -1132,12 +1132,17 @@ class RagModelSaveLoadTests(unittest.TestCase):
|
||||
"facebook/bart-large-cnn",
|
||||
retriever=rag_retriever,
|
||||
config=rag_config,
|
||||
question_encoder_max_length=200,
|
||||
generator_max_length=200,
|
||||
).to(torch_device)
|
||||
# check that the from pretrained methods work
|
||||
rag_token.save_pretrained(tmp_dirname)
|
||||
rag_token.from_pretrained(tmp_dirname, retriever=rag_retriever)
|
||||
rag_token.to(torch_device)
|
||||
|
||||
self.assertTrue(rag_token.question_encoder.config.max_length == 200)
|
||||
self.assertTrue(rag_token.generator.config.max_length == 200)
|
||||
|
||||
with torch.no_grad():
|
||||
output = rag_token(
|
||||
input_ids,
|
||||
|
Loading…
Reference in New Issue
Block a user