mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
use huggingface_hub.model_inifo() to get pipline_tag (#20077)
This commit is contained in:
parent
3222fc645b
commit
cfaeb1539e
@ -26,7 +26,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
|||||||
|
|
||||||
from numpy import isin
|
from numpy import isin
|
||||||
|
|
||||||
from huggingface_hub.file_download import http_get
|
from huggingface_hub import model_info
|
||||||
|
|
||||||
from ..configuration_utils import PretrainedConfig
|
from ..configuration_utils import PretrainedConfig
|
||||||
from ..dynamic_module_utils import get_class_from_dynamic_module
|
from ..dynamic_module_utils import get_class_from_dynamic_module
|
||||||
@ -389,25 +389,17 @@ def get_supported_tasks() -> List[str]:
|
|||||||
|
|
||||||
|
|
||||||
def get_task(model: str, use_auth_token: Optional[str] = None) -> str:
|
def get_task(model: str, use_auth_token: Optional[str] = None) -> str:
|
||||||
tmp = io.BytesIO()
|
|
||||||
headers = {}
|
|
||||||
if use_auth_token:
|
|
||||||
headers["Authorization"] = f"Bearer {use_auth_token}"
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
http_get(f"https://huggingface.co/api/models/{model}", tmp, headers=headers)
|
info = model_info(model, token=use_auth_token)
|
||||||
tmp.seek(0)
|
|
||||||
body = tmp.read()
|
|
||||||
data = json.loads(body)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Instantiating a pipeline without a task set raised an error: {e}")
|
raise RuntimeError(f"Instantiating a pipeline without a task set raised an error: {e}")
|
||||||
if "pipeline_tag" not in data:
|
if not info.pipeline_tag:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"The model {model} does not seem to have a correct `pipeline_tag` set to infer the task automatically"
|
f"The model {model} does not seem to have a correct `pipeline_tag` set to infer the task automatically"
|
||||||
)
|
)
|
||||||
if data.get("library_name", "transformers") != "transformers":
|
if getattr(info, "library_name", "transformers") != "transformers":
|
||||||
raise RuntimeError(f"This model is meant to be used with {data['library_name']} not with transformers")
|
raise RuntimeError(f"This model is meant to be used with {info.library_name} not with transformers")
|
||||||
task = data["pipeline_tag"]
|
task = info.pipeline_tag
|
||||||
return task
|
return task
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user