Scaffolding

This commit is contained in:
Lysandre 2025-07-01 18:23:16 +02:00
parent feb1eb7531
commit 4eb99fed12

View File

@ -19,7 +19,7 @@ import time
from argparse import ArgumentParser, Namespace from argparse import ArgumentParser, Namespace
from dataclasses import dataclass, field from dataclasses import dataclass, field
from threading import Thread from threading import Thread
from typing import Any, Optional, Generator, Literal from typing import Any, Generator, Optional
from huggingface_hub import ( from huggingface_hub import (
ChatCompletionStreamOutputChoice, ChatCompletionStreamOutputChoice,
@ -27,7 +27,7 @@ from huggingface_hub import (
ChatCompletionStreamOutputDeltaToolCall, ChatCompletionStreamOutputDeltaToolCall,
ChatCompletionStreamOutputFunction, ChatCompletionStreamOutputFunction,
ModelInfo, ModelInfo,
model_info, ChatCompletionStreamOutput, model_info,
) )
from transformers.utils.import_utils import is_fastapi_available, is_pydantic_available, is_uvicorn_available from transformers.utils.import_utils import is_fastapi_available, is_pydantic_available, is_uvicorn_available
@ -340,8 +340,6 @@ class ServeCommand(BaseTransformersCLICommand):
} }
return f"data: {json.dumps(payload)}\n\n" return f"data: {json.dumps(payload)}\n\n"
def run(self): def run(self):
app = FastAPI() app = FastAPI()
@ -426,9 +424,9 @@ class ServeCommand(BaseTransformersCLICommand):
self.running_continuous_batching_manager.logit_processor = LogitsProcessorList() self.running_continuous_batching_manager.logit_processor = LogitsProcessorList()
self.running_continuous_batching_manager.start() self.running_continuous_batching_manager.start()
inputs = self.tokenizer.apply_chat_template( inputs = self.tokenizer.apply_chat_template(req.messages, return_tensors="pt", add_generation_prompt=True).to(
req.messages, return_tensors="pt", add_generation_prompt=True self.model.device
).to(self.model.device) )
def stream_response(_inputs): def stream_response(_inputs):
try: try:
@ -687,7 +685,6 @@ class ServeCommand(BaseTransformersCLICommand):
self.last_messages = req.messages self.last_messages = req.messages
return req_continues_last_messages return req_continues_last_messages
@staticmethod @staticmethod
def get_quantization_config(model_args: ServeArguments) -> Optional["BitsAndBytesConfig"]: def get_quantization_config(model_args: ServeArguments) -> Optional["BitsAndBytesConfig"]:
if model_args.load_in_4bit: if model_args.load_in_4bit: