Random serve fixes (#39176)

* Fix index out of bounds exception on wrong kv reuse

* Prevent loading same model twice

---------

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: Lysandre Debut <hi@lysand.re>
This commit is contained in:
Pedro Cuenca 2025-07-02 22:09:58 +02:00 committed by GitHub
parent 548794b886
commit 25cd65ac43
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -347,7 +347,7 @@ class ServeCommand(BaseTransformersCLICommand):
if not req.stream: if not req.stream:
return {"error": "Only streaming mode is supported."} return {"error": "Only streaming mode is supported."}
update_model = req.model != self.loaded_model update_model = self.canonicalized_model_name(req.model) != self.loaded_model
if update_model: if update_model:
self.model, self.tokenizer = self.load_model_and_tokenizer(req.model, self.args) self.model, self.tokenizer = self.load_model_and_tokenizer(req.model, self.args)
@ -402,7 +402,7 @@ class ServeCommand(BaseTransformersCLICommand):
if self.last_messages is None: if self.last_messages is None:
req_continues_last_messages = False req_continues_last_messages = False
# The new request has fewer rounds of conversation: this is a new request # The new request has fewer rounds of conversation: this is a new request
elif len(self.last_messages) > len(req.messages): elif len(self.last_messages) >= len(req.messages):
req_continues_last_messages = False req_continues_last_messages = False
# Otherwise, check that the last messages are a subset of the new request # Otherwise, check that the last messages are a subset of the new request
else: else:
@ -417,7 +417,7 @@ class ServeCommand(BaseTransformersCLICommand):
def generate(self, app): def generate(self, app):
@app.post("/v1/chat/completions") @app.post("/v1/chat/completions")
def _serve(req: "ChatCompletionInput"): def _serve(req: "ChatCompletionInput"):
update_model = req.model != self.loaded_model update_model = self.canonicalized_model_name(req.model) != self.loaded_model
if update_model: if update_model:
self.model, self.tokenizer = self.load_model_and_tokenizer(req.model, self.args) self.model, self.tokenizer = self.load_model_and_tokenizer(req.model, self.args)
@ -585,6 +585,11 @@ class ServeCommand(BaseTransformersCLICommand):
return quantization_config return quantization_config
def canonicalized_model_name(self, model_id: str) -> str:
if "@" in model_id:
return model_id
return f"{model_id}@main"
def load_model_and_tokenizer( def load_model_and_tokenizer(
self, model_id_and_revision: str, args: ServeArguments self, model_id_and_revision: str, args: ServeArguments
) -> tuple[PreTrainedModel, PreTrainedTokenizerFast]: ) -> tuple[PreTrainedModel, PreTrainedTokenizerFast]:
@ -621,9 +626,9 @@ class ServeCommand(BaseTransformersCLICommand):
if getattr(model, "hf_device_map", None) is None: if getattr(model, "hf_device_map", None) is None:
model = model.to(args.device) model = model.to(args.device)
self.loaded_model = model_id_and_revision self.loaded_model = f"{model_id}@{revision}"
logger.warning(f"Loaded model {model_id_and_revision}") logger.warning(f"Loaded model {self.loaded_model}")
return model, tokenizer return model, tokenizer