diff --git a/src/transformers/commands/serving.py b/src/transformers/commands/serving.py index 2d0ec77f4b4..49f969306f9 100644 --- a/src/transformers/commands/serving.py +++ b/src/transformers/commands/serving.py @@ -31,7 +31,7 @@ from huggingface_hub import ( from transformers.utils.import_utils import is_fastapi_available, is_pydantic_available, is_uvicorn_available -from .. import PreTrainedTokenizerFast, TextIteratorStreamer +from .. import LogitsProcessorList, PreTrainedTokenizerFast, TextIteratorStreamer from ..generation.continuous_batching import ContinuousBatchingManager, RequestStatus from ..utils import is_torch_available, logging from . import BaseTransformersCLICommand @@ -111,7 +111,7 @@ def serve_command_factory(args: Namespace): return ServeCommand(args) -def create_generation_config_from_req(req: "ChatCompletionInput") -> "GenerationConfig": +def create_generation_config_from_req(req: "ChatCompletionInput", **kwargs) -> "GenerationConfig": """ Creates a generation config from the parameters of the request. Note that we can pass a `GenerationConfig` (serialized into a `dict`) in `extra_body`, for full `generate` parameterization. @@ -125,12 +125,12 @@ def create_generation_config_from_req(req: "ChatCompletionInput") -> "Generation if req.extra_body is not None and "generation_config" in req.extra_body: for key in req.extra_body["generation_config"].keys(): if key in ChatCompletionInput.base_field_names.keys(): - return {"error": "Duplicated key in the root request and in the passed generation config."} + raise ValueError("error: Duplicated key in the root request and in the passed generation config.") if req.extra_body is not None and "generation_config" in req.extra_body: - generation_config = GenerationConfig(**(req.extra_body["generation_config"])) + generation_config = GenerationConfig(**(req.extra_body["generation_config"]), **kwargs) else: - generation_config = GenerationConfig() + generation_config = GenerationConfig(**kwargs) if req.frequency_penalty is not None: generation_config.repetition_penalty = req.frequency_penalty @@ -213,6 +213,8 @@ class ServeArguments: class ServeCommand(BaseTransformersCLICommand): loaded_model: Optional[str] = None + running_continuous_batching_manager: Optional[ContinuousBatchingManager] = None + model: PreTrainedModel tokenizer: PreTrainedTokenizerFast @@ -326,46 +328,50 @@ class ServeCommand(BaseTransformersCLICommand): uvicorn.run(app, host=self.args.host, port=self.args.port, log_level=self.args.log_level) def continuous_batching(self, app): - generation_config = GenerationConfig( - eos_token_id=self.tokenizer.eos_token_id, - pad_token_id=self.tokenizer.pad_token_id, - use_cache=False, - num_blocks=1, - block_size=1024, - do_sample=False, - max_batch_tokens=10, - scheduler="fifo", - ) - - manager: ContinuousBatchingManager = self.model.init_continuous_batching( - generation_config=generation_config, streaming=True - ) - manager.start() - @app.post("/v1/chat/completions") def _serve(req: "ChatCompletionInput"): if not req.stream: return {"error": "Only streaming mode is supported."} update_model = req.model != self.loaded_model - if update_model: self.model, self.tokenizer = self.load_model_and_tokenizer(req.model, self.args) - chat = req.messages - inputs = self.tokenizer.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=True).to( - self.model.device + generation_config = create_generation_config_from_req( + req, + eos_token_id=self.tokenizer.eos_token_id, + pad_token_id=self.tokenizer.pad_token_id, + use_cache=False, + num_blocks=1, + block_size=1024, + do_sample=False, + max_batch_tokens=10, + scheduler="fifo", ) - generation_config = create_generation_config_from_req(req) + if self.running_continuous_batching_manager is None or update_model: + self.running_continuous_batching_manager = self.model.init_continuous_batching( + generation_config=generation_config, streaming=True + ) + self.running_continuous_batching_manager.logit_processor = LogitsProcessorList() + self.running_continuous_batching_manager.start() + + inputs = self.tokenizer.apply_chat_template( + req.messages, return_tensors="pt", add_generation_prompt=True + ).to(self.model.device) def stream_response(_inputs): try: max_new_tokens = req.max_tokens or generation_config.max_new_tokens or 256 - request_id = manager.add_request(_inputs, request_id=req.request_id, max_new_tokens=max_new_tokens) + request_id = self.running_continuous_batching_manager.add_request( + _inputs, request_id=req.request_id, max_new_tokens=max_new_tokens + ) queue_is_flushed = False - for result in manager: + for result in self.running_continuous_batching_manager: + if result.request_id != request_id: + continue + if req.request_id is not None and not queue_is_flushed: if result.status == RequestStatus.FINISHED: continue @@ -373,7 +379,10 @@ class ServeCommand(BaseTransformersCLICommand): queue_is_flushed = True finish_reason = "stop" if result.status == RequestStatus.FINISHED else None - yield self.build_chunk(result.next_token, request_id=request_id, finish_reason=finish_reason) + next_token = self.build_chunk( + result.next_token, request_id=request_id, finish_reason=finish_reason + ) + yield next_token if result.status == RequestStatus.FINISHED: break diff --git a/src/transformers/generation/continuous_batching.py b/src/transformers/generation/continuous_batching.py index eae1ee91988..7e8fc2620d2 100644 --- a/src/transformers/generation/continuous_batching.py +++ b/src/transformers/generation/continuous_batching.py @@ -1257,6 +1257,7 @@ class ContinuousBatchingManager: @traced(span_name="logit_processing") def _process_logit(self, batch_data, logits): + print(self.logit_processor) return self.logit_processor(batch_data["input_ids"], logits) @traced(span_name="sampling") diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index d4c08e270bb..05585349ff9 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -88,9 +88,9 @@ class LogitsProcessorList(list): f"Make sure that all the required parameters: {list(function_args.keys())} for " f"{processor.__class__} are passed to the logits processor." ) - scores = processor(input_ids, scores, **kwargs) + scores = processor(input_ids, scores.to(input_ids.dtype), **kwargs) else: - scores = processor(input_ids, scores) + scores = processor(input_ids, scores.to(input_ids.dtype)) return scores