mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-14 01:58:22 +06:00
Scaffolding
This commit is contained in:
parent
feb1eb7531
commit
4eb99fed12
@ -19,7 +19,7 @@ import time
|
|||||||
from argparse import ArgumentParser, Namespace
|
from argparse import ArgumentParser, Namespace
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from typing import Any, Optional, Generator, Literal
|
from typing import Any, Generator, Optional
|
||||||
|
|
||||||
from huggingface_hub import (
|
from huggingface_hub import (
|
||||||
ChatCompletionStreamOutputChoice,
|
ChatCompletionStreamOutputChoice,
|
||||||
@ -27,7 +27,7 @@ from huggingface_hub import (
|
|||||||
ChatCompletionStreamOutputDeltaToolCall,
|
ChatCompletionStreamOutputDeltaToolCall,
|
||||||
ChatCompletionStreamOutputFunction,
|
ChatCompletionStreamOutputFunction,
|
||||||
ModelInfo,
|
ModelInfo,
|
||||||
model_info, ChatCompletionStreamOutput,
|
model_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
from transformers.utils.import_utils import is_fastapi_available, is_pydantic_available, is_uvicorn_available
|
from transformers.utils.import_utils import is_fastapi_available, is_pydantic_available, is_uvicorn_available
|
||||||
@ -340,8 +340,6 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
}
|
}
|
||||||
return f"data: {json.dumps(payload)}\n\n"
|
return f"data: {json.dumps(payload)}\n\n"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
@ -426,9 +424,9 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
self.running_continuous_batching_manager.logit_processor = LogitsProcessorList()
|
self.running_continuous_batching_manager.logit_processor = LogitsProcessorList()
|
||||||
self.running_continuous_batching_manager.start()
|
self.running_continuous_batching_manager.start()
|
||||||
|
|
||||||
inputs = self.tokenizer.apply_chat_template(
|
inputs = self.tokenizer.apply_chat_template(req.messages, return_tensors="pt", add_generation_prompt=True).to(
|
||||||
req.messages, return_tensors="pt", add_generation_prompt=True
|
self.model.device
|
||||||
).to(self.model.device)
|
)
|
||||||
|
|
||||||
def stream_response(_inputs):
|
def stream_response(_inputs):
|
||||||
try:
|
try:
|
||||||
@ -687,7 +685,6 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
self.last_messages = req.messages
|
self.last_messages = req.messages
|
||||||
return req_continues_last_messages
|
return req_continues_last_messages
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_quantization_config(model_args: ServeArguments) -> Optional["BitsAndBytesConfig"]:
|
def get_quantization_config(model_args: ServeArguments) -> Optional["BitsAndBytesConfig"]:
|
||||||
if model_args.load_in_4bit:
|
if model_args.load_in_4bit:
|
||||||
|
Loading…
Reference in New Issue
Block a user