mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
use huggingface_hub.model_inifo() to get pipline_tag (#20077)
This commit is contained in:
parent
3222fc645b
commit
cfaeb1539e
@ -26,7 +26,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from numpy import isin
|
||||
|
||||
from huggingface_hub.file_download import http_get
|
||||
from huggingface_hub import model_info
|
||||
|
||||
from ..configuration_utils import PretrainedConfig
|
||||
from ..dynamic_module_utils import get_class_from_dynamic_module
|
||||
@ -389,25 +389,17 @@ def get_supported_tasks() -> List[str]:
|
||||
|
||||
|
||||
def get_task(model: str, use_auth_token: Optional[str] = None) -> str:
|
||||
tmp = io.BytesIO()
|
||||
headers = {}
|
||||
if use_auth_token:
|
||||
headers["Authorization"] = f"Bearer {use_auth_token}"
|
||||
|
||||
try:
|
||||
http_get(f"https://huggingface.co/api/models/{model}", tmp, headers=headers)
|
||||
tmp.seek(0)
|
||||
body = tmp.read()
|
||||
data = json.loads(body)
|
||||
info = model_info(model, token=use_auth_token)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Instantiating a pipeline without a task set raised an error: {e}")
|
||||
if "pipeline_tag" not in data:
|
||||
if not info.pipeline_tag:
|
||||
raise RuntimeError(
|
||||
f"The model {model} does not seem to have a correct `pipeline_tag` set to infer the task automatically"
|
||||
)
|
||||
if data.get("library_name", "transformers") != "transformers":
|
||||
raise RuntimeError(f"This model is meant to be used with {data['library_name']} not with transformers")
|
||||
task = data["pipeline_tag"]
|
||||
if getattr(info, "library_name", "transformers") != "transformers":
|
||||
raise RuntimeError(f"This model is meant to be used with {info.library_name} not with transformers")
|
||||
task = info.pipeline_tag
|
||||
return task
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user