use return dict for rag encoder (#9363)

This commit is contained in:
Derrick Blakely 2021-01-02 04:39:14 -07:00 committed by GitHub
parent ae333d04b2
commit 5f7a07c0c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1437,7 +1437,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
batch_size = context_input_ids.shape[0] // n_docs
encoder = self.rag.generator.get_encoder()
encoder_outputs = encoder(input_ids=context_input_ids, attention_mask=context_attention_mask)
encoder_outputs = encoder(input_ids=context_input_ids, attention_mask=context_attention_mask, return_dict=True)
input_ids = torch.full(
(batch_size * num_beams, 1),