mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
Scaffolding
This commit is contained in:
parent
733bcb4fed
commit
feb1eb7531
@ -11,6 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import enum
|
||||
import functools
|
||||
import json
|
||||
import re
|
||||
@ -18,7 +19,7 @@ import time
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from dataclasses import dataclass, field
|
||||
from threading import Thread
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, Generator, Literal
|
||||
|
||||
from huggingface_hub import (
|
||||
ChatCompletionStreamOutputChoice,
|
||||
@ -26,7 +27,7 @@ from huggingface_hub import (
|
||||
ChatCompletionStreamOutputDeltaToolCall,
|
||||
ChatCompletionStreamOutputFunction,
|
||||
ModelInfo,
|
||||
model_info,
|
||||
model_info, ChatCompletionStreamOutput,
|
||||
)
|
||||
|
||||
from transformers.utils.import_utils import is_fastapi_available, is_pydantic_available, is_uvicorn_available
|
||||
@ -59,11 +60,42 @@ if is_pydantic_available() and is_fastapi_available() and is_uvicorn_available()
|
||||
role: str
|
||||
content: str
|
||||
|
||||
class ChatCompletionInput(BaseModel):
|
||||
messages: list[Message]
|
||||
class Prompt(BaseModel):
|
||||
id: str
|
||||
variables: dict
|
||||
version: Optional[str]
|
||||
|
||||
class TextFormatOptions(enum.StrEnum):
|
||||
text = "text"
|
||||
json_schema = "json_schema"
|
||||
|
||||
class TextFormat(BaseModel):
|
||||
type: TextFormatOptions
|
||||
|
||||
class ResponsesInput(BaseModel):
|
||||
input: str | list
|
||||
model: str
|
||||
|
||||
stream: Optional[bool] = False
|
||||
instructions: Optional[str] = None
|
||||
max_output_tokens: Optional[int] = None
|
||||
max_tool_calls: Optional[int] = None
|
||||
previous_response_id: Optional[str] = None
|
||||
prompt: Optional[Prompt] = None
|
||||
temperature: Optional[float] = None
|
||||
text: Optional[TextFormat] = None
|
||||
tools: any = None
|
||||
top_p: Optional[float] = None
|
||||
|
||||
# Additional options supported by the Responses API
|
||||
# that aren't yet supported here.
|
||||
# top_logprobs
|
||||
|
||||
class ChatCompletionInput(BaseModel):
|
||||
messages: list[Message]
|
||||
model: str
|
||||
|
||||
stream: Optional[bool] = False
|
||||
model: Optional[str] = None
|
||||
request_id: Optional[str] = None
|
||||
extra_body: Optional[dict] = None
|
||||
frequency_penalty: Optional[float] = None
|
||||
@ -250,7 +282,7 @@ class ServeCommand(BaseTransformersCLICommand):
|
||||
cb_logger = logging.get_logger("transformers.generation.continuous_batching")
|
||||
cb_logger.setLevel(logging.log_levels[self.args.log_level.lower()])
|
||||
|
||||
def build_chunk(
|
||||
def build_chat_completion_chunk(
|
||||
self,
|
||||
content: str,
|
||||
request_id: str,
|
||||
@ -279,16 +311,78 @@ class ServeCommand(BaseTransformersCLICommand):
|
||||
}
|
||||
return f"data: {json.dumps(payload)}\n\n"
|
||||
|
||||
def build_responses_chunk(
|
||||
self,
|
||||
content: str,
|
||||
request_id: str,
|
||||
role: Optional[str] = None,
|
||||
finish_reason: Optional[str] = None,
|
||||
tool_calls: Optional[list[ChatCompletionStreamOutputDeltaToolCall]] = None,
|
||||
) -> str:
|
||||
payload = {
|
||||
"object": "chat.completion.chunk",
|
||||
"id": request_id,
|
||||
"created": int(time.time()),
|
||||
"model": self.loaded_model,
|
||||
"system_fingerprint": "",
|
||||
"choices": [
|
||||
ChatCompletionStreamOutputChoice(
|
||||
delta=ChatCompletionStreamOutputDelta(
|
||||
role=role,
|
||||
content=content,
|
||||
tool_calls=tool_calls,
|
||||
),
|
||||
index=0,
|
||||
logprobs=None,
|
||||
finish_reason=finish_reason,
|
||||
),
|
||||
],
|
||||
}
|
||||
return f"data: {json.dumps(payload)}\n\n"
|
||||
|
||||
|
||||
|
||||
def run(self):
|
||||
app = FastAPI()
|
||||
|
||||
if self.use_continuous_batching:
|
||||
self.continuous_batching(app)
|
||||
else:
|
||||
self.generate(app)
|
||||
@app.get("/v1/chat/completions")
|
||||
def chat_completion(req: ChatCompletionInput):
|
||||
if not req.stream:
|
||||
return {"error": "Only streaming mode is supported."}
|
||||
|
||||
output = self.continuous_batching(req) if self.use_continuous_batching else self.generate(req)
|
||||
|
||||
return StreamingResponse(output, media_type="text/event-stream")
|
||||
|
||||
@app.get("/v1/responses")
|
||||
def responses(req: ResponsesInput):
|
||||
if not req.stream:
|
||||
return {"error": "Only streaming mode is supported."}
|
||||
|
||||
output = self.generate_responses(req)
|
||||
return StreamingResponse(output, media_type="text/event-stream")
|
||||
|
||||
@app.get("/v1/models")
|
||||
def get_all_models():
|
||||
return JSONResponse(
|
||||
{
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": model.id,
|
||||
"object": "model",
|
||||
"crated": model.created_at.timestamp(),
|
||||
"owned_by": model.author,
|
||||
}
|
||||
for model in self.get_text_gen_models()
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
uvicorn.run(app, host=self.args.host, port=self.args.port, log_level=self.args.log_level)
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_text_gen_models() -> list[ModelInfo]:
|
||||
def get_text_gen_models(self) -> list[ModelInfo]:
|
||||
"""
|
||||
This is by no means a limit to which models may be instantiated with `transformers serve`: any chat-based
|
||||
model working with generate can work.
|
||||
@ -308,31 +402,7 @@ class ServeCommand(BaseTransformersCLICommand):
|
||||
model_info("meta-llama/Llama-3.3-70B-Instruct"),
|
||||
]
|
||||
|
||||
@app.get("/v1/models")
|
||||
def get_all_models():
|
||||
return JSONResponse(
|
||||
{
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": model.id,
|
||||
"object": "model",
|
||||
"crated": model.created_at.timestamp(),
|
||||
"owned_by": model.author,
|
||||
}
|
||||
for model in get_text_gen_models()
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
uvicorn.run(app, host=self.args.host, port=self.args.port, log_level=self.args.log_level)
|
||||
|
||||
def continuous_batching(self, app):
|
||||
@app.post("/v1/chat/completions")
|
||||
def _serve(req: "ChatCompletionInput"):
|
||||
if not req.stream:
|
||||
return {"error": "Only streaming mode is supported."}
|
||||
|
||||
def continuous_batching(self, req: ChatCompletionInput) -> Generator:
|
||||
update_model = req.model != self.loaded_model
|
||||
if update_model:
|
||||
self.model, self.tokenizer = self.load_model_and_tokenizer(req.model, self.args)
|
||||
@ -392,48 +462,13 @@ class ServeCommand(BaseTransformersCLICommand):
|
||||
logger.error(str(e))
|
||||
yield f'data: {{"error": "{str(e)}"}}'
|
||||
|
||||
return StreamingResponse(stream_response(inputs[0]), media_type="text/event-stream")
|
||||
return stream_response(inputs[0])
|
||||
|
||||
def is_continuation(self, req: "ChatCompletionInput") -> bool:
|
||||
"""
|
||||
Determines whether the current request is a continuation of the last request. In other words, if it is the
|
||||
same chat session.
|
||||
|
||||
Args:
|
||||
req (`ChatCompletionInput`): The request to check.
|
||||
|
||||
Returns:
|
||||
`True` if the request is a continuation of the last request, `False` otherwise.
|
||||
"""
|
||||
req_continues_last_messages = True
|
||||
|
||||
# No cached messages: this is a new request
|
||||
if self.last_messages is None:
|
||||
req_continues_last_messages = False
|
||||
# The new request has fewer rounds of conversation: this is a new request
|
||||
elif len(self.last_messages) > len(req.messages):
|
||||
req_continues_last_messages = False
|
||||
# Otherwise, check that the last messages are a subset of the new request
|
||||
else:
|
||||
for i in range(len(self.last_messages)):
|
||||
if self.last_messages[i] != req.messages[i]:
|
||||
req_continues_last_messages = False
|
||||
break
|
||||
|
||||
self.last_messages = req.messages
|
||||
return req_continues_last_messages
|
||||
|
||||
def generate(self, app):
|
||||
@app.post("/v1/chat/completions")
|
||||
def _serve(req: "ChatCompletionInput"):
|
||||
def generate(self, req: ChatCompletionInput) -> Generator:
|
||||
update_model = req.model != self.loaded_model
|
||||
|
||||
if update_model:
|
||||
self.model, self.tokenizer = self.load_model_and_tokenizer(req.model, self.args)
|
||||
|
||||
if not req.stream:
|
||||
return {"error": "Only streaming mode is supported."}
|
||||
|
||||
# HACK for tiny-agents: it sends a request after the assistant message (???). Let's assume we can't have a
|
||||
# request whose last message is from the assistant.
|
||||
if req.messages[-1].role == "assistant":
|
||||
@ -572,7 +607,86 @@ class ServeCommand(BaseTransformersCLICommand):
|
||||
finally:
|
||||
thread.join()
|
||||
|
||||
return StreamingResponse(stream_response(generation_streamer, request_id), media_type="text/event-stream")
|
||||
return stream_response(generation_streamer, request_id)
|
||||
|
||||
def generate_response(self, req: ResponsesInput) -> Generator:
|
||||
update_model = req.model != self.loaded_model
|
||||
if update_model:
|
||||
self.model, self.tokenizer = self.load_model_and_tokenizer(req.model, self.args)
|
||||
|
||||
text = self.tokenizer.apply_chat_template(req.messages, add_generation_prompt=True, tokenize=False)
|
||||
|
||||
inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device)["input_ids"]
|
||||
request_id = req.request_id if req.request_id is not None else "req_0"
|
||||
|
||||
generation_streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True, skip_prompt=True)
|
||||
|
||||
generation_config = create_generation_config_from_req(req)
|
||||
max_new_tokens = req.max_tokens or generation_config.max_new_tokens or 256
|
||||
generation_config.max_new_tokens = max_new_tokens
|
||||
|
||||
last_kv_cache = None
|
||||
if self.is_continuation(req) and not update_model:
|
||||
last_kv_cache = self.last_kv_cache
|
||||
|
||||
generation_kwargs = {
|
||||
"inputs": inputs,
|
||||
"attention_mask": torch.ones_like(inputs),
|
||||
"streamer": generation_streamer,
|
||||
"generation_config": generation_config,
|
||||
"return_dict_in_generate": True,
|
||||
"past_key_values": last_kv_cache,
|
||||
}
|
||||
|
||||
def stream_response(streamer, _request_id):
|
||||
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
|
||||
|
||||
try:
|
||||
thread.start()
|
||||
for result in streamer:
|
||||
yield self.build_chunk(result, _request_id, role="assistant")
|
||||
yield self.build_chunk(None, _request_id, role=None, finish_reason="stop")
|
||||
|
||||
thread.join()
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
raise
|
||||
yield f'data: {{"error": "{str(e)}"}}'
|
||||
|
||||
finally:
|
||||
thread.join()
|
||||
|
||||
return stream_response(generation_streamer, request_id)
|
||||
|
||||
def is_continuation(self, req: "ChatCompletionInput") -> bool:
|
||||
"""
|
||||
Determines whether the current request is a continuation of the last request. In other words, if it is the
|
||||
same chat session.
|
||||
|
||||
Args:
|
||||
req (`ChatCompletionInput`): The request to check.
|
||||
|
||||
Returns:
|
||||
`True` if the request is a continuation of the last request, `False` otherwise.
|
||||
"""
|
||||
req_continues_last_messages = True
|
||||
|
||||
# No cached messages: this is a new request
|
||||
if self.last_messages is None:
|
||||
req_continues_last_messages = False
|
||||
# The new request has fewer rounds of conversation: this is a new request
|
||||
elif len(self.last_messages) > len(req.messages):
|
||||
req_continues_last_messages = False
|
||||
# Otherwise, check that the last messages are a subset of the new request
|
||||
else:
|
||||
for i in range(len(self.last_messages)):
|
||||
if self.last_messages[i] != req.messages[i]:
|
||||
req_continues_last_messages = False
|
||||
break
|
||||
|
||||
self.last_messages = req.messages
|
||||
return req_continues_last_messages
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_quantization_config(model_args: ServeArguments) -> Optional["BitsAndBytesConfig"]:
|
||||
|
Loading…
Reference in New Issue
Block a user