mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Merge pull request #3132 from huggingface/hf_api_model_list
[hf_api] Get the public list of all the models on huggingface
This commit is contained in:
commit
3e5da38dae
@ -17,7 +17,7 @@
|
||||
import io
|
||||
import os
|
||||
from os.path import expanduser
|
||||
from typing import List
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
@ -27,6 +27,10 @@ ENDPOINT = "https://huggingface.co"
|
||||
|
||||
|
||||
class S3Obj:
|
||||
"""
|
||||
Data structure that represents a file belonging to the current user.
|
||||
"""
|
||||
|
||||
def __init__(self, filename: str, LastModified: str, ETag: str, Size: int, **kwargs):
|
||||
self.filename = filename
|
||||
self.LastModified = LastModified
|
||||
@ -41,6 +45,50 @@ class PresignedUrl:
|
||||
self.type = type # mime-type to send to S3.
|
||||
|
||||
|
||||
class S3Object:
|
||||
"""
|
||||
Data structure that represents a public file accessible on our S3.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
key: str, # S3 object key
|
||||
etag: str,
|
||||
lastModified: str,
|
||||
size: int,
|
||||
rfilename: str, # filename relative to config.json
|
||||
**kwargs
|
||||
):
|
||||
self.key = key
|
||||
self.etag = etag
|
||||
self.lastModified = lastModified
|
||||
self.size = size
|
||||
self.rfilename = rfilename
|
||||
|
||||
|
||||
class ModelInfo:
|
||||
"""
|
||||
Info about a public model accessible from our S3.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
modelId: str, # id of model
|
||||
key: str, # S3 object key of config.json
|
||||
author: Optional[str] = None,
|
||||
downloads: Optional[int] = None,
|
||||
tags: List[str] = [],
|
||||
siblings: List[Dict] = [], # list of files that constitute the model
|
||||
**kwargs
|
||||
):
|
||||
self.modelId = modelId
|
||||
self.key = key
|
||||
self.author = author
|
||||
self.downloads = downloads
|
||||
self.tags = tags
|
||||
self.siblings = [S3Object(**x) for x in siblings]
|
||||
|
||||
|
||||
class HfApi:
|
||||
def __init__(self, endpoint=None):
|
||||
self.endpoint = endpoint if endpoint is not None else ENDPOINT
|
||||
@ -129,6 +177,16 @@ class HfApi:
|
||||
r = requests.delete(path, headers={"authorization": "Bearer {}".format(token)}, json={"filename": filename})
|
||||
r.raise_for_status()
|
||||
|
||||
def model_list(self) -> List[ModelInfo]:
|
||||
"""
|
||||
Get the public list of all the models on huggingface, including the community models
|
||||
"""
|
||||
path = "{}/api/models".format(self.endpoint)
|
||||
r = requests.get(path)
|
||||
r.raise_for_status()
|
||||
d = r.json()
|
||||
return [ModelInfo(**x) for x in d]
|
||||
|
||||
|
||||
class TqdmProgressFileReader:
|
||||
"""
|
||||
|
@ -21,7 +21,7 @@ import unittest
|
||||
import requests
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from transformers.hf_api import HfApi, HfFolder, PresignedUrl, S3Obj
|
||||
from transformers.hf_api import HfApi, HfFolder, ModelInfo, PresignedUrl, S3Obj
|
||||
|
||||
|
||||
USER = "__DUMMY_TRANSFORMERS_USER__"
|
||||
@ -36,10 +36,11 @@ FILES = [
|
||||
os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/empty.txt"),
|
||||
),
|
||||
]
|
||||
ENDPOINT_STAGING = "https://moon-staging.huggingface.co"
|
||||
|
||||
|
||||
class HfApiCommonTest(unittest.TestCase):
|
||||
_api = HfApi(endpoint="https://moon-staging.huggingface.co")
|
||||
_api = HfApi(endpoint=ENDPOINT_STAGING)
|
||||
|
||||
|
||||
class HfApiLoginTest(HfApiCommonTest):
|
||||
@ -92,6 +93,18 @@ class HfApiEndpointsTest(HfApiCommonTest):
|
||||
self.assertIsInstance(o, S3Obj)
|
||||
|
||||
|
||||
class HfApiPublicTest(unittest.TestCase):
|
||||
def test_staging_model_list(self):
|
||||
_api = HfApi(endpoint=ENDPOINT_STAGING)
|
||||
_ = _api.model_list()
|
||||
|
||||
def test_model_list(self):
|
||||
_api = HfApi()
|
||||
models = _api.model_list()
|
||||
self.assertGreater(len(models), 100)
|
||||
self.assertIsInstance(models[0], ModelInfo)
|
||||
|
||||
|
||||
class HfFolderTest(unittest.TestCase):
|
||||
def test_token_workflow(self):
|
||||
"""
|
||||
|
Loading…
Reference in New Issue
Block a user