[RAG] Fix rag from pretrained question encoder generator behavior (#11962)

* fix_torch_device_generate_test

* remove @

* fix rag from pretrained loading

* add test

* uplaod

* finish
This commit is contained in:
Patrick von Platen 2021-06-02 09:17:14 +01:00 committed by GitHub
parent 6db3a87de2
commit 43f46aa7fd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 5 deletions

View File

@ -245,7 +245,6 @@ class RagPreTrainedModel(PreTrainedModel):
question_encoder_pretrained_model_name_or_path: str = None,
generator_pretrained_model_name_or_path: str = None,
retriever: RagRetriever = None,
*model_args,
**kwargs
) -> PreTrainedModel:
r"""
@ -310,7 +309,7 @@ class RagPreTrainedModel(PreTrainedModel):
"""
kwargs_question_encoder = {
argument[len("question_question_encoder_") :]: value
argument[len("question_encoder_") :]: value
for argument, value in kwargs.items()
if argument.startswith("question_encoder_")
}
@ -340,11 +339,15 @@ class RagPreTrainedModel(PreTrainedModel):
if "config" not in kwargs_question_encoder:
from ..auto.configuration_auto import AutoConfig
question_encoder_config = AutoConfig.from_pretrained(question_encoder_pretrained_model_name_or_path)
question_encoder_config, kwargs_question_encoder = AutoConfig.from_pretrained(
question_encoder_pretrained_model_name_or_path,
**kwargs_question_encoder,
return_unused_kwargs=True,
)
kwargs_question_encoder["config"] = question_encoder_config
question_encoder = AutoModel.from_pretrained(
question_encoder_pretrained_model_name_or_path, *model_args, **kwargs_question_encoder
question_encoder_pretrained_model_name_or_path, **kwargs_question_encoder
)
generator = kwargs_generator.pop("model", None)
@ -357,7 +360,10 @@ class RagPreTrainedModel(PreTrainedModel):
if "config" not in kwargs_generator:
from ..auto.configuration_auto import AutoConfig
generator_config = AutoConfig.from_pretrained(generator_pretrained_model_name_or_path)
generator_config, kwargs_generator = AutoConfig.from_pretrained(
generator_pretrained_model_name_or_path, **kwargs_generator, return_unused_kwargs=True
)
kwargs_generator["config"] = generator_config
generator = AutoModelForSeq2SeqLM.from_pretrained(

View File

@ -1132,12 +1132,17 @@ class RagModelSaveLoadTests(unittest.TestCase):
"facebook/bart-large-cnn",
retriever=rag_retriever,
config=rag_config,
question_encoder_max_length=200,
generator_max_length=200,
).to(torch_device)
# check that the from pretrained methods work
rag_token.save_pretrained(tmp_dirname)
rag_token.from_pretrained(tmp_dirname, retriever=rag_retriever)
rag_token.to(torch_device)
self.assertTrue(rag_token.question_encoder.config.max_length == 200)
self.assertTrue(rag_token.generator.config.max_length == 200)
with torch.no_grad():
output = rag_token(
input_ids,