diff --git a/src/transformers/convert_marian_to_pytorch.py b/src/transformers/convert_marian_to_pytorch.py index ce001166a85..6135f61275c 100644 --- a/src/transformers/convert_marian_to_pytorch.py +++ b/src/transformers/convert_marian_to_pytorch.py @@ -40,33 +40,7 @@ def check_if_models_are_dominated(old_repo_path="OPUS-MT-train/models", new_repo """Make a blacklist for models where we have already ported the same language pair, and the ported model has higher BLEU score.""" import pandas as pd - released_cols = [ - "url_base", - "pair", # (ISO639-3/ISO639-5 codes), - "short_pair", # (reduced codes), - "chrF2_score", - "bleu", - "brevity_penalty", - "ref_len", - "src_name", - "tgt_name", - ] - - released = pd.read_csv(f"{new_repo_path}/released-models.txt", sep="\t", header=None).iloc[:-1] - released.columns = released_cols - old_reg = make_registry(repo_path=old_repo_path) - old_reg = pd.DataFrame(old_reg, columns=["id", "prepro", "url_model", "url_test_set"]) - assert old_reg.id.value_counts().max() == 1 - old_reg = old_reg.set_index("id") - - released["fname"] = released["url_base"].apply( - lambda x: remove_suffix(remove_prefix(x, "https://object.pouta.csc.fi/Tatoeba-Challenge/opus"), ".zip") - ) - - released["2m"] = released.fname.str.startswith("2m") - released["date"] = pd.to_datetime(released["fname"].apply(lambda x: remove_prefix(remove_prefix(x, "2m-"), "-"))) - - newest_released = released.dsort("date").drop_duplicates(["short_pair"], keep="first") + newest_released, old_reg, released = get_released_df(new_repo_path, old_repo_path) short_to_new_bleu = newest_released.set_index("short_pair").bleu @@ -94,8 +68,38 @@ def check_if_models_are_dominated(old_repo_path="OPUS-MT-train/models", new_repo ).fillna(-1) dominated = cmp_df[cmp_df.old_bleu > cmp_df.new_bleu] + whitelist_df = cmp_df[cmp_df.old_bleu <= cmp_df.new_bleu] blacklist = dominated.long.unique().tolist() # 3 letter codes - return dominated, blacklist + return whitelist_df, dominated, blacklist + + +def get_released_df(new_repo_path, old_repo_path): + import pandas as pd + + released_cols = [ + "url_base", + "pair", # (ISO639-3/ISO639-5 codes), + "short_pair", # (reduced codes), + "chrF2_score", + "bleu", + "brevity_penalty", + "ref_len", + "src_name", + "tgt_name", + ] + released = pd.read_csv(f"{new_repo_path}/released-models.txt", sep="\t", header=None).iloc[:-1] + released.columns = released_cols + old_reg = make_registry(repo_path=old_repo_path) + old_reg = pd.DataFrame(old_reg, columns=["id", "prepro", "url_model", "url_test_set"]) + assert old_reg.id.value_counts().max() == 1 + old_reg = old_reg.set_index("id") + released["fname"] = released["url_base"].apply( + lambda x: remove_suffix(remove_prefix(x, "https://object.pouta.csc.fi/Tatoeba-Challenge/opus"), ".zip") + ) + released["2m"] = released.fname.str.startswith("2m") + released["date"] = pd.to_datetime(released["fname"].apply(lambda x: remove_prefix(remove_prefix(x, "2m-"), "-"))) + newest_released = released.dsort("date").drop_duplicates(["short_pair"], keep="first") + return newest_released, old_reg, released def remove_prefix(text: str, prefix: str): @@ -323,6 +327,44 @@ def get_clean_model_id_mapping(multiling_model_ids): return {x: convert_opus_name_to_hf_name(x) for x in multiling_model_ids} +def expand_group_to_two_letter_codes(grp_name): + raise NotImplementedError() + + +def get_two_letter_code(three_letter_code): + raise NotImplementedError() + # return two_letter_code + + +def get_tags(code, ref_name): + if len(code) == 2: + assert "languages" not in ref_name, f"{code}: {ref_name}" + return [code], False + elif "languages" in ref_name: + group = expand_group_to_two_letter_codes(code) + group.append(code) + return group, True + else: # zho-> zh + raise ValueError(f"Three letter monolingual code: {code}") + + +def resolve_lang_code(r): + """R is a row in ported""" + short_pair = r.short_pair + src, tgt = short_pair.split("-") + src_tags, src_multilingual = get_tags(src, r.src_name) + assert isinstance(src_tags, list) + tgt_tags, tgt_multilingual = get_tags(src, r.tgt_name) + assert isinstance(tgt_tags, list) + if src_multilingual: + src_tags.append("multilingual_src") + if tgt_multilingual: + tgt_tags.append("multilingual_tgt") + return src_tags + tgt_tags + + # process target + + def make_registry(repo_path="Opus-MT-train/models"): if not (Path(repo_path) / "fr-en" / "README.md").exists(): raise ValueError( @@ -666,6 +708,7 @@ def convert(source_dir: Path, dest_dir): # ^^ Save human readable marian config for debugging model = opus_state.load_marian_model() + model = model.half() model.save_pretrained(dest_dir) model.from_pretrained(dest_dir) # sanity check