diff --git a/src/transformers/commands/serving.py b/src/transformers/commands/serving.py index 49f969306f9..14865baf308 100644 --- a/src/transformers/commands/serving.py +++ b/src/transformers/commands/serving.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import enum import functools import json import re @@ -18,7 +19,7 @@ import time from argparse import ArgumentParser, Namespace from dataclasses import dataclass, field from threading import Thread -from typing import Any, Optional +from typing import Any, Optional, Generator, Literal from huggingface_hub import ( ChatCompletionStreamOutputChoice, @@ -26,7 +27,7 @@ from huggingface_hub import ( ChatCompletionStreamOutputDeltaToolCall, ChatCompletionStreamOutputFunction, ModelInfo, - model_info, + model_info, ChatCompletionStreamOutput, ) from transformers.utils.import_utils import is_fastapi_available, is_pydantic_available, is_uvicorn_available @@ -59,11 +60,42 @@ if is_pydantic_available() and is_fastapi_available() and is_uvicorn_available() role: str content: str - class ChatCompletionInput(BaseModel): - messages: list[Message] + class Prompt(BaseModel): + id: str + variables: dict + version: Optional[str] + + class TextFormatOptions(enum.StrEnum): + text = "text" + json_schema = "json_schema" + + class TextFormat(BaseModel): + type: TextFormatOptions + + class ResponsesInput(BaseModel): + input: str | list + model: str + + stream: Optional[bool] = False + instructions: Optional[str] = None + max_output_tokens: Optional[int] = None + max_tool_calls: Optional[int] = None + previous_response_id: Optional[str] = None + prompt: Optional[Prompt] = None + temperature: Optional[float] = None + text: Optional[TextFormat] = None + tools: any = None + top_p: Optional[float] = None + + # Additional options supported by the Responses API + # that aren't yet supported here. + # top_logprobs + + class ChatCompletionInput(BaseModel): + messages: list[Message] + model: str stream: Optional[bool] = False - model: Optional[str] = None request_id: Optional[str] = None extra_body: Optional[dict] = None frequency_penalty: Optional[float] = None @@ -250,7 +282,7 @@ class ServeCommand(BaseTransformersCLICommand): cb_logger = logging.get_logger("transformers.generation.continuous_batching") cb_logger.setLevel(logging.log_levels[self.args.log_level.lower()]) - def build_chunk( + def build_chat_completion_chunk( self, content: str, request_id: str, @@ -279,34 +311,56 @@ class ServeCommand(BaseTransformersCLICommand): } return f"data: {json.dumps(payload)}\n\n" + def build_responses_chunk( + self, + content: str, + request_id: str, + role: Optional[str] = None, + finish_reason: Optional[str] = None, + tool_calls: Optional[list[ChatCompletionStreamOutputDeltaToolCall]] = None, + ) -> str: + payload = { + "object": "chat.completion.chunk", + "id": request_id, + "created": int(time.time()), + "model": self.loaded_model, + "system_fingerprint": "", + "choices": [ + ChatCompletionStreamOutputChoice( + delta=ChatCompletionStreamOutputDelta( + role=role, + content=content, + tool_calls=tool_calls, + ), + index=0, + logprobs=None, + finish_reason=finish_reason, + ), + ], + } + return f"data: {json.dumps(payload)}\n\n" + + + def run(self): app = FastAPI() - if self.use_continuous_batching: - self.continuous_batching(app) - else: - self.generate(app) + @app.get("/v1/chat/completions") + def chat_completion(req: ChatCompletionInput): + if not req.stream: + return {"error": "Only streaming mode is supported."} - @functools.lru_cache(maxsize=None) - def get_text_gen_models() -> list[ModelInfo]: - """ - This is by no means a limit to which models may be instantiated with `transformers serve`: any chat-based - model working with generate can work. + output = self.continuous_batching(req) if self.use_continuous_batching else self.generate(req) - This is a limited list of models to ensure we have a discoverable /v1/models endpoint for third-party - integrations. - """ - return [ - model_info("Menlo/Jan-nano"), - model_info("Menlo/Jan-nano-128k"), - model_info("Qwen/Qwen2.5-0.5B-Instruct"), - model_info("Qwen/Qwen2.5-3B-Instruct"), - model_info("Qwen/Qwen2.5-7B-Instruct"), - model_info("Qwen/Qwen2.5-14B-Instruct"), - model_info("meta-llama/Llama-3.1-8B-Instruct"), - model_info("meta-llama/Llama-3.2-1B-Instruct"), - model_info("meta-llama/Llama-3.3-70B-Instruct"), - ] + return StreamingResponse(output, media_type="text/event-stream") + + @app.get("/v1/responses") + def responses(req: ResponsesInput): + if not req.stream: + return {"error": "Only streaming mode is supported."} + + output = self.generate_responses(req) + return StreamingResponse(output, media_type="text/event-stream") @app.get("/v1/models") def get_all_models(): @@ -320,79 +374,289 @@ class ServeCommand(BaseTransformersCLICommand): "crated": model.created_at.timestamp(), "owned_by": model.author, } - for model in get_text_gen_models() + for model in self.get_text_gen_models() ], } ) uvicorn.run(app, host=self.args.host, port=self.args.port, log_level=self.args.log_level) - def continuous_batching(self, app): - @app.post("/v1/chat/completions") - def _serve(req: "ChatCompletionInput"): - if not req.stream: - return {"error": "Only streaming mode is supported."} + @functools.lru_cache(maxsize=None) + def get_text_gen_models(self) -> list[ModelInfo]: + """ + This is by no means a limit to which models may be instantiated with `transformers serve`: any chat-based + model working with generate can work. - update_model = req.model != self.loaded_model - if update_model: - self.model, self.tokenizer = self.load_model_and_tokenizer(req.model, self.args) + This is a limited list of models to ensure we have a discoverable /v1/models endpoint for third-party + integrations. + """ + return [ + model_info("Menlo/Jan-nano"), + model_info("Menlo/Jan-nano-128k"), + model_info("Qwen/Qwen2.5-0.5B-Instruct"), + model_info("Qwen/Qwen2.5-3B-Instruct"), + model_info("Qwen/Qwen2.5-7B-Instruct"), + model_info("Qwen/Qwen2.5-14B-Instruct"), + model_info("meta-llama/Llama-3.1-8B-Instruct"), + model_info("meta-llama/Llama-3.2-1B-Instruct"), + model_info("meta-llama/Llama-3.3-70B-Instruct"), + ] - generation_config = create_generation_config_from_req( - req, - eos_token_id=self.tokenizer.eos_token_id, - pad_token_id=self.tokenizer.pad_token_id, - use_cache=False, - num_blocks=1, - block_size=1024, - do_sample=False, - max_batch_tokens=10, - scheduler="fifo", + def continuous_batching(self, req: ChatCompletionInput) -> Generator: + update_model = req.model != self.loaded_model + if update_model: + self.model, self.tokenizer = self.load_model_and_tokenizer(req.model, self.args) + + generation_config = create_generation_config_from_req( + req, + eos_token_id=self.tokenizer.eos_token_id, + pad_token_id=self.tokenizer.pad_token_id, + use_cache=False, + num_blocks=1, + block_size=1024, + do_sample=False, + max_batch_tokens=10, + scheduler="fifo", + ) + + if self.running_continuous_batching_manager is None or update_model: + self.running_continuous_batching_manager = self.model.init_continuous_batching( + generation_config=generation_config, streaming=True ) + self.running_continuous_batching_manager.logit_processor = LogitsProcessorList() + self.running_continuous_batching_manager.start() - if self.running_continuous_batching_manager is None or update_model: - self.running_continuous_batching_manager = self.model.init_continuous_batching( - generation_config=generation_config, streaming=True + inputs = self.tokenizer.apply_chat_template( + req.messages, return_tensors="pt", add_generation_prompt=True + ).to(self.model.device) + + def stream_response(_inputs): + try: + max_new_tokens = req.max_tokens or generation_config.max_new_tokens or 256 + request_id = self.running_continuous_batching_manager.add_request( + _inputs, request_id=req.request_id, max_new_tokens=max_new_tokens ) - self.running_continuous_batching_manager.logit_processor = LogitsProcessorList() - self.running_continuous_batching_manager.start() + queue_is_flushed = False - inputs = self.tokenizer.apply_chat_template( - req.messages, return_tensors="pt", add_generation_prompt=True - ).to(self.model.device) + for result in self.running_continuous_batching_manager: + if result.request_id != request_id: + continue - def stream_response(_inputs): - try: - max_new_tokens = req.max_tokens or generation_config.max_new_tokens or 256 - request_id = self.running_continuous_batching_manager.add_request( - _inputs, request_id=req.request_id, max_new_tokens=max_new_tokens + if req.request_id is not None and not queue_is_flushed: + if result.status == RequestStatus.FINISHED: + continue + else: + queue_is_flushed = True + + finish_reason = "stop" if result.status == RequestStatus.FINISHED else None + next_token = self.build_chunk( + result.next_token, request_id=request_id, finish_reason=finish_reason ) - queue_is_flushed = False + yield next_token - for result in self.running_continuous_batching_manager: - if result.request_id != request_id: + if result.status == RequestStatus.FINISHED: + break + + yield "data: [DONE]\n\n" + except Exception as e: + logger.error(str(e)) + yield f'data: {{"error": "{str(e)}"}}' + + return stream_response(inputs[0]) + + def generate(self, req: ChatCompletionInput) -> Generator: + update_model = req.model != self.loaded_model + if update_model: + self.model, self.tokenizer = self.load_model_and_tokenizer(req.model, self.args) + + # HACK for tiny-agents: it sends a request after the assistant message (???). Let's assume we can't have a + # request whose last message is from the assistant. + if req.messages[-1].role == "assistant": + return + + # ====== TOOL PREPROCESSING LOGIC ====== + tool_model_family = None + for supported_model_families in _MODELS_WITH_TOOL_SUPPORT: + if supported_model_families in self.model.config.architectures[0].lower(): + tool_model_family = supported_model_families + break + # TODO: trigger 2 constrained generations after the tool call start token is emitted: + # 1. force generation to pick from the tool names + # 2. force generation to pick from that tool's arguments + # ====== END OF TOOL PREPROCESSING LOGIC ====== + + if tool_model_family is not None: + text = self.tokenizer.apply_chat_template( + req.messages, add_generation_prompt=True, tokenize=False, tools=req.tools + ) + else: + text = self.tokenizer.apply_chat_template(req.messages, add_generation_prompt=True, tokenize=False) + + inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device)["input_ids"] + request_id = req.request_id if req.request_id is not None else "req_0" + + generation_streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True, skip_prompt=True) + + generation_config = create_generation_config_from_req(req) + max_new_tokens = req.max_tokens or generation_config.max_new_tokens or 256 + generation_config.max_new_tokens = max_new_tokens + + last_kv_cache = None + if self.is_continuation(req) and not update_model: + last_kv_cache = self.last_kv_cache + + generation_kwargs = { + "inputs": inputs, + "attention_mask": torch.ones_like(inputs), + "streamer": generation_streamer, + "generation_config": generation_config, + "return_dict_in_generate": True, + "past_key_values": last_kv_cache, + } + + def stream_response(streamer, _request_id): + # Thin wrapper to save the KV cache after generation + def generate_with_cache(**kwargs): + generate_output = self.model.generate(**kwargs) + self.last_kv_cache = generate_output.past_key_values + + thread = Thread(target=generate_with_cache, kwargs=generation_kwargs) + + try: + thread.start() + tool_state = ToolState() + + for result in streamer: + # ====== TOOL CALL LOGIC ====== + if tool_model_family is not None: + # Start of a tool call: reset state variables, set `inside_tool_call` + if result.strip() == _TOOL_CALL_TOKENS[tool_model_family]["start"]: + tool_state.inside_tool_call = True continue - if req.request_id is not None and not queue_is_flushed: - if result.status == RequestStatus.FINISHED: - continue + # End of tool call: reset `inside_tool_call`, emit a `finish_reason` + if result.strip() == _TOOL_CALL_TOKENS[tool_model_family]["end"]: + tool_state.reset() + yield self.build_chunk("", _request_id, role=None, finish_reason="tool_calls") + continue + + # Inside a tool call + if tool_state.inside_tool_call: + tool_state.buffer += result + + # First step: extract the tool name (may need several tokens, and we can't emit a delta + # until we have the full name) + if not tool_state.has_tool_name_defined: + tool_name = re.search(r"\"name\": \"(.*?)\"", tool_state.buffer) + if tool_name is None: + continue + else: + tool_name = tool_name.group(1) + tool_state.has_tool_name_defined = True + tool = ChatCompletionStreamOutputDeltaToolCall( + function=ChatCompletionStreamOutputFunction( + name=tool_name, + arguments=None, + ), + index=0, + type="function", + id=_request_id + "_tool_call", # Only the first tool call delta has an id + ) + + # Second step: extract tool arguments. The tool arguments can be seen as a json string + # within the tool json string. We emit a delta for the arguments. else: - queue_is_flushed = True + # Empty text: skip + if result == "": + continue + # Until we see the `"arguments": {` in the buffer, we skip + # TODO: other models will likely need more elaborate processing here + if '"arguments": {' not in tool_state.buffer: + continue - finish_reason = "stop" if result.status == RequestStatus.FINISHED else None - next_token = self.build_chunk( - result.next_token, request_id=request_id, finish_reason=finish_reason - ) - yield next_token + # Handle nesting. We want to exclude the last } from the emitted arguments (it's + # closing the outermost nesting level, outside the arguments block) + tool_state.arg_nesting_level += result.count("{") + tool_state.arg_nesting_level -= result.count("}") + if tool_state.arg_nesting_level < 0: + result = "".join(result.split("}")[:-2]) + "}" # e.g. "4}}\n" -> "4}" - if result.status == RequestStatus.FINISHED: - break + tool = ChatCompletionStreamOutputDeltaToolCall( + function=ChatCompletionStreamOutputFunction( + arguments=result, + ), + index=0, + type="function", + id=None, + ) - yield "data: [DONE]\n\n" - except Exception as e: - logger.error(str(e)) - yield f'data: {{"error": "{str(e)}"}}' + yield self.build_chunk(None, _request_id, role=None, tool_calls=[tool]) + continue + # ====== END OF TOOL CALL LOGIC ====== - return StreamingResponse(stream_response(inputs[0]), media_type="text/event-stream") + # All non-tool related tokens are emitted as assistant messages + yield self.build_chunk(result, _request_id, role="assistant") + yield self.build_chunk(None, _request_id, role=None, finish_reason="stop") + + thread.join() + except Exception as e: + logger.error(str(e)) + raise + yield f'data: {{"error": "{str(e)}"}}' + + finally: + thread.join() + + return stream_response(generation_streamer, request_id) + + def generate_response(self, req: ResponsesInput) -> Generator: + update_model = req.model != self.loaded_model + if update_model: + self.model, self.tokenizer = self.load_model_and_tokenizer(req.model, self.args) + + text = self.tokenizer.apply_chat_template(req.messages, add_generation_prompt=True, tokenize=False) + + inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device)["input_ids"] + request_id = req.request_id if req.request_id is not None else "req_0" + + generation_streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True, skip_prompt=True) + + generation_config = create_generation_config_from_req(req) + max_new_tokens = req.max_tokens or generation_config.max_new_tokens or 256 + generation_config.max_new_tokens = max_new_tokens + + last_kv_cache = None + if self.is_continuation(req) and not update_model: + last_kv_cache = self.last_kv_cache + + generation_kwargs = { + "inputs": inputs, + "attention_mask": torch.ones_like(inputs), + "streamer": generation_streamer, + "generation_config": generation_config, + "return_dict_in_generate": True, + "past_key_values": last_kv_cache, + } + + def stream_response(streamer, _request_id): + thread = Thread(target=self.model.generate, kwargs=generation_kwargs) + + try: + thread.start() + for result in streamer: + yield self.build_chunk(result, _request_id, role="assistant") + yield self.build_chunk(None, _request_id, role=None, finish_reason="stop") + + thread.join() + except Exception as e: + logger.error(str(e)) + raise + yield f'data: {{"error": "{str(e)}"}}' + + finally: + thread.join() + + return stream_response(generation_streamer, request_id) def is_continuation(self, req: "ChatCompletionInput") -> bool: """ @@ -423,156 +687,6 @@ class ServeCommand(BaseTransformersCLICommand): self.last_messages = req.messages return req_continues_last_messages - def generate(self, app): - @app.post("/v1/chat/completions") - def _serve(req: "ChatCompletionInput"): - update_model = req.model != self.loaded_model - - if update_model: - self.model, self.tokenizer = self.load_model_and_tokenizer(req.model, self.args) - - if not req.stream: - return {"error": "Only streaming mode is supported."} - - # HACK for tiny-agents: it sends a request after the assistant message (???). Let's assume we can't have a - # request whose last message is from the assistant. - if req.messages[-1].role == "assistant": - return - - # ====== TOOL PREPROCESSING LOGIC ====== - tool_model_family = None - for supported_model_families in _MODELS_WITH_TOOL_SUPPORT: - if supported_model_families in self.model.config.architectures[0].lower(): - tool_model_family = supported_model_families - break - # TODO: trigger 2 constrained generations after the tool call start token is emitted: - # 1. force generation to pick from the tool names - # 2. force generation to pick from that tool's arguments - # ====== END OF TOOL PREPROCESSING LOGIC ====== - - if tool_model_family is not None: - text = self.tokenizer.apply_chat_template( - req.messages, add_generation_prompt=True, tokenize=False, tools=req.tools - ) - else: - text = self.tokenizer.apply_chat_template(req.messages, add_generation_prompt=True, tokenize=False) - - inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device)["input_ids"] - request_id = req.request_id if req.request_id is not None else "req_0" - - generation_streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True, skip_prompt=True) - - generation_config = create_generation_config_from_req(req) - max_new_tokens = req.max_tokens or generation_config.max_new_tokens or 256 - generation_config.max_new_tokens = max_new_tokens - - last_kv_cache = None - if self.is_continuation(req) and not update_model: - last_kv_cache = self.last_kv_cache - - generation_kwargs = { - "inputs": inputs, - "attention_mask": torch.ones_like(inputs), - "streamer": generation_streamer, - "generation_config": generation_config, - "return_dict_in_generate": True, - "past_key_values": last_kv_cache, - } - - def stream_response(streamer, _request_id): - # Thin wrapper to save the KV cache after generation - def generate_with_cache(**kwargs): - generate_output = self.model.generate(**kwargs) - self.last_kv_cache = generate_output.past_key_values - - thread = Thread(target=generate_with_cache, kwargs=generation_kwargs) - - try: - thread.start() - tool_state = ToolState() - - for result in streamer: - # ====== TOOL CALL LOGIC ====== - if tool_model_family is not None: - # Start of a tool call: reset state variables, set `inside_tool_call` - if result.strip() == _TOOL_CALL_TOKENS[tool_model_family]["start"]: - tool_state.inside_tool_call = True - continue - - # End of tool call: reset `inside_tool_call`, emit a `finish_reason` - if result.strip() == _TOOL_CALL_TOKENS[tool_model_family]["end"]: - tool_state.reset() - yield self.build_chunk("", _request_id, role=None, finish_reason="tool_calls") - continue - - # Inside a tool call - if tool_state.inside_tool_call: - tool_state.buffer += result - - # First step: extract the tool name (may need several tokens, and we can't emit a delta - # until we have the full name) - if not tool_state.has_tool_name_defined: - tool_name = re.search(r"\"name\": \"(.*?)\"", tool_state.buffer) - if tool_name is None: - continue - else: - tool_name = tool_name.group(1) - tool_state.has_tool_name_defined = True - tool = ChatCompletionStreamOutputDeltaToolCall( - function=ChatCompletionStreamOutputFunction( - name=tool_name, - arguments=None, - ), - index=0, - type="function", - id=_request_id + "_tool_call", # Only the first tool call delta has an id - ) - - # Second step: extract tool arguments. The tool arguments can be seen as a json string - # within the tool json string. We emit a delta for the arguments. - else: - # Empty text: skip - if result == "": - continue - # Until we see the `"arguments": {` in the buffer, we skip - # TODO: other models will likely need more elaborate processing here - if '"arguments": {' not in tool_state.buffer: - continue - - # Handle nesting. We want to exclude the last } from the emitted arguments (it's - # closing the outermost nesting level, outside the arguments block) - tool_state.arg_nesting_level += result.count("{") - tool_state.arg_nesting_level -= result.count("}") - if tool_state.arg_nesting_level < 0: - result = "".join(result.split("}")[:-2]) + "}" # e.g. "4}}\n" -> "4}" - - tool = ChatCompletionStreamOutputDeltaToolCall( - function=ChatCompletionStreamOutputFunction( - arguments=result, - ), - index=0, - type="function", - id=None, - ) - - yield self.build_chunk(None, _request_id, role=None, tool_calls=[tool]) - continue - # ====== END OF TOOL CALL LOGIC ====== - - # All non-tool related tokens are emitted as assistant messages - yield self.build_chunk(result, _request_id, role="assistant") - yield self.build_chunk(None, _request_id, role=None, finish_reason="stop") - - thread.join() - except Exception as e: - logger.error(str(e)) - raise - yield f'data: {{"error": "{str(e)}"}}' - - finally: - thread.join() - - return StreamingResponse(stream_response(generation_streamer, request_id), media_type="text/event-stream") @staticmethod def get_quantization_config(model_args: ServeArguments) -> Optional["BitsAndBytesConfig"]: