mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix Marian model conversion (#30173)
* fix marian model coversion * uncomment that line * remove unnecessary code * revert tie_weights, doesn't hurt
This commit is contained in:
parent
38a4bf79ad
commit
4bc9cb36b7
@ -34,7 +34,6 @@ from transformers.models.marian.convert_marian_to_pytorch import (
|
||||
|
||||
DEFAULT_REPO = "Tatoeba-Challenge"
|
||||
DEFAULT_MODEL_DIR = os.path.join(DEFAULT_REPO, "models")
|
||||
LANG_CODE_URL = "https://datahub.io/core/language-codes/r/language-codes-3b2.csv"
|
||||
ISO_URL = "https://cdn-datasets.huggingface.co/language_codes/iso-639-3.csv"
|
||||
ISO_PATH = "lang_code_data/iso-639-3.csv"
|
||||
LANG_CODE_PATH = "lang_code_data/language-codes-3b2.csv"
|
||||
@ -277,13 +276,17 @@ class TatoebaConverter:
|
||||
json.dump(metadata, writeobj)
|
||||
|
||||
def download_lang_info(self):
|
||||
global LANG_CODE_PATH
|
||||
Path(LANG_CODE_PATH).parent.mkdir(exist_ok=True)
|
||||
import wget
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
if not os.path.exists(ISO_PATH):
|
||||
wget.download(ISO_URL, ISO_PATH)
|
||||
if not os.path.exists(LANG_CODE_PATH):
|
||||
wget.download(LANG_CODE_URL, LANG_CODE_PATH)
|
||||
LANG_CODE_PATH = hf_hub_download(
|
||||
repo_id="huggingface/language_codes_marianMT", filename="language-codes-3b2.csv", repo_type="dataset"
|
||||
)
|
||||
|
||||
def parse_metadata(self, model_name, repo_path=DEFAULT_MODEL_DIR, method="best"):
|
||||
p = Path(repo_path) / model_name
|
||||
|
@ -622,6 +622,10 @@ class OpusState:
|
||||
bias_tensor = nn.Parameter(torch.FloatTensor(self.final_bias))
|
||||
model.model.decoder.embed_tokens.weight = decoder_wemb_tensor
|
||||
|
||||
# handle tied embeddings, otherwise "from_pretrained" loads them incorrectly
|
||||
if self.cfg["tied-embeddings"]:
|
||||
model.lm_head.weight.data = model.model.decoder.embed_tokens.weight.data.clone()
|
||||
|
||||
model.final_logits_bias = bias_tensor
|
||||
|
||||
if "Wpos" in state_dict:
|
||||
|
Loading…
Reference in New Issue
Block a user