wip: Code to add lang tags to marian model cards (#6586)

This commit is contained in:
Sam Shleifer 2020-09-23 18:11:06 -04:00 committed by GitHub
parent 129fdae040
commit 38f1703795
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -40,33 +40,7 @@ def check_if_models_are_dominated(old_repo_path="OPUS-MT-train/models", new_repo
"""Make a blacklist for models where we have already ported the same language pair, and the ported model has higher BLEU score."""
import pandas as pd
released_cols = [
"url_base",
"pair", # (ISO639-3/ISO639-5 codes),
"short_pair", # (reduced codes),
"chrF2_score",
"bleu",
"brevity_penalty",
"ref_len",
"src_name",
"tgt_name",
]
released = pd.read_csv(f"{new_repo_path}/released-models.txt", sep="\t", header=None).iloc[:-1]
released.columns = released_cols
old_reg = make_registry(repo_path=old_repo_path)
old_reg = pd.DataFrame(old_reg, columns=["id", "prepro", "url_model", "url_test_set"])
assert old_reg.id.value_counts().max() == 1
old_reg = old_reg.set_index("id")
released["fname"] = released["url_base"].apply(
lambda x: remove_suffix(remove_prefix(x, "https://object.pouta.csc.fi/Tatoeba-Challenge/opus"), ".zip")
)
released["2m"] = released.fname.str.startswith("2m")
released["date"] = pd.to_datetime(released["fname"].apply(lambda x: remove_prefix(remove_prefix(x, "2m-"), "-")))
newest_released = released.dsort("date").drop_duplicates(["short_pair"], keep="first")
newest_released, old_reg, released = get_released_df(new_repo_path, old_repo_path)
short_to_new_bleu = newest_released.set_index("short_pair").bleu
@ -94,8 +68,38 @@ def check_if_models_are_dominated(old_repo_path="OPUS-MT-train/models", new_repo
).fillna(-1)
dominated = cmp_df[cmp_df.old_bleu > cmp_df.new_bleu]
whitelist_df = cmp_df[cmp_df.old_bleu <= cmp_df.new_bleu]
blacklist = dominated.long.unique().tolist() # 3 letter codes
return dominated, blacklist
return whitelist_df, dominated, blacklist
def get_released_df(new_repo_path, old_repo_path):
import pandas as pd
released_cols = [
"url_base",
"pair", # (ISO639-3/ISO639-5 codes),
"short_pair", # (reduced codes),
"chrF2_score",
"bleu",
"brevity_penalty",
"ref_len",
"src_name",
"tgt_name",
]
released = pd.read_csv(f"{new_repo_path}/released-models.txt", sep="\t", header=None).iloc[:-1]
released.columns = released_cols
old_reg = make_registry(repo_path=old_repo_path)
old_reg = pd.DataFrame(old_reg, columns=["id", "prepro", "url_model", "url_test_set"])
assert old_reg.id.value_counts().max() == 1
old_reg = old_reg.set_index("id")
released["fname"] = released["url_base"].apply(
lambda x: remove_suffix(remove_prefix(x, "https://object.pouta.csc.fi/Tatoeba-Challenge/opus"), ".zip")
)
released["2m"] = released.fname.str.startswith("2m")
released["date"] = pd.to_datetime(released["fname"].apply(lambda x: remove_prefix(remove_prefix(x, "2m-"), "-")))
newest_released = released.dsort("date").drop_duplicates(["short_pair"], keep="first")
return newest_released, old_reg, released
def remove_prefix(text: str, prefix: str):
@ -323,6 +327,44 @@ def get_clean_model_id_mapping(multiling_model_ids):
return {x: convert_opus_name_to_hf_name(x) for x in multiling_model_ids}
def expand_group_to_two_letter_codes(grp_name):
raise NotImplementedError()
def get_two_letter_code(three_letter_code):
raise NotImplementedError()
# return two_letter_code
def get_tags(code, ref_name):
if len(code) == 2:
assert "languages" not in ref_name, f"{code}: {ref_name}"
return [code], False
elif "languages" in ref_name:
group = expand_group_to_two_letter_codes(code)
group.append(code)
return group, True
else: # zho-> zh
raise ValueError(f"Three letter monolingual code: {code}")
def resolve_lang_code(r):
"""R is a row in ported"""
short_pair = r.short_pair
src, tgt = short_pair.split("-")
src_tags, src_multilingual = get_tags(src, r.src_name)
assert isinstance(src_tags, list)
tgt_tags, tgt_multilingual = get_tags(src, r.tgt_name)
assert isinstance(tgt_tags, list)
if src_multilingual:
src_tags.append("multilingual_src")
if tgt_multilingual:
tgt_tags.append("multilingual_tgt")
return src_tags + tgt_tags
# process target
def make_registry(repo_path="Opus-MT-train/models"):
if not (Path(repo_path) / "fr-en" / "README.md").exists():
raise ValueError(
@ -666,6 +708,7 @@ def convert(source_dir: Path, dest_dir):
# ^^ Save human readable marian config for debugging
model = opus_state.load_marian_model()
model = model.half()
model.save_pretrained(dest_dir)
model.from_pretrained(dest_dir) # sanity check