Scaffolding

This commit is contained in:
Lysandre 2025-07-01 18:22:55 +02:00
parent 733bcb4fed
commit feb1eb7531

View File

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import enum
import functools import functools
import json import json
import re import re
@ -18,7 +19,7 @@ import time
from argparse import ArgumentParser, Namespace from argparse import ArgumentParser, Namespace
from dataclasses import dataclass, field from dataclasses import dataclass, field
from threading import Thread from threading import Thread
from typing import Any, Optional from typing import Any, Optional, Generator, Literal
from huggingface_hub import ( from huggingface_hub import (
ChatCompletionStreamOutputChoice, ChatCompletionStreamOutputChoice,
@ -26,7 +27,7 @@ from huggingface_hub import (
ChatCompletionStreamOutputDeltaToolCall, ChatCompletionStreamOutputDeltaToolCall,
ChatCompletionStreamOutputFunction, ChatCompletionStreamOutputFunction,
ModelInfo, ModelInfo,
model_info, model_info, ChatCompletionStreamOutput,
) )
from transformers.utils.import_utils import is_fastapi_available, is_pydantic_available, is_uvicorn_available from transformers.utils.import_utils import is_fastapi_available, is_pydantic_available, is_uvicorn_available
@ -59,11 +60,42 @@ if is_pydantic_available() and is_fastapi_available() and is_uvicorn_available()
role: str role: str
content: str content: str
class ChatCompletionInput(BaseModel): class Prompt(BaseModel):
messages: list[Message] id: str
variables: dict
version: Optional[str]
class TextFormatOptions(enum.StrEnum):
text = "text"
json_schema = "json_schema"
class TextFormat(BaseModel):
type: TextFormatOptions
class ResponsesInput(BaseModel):
input: str | list
model: str
stream: Optional[bool] = False
instructions: Optional[str] = None
max_output_tokens: Optional[int] = None
max_tool_calls: Optional[int] = None
previous_response_id: Optional[str] = None
prompt: Optional[Prompt] = None
temperature: Optional[float] = None
text: Optional[TextFormat] = None
tools: any = None
top_p: Optional[float] = None
# Additional options supported by the Responses API
# that aren't yet supported here.
# top_logprobs
class ChatCompletionInput(BaseModel):
messages: list[Message]
model: str
stream: Optional[bool] = False stream: Optional[bool] = False
model: Optional[str] = None
request_id: Optional[str] = None request_id: Optional[str] = None
extra_body: Optional[dict] = None extra_body: Optional[dict] = None
frequency_penalty: Optional[float] = None frequency_penalty: Optional[float] = None
@ -250,7 +282,7 @@ class ServeCommand(BaseTransformersCLICommand):
cb_logger = logging.get_logger("transformers.generation.continuous_batching") cb_logger = logging.get_logger("transformers.generation.continuous_batching")
cb_logger.setLevel(logging.log_levels[self.args.log_level.lower()]) cb_logger.setLevel(logging.log_levels[self.args.log_level.lower()])
def build_chunk( def build_chat_completion_chunk(
self, self,
content: str, content: str,
request_id: str, request_id: str,
@ -279,34 +311,56 @@ class ServeCommand(BaseTransformersCLICommand):
} }
return f"data: {json.dumps(payload)}\n\n" return f"data: {json.dumps(payload)}\n\n"
def build_responses_chunk(
self,
content: str,
request_id: str,
role: Optional[str] = None,
finish_reason: Optional[str] = None,
tool_calls: Optional[list[ChatCompletionStreamOutputDeltaToolCall]] = None,
) -> str:
payload = {
"object": "chat.completion.chunk",
"id": request_id,
"created": int(time.time()),
"model": self.loaded_model,
"system_fingerprint": "",
"choices": [
ChatCompletionStreamOutputChoice(
delta=ChatCompletionStreamOutputDelta(
role=role,
content=content,
tool_calls=tool_calls,
),
index=0,
logprobs=None,
finish_reason=finish_reason,
),
],
}
return f"data: {json.dumps(payload)}\n\n"
def run(self): def run(self):
app = FastAPI() app = FastAPI()
if self.use_continuous_batching: @app.get("/v1/chat/completions")
self.continuous_batching(app) def chat_completion(req: ChatCompletionInput):
else: if not req.stream:
self.generate(app) return {"error": "Only streaming mode is supported."}
@functools.lru_cache(maxsize=None) output = self.continuous_batching(req) if self.use_continuous_batching else self.generate(req)
def get_text_gen_models() -> list[ModelInfo]:
"""
This is by no means a limit to which models may be instantiated with `transformers serve`: any chat-based
model working with generate can work.
This is a limited list of models to ensure we have a discoverable /v1/models endpoint for third-party return StreamingResponse(output, media_type="text/event-stream")
integrations.
""" @app.get("/v1/responses")
return [ def responses(req: ResponsesInput):
model_info("Menlo/Jan-nano"), if not req.stream:
model_info("Menlo/Jan-nano-128k"), return {"error": "Only streaming mode is supported."}
model_info("Qwen/Qwen2.5-0.5B-Instruct"),
model_info("Qwen/Qwen2.5-3B-Instruct"), output = self.generate_responses(req)
model_info("Qwen/Qwen2.5-7B-Instruct"), return StreamingResponse(output, media_type="text/event-stream")
model_info("Qwen/Qwen2.5-14B-Instruct"),
model_info("meta-llama/Llama-3.1-8B-Instruct"),
model_info("meta-llama/Llama-3.2-1B-Instruct"),
model_info("meta-llama/Llama-3.3-70B-Instruct"),
]
@app.get("/v1/models") @app.get("/v1/models")
def get_all_models(): def get_all_models():
@ -320,79 +374,289 @@ class ServeCommand(BaseTransformersCLICommand):
"crated": model.created_at.timestamp(), "crated": model.created_at.timestamp(),
"owned_by": model.author, "owned_by": model.author,
} }
for model in get_text_gen_models() for model in self.get_text_gen_models()
], ],
} }
) )
uvicorn.run(app, host=self.args.host, port=self.args.port, log_level=self.args.log_level) uvicorn.run(app, host=self.args.host, port=self.args.port, log_level=self.args.log_level)
def continuous_batching(self, app): @functools.lru_cache(maxsize=None)
@app.post("/v1/chat/completions") def get_text_gen_models(self) -> list[ModelInfo]:
def _serve(req: "ChatCompletionInput"): """
if not req.stream: This is by no means a limit to which models may be instantiated with `transformers serve`: any chat-based
return {"error": "Only streaming mode is supported."} model working with generate can work.
update_model = req.model != self.loaded_model This is a limited list of models to ensure we have a discoverable /v1/models endpoint for third-party
if update_model: integrations.
self.model, self.tokenizer = self.load_model_and_tokenizer(req.model, self.args) """
return [
model_info("Menlo/Jan-nano"),
model_info("Menlo/Jan-nano-128k"),
model_info("Qwen/Qwen2.5-0.5B-Instruct"),
model_info("Qwen/Qwen2.5-3B-Instruct"),
model_info("Qwen/Qwen2.5-7B-Instruct"),
model_info("Qwen/Qwen2.5-14B-Instruct"),
model_info("meta-llama/Llama-3.1-8B-Instruct"),
model_info("meta-llama/Llama-3.2-1B-Instruct"),
model_info("meta-llama/Llama-3.3-70B-Instruct"),
]
generation_config = create_generation_config_from_req( def continuous_batching(self, req: ChatCompletionInput) -> Generator:
req, update_model = req.model != self.loaded_model
eos_token_id=self.tokenizer.eos_token_id, if update_model:
pad_token_id=self.tokenizer.pad_token_id, self.model, self.tokenizer = self.load_model_and_tokenizer(req.model, self.args)
use_cache=False,
num_blocks=1, generation_config = create_generation_config_from_req(
block_size=1024, req,
do_sample=False, eos_token_id=self.tokenizer.eos_token_id,
max_batch_tokens=10, pad_token_id=self.tokenizer.pad_token_id,
scheduler="fifo", use_cache=False,
num_blocks=1,
block_size=1024,
do_sample=False,
max_batch_tokens=10,
scheduler="fifo",
)
if self.running_continuous_batching_manager is None or update_model:
self.running_continuous_batching_manager = self.model.init_continuous_batching(
generation_config=generation_config, streaming=True
) )
self.running_continuous_batching_manager.logit_processor = LogitsProcessorList()
self.running_continuous_batching_manager.start()
if self.running_continuous_batching_manager is None or update_model: inputs = self.tokenizer.apply_chat_template(
self.running_continuous_batching_manager = self.model.init_continuous_batching( req.messages, return_tensors="pt", add_generation_prompt=True
generation_config=generation_config, streaming=True ).to(self.model.device)
def stream_response(_inputs):
try:
max_new_tokens = req.max_tokens or generation_config.max_new_tokens or 256
request_id = self.running_continuous_batching_manager.add_request(
_inputs, request_id=req.request_id, max_new_tokens=max_new_tokens
) )
self.running_continuous_batching_manager.logit_processor = LogitsProcessorList() queue_is_flushed = False
self.running_continuous_batching_manager.start()
inputs = self.tokenizer.apply_chat_template( for result in self.running_continuous_batching_manager:
req.messages, return_tensors="pt", add_generation_prompt=True if result.request_id != request_id:
).to(self.model.device) continue
def stream_response(_inputs): if req.request_id is not None and not queue_is_flushed:
try: if result.status == RequestStatus.FINISHED:
max_new_tokens = req.max_tokens or generation_config.max_new_tokens or 256 continue
request_id = self.running_continuous_batching_manager.add_request( else:
_inputs, request_id=req.request_id, max_new_tokens=max_new_tokens queue_is_flushed = True
finish_reason = "stop" if result.status == RequestStatus.FINISHED else None
next_token = self.build_chunk(
result.next_token, request_id=request_id, finish_reason=finish_reason
) )
queue_is_flushed = False yield next_token
for result in self.running_continuous_batching_manager: if result.status == RequestStatus.FINISHED:
if result.request_id != request_id: break
yield "data: [DONE]\n\n"
except Exception as e:
logger.error(str(e))
yield f'data: {{"error": "{str(e)}"}}'
return stream_response(inputs[0])
def generate(self, req: ChatCompletionInput) -> Generator:
update_model = req.model != self.loaded_model
if update_model:
self.model, self.tokenizer = self.load_model_and_tokenizer(req.model, self.args)
# HACK for tiny-agents: it sends a request after the assistant message (???). Let's assume we can't have a
# request whose last message is from the assistant.
if req.messages[-1].role == "assistant":
return
# ====== TOOL PREPROCESSING LOGIC ======
tool_model_family = None
for supported_model_families in _MODELS_WITH_TOOL_SUPPORT:
if supported_model_families in self.model.config.architectures[0].lower():
tool_model_family = supported_model_families
break
# TODO: trigger 2 constrained generations after the tool call start token is emitted:
# 1. force generation to pick from the tool names
# 2. force generation to pick from that tool's arguments
# ====== END OF TOOL PREPROCESSING LOGIC ======
if tool_model_family is not None:
text = self.tokenizer.apply_chat_template(
req.messages, add_generation_prompt=True, tokenize=False, tools=req.tools
)
else:
text = self.tokenizer.apply_chat_template(req.messages, add_generation_prompt=True, tokenize=False)
inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device)["input_ids"]
request_id = req.request_id if req.request_id is not None else "req_0"
generation_streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True, skip_prompt=True)
generation_config = create_generation_config_from_req(req)
max_new_tokens = req.max_tokens or generation_config.max_new_tokens or 256
generation_config.max_new_tokens = max_new_tokens
last_kv_cache = None
if self.is_continuation(req) and not update_model:
last_kv_cache = self.last_kv_cache
generation_kwargs = {
"inputs": inputs,
"attention_mask": torch.ones_like(inputs),
"streamer": generation_streamer,
"generation_config": generation_config,
"return_dict_in_generate": True,
"past_key_values": last_kv_cache,
}
def stream_response(streamer, _request_id):
# Thin wrapper to save the KV cache after generation
def generate_with_cache(**kwargs):
generate_output = self.model.generate(**kwargs)
self.last_kv_cache = generate_output.past_key_values
thread = Thread(target=generate_with_cache, kwargs=generation_kwargs)
try:
thread.start()
tool_state = ToolState()
for result in streamer:
# ====== TOOL CALL LOGIC ======
if tool_model_family is not None:
# Start of a tool call: reset state variables, set `inside_tool_call`
if result.strip() == _TOOL_CALL_TOKENS[tool_model_family]["start"]:
tool_state.inside_tool_call = True
continue continue
if req.request_id is not None and not queue_is_flushed: # End of tool call: reset `inside_tool_call`, emit a `finish_reason`
if result.status == RequestStatus.FINISHED: if result.strip() == _TOOL_CALL_TOKENS[tool_model_family]["end"]:
continue tool_state.reset()
yield self.build_chunk("", _request_id, role=None, finish_reason="tool_calls")
continue
# Inside a tool call
if tool_state.inside_tool_call:
tool_state.buffer += result
# First step: extract the tool name (may need several tokens, and we can't emit a delta
# until we have the full name)
if not tool_state.has_tool_name_defined:
tool_name = re.search(r"\"name\": \"(.*?)\"", tool_state.buffer)
if tool_name is None:
continue
else:
tool_name = tool_name.group(1)
tool_state.has_tool_name_defined = True
tool = ChatCompletionStreamOutputDeltaToolCall(
function=ChatCompletionStreamOutputFunction(
name=tool_name,
arguments=None,
),
index=0,
type="function",
id=_request_id + "_tool_call", # Only the first tool call delta has an id
)
# Second step: extract tool arguments. The tool arguments can be seen as a json string
# within the tool json string. We emit a delta for the arguments.
else: else:
queue_is_flushed = True # Empty text: skip
if result == "":
continue
# Until we see the `"arguments": {` in the buffer, we skip
# TODO: other models will likely need more elaborate processing here
if '"arguments": {' not in tool_state.buffer:
continue
finish_reason = "stop" if result.status == RequestStatus.FINISHED else None # Handle nesting. We want to exclude the last } from the emitted arguments (it's
next_token = self.build_chunk( # closing the outermost nesting level, outside the arguments block)
result.next_token, request_id=request_id, finish_reason=finish_reason tool_state.arg_nesting_level += result.count("{")
) tool_state.arg_nesting_level -= result.count("}")
yield next_token if tool_state.arg_nesting_level < 0:
result = "".join(result.split("}")[:-2]) + "}" # e.g. "4}}\n" -> "4}"
if result.status == RequestStatus.FINISHED: tool = ChatCompletionStreamOutputDeltaToolCall(
break function=ChatCompletionStreamOutputFunction(
arguments=result,
),
index=0,
type="function",
id=None,
)
yield "data: [DONE]\n\n" yield self.build_chunk(None, _request_id, role=None, tool_calls=[tool])
except Exception as e: continue
logger.error(str(e)) # ====== END OF TOOL CALL LOGIC ======
yield f'data: {{"error": "{str(e)}"}}'
return StreamingResponse(stream_response(inputs[0]), media_type="text/event-stream") # All non-tool related tokens are emitted as assistant messages
yield self.build_chunk(result, _request_id, role="assistant")
yield self.build_chunk(None, _request_id, role=None, finish_reason="stop")
thread.join()
except Exception as e:
logger.error(str(e))
raise
yield f'data: {{"error": "{str(e)}"}}'
finally:
thread.join()
return stream_response(generation_streamer, request_id)
def generate_response(self, req: ResponsesInput) -> Generator:
update_model = req.model != self.loaded_model
if update_model:
self.model, self.tokenizer = self.load_model_and_tokenizer(req.model, self.args)
text = self.tokenizer.apply_chat_template(req.messages, add_generation_prompt=True, tokenize=False)
inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device)["input_ids"]
request_id = req.request_id if req.request_id is not None else "req_0"
generation_streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True, skip_prompt=True)
generation_config = create_generation_config_from_req(req)
max_new_tokens = req.max_tokens or generation_config.max_new_tokens or 256
generation_config.max_new_tokens = max_new_tokens
last_kv_cache = None
if self.is_continuation(req) and not update_model:
last_kv_cache = self.last_kv_cache
generation_kwargs = {
"inputs": inputs,
"attention_mask": torch.ones_like(inputs),
"streamer": generation_streamer,
"generation_config": generation_config,
"return_dict_in_generate": True,
"past_key_values": last_kv_cache,
}
def stream_response(streamer, _request_id):
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
try:
thread.start()
for result in streamer:
yield self.build_chunk(result, _request_id, role="assistant")
yield self.build_chunk(None, _request_id, role=None, finish_reason="stop")
thread.join()
except Exception as e:
logger.error(str(e))
raise
yield f'data: {{"error": "{str(e)}"}}'
finally:
thread.join()
return stream_response(generation_streamer, request_id)
def is_continuation(self, req: "ChatCompletionInput") -> bool: def is_continuation(self, req: "ChatCompletionInput") -> bool:
""" """
@ -423,156 +687,6 @@ class ServeCommand(BaseTransformersCLICommand):
self.last_messages = req.messages self.last_messages = req.messages
return req_continues_last_messages return req_continues_last_messages
def generate(self, app):
@app.post("/v1/chat/completions")
def _serve(req: "ChatCompletionInput"):
update_model = req.model != self.loaded_model
if update_model:
self.model, self.tokenizer = self.load_model_and_tokenizer(req.model, self.args)
if not req.stream:
return {"error": "Only streaming mode is supported."}
# HACK for tiny-agents: it sends a request after the assistant message (???). Let's assume we can't have a
# request whose last message is from the assistant.
if req.messages[-1].role == "assistant":
return
# ====== TOOL PREPROCESSING LOGIC ======
tool_model_family = None
for supported_model_families in _MODELS_WITH_TOOL_SUPPORT:
if supported_model_families in self.model.config.architectures[0].lower():
tool_model_family = supported_model_families
break
# TODO: trigger 2 constrained generations after the tool call start token is emitted:
# 1. force generation to pick from the tool names
# 2. force generation to pick from that tool's arguments
# ====== END OF TOOL PREPROCESSING LOGIC ======
if tool_model_family is not None:
text = self.tokenizer.apply_chat_template(
req.messages, add_generation_prompt=True, tokenize=False, tools=req.tools
)
else:
text = self.tokenizer.apply_chat_template(req.messages, add_generation_prompt=True, tokenize=False)
inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device)["input_ids"]
request_id = req.request_id if req.request_id is not None else "req_0"
generation_streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True, skip_prompt=True)
generation_config = create_generation_config_from_req(req)
max_new_tokens = req.max_tokens or generation_config.max_new_tokens or 256
generation_config.max_new_tokens = max_new_tokens
last_kv_cache = None
if self.is_continuation(req) and not update_model:
last_kv_cache = self.last_kv_cache
generation_kwargs = {
"inputs": inputs,
"attention_mask": torch.ones_like(inputs),
"streamer": generation_streamer,
"generation_config": generation_config,
"return_dict_in_generate": True,
"past_key_values": last_kv_cache,
}
def stream_response(streamer, _request_id):
# Thin wrapper to save the KV cache after generation
def generate_with_cache(**kwargs):
generate_output = self.model.generate(**kwargs)
self.last_kv_cache = generate_output.past_key_values
thread = Thread(target=generate_with_cache, kwargs=generation_kwargs)
try:
thread.start()
tool_state = ToolState()
for result in streamer:
# ====== TOOL CALL LOGIC ======
if tool_model_family is not None:
# Start of a tool call: reset state variables, set `inside_tool_call`
if result.strip() == _TOOL_CALL_TOKENS[tool_model_family]["start"]:
tool_state.inside_tool_call = True
continue
# End of tool call: reset `inside_tool_call`, emit a `finish_reason`
if result.strip() == _TOOL_CALL_TOKENS[tool_model_family]["end"]:
tool_state.reset()
yield self.build_chunk("", _request_id, role=None, finish_reason="tool_calls")
continue
# Inside a tool call
if tool_state.inside_tool_call:
tool_state.buffer += result
# First step: extract the tool name (may need several tokens, and we can't emit a delta
# until we have the full name)
if not tool_state.has_tool_name_defined:
tool_name = re.search(r"\"name\": \"(.*?)\"", tool_state.buffer)
if tool_name is None:
continue
else:
tool_name = tool_name.group(1)
tool_state.has_tool_name_defined = True
tool = ChatCompletionStreamOutputDeltaToolCall(
function=ChatCompletionStreamOutputFunction(
name=tool_name,
arguments=None,
),
index=0,
type="function",
id=_request_id + "_tool_call", # Only the first tool call delta has an id
)
# Second step: extract tool arguments. The tool arguments can be seen as a json string
# within the tool json string. We emit a delta for the arguments.
else:
# Empty text: skip
if result == "":
continue
# Until we see the `"arguments": {` in the buffer, we skip
# TODO: other models will likely need more elaborate processing here
if '"arguments": {' not in tool_state.buffer:
continue
# Handle nesting. We want to exclude the last } from the emitted arguments (it's
# closing the outermost nesting level, outside the arguments block)
tool_state.arg_nesting_level += result.count("{")
tool_state.arg_nesting_level -= result.count("}")
if tool_state.arg_nesting_level < 0:
result = "".join(result.split("}")[:-2]) + "}" # e.g. "4}}\n" -> "4}"
tool = ChatCompletionStreamOutputDeltaToolCall(
function=ChatCompletionStreamOutputFunction(
arguments=result,
),
index=0,
type="function",
id=None,
)
yield self.build_chunk(None, _request_id, role=None, tool_calls=[tool])
continue
# ====== END OF TOOL CALL LOGIC ======
# All non-tool related tokens are emitted as assistant messages
yield self.build_chunk(result, _request_id, role="assistant")
yield self.build_chunk(None, _request_id, role=None, finish_reason="stop")
thread.join()
except Exception as e:
logger.error(str(e))
raise
yield f'data: {{"error": "{str(e)}"}}'
finally:
thread.join()
return StreamingResponse(stream_response(generation_streamer, request_id), media_type="text/event-stream")
@staticmethod @staticmethod
def get_quantization_config(model_args: ServeArguments) -> Optional["BitsAndBytesConfig"]: def get_quantization_config(model_args: ServeArguments) -> Optional["BitsAndBytesConfig"]: