From 25cd65ac43ee1a96cef4692bda0b110d1e3c6903 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Wed, 2 Jul 2025 22:09:58 +0200 Subject: [PATCH] Random serve fixes (#39176) * Fix index out of bounds exception on wrong kv reuse * Prevent loading same model twice --------- Co-authored-by: Joao Gante Co-authored-by: Lysandre Debut --- src/transformers/commands/serving.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/transformers/commands/serving.py b/src/transformers/commands/serving.py index f8b4131a463..d8f61603692 100644 --- a/src/transformers/commands/serving.py +++ b/src/transformers/commands/serving.py @@ -347,7 +347,7 @@ class ServeCommand(BaseTransformersCLICommand): if not req.stream: return {"error": "Only streaming mode is supported."} - update_model = req.model != self.loaded_model + update_model = self.canonicalized_model_name(req.model) != self.loaded_model if update_model: self.model, self.tokenizer = self.load_model_and_tokenizer(req.model, self.args) @@ -402,7 +402,7 @@ class ServeCommand(BaseTransformersCLICommand): if self.last_messages is None: req_continues_last_messages = False # The new request has fewer rounds of conversation: this is a new request - elif len(self.last_messages) > len(req.messages): + elif len(self.last_messages) >= len(req.messages): req_continues_last_messages = False # Otherwise, check that the last messages are a subset of the new request else: @@ -417,7 +417,7 @@ class ServeCommand(BaseTransformersCLICommand): def generate(self, app): @app.post("/v1/chat/completions") def _serve(req: "ChatCompletionInput"): - update_model = req.model != self.loaded_model + update_model = self.canonicalized_model_name(req.model) != self.loaded_model if update_model: self.model, self.tokenizer = self.load_model_and_tokenizer(req.model, self.args) @@ -585,6 +585,11 @@ class ServeCommand(BaseTransformersCLICommand): return quantization_config + def canonicalized_model_name(self, model_id: str) -> str: + if "@" in model_id: + return model_id + return f"{model_id}@main" + def load_model_and_tokenizer( self, model_id_and_revision: str, args: ServeArguments ) -> tuple[PreTrainedModel, PreTrainedTokenizerFast]: @@ -621,9 +626,9 @@ class ServeCommand(BaseTransformersCLICommand): if getattr(model, "hf_device_map", None) is None: model = model.to(args.device) - self.loaded_model = model_id_and_revision + self.loaded_model = f"{model_id}@{revision}" - logger.warning(f"Loaded model {model_id_and_revision}") + logger.warning(f"Loaded model {self.loaded_model}") return model, tokenizer