mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix CB
This commit is contained in:
parent
e8f90b5397
commit
733bcb4fed
@ -31,7 +31,7 @@ from huggingface_hub import (
|
||||
|
||||
from transformers.utils.import_utils import is_fastapi_available, is_pydantic_available, is_uvicorn_available
|
||||
|
||||
from .. import PreTrainedTokenizerFast, TextIteratorStreamer
|
||||
from .. import LogitsProcessorList, PreTrainedTokenizerFast, TextIteratorStreamer
|
||||
from ..generation.continuous_batching import ContinuousBatchingManager, RequestStatus
|
||||
from ..utils import is_torch_available, logging
|
||||
from . import BaseTransformersCLICommand
|
||||
@ -111,7 +111,7 @@ def serve_command_factory(args: Namespace):
|
||||
return ServeCommand(args)
|
||||
|
||||
|
||||
def create_generation_config_from_req(req: "ChatCompletionInput") -> "GenerationConfig":
|
||||
def create_generation_config_from_req(req: "ChatCompletionInput", **kwargs) -> "GenerationConfig":
|
||||
"""
|
||||
Creates a generation config from the parameters of the request. Note that we can pass a `GenerationConfig`
|
||||
(serialized into a `dict`) in `extra_body`, for full `generate` parameterization.
|
||||
@ -125,12 +125,12 @@ def create_generation_config_from_req(req: "ChatCompletionInput") -> "Generation
|
||||
if req.extra_body is not None and "generation_config" in req.extra_body:
|
||||
for key in req.extra_body["generation_config"].keys():
|
||||
if key in ChatCompletionInput.base_field_names.keys():
|
||||
return {"error": "Duplicated key in the root request and in the passed generation config."}
|
||||
raise ValueError("error: Duplicated key in the root request and in the passed generation config.")
|
||||
|
||||
if req.extra_body is not None and "generation_config" in req.extra_body:
|
||||
generation_config = GenerationConfig(**(req.extra_body["generation_config"]))
|
||||
generation_config = GenerationConfig(**(req.extra_body["generation_config"]), **kwargs)
|
||||
else:
|
||||
generation_config = GenerationConfig()
|
||||
generation_config = GenerationConfig(**kwargs)
|
||||
|
||||
if req.frequency_penalty is not None:
|
||||
generation_config.repetition_penalty = req.frequency_penalty
|
||||
@ -213,6 +213,8 @@ class ServeArguments:
|
||||
|
||||
class ServeCommand(BaseTransformersCLICommand):
|
||||
loaded_model: Optional[str] = None
|
||||
running_continuous_batching_manager: Optional[ContinuousBatchingManager] = None
|
||||
|
||||
model: PreTrainedModel
|
||||
tokenizer: PreTrainedTokenizerFast
|
||||
|
||||
@ -326,46 +328,50 @@ class ServeCommand(BaseTransformersCLICommand):
|
||||
uvicorn.run(app, host=self.args.host, port=self.args.port, log_level=self.args.log_level)
|
||||
|
||||
def continuous_batching(self, app):
|
||||
generation_config = GenerationConfig(
|
||||
eos_token_id=self.tokenizer.eos_token_id,
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
use_cache=False,
|
||||
num_blocks=1,
|
||||
block_size=1024,
|
||||
do_sample=False,
|
||||
max_batch_tokens=10,
|
||||
scheduler="fifo",
|
||||
)
|
||||
|
||||
manager: ContinuousBatchingManager = self.model.init_continuous_batching(
|
||||
generation_config=generation_config, streaming=True
|
||||
)
|
||||
manager.start()
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
def _serve(req: "ChatCompletionInput"):
|
||||
if not req.stream:
|
||||
return {"error": "Only streaming mode is supported."}
|
||||
|
||||
update_model = req.model != self.loaded_model
|
||||
|
||||
if update_model:
|
||||
self.model, self.tokenizer = self.load_model_and_tokenizer(req.model, self.args)
|
||||
|
||||
chat = req.messages
|
||||
inputs = self.tokenizer.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=True).to(
|
||||
self.model.device
|
||||
generation_config = create_generation_config_from_req(
|
||||
req,
|
||||
eos_token_id=self.tokenizer.eos_token_id,
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
use_cache=False,
|
||||
num_blocks=1,
|
||||
block_size=1024,
|
||||
do_sample=False,
|
||||
max_batch_tokens=10,
|
||||
scheduler="fifo",
|
||||
)
|
||||
|
||||
generation_config = create_generation_config_from_req(req)
|
||||
if self.running_continuous_batching_manager is None or update_model:
|
||||
self.running_continuous_batching_manager = self.model.init_continuous_batching(
|
||||
generation_config=generation_config, streaming=True
|
||||
)
|
||||
self.running_continuous_batching_manager.logit_processor = LogitsProcessorList()
|
||||
self.running_continuous_batching_manager.start()
|
||||
|
||||
inputs = self.tokenizer.apply_chat_template(
|
||||
req.messages, return_tensors="pt", add_generation_prompt=True
|
||||
).to(self.model.device)
|
||||
|
||||
def stream_response(_inputs):
|
||||
try:
|
||||
max_new_tokens = req.max_tokens or generation_config.max_new_tokens or 256
|
||||
request_id = manager.add_request(_inputs, request_id=req.request_id, max_new_tokens=max_new_tokens)
|
||||
request_id = self.running_continuous_batching_manager.add_request(
|
||||
_inputs, request_id=req.request_id, max_new_tokens=max_new_tokens
|
||||
)
|
||||
queue_is_flushed = False
|
||||
|
||||
for result in manager:
|
||||
for result in self.running_continuous_batching_manager:
|
||||
if result.request_id != request_id:
|
||||
continue
|
||||
|
||||
if req.request_id is not None and not queue_is_flushed:
|
||||
if result.status == RequestStatus.FINISHED:
|
||||
continue
|
||||
@ -373,7 +379,10 @@ class ServeCommand(BaseTransformersCLICommand):
|
||||
queue_is_flushed = True
|
||||
|
||||
finish_reason = "stop" if result.status == RequestStatus.FINISHED else None
|
||||
yield self.build_chunk(result.next_token, request_id=request_id, finish_reason=finish_reason)
|
||||
next_token = self.build_chunk(
|
||||
result.next_token, request_id=request_id, finish_reason=finish_reason
|
||||
)
|
||||
yield next_token
|
||||
|
||||
if result.status == RequestStatus.FINISHED:
|
||||
break
|
||||
|
@ -1257,6 +1257,7 @@ class ContinuousBatchingManager:
|
||||
|
||||
@traced(span_name="logit_processing")
|
||||
def _process_logit(self, batch_data, logits):
|
||||
print(self.logit_processor)
|
||||
return self.logit_processor(batch_data["input_ids"], logits)
|
||||
|
||||
@traced(span_name="sampling")
|
||||
|
@ -88,9 +88,9 @@ class LogitsProcessorList(list):
|
||||
f"Make sure that all the required parameters: {list(function_args.keys())} for "
|
||||
f"{processor.__class__} are passed to the logits processor."
|
||||
)
|
||||
scores = processor(input_ids, scores, **kwargs)
|
||||
scores = processor(input_ids, scores.to(input_ids.dtype), **kwargs)
|
||||
else:
|
||||
scores = processor(input_ids, scores)
|
||||
scores = processor(input_ids, scores.to(input_ids.dtype))
|
||||
|
||||
return scores
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user