Scaffolding

This commit is contained in:
Lysandre 2025-07-01 18:22:55 +02:00
parent 733bcb4fed
commit feb1eb7531

View File

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
import functools
import json
import re
@ -18,7 +19,7 @@ import time
from argparse import ArgumentParser, Namespace
from dataclasses import dataclass, field
from threading import Thread
from typing import Any, Optional
from typing import Any, Optional, Generator, Literal
from huggingface_hub import (
ChatCompletionStreamOutputChoice,
@ -26,7 +27,7 @@ from huggingface_hub import (
ChatCompletionStreamOutputDeltaToolCall,
ChatCompletionStreamOutputFunction,
ModelInfo,
model_info,
model_info, ChatCompletionStreamOutput,
)
from transformers.utils.import_utils import is_fastapi_available, is_pydantic_available, is_uvicorn_available
@ -59,11 +60,42 @@ if is_pydantic_available() and is_fastapi_available() and is_uvicorn_available()
role: str
content: str
class ChatCompletionInput(BaseModel):
messages: list[Message]
class Prompt(BaseModel):
id: str
variables: dict
version: Optional[str]
class TextFormatOptions(enum.StrEnum):
text = "text"
json_schema = "json_schema"
class TextFormat(BaseModel):
type: TextFormatOptions
class ResponsesInput(BaseModel):
input: str | list
model: str
stream: Optional[bool] = False
instructions: Optional[str] = None
max_output_tokens: Optional[int] = None
max_tool_calls: Optional[int] = None
previous_response_id: Optional[str] = None
prompt: Optional[Prompt] = None
temperature: Optional[float] = None
text: Optional[TextFormat] = None
tools: any = None
top_p: Optional[float] = None
# Additional options supported by the Responses API
# that aren't yet supported here.
# top_logprobs
class ChatCompletionInput(BaseModel):
messages: list[Message]
model: str
stream: Optional[bool] = False
model: Optional[str] = None
request_id: Optional[str] = None
extra_body: Optional[dict] = None
frequency_penalty: Optional[float] = None
@ -250,7 +282,7 @@ class ServeCommand(BaseTransformersCLICommand):
cb_logger = logging.get_logger("transformers.generation.continuous_batching")
cb_logger.setLevel(logging.log_levels[self.args.log_level.lower()])
def build_chunk(
def build_chat_completion_chunk(
self,
content: str,
request_id: str,
@ -279,34 +311,56 @@ class ServeCommand(BaseTransformersCLICommand):
}
return f"data: {json.dumps(payload)}\n\n"
def build_responses_chunk(
self,
content: str,
request_id: str,
role: Optional[str] = None,
finish_reason: Optional[str] = None,
tool_calls: Optional[list[ChatCompletionStreamOutputDeltaToolCall]] = None,
) -> str:
payload = {
"object": "chat.completion.chunk",
"id": request_id,
"created": int(time.time()),
"model": self.loaded_model,
"system_fingerprint": "",
"choices": [
ChatCompletionStreamOutputChoice(
delta=ChatCompletionStreamOutputDelta(
role=role,
content=content,
tool_calls=tool_calls,
),
index=0,
logprobs=None,
finish_reason=finish_reason,
),
],
}
return f"data: {json.dumps(payload)}\n\n"
def run(self):
app = FastAPI()
if self.use_continuous_batching:
self.continuous_batching(app)
else:
self.generate(app)
@app.get("/v1/chat/completions")
def chat_completion(req: ChatCompletionInput):
if not req.stream:
return {"error": "Only streaming mode is supported."}
@functools.lru_cache(maxsize=None)
def get_text_gen_models() -> list[ModelInfo]:
"""
This is by no means a limit to which models may be instantiated with `transformers serve`: any chat-based
model working with generate can work.
output = self.continuous_batching(req) if self.use_continuous_batching else self.generate(req)
This is a limited list of models to ensure we have a discoverable /v1/models endpoint for third-party
integrations.
"""
return [
model_info("Menlo/Jan-nano"),
model_info("Menlo/Jan-nano-128k"),
model_info("Qwen/Qwen2.5-0.5B-Instruct"),
model_info("Qwen/Qwen2.5-3B-Instruct"),
model_info("Qwen/Qwen2.5-7B-Instruct"),
model_info("Qwen/Qwen2.5-14B-Instruct"),
model_info("meta-llama/Llama-3.1-8B-Instruct"),
model_info("meta-llama/Llama-3.2-1B-Instruct"),
model_info("meta-llama/Llama-3.3-70B-Instruct"),
]
return StreamingResponse(output, media_type="text/event-stream")
@app.get("/v1/responses")
def responses(req: ResponsesInput):
if not req.stream:
return {"error": "Only streaming mode is supported."}
output = self.generate_responses(req)
return StreamingResponse(output, media_type="text/event-stream")
@app.get("/v1/models")
def get_all_models():
@ -320,79 +374,289 @@ class ServeCommand(BaseTransformersCLICommand):
"crated": model.created_at.timestamp(),
"owned_by": model.author,
}
for model in get_text_gen_models()
for model in self.get_text_gen_models()
],
}
)
uvicorn.run(app, host=self.args.host, port=self.args.port, log_level=self.args.log_level)
def continuous_batching(self, app):
@app.post("/v1/chat/completions")
def _serve(req: "ChatCompletionInput"):
if not req.stream:
return {"error": "Only streaming mode is supported."}
@functools.lru_cache(maxsize=None)
def get_text_gen_models(self) -> list[ModelInfo]:
"""
This is by no means a limit to which models may be instantiated with `transformers serve`: any chat-based
model working with generate can work.
update_model = req.model != self.loaded_model
if update_model:
self.model, self.tokenizer = self.load_model_and_tokenizer(req.model, self.args)
This is a limited list of models to ensure we have a discoverable /v1/models endpoint for third-party
integrations.
"""
return [
model_info("Menlo/Jan-nano"),
model_info("Menlo/Jan-nano-128k"),
model_info("Qwen/Qwen2.5-0.5B-Instruct"),
model_info("Qwen/Qwen2.5-3B-Instruct"),
model_info("Qwen/Qwen2.5-7B-Instruct"),
model_info("Qwen/Qwen2.5-14B-Instruct"),
model_info("meta-llama/Llama-3.1-8B-Instruct"),
model_info("meta-llama/Llama-3.2-1B-Instruct"),
model_info("meta-llama/Llama-3.3-70B-Instruct"),
]
generation_config = create_generation_config_from_req(
req,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id,
use_cache=False,
num_blocks=1,
block_size=1024,
do_sample=False,
max_batch_tokens=10,
scheduler="fifo",
def continuous_batching(self, req: ChatCompletionInput) -> Generator:
update_model = req.model != self.loaded_model
if update_model:
self.model, self.tokenizer = self.load_model_and_tokenizer(req.model, self.args)
generation_config = create_generation_config_from_req(
req,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id,
use_cache=False,
num_blocks=1,
block_size=1024,
do_sample=False,
max_batch_tokens=10,
scheduler="fifo",
)
if self.running_continuous_batching_manager is None or update_model:
self.running_continuous_batching_manager = self.model.init_continuous_batching(
generation_config=generation_config, streaming=True
)
self.running_continuous_batching_manager.logit_processor = LogitsProcessorList()
self.running_continuous_batching_manager.start()
if self.running_continuous_batching_manager is None or update_model:
self.running_continuous_batching_manager = self.model.init_continuous_batching(
generation_config=generation_config, streaming=True
inputs = self.tokenizer.apply_chat_template(
req.messages, return_tensors="pt", add_generation_prompt=True
).to(self.model.device)
def stream_response(_inputs):
try:
max_new_tokens = req.max_tokens or generation_config.max_new_tokens or 256
request_id = self.running_continuous_batching_manager.add_request(
_inputs, request_id=req.request_id, max_new_tokens=max_new_tokens
)
self.running_continuous_batching_manager.logit_processor = LogitsProcessorList()
self.running_continuous_batching_manager.start()
queue_is_flushed = False
inputs = self.tokenizer.apply_chat_template(
req.messages, return_tensors="pt", add_generation_prompt=True
).to(self.model.device)
for result in self.running_continuous_batching_manager:
if result.request_id != request_id:
continue
def stream_response(_inputs):
try:
max_new_tokens = req.max_tokens or generation_config.max_new_tokens or 256
request_id = self.running_continuous_batching_manager.add_request(
_inputs, request_id=req.request_id, max_new_tokens=max_new_tokens
if req.request_id is not None and not queue_is_flushed:
if result.status == RequestStatus.FINISHED:
continue
else:
queue_is_flushed = True
finish_reason = "stop" if result.status == RequestStatus.FINISHED else None
next_token = self.build_chunk(
result.next_token, request_id=request_id, finish_reason=finish_reason
)
queue_is_flushed = False
yield next_token
for result in self.running_continuous_batching_manager:
if result.request_id != request_id:
if result.status == RequestStatus.FINISHED:
break
yield "data: [DONE]\n\n"
except Exception as e:
logger.error(str(e))
yield f'data: {{"error": "{str(e)}"}}'
return stream_response(inputs[0])
def generate(self, req: ChatCompletionInput) -> Generator:
update_model = req.model != self.loaded_model
if update_model:
self.model, self.tokenizer = self.load_model_and_tokenizer(req.model, self.args)
# HACK for tiny-agents: it sends a request after the assistant message (???). Let's assume we can't have a
# request whose last message is from the assistant.
if req.messages[-1].role == "assistant":
return
# ====== TOOL PREPROCESSING LOGIC ======
tool_model_family = None
for supported_model_families in _MODELS_WITH_TOOL_SUPPORT:
if supported_model_families in self.model.config.architectures[0].lower():
tool_model_family = supported_model_families
break
# TODO: trigger 2 constrained generations after the tool call start token is emitted:
# 1. force generation to pick from the tool names
# 2. force generation to pick from that tool's arguments
# ====== END OF TOOL PREPROCESSING LOGIC ======
if tool_model_family is not None:
text = self.tokenizer.apply_chat_template(
req.messages, add_generation_prompt=True, tokenize=False, tools=req.tools
)
else:
text = self.tokenizer.apply_chat_template(req.messages, add_generation_prompt=True, tokenize=False)
inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device)["input_ids"]
request_id = req.request_id if req.request_id is not None else "req_0"
generation_streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True, skip_prompt=True)
generation_config = create_generation_config_from_req(req)
max_new_tokens = req.max_tokens or generation_config.max_new_tokens or 256
generation_config.max_new_tokens = max_new_tokens
last_kv_cache = None
if self.is_continuation(req) and not update_model:
last_kv_cache = self.last_kv_cache
generation_kwargs = {
"inputs": inputs,
"attention_mask": torch.ones_like(inputs),
"streamer": generation_streamer,
"generation_config": generation_config,
"return_dict_in_generate": True,
"past_key_values": last_kv_cache,
}
def stream_response(streamer, _request_id):
# Thin wrapper to save the KV cache after generation
def generate_with_cache(**kwargs):
generate_output = self.model.generate(**kwargs)
self.last_kv_cache = generate_output.past_key_values
thread = Thread(target=generate_with_cache, kwargs=generation_kwargs)
try:
thread.start()
tool_state = ToolState()
for result in streamer:
# ====== TOOL CALL LOGIC ======
if tool_model_family is not None:
# Start of a tool call: reset state variables, set `inside_tool_call`
if result.strip() == _TOOL_CALL_TOKENS[tool_model_family]["start"]:
tool_state.inside_tool_call = True
continue
if req.request_id is not None and not queue_is_flushed:
if result.status == RequestStatus.FINISHED:
continue
# End of tool call: reset `inside_tool_call`, emit a `finish_reason`
if result.strip() == _TOOL_CALL_TOKENS[tool_model_family]["end"]:
tool_state.reset()
yield self.build_chunk("", _request_id, role=None, finish_reason="tool_calls")
continue
# Inside a tool call
if tool_state.inside_tool_call:
tool_state.buffer += result
# First step: extract the tool name (may need several tokens, and we can't emit a delta
# until we have the full name)
if not tool_state.has_tool_name_defined:
tool_name = re.search(r"\"name\": \"(.*?)\"", tool_state.buffer)
if tool_name is None:
continue
else:
tool_name = tool_name.group(1)
tool_state.has_tool_name_defined = True
tool = ChatCompletionStreamOutputDeltaToolCall(
function=ChatCompletionStreamOutputFunction(
name=tool_name,
arguments=None,
),
index=0,
type="function",
id=_request_id + "_tool_call", # Only the first tool call delta has an id
)
# Second step: extract tool arguments. The tool arguments can be seen as a json string
# within the tool json string. We emit a delta for the arguments.
else:
queue_is_flushed = True
# Empty text: skip
if result == "":
continue
# Until we see the `"arguments": {` in the buffer, we skip
# TODO: other models will likely need more elaborate processing here
if '"arguments": {' not in tool_state.buffer:
continue
finish_reason = "stop" if result.status == RequestStatus.FINISHED else None
next_token = self.build_chunk(
result.next_token, request_id=request_id, finish_reason=finish_reason
)
yield next_token
# Handle nesting. We want to exclude the last } from the emitted arguments (it's
# closing the outermost nesting level, outside the arguments block)
tool_state.arg_nesting_level += result.count("{")
tool_state.arg_nesting_level -= result.count("}")
if tool_state.arg_nesting_level < 0:
result = "".join(result.split("}")[:-2]) + "}" # e.g. "4}}\n" -> "4}"
if result.status == RequestStatus.FINISHED:
break
tool = ChatCompletionStreamOutputDeltaToolCall(
function=ChatCompletionStreamOutputFunction(
arguments=result,
),
index=0,
type="function",
id=None,
)
yield "data: [DONE]\n\n"
except Exception as e:
logger.error(str(e))
yield f'data: {{"error": "{str(e)}"}}'
yield self.build_chunk(None, _request_id, role=None, tool_calls=[tool])
continue
# ====== END OF TOOL CALL LOGIC ======
return StreamingResponse(stream_response(inputs[0]), media_type="text/event-stream")
# All non-tool related tokens are emitted as assistant messages
yield self.build_chunk(result, _request_id, role="assistant")
yield self.build_chunk(None, _request_id, role=None, finish_reason="stop")
thread.join()
except Exception as e:
logger.error(str(e))
raise
yield f'data: {{"error": "{str(e)}"}}'
finally:
thread.join()
return stream_response(generation_streamer, request_id)
def generate_response(self, req: ResponsesInput) -> Generator:
update_model = req.model != self.loaded_model
if update_model:
self.model, self.tokenizer = self.load_model_and_tokenizer(req.model, self.args)
text = self.tokenizer.apply_chat_template(req.messages, add_generation_prompt=True, tokenize=False)
inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device)["input_ids"]
request_id = req.request_id if req.request_id is not None else "req_0"
generation_streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True, skip_prompt=True)
generation_config = create_generation_config_from_req(req)
max_new_tokens = req.max_tokens or generation_config.max_new_tokens or 256
generation_config.max_new_tokens = max_new_tokens
last_kv_cache = None
if self.is_continuation(req) and not update_model:
last_kv_cache = self.last_kv_cache
generation_kwargs = {
"inputs": inputs,
"attention_mask": torch.ones_like(inputs),
"streamer": generation_streamer,
"generation_config": generation_config,
"return_dict_in_generate": True,
"past_key_values": last_kv_cache,
}
def stream_response(streamer, _request_id):
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
try:
thread.start()
for result in streamer:
yield self.build_chunk(result, _request_id, role="assistant")
yield self.build_chunk(None, _request_id, role=None, finish_reason="stop")
thread.join()
except Exception as e:
logger.error(str(e))
raise
yield f'data: {{"error": "{str(e)}"}}'
finally:
thread.join()
return stream_response(generation_streamer, request_id)
def is_continuation(self, req: "ChatCompletionInput") -> bool:
"""
@ -423,156 +687,6 @@ class ServeCommand(BaseTransformersCLICommand):
self.last_messages = req.messages
return req_continues_last_messages
def generate(self, app):
@app.post("/v1/chat/completions")
def _serve(req: "ChatCompletionInput"):
update_model = req.model != self.loaded_model
if update_model:
self.model, self.tokenizer = self.load_model_and_tokenizer(req.model, self.args)
if not req.stream:
return {"error": "Only streaming mode is supported."}
# HACK for tiny-agents: it sends a request after the assistant message (???). Let's assume we can't have a
# request whose last message is from the assistant.
if req.messages[-1].role == "assistant":
return
# ====== TOOL PREPROCESSING LOGIC ======
tool_model_family = None
for supported_model_families in _MODELS_WITH_TOOL_SUPPORT:
if supported_model_families in self.model.config.architectures[0].lower():
tool_model_family = supported_model_families
break
# TODO: trigger 2 constrained generations after the tool call start token is emitted:
# 1. force generation to pick from the tool names
# 2. force generation to pick from that tool's arguments
# ====== END OF TOOL PREPROCESSING LOGIC ======
if tool_model_family is not None:
text = self.tokenizer.apply_chat_template(
req.messages, add_generation_prompt=True, tokenize=False, tools=req.tools
)
else:
text = self.tokenizer.apply_chat_template(req.messages, add_generation_prompt=True, tokenize=False)
inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device)["input_ids"]
request_id = req.request_id if req.request_id is not None else "req_0"
generation_streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True, skip_prompt=True)
generation_config = create_generation_config_from_req(req)
max_new_tokens = req.max_tokens or generation_config.max_new_tokens or 256
generation_config.max_new_tokens = max_new_tokens
last_kv_cache = None
if self.is_continuation(req) and not update_model:
last_kv_cache = self.last_kv_cache
generation_kwargs = {
"inputs": inputs,
"attention_mask": torch.ones_like(inputs),
"streamer": generation_streamer,
"generation_config": generation_config,
"return_dict_in_generate": True,
"past_key_values": last_kv_cache,
}
def stream_response(streamer, _request_id):
# Thin wrapper to save the KV cache after generation
def generate_with_cache(**kwargs):
generate_output = self.model.generate(**kwargs)
self.last_kv_cache = generate_output.past_key_values
thread = Thread(target=generate_with_cache, kwargs=generation_kwargs)
try:
thread.start()
tool_state = ToolState()
for result in streamer:
# ====== TOOL CALL LOGIC ======
if tool_model_family is not None:
# Start of a tool call: reset state variables, set `inside_tool_call`
if result.strip() == _TOOL_CALL_TOKENS[tool_model_family]["start"]:
tool_state.inside_tool_call = True
continue
# End of tool call: reset `inside_tool_call`, emit a `finish_reason`
if result.strip() == _TOOL_CALL_TOKENS[tool_model_family]["end"]:
tool_state.reset()
yield self.build_chunk("", _request_id, role=None, finish_reason="tool_calls")
continue
# Inside a tool call
if tool_state.inside_tool_call:
tool_state.buffer += result
# First step: extract the tool name (may need several tokens, and we can't emit a delta
# until we have the full name)
if not tool_state.has_tool_name_defined:
tool_name = re.search(r"\"name\": \"(.*?)\"", tool_state.buffer)
if tool_name is None:
continue
else:
tool_name = tool_name.group(1)
tool_state.has_tool_name_defined = True
tool = ChatCompletionStreamOutputDeltaToolCall(
function=ChatCompletionStreamOutputFunction(
name=tool_name,
arguments=None,
),
index=0,
type="function",
id=_request_id + "_tool_call", # Only the first tool call delta has an id
)
# Second step: extract tool arguments. The tool arguments can be seen as a json string
# within the tool json string. We emit a delta for the arguments.
else:
# Empty text: skip
if result == "":
continue
# Until we see the `"arguments": {` in the buffer, we skip
# TODO: other models will likely need more elaborate processing here
if '"arguments": {' not in tool_state.buffer:
continue
# Handle nesting. We want to exclude the last } from the emitted arguments (it's
# closing the outermost nesting level, outside the arguments block)
tool_state.arg_nesting_level += result.count("{")
tool_state.arg_nesting_level -= result.count("}")
if tool_state.arg_nesting_level < 0:
result = "".join(result.split("}")[:-2]) + "}" # e.g. "4}}\n" -> "4}"
tool = ChatCompletionStreamOutputDeltaToolCall(
function=ChatCompletionStreamOutputFunction(
arguments=result,
),
index=0,
type="function",
id=None,
)
yield self.build_chunk(None, _request_id, role=None, tool_calls=[tool])
continue
# ====== END OF TOOL CALL LOGIC ======
# All non-tool related tokens are emitted as assistant messages
yield self.build_chunk(result, _request_id, role="assistant")
yield self.build_chunk(None, _request_id, role=None, finish_reason="stop")
thread.join()
except Exception as e:
logger.error(str(e))
raise
yield f'data: {{"error": "{str(e)}"}}'
finally:
thread.join()
return StreamingResponse(stream_response(generation_streamer, request_id), media_type="text/event-stream")
@staticmethod
def get_quantization_config(model_args: ServeArguments) -> Optional["BitsAndBytesConfig"]: