Scaffolding

This commit is contained in:
Lysandre 2025-07-01 18:22:55 +02:00
parent 733bcb4fed
commit feb1eb7531

View File

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
import functools
import json
import re
@ -18,7 +19,7 @@ import time
from argparse import ArgumentParser, Namespace
from dataclasses import dataclass, field
from threading import Thread
from typing import Any, Optional
from typing import Any, Optional, Generator, Literal
from huggingface_hub import (
ChatCompletionStreamOutputChoice,
@ -26,7 +27,7 @@ from huggingface_hub import (
ChatCompletionStreamOutputDeltaToolCall,
ChatCompletionStreamOutputFunction,
ModelInfo,
model_info,
model_info, ChatCompletionStreamOutput,
)
from transformers.utils.import_utils import is_fastapi_available, is_pydantic_available, is_uvicorn_available
@ -59,11 +60,42 @@ if is_pydantic_available() and is_fastapi_available() and is_uvicorn_available()
role: str
content: str
class ChatCompletionInput(BaseModel):
messages: list[Message]
class Prompt(BaseModel):
id: str
variables: dict
version: Optional[str]
class TextFormatOptions(enum.StrEnum):
text = "text"
json_schema = "json_schema"
class TextFormat(BaseModel):
type: TextFormatOptions
class ResponsesInput(BaseModel):
input: str | list
model: str
stream: Optional[bool] = False
instructions: Optional[str] = None
max_output_tokens: Optional[int] = None
max_tool_calls: Optional[int] = None
previous_response_id: Optional[str] = None
prompt: Optional[Prompt] = None
temperature: Optional[float] = None
text: Optional[TextFormat] = None
tools: any = None
top_p: Optional[float] = None
# Additional options supported by the Responses API
# that aren't yet supported here.
# top_logprobs
class ChatCompletionInput(BaseModel):
messages: list[Message]
model: str
stream: Optional[bool] = False
model: Optional[str] = None
request_id: Optional[str] = None
extra_body: Optional[dict] = None
frequency_penalty: Optional[float] = None
@ -250,7 +282,7 @@ class ServeCommand(BaseTransformersCLICommand):
cb_logger = logging.get_logger("transformers.generation.continuous_batching")
cb_logger.setLevel(logging.log_levels[self.args.log_level.lower()])
def build_chunk(
def build_chat_completion_chunk(
self,
content: str,
request_id: str,
@ -279,16 +311,78 @@ class ServeCommand(BaseTransformersCLICommand):
}
return f"data: {json.dumps(payload)}\n\n"
def build_responses_chunk(
self,
content: str,
request_id: str,
role: Optional[str] = None,
finish_reason: Optional[str] = None,
tool_calls: Optional[list[ChatCompletionStreamOutputDeltaToolCall]] = None,
) -> str:
payload = {
"object": "chat.completion.chunk",
"id": request_id,
"created": int(time.time()),
"model": self.loaded_model,
"system_fingerprint": "",
"choices": [
ChatCompletionStreamOutputChoice(
delta=ChatCompletionStreamOutputDelta(
role=role,
content=content,
tool_calls=tool_calls,
),
index=0,
logprobs=None,
finish_reason=finish_reason,
),
],
}
return f"data: {json.dumps(payload)}\n\n"
def run(self):
app = FastAPI()
if self.use_continuous_batching:
self.continuous_batching(app)
else:
self.generate(app)
@app.get("/v1/chat/completions")
def chat_completion(req: ChatCompletionInput):
if not req.stream:
return {"error": "Only streaming mode is supported."}
output = self.continuous_batching(req) if self.use_continuous_batching else self.generate(req)
return StreamingResponse(output, media_type="text/event-stream")
@app.get("/v1/responses")
def responses(req: ResponsesInput):
if not req.stream:
return {"error": "Only streaming mode is supported."}
output = self.generate_responses(req)
return StreamingResponse(output, media_type="text/event-stream")
@app.get("/v1/models")
def get_all_models():
return JSONResponse(
{
"object": "list",
"data": [
{
"id": model.id,
"object": "model",
"crated": model.created_at.timestamp(),
"owned_by": model.author,
}
for model in self.get_text_gen_models()
],
}
)
uvicorn.run(app, host=self.args.host, port=self.args.port, log_level=self.args.log_level)
@functools.lru_cache(maxsize=None)
def get_text_gen_models() -> list[ModelInfo]:
def get_text_gen_models(self) -> list[ModelInfo]:
"""
This is by no means a limit to which models may be instantiated with `transformers serve`: any chat-based
model working with generate can work.
@ -308,31 +402,7 @@ class ServeCommand(BaseTransformersCLICommand):
model_info("meta-llama/Llama-3.3-70B-Instruct"),
]
@app.get("/v1/models")
def get_all_models():
return JSONResponse(
{
"object": "list",
"data": [
{
"id": model.id,
"object": "model",
"crated": model.created_at.timestamp(),
"owned_by": model.author,
}
for model in get_text_gen_models()
],
}
)
uvicorn.run(app, host=self.args.host, port=self.args.port, log_level=self.args.log_level)
def continuous_batching(self, app):
@app.post("/v1/chat/completions")
def _serve(req: "ChatCompletionInput"):
if not req.stream:
return {"error": "Only streaming mode is supported."}
def continuous_batching(self, req: ChatCompletionInput) -> Generator:
update_model = req.model != self.loaded_model
if update_model:
self.model, self.tokenizer = self.load_model_and_tokenizer(req.model, self.args)
@ -392,48 +462,13 @@ class ServeCommand(BaseTransformersCLICommand):
logger.error(str(e))
yield f'data: {{"error": "{str(e)}"}}'
return StreamingResponse(stream_response(inputs[0]), media_type="text/event-stream")
return stream_response(inputs[0])
def is_continuation(self, req: "ChatCompletionInput") -> bool:
"""
Determines whether the current request is a continuation of the last request. In other words, if it is the
same chat session.
Args:
req (`ChatCompletionInput`): The request to check.
Returns:
`True` if the request is a continuation of the last request, `False` otherwise.
"""
req_continues_last_messages = True
# No cached messages: this is a new request
if self.last_messages is None:
req_continues_last_messages = False
# The new request has fewer rounds of conversation: this is a new request
elif len(self.last_messages) > len(req.messages):
req_continues_last_messages = False
# Otherwise, check that the last messages are a subset of the new request
else:
for i in range(len(self.last_messages)):
if self.last_messages[i] != req.messages[i]:
req_continues_last_messages = False
break
self.last_messages = req.messages
return req_continues_last_messages
def generate(self, app):
@app.post("/v1/chat/completions")
def _serve(req: "ChatCompletionInput"):
def generate(self, req: ChatCompletionInput) -> Generator:
update_model = req.model != self.loaded_model
if update_model:
self.model, self.tokenizer = self.load_model_and_tokenizer(req.model, self.args)
if not req.stream:
return {"error": "Only streaming mode is supported."}
# HACK for tiny-agents: it sends a request after the assistant message (???). Let's assume we can't have a
# request whose last message is from the assistant.
if req.messages[-1].role == "assistant":
@ -572,7 +607,86 @@ class ServeCommand(BaseTransformersCLICommand):
finally:
thread.join()
return StreamingResponse(stream_response(generation_streamer, request_id), media_type="text/event-stream")
return stream_response(generation_streamer, request_id)
def generate_response(self, req: ResponsesInput) -> Generator:
update_model = req.model != self.loaded_model
if update_model:
self.model, self.tokenizer = self.load_model_and_tokenizer(req.model, self.args)
text = self.tokenizer.apply_chat_template(req.messages, add_generation_prompt=True, tokenize=False)
inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device)["input_ids"]
request_id = req.request_id if req.request_id is not None else "req_0"
generation_streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True, skip_prompt=True)
generation_config = create_generation_config_from_req(req)
max_new_tokens = req.max_tokens or generation_config.max_new_tokens or 256
generation_config.max_new_tokens = max_new_tokens
last_kv_cache = None
if self.is_continuation(req) and not update_model:
last_kv_cache = self.last_kv_cache
generation_kwargs = {
"inputs": inputs,
"attention_mask": torch.ones_like(inputs),
"streamer": generation_streamer,
"generation_config": generation_config,
"return_dict_in_generate": True,
"past_key_values": last_kv_cache,
}
def stream_response(streamer, _request_id):
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
try:
thread.start()
for result in streamer:
yield self.build_chunk(result, _request_id, role="assistant")
yield self.build_chunk(None, _request_id, role=None, finish_reason="stop")
thread.join()
except Exception as e:
logger.error(str(e))
raise
yield f'data: {{"error": "{str(e)}"}}'
finally:
thread.join()
return stream_response(generation_streamer, request_id)
def is_continuation(self, req: "ChatCompletionInput") -> bool:
"""
Determines whether the current request is a continuation of the last request. In other words, if it is the
same chat session.
Args:
req (`ChatCompletionInput`): The request to check.
Returns:
`True` if the request is a continuation of the last request, `False` otherwise.
"""
req_continues_last_messages = True
# No cached messages: this is a new request
if self.last_messages is None:
req_continues_last_messages = False
# The new request has fewer rounds of conversation: this is a new request
elif len(self.last_messages) > len(req.messages):
req_continues_last_messages = False
# Otherwise, check that the last messages are a subset of the new request
else:
for i in range(len(self.last_messages)):
if self.last_messages[i] != req.messages[i]:
req_continues_last_messages = False
break
self.last_messages = req.messages
return req_continues_last_messages
@staticmethod
def get_quantization_config(model_args: ServeArguments) -> Optional["BitsAndBytesConfig"]: