Mention model_info.id instead of model_info.modelId (#32106)

This commit is contained in:
Lucain 2024-07-22 15:14:47 +02:00 committed by GitHub
parent 0fcfc5ccc9
commit f2a1e3ca68
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 4 additions and 4 deletions

View File

@ -105,7 +105,7 @@ from huggingface_hub import list_models
model_list = list_models() model_list = list_models()
org = "Helsinki-NLP" org = "Helsinki-NLP"
model_ids = [x.modelId for x in model_list if x.modelId.startswith(org)] model_ids = [x.id for x in model_list if x.id.startswith(org)]
suffix = [x.split("/")[1] for x in model_ids] suffix = [x.split("/")[1] for x in model_ids]
old_style_multi_models = [f"{org}/{s}" for s in suffix if s != s.lower()] old_style_multi_models = [f"{org}/{s}" for s in suffix if s != s.lower()]
``` ```

View File

@ -65,7 +65,7 @@ def find_pretrained_model(src_lang: str, tgt_lang: str) -> List[str]:
"""Find models that can accept src_lang as input and return tgt_lang as output.""" """Find models that can accept src_lang as input and return tgt_lang as output."""
prefix = "Helsinki-NLP/opus-mt-" prefix = "Helsinki-NLP/opus-mt-"
model_list = list_models() model_list = list_models()
model_ids = [x.modelId for x in model_list if x.modelId.startswith("Helsinki-NLP")] model_ids = [x.id for x in model_list if x.id.startswith("Helsinki-NLP")]
src_and_targ = [ src_and_targ = [
remove_prefix(m, prefix).lower().split("-") for m in model_ids if "+" not in m remove_prefix(m, prefix).lower().split("-") for m in model_ids if "+" not in m
] # + cant be loaded. ] # + cant be loaded.

View File

@ -409,7 +409,7 @@ class ModelManagementTests(unittest.TestCase):
@require_torch @require_torch
def test_model_names(self): def test_model_names(self):
model_list = list_models() model_list = list_models()
model_ids = [x.modelId for x in model_list if x.modelId.startswith(ORG_NAME)] model_ids = [x.id for x in model_list if x.id.startswith(ORG_NAME)]
bad_model_ids = [mid for mid in model_ids if "+" in model_ids] bad_model_ids = [mid for mid in model_ids if "+" in model_ids]
self.assertListEqual([], bad_model_ids) self.assertListEqual([], bad_model_ids)
self.assertGreater(len(model_ids), 500) self.assertGreater(len(model_ids), 500)

View File

@ -94,7 +94,7 @@ def get_tiny_model_summary_from_hub(output_path):
) )
_models = set() _models = set()
for x in models: for x in models:
model = x.modelId model = x.id
org, model = model.split("/") org, model = model.split("/")
if not model.startswith("tiny-random-"): if not model.startswith("tiny-random-"):
continue continue