mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
wip: Code to add lang tags to marian model cards (#6586)
This commit is contained in:
parent
129fdae040
commit
38f1703795
@ -40,33 +40,7 @@ def check_if_models_are_dominated(old_repo_path="OPUS-MT-train/models", new_repo
|
|||||||
"""Make a blacklist for models where we have already ported the same language pair, and the ported model has higher BLEU score."""
|
"""Make a blacklist for models where we have already ported the same language pair, and the ported model has higher BLEU score."""
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
released_cols = [
|
newest_released, old_reg, released = get_released_df(new_repo_path, old_repo_path)
|
||||||
"url_base",
|
|
||||||
"pair", # (ISO639-3/ISO639-5 codes),
|
|
||||||
"short_pair", # (reduced codes),
|
|
||||||
"chrF2_score",
|
|
||||||
"bleu",
|
|
||||||
"brevity_penalty",
|
|
||||||
"ref_len",
|
|
||||||
"src_name",
|
|
||||||
"tgt_name",
|
|
||||||
]
|
|
||||||
|
|
||||||
released = pd.read_csv(f"{new_repo_path}/released-models.txt", sep="\t", header=None).iloc[:-1]
|
|
||||||
released.columns = released_cols
|
|
||||||
old_reg = make_registry(repo_path=old_repo_path)
|
|
||||||
old_reg = pd.DataFrame(old_reg, columns=["id", "prepro", "url_model", "url_test_set"])
|
|
||||||
assert old_reg.id.value_counts().max() == 1
|
|
||||||
old_reg = old_reg.set_index("id")
|
|
||||||
|
|
||||||
released["fname"] = released["url_base"].apply(
|
|
||||||
lambda x: remove_suffix(remove_prefix(x, "https://object.pouta.csc.fi/Tatoeba-Challenge/opus"), ".zip")
|
|
||||||
)
|
|
||||||
|
|
||||||
released["2m"] = released.fname.str.startswith("2m")
|
|
||||||
released["date"] = pd.to_datetime(released["fname"].apply(lambda x: remove_prefix(remove_prefix(x, "2m-"), "-")))
|
|
||||||
|
|
||||||
newest_released = released.dsort("date").drop_duplicates(["short_pair"], keep="first")
|
|
||||||
|
|
||||||
short_to_new_bleu = newest_released.set_index("short_pair").bleu
|
short_to_new_bleu = newest_released.set_index("short_pair").bleu
|
||||||
|
|
||||||
@ -94,8 +68,38 @@ def check_if_models_are_dominated(old_repo_path="OPUS-MT-train/models", new_repo
|
|||||||
).fillna(-1)
|
).fillna(-1)
|
||||||
|
|
||||||
dominated = cmp_df[cmp_df.old_bleu > cmp_df.new_bleu]
|
dominated = cmp_df[cmp_df.old_bleu > cmp_df.new_bleu]
|
||||||
|
whitelist_df = cmp_df[cmp_df.old_bleu <= cmp_df.new_bleu]
|
||||||
blacklist = dominated.long.unique().tolist() # 3 letter codes
|
blacklist = dominated.long.unique().tolist() # 3 letter codes
|
||||||
return dominated, blacklist
|
return whitelist_df, dominated, blacklist
|
||||||
|
|
||||||
|
|
||||||
|
def get_released_df(new_repo_path, old_repo_path):
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
released_cols = [
|
||||||
|
"url_base",
|
||||||
|
"pair", # (ISO639-3/ISO639-5 codes),
|
||||||
|
"short_pair", # (reduced codes),
|
||||||
|
"chrF2_score",
|
||||||
|
"bleu",
|
||||||
|
"brevity_penalty",
|
||||||
|
"ref_len",
|
||||||
|
"src_name",
|
||||||
|
"tgt_name",
|
||||||
|
]
|
||||||
|
released = pd.read_csv(f"{new_repo_path}/released-models.txt", sep="\t", header=None).iloc[:-1]
|
||||||
|
released.columns = released_cols
|
||||||
|
old_reg = make_registry(repo_path=old_repo_path)
|
||||||
|
old_reg = pd.DataFrame(old_reg, columns=["id", "prepro", "url_model", "url_test_set"])
|
||||||
|
assert old_reg.id.value_counts().max() == 1
|
||||||
|
old_reg = old_reg.set_index("id")
|
||||||
|
released["fname"] = released["url_base"].apply(
|
||||||
|
lambda x: remove_suffix(remove_prefix(x, "https://object.pouta.csc.fi/Tatoeba-Challenge/opus"), ".zip")
|
||||||
|
)
|
||||||
|
released["2m"] = released.fname.str.startswith("2m")
|
||||||
|
released["date"] = pd.to_datetime(released["fname"].apply(lambda x: remove_prefix(remove_prefix(x, "2m-"), "-")))
|
||||||
|
newest_released = released.dsort("date").drop_duplicates(["short_pair"], keep="first")
|
||||||
|
return newest_released, old_reg, released
|
||||||
|
|
||||||
|
|
||||||
def remove_prefix(text: str, prefix: str):
|
def remove_prefix(text: str, prefix: str):
|
||||||
@ -323,6 +327,44 @@ def get_clean_model_id_mapping(multiling_model_ids):
|
|||||||
return {x: convert_opus_name_to_hf_name(x) for x in multiling_model_ids}
|
return {x: convert_opus_name_to_hf_name(x) for x in multiling_model_ids}
|
||||||
|
|
||||||
|
|
||||||
|
def expand_group_to_two_letter_codes(grp_name):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
def get_two_letter_code(three_letter_code):
|
||||||
|
raise NotImplementedError()
|
||||||
|
# return two_letter_code
|
||||||
|
|
||||||
|
|
||||||
|
def get_tags(code, ref_name):
|
||||||
|
if len(code) == 2:
|
||||||
|
assert "languages" not in ref_name, f"{code}: {ref_name}"
|
||||||
|
return [code], False
|
||||||
|
elif "languages" in ref_name:
|
||||||
|
group = expand_group_to_two_letter_codes(code)
|
||||||
|
group.append(code)
|
||||||
|
return group, True
|
||||||
|
else: # zho-> zh
|
||||||
|
raise ValueError(f"Three letter monolingual code: {code}")
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_lang_code(r):
|
||||||
|
"""R is a row in ported"""
|
||||||
|
short_pair = r.short_pair
|
||||||
|
src, tgt = short_pair.split("-")
|
||||||
|
src_tags, src_multilingual = get_tags(src, r.src_name)
|
||||||
|
assert isinstance(src_tags, list)
|
||||||
|
tgt_tags, tgt_multilingual = get_tags(src, r.tgt_name)
|
||||||
|
assert isinstance(tgt_tags, list)
|
||||||
|
if src_multilingual:
|
||||||
|
src_tags.append("multilingual_src")
|
||||||
|
if tgt_multilingual:
|
||||||
|
tgt_tags.append("multilingual_tgt")
|
||||||
|
return src_tags + tgt_tags
|
||||||
|
|
||||||
|
# process target
|
||||||
|
|
||||||
|
|
||||||
def make_registry(repo_path="Opus-MT-train/models"):
|
def make_registry(repo_path="Opus-MT-train/models"):
|
||||||
if not (Path(repo_path) / "fr-en" / "README.md").exists():
|
if not (Path(repo_path) / "fr-en" / "README.md").exists():
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -666,6 +708,7 @@ def convert(source_dir: Path, dest_dir):
|
|||||||
# ^^ Save human readable marian config for debugging
|
# ^^ Save human readable marian config for debugging
|
||||||
|
|
||||||
model = opus_state.load_marian_model()
|
model = opus_state.load_marian_model()
|
||||||
|
model = model.half()
|
||||||
model.save_pretrained(dest_dir)
|
model.save_pretrained(dest_dir)
|
||||||
model.from_pretrained(dest_dir) # sanity check
|
model.from_pretrained(dest_dir) # sanity check
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user