mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-19 12:38:23 +06:00
[marian] Automate Tatoeba-Challenge conversion (#7709)
This commit is contained in:
parent
aacac8f708
commit
9c2b2db2cd
1
.gitignore
vendored
1
.gitignore
vendored
@ -12,6 +12,7 @@ __pycache__/
|
|||||||
tests/fixtures
|
tests/fixtures
|
||||||
logs/
|
logs/
|
||||||
lightning_logs/
|
lightning_logs/
|
||||||
|
lang_code_data/
|
||||||
|
|
||||||
# Distribution / packaging
|
# Distribution / packaging
|
||||||
.Python
|
.Python
|
||||||
|
22
examples/seq2seq/test_tatoeba_conversion.py
Normal file
22
examples/seq2seq/test_tatoeba_conversion.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from transformers.convert_marian_tatoeba_to_pytorch import TatoebaConverter
|
||||||
|
from transformers.file_utils import cached_property
|
||||||
|
from transformers.testing_utils import slow
|
||||||
|
|
||||||
|
|
||||||
|
class TatoebaConversionTester(unittest.TestCase):
|
||||||
|
@cached_property
|
||||||
|
def resolver(self):
|
||||||
|
tmp_dir = tempfile.mkdtemp()
|
||||||
|
return TatoebaConverter(save_dir=tmp_dir)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_resolver(self):
|
||||||
|
self.resolver.convert_models(["heb-eng"])
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_model_card(self):
|
||||||
|
content, mmeta = self.resolver.write_model_card("opus-mt-he-en", dry_run=True)
|
||||||
|
assert mmeta["long_pair"] == "heb-eng"
|
44
scripts/tatoeba/README.md
Normal file
44
scripts/tatoeba/README.md
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
Setup transformers following instructions in README.md, (I would fork first).
|
||||||
|
```bash
|
||||||
|
git clone git@github.com:huggingface/transformers.git
|
||||||
|
cd transformers
|
||||||
|
pip install -e .
|
||||||
|
pip install pandas
|
||||||
|
```
|
||||||
|
|
||||||
|
Get required metadata
|
||||||
|
```
|
||||||
|
curl https://cdn-datasets.huggingface.co/language_codes/language-codes-3b2.csv > language-codes-3b2.csv
|
||||||
|
curl https://cdn-datasets.huggingface.co/language_codes/iso-639-3.csv > iso-639-3.csv
|
||||||
|
```
|
||||||
|
|
||||||
|
Install Tatoeba-Challenge repo inside transformers
|
||||||
|
```bash
|
||||||
|
git clone git@github.com:Helsinki-NLP/Tatoeba-Challenge.git
|
||||||
|
```
|
||||||
|
|
||||||
|
To convert a few models, call the conversion script from command line:
|
||||||
|
```bash
|
||||||
|
python src/transformers/convert_marian_tatoeba_to_pytorch.py --models heb-eng eng-heb --save_dir converted
|
||||||
|
```
|
||||||
|
|
||||||
|
To convert lots of models you can pass your list of Tatoeba model names to `resolver.convert_models` in a python client or script.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from transformers.convert_marian_tatoeba_to_pytorch import TatoebaConverter
|
||||||
|
resolver = TatoebaConverter(save_dir='converted')
|
||||||
|
resolver.convert_models(['heb-eng', 'eng-heb'])
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### Upload converted models
|
||||||
|
```bash
|
||||||
|
cd converted
|
||||||
|
transformers-cli login
|
||||||
|
for FILE in *; do transformers-cli upload $FILE; done
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### Modifications
|
||||||
|
- To change naming logic, change the code near `os.rename`. The model card creation code may also need to change.
|
||||||
|
- To change model card content, you must modify `TatoebaCodeResolver.write_model_card`
|
1249
src/transformers/convert_marian_tatoeba_to_pytorch.py
Normal file
1249
src/transformers/convert_marian_tatoeba_to_pytorch.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -1,12 +1,11 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import shutil
|
|
||||||
import socket
|
import socket
|
||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Tuple, Union
|
from typing import Dict, List, Union
|
||||||
from zipfile import ZipFile
|
from zipfile import ZipFile
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -23,85 +22,6 @@ def remove_suffix(text: str, suffix: str):
|
|||||||
return text # or whatever
|
return text # or whatever
|
||||||
|
|
||||||
|
|
||||||
def _process_benchmark_table_row(x):
|
|
||||||
fields = lmap(str.strip, x.replace("\t", "").split("|")[1:-1])
|
|
||||||
assert len(fields) == 3
|
|
||||||
return (fields[0], float(fields[1]), float(fields[2]))
|
|
||||||
|
|
||||||
|
|
||||||
def process_last_benchmark_table(readme_path) -> List[Tuple[str, float, float]]:
|
|
||||||
md_content = Path(readme_path).open().read()
|
|
||||||
entries = md_content.split("## Benchmarks")[-1].strip().split("\n")[2:]
|
|
||||||
data = lmap(_process_benchmark_table_row, entries)
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
def check_if_models_are_dominated(old_repo_path="OPUS-MT-train/models", new_repo_path="Tatoeba-Challenge/models/"):
|
|
||||||
"""Make a blacklist for models where we have already ported the same language pair, and the ported model has higher BLEU score."""
|
|
||||||
import pandas as pd
|
|
||||||
|
|
||||||
newest_released, old_reg, released = get_released_df(new_repo_path, old_repo_path)
|
|
||||||
|
|
||||||
short_to_new_bleu = newest_released.set_index("short_pair").bleu
|
|
||||||
|
|
||||||
assert released.groupby("short_pair").pair.nunique().max() == 1
|
|
||||||
|
|
||||||
short_to_long = released.groupby("short_pair").pair.first().to_dict()
|
|
||||||
|
|
||||||
overlap_short = old_reg.index.intersection(released.short_pair.unique())
|
|
||||||
overlap_long = [short_to_long[o] for o in overlap_short]
|
|
||||||
new_reported_bleu = [short_to_new_bleu[o] for o in overlap_short]
|
|
||||||
|
|
||||||
def get_old_bleu(o) -> float:
|
|
||||||
pat = old_repo_path + "/{}/README.md"
|
|
||||||
bm_data = process_last_benchmark_table(pat.format(o))
|
|
||||||
tab = pd.DataFrame(bm_data, columns=["testset", "bleu", "chr-f"])
|
|
||||||
tato_bleu = tab.loc[lambda x: x.testset.str.startswith("Tato")].bleu
|
|
||||||
if tato_bleu.shape[0] > 0:
|
|
||||||
return tato_bleu.iloc[0]
|
|
||||||
else:
|
|
||||||
return np.nan
|
|
||||||
|
|
||||||
old_bleu = [get_old_bleu(o) for o in overlap_short]
|
|
||||||
cmp_df = pd.DataFrame(
|
|
||||||
dict(short=overlap_short, long=overlap_long, old_bleu=old_bleu, new_bleu=new_reported_bleu)
|
|
||||||
).fillna(-1)
|
|
||||||
|
|
||||||
dominated = cmp_df[cmp_df.old_bleu > cmp_df.new_bleu]
|
|
||||||
whitelist_df = cmp_df[cmp_df.old_bleu <= cmp_df.new_bleu]
|
|
||||||
blacklist = dominated.long.unique().tolist() # 3 letter codes
|
|
||||||
return whitelist_df, dominated, blacklist
|
|
||||||
|
|
||||||
|
|
||||||
def get_released_df(new_repo_path, old_repo_path):
|
|
||||||
import pandas as pd
|
|
||||||
|
|
||||||
released_cols = [
|
|
||||||
"url_base",
|
|
||||||
"pair", # (ISO639-3/ISO639-5 codes),
|
|
||||||
"short_pair", # (reduced codes),
|
|
||||||
"chrF2_score",
|
|
||||||
"bleu",
|
|
||||||
"brevity_penalty",
|
|
||||||
"ref_len",
|
|
||||||
"src_name",
|
|
||||||
"tgt_name",
|
|
||||||
]
|
|
||||||
released = pd.read_csv(f"{new_repo_path}/released-models.txt", sep="\t", header=None).iloc[:-1]
|
|
||||||
released.columns = released_cols
|
|
||||||
old_reg = make_registry(repo_path=old_repo_path)
|
|
||||||
old_reg = pd.DataFrame(old_reg, columns=["id", "prepro", "url_model", "url_test_set"])
|
|
||||||
assert old_reg.id.value_counts().max() == 1
|
|
||||||
old_reg = old_reg.set_index("id")
|
|
||||||
released["fname"] = released["url_base"].apply(
|
|
||||||
lambda x: remove_suffix(remove_prefix(x, "https://object.pouta.csc.fi/Tatoeba-Challenge/opus"), ".zip")
|
|
||||||
)
|
|
||||||
released["2m"] = released.fname.str.startswith("2m")
|
|
||||||
released["date"] = pd.to_datetime(released["fname"].apply(lambda x: remove_prefix(remove_prefix(x, "2m-"), "-")))
|
|
||||||
newest_released = released.dsort("date").drop_duplicates(["short_pair"], keep="first")
|
|
||||||
return newest_released, old_reg, released
|
|
||||||
|
|
||||||
|
|
||||||
def remove_prefix(text: str, prefix: str):
|
def remove_prefix(text: str, prefix: str):
|
||||||
if text.startswith(prefix):
|
if text.startswith(prefix):
|
||||||
return text[len(prefix) :]
|
return text[len(prefix) :]
|
||||||
@ -183,7 +103,11 @@ def find_model_file(dest_dir): # this one better
|
|||||||
|
|
||||||
|
|
||||||
# Group Names Logic: change long opus model names to something shorter, like opus-mt-en-ROMANCE
|
# Group Names Logic: change long opus model names to something shorter, like opus-mt-en-ROMANCE
|
||||||
ROM_GROUP = "fr+fr_BE+fr_CA+fr_FR+wa+frp+oc+ca+rm+lld+fur+lij+lmo+es+es_AR+es_CL+es_CO+es_CR+es_DO+es_EC+es_ES+es_GT+es_HN+es_MX+es_NI+es_PA+es_PE+es_PR+es_SV+es_UY+es_VE+pt+pt_br+pt_BR+pt_PT+gl+lad+an+mwl+it+it_IT+co+nap+scn+vec+sc+ro+la"
|
ROM_GROUP = (
|
||||||
|
"fr+fr_BE+fr_CA+fr_FR+wa+frp+oc+ca+rm+lld+fur+lij+lmo+es+es_AR+es_CL+es_CO+es_CR+es_DO+es_EC+es_ES+es_GT"
|
||||||
|
"+es_HN+es_MX+es_NI+es_PA+es_PE+es_PR+es_SV+es_UY+es_VE+pt+pt_br+pt_BR+pt_PT+gl+lad+an+mwl+it+it_IT+co"
|
||||||
|
"+nap+scn+vec+sc+ro+la"
|
||||||
|
)
|
||||||
GROUPS = [
|
GROUPS = [
|
||||||
("cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh", "ZH"),
|
("cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh", "ZH"),
|
||||||
(ROM_GROUP, "ROMANCE"),
|
(ROM_GROUP, "ROMANCE"),
|
||||||
@ -221,13 +145,15 @@ ORG_NAME = "Helsinki-NLP/"
|
|||||||
|
|
||||||
|
|
||||||
def convert_opus_name_to_hf_name(x):
|
def convert_opus_name_to_hf_name(x):
|
||||||
|
"""For OPUS-MT-Train/ DEPRECATED"""
|
||||||
for substr, grp_name in GROUPS:
|
for substr, grp_name in GROUPS:
|
||||||
x = x.replace(substr, grp_name)
|
x = x.replace(substr, grp_name)
|
||||||
return x.replace("+", "_")
|
return x.replace("+", "_")
|
||||||
|
|
||||||
|
|
||||||
def convert_hf_name_to_opus_name(hf_model_name):
|
def convert_hf_name_to_opus_name(hf_model_name):
|
||||||
"""Relies on the assumption that there are no language codes like pt_br in models that are not in GROUP_TO_OPUS_NAME."""
|
"""Relies on the assumption that there are no language codes like pt_br in models that are not in
|
||||||
|
GROUP_TO_OPUS_NAME."""
|
||||||
hf_model_name = remove_prefix(hf_model_name, ORG_NAME)
|
hf_model_name = remove_prefix(hf_model_name, ORG_NAME)
|
||||||
if hf_model_name in GROUP_TO_OPUS_NAME:
|
if hf_model_name in GROUP_TO_OPUS_NAME:
|
||||||
opus_w_prefix = GROUP_TO_OPUS_NAME[hf_model_name]
|
opus_w_prefix = GROUP_TO_OPUS_NAME[hf_model_name]
|
||||||
@ -247,8 +173,9 @@ def get_system_metadata(repo_root):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
front_matter = """---
|
FRONT_MATTER_TEMPLATE = """---
|
||||||
language: {}
|
language:
|
||||||
|
{}
|
||||||
tags:
|
tags:
|
||||||
- translation
|
- translation
|
||||||
|
|
||||||
@ -256,11 +183,13 @@ license: apache-2.0
|
|||||||
---
|
---
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
DEFAULT_REPO = "Tatoeba-Challenge"
|
||||||
|
DEFAULT_MODEL_DIR = os.path.join(DEFAULT_REPO, "models")
|
||||||
|
|
||||||
|
|
||||||
def write_model_card(
|
def write_model_card(
|
||||||
hf_model_name: str,
|
hf_model_name: str,
|
||||||
repo_root="OPUS-MT-train",
|
repo_root=DEFAULT_REPO,
|
||||||
save_dir=Path("marian_converted"),
|
save_dir=Path("marian_converted"),
|
||||||
dry_run=False,
|
dry_run=False,
|
||||||
extra_metadata={},
|
extra_metadata={},
|
||||||
@ -294,7 +223,10 @@ def write_model_card(
|
|||||||
|
|
||||||
# combine with opus markdown
|
# combine with opus markdown
|
||||||
|
|
||||||
extra_markdown = f"### {hf_model_name}\n\n* source group: {metadata['src_name']} \n* target group: {metadata['tgt_name']} \n* OPUS readme: [{opus_name}]({readme_url})\n"
|
extra_markdown = (
|
||||||
|
f"### {hf_model_name}\n\n* source group: {metadata['src_name']} \n* target group: "
|
||||||
|
f"{metadata['tgt_name']} \n* OPUS readme: [{opus_name}]({readme_url})\n"
|
||||||
|
)
|
||||||
|
|
||||||
content = opus_readme_path.open().read()
|
content = opus_readme_path.open().read()
|
||||||
content = content.split("\n# ")[-1] # Get the lowest level 1 header in the README -- the most recent model.
|
content = content.split("\n# ")[-1] # Get the lowest level 1 header in the README -- the most recent model.
|
||||||
@ -302,7 +234,7 @@ def write_model_card(
|
|||||||
print(splat[3])
|
print(splat[3])
|
||||||
content = "*".join(splat)
|
content = "*".join(splat)
|
||||||
content = (
|
content = (
|
||||||
front_matter.format(metadata["src_alpha2"])
|
FRONT_MATTER_TEMPLATE.format(metadata["src_alpha2"])
|
||||||
+ extra_markdown
|
+ extra_markdown
|
||||||
+ "\n* "
|
+ "\n* "
|
||||||
+ content.replace("download", "download original weights")
|
+ content.replace("download", "download original weights")
|
||||||
@ -323,48 +255,6 @@ def write_model_card(
|
|||||||
return content, metadata
|
return content, metadata
|
||||||
|
|
||||||
|
|
||||||
def get_clean_model_id_mapping(multiling_model_ids):
|
|
||||||
return {x: convert_opus_name_to_hf_name(x) for x in multiling_model_ids}
|
|
||||||
|
|
||||||
|
|
||||||
def expand_group_to_two_letter_codes(grp_name):
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
|
|
||||||
def get_two_letter_code(three_letter_code):
|
|
||||||
raise NotImplementedError()
|
|
||||||
# return two_letter_code
|
|
||||||
|
|
||||||
|
|
||||||
def get_tags(code, ref_name):
|
|
||||||
if len(code) == 2:
|
|
||||||
assert "languages" not in ref_name, f"{code}: {ref_name}"
|
|
||||||
return [code], False
|
|
||||||
elif "languages" in ref_name:
|
|
||||||
group = expand_group_to_two_letter_codes(code)
|
|
||||||
group.append(code)
|
|
||||||
return group, True
|
|
||||||
else: # zho-> zh
|
|
||||||
raise ValueError(f"Three letter monolingual code: {code}")
|
|
||||||
|
|
||||||
|
|
||||||
def resolve_lang_code(r):
|
|
||||||
"""R is a row in ported"""
|
|
||||||
short_pair = r.short_pair
|
|
||||||
src, tgt = short_pair.split("-")
|
|
||||||
src_tags, src_multilingual = get_tags(src, r.src_name)
|
|
||||||
assert isinstance(src_tags, list)
|
|
||||||
tgt_tags, tgt_multilingual = get_tags(src, r.tgt_name)
|
|
||||||
assert isinstance(tgt_tags, list)
|
|
||||||
if src_multilingual:
|
|
||||||
src_tags.append("multilingual_src")
|
|
||||||
if tgt_multilingual:
|
|
||||||
tgt_tags.append("multilingual_tgt")
|
|
||||||
return src_tags + tgt_tags
|
|
||||||
|
|
||||||
# process target
|
|
||||||
|
|
||||||
|
|
||||||
def make_registry(repo_path="Opus-MT-train/models"):
|
def make_registry(repo_path="Opus-MT-train/models"):
|
||||||
if not (Path(repo_path) / "fr-en" / "README.md").exists():
|
if not (Path(repo_path) / "fr-en" / "README.md").exists():
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -382,36 +272,25 @@ def make_registry(repo_path="Opus-MT-train/models"):
|
|||||||
return [(k, v["pre-processing"], v["download"], v["download"][:-4] + ".test.txt") for k, v in results.items()]
|
return [(k, v["pre-processing"], v["download"], v["download"][:-4] + ".test.txt") for k, v in results.items()]
|
||||||
|
|
||||||
|
|
||||||
def make_tatoeba_registry(repo_path="Tatoeba-Challenge/models"):
|
def convert_all_sentencepiece_models(model_list=None, repo_path=None, dest_dir=Path("marian_converted")):
|
||||||
if not (Path(repo_path) / "zho-eng" / "README.md").exists():
|
|
||||||
raise ValueError(
|
|
||||||
f"repo_path:{repo_path} does not exist: "
|
|
||||||
"You must run: git clone git@github.com:Helsinki-NLP/Tatoeba-Challenge.git before calling."
|
|
||||||
)
|
|
||||||
results = {}
|
|
||||||
for p in Path(repo_path).iterdir():
|
|
||||||
if len(p.name) != 7:
|
|
||||||
continue
|
|
||||||
lns = list(open(p / "README.md").readlines())
|
|
||||||
results[p.name] = _parse_readme(lns)
|
|
||||||
return [(k, v["pre-processing"], v["download"], v["download"][:-4] + ".test.txt") for k, v in results.items()]
|
|
||||||
|
|
||||||
|
|
||||||
def convert_all_sentencepiece_models(model_list=None, repo_path=None):
|
|
||||||
"""Requires 300GB"""
|
"""Requires 300GB"""
|
||||||
save_dir = Path("marian_ckpt")
|
save_dir = Path("marian_ckpt")
|
||||||
dest_dir = Path("marian_converted")
|
dest_dir = Path(dest_dir)
|
||||||
dest_dir.mkdir(exist_ok=True)
|
dest_dir.mkdir(exist_ok=True)
|
||||||
|
save_paths = []
|
||||||
if model_list is None:
|
if model_list is None:
|
||||||
model_list: list = make_registry(repo_path=repo_path)
|
model_list: list = make_registry(repo_path=repo_path)
|
||||||
for k, prepro, download, test_set_url in tqdm(model_list):
|
for k, prepro, download, test_set_url in tqdm(model_list):
|
||||||
if "SentencePiece" not in prepro: # dont convert BPE models.
|
if "SentencePiece" not in prepro: # dont convert BPE models.
|
||||||
continue
|
continue
|
||||||
if not os.path.exists(save_dir / k / "pytorch_model.bin"):
|
if not os.path.exists(save_dir / k):
|
||||||
download_and_unzip(download, save_dir / k)
|
download_and_unzip(download, save_dir / k)
|
||||||
pair_name = convert_opus_name_to_hf_name(k)
|
pair_name = convert_opus_name_to_hf_name(k)
|
||||||
convert(save_dir / k, dest_dir / f"opus-mt-{pair_name}")
|
convert(save_dir / k, dest_dir / f"opus-mt-{pair_name}")
|
||||||
|
|
||||||
|
save_paths.append(dest_dir / f"opus-mt-{pair_name}")
|
||||||
|
return save_paths
|
||||||
|
|
||||||
|
|
||||||
def lmap(f, x) -> List:
|
def lmap(f, x) -> List:
|
||||||
return list(map(f, x))
|
return list(map(f, x))
|
||||||
@ -493,15 +372,6 @@ def add_special_tokens_to_vocab(model_dir: Path) -> None:
|
|||||||
save_tokenizer_config(model_dir)
|
save_tokenizer_config(model_dir)
|
||||||
|
|
||||||
|
|
||||||
def save_tokenizer(self, save_directory):
|
|
||||||
dest = Path(save_directory)
|
|
||||||
src_path = Path(self.init_kwargs["source_spm"])
|
|
||||||
|
|
||||||
for dest_name in {"source.spm", "target.spm", "tokenizer_config.json"}:
|
|
||||||
shutil.copyfile(src_path.parent / dest_name, dest / dest_name)
|
|
||||||
save_json(self.encoder, dest / "vocab.json")
|
|
||||||
|
|
||||||
|
|
||||||
def check_equal(marian_cfg, k1, k2):
|
def check_equal(marian_cfg, k1, k2):
|
||||||
v1, v2 = marian_cfg[k1], marian_cfg[k2]
|
v1, v2 = marian_cfg[k1], marian_cfg[k2]
|
||||||
assert v1 == v2, f"hparams {k1},{k2} differ: {v1} != {v2}"
|
assert v1 == v2, f"hparams {k1},{k2} differ: {v1} != {v2}"
|
||||||
@ -698,14 +568,14 @@ def convert(source_dir: Path, dest_dir):
|
|||||||
|
|
||||||
add_special_tokens_to_vocab(source_dir)
|
add_special_tokens_to_vocab(source_dir)
|
||||||
tokenizer = MarianTokenizer.from_pretrained(str(source_dir))
|
tokenizer = MarianTokenizer.from_pretrained(str(source_dir))
|
||||||
save_tokenizer(tokenizer, dest_dir)
|
tokenizer.save_pretrained(dest_dir)
|
||||||
|
|
||||||
opus_state = OpusState(source_dir)
|
opus_state = OpusState(source_dir)
|
||||||
assert opus_state.cfg["vocab_size"] == len(
|
assert opus_state.cfg["vocab_size"] == len(
|
||||||
tokenizer.encoder
|
tokenizer.encoder
|
||||||
), f"Original vocab size {opus_state.cfg['vocab_size']} and new vocab size {len(tokenizer.encoder)} mismatched"
|
), f"Original vocab size {opus_state.cfg['vocab_size']} and new vocab size {len(tokenizer.encoder)} mismatched"
|
||||||
# save_json(opus_state.cfg, dest_dir / "marian_original_config.json")
|
# save_json(opus_state.cfg, dest_dir / "marian_original_config.json")
|
||||||
# ^^ Save human readable marian config for debugging
|
# ^^ Uncomment to save human readable marian config for debugging
|
||||||
|
|
||||||
model = opus_state.load_marian_model()
|
model = opus_state.load_marian_model()
|
||||||
model = model.half()
|
model = model.half()
|
||||||
@ -732,15 +602,11 @@ def unzip(zip_path: str, dest_dir: str) -> None:
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
"""
|
"""
|
||||||
To bulk convert, run
|
Tatoeba conversion instructions in scripts/tatoeba/README.md
|
||||||
>>> from transformers.convert_marian_to_pytorch import make_tatoeba_registry, convert_all_sentencepiece_models
|
|
||||||
>>> reg = make_tatoeba_registry()
|
|
||||||
>>> convert_all_sentencepiece_models(model_list=reg) # saves to marian_converted
|
|
||||||
(bash) aws s3 sync marian_converted s3://models.huggingface.co/bert/Helsinki-NLP/ --dryrun
|
|
||||||
"""
|
"""
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
# Required parameters
|
# Required parameters
|
||||||
parser.add_argument("--src", type=str, help="path to marian model dir", default="en-de")
|
parser.add_argument("--src", type=str, help="path to marian model sub dir", default="en-de")
|
||||||
parser.add_argument("--dest", type=str, default=None, help="Path to the output PyTorch model.")
|
parser.add_argument("--dest", type=str, default=None, help="Path to the output PyTorch model.")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user