WIP serving through HTTP internally using pipelines.

This commit is contained in:
Morgan Funtowicz 2019-12-16 16:38:02 +01:00
parent 43a4e1bbe4
commit a096e2a88b

View File

@ -1,15 +1,15 @@
from argparse import ArgumentParser, Namespace
from typing import List, Optional, Union, Any
import torch
from fastapi import FastAPI, HTTPException, Body
from logging import getLogger
from pydantic import BaseModel
from uvicorn import run
from transformers import AutoModel, AutoTokenizer, AutoConfig
from transformers import Pipeline
from transformers.commands import BaseTransformersCLICommand
from transformers.pipelines import SUPPORTED_TASKS, pipeline
def serve_command_factory(args: Namespace):
@ -17,7 +17,8 @@ def serve_command_factory(args: Namespace):
Factory function used to instantiate serving server from provided command line arguments.
:return: ServeCommand
"""
return ServeCommand(args.host, args.port, args.model, args.graphql)
nlp = pipeline(args.task, args.model)
return ServeCommand(nlp, args.host, args.port, args.model, args.graphql)
class ServeResult(BaseModel):
@ -53,8 +54,6 @@ class ServeForwardResult(ServeResult):
"""
Forward result model
"""
tokens: List[str]
tokens_ids: List[int]
output: Any
@ -68,19 +67,18 @@ class ServeCommand(BaseTransformersCLICommand):
:return:
"""
serve_parser = parser.add_parser('serve', help='CLI tool to run inference requests through REST and GraphQL endpoints.')
serve_parser.add_argument('--task', type=str, choices=SUPPORTED_TASKS.keys(), help='The task to run the pipeline on')
serve_parser.add_argument('--device', type=int, default=-1, help='Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)')
serve_parser.add_argument('--host', type=str, default='localhost', help='Interface the server will listen on.')
serve_parser.add_argument('--port', type=int, default=8888, help='Port the serving will listen to.')
serve_parser.add_argument('--model', type=str, required=True, help='Model\'s name or path to stored model to infer from.')
serve_parser.add_argument('--graphql', action='store_true', default=False, help='Enable GraphQL endpoints.')
serve_parser.set_defaults(func=serve_command_factory)
def __init__(self, host: str, port: int, model: str, graphql: bool):
def __init__(self, pipeline: Pipeline, host: str, port: int, model: str, graphql: bool):
self._logger = getLogger('transformers-cli/serving')
self._logger.info('Loading model {}'.format(model))
self._model_name = model
self._model = AutoModel.from_pretrained(model)
self._tokenizer = AutoTokenizer.from_pretrained(model)
self._pipeline = pipeline
self._logger.info('Serving model over {}:{}'.format(host, port))
self._host = host
@ -97,7 +95,7 @@ class ServeCommand(BaseTransformersCLICommand):
run(self._app, host=self._host, port=self._port)
def model_info(self):
return ServeModelInfoResult(model=self._model_name, infos=vars(self._model.config))
return ServeModelInfoResult(model='', infos=vars(self._pipeline.model.config))
def tokenize(self, text_input: str = Body(None, embed=True), return_ids: bool = Body(False, embed=True)):
"""
@ -106,16 +104,16 @@ class ServeCommand(BaseTransformersCLICommand):
- **return_ids**: Boolean flags indicating if the tokens have to be converted to their integer mapping.
"""
try:
tokens_txt = self._tokenizer.tokenize(text_input)
tokens_txt = self._pipeline.tokenizer.tokenize(text_input)
if return_ids:
tokens_ids = self._tokenizer.convert_tokens_to_ids(tokens_txt)
return ServeTokenizeResult(model=self._model_name, tokens=tokens_txt, tokens_ids=tokens_ids)
tokens_ids = self._pipeline.tokenizer.convert_tokens_to_ids(tokens_txt)
return ServeTokenizeResult(model='', tokens=tokens_txt, tokens_ids=tokens_ids)
else:
return ServeTokenizeResult(model=self._model_name, tokens=tokens_txt)
return ServeTokenizeResult(model='', tokens=tokens_txt)
except Exception as e:
raise HTTPException(status_code=500, detail={"model": self._model_name, "error": str(e)})
raise HTTPException(status_code=500, detail={"model": '', "error": str(e)})
def detokenize(self, tokens_ids: List[int] = Body(None, embed=True),
skip_special_tokens: bool = Body(False, embed=True),
@ -127,14 +125,12 @@ class ServeCommand(BaseTransformersCLICommand):
- **cleanup_tokenization_spaces**: Flag indicating to remove all leading/trailing spaces and intermediate ones.
"""
try:
decoded_str = self._tokenizer.decode(tokens_ids, skip_special_tokens, cleanup_tokenization_spaces)
return ServeDeTokenizeResult(model=self._model_name, text=decoded_str)
decoded_str = self._pipeline.tokenizer.decode(tokens_ids, skip_special_tokens, cleanup_tokenization_spaces)
return ServeDeTokenizeResult(model='', text=decoded_str)
except Exception as e:
raise HTTPException(status_code=500, detail={"model": self._model_name, "error": str(e)})
raise HTTPException(status_code=500, detail={"model": '', "error": str(e)})
def forward(self, inputs: Union[str, List[str], List[int]] = Body(None, embed=True),
attention_mask: Optional[List[int]] = Body(None, embed=True),
tokens_type_ids: Optional[List[int]] = Body(None, embed=True)):
def forward(self, inputs: Union[str, dict, List[str], List[int], List[dict]] = Body(None, embed=True)):
"""
**inputs**:
**attention_mask**:
@ -143,34 +139,13 @@ class ServeCommand(BaseTransformersCLICommand):
# Check we don't have empty string
if len(inputs) == 0:
return ServeForwardResult(model=self._model_name, output=[], attention=[])
if isinstance(inputs, str):
inputs_tokens = self._tokenizer.tokenize(inputs)
inputs_ids = self._tokenizer.convert_tokens_to_ids(inputs_tokens)
elif isinstance(inputs, List):
if isinstance(inputs[0], str):
inputs_tokens = inputs
inputs_ids = self._tokenizer.convert_tokens_to_ids(inputs_tokens)
elif isinstance(inputs[0], int):
inputs_tokens = []
inputs_ids = inputs
else:
error_msg = "inputs should be string, [str] of [int] (got {})".format(type(inputs[0]))
raise HTTPException(423, detail={"error": error_msg})
else:
error_msg = "inputs should be string, [str] of [int] (got {})".format(type(inputs))
raise HTTPException(423, detail={"error": error_msg})
return ServeForwardResult(model='', output=[], attention=[])
try:
# Forward through the model
t_input_ids = torch.tensor(inputs_ids).unsqueeze(0)
output = self._model(t_input_ids, attention_mask, tokens_type_ids)
output = self._pipeline(inputs)
return ServeForwardResult(
model=self._model_name, tokens=inputs_tokens,
tokens_ids=inputs_ids, output=output[0].tolist()
model='', output=output
)
except Exception as e:
raise HTTPException(500, {"error": str(e)})