This commit is contained in:
Lysandre 2025-07-01 13:08:36 +02:00
parent e8f90b5397
commit 733bcb4fed
3 changed files with 41 additions and 31 deletions

View File

@ -31,7 +31,7 @@ from huggingface_hub import (
from transformers.utils.import_utils import is_fastapi_available, is_pydantic_available, is_uvicorn_available
from .. import PreTrainedTokenizerFast, TextIteratorStreamer
from .. import LogitsProcessorList, PreTrainedTokenizerFast, TextIteratorStreamer
from ..generation.continuous_batching import ContinuousBatchingManager, RequestStatus
from ..utils import is_torch_available, logging
from . import BaseTransformersCLICommand
@ -111,7 +111,7 @@ def serve_command_factory(args: Namespace):
return ServeCommand(args)
def create_generation_config_from_req(req: "ChatCompletionInput") -> "GenerationConfig":
def create_generation_config_from_req(req: "ChatCompletionInput", **kwargs) -> "GenerationConfig":
"""
Creates a generation config from the parameters of the request. Note that we can pass a `GenerationConfig`
(serialized into a `dict`) in `extra_body`, for full `generate` parameterization.
@ -125,12 +125,12 @@ def create_generation_config_from_req(req: "ChatCompletionInput") -> "Generation
if req.extra_body is not None and "generation_config" in req.extra_body:
for key in req.extra_body["generation_config"].keys():
if key in ChatCompletionInput.base_field_names.keys():
return {"error": "Duplicated key in the root request and in the passed generation config."}
raise ValueError("error: Duplicated key in the root request and in the passed generation config.")
if req.extra_body is not None and "generation_config" in req.extra_body:
generation_config = GenerationConfig(**(req.extra_body["generation_config"]))
generation_config = GenerationConfig(**(req.extra_body["generation_config"]), **kwargs)
else:
generation_config = GenerationConfig()
generation_config = GenerationConfig(**kwargs)
if req.frequency_penalty is not None:
generation_config.repetition_penalty = req.frequency_penalty
@ -213,6 +213,8 @@ class ServeArguments:
class ServeCommand(BaseTransformersCLICommand):
loaded_model: Optional[str] = None
running_continuous_batching_manager: Optional[ContinuousBatchingManager] = None
model: PreTrainedModel
tokenizer: PreTrainedTokenizerFast
@ -326,46 +328,50 @@ class ServeCommand(BaseTransformersCLICommand):
uvicorn.run(app, host=self.args.host, port=self.args.port, log_level=self.args.log_level)
def continuous_batching(self, app):
generation_config = GenerationConfig(
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id,
use_cache=False,
num_blocks=1,
block_size=1024,
do_sample=False,
max_batch_tokens=10,
scheduler="fifo",
)
manager: ContinuousBatchingManager = self.model.init_continuous_batching(
generation_config=generation_config, streaming=True
)
manager.start()
@app.post("/v1/chat/completions")
def _serve(req: "ChatCompletionInput"):
if not req.stream:
return {"error": "Only streaming mode is supported."}
update_model = req.model != self.loaded_model
if update_model:
self.model, self.tokenizer = self.load_model_and_tokenizer(req.model, self.args)
chat = req.messages
inputs = self.tokenizer.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=True).to(
self.model.device
generation_config = create_generation_config_from_req(
req,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id,
use_cache=False,
num_blocks=1,
block_size=1024,
do_sample=False,
max_batch_tokens=10,
scheduler="fifo",
)
generation_config = create_generation_config_from_req(req)
if self.running_continuous_batching_manager is None or update_model:
self.running_continuous_batching_manager = self.model.init_continuous_batching(
generation_config=generation_config, streaming=True
)
self.running_continuous_batching_manager.logit_processor = LogitsProcessorList()
self.running_continuous_batching_manager.start()
inputs = self.tokenizer.apply_chat_template(
req.messages, return_tensors="pt", add_generation_prompt=True
).to(self.model.device)
def stream_response(_inputs):
try:
max_new_tokens = req.max_tokens or generation_config.max_new_tokens or 256
request_id = manager.add_request(_inputs, request_id=req.request_id, max_new_tokens=max_new_tokens)
request_id = self.running_continuous_batching_manager.add_request(
_inputs, request_id=req.request_id, max_new_tokens=max_new_tokens
)
queue_is_flushed = False
for result in manager:
for result in self.running_continuous_batching_manager:
if result.request_id != request_id:
continue
if req.request_id is not None and not queue_is_flushed:
if result.status == RequestStatus.FINISHED:
continue
@ -373,7 +379,10 @@ class ServeCommand(BaseTransformersCLICommand):
queue_is_flushed = True
finish_reason = "stop" if result.status == RequestStatus.FINISHED else None
yield self.build_chunk(result.next_token, request_id=request_id, finish_reason=finish_reason)
next_token = self.build_chunk(
result.next_token, request_id=request_id, finish_reason=finish_reason
)
yield next_token
if result.status == RequestStatus.FINISHED:
break

View File

@ -1257,6 +1257,7 @@ class ContinuousBatchingManager:
@traced(span_name="logit_processing")
def _process_logit(self, batch_data, logits):
print(self.logit_processor)
return self.logit_processor(batch_data["input_ids"], logits)
@traced(span_name="sampling")

View File

@ -88,9 +88,9 @@ class LogitsProcessorList(list):
f"Make sure that all the required parameters: {list(function_args.keys())} for "
f"{processor.__class__} are passed to the logits processor."
)
scores = processor(input_ids, scores, **kwargs)
scores = processor(input_ids, scores.to(input_ids.dtype), **kwargs)
else:
scores = processor(input_ids, scores)
scores = processor(input_ids, scores.to(input_ids.dtype))
return scores