Update convert_marian_to_pytorch.py (#16124)

Configuration `tied-embeddings-all` implies `tied-embeddings-src`
This commit is contained in:
tiedemann 2022-03-14 13:15:38 +02:00 committed by GitHub
parent 2de99e6c43
commit 9e9f6b8a45
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -480,6 +480,8 @@ class OpusState:
if "Wpos" in self.state_dict:
raise ValueError("Wpos key in state dictionary")
self.state_dict = dict(self.state_dict)
if cfg["tied-embeddings-all"]:
cfg["tied-embeddings-src"] = True
self.share_encoder_decoder_embeddings = cfg["tied-embeddings-src"]
# create the tokenizer here because we need to know the eos_token_id