diff --git a/src/transformers/models/rag/modeling_rag.py b/src/transformers/models/rag/modeling_rag.py index 5eeabef2cde..8caf9ecdd9a 100644 --- a/src/transformers/models/rag/modeling_rag.py +++ b/src/transformers/models/rag/modeling_rag.py @@ -245,7 +245,6 @@ class RagPreTrainedModel(PreTrainedModel): question_encoder_pretrained_model_name_or_path: str = None, generator_pretrained_model_name_or_path: str = None, retriever: RagRetriever = None, - *model_args, **kwargs ) -> PreTrainedModel: r""" @@ -310,7 +309,7 @@ class RagPreTrainedModel(PreTrainedModel): """ kwargs_question_encoder = { - argument[len("question_question_encoder_") :]: value + argument[len("question_encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("question_encoder_") } @@ -340,11 +339,15 @@ class RagPreTrainedModel(PreTrainedModel): if "config" not in kwargs_question_encoder: from ..auto.configuration_auto import AutoConfig - question_encoder_config = AutoConfig.from_pretrained(question_encoder_pretrained_model_name_or_path) + question_encoder_config, kwargs_question_encoder = AutoConfig.from_pretrained( + question_encoder_pretrained_model_name_or_path, + **kwargs_question_encoder, + return_unused_kwargs=True, + ) kwargs_question_encoder["config"] = question_encoder_config question_encoder = AutoModel.from_pretrained( - question_encoder_pretrained_model_name_or_path, *model_args, **kwargs_question_encoder + question_encoder_pretrained_model_name_or_path, **kwargs_question_encoder ) generator = kwargs_generator.pop("model", None) @@ -357,7 +360,10 @@ class RagPreTrainedModel(PreTrainedModel): if "config" not in kwargs_generator: from ..auto.configuration_auto import AutoConfig - generator_config = AutoConfig.from_pretrained(generator_pretrained_model_name_or_path) + generator_config, kwargs_generator = AutoConfig.from_pretrained( + generator_pretrained_model_name_or_path, **kwargs_generator, return_unused_kwargs=True + ) + kwargs_generator["config"] = generator_config generator = AutoModelForSeq2SeqLM.from_pretrained( diff --git a/tests/test_modeling_rag.py b/tests/test_modeling_rag.py index 9ad7ecde0cc..15bbea52373 100644 --- a/tests/test_modeling_rag.py +++ b/tests/test_modeling_rag.py @@ -1132,12 +1132,17 @@ class RagModelSaveLoadTests(unittest.TestCase): "facebook/bart-large-cnn", retriever=rag_retriever, config=rag_config, + question_encoder_max_length=200, + generator_max_length=200, ).to(torch_device) # check that the from pretrained methods work rag_token.save_pretrained(tmp_dirname) rag_token.from_pretrained(tmp_dirname, retriever=rag_retriever) rag_token.to(torch_device) + self.assertTrue(rag_token.question_encoder.config.max_length == 200) + self.assertTrue(rag_token.generator.config.max_length == 200) + with torch.no_grad(): output = rag_token( input_ids,