mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[BART] Update encoder and decoder on set_input_embedding (#3501)
Co-authored-by: Ioannis Douratsos <ioannisd@amazon.com>
This commit is contained in:
parent
cc598b312b
commit
1f72865726
@ -805,6 +805,8 @@ class BartModel(PretrainedBartModel):
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.shared = value
|
||||
self.encoder.embed_tokens = self.shared
|
||||
self.decoder.embed_tokens = self.shared
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return _make_linear_from_emb(self.shared) # make it on the fly
|
||||
|
Loading…
Reference in New Issue
Block a user