Mention model_info.id instead of model_info.modelId (#32106)

This commit is contained in:
Lucain 2024-07-22 15:14:47 +02:00 committed by GitHub
parent 0fcfc5ccc9
commit f2a1e3ca68
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 4 additions and 4 deletions

View File

@ -105,7 +105,7 @@ from huggingface_hub import list_models
model_list = list_models()
org = "Helsinki-NLP"
model_ids = [x.modelId for x in model_list if x.modelId.startswith(org)]
model_ids = [x.id for x in model_list if x.id.startswith(org)]
suffix = [x.split("/")[1] for x in model_ids]
old_style_multi_models = [f"{org}/{s}" for s in suffix if s != s.lower()]
```

View File

@ -65,7 +65,7 @@ def find_pretrained_model(src_lang: str, tgt_lang: str) -> List[str]:
"""Find models that can accept src_lang as input and return tgt_lang as output."""
prefix = "Helsinki-NLP/opus-mt-"
model_list = list_models()
model_ids = [x.modelId for x in model_list if x.modelId.startswith("Helsinki-NLP")]
model_ids = [x.id for x in model_list if x.id.startswith("Helsinki-NLP")]
src_and_targ = [
remove_prefix(m, prefix).lower().split("-") for m in model_ids if "+" not in m
] # + cant be loaded.

View File

@ -409,7 +409,7 @@ class ModelManagementTests(unittest.TestCase):
@require_torch
def test_model_names(self):
model_list = list_models()
model_ids = [x.modelId for x in model_list if x.modelId.startswith(ORG_NAME)]
model_ids = [x.id for x in model_list if x.id.startswith(ORG_NAME)]
bad_model_ids = [mid for mid in model_ids if "+" in model_ids]
self.assertListEqual([], bad_model_ids)
self.assertGreater(len(model_ids), 500)

View File

@ -94,7 +94,7 @@ def get_tiny_model_summary_from_hub(output_path):
)
_models = set()
for x in models:
model = x.modelId
model = x.id
org, model = model.split("/")
if not model.startswith("tiny-random-"):
continue