Scaffolding

This commit is contained in:
Lysandre 2025-07-01 18:22:55 +02:00
parent 733bcb4fed
commit feb1eb7531

View File

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import enum
import functools import functools
import json import json
import re import re
@ -18,7 +19,7 @@ import time
from argparse import ArgumentParser, Namespace from argparse import ArgumentParser, Namespace
from dataclasses import dataclass, field from dataclasses import dataclass, field
from threading import Thread from threading import Thread
from typing import Any, Optional from typing import Any, Optional, Generator, Literal
from huggingface_hub import ( from huggingface_hub import (
ChatCompletionStreamOutputChoice, ChatCompletionStreamOutputChoice,
@ -26,7 +27,7 @@ from huggingface_hub import (
ChatCompletionStreamOutputDeltaToolCall, ChatCompletionStreamOutputDeltaToolCall,
ChatCompletionStreamOutputFunction, ChatCompletionStreamOutputFunction,
ModelInfo, ModelInfo,
model_info, model_info, ChatCompletionStreamOutput,
) )
from transformers.utils.import_utils import is_fastapi_available, is_pydantic_available, is_uvicorn_available from transformers.utils.import_utils import is_fastapi_available, is_pydantic_available, is_uvicorn_available
@ -59,11 +60,42 @@ if is_pydantic_available() and is_fastapi_available() and is_uvicorn_available()
role: str role: str
content: str content: str
class ChatCompletionInput(BaseModel): class Prompt(BaseModel):
messages: list[Message] id: str
variables: dict
version: Optional[str]
class TextFormatOptions(enum.StrEnum):
text = "text"
json_schema = "json_schema"
class TextFormat(BaseModel):
type: TextFormatOptions
class ResponsesInput(BaseModel):
input: str | list
model: str
stream: Optional[bool] = False
instructions: Optional[str] = None
max_output_tokens: Optional[int] = None
max_tool_calls: Optional[int] = None
previous_response_id: Optional[str] = None
prompt: Optional[Prompt] = None
temperature: Optional[float] = None
text: Optional[TextFormat] = None
tools: any = None
top_p: Optional[float] = None
# Additional options supported by the Responses API
# that aren't yet supported here.
# top_logprobs
class ChatCompletionInput(BaseModel):
messages: list[Message]
model: str
stream: Optional[bool] = False stream: Optional[bool] = False
model: Optional[str] = None
request_id: Optional[str] = None request_id: Optional[str] = None
extra_body: Optional[dict] = None extra_body: Optional[dict] = None
frequency_penalty: Optional[float] = None frequency_penalty: Optional[float] = None
@ -250,7 +282,7 @@ class ServeCommand(BaseTransformersCLICommand):
cb_logger = logging.get_logger("transformers.generation.continuous_batching") cb_logger = logging.get_logger("transformers.generation.continuous_batching")
cb_logger.setLevel(logging.log_levels[self.args.log_level.lower()]) cb_logger.setLevel(logging.log_levels[self.args.log_level.lower()])
def build_chunk( def build_chat_completion_chunk(
self, self,
content: str, content: str,
request_id: str, request_id: str,
@ -279,16 +311,78 @@ class ServeCommand(BaseTransformersCLICommand):
} }
return f"data: {json.dumps(payload)}\n\n" return f"data: {json.dumps(payload)}\n\n"
def build_responses_chunk(
self,
content: str,
request_id: str,
role: Optional[str] = None,
finish_reason: Optional[str] = None,
tool_calls: Optional[list[ChatCompletionStreamOutputDeltaToolCall]] = None,
) -> str:
payload = {
"object": "chat.completion.chunk",
"id": request_id,
"created": int(time.time()),
"model": self.loaded_model,
"system_fingerprint": "",
"choices": [
ChatCompletionStreamOutputChoice(
delta=ChatCompletionStreamOutputDelta(
role=role,
content=content,
tool_calls=tool_calls,
),
index=0,
logprobs=None,
finish_reason=finish_reason,
),
],
}
return f"data: {json.dumps(payload)}\n\n"
def run(self): def run(self):
app = FastAPI() app = FastAPI()
if self.use_continuous_batching: @app.get("/v1/chat/completions")
self.continuous_batching(app) def chat_completion(req: ChatCompletionInput):
else: if not req.stream:
self.generate(app) return {"error": "Only streaming mode is supported."}
output = self.continuous_batching(req) if self.use_continuous_batching else self.generate(req)
return StreamingResponse(output, media_type="text/event-stream")
@app.get("/v1/responses")
def responses(req: ResponsesInput):
if not req.stream:
return {"error": "Only streaming mode is supported."}
output = self.generate_responses(req)
return StreamingResponse(output, media_type="text/event-stream")
@app.get("/v1/models")
def get_all_models():
return JSONResponse(
{
"object": "list",
"data": [
{
"id": model.id,
"object": "model",
"crated": model.created_at.timestamp(),
"owned_by": model.author,
}
for model in self.get_text_gen_models()
],
}
)
uvicorn.run(app, host=self.args.host, port=self.args.port, log_level=self.args.log_level)
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
def get_text_gen_models() -> list[ModelInfo]: def get_text_gen_models(self) -> list[ModelInfo]:
""" """
This is by no means a limit to which models may be instantiated with `transformers serve`: any chat-based This is by no means a limit to which models may be instantiated with `transformers serve`: any chat-based
model working with generate can work. model working with generate can work.
@ -308,31 +402,7 @@ class ServeCommand(BaseTransformersCLICommand):
model_info("meta-llama/Llama-3.3-70B-Instruct"), model_info("meta-llama/Llama-3.3-70B-Instruct"),
] ]
@app.get("/v1/models") def continuous_batching(self, req: ChatCompletionInput) -> Generator:
def get_all_models():
return JSONResponse(
{
"object": "list",
"data": [
{
"id": model.id,
"object": "model",
"crated": model.created_at.timestamp(),
"owned_by": model.author,
}
for model in get_text_gen_models()
],
}
)
uvicorn.run(app, host=self.args.host, port=self.args.port, log_level=self.args.log_level)
def continuous_batching(self, app):
@app.post("/v1/chat/completions")
def _serve(req: "ChatCompletionInput"):
if not req.stream:
return {"error": "Only streaming mode is supported."}
update_model = req.model != self.loaded_model update_model = req.model != self.loaded_model
if update_model: if update_model:
self.model, self.tokenizer = self.load_model_and_tokenizer(req.model, self.args) self.model, self.tokenizer = self.load_model_and_tokenizer(req.model, self.args)
@ -392,48 +462,13 @@ class ServeCommand(BaseTransformersCLICommand):
logger.error(str(e)) logger.error(str(e))
yield f'data: {{"error": "{str(e)}"}}' yield f'data: {{"error": "{str(e)}"}}'
return StreamingResponse(stream_response(inputs[0]), media_type="text/event-stream") return stream_response(inputs[0])
def is_continuation(self, req: "ChatCompletionInput") -> bool: def generate(self, req: ChatCompletionInput) -> Generator:
"""
Determines whether the current request is a continuation of the last request. In other words, if it is the
same chat session.
Args:
req (`ChatCompletionInput`): The request to check.
Returns:
`True` if the request is a continuation of the last request, `False` otherwise.
"""
req_continues_last_messages = True
# No cached messages: this is a new request
if self.last_messages is None:
req_continues_last_messages = False
# The new request has fewer rounds of conversation: this is a new request
elif len(self.last_messages) > len(req.messages):
req_continues_last_messages = False
# Otherwise, check that the last messages are a subset of the new request
else:
for i in range(len(self.last_messages)):
if self.last_messages[i] != req.messages[i]:
req_continues_last_messages = False
break
self.last_messages = req.messages
return req_continues_last_messages
def generate(self, app):
@app.post("/v1/chat/completions")
def _serve(req: "ChatCompletionInput"):
update_model = req.model != self.loaded_model update_model = req.model != self.loaded_model
if update_model: if update_model:
self.model, self.tokenizer = self.load_model_and_tokenizer(req.model, self.args) self.model, self.tokenizer = self.load_model_and_tokenizer(req.model, self.args)
if not req.stream:
return {"error": "Only streaming mode is supported."}
# HACK for tiny-agents: it sends a request after the assistant message (???). Let's assume we can't have a # HACK for tiny-agents: it sends a request after the assistant message (???). Let's assume we can't have a
# request whose last message is from the assistant. # request whose last message is from the assistant.
if req.messages[-1].role == "assistant": if req.messages[-1].role == "assistant":
@ -572,7 +607,86 @@ class ServeCommand(BaseTransformersCLICommand):
finally: finally:
thread.join() thread.join()
return StreamingResponse(stream_response(generation_streamer, request_id), media_type="text/event-stream") return stream_response(generation_streamer, request_id)
def generate_response(self, req: ResponsesInput) -> Generator:
update_model = req.model != self.loaded_model
if update_model:
self.model, self.tokenizer = self.load_model_and_tokenizer(req.model, self.args)
text = self.tokenizer.apply_chat_template(req.messages, add_generation_prompt=True, tokenize=False)
inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device)["input_ids"]
request_id = req.request_id if req.request_id is not None else "req_0"
generation_streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True, skip_prompt=True)
generation_config = create_generation_config_from_req(req)
max_new_tokens = req.max_tokens or generation_config.max_new_tokens or 256
generation_config.max_new_tokens = max_new_tokens
last_kv_cache = None
if self.is_continuation(req) and not update_model:
last_kv_cache = self.last_kv_cache
generation_kwargs = {
"inputs": inputs,
"attention_mask": torch.ones_like(inputs),
"streamer": generation_streamer,
"generation_config": generation_config,
"return_dict_in_generate": True,
"past_key_values": last_kv_cache,
}
def stream_response(streamer, _request_id):
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
try:
thread.start()
for result in streamer:
yield self.build_chunk(result, _request_id, role="assistant")
yield self.build_chunk(None, _request_id, role=None, finish_reason="stop")
thread.join()
except Exception as e:
logger.error(str(e))
raise
yield f'data: {{"error": "{str(e)}"}}'
finally:
thread.join()
return stream_response(generation_streamer, request_id)
def is_continuation(self, req: "ChatCompletionInput") -> bool:
"""
Determines whether the current request is a continuation of the last request. In other words, if it is the
same chat session.
Args:
req (`ChatCompletionInput`): The request to check.
Returns:
`True` if the request is a continuation of the last request, `False` otherwise.
"""
req_continues_last_messages = True
# No cached messages: this is a new request
if self.last_messages is None:
req_continues_last_messages = False
# The new request has fewer rounds of conversation: this is a new request
elif len(self.last_messages) > len(req.messages):
req_continues_last_messages = False
# Otherwise, check that the last messages are a subset of the new request
else:
for i in range(len(self.last_messages)):
if self.last_messages[i] != req.messages[i]:
req_continues_last_messages = False
break
self.last_messages = req.messages
return req_continues_last_messages
@staticmethod @staticmethod
def get_quantization_config(model_args: ServeArguments) -> Optional["BitsAndBytesConfig"]: def get_quantization_config(model_args: ServeArguments) -> Optional["BitsAndBytesConfig"]: