use huggingface_hub.model_inifo() to get pipline_tag (#20077)

This commit is contained in:
TAGAMI Yukihiro 2022-11-08 00:07:59 +09:00 committed by GitHub
parent 3222fc645b
commit cfaeb1539e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -26,7 +26,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from numpy import isin
from huggingface_hub.file_download import http_get
from huggingface_hub import model_info
from ..configuration_utils import PretrainedConfig
from ..dynamic_module_utils import get_class_from_dynamic_module
@ -389,25 +389,17 @@ def get_supported_tasks() -> List[str]:
def get_task(model: str, use_auth_token: Optional[str] = None) -> str:
tmp = io.BytesIO()
headers = {}
if use_auth_token:
headers["Authorization"] = f"Bearer {use_auth_token}"
try:
http_get(f"https://huggingface.co/api/models/{model}", tmp, headers=headers)
tmp.seek(0)
body = tmp.read()
data = json.loads(body)
info = model_info(model, token=use_auth_token)
except Exception as e:
raise RuntimeError(f"Instantiating a pipeline without a task set raised an error: {e}")
if "pipeline_tag" not in data:
if not info.pipeline_tag:
raise RuntimeError(
f"The model {model} does not seem to have a correct `pipeline_tag` set to infer the task automatically"
)
if data.get("library_name", "transformers") != "transformers":
raise RuntimeError(f"This model is meant to be used with {data['library_name']} not with transformers")
task = data["pipeline_tag"]
if getattr(info, "library_name", "transformers") != "transformers":
raise RuntimeError(f"This model is meant to be used with {info.library_name} not with transformers")
task = info.pipeline_tag
return task