mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
Random serve fixes (#39176)
* Fix index out of bounds exception on wrong kv reuse * Prevent loading same model twice --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> Co-authored-by: Lysandre Debut <hi@lysand.re>
This commit is contained in:
parent
548794b886
commit
25cd65ac43
@ -347,7 +347,7 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
if not req.stream:
|
if not req.stream:
|
||||||
return {"error": "Only streaming mode is supported."}
|
return {"error": "Only streaming mode is supported."}
|
||||||
|
|
||||||
update_model = req.model != self.loaded_model
|
update_model = self.canonicalized_model_name(req.model) != self.loaded_model
|
||||||
|
|
||||||
if update_model:
|
if update_model:
|
||||||
self.model, self.tokenizer = self.load_model_and_tokenizer(req.model, self.args)
|
self.model, self.tokenizer = self.load_model_and_tokenizer(req.model, self.args)
|
||||||
@ -402,7 +402,7 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
if self.last_messages is None:
|
if self.last_messages is None:
|
||||||
req_continues_last_messages = False
|
req_continues_last_messages = False
|
||||||
# The new request has fewer rounds of conversation: this is a new request
|
# The new request has fewer rounds of conversation: this is a new request
|
||||||
elif len(self.last_messages) > len(req.messages):
|
elif len(self.last_messages) >= len(req.messages):
|
||||||
req_continues_last_messages = False
|
req_continues_last_messages = False
|
||||||
# Otherwise, check that the last messages are a subset of the new request
|
# Otherwise, check that the last messages are a subset of the new request
|
||||||
else:
|
else:
|
||||||
@ -417,7 +417,7 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
def generate(self, app):
|
def generate(self, app):
|
||||||
@app.post("/v1/chat/completions")
|
@app.post("/v1/chat/completions")
|
||||||
def _serve(req: "ChatCompletionInput"):
|
def _serve(req: "ChatCompletionInput"):
|
||||||
update_model = req.model != self.loaded_model
|
update_model = self.canonicalized_model_name(req.model) != self.loaded_model
|
||||||
|
|
||||||
if update_model:
|
if update_model:
|
||||||
self.model, self.tokenizer = self.load_model_and_tokenizer(req.model, self.args)
|
self.model, self.tokenizer = self.load_model_and_tokenizer(req.model, self.args)
|
||||||
@ -585,6 +585,11 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
|
|
||||||
return quantization_config
|
return quantization_config
|
||||||
|
|
||||||
|
def canonicalized_model_name(self, model_id: str) -> str:
|
||||||
|
if "@" in model_id:
|
||||||
|
return model_id
|
||||||
|
return f"{model_id}@main"
|
||||||
|
|
||||||
def load_model_and_tokenizer(
|
def load_model_and_tokenizer(
|
||||||
self, model_id_and_revision: str, args: ServeArguments
|
self, model_id_and_revision: str, args: ServeArguments
|
||||||
) -> tuple[PreTrainedModel, PreTrainedTokenizerFast]:
|
) -> tuple[PreTrainedModel, PreTrainedTokenizerFast]:
|
||||||
@ -621,9 +626,9 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
if getattr(model, "hf_device_map", None) is None:
|
if getattr(model, "hf_device_map", None) is None:
|
||||||
model = model.to(args.device)
|
model = model.to(args.device)
|
||||||
|
|
||||||
self.loaded_model = model_id_and_revision
|
self.loaded_model = f"{model_id}@{revision}"
|
||||||
|
|
||||||
logger.warning(f"Loaded model {model_id_and_revision}")
|
logger.warning(f"Loaded model {self.loaded_model}")
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user