mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
Scaffolding
This commit is contained in:
parent
733bcb4fed
commit
feb1eb7531
@ -11,6 +11,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import enum
|
||||||
import functools
|
import functools
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
@ -18,7 +19,7 @@ import time
|
|||||||
from argparse import ArgumentParser, Namespace
|
from argparse import ArgumentParser, Namespace
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional, Generator, Literal
|
||||||
|
|
||||||
from huggingface_hub import (
|
from huggingface_hub import (
|
||||||
ChatCompletionStreamOutputChoice,
|
ChatCompletionStreamOutputChoice,
|
||||||
@ -26,7 +27,7 @@ from huggingface_hub import (
|
|||||||
ChatCompletionStreamOutputDeltaToolCall,
|
ChatCompletionStreamOutputDeltaToolCall,
|
||||||
ChatCompletionStreamOutputFunction,
|
ChatCompletionStreamOutputFunction,
|
||||||
ModelInfo,
|
ModelInfo,
|
||||||
model_info,
|
model_info, ChatCompletionStreamOutput,
|
||||||
)
|
)
|
||||||
|
|
||||||
from transformers.utils.import_utils import is_fastapi_available, is_pydantic_available, is_uvicorn_available
|
from transformers.utils.import_utils import is_fastapi_available, is_pydantic_available, is_uvicorn_available
|
||||||
@ -59,11 +60,42 @@ if is_pydantic_available() and is_fastapi_available() and is_uvicorn_available()
|
|||||||
role: str
|
role: str
|
||||||
content: str
|
content: str
|
||||||
|
|
||||||
class ChatCompletionInput(BaseModel):
|
class Prompt(BaseModel):
|
||||||
messages: list[Message]
|
id: str
|
||||||
|
variables: dict
|
||||||
|
version: Optional[str]
|
||||||
|
|
||||||
|
class TextFormatOptions(enum.StrEnum):
|
||||||
|
text = "text"
|
||||||
|
json_schema = "json_schema"
|
||||||
|
|
||||||
|
class TextFormat(BaseModel):
|
||||||
|
type: TextFormatOptions
|
||||||
|
|
||||||
|
class ResponsesInput(BaseModel):
|
||||||
|
input: str | list
|
||||||
|
model: str
|
||||||
|
|
||||||
|
stream: Optional[bool] = False
|
||||||
|
instructions: Optional[str] = None
|
||||||
|
max_output_tokens: Optional[int] = None
|
||||||
|
max_tool_calls: Optional[int] = None
|
||||||
|
previous_response_id: Optional[str] = None
|
||||||
|
prompt: Optional[Prompt] = None
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
text: Optional[TextFormat] = None
|
||||||
|
tools: any = None
|
||||||
|
top_p: Optional[float] = None
|
||||||
|
|
||||||
|
# Additional options supported by the Responses API
|
||||||
|
# that aren't yet supported here.
|
||||||
|
# top_logprobs
|
||||||
|
|
||||||
|
class ChatCompletionInput(BaseModel):
|
||||||
|
messages: list[Message]
|
||||||
|
model: str
|
||||||
|
|
||||||
stream: Optional[bool] = False
|
stream: Optional[bool] = False
|
||||||
model: Optional[str] = None
|
|
||||||
request_id: Optional[str] = None
|
request_id: Optional[str] = None
|
||||||
extra_body: Optional[dict] = None
|
extra_body: Optional[dict] = None
|
||||||
frequency_penalty: Optional[float] = None
|
frequency_penalty: Optional[float] = None
|
||||||
@ -250,7 +282,7 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
cb_logger = logging.get_logger("transformers.generation.continuous_batching")
|
cb_logger = logging.get_logger("transformers.generation.continuous_batching")
|
||||||
cb_logger.setLevel(logging.log_levels[self.args.log_level.lower()])
|
cb_logger.setLevel(logging.log_levels[self.args.log_level.lower()])
|
||||||
|
|
||||||
def build_chunk(
|
def build_chat_completion_chunk(
|
||||||
self,
|
self,
|
||||||
content: str,
|
content: str,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
@ -279,34 +311,56 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
}
|
}
|
||||||
return f"data: {json.dumps(payload)}\n\n"
|
return f"data: {json.dumps(payload)}\n\n"
|
||||||
|
|
||||||
|
def build_responses_chunk(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
request_id: str,
|
||||||
|
role: Optional[str] = None,
|
||||||
|
finish_reason: Optional[str] = None,
|
||||||
|
tool_calls: Optional[list[ChatCompletionStreamOutputDeltaToolCall]] = None,
|
||||||
|
) -> str:
|
||||||
|
payload = {
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"id": request_id,
|
||||||
|
"created": int(time.time()),
|
||||||
|
"model": self.loaded_model,
|
||||||
|
"system_fingerprint": "",
|
||||||
|
"choices": [
|
||||||
|
ChatCompletionStreamOutputChoice(
|
||||||
|
delta=ChatCompletionStreamOutputDelta(
|
||||||
|
role=role,
|
||||||
|
content=content,
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
),
|
||||||
|
index=0,
|
||||||
|
logprobs=None,
|
||||||
|
finish_reason=finish_reason,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
}
|
||||||
|
return f"data: {json.dumps(payload)}\n\n"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
if self.use_continuous_batching:
|
@app.get("/v1/chat/completions")
|
||||||
self.continuous_batching(app)
|
def chat_completion(req: ChatCompletionInput):
|
||||||
else:
|
if not req.stream:
|
||||||
self.generate(app)
|
return {"error": "Only streaming mode is supported."}
|
||||||
|
|
||||||
@functools.lru_cache(maxsize=None)
|
output = self.continuous_batching(req) if self.use_continuous_batching else self.generate(req)
|
||||||
def get_text_gen_models() -> list[ModelInfo]:
|
|
||||||
"""
|
|
||||||
This is by no means a limit to which models may be instantiated with `transformers serve`: any chat-based
|
|
||||||
model working with generate can work.
|
|
||||||
|
|
||||||
This is a limited list of models to ensure we have a discoverable /v1/models endpoint for third-party
|
return StreamingResponse(output, media_type="text/event-stream")
|
||||||
integrations.
|
|
||||||
"""
|
@app.get("/v1/responses")
|
||||||
return [
|
def responses(req: ResponsesInput):
|
||||||
model_info("Menlo/Jan-nano"),
|
if not req.stream:
|
||||||
model_info("Menlo/Jan-nano-128k"),
|
return {"error": "Only streaming mode is supported."}
|
||||||
model_info("Qwen/Qwen2.5-0.5B-Instruct"),
|
|
||||||
model_info("Qwen/Qwen2.5-3B-Instruct"),
|
output = self.generate_responses(req)
|
||||||
model_info("Qwen/Qwen2.5-7B-Instruct"),
|
return StreamingResponse(output, media_type="text/event-stream")
|
||||||
model_info("Qwen/Qwen2.5-14B-Instruct"),
|
|
||||||
model_info("meta-llama/Llama-3.1-8B-Instruct"),
|
|
||||||
model_info("meta-llama/Llama-3.2-1B-Instruct"),
|
|
||||||
model_info("meta-llama/Llama-3.3-70B-Instruct"),
|
|
||||||
]
|
|
||||||
|
|
||||||
@app.get("/v1/models")
|
@app.get("/v1/models")
|
||||||
def get_all_models():
|
def get_all_models():
|
||||||
@ -320,79 +374,289 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
"crated": model.created_at.timestamp(),
|
"crated": model.created_at.timestamp(),
|
||||||
"owned_by": model.author,
|
"owned_by": model.author,
|
||||||
}
|
}
|
||||||
for model in get_text_gen_models()
|
for model in self.get_text_gen_models()
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
uvicorn.run(app, host=self.args.host, port=self.args.port, log_level=self.args.log_level)
|
uvicorn.run(app, host=self.args.host, port=self.args.port, log_level=self.args.log_level)
|
||||||
|
|
||||||
def continuous_batching(self, app):
|
@functools.lru_cache(maxsize=None)
|
||||||
@app.post("/v1/chat/completions")
|
def get_text_gen_models(self) -> list[ModelInfo]:
|
||||||
def _serve(req: "ChatCompletionInput"):
|
"""
|
||||||
if not req.stream:
|
This is by no means a limit to which models may be instantiated with `transformers serve`: any chat-based
|
||||||
return {"error": "Only streaming mode is supported."}
|
model working with generate can work.
|
||||||
|
|
||||||
update_model = req.model != self.loaded_model
|
This is a limited list of models to ensure we have a discoverable /v1/models endpoint for third-party
|
||||||
if update_model:
|
integrations.
|
||||||
self.model, self.tokenizer = self.load_model_and_tokenizer(req.model, self.args)
|
"""
|
||||||
|
return [
|
||||||
|
model_info("Menlo/Jan-nano"),
|
||||||
|
model_info("Menlo/Jan-nano-128k"),
|
||||||
|
model_info("Qwen/Qwen2.5-0.5B-Instruct"),
|
||||||
|
model_info("Qwen/Qwen2.5-3B-Instruct"),
|
||||||
|
model_info("Qwen/Qwen2.5-7B-Instruct"),
|
||||||
|
model_info("Qwen/Qwen2.5-14B-Instruct"),
|
||||||
|
model_info("meta-llama/Llama-3.1-8B-Instruct"),
|
||||||
|
model_info("meta-llama/Llama-3.2-1B-Instruct"),
|
||||||
|
model_info("meta-llama/Llama-3.3-70B-Instruct"),
|
||||||
|
]
|
||||||
|
|
||||||
generation_config = create_generation_config_from_req(
|
def continuous_batching(self, req: ChatCompletionInput) -> Generator:
|
||||||
req,
|
update_model = req.model != self.loaded_model
|
||||||
eos_token_id=self.tokenizer.eos_token_id,
|
if update_model:
|
||||||
pad_token_id=self.tokenizer.pad_token_id,
|
self.model, self.tokenizer = self.load_model_and_tokenizer(req.model, self.args)
|
||||||
use_cache=False,
|
|
||||||
num_blocks=1,
|
generation_config = create_generation_config_from_req(
|
||||||
block_size=1024,
|
req,
|
||||||
do_sample=False,
|
eos_token_id=self.tokenizer.eos_token_id,
|
||||||
max_batch_tokens=10,
|
pad_token_id=self.tokenizer.pad_token_id,
|
||||||
scheduler="fifo",
|
use_cache=False,
|
||||||
|
num_blocks=1,
|
||||||
|
block_size=1024,
|
||||||
|
do_sample=False,
|
||||||
|
max_batch_tokens=10,
|
||||||
|
scheduler="fifo",
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.running_continuous_batching_manager is None or update_model:
|
||||||
|
self.running_continuous_batching_manager = self.model.init_continuous_batching(
|
||||||
|
generation_config=generation_config, streaming=True
|
||||||
)
|
)
|
||||||
|
self.running_continuous_batching_manager.logit_processor = LogitsProcessorList()
|
||||||
|
self.running_continuous_batching_manager.start()
|
||||||
|
|
||||||
if self.running_continuous_batching_manager is None or update_model:
|
inputs = self.tokenizer.apply_chat_template(
|
||||||
self.running_continuous_batching_manager = self.model.init_continuous_batching(
|
req.messages, return_tensors="pt", add_generation_prompt=True
|
||||||
generation_config=generation_config, streaming=True
|
).to(self.model.device)
|
||||||
|
|
||||||
|
def stream_response(_inputs):
|
||||||
|
try:
|
||||||
|
max_new_tokens = req.max_tokens or generation_config.max_new_tokens or 256
|
||||||
|
request_id = self.running_continuous_batching_manager.add_request(
|
||||||
|
_inputs, request_id=req.request_id, max_new_tokens=max_new_tokens
|
||||||
)
|
)
|
||||||
self.running_continuous_batching_manager.logit_processor = LogitsProcessorList()
|
queue_is_flushed = False
|
||||||
self.running_continuous_batching_manager.start()
|
|
||||||
|
|
||||||
inputs = self.tokenizer.apply_chat_template(
|
for result in self.running_continuous_batching_manager:
|
||||||
req.messages, return_tensors="pt", add_generation_prompt=True
|
if result.request_id != request_id:
|
||||||
).to(self.model.device)
|
continue
|
||||||
|
|
||||||
def stream_response(_inputs):
|
if req.request_id is not None and not queue_is_flushed:
|
||||||
try:
|
if result.status == RequestStatus.FINISHED:
|
||||||
max_new_tokens = req.max_tokens or generation_config.max_new_tokens or 256
|
continue
|
||||||
request_id = self.running_continuous_batching_manager.add_request(
|
else:
|
||||||
_inputs, request_id=req.request_id, max_new_tokens=max_new_tokens
|
queue_is_flushed = True
|
||||||
|
|
||||||
|
finish_reason = "stop" if result.status == RequestStatus.FINISHED else None
|
||||||
|
next_token = self.build_chunk(
|
||||||
|
result.next_token, request_id=request_id, finish_reason=finish_reason
|
||||||
)
|
)
|
||||||
queue_is_flushed = False
|
yield next_token
|
||||||
|
|
||||||
for result in self.running_continuous_batching_manager:
|
if result.status == RequestStatus.FINISHED:
|
||||||
if result.request_id != request_id:
|
break
|
||||||
|
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(str(e))
|
||||||
|
yield f'data: {{"error": "{str(e)}"}}'
|
||||||
|
|
||||||
|
return stream_response(inputs[0])
|
||||||
|
|
||||||
|
def generate(self, req: ChatCompletionInput) -> Generator:
|
||||||
|
update_model = req.model != self.loaded_model
|
||||||
|
if update_model:
|
||||||
|
self.model, self.tokenizer = self.load_model_and_tokenizer(req.model, self.args)
|
||||||
|
|
||||||
|
# HACK for tiny-agents: it sends a request after the assistant message (???). Let's assume we can't have a
|
||||||
|
# request whose last message is from the assistant.
|
||||||
|
if req.messages[-1].role == "assistant":
|
||||||
|
return
|
||||||
|
|
||||||
|
# ====== TOOL PREPROCESSING LOGIC ======
|
||||||
|
tool_model_family = None
|
||||||
|
for supported_model_families in _MODELS_WITH_TOOL_SUPPORT:
|
||||||
|
if supported_model_families in self.model.config.architectures[0].lower():
|
||||||
|
tool_model_family = supported_model_families
|
||||||
|
break
|
||||||
|
# TODO: trigger 2 constrained generations after the tool call start token is emitted:
|
||||||
|
# 1. force generation to pick from the tool names
|
||||||
|
# 2. force generation to pick from that tool's arguments
|
||||||
|
# ====== END OF TOOL PREPROCESSING LOGIC ======
|
||||||
|
|
||||||
|
if tool_model_family is not None:
|
||||||
|
text = self.tokenizer.apply_chat_template(
|
||||||
|
req.messages, add_generation_prompt=True, tokenize=False, tools=req.tools
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
text = self.tokenizer.apply_chat_template(req.messages, add_generation_prompt=True, tokenize=False)
|
||||||
|
|
||||||
|
inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device)["input_ids"]
|
||||||
|
request_id = req.request_id if req.request_id is not None else "req_0"
|
||||||
|
|
||||||
|
generation_streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True, skip_prompt=True)
|
||||||
|
|
||||||
|
generation_config = create_generation_config_from_req(req)
|
||||||
|
max_new_tokens = req.max_tokens or generation_config.max_new_tokens or 256
|
||||||
|
generation_config.max_new_tokens = max_new_tokens
|
||||||
|
|
||||||
|
last_kv_cache = None
|
||||||
|
if self.is_continuation(req) and not update_model:
|
||||||
|
last_kv_cache = self.last_kv_cache
|
||||||
|
|
||||||
|
generation_kwargs = {
|
||||||
|
"inputs": inputs,
|
||||||
|
"attention_mask": torch.ones_like(inputs),
|
||||||
|
"streamer": generation_streamer,
|
||||||
|
"generation_config": generation_config,
|
||||||
|
"return_dict_in_generate": True,
|
||||||
|
"past_key_values": last_kv_cache,
|
||||||
|
}
|
||||||
|
|
||||||
|
def stream_response(streamer, _request_id):
|
||||||
|
# Thin wrapper to save the KV cache after generation
|
||||||
|
def generate_with_cache(**kwargs):
|
||||||
|
generate_output = self.model.generate(**kwargs)
|
||||||
|
self.last_kv_cache = generate_output.past_key_values
|
||||||
|
|
||||||
|
thread = Thread(target=generate_with_cache, kwargs=generation_kwargs)
|
||||||
|
|
||||||
|
try:
|
||||||
|
thread.start()
|
||||||
|
tool_state = ToolState()
|
||||||
|
|
||||||
|
for result in streamer:
|
||||||
|
# ====== TOOL CALL LOGIC ======
|
||||||
|
if tool_model_family is not None:
|
||||||
|
# Start of a tool call: reset state variables, set `inside_tool_call`
|
||||||
|
if result.strip() == _TOOL_CALL_TOKENS[tool_model_family]["start"]:
|
||||||
|
tool_state.inside_tool_call = True
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if req.request_id is not None and not queue_is_flushed:
|
# End of tool call: reset `inside_tool_call`, emit a `finish_reason`
|
||||||
if result.status == RequestStatus.FINISHED:
|
if result.strip() == _TOOL_CALL_TOKENS[tool_model_family]["end"]:
|
||||||
continue
|
tool_state.reset()
|
||||||
|
yield self.build_chunk("", _request_id, role=None, finish_reason="tool_calls")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Inside a tool call
|
||||||
|
if tool_state.inside_tool_call:
|
||||||
|
tool_state.buffer += result
|
||||||
|
|
||||||
|
# First step: extract the tool name (may need several tokens, and we can't emit a delta
|
||||||
|
# until we have the full name)
|
||||||
|
if not tool_state.has_tool_name_defined:
|
||||||
|
tool_name = re.search(r"\"name\": \"(.*?)\"", tool_state.buffer)
|
||||||
|
if tool_name is None:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
tool_name = tool_name.group(1)
|
||||||
|
tool_state.has_tool_name_defined = True
|
||||||
|
tool = ChatCompletionStreamOutputDeltaToolCall(
|
||||||
|
function=ChatCompletionStreamOutputFunction(
|
||||||
|
name=tool_name,
|
||||||
|
arguments=None,
|
||||||
|
),
|
||||||
|
index=0,
|
||||||
|
type="function",
|
||||||
|
id=_request_id + "_tool_call", # Only the first tool call delta has an id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Second step: extract tool arguments. The tool arguments can be seen as a json string
|
||||||
|
# within the tool json string. We emit a delta for the arguments.
|
||||||
else:
|
else:
|
||||||
queue_is_flushed = True
|
# Empty text: skip
|
||||||
|
if result == "":
|
||||||
|
continue
|
||||||
|
# Until we see the `"arguments": {` in the buffer, we skip
|
||||||
|
# TODO: other models will likely need more elaborate processing here
|
||||||
|
if '"arguments": {' not in tool_state.buffer:
|
||||||
|
continue
|
||||||
|
|
||||||
finish_reason = "stop" if result.status == RequestStatus.FINISHED else None
|
# Handle nesting. We want to exclude the last } from the emitted arguments (it's
|
||||||
next_token = self.build_chunk(
|
# closing the outermost nesting level, outside the arguments block)
|
||||||
result.next_token, request_id=request_id, finish_reason=finish_reason
|
tool_state.arg_nesting_level += result.count("{")
|
||||||
)
|
tool_state.arg_nesting_level -= result.count("}")
|
||||||
yield next_token
|
if tool_state.arg_nesting_level < 0:
|
||||||
|
result = "".join(result.split("}")[:-2]) + "}" # e.g. "4}}\n" -> "4}"
|
||||||
|
|
||||||
if result.status == RequestStatus.FINISHED:
|
tool = ChatCompletionStreamOutputDeltaToolCall(
|
||||||
break
|
function=ChatCompletionStreamOutputFunction(
|
||||||
|
arguments=result,
|
||||||
|
),
|
||||||
|
index=0,
|
||||||
|
type="function",
|
||||||
|
id=None,
|
||||||
|
)
|
||||||
|
|
||||||
yield "data: [DONE]\n\n"
|
yield self.build_chunk(None, _request_id, role=None, tool_calls=[tool])
|
||||||
except Exception as e:
|
continue
|
||||||
logger.error(str(e))
|
# ====== END OF TOOL CALL LOGIC ======
|
||||||
yield f'data: {{"error": "{str(e)}"}}'
|
|
||||||
|
|
||||||
return StreamingResponse(stream_response(inputs[0]), media_type="text/event-stream")
|
# All non-tool related tokens are emitted as assistant messages
|
||||||
|
yield self.build_chunk(result, _request_id, role="assistant")
|
||||||
|
yield self.build_chunk(None, _request_id, role=None, finish_reason="stop")
|
||||||
|
|
||||||
|
thread.join()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(str(e))
|
||||||
|
raise
|
||||||
|
yield f'data: {{"error": "{str(e)}"}}'
|
||||||
|
|
||||||
|
finally:
|
||||||
|
thread.join()
|
||||||
|
|
||||||
|
return stream_response(generation_streamer, request_id)
|
||||||
|
|
||||||
|
def generate_response(self, req: ResponsesInput) -> Generator:
|
||||||
|
update_model = req.model != self.loaded_model
|
||||||
|
if update_model:
|
||||||
|
self.model, self.tokenizer = self.load_model_and_tokenizer(req.model, self.args)
|
||||||
|
|
||||||
|
text = self.tokenizer.apply_chat_template(req.messages, add_generation_prompt=True, tokenize=False)
|
||||||
|
|
||||||
|
inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device)["input_ids"]
|
||||||
|
request_id = req.request_id if req.request_id is not None else "req_0"
|
||||||
|
|
||||||
|
generation_streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True, skip_prompt=True)
|
||||||
|
|
||||||
|
generation_config = create_generation_config_from_req(req)
|
||||||
|
max_new_tokens = req.max_tokens or generation_config.max_new_tokens or 256
|
||||||
|
generation_config.max_new_tokens = max_new_tokens
|
||||||
|
|
||||||
|
last_kv_cache = None
|
||||||
|
if self.is_continuation(req) and not update_model:
|
||||||
|
last_kv_cache = self.last_kv_cache
|
||||||
|
|
||||||
|
generation_kwargs = {
|
||||||
|
"inputs": inputs,
|
||||||
|
"attention_mask": torch.ones_like(inputs),
|
||||||
|
"streamer": generation_streamer,
|
||||||
|
"generation_config": generation_config,
|
||||||
|
"return_dict_in_generate": True,
|
||||||
|
"past_key_values": last_kv_cache,
|
||||||
|
}
|
||||||
|
|
||||||
|
def stream_response(streamer, _request_id):
|
||||||
|
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
|
||||||
|
|
||||||
|
try:
|
||||||
|
thread.start()
|
||||||
|
for result in streamer:
|
||||||
|
yield self.build_chunk(result, _request_id, role="assistant")
|
||||||
|
yield self.build_chunk(None, _request_id, role=None, finish_reason="stop")
|
||||||
|
|
||||||
|
thread.join()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(str(e))
|
||||||
|
raise
|
||||||
|
yield f'data: {{"error": "{str(e)}"}}'
|
||||||
|
|
||||||
|
finally:
|
||||||
|
thread.join()
|
||||||
|
|
||||||
|
return stream_response(generation_streamer, request_id)
|
||||||
|
|
||||||
def is_continuation(self, req: "ChatCompletionInput") -> bool:
|
def is_continuation(self, req: "ChatCompletionInput") -> bool:
|
||||||
"""
|
"""
|
||||||
@ -423,156 +687,6 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
self.last_messages = req.messages
|
self.last_messages = req.messages
|
||||||
return req_continues_last_messages
|
return req_continues_last_messages
|
||||||
|
|
||||||
def generate(self, app):
|
|
||||||
@app.post("/v1/chat/completions")
|
|
||||||
def _serve(req: "ChatCompletionInput"):
|
|
||||||
update_model = req.model != self.loaded_model
|
|
||||||
|
|
||||||
if update_model:
|
|
||||||
self.model, self.tokenizer = self.load_model_and_tokenizer(req.model, self.args)
|
|
||||||
|
|
||||||
if not req.stream:
|
|
||||||
return {"error": "Only streaming mode is supported."}
|
|
||||||
|
|
||||||
# HACK for tiny-agents: it sends a request after the assistant message (???). Let's assume we can't have a
|
|
||||||
# request whose last message is from the assistant.
|
|
||||||
if req.messages[-1].role == "assistant":
|
|
||||||
return
|
|
||||||
|
|
||||||
# ====== TOOL PREPROCESSING LOGIC ======
|
|
||||||
tool_model_family = None
|
|
||||||
for supported_model_families in _MODELS_WITH_TOOL_SUPPORT:
|
|
||||||
if supported_model_families in self.model.config.architectures[0].lower():
|
|
||||||
tool_model_family = supported_model_families
|
|
||||||
break
|
|
||||||
# TODO: trigger 2 constrained generations after the tool call start token is emitted:
|
|
||||||
# 1. force generation to pick from the tool names
|
|
||||||
# 2. force generation to pick from that tool's arguments
|
|
||||||
# ====== END OF TOOL PREPROCESSING LOGIC ======
|
|
||||||
|
|
||||||
if tool_model_family is not None:
|
|
||||||
text = self.tokenizer.apply_chat_template(
|
|
||||||
req.messages, add_generation_prompt=True, tokenize=False, tools=req.tools
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
text = self.tokenizer.apply_chat_template(req.messages, add_generation_prompt=True, tokenize=False)
|
|
||||||
|
|
||||||
inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device)["input_ids"]
|
|
||||||
request_id = req.request_id if req.request_id is not None else "req_0"
|
|
||||||
|
|
||||||
generation_streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True, skip_prompt=True)
|
|
||||||
|
|
||||||
generation_config = create_generation_config_from_req(req)
|
|
||||||
max_new_tokens = req.max_tokens or generation_config.max_new_tokens or 256
|
|
||||||
generation_config.max_new_tokens = max_new_tokens
|
|
||||||
|
|
||||||
last_kv_cache = None
|
|
||||||
if self.is_continuation(req) and not update_model:
|
|
||||||
last_kv_cache = self.last_kv_cache
|
|
||||||
|
|
||||||
generation_kwargs = {
|
|
||||||
"inputs": inputs,
|
|
||||||
"attention_mask": torch.ones_like(inputs),
|
|
||||||
"streamer": generation_streamer,
|
|
||||||
"generation_config": generation_config,
|
|
||||||
"return_dict_in_generate": True,
|
|
||||||
"past_key_values": last_kv_cache,
|
|
||||||
}
|
|
||||||
|
|
||||||
def stream_response(streamer, _request_id):
|
|
||||||
# Thin wrapper to save the KV cache after generation
|
|
||||||
def generate_with_cache(**kwargs):
|
|
||||||
generate_output = self.model.generate(**kwargs)
|
|
||||||
self.last_kv_cache = generate_output.past_key_values
|
|
||||||
|
|
||||||
thread = Thread(target=generate_with_cache, kwargs=generation_kwargs)
|
|
||||||
|
|
||||||
try:
|
|
||||||
thread.start()
|
|
||||||
tool_state = ToolState()
|
|
||||||
|
|
||||||
for result in streamer:
|
|
||||||
# ====== TOOL CALL LOGIC ======
|
|
||||||
if tool_model_family is not None:
|
|
||||||
# Start of a tool call: reset state variables, set `inside_tool_call`
|
|
||||||
if result.strip() == _TOOL_CALL_TOKENS[tool_model_family]["start"]:
|
|
||||||
tool_state.inside_tool_call = True
|
|
||||||
continue
|
|
||||||
|
|
||||||
# End of tool call: reset `inside_tool_call`, emit a `finish_reason`
|
|
||||||
if result.strip() == _TOOL_CALL_TOKENS[tool_model_family]["end"]:
|
|
||||||
tool_state.reset()
|
|
||||||
yield self.build_chunk("", _request_id, role=None, finish_reason="tool_calls")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Inside a tool call
|
|
||||||
if tool_state.inside_tool_call:
|
|
||||||
tool_state.buffer += result
|
|
||||||
|
|
||||||
# First step: extract the tool name (may need several tokens, and we can't emit a delta
|
|
||||||
# until we have the full name)
|
|
||||||
if not tool_state.has_tool_name_defined:
|
|
||||||
tool_name = re.search(r"\"name\": \"(.*?)\"", tool_state.buffer)
|
|
||||||
if tool_name is None:
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
tool_name = tool_name.group(1)
|
|
||||||
tool_state.has_tool_name_defined = True
|
|
||||||
tool = ChatCompletionStreamOutputDeltaToolCall(
|
|
||||||
function=ChatCompletionStreamOutputFunction(
|
|
||||||
name=tool_name,
|
|
||||||
arguments=None,
|
|
||||||
),
|
|
||||||
index=0,
|
|
||||||
type="function",
|
|
||||||
id=_request_id + "_tool_call", # Only the first tool call delta has an id
|
|
||||||
)
|
|
||||||
|
|
||||||
# Second step: extract tool arguments. The tool arguments can be seen as a json string
|
|
||||||
# within the tool json string. We emit a delta for the arguments.
|
|
||||||
else:
|
|
||||||
# Empty text: skip
|
|
||||||
if result == "":
|
|
||||||
continue
|
|
||||||
# Until we see the `"arguments": {` in the buffer, we skip
|
|
||||||
# TODO: other models will likely need more elaborate processing here
|
|
||||||
if '"arguments": {' not in tool_state.buffer:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Handle nesting. We want to exclude the last } from the emitted arguments (it's
|
|
||||||
# closing the outermost nesting level, outside the arguments block)
|
|
||||||
tool_state.arg_nesting_level += result.count("{")
|
|
||||||
tool_state.arg_nesting_level -= result.count("}")
|
|
||||||
if tool_state.arg_nesting_level < 0:
|
|
||||||
result = "".join(result.split("}")[:-2]) + "}" # e.g. "4}}\n" -> "4}"
|
|
||||||
|
|
||||||
tool = ChatCompletionStreamOutputDeltaToolCall(
|
|
||||||
function=ChatCompletionStreamOutputFunction(
|
|
||||||
arguments=result,
|
|
||||||
),
|
|
||||||
index=0,
|
|
||||||
type="function",
|
|
||||||
id=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
yield self.build_chunk(None, _request_id, role=None, tool_calls=[tool])
|
|
||||||
continue
|
|
||||||
# ====== END OF TOOL CALL LOGIC ======
|
|
||||||
|
|
||||||
# All non-tool related tokens are emitted as assistant messages
|
|
||||||
yield self.build_chunk(result, _request_id, role="assistant")
|
|
||||||
yield self.build_chunk(None, _request_id, role=None, finish_reason="stop")
|
|
||||||
|
|
||||||
thread.join()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(str(e))
|
|
||||||
raise
|
|
||||||
yield f'data: {{"error": "{str(e)}"}}'
|
|
||||||
|
|
||||||
finally:
|
|
||||||
thread.join()
|
|
||||||
|
|
||||||
return StreamingResponse(stream_response(generation_streamer, request_id), media_type="text/event-stream")
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_quantization_config(model_args: ServeArguments) -> Optional["BitsAndBytesConfig"]:
|
def get_quantization_config(model_args: ServeArguments) -> Optional["BitsAndBytesConfig"]:
|
||||||
|
Loading…
Reference in New Issue
Block a user