Merge pull request #3132 from huggingface/hf_api_model_list

[hf_api] Get the public list of all the models on huggingface
This commit is contained in:
Thomas Wolf 2020-03-06 13:05:52 +01:00 committed by GitHub
commit 3e5da38dae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 74 additions and 3 deletions

View File

@ -17,7 +17,7 @@
import io
import os
from os.path import expanduser
from typing import List
from typing import Dict, List, Optional
import requests
from tqdm import tqdm
@ -27,6 +27,10 @@ ENDPOINT = "https://huggingface.co"
class S3Obj:
"""
Data structure that represents a file belonging to the current user.
"""
def __init__(self, filename: str, LastModified: str, ETag: str, Size: int, **kwargs):
self.filename = filename
self.LastModified = LastModified
@ -41,6 +45,50 @@ class PresignedUrl:
self.type = type # mime-type to send to S3.
class S3Object:
"""
Data structure that represents a public file accessible on our S3.
"""
def __init__(
self,
key: str, # S3 object key
etag: str,
lastModified: str,
size: int,
rfilename: str, # filename relative to config.json
**kwargs
):
self.key = key
self.etag = etag
self.lastModified = lastModified
self.size = size
self.rfilename = rfilename
class ModelInfo:
"""
Info about a public model accessible from our S3.
"""
def __init__(
self,
modelId: str, # id of model
key: str, # S3 object key of config.json
author: Optional[str] = None,
downloads: Optional[int] = None,
tags: List[str] = [],
siblings: List[Dict] = [], # list of files that constitute the model
**kwargs
):
self.modelId = modelId
self.key = key
self.author = author
self.downloads = downloads
self.tags = tags
self.siblings = [S3Object(**x) for x in siblings]
class HfApi:
def __init__(self, endpoint=None):
self.endpoint = endpoint if endpoint is not None else ENDPOINT
@ -129,6 +177,16 @@ class HfApi:
r = requests.delete(path, headers={"authorization": "Bearer {}".format(token)}, json={"filename": filename})
r.raise_for_status()
def model_list(self) -> List[ModelInfo]:
"""
Get the public list of all the models on huggingface, including the community models
"""
path = "{}/api/models".format(self.endpoint)
r = requests.get(path)
r.raise_for_status()
d = r.json()
return [ModelInfo(**x) for x in d]
class TqdmProgressFileReader:
"""

View File

@ -21,7 +21,7 @@ import unittest
import requests
from requests.exceptions import HTTPError
from transformers.hf_api import HfApi, HfFolder, PresignedUrl, S3Obj
from transformers.hf_api import HfApi, HfFolder, ModelInfo, PresignedUrl, S3Obj
USER = "__DUMMY_TRANSFORMERS_USER__"
@ -36,10 +36,11 @@ FILES = [
os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/empty.txt"),
),
]
ENDPOINT_STAGING = "https://moon-staging.huggingface.co"
class HfApiCommonTest(unittest.TestCase):
_api = HfApi(endpoint="https://moon-staging.huggingface.co")
_api = HfApi(endpoint=ENDPOINT_STAGING)
class HfApiLoginTest(HfApiCommonTest):
@ -92,6 +93,18 @@ class HfApiEndpointsTest(HfApiCommonTest):
self.assertIsInstance(o, S3Obj)
class HfApiPublicTest(unittest.TestCase):
def test_staging_model_list(self):
_api = HfApi(endpoint=ENDPOINT_STAGING)
_ = _api.model_list()
def test_model_list(self):
_api = HfApi()
models = _api.model_list()
self.assertGreater(len(models), 100)
self.assertIsInstance(models[0], ModelInfo)
class HfFolderTest(unittest.TestCase):
def test_token_workflow(self):
"""