Scaffolding

This commit is contained in:
Lysandre 2025-07-01 18:23:16 +02:00
parent feb1eb7531
commit 4eb99fed12

View File

@ -19,7 +19,7 @@ import time
from argparse import ArgumentParser, Namespace
from dataclasses import dataclass, field
from threading import Thread
from typing import Any, Optional, Generator, Literal
from typing import Any, Generator, Optional
from huggingface_hub import (
ChatCompletionStreamOutputChoice,
@ -27,7 +27,7 @@ from huggingface_hub import (
ChatCompletionStreamOutputDeltaToolCall,
ChatCompletionStreamOutputFunction,
ModelInfo,
model_info, ChatCompletionStreamOutput,
model_info,
)
from transformers.utils.import_utils import is_fastapi_available, is_pydantic_available, is_uvicorn_available
@ -340,8 +340,6 @@ class ServeCommand(BaseTransformersCLICommand):
}
return f"data: {json.dumps(payload)}\n\n"
def run(self):
app = FastAPI()
@ -426,9 +424,9 @@ class ServeCommand(BaseTransformersCLICommand):
self.running_continuous_batching_manager.logit_processor = LogitsProcessorList()
self.running_continuous_batching_manager.start()
inputs = self.tokenizer.apply_chat_template(
req.messages, return_tensors="pt", add_generation_prompt=True
).to(self.model.device)
inputs = self.tokenizer.apply_chat_template(req.messages, return_tensors="pt", add_generation_prompt=True).to(
self.model.device
)
def stream_response(_inputs):
try:
@ -687,7 +685,6 @@ class ServeCommand(BaseTransformersCLICommand):
self.last_messages = req.messages
return req_continues_last_messages
@staticmethod
def get_quantization_config(model_args: ServeArguments) -> Optional["BitsAndBytesConfig"]:
if model_args.load_in_4bit: