mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
WIP serving through HTTP internally using pipelines.
This commit is contained in:
parent
43a4e1bbe4
commit
a096e2a88b
@ -1,15 +1,15 @@
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from typing import List, Optional, Union, Any
|
||||
|
||||
import torch
|
||||
from fastapi import FastAPI, HTTPException, Body
|
||||
from logging import getLogger
|
||||
|
||||
from pydantic import BaseModel
|
||||
from uvicorn import run
|
||||
|
||||
from transformers import AutoModel, AutoTokenizer, AutoConfig
|
||||
from transformers import Pipeline
|
||||
from transformers.commands import BaseTransformersCLICommand
|
||||
from transformers.pipelines import SUPPORTED_TASKS, pipeline
|
||||
|
||||
|
||||
def serve_command_factory(args: Namespace):
|
||||
@ -17,7 +17,8 @@ def serve_command_factory(args: Namespace):
|
||||
Factory function used to instantiate serving server from provided command line arguments.
|
||||
:return: ServeCommand
|
||||
"""
|
||||
return ServeCommand(args.host, args.port, args.model, args.graphql)
|
||||
nlp = pipeline(args.task, args.model)
|
||||
return ServeCommand(nlp, args.host, args.port, args.model, args.graphql)
|
||||
|
||||
|
||||
class ServeResult(BaseModel):
|
||||
@ -53,8 +54,6 @@ class ServeForwardResult(ServeResult):
|
||||
"""
|
||||
Forward result model
|
||||
"""
|
||||
tokens: List[str]
|
||||
tokens_ids: List[int]
|
||||
output: Any
|
||||
|
||||
|
||||
@ -68,19 +67,18 @@ class ServeCommand(BaseTransformersCLICommand):
|
||||
:return:
|
||||
"""
|
||||
serve_parser = parser.add_parser('serve', help='CLI tool to run inference requests through REST and GraphQL endpoints.')
|
||||
serve_parser.add_argument('--task', type=str, choices=SUPPORTED_TASKS.keys(), help='The task to run the pipeline on')
|
||||
serve_parser.add_argument('--device', type=int, default=-1, help='Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)')
|
||||
serve_parser.add_argument('--host', type=str, default='localhost', help='Interface the server will listen on.')
|
||||
serve_parser.add_argument('--port', type=int, default=8888, help='Port the serving will listen to.')
|
||||
serve_parser.add_argument('--model', type=str, required=True, help='Model\'s name or path to stored model to infer from.')
|
||||
serve_parser.add_argument('--graphql', action='store_true', default=False, help='Enable GraphQL endpoints.')
|
||||
serve_parser.set_defaults(func=serve_command_factory)
|
||||
|
||||
def __init__(self, host: str, port: int, model: str, graphql: bool):
|
||||
def __init__(self, pipeline: Pipeline, host: str, port: int, model: str, graphql: bool):
|
||||
self._logger = getLogger('transformers-cli/serving')
|
||||
|
||||
self._logger.info('Loading model {}'.format(model))
|
||||
self._model_name = model
|
||||
self._model = AutoModel.from_pretrained(model)
|
||||
self._tokenizer = AutoTokenizer.from_pretrained(model)
|
||||
self._pipeline = pipeline
|
||||
|
||||
self._logger.info('Serving model over {}:{}'.format(host, port))
|
||||
self._host = host
|
||||
@ -97,7 +95,7 @@ class ServeCommand(BaseTransformersCLICommand):
|
||||
run(self._app, host=self._host, port=self._port)
|
||||
|
||||
def model_info(self):
|
||||
return ServeModelInfoResult(model=self._model_name, infos=vars(self._model.config))
|
||||
return ServeModelInfoResult(model='', infos=vars(self._pipeline.model.config))
|
||||
|
||||
def tokenize(self, text_input: str = Body(None, embed=True), return_ids: bool = Body(False, embed=True)):
|
||||
"""
|
||||
@ -106,16 +104,16 @@ class ServeCommand(BaseTransformersCLICommand):
|
||||
- **return_ids**: Boolean flags indicating if the tokens have to be converted to their integer mapping.
|
||||
"""
|
||||
try:
|
||||
tokens_txt = self._tokenizer.tokenize(text_input)
|
||||
tokens_txt = self._pipeline.tokenizer.tokenize(text_input)
|
||||
|
||||
if return_ids:
|
||||
tokens_ids = self._tokenizer.convert_tokens_to_ids(tokens_txt)
|
||||
return ServeTokenizeResult(model=self._model_name, tokens=tokens_txt, tokens_ids=tokens_ids)
|
||||
tokens_ids = self._pipeline.tokenizer.convert_tokens_to_ids(tokens_txt)
|
||||
return ServeTokenizeResult(model='', tokens=tokens_txt, tokens_ids=tokens_ids)
|
||||
else:
|
||||
return ServeTokenizeResult(model=self._model_name, tokens=tokens_txt)
|
||||
return ServeTokenizeResult(model='', tokens=tokens_txt)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail={"model": self._model_name, "error": str(e)})
|
||||
raise HTTPException(status_code=500, detail={"model": '', "error": str(e)})
|
||||
|
||||
def detokenize(self, tokens_ids: List[int] = Body(None, embed=True),
|
||||
skip_special_tokens: bool = Body(False, embed=True),
|
||||
@ -127,14 +125,12 @@ class ServeCommand(BaseTransformersCLICommand):
|
||||
- **cleanup_tokenization_spaces**: Flag indicating to remove all leading/trailing spaces and intermediate ones.
|
||||
"""
|
||||
try:
|
||||
decoded_str = self._tokenizer.decode(tokens_ids, skip_special_tokens, cleanup_tokenization_spaces)
|
||||
return ServeDeTokenizeResult(model=self._model_name, text=decoded_str)
|
||||
decoded_str = self._pipeline.tokenizer.decode(tokens_ids, skip_special_tokens, cleanup_tokenization_spaces)
|
||||
return ServeDeTokenizeResult(model='', text=decoded_str)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail={"model": self._model_name, "error": str(e)})
|
||||
raise HTTPException(status_code=500, detail={"model": '', "error": str(e)})
|
||||
|
||||
def forward(self, inputs: Union[str, List[str], List[int]] = Body(None, embed=True),
|
||||
attention_mask: Optional[List[int]] = Body(None, embed=True),
|
||||
tokens_type_ids: Optional[List[int]] = Body(None, embed=True)):
|
||||
def forward(self, inputs: Union[str, dict, List[str], List[int], List[dict]] = Body(None, embed=True)):
|
||||
"""
|
||||
**inputs**:
|
||||
**attention_mask**:
|
||||
@ -143,34 +139,13 @@ class ServeCommand(BaseTransformersCLICommand):
|
||||
|
||||
# Check we don't have empty string
|
||||
if len(inputs) == 0:
|
||||
return ServeForwardResult(model=self._model_name, output=[], attention=[])
|
||||
|
||||
if isinstance(inputs, str):
|
||||
inputs_tokens = self._tokenizer.tokenize(inputs)
|
||||
inputs_ids = self._tokenizer.convert_tokens_to_ids(inputs_tokens)
|
||||
|
||||
elif isinstance(inputs, List):
|
||||
if isinstance(inputs[0], str):
|
||||
inputs_tokens = inputs
|
||||
inputs_ids = self._tokenizer.convert_tokens_to_ids(inputs_tokens)
|
||||
elif isinstance(inputs[0], int):
|
||||
inputs_tokens = []
|
||||
inputs_ids = inputs
|
||||
else:
|
||||
error_msg = "inputs should be string, [str] of [int] (got {})".format(type(inputs[0]))
|
||||
raise HTTPException(423, detail={"error": error_msg})
|
||||
else:
|
||||
error_msg = "inputs should be string, [str] of [int] (got {})".format(type(inputs))
|
||||
raise HTTPException(423, detail={"error": error_msg})
|
||||
return ServeForwardResult(model='', output=[], attention=[])
|
||||
|
||||
try:
|
||||
# Forward through the model
|
||||
t_input_ids = torch.tensor(inputs_ids).unsqueeze(0)
|
||||
output = self._model(t_input_ids, attention_mask, tokens_type_ids)
|
||||
|
||||
output = self._pipeline(inputs)
|
||||
return ServeForwardResult(
|
||||
model=self._model_name, tokens=inputs_tokens,
|
||||
tokens_ids=inputs_ids, output=output[0].tolist()
|
||||
model='', output=output
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(500, {"error": str(e)})
|
||||
|
Loading…
Reference in New Issue
Block a user