mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Prevent loading same model twice
This commit is contained in:
parent
244d34de86
commit
153948f979
@ -347,7 +347,7 @@ class ServeCommand(BaseTransformersCLICommand):
|
||||
if not req.stream:
|
||||
return {"error": "Only streaming mode is supported."}
|
||||
|
||||
update_model = req.model != self.loaded_model
|
||||
update_model = self.canonicalized_model_name(req.model) != self.loaded_model
|
||||
|
||||
if update_model:
|
||||
self.model, self.tokenizer = self.load_model_and_tokenizer(req.model, self.args)
|
||||
@ -417,7 +417,7 @@ class ServeCommand(BaseTransformersCLICommand):
|
||||
def generate(self, app):
|
||||
@app.post("/v1/chat/completions")
|
||||
def _serve(req: "ChatCompletionInput"):
|
||||
update_model = req.model != self.loaded_model
|
||||
update_model = self.canonicalized_model_name(req.model) != self.loaded_model
|
||||
|
||||
if update_model:
|
||||
self.model, self.tokenizer = self.load_model_and_tokenizer(req.model, self.args)
|
||||
@ -585,6 +585,11 @@ class ServeCommand(BaseTransformersCLICommand):
|
||||
|
||||
return quantization_config
|
||||
|
||||
def canonicalized_model_name(self, model_id: str) -> str:
|
||||
if "@" in model_id:
|
||||
return model_id
|
||||
return f"{model_id}@main"
|
||||
|
||||
def load_model_and_tokenizer(
|
||||
self, model_id_and_revision: str, args: ServeArguments
|
||||
) -> tuple[PreTrainedModel, PreTrainedTokenizerFast]:
|
||||
@ -621,9 +626,9 @@ class ServeCommand(BaseTransformersCLICommand):
|
||||
if getattr(model, "hf_device_map", None) is None:
|
||||
model = model.to(args.device)
|
||||
|
||||
self.loaded_model = model_id_and_revision
|
||||
self.loaded_model = f"{model_id}@{revision}"
|
||||
|
||||
print("Loaded model", model_id_and_revision)
|
||||
print("Loaded model", self.loaded_model)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user