mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
[marian] Automate Tatoeba-Challenge conversion (#7709)
This commit is contained in:
parent
aacac8f708
commit
9c2b2db2cd
1
.gitignore
vendored
1
.gitignore
vendored
@ -12,6 +12,7 @@ __pycache__/
|
||||
tests/fixtures
|
||||
logs/
|
||||
lightning_logs/
|
||||
lang_code_data/
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
|
22
examples/seq2seq/test_tatoeba_conversion.py
Normal file
22
examples/seq2seq/test_tatoeba_conversion.py
Normal file
@ -0,0 +1,22 @@
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers.convert_marian_tatoeba_to_pytorch import TatoebaConverter
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.testing_utils import slow
|
||||
|
||||
|
||||
class TatoebaConversionTester(unittest.TestCase):
|
||||
@cached_property
|
||||
def resolver(self):
|
||||
tmp_dir = tempfile.mkdtemp()
|
||||
return TatoebaConverter(save_dir=tmp_dir)
|
||||
|
||||
@slow
|
||||
def test_resolver(self):
|
||||
self.resolver.convert_models(["heb-eng"])
|
||||
|
||||
@slow
|
||||
def test_model_card(self):
|
||||
content, mmeta = self.resolver.write_model_card("opus-mt-he-en", dry_run=True)
|
||||
assert mmeta["long_pair"] == "heb-eng"
|
44
scripts/tatoeba/README.md
Normal file
44
scripts/tatoeba/README.md
Normal file
@ -0,0 +1,44 @@
|
||||
Setup transformers following instructions in README.md, (I would fork first).
|
||||
```bash
|
||||
git clone git@github.com:huggingface/transformers.git
|
||||
cd transformers
|
||||
pip install -e .
|
||||
pip install pandas
|
||||
```
|
||||
|
||||
Get required metadata
|
||||
```
|
||||
curl https://cdn-datasets.huggingface.co/language_codes/language-codes-3b2.csv > language-codes-3b2.csv
|
||||
curl https://cdn-datasets.huggingface.co/language_codes/iso-639-3.csv > iso-639-3.csv
|
||||
```
|
||||
|
||||
Install Tatoeba-Challenge repo inside transformers
|
||||
```bash
|
||||
git clone git@github.com:Helsinki-NLP/Tatoeba-Challenge.git
|
||||
```
|
||||
|
||||
To convert a few models, call the conversion script from command line:
|
||||
```bash
|
||||
python src/transformers/convert_marian_tatoeba_to_pytorch.py --models heb-eng eng-heb --save_dir converted
|
||||
```
|
||||
|
||||
To convert lots of models you can pass your list of Tatoeba model names to `resolver.convert_models` in a python client or script.
|
||||
|
||||
```python
|
||||
from transformers.convert_marian_tatoeba_to_pytorch import TatoebaConverter
|
||||
resolver = TatoebaConverter(save_dir='converted')
|
||||
resolver.convert_models(['heb-eng', 'eng-heb'])
|
||||
```
|
||||
|
||||
|
||||
### Upload converted models
|
||||
```bash
|
||||
cd converted
|
||||
transformers-cli login
|
||||
for FILE in *; do transformers-cli upload $FILE; done
|
||||
```
|
||||
|
||||
|
||||
### Modifications
|
||||
- To change naming logic, change the code near `os.rename`. The model card creation code may also need to change.
|
||||
- To change model card content, you must modify `TatoebaCodeResolver.write_model_card`
|
1249
src/transformers/convert_marian_tatoeba_to_pytorch.py
Normal file
1249
src/transformers/convert_marian_tatoeba_to_pytorch.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -1,12 +1,11 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import socket
|
||||
import time
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple, Union
|
||||
from typing import Dict, List, Union
|
||||
from zipfile import ZipFile
|
||||
|
||||
import numpy as np
|
||||
@ -23,85 +22,6 @@ def remove_suffix(text: str, suffix: str):
|
||||
return text # or whatever
|
||||
|
||||
|
||||
def _process_benchmark_table_row(x):
|
||||
fields = lmap(str.strip, x.replace("\t", "").split("|")[1:-1])
|
||||
assert len(fields) == 3
|
||||
return (fields[0], float(fields[1]), float(fields[2]))
|
||||
|
||||
|
||||
def process_last_benchmark_table(readme_path) -> List[Tuple[str, float, float]]:
|
||||
md_content = Path(readme_path).open().read()
|
||||
entries = md_content.split("## Benchmarks")[-1].strip().split("\n")[2:]
|
||||
data = lmap(_process_benchmark_table_row, entries)
|
||||
return data
|
||||
|
||||
|
||||
def check_if_models_are_dominated(old_repo_path="OPUS-MT-train/models", new_repo_path="Tatoeba-Challenge/models/"):
|
||||
"""Make a blacklist for models where we have already ported the same language pair, and the ported model has higher BLEU score."""
|
||||
import pandas as pd
|
||||
|
||||
newest_released, old_reg, released = get_released_df(new_repo_path, old_repo_path)
|
||||
|
||||
short_to_new_bleu = newest_released.set_index("short_pair").bleu
|
||||
|
||||
assert released.groupby("short_pair").pair.nunique().max() == 1
|
||||
|
||||
short_to_long = released.groupby("short_pair").pair.first().to_dict()
|
||||
|
||||
overlap_short = old_reg.index.intersection(released.short_pair.unique())
|
||||
overlap_long = [short_to_long[o] for o in overlap_short]
|
||||
new_reported_bleu = [short_to_new_bleu[o] for o in overlap_short]
|
||||
|
||||
def get_old_bleu(o) -> float:
|
||||
pat = old_repo_path + "/{}/README.md"
|
||||
bm_data = process_last_benchmark_table(pat.format(o))
|
||||
tab = pd.DataFrame(bm_data, columns=["testset", "bleu", "chr-f"])
|
||||
tato_bleu = tab.loc[lambda x: x.testset.str.startswith("Tato")].bleu
|
||||
if tato_bleu.shape[0] > 0:
|
||||
return tato_bleu.iloc[0]
|
||||
else:
|
||||
return np.nan
|
||||
|
||||
old_bleu = [get_old_bleu(o) for o in overlap_short]
|
||||
cmp_df = pd.DataFrame(
|
||||
dict(short=overlap_short, long=overlap_long, old_bleu=old_bleu, new_bleu=new_reported_bleu)
|
||||
).fillna(-1)
|
||||
|
||||
dominated = cmp_df[cmp_df.old_bleu > cmp_df.new_bleu]
|
||||
whitelist_df = cmp_df[cmp_df.old_bleu <= cmp_df.new_bleu]
|
||||
blacklist = dominated.long.unique().tolist() # 3 letter codes
|
||||
return whitelist_df, dominated, blacklist
|
||||
|
||||
|
||||
def get_released_df(new_repo_path, old_repo_path):
|
||||
import pandas as pd
|
||||
|
||||
released_cols = [
|
||||
"url_base",
|
||||
"pair", # (ISO639-3/ISO639-5 codes),
|
||||
"short_pair", # (reduced codes),
|
||||
"chrF2_score",
|
||||
"bleu",
|
||||
"brevity_penalty",
|
||||
"ref_len",
|
||||
"src_name",
|
||||
"tgt_name",
|
||||
]
|
||||
released = pd.read_csv(f"{new_repo_path}/released-models.txt", sep="\t", header=None).iloc[:-1]
|
||||
released.columns = released_cols
|
||||
old_reg = make_registry(repo_path=old_repo_path)
|
||||
old_reg = pd.DataFrame(old_reg, columns=["id", "prepro", "url_model", "url_test_set"])
|
||||
assert old_reg.id.value_counts().max() == 1
|
||||
old_reg = old_reg.set_index("id")
|
||||
released["fname"] = released["url_base"].apply(
|
||||
lambda x: remove_suffix(remove_prefix(x, "https://object.pouta.csc.fi/Tatoeba-Challenge/opus"), ".zip")
|
||||
)
|
||||
released["2m"] = released.fname.str.startswith("2m")
|
||||
released["date"] = pd.to_datetime(released["fname"].apply(lambda x: remove_prefix(remove_prefix(x, "2m-"), "-")))
|
||||
newest_released = released.dsort("date").drop_duplicates(["short_pair"], keep="first")
|
||||
return newest_released, old_reg, released
|
||||
|
||||
|
||||
def remove_prefix(text: str, prefix: str):
|
||||
if text.startswith(prefix):
|
||||
return text[len(prefix) :]
|
||||
@ -183,7 +103,11 @@ def find_model_file(dest_dir): # this one better
|
||||
|
||||
|
||||
# Group Names Logic: change long opus model names to something shorter, like opus-mt-en-ROMANCE
|
||||
ROM_GROUP = "fr+fr_BE+fr_CA+fr_FR+wa+frp+oc+ca+rm+lld+fur+lij+lmo+es+es_AR+es_CL+es_CO+es_CR+es_DO+es_EC+es_ES+es_GT+es_HN+es_MX+es_NI+es_PA+es_PE+es_PR+es_SV+es_UY+es_VE+pt+pt_br+pt_BR+pt_PT+gl+lad+an+mwl+it+it_IT+co+nap+scn+vec+sc+ro+la"
|
||||
ROM_GROUP = (
|
||||
"fr+fr_BE+fr_CA+fr_FR+wa+frp+oc+ca+rm+lld+fur+lij+lmo+es+es_AR+es_CL+es_CO+es_CR+es_DO+es_EC+es_ES+es_GT"
|
||||
"+es_HN+es_MX+es_NI+es_PA+es_PE+es_PR+es_SV+es_UY+es_VE+pt+pt_br+pt_BR+pt_PT+gl+lad+an+mwl+it+it_IT+co"
|
||||
"+nap+scn+vec+sc+ro+la"
|
||||
)
|
||||
GROUPS = [
|
||||
("cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh", "ZH"),
|
||||
(ROM_GROUP, "ROMANCE"),
|
||||
@ -221,13 +145,15 @@ ORG_NAME = "Helsinki-NLP/"
|
||||
|
||||
|
||||
def convert_opus_name_to_hf_name(x):
|
||||
"""For OPUS-MT-Train/ DEPRECATED"""
|
||||
for substr, grp_name in GROUPS:
|
||||
x = x.replace(substr, grp_name)
|
||||
return x.replace("+", "_")
|
||||
|
||||
|
||||
def convert_hf_name_to_opus_name(hf_model_name):
|
||||
"""Relies on the assumption that there are no language codes like pt_br in models that are not in GROUP_TO_OPUS_NAME."""
|
||||
"""Relies on the assumption that there are no language codes like pt_br in models that are not in
|
||||
GROUP_TO_OPUS_NAME."""
|
||||
hf_model_name = remove_prefix(hf_model_name, ORG_NAME)
|
||||
if hf_model_name in GROUP_TO_OPUS_NAME:
|
||||
opus_w_prefix = GROUP_TO_OPUS_NAME[hf_model_name]
|
||||
@ -247,8 +173,9 @@ def get_system_metadata(repo_root):
|
||||
)
|
||||
|
||||
|
||||
front_matter = """---
|
||||
language: {}
|
||||
FRONT_MATTER_TEMPLATE = """---
|
||||
language:
|
||||
{}
|
||||
tags:
|
||||
- translation
|
||||
|
||||
@ -256,11 +183,13 @@ license: apache-2.0
|
||||
---
|
||||
|
||||
"""
|
||||
DEFAULT_REPO = "Tatoeba-Challenge"
|
||||
DEFAULT_MODEL_DIR = os.path.join(DEFAULT_REPO, "models")
|
||||
|
||||
|
||||
def write_model_card(
|
||||
hf_model_name: str,
|
||||
repo_root="OPUS-MT-train",
|
||||
repo_root=DEFAULT_REPO,
|
||||
save_dir=Path("marian_converted"),
|
||||
dry_run=False,
|
||||
extra_metadata={},
|
||||
@ -294,7 +223,10 @@ def write_model_card(
|
||||
|
||||
# combine with opus markdown
|
||||
|
||||
extra_markdown = f"### {hf_model_name}\n\n* source group: {metadata['src_name']} \n* target group: {metadata['tgt_name']} \n* OPUS readme: [{opus_name}]({readme_url})\n"
|
||||
extra_markdown = (
|
||||
f"### {hf_model_name}\n\n* source group: {metadata['src_name']} \n* target group: "
|
||||
f"{metadata['tgt_name']} \n* OPUS readme: [{opus_name}]({readme_url})\n"
|
||||
)
|
||||
|
||||
content = opus_readme_path.open().read()
|
||||
content = content.split("\n# ")[-1] # Get the lowest level 1 header in the README -- the most recent model.
|
||||
@ -302,7 +234,7 @@ def write_model_card(
|
||||
print(splat[3])
|
||||
content = "*".join(splat)
|
||||
content = (
|
||||
front_matter.format(metadata["src_alpha2"])
|
||||
FRONT_MATTER_TEMPLATE.format(metadata["src_alpha2"])
|
||||
+ extra_markdown
|
||||
+ "\n* "
|
||||
+ content.replace("download", "download original weights")
|
||||
@ -323,48 +255,6 @@ def write_model_card(
|
||||
return content, metadata
|
||||
|
||||
|
||||
def get_clean_model_id_mapping(multiling_model_ids):
|
||||
return {x: convert_opus_name_to_hf_name(x) for x in multiling_model_ids}
|
||||
|
||||
|
||||
def expand_group_to_two_letter_codes(grp_name):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def get_two_letter_code(three_letter_code):
|
||||
raise NotImplementedError()
|
||||
# return two_letter_code
|
||||
|
||||
|
||||
def get_tags(code, ref_name):
|
||||
if len(code) == 2:
|
||||
assert "languages" not in ref_name, f"{code}: {ref_name}"
|
||||
return [code], False
|
||||
elif "languages" in ref_name:
|
||||
group = expand_group_to_two_letter_codes(code)
|
||||
group.append(code)
|
||||
return group, True
|
||||
else: # zho-> zh
|
||||
raise ValueError(f"Three letter monolingual code: {code}")
|
||||
|
||||
|
||||
def resolve_lang_code(r):
|
||||
"""R is a row in ported"""
|
||||
short_pair = r.short_pair
|
||||
src, tgt = short_pair.split("-")
|
||||
src_tags, src_multilingual = get_tags(src, r.src_name)
|
||||
assert isinstance(src_tags, list)
|
||||
tgt_tags, tgt_multilingual = get_tags(src, r.tgt_name)
|
||||
assert isinstance(tgt_tags, list)
|
||||
if src_multilingual:
|
||||
src_tags.append("multilingual_src")
|
||||
if tgt_multilingual:
|
||||
tgt_tags.append("multilingual_tgt")
|
||||
return src_tags + tgt_tags
|
||||
|
||||
# process target
|
||||
|
||||
|
||||
def make_registry(repo_path="Opus-MT-train/models"):
|
||||
if not (Path(repo_path) / "fr-en" / "README.md").exists():
|
||||
raise ValueError(
|
||||
@ -382,36 +272,25 @@ def make_registry(repo_path="Opus-MT-train/models"):
|
||||
return [(k, v["pre-processing"], v["download"], v["download"][:-4] + ".test.txt") for k, v in results.items()]
|
||||
|
||||
|
||||
def make_tatoeba_registry(repo_path="Tatoeba-Challenge/models"):
|
||||
if not (Path(repo_path) / "zho-eng" / "README.md").exists():
|
||||
raise ValueError(
|
||||
f"repo_path:{repo_path} does not exist: "
|
||||
"You must run: git clone git@github.com:Helsinki-NLP/Tatoeba-Challenge.git before calling."
|
||||
)
|
||||
results = {}
|
||||
for p in Path(repo_path).iterdir():
|
||||
if len(p.name) != 7:
|
||||
continue
|
||||
lns = list(open(p / "README.md").readlines())
|
||||
results[p.name] = _parse_readme(lns)
|
||||
return [(k, v["pre-processing"], v["download"], v["download"][:-4] + ".test.txt") for k, v in results.items()]
|
||||
|
||||
|
||||
def convert_all_sentencepiece_models(model_list=None, repo_path=None):
|
||||
def convert_all_sentencepiece_models(model_list=None, repo_path=None, dest_dir=Path("marian_converted")):
|
||||
"""Requires 300GB"""
|
||||
save_dir = Path("marian_ckpt")
|
||||
dest_dir = Path("marian_converted")
|
||||
dest_dir = Path(dest_dir)
|
||||
dest_dir.mkdir(exist_ok=True)
|
||||
save_paths = []
|
||||
if model_list is None:
|
||||
model_list: list = make_registry(repo_path=repo_path)
|
||||
for k, prepro, download, test_set_url in tqdm(model_list):
|
||||
if "SentencePiece" not in prepro: # dont convert BPE models.
|
||||
continue
|
||||
if not os.path.exists(save_dir / k / "pytorch_model.bin"):
|
||||
if not os.path.exists(save_dir / k):
|
||||
download_and_unzip(download, save_dir / k)
|
||||
pair_name = convert_opus_name_to_hf_name(k)
|
||||
convert(save_dir / k, dest_dir / f"opus-mt-{pair_name}")
|
||||
|
||||
save_paths.append(dest_dir / f"opus-mt-{pair_name}")
|
||||
return save_paths
|
||||
|
||||
|
||||
def lmap(f, x) -> List:
|
||||
return list(map(f, x))
|
||||
@ -493,15 +372,6 @@ def add_special_tokens_to_vocab(model_dir: Path) -> None:
|
||||
save_tokenizer_config(model_dir)
|
||||
|
||||
|
||||
def save_tokenizer(self, save_directory):
|
||||
dest = Path(save_directory)
|
||||
src_path = Path(self.init_kwargs["source_spm"])
|
||||
|
||||
for dest_name in {"source.spm", "target.spm", "tokenizer_config.json"}:
|
||||
shutil.copyfile(src_path.parent / dest_name, dest / dest_name)
|
||||
save_json(self.encoder, dest / "vocab.json")
|
||||
|
||||
|
||||
def check_equal(marian_cfg, k1, k2):
|
||||
v1, v2 = marian_cfg[k1], marian_cfg[k2]
|
||||
assert v1 == v2, f"hparams {k1},{k2} differ: {v1} != {v2}"
|
||||
@ -698,14 +568,14 @@ def convert(source_dir: Path, dest_dir):
|
||||
|
||||
add_special_tokens_to_vocab(source_dir)
|
||||
tokenizer = MarianTokenizer.from_pretrained(str(source_dir))
|
||||
save_tokenizer(tokenizer, dest_dir)
|
||||
tokenizer.save_pretrained(dest_dir)
|
||||
|
||||
opus_state = OpusState(source_dir)
|
||||
assert opus_state.cfg["vocab_size"] == len(
|
||||
tokenizer.encoder
|
||||
), f"Original vocab size {opus_state.cfg['vocab_size']} and new vocab size {len(tokenizer.encoder)} mismatched"
|
||||
# save_json(opus_state.cfg, dest_dir / "marian_original_config.json")
|
||||
# ^^ Save human readable marian config for debugging
|
||||
# ^^ Uncomment to save human readable marian config for debugging
|
||||
|
||||
model = opus_state.load_marian_model()
|
||||
model = model.half()
|
||||
@ -732,15 +602,11 @@ def unzip(zip_path: str, dest_dir: str) -> None:
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
To bulk convert, run
|
||||
>>> from transformers.convert_marian_to_pytorch import make_tatoeba_registry, convert_all_sentencepiece_models
|
||||
>>> reg = make_tatoeba_registry()
|
||||
>>> convert_all_sentencepiece_models(model_list=reg) # saves to marian_converted
|
||||
(bash) aws s3 sync marian_converted s3://models.huggingface.co/bert/Helsinki-NLP/ --dryrun
|
||||
Tatoeba conversion instructions in scripts/tatoeba/README.md
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument("--src", type=str, help="path to marian model dir", default="en-de")
|
||||
parser.add_argument("--src", type=str, help="path to marian model sub dir", default="en-de")
|
||||
parser.add_argument("--dest", type=str, default=None, help="Path to the output PyTorch model.")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user