[BART] Update encoder and decoder on set_input_embedding (#3501)

Co-authored-by: Ioannis Douratsos <ioannisd@amazon.com>
This commit is contained in:
dougian 2020-03-30 17:20:37 +01:00 committed by GitHub
parent cc598b312b
commit 1f72865726
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -805,6 +805,8 @@ class BartModel(PretrainedBartModel):
def set_input_embeddings(self, value):
self.shared = value
self.encoder.embed_tokens = self.shared
self.decoder.embed_tokens = self.shared
def get_output_embeddings(self):
return _make_linear_from_emb(self.shared) # make it on the fly