Split transformers chat and transformers serve (#38443)

* Next token

* Split chat and serve

* Support both generation methods

* Style

* Generation Config

* temp

* temp

* Finalize serving.py

Co-authored-by: =?UTF-8?q?c=C3=A9lina?= <hanouticelina@gmail.com>

* Finalize chat.py

* Update src/transformers/commands/serving.py

Co-authored-by: célina <hanouticelina@gmail.com>

* Lucain's comments

Co-authored-by: Lucain <lucain@huggingface.co>

* Update

* Last comments on PR

* Better error handling

* Better error handling

* CI errors

* CI errors

* Add tests

* Fix tests

* Fix tests

* [chat] Split chat/serve (built on top of lysandre's PR) (#39031)

* Next token

* Split chat and serve

* Support both generation methods

* Style

* Generation Config

* temp

* temp

* Finalize serving.py

Co-authored-by: =?UTF-8?q?c=C3=A9lina?= <hanouticelina@gmail.com>

* Finalize chat.py

* Update src/transformers/commands/serving.py

Co-authored-by: célina <hanouticelina@gmail.com>

* Lucain's comments

Co-authored-by: Lucain <lucain@huggingface.co>

* Update

* Last comments on PR

* Better error handling

* Better error handling

* CI errors

* CI errors

* Add tests

* Fix tests

* Fix tests

* streaming tool call

* abstract tool state; set tool start as eos

* todos

* server working on models without tools

* rm chat's deprecated flags

* chat defaults

* kv cache persists across calls

* add server docs

* link

* Update src/transformers/commands/serving.py

* Apply suggestions from code review

* i love merge conflicts

* solve multi turn with tiny-agents

* On the fly switching of the models

* Remove required positional arg

---------

Co-authored-by: Lysandre <hi@lysand.re>
Co-authored-by: =?UTF-8?q?c=C3=A9lina?= <hanouticelina@gmail.com>
Co-authored-by: Lucain <lucain@huggingface.co>

* Protect names

* Fix tests

---------

Co-authored-by: =?UTF-8?q?c=C3=A9lina?= <hanouticelina@gmail.com>
Co-authored-by: Lucain <lucain@huggingface.co>
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
Lysandre Debut 2025-06-30 15:10:53 +02:00 committed by GitHub
parent 539c6c2fa8
commit e8f90b5397
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 924 additions and 319 deletions

View File

@ -27,6 +27,9 @@ This guide shows you how to quickly start chatting with Transformers from the co
## transformers CLI ## transformers CLI
### Interactive chat session
After you've [installed Transformers](./installation.md), chat with a model directly from the command line as shown below. It launches an interactive session with a model, with a few base commands listed at the start of the session. After you've [installed Transformers](./installation.md), chat with a model directly from the command line as shown below. It launches an interactive session with a model, with a few base commands listed at the start of the session.
```bash ```bash
@ -51,6 +54,68 @@ transformers chat -h
The chat is implemented on top of the [AutoClass](./model_doc/auto), using tooling from [text generation](./llm_tutorial) and [chat](./chat_templating). The chat is implemented on top of the [AutoClass](./model_doc/auto), using tooling from [text generation](./llm_tutorial) and [chat](./chat_templating).
### Serving a model and using MCP tools
> [!WARNING]
> This section is experimental and subject to changes in future versions
Powering the `chat` interface, we have a server that takes user messages and returns completions. The server has a chat completion API compatible with the OpenAI SDK, so you can also quickly experiment with `transformers` models on existing aplications. To launch a server separately, use the `transformers serve` CLI:
```bash
transformers serve Menlo/Jan-nano
```
Under the hood, the `chat` CLI launches and uses `transformers serve`. This server is also an MCP client, which can receive information available MCP servers (i.e. tools), massage their information into the model prompt, and prepare calls to these tools when the model commands to do so. Naturally, this requires a model that is trained to use tools.
At the moment, MCP tool usage in `transformers` has the following constraints:
- `chat` can't handle tools, but the [`tiny-agents`](https://huggingface.co/blog/python-tiny-agents) CLI can;
- Only the `qwen` family of models is supported.
The first step to use MCP tools is to let the model know which tools are available. As an example, let's consider a `tiny-agents` configuration file with a reference to an [image generation MCP server](https://evalstate-flux1-schnell.hf.space/).
> [!TIP]
> Many Hugging Face Spaces can be used as MCP servers. You can find all compatible Spaces [here](https://huggingface.co/spaces?filter=mcp-server).
```json
{
"model": "http://localhost:8000",
"provider": "local",
"servers": [
{
"type": "sse",
"config": {
"url": "https://evalstate-flux1-schnell.hf.space/gradio_api/mcp/sse"
}
}
]
}
```
You can then launch your `tiny-agents` chat interface with the following command.
```bash
tiny-agents run path/to/your/config.json
```
If you have a server (from `transformers serve`) running in the background, you're ready to use MCP tools from a local model! For instance, here's the example of a chat session:
```bash
Agent loaded with 1 tools:
• flux1_schnell_infer
» Generate an image of a cat on the moon
<Tool req_0_tool_call>flux1_schnell_infer {"prompt": "a cat on the moon", "seed": 42, "randomize_seed": true, "width": 1024, "height": 1024, "num_inference_steps": 4}
Tool req_0_tool_call
[Binary Content: Image image/webp, 57732 bytes]
The task is complete and the content accessible to the User
Image URL: https://evalstate-flux1-schnell.hf.space/gradio_api/file=/tmp/gradio/3dbddc0e53b5a865ed56a4e3dbdd30f3f61cf3b8aabf1b456f43e5241bd968b8/image.webp
380576952
I have generated an image of a cat on the moon using the Flux 1 Schnell Image Generator. The image is 1024x1024 pixels and was created with 4 inference steps. Let me know if you would like to make any changes or need further assistance!
```
## TextGenerationPipeline ## TextGenerationPipeline
[`TextGenerationPipeline`] is a high-level text generation class with a "chat mode". Chat mode is enabled when a conversational model is detected and the chat prompt is [properly formatted](./llm_tutorial#wrong-prompt-format). [`TextGenerationPipeline`] is a high-level text generation class with a "chat mode". Chat mode is enabled when a conversational model is detected and the chat prompt is [properly formatted](./llm_tutorial#wrong-prompt-format).

View File

@ -148,7 +148,7 @@ _deps = [
"protobuf", "protobuf",
"psutil", "psutil",
"pyyaml>=5.1", "pyyaml>=5.1",
"pydantic", "pydantic>=2",
"pytest>=7.2.0", "pytest>=7.2.0",
"pytest-asyncio", "pytest-asyncio",
"pytest-rerunfailures", "pytest-rerunfailures",

View File

@ -13,33 +13,30 @@
# limitations under the License. # limitations under the License.
import copy import asyncio
import json import json
import os import os
import platform import platform
import re import re
import string import string
import time import time
import warnings
from argparse import ArgumentParser, Namespace from argparse import ArgumentParser, Namespace
from dataclasses import dataclass, field from dataclasses import dataclass, field
from threading import Thread from threading import Thread
from typing import Optional from typing import AsyncIterator, Optional
import yaml import yaml
from huggingface_hub.utils import disable_progress_bars from huggingface_hub import AsyncInferenceClient, ChatCompletionStreamOutput
from transformers import ( from transformers import (
AutoTokenizer, AutoTokenizer,
GenerationConfig, GenerationConfig,
PreTrainedTokenizer, PreTrainedTokenizer,
TextIteratorStreamer,
logging,
) )
from transformers.commands import BaseTransformersCLICommand
from transformers.commands.serving import ServeArguments, ServeCommand
from transformers.utils import is_rich_available, is_torch_available from transformers.utils import is_rich_available, is_torch_available
from . import BaseTransformersCLICommand
if platform.system() != "Windows": if platform.system() != "Windows":
import pwd import pwd
@ -52,8 +49,12 @@ if is_rich_available():
if is_torch_available(): if is_torch_available():
import torch import torch
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, PreTrainedModel from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
GenerationConfig,
)
ALLOWED_KEY_CHARS = set(string.ascii_letters + string.whitespace) ALLOWED_KEY_CHARS = set(string.ascii_letters + string.whitespace)
ALLOWED_VALUE_CHARS = set( ALLOWED_VALUE_CHARS = set(
@ -107,19 +108,6 @@ If you're a new user, check this basic flag guide: https://huggingface.co/docs/t
- **!exit**: closes the interface - **!exit**: closes the interface
""" """
# format: (optional CLI arg being deprecated, its current default, corresponding `generate` flag)
_DEPRECATION_MAP = [
("max_new_tokens", 256, "max_new_tokens"),
("do_sample", True, "do_sample"),
("num_beams", 1, "num_beams"),
("temperature", 1.0, "temperature"),
("top_k", 50, "top_k"),
("top_p", 1.0, "top_p"),
("repetition_penalty", 1.0, "repetition_penalty"),
("eos_tokens", None, "eos_token_id"),
("eos_token_ids", None, "eos_token_id"),
]
class RichInterface: class RichInterface:
def __init__(self, model_name: Optional[str] = None, user_name: Optional[str] = None): def __init__(self, model_name: Optional[str] = None, user_name: Optional[str] = None):
@ -133,21 +121,21 @@ class RichInterface:
else: else:
self.user_name = user_name self.user_name = user_name
def stream_output(self, output_stream: TextIteratorStreamer) -> str: async def stream_output(self, stream: AsyncIterator[ChatCompletionStreamOutput]) -> tuple[str, int]:
"""Stream output from a role, and return the generated text after it's done steaming."""
# This method is originally from the FastChat CLI:
# https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/cli.py
# Create a Live context for updating the console output
text = ""
self._console.print(f"[bold blue]<{self.model_name}>:") self._console.print(f"[bold blue]<{self.model_name}>:")
with Live(console=self._console, refresh_per_second=4) as live: with Live(console=self._console, refresh_per_second=4) as live:
# Read lines from the stream text = ""
for i, outputs in enumerate(output_stream): async for token in await stream:
if not outputs or i == 0: outputs = token.choices[0].delta.content
request_id = token.id
if not outputs:
continue continue
# Escapes single words encased in <>, e.g. <think> -> \<think\>, for proper rendering in Markdown. # Escapes single words encased in <>, e.g. <think> -> \<think\>, for proper rendering in Markdown.
# It only escapes single words that may have `_`, optionally following a `/` (e.g. </think>) # It only escapes single words that may have `_`, optionally following a `/` (e.g. </think>)
outputs = re.sub(r"<(/*)(\w*)>", r"\<\1\2\>", outputs) outputs = re.sub(r"<(/*)(\w*)>", r"\<\1\2\>", outputs)
text += outputs text += outputs
# Render the accumulated text as Markdown # Render the accumulated text as Markdown
# NOTE: this is a workaround for the rendering "unstandard markdown" # NOTE: this is a workaround for the rendering "unstandard markdown"
@ -160,6 +148,7 @@ class RichInterface:
# introduce trailing spaces (only) in code block, but it works well # introduce trailing spaces (only) in code block, but it works well
# especially for console output, because in general the console does not # especially for console output, because in general the console does not
# care about trailing spaces. # care about trailing spaces.
lines = [] lines = []
for line in text.splitlines(): for line in text.splitlines():
lines.append(line) lines.append(line)
@ -169,11 +158,15 @@ class RichInterface:
lines.append("\n") lines.append("\n")
else: else:
lines.append(" \n") lines.append(" \n")
markdown = Markdown("".join(lines).strip(), code_theme="github-dark") markdown = Markdown("".join(lines).strip(), code_theme="github-dark")
# Update the Live console output # Update the Live console output
live.update(markdown) live.update(markdown, refresh=True)
self._console.print() self._console.print()
return text
return text, request_id
def input(self) -> str: def input(self) -> str:
"""Gets user input from the console.""" """Gets user input from the console."""
@ -245,25 +238,6 @@ class ChatArguments:
), ),
}, },
) )
# Deprecated CLI args start here
max_new_tokens: int = field(default=256, metadata={"help": "Maximum number of tokens to generate."})
do_sample: bool = field(default=True, metadata={"help": "Whether to sample outputs during generation."})
num_beams: int = field(default=1, metadata={"help": "Number of beams for beam search."})
temperature: float = field(default=1.0, metadata={"help": "Temperature parameter for generation."})
top_k: int = field(default=50, metadata={"help": "Value of k for top-k sampling."})
top_p: float = field(default=1.0, metadata={"help": "Value of p for nucleus sampling."})
repetition_penalty: float = field(default=1.0, metadata={"help": "Repetition penalty."})
eos_tokens: Optional[str] = field(
default=None,
metadata={
"help": "EOS tokens (text format) to stop the generation. If multiple they should be comma separated."
},
)
eos_token_ids: Optional[str] = field(
default=None,
metadata={"help": "EOS token IDs to stop the generation. If multiple they should be comma separated."},
)
# Deprecated CLI args end here
# Model loading # Model loading
model_revision: str = field( model_revision: str = field(
@ -300,6 +274,10 @@ class ChatArguments:
bnb_4bit_quant_type: str = field(default="nf4", metadata={"help": "Quantization type.", "choices": ["fp4", "nf4"]}) bnb_4bit_quant_type: str = field(default="nf4", metadata={"help": "Quantization type.", "choices": ["fp4", "nf4"]})
use_bnb_nested_quant: bool = field(default=False, metadata={"help": "Whether to use nested quantization."}) use_bnb_nested_quant: bool = field(default=False, metadata={"help": "Whether to use nested quantization."})
# Serving settings
host: str = field(default="localhost", metadata={"help": "Interface the server will listen to.."})
port: int = field(default=8000, metadata={"help": "Port the server will listen to."})
def chat_command_factory(args: Namespace): def chat_command_factory(args: Namespace):
""" """
@ -322,7 +300,10 @@ class ChatCommand(BaseTransformersCLICommand):
group = chat_parser.add_argument_group("Positional arguments") group = chat_parser.add_argument_group("Positional arguments")
group.add_argument( group.add_argument(
"model_name_or_path_positional", type=str, default=None, help="Name of the pre-trained model." "model_name_or_path_or_address",
type=str,
default=None,
help="Name of the pre-trained model or address to connect to.",
) )
group.add_argument( group.add_argument(
"generate_flags", "generate_flags",
@ -332,57 +313,45 @@ class ChatCommand(BaseTransformersCLICommand):
"Flags to pass to `generate`, using a space as a separator between flags. Accepts booleans, numbers, " "Flags to pass to `generate`, using a space as a separator between flags. Accepts booleans, numbers, "
"and lists of integers, more advanced parameterization should be set through --generation-config. " "and lists of integers, more advanced parameterization should be set through --generation-config. "
"Example: `transformers chat <model_repo> max_new_tokens=100 do_sample=False eos_token_id=[1,2]`. " "Example: `transformers chat <model_repo> max_new_tokens=100 do_sample=False eos_token_id=[1,2]`. "
"If you're a new user, check this basic flag guide: https://huggingface.co/docs/transformers/llm_tutorial#common-options" "If you're a new user, check this basic flag guide: "
"https://huggingface.co/docs/transformers/llm_tutorial#common-options"
), ),
nargs="*", nargs="*",
) )
chat_parser.set_defaults(func=chat_command_factory) chat_parser.set_defaults(func=chat_command_factory)
def __init__(self, args): def __init__(self, args):
args = self._handle_deprecated_args(args) if args.model_name_or_path_or_address is not None:
name = args.model_name_or_path_or_address
if name.startswith("http") or name.startswith("https") or name.startswith("localhost"):
self.spawn_backend = False
if args.host != "localhost" or args.port != 8000:
raise ValueError(
"Looks like youve set both a server address and a custom host/port. "
"Please pick just one way to specify the server."
)
args.host, args.port = args.model_name_or_path_or_address.rsplit(":", 1)
else:
self.spawn_backend = True
args.model_name_or_path = args.model_name_or_path_or_address
if not is_rich_available() and (not is_torch_available() and self.spawn_backend):
raise ImportError(
"You need to install rich to use the chat interface. Additionally, you have not specified a remote "
"endpoint and are therefore spawning a backend. Torch is required for this: (`pip install rich torch`)"
)
elif not is_rich_available():
raise ImportError("You need to install rich to use the chat interface. (`pip install rich`)")
elif not is_torch_available() and self.spawn_backend:
raise ImportError(
"You have not specified a remote endpoint and are therefore spawning a backend. Torch is required "
"for this: (`pip install rich torch`)"
)
self.args = args self.args = args
def _handle_deprecated_args(self, args: ChatArguments) -> ChatArguments:
"""
Handles deprecated arguments and their deprecation cycle. To be removed after we fully migrated to the new
args.
"""
has_warnings = False
# 1. Model as a positional argument
args.model_name_or_path_positional = args.model_name_or_path_positional or args.model_name_or_path
if args.model_name_or_path_positional is None:
raise ValueError(
"One of the following must be provided:"
"\n- The positional argument containing the model repo, e.g. `transformers chat <model_repo>`"
"\n- the optional --model_name_or_path argument, containing the model repo (deprecated)"
)
elif args.model_name_or_path is not None:
has_warnings = True
warnings.warn(
"The --model_name_or_path argument is deprecated will be removed in v4.54.0. Use the positional "
"argument instead, e.g. `transformers chat <model_repo>`.",
FutureWarning,
)
# 2. Named generate option args
for deprecated_arg, default_value, new_arg in _DEPRECATION_MAP:
value = getattr(args, deprecated_arg)
if value != default_value:
has_warnings = True
warnings.warn(
f"The --{deprecated_arg} argument is deprecated will be removed in v4.54.0. There are two "
"alternative solutions to specify this generation option: \n"
"1. Pass `--generation-config <path_to_file/Hub repo>` to specify a generation config.\n"
"2. Pass `generate` flags through positional arguments, e.g. `transformers chat <model_repo> "
f"{new_arg}={value}`",
FutureWarning,
)
if has_warnings:
print("\n(Press enter to continue)")
input()
return args
# ----------------------------------------------------------------------------------------------------------------- # -----------------------------------------------------------------------------------------------------------------
# Chat session methods # Chat session methods
@staticmethod @staticmethod
@ -404,7 +373,7 @@ class ChatCommand(BaseTransformersCLICommand):
if filename is None: if filename is None:
time_str = time.strftime("%Y-%m-%d_%H-%M-%S") time_str = time.strftime("%Y-%m-%d_%H-%M-%S")
filename = f"{args.model_name_or_path_positional}/chat_{time_str}.json" filename = f"{args.model_name_or_path_or_address}/chat_{time_str}.json"
filename = os.path.join(folder, filename) filename = os.path.join(folder, filename)
os.makedirs(os.path.dirname(filename), exist_ok=True) os.makedirs(os.path.dirname(filename), exist_ok=True)
@ -477,40 +446,23 @@ class ChatCommand(BaseTransformersCLICommand):
) )
return processed_generate_flags return processed_generate_flags
def get_generation_parameterization( def get_generation_parameterization(self, args: ChatArguments) -> tuple[GenerationConfig, dict]:
self, args: ChatArguments, tokenizer: AutoTokenizer, model: PreTrainedModel
) -> tuple[GenerationConfig, dict]:
""" """
Returns a GenerationConfig object holding the generation parameters for the CLI command. Returns a GenerationConfig object holding the generation parameters for the CLI command.
""" """
# No generation config arg provided -> use default generation config, apply CLI defaults # No generation config arg provided -> use base generation config, apply CLI defaults
if args.generation_config is None: if args.generation_config is not None:
# We start off from the checkpoint's generation config
generation_config = copy.deepcopy(model.generation_config)
# Apply deprecated CLI args on top of the default generation config
pad_token_id, eos_token_ids = self.parse_eos_tokens(
tokenizer, generation_config, args.eos_tokens, args.eos_token_ids
)
deprecated_kwargs = {
"max_new_tokens": args.max_new_tokens,
"do_sample": args.do_sample,
"num_beams": args.num_beams,
"temperature": args.temperature,
"top_k": args.top_k,
"top_p": args.top_p,
"repetition_penalty": args.repetition_penalty,
"pad_token_id": pad_token_id,
"eos_token_id": eos_token_ids,
}
generation_config.update(**deprecated_kwargs)
# generation config arg provided -> use it as the base parameterization
else:
if ".json" in args.generation_config: # is a local file if ".json" in args.generation_config: # is a local file
dirname = os.path.dirname(args.generation_config) dirname = os.path.dirname(args.generation_config)
filename = os.path.basename(args.generation_config) filename = os.path.basename(args.generation_config)
generation_config = GenerationConfig.from_pretrained(dirname, filename) generation_config = GenerationConfig.from_pretrained(dirname, filename)
else: else:
generation_config = GenerationConfig.from_pretrained(args.generation_config) generation_config = GenerationConfig.from_pretrained(args.generation_config)
else:
# !!!!!!!!!
# This is a chat session, so we have a few non-standard defaults
# !!!!!!!!!
generation_config = GenerationConfig(do_sample=True, max_new_tokens=256)
# Finally: parse and apply `generate_flags` # Finally: parse and apply `generate_flags`
parsed_generate_flags = self.parse_generate_flags(args.generate_flags) parsed_generate_flags = self.parse_generate_flags(args.generate_flags)
@ -664,7 +616,7 @@ class ChatCommand(BaseTransformersCLICommand):
elif user_input == "!status": elif user_input == "!status":
interface.print_status( interface.print_status(
model_name=args.model_name_or_path_positional, model_name=args.model_name_or_path,
generation_config=generation_config, generation_config=generation_config,
model_kwargs=model_kwargs, model_kwargs=model_kwargs,
) )
@ -679,10 +631,33 @@ class ChatCommand(BaseTransformersCLICommand):
# ----------------------------------------------------------------------------------------------------------------- # -----------------------------------------------------------------------------------------------------------------
# Main logic # Main logic
def run(self): def run(self):
if not is_rich_available(): asyncio.run(self._inner_run())
raise ImportError("You need to install rich to use the chat interface. (`pip install rich`)")
if not is_torch_available(): async def _inner_run(self):
raise ImportError("You need to install torch to use the chat interface. (`pip install torch`)") if self.spawn_backend:
serve_args = ServeArguments(
model_revision=self.args.model_revision,
device=self.args.device,
torch_dtype=self.args.torch_dtype,
trust_remote_code=self.args.trust_remote_code,
attn_implementation=self.args.attn_implementation,
load_in_8bit=self.args.load_in_8bit,
load_in_4bit=self.args.load_in_4bit,
bnb_4bit_quant_type=self.args.bnb_4bit_quant_type,
use_bnb_nested_quant=self.args.use_bnb_nested_quant,
host=self.args.host,
port=self.args.port,
log_level="error",
)
serve_args.model_name_or_path = self.args.model_name_or_path
serve_command = ServeCommand(serve_args)
thread = Thread(target=serve_command.run)
thread.daemon = True
thread.start()
host = "http://localhost" if self.args.host == "localhost" else self.args.host
client = AsyncInferenceClient(f"{host}:{self.args.port}")
args = self.args args = self.args
if args.examples_path is None: if args.examples_path is None:
@ -696,19 +671,14 @@ class ChatCommand(BaseTransformersCLICommand):
else: else:
user = args.user user = args.user
model, tokenizer = self.load_model_and_tokenizer(args) generation_config, model_kwargs = self.get_generation_parameterization(args)
generation_streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True)
generation_config, model_kwargs = self.get_generation_parameterization(args, tokenizer, model)
# if not verbose -> disable warnings, progress bars, etc in the chat interface interface = RichInterface(model_name=args.model_name_or_path, user_name=user)
if not args.verbose:
logging.set_verbosity_error()
disable_progress_bars()
interface = RichInterface(model_name=args.model_name_or_path_positional, user_name=user)
interface.clear() interface.clear()
chat = self.clear_chat_history(args.system_prompt) chat = self.clear_chat_history(args.system_prompt)
request_id = None
# Starts the session with a minimal help message at the top, so that a user doesn't get stuck # Starts the session with a minimal help message at the top, so that a user doesn't get stuck
interface.print_help(minimal=True) interface.print_help(minimal=True)
while True: while True:
@ -736,23 +706,25 @@ class ChatCommand(BaseTransformersCLICommand):
else: else:
chat.append({"role": "user", "content": user_input}) chat.append({"role": "user", "content": user_input})
inputs = tokenizer.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=True).to( stream = client.chat_completion(
model.device chat,
stream=True,
extra_body={"request_id": request_id, "generation_config": {**generation_config.to_dict()}},
) )
attention_mask = torch.ones_like(inputs)
generation_kwargs = {
"inputs": inputs,
"attention_mask": attention_mask,
"streamer": generation_streamer,
"generation_config": generation_config,
**model_kwargs,
}
thread = Thread(target=model.generate, kwargs=generation_kwargs) model_output, request_id = await interface.stream_output(stream)
thread.start()
model_output = interface.stream_output(generation_streamer)
thread.join()
chat.append({"role": "assistant", "content": model_output}) chat.append({"role": "assistant", "content": model_output})
except KeyboardInterrupt: except KeyboardInterrupt:
break break
finally:
await client.close()
if __name__ == "__main__":
args = ChatArguments()
args.model_name_or_path_or_address = "meta-llama/Llama-3.2-3b-Instruct"
args.model_name_or_path_or_address = "http://localhost:8000"
chat = ChatCommand(args)
chat.run()

View File

@ -11,33 +11,95 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import functools
import json
import re
import time
from argparse import ArgumentParser, Namespace from argparse import ArgumentParser, Namespace
from dataclasses import dataclass, field
from threading import Thread
from typing import Any, Optional from typing import Any, Optional
from ..pipelines import Pipeline, get_supported_tasks, pipeline from huggingface_hub import (
from ..utils import logging ChatCompletionStreamOutputChoice,
ChatCompletionStreamOutputDelta,
ChatCompletionStreamOutputDeltaToolCall,
ChatCompletionStreamOutputFunction,
ModelInfo,
model_info,
)
from transformers.utils.import_utils import is_fastapi_available, is_pydantic_available, is_uvicorn_available
from .. import PreTrainedTokenizerFast, TextIteratorStreamer
from ..generation.continuous_batching import ContinuousBatchingManager, RequestStatus
from ..utils import is_torch_available, logging
from . import BaseTransformersCLICommand from . import BaseTransformersCLICommand
try: if is_torch_available():
from fastapi import Body, FastAPI, HTTPException import torch
from fastapi.routing import APIRoute
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
GenerationConfig,
PreTrainedModel,
)
if is_pydantic_available() and is_fastapi_available() and is_uvicorn_available():
import uvicorn
from fastapi import FastAPI
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel from pydantic import BaseModel
from starlette.responses import JSONResponse
from uvicorn import run
_serve_dependencies_installed = True class Message(BaseModel):
except (ImportError, AttributeError): role: str
BaseModel = object content: str
def Body(*x, **y): class ChatCompletionInput(BaseModel):
pass messages: list[Message]
_serve_dependencies_installed = False stream: Optional[bool] = False
model: Optional[str] = None
request_id: Optional[str] = None
extra_body: Optional[dict] = None
frequency_penalty: Optional[float] = None
logit_bias: Optional[list[float]] = None
max_tokens: Optional[int] = None
stop: Optional[list[str]] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
seed: Optional[int] = None
# Additional options supported by the HFH InferenceClient
# that aren't yet supported here.
# logprobs: Optional[bool] = None
tools: Any = None
# n: Optional[int] = None
# presence_penalty: Optional[float] = None
# response_format: Optional[ChatCompletionInputGrammarType] = None
# stream_options: Optional[ChatCompletionInputStreamOptions] = None
# tool_choice: Optional[Union[ChatCompletionInputToolChoiceClass, "ChatCompletionInputToolChoiceEnum"]] = None
# tool_prompt: Optional[str] = None
# top_logprobs: Optional[int] = None
logger = logging.get_logger("transformers/serving") logger = logging.get_logger(__name__)
# Possible tokens that indicate the start/end of a tool call
# TODO (joao, matt): streamline tool token detection logic
_TOOL_CALL_TOKENS = {
"qwen": {
"start": "<tool_call>",
"end": "</tool_call>",
},
}
_MODELS_WITH_TOOL_SUPPORT = list(_TOOL_CALL_TOKENS.keys())
def serve_command_factory(args: Namespace): def serve_command_factory(args: Namespace):
@ -46,50 +108,114 @@ def serve_command_factory(args: Namespace):
Returns: ServeCommand Returns: ServeCommand
""" """
nlp = pipeline( return ServeCommand(args)
task=args.task,
model=args.model if args.model else None,
config=args.config, def create_generation_config_from_req(req: "ChatCompletionInput") -> "GenerationConfig":
tokenizer=args.tokenizer, """
device=args.device, Creates a generation config from the parameters of the request. Note that we can pass a `GenerationConfig`
(serialized into a `dict`) in `extra_body`, for full `generate` parameterization.
Args:
req (`ChatCompletionInput`): The request which may optionally contain generation parameters.
Returns:
The prepared `GenerationConfig` object.
"""
if req.extra_body is not None and "generation_config" in req.extra_body:
for key in req.extra_body["generation_config"].keys():
if key in ChatCompletionInput.base_field_names.keys():
return {"error": "Duplicated key in the root request and in the passed generation config."}
if req.extra_body is not None and "generation_config" in req.extra_body:
generation_config = GenerationConfig(**(req.extra_body["generation_config"]))
else:
generation_config = GenerationConfig()
if req.frequency_penalty is not None:
generation_config.repetition_penalty = req.frequency_penalty
if req.logit_bias is not None:
generation_config.sequence_bias = req.logit_bias
if req.stop is not None:
generation_config.stop_strings = req.stop
if req.temperature is not None:
generation_config.temperature = req.temperature
if req.top_p is not None:
generation_config.top_p = req.top_p
if req.seed is not None:
torch.manual_seed(req.seed)
return generation_config
class ToolState:
"""Lightweight class to keep track of the tool call state."""
def __init__(self):
self.reset()
def reset(self):
"""Reset the tool call state (assumes we're outside a tool call)."""
self.inside_tool_call = False
self.has_tool_name_defined = False
self.arg_nesting_level = 0
self.buffer = ""
@dataclass
class ServeArguments:
r"""
Arguments for the serve CLI.
See the metadata arg for each argument's description -- the metadata will be printed with
`transformers serve --help`
"""
device: str = field(default="cpu", metadata={"help": "Device to use for inference."})
torch_dtype: Optional[str] = field(
default="auto",
metadata={
"help": "Override the default `torch.dtype` and load the model under this dtype. If `'auto'` is passed, "
"the dtype will be automatically derived from the model's weights.",
"choices": ["auto", "bfloat16", "float16", "float32"],
},
) )
return ServeCommand(nlp, args.host, args.port, args.workers) trust_remote_code: bool = field(
default=False, metadata={"help": "Whether to trust remote code when loading a model."}
)
attn_implementation: Optional[str] = field(
default=None,
metadata={
"help": "Which attention implementation to use; you can run --attn_implementation=flash_attention_2, in "
"which case you must install this manually by running `pip install flash-attn --no-build-isolation`."
},
)
load_in_8bit: bool = field(
default=False,
metadata={"help": "Whether to use 8 bit precision for the base model - works only with LoRA."},
)
load_in_4bit: bool = field(
default=False,
metadata={"help": "Whether to use 4 bit precision for the base model - works only with LoRA."},
)
bnb_4bit_quant_type: str = field(default="nf4", metadata={"help": "Quantization type.", "choices": ["fp4", "nf4"]})
use_bnb_nested_quant: bool = field(default=False, metadata={"help": "Whether to use nested quantization."})
# Serving settings
host: str = field(default="localhost", metadata={"help": "Interface the server will listen to.."})
port: int = field(default=8000, metadata={"help": "Port the server will listen to."})
class ServeModelInfoResult(BaseModel): # Other settings
""" log_level: str = field(
Expose model information default="info", metadata={"help": "Logging level as a string. Example: 'info' or 'warning'."}
""" )
infos: dict
class ServeTokenizeResult(BaseModel):
"""
Tokenize result model
"""
tokens: list[str]
tokens_ids: Optional[list[int]]
class ServeDeTokenizeResult(BaseModel):
"""
DeTokenize result model
"""
text: str
class ServeForwardResult(BaseModel):
"""
Forward result model
"""
output: Any
class ServeCommand(BaseTransformersCLICommand): class ServeCommand(BaseTransformersCLICommand):
loaded_model: Optional[str] = None
model: PreTrainedModel
tokenizer: PreTrainedTokenizerFast
@staticmethod @staticmethod
def register_subcommand(parser: ArgumentParser): def register_subcommand(parser: ArgumentParser):
""" """
@ -98,131 +224,409 @@ class ServeCommand(BaseTransformersCLICommand):
Args: Args:
parser: Root parser to register command-specific arguments parser: Root parser to register command-specific arguments
""" """
serve_parser = parser.add_parser( dataclass_types = (ServeArguments,)
"serve", help="CLI tool to run inference requests through REST and GraphQL endpoints." serve_parser = parser.add_parser("serve", dataclass_types=dataclass_types)
)
serve_parser.add_argument(
"--task",
type=str,
choices=get_supported_tasks(),
help="The task to run the pipeline on",
)
serve_parser.add_argument("--host", type=str, default="localhost", help="Interface the server will listen on.")
serve_parser.add_argument("--port", type=int, default=8888, help="Port the serving will listen to.")
serve_parser.add_argument("--workers", type=int, default=1, help="Number of http workers")
serve_parser.add_argument("--model", type=str, help="Model's name or path to stored model.")
serve_parser.add_argument("--config", type=str, help="Model's config name or path to stored model.")
serve_parser.add_argument("--tokenizer", type=str, help="Tokenizer name to use.")
serve_parser.add_argument(
"--device",
type=int,
default=-1,
help="Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)",
)
serve_parser.set_defaults(func=serve_command_factory) serve_parser.set_defaults(func=serve_command_factory)
def __init__(self, pipeline: Pipeline, host: str, port: int, workers: int): def __init__(self, args: ServeArguments):
self._pipeline = pipeline if not is_pydantic_available() or not is_fastapi_available() or not is_uvicorn_available():
raise ImportError(
self.host = host "Missing dependencies for the serving CLI. Please install with `pip install transformers[serving]`"
self.port = port
self.workers = workers
if not _serve_dependencies_installed:
raise RuntimeError(
"Using serve command requires FastAPI and uvicorn. "
'Please install transformers with [serving]: pip install "transformers[serving]". '
"Or install FastAPI and uvicorn separately."
) )
else:
logger.info(f"Serving model over {host}:{port}") self.args = args
self._app = FastAPI( self.use_continuous_batching = self.args.attn_implementation == "sdpa_paged"
routes=[
APIRoute( # State: preserves information about the last call and last KV cache, to determine whether we can reuse the KV
"/", # cache and avoid re-running prefil
self.model_info, self.last_messages = None
response_model=ServeModelInfoResult, self.last_kv_cache = None
response_class=JSONResponse,
methods=["GET"], transformers_logger = logging.get_logger("transformers")
transformers_logger.setLevel(logging.log_levels[self.args.log_level.lower()])
cb_logger = logging.get_logger("transformers.generation.continuous_batching")
cb_logger.setLevel(logging.log_levels[self.args.log_level.lower()])
def build_chunk(
self,
content: str,
request_id: str,
role: Optional[str] = None,
finish_reason: Optional[str] = None,
tool_calls: Optional[list[ChatCompletionStreamOutputDeltaToolCall]] = None,
) -> str:
payload = {
"object": "chat.completion.chunk",
"id": request_id,
"created": int(time.time()),
"model": self.loaded_model,
"system_fingerprint": "",
"choices": [
ChatCompletionStreamOutputChoice(
delta=ChatCompletionStreamOutputDelta(
role=role,
content=content,
tool_calls=tool_calls,
), ),
APIRoute( index=0,
"/tokenize", logprobs=None,
self.tokenize, finish_reason=finish_reason,
response_model=ServeTokenizeResult, ),
response_class=JSONResponse, ],
methods=["POST"], }
), return f"data: {json.dumps(payload)}\n\n"
APIRoute(
"/detokenize",
self.detokenize,
response_model=ServeDeTokenizeResult,
response_class=JSONResponse,
methods=["POST"],
),
APIRoute(
"/forward",
self.forward,
response_model=ServeForwardResult,
response_class=JSONResponse,
methods=["POST"],
),
],
timeout=600,
)
def run(self): def run(self):
run(self._app, host=self.host, port=self.port, workers=self.workers) app = FastAPI()
def model_info(self): if self.use_continuous_batching:
return ServeModelInfoResult(infos=vars(self._pipeline.model.config)) self.continuous_batching(app)
else:
self.generate(app)
def tokenize(self, text_input: str = Body(None, embed=True), return_ids: bool = Body(False, embed=True)): @functools.lru_cache(maxsize=None)
def get_text_gen_models() -> list[ModelInfo]:
"""
This is by no means a limit to which models may be instantiated with `transformers serve`: any chat-based
model working with generate can work.
This is a limited list of models to ensure we have a discoverable /v1/models endpoint for third-party
integrations.
"""
return [
model_info("Menlo/Jan-nano"),
model_info("Menlo/Jan-nano-128k"),
model_info("Qwen/Qwen2.5-0.5B-Instruct"),
model_info("Qwen/Qwen2.5-3B-Instruct"),
model_info("Qwen/Qwen2.5-7B-Instruct"),
model_info("Qwen/Qwen2.5-14B-Instruct"),
model_info("meta-llama/Llama-3.1-8B-Instruct"),
model_info("meta-llama/Llama-3.2-1B-Instruct"),
model_info("meta-llama/Llama-3.3-70B-Instruct"),
]
@app.get("/v1/models")
def get_all_models():
return JSONResponse(
{
"object": "list",
"data": [
{
"id": model.id,
"object": "model",
"crated": model.created_at.timestamp(),
"owned_by": model.author,
}
for model in get_text_gen_models()
],
}
)
uvicorn.run(app, host=self.args.host, port=self.args.port, log_level=self.args.log_level)
def continuous_batching(self, app):
generation_config = GenerationConfig(
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id,
use_cache=False,
num_blocks=1,
block_size=1024,
do_sample=False,
max_batch_tokens=10,
scheduler="fifo",
)
manager: ContinuousBatchingManager = self.model.init_continuous_batching(
generation_config=generation_config, streaming=True
)
manager.start()
@app.post("/v1/chat/completions")
def _serve(req: "ChatCompletionInput"):
if not req.stream:
return {"error": "Only streaming mode is supported."}
update_model = req.model != self.loaded_model
if update_model:
self.model, self.tokenizer = self.load_model_and_tokenizer(req.model, self.args)
chat = req.messages
inputs = self.tokenizer.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=True).to(
self.model.device
)
generation_config = create_generation_config_from_req(req)
def stream_response(_inputs):
try:
max_new_tokens = req.max_tokens or generation_config.max_new_tokens or 256
request_id = manager.add_request(_inputs, request_id=req.request_id, max_new_tokens=max_new_tokens)
queue_is_flushed = False
for result in manager:
if req.request_id is not None and not queue_is_flushed:
if result.status == RequestStatus.FINISHED:
continue
else:
queue_is_flushed = True
finish_reason = "stop" if result.status == RequestStatus.FINISHED else None
yield self.build_chunk(result.next_token, request_id=request_id, finish_reason=finish_reason)
if result.status == RequestStatus.FINISHED:
break
yield "data: [DONE]\n\n"
except Exception as e:
logger.error(str(e))
yield f'data: {{"error": "{str(e)}"}}'
return StreamingResponse(stream_response(inputs[0]), media_type="text/event-stream")
def is_continuation(self, req: "ChatCompletionInput") -> bool:
""" """
Tokenize the provided input and eventually returns corresponding tokens id: - **text_input**: String to Determines whether the current request is a continuation of the last request. In other words, if it is the
tokenize - **return_ids**: Boolean flags indicating if the tokens have to be converted to their integer same chat session.
mapping.
"""
try:
tokens_txt = self._pipeline.tokenizer.tokenize(text_input)
if return_ids: Args:
tokens_ids = self._pipeline.tokenizer.convert_tokens_to_ids(tokens_txt) req (`ChatCompletionInput`): The request to check.
return ServeTokenizeResult(tokens=tokens_txt, tokens_ids=tokens_ids)
Returns:
`True` if the request is a continuation of the last request, `False` otherwise.
"""
req_continues_last_messages = True
# No cached messages: this is a new request
if self.last_messages is None:
req_continues_last_messages = False
# The new request has fewer rounds of conversation: this is a new request
elif len(self.last_messages) > len(req.messages):
req_continues_last_messages = False
# Otherwise, check that the last messages are a subset of the new request
else:
for i in range(len(self.last_messages)):
if self.last_messages[i] != req.messages[i]:
req_continues_last_messages = False
break
self.last_messages = req.messages
return req_continues_last_messages
def generate(self, app):
@app.post("/v1/chat/completions")
def _serve(req: "ChatCompletionInput"):
update_model = req.model != self.loaded_model
if update_model:
self.model, self.tokenizer = self.load_model_and_tokenizer(req.model, self.args)
if not req.stream:
return {"error": "Only streaming mode is supported."}
# HACK for tiny-agents: it sends a request after the assistant message (???). Let's assume we can't have a
# request whose last message is from the assistant.
if req.messages[-1].role == "assistant":
return
# ====== TOOL PREPROCESSING LOGIC ======
tool_model_family = None
for supported_model_families in _MODELS_WITH_TOOL_SUPPORT:
if supported_model_families in self.model.config.architectures[0].lower():
tool_model_family = supported_model_families
break
# TODO: trigger 2 constrained generations after the tool call start token is emitted:
# 1. force generation to pick from the tool names
# 2. force generation to pick from that tool's arguments
# ====== END OF TOOL PREPROCESSING LOGIC ======
if tool_model_family is not None:
text = self.tokenizer.apply_chat_template(
req.messages, add_generation_prompt=True, tokenize=False, tools=req.tools
)
else: else:
return ServeTokenizeResult(tokens=tokens_txt) text = self.tokenizer.apply_chat_template(req.messages, add_generation_prompt=True, tokenize=False)
except Exception as e: inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device)["input_ids"]
raise HTTPException(status_code=500, detail={"model": "", "error": str(e)}) request_id = req.request_id if req.request_id is not None else "req_0"
def detokenize( generation_streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True, skip_prompt=True)
self,
tokens_ids: list[int] = Body(None, embed=True),
skip_special_tokens: bool = Body(False, embed=True),
cleanup_tokenization_spaces: bool = Body(True, embed=True),
):
"""
Detokenize the provided tokens ids to readable text: - **tokens_ids**: List of tokens ids -
**skip_special_tokens**: Flag indicating to not try to decode special tokens - **cleanup_tokenization_spaces**:
Flag indicating to remove all leading/trailing spaces and intermediate ones.
"""
try:
decoded_str = self._pipeline.tokenizer.decode(tokens_ids, skip_special_tokens, cleanup_tokenization_spaces)
return ServeDeTokenizeResult(model="", text=decoded_str)
except Exception as e:
raise HTTPException(status_code=500, detail={"model": "", "error": str(e)})
async def forward(self, inputs=Body(None, embed=True)): generation_config = create_generation_config_from_req(req)
""" max_new_tokens = req.max_tokens or generation_config.max_new_tokens or 256
**inputs**: **attention_mask**: **tokens_type_ids**: generation_config.max_new_tokens = max_new_tokens
"""
# Check we don't have empty string last_kv_cache = None
if len(inputs) == 0: if self.is_continuation(req) and not update_model:
return ServeForwardResult(output=[], attention=[]) last_kv_cache = self.last_kv_cache
try: generation_kwargs = {
# Forward through the model "inputs": inputs,
output = self._pipeline(inputs) "attention_mask": torch.ones_like(inputs),
return ServeForwardResult(output=output) "streamer": generation_streamer,
except Exception as e: "generation_config": generation_config,
raise HTTPException(500, {"error": str(e)}) "return_dict_in_generate": True,
"past_key_values": last_kv_cache,
}
def stream_response(streamer, _request_id):
# Thin wrapper to save the KV cache after generation
def generate_with_cache(**kwargs):
generate_output = self.model.generate(**kwargs)
self.last_kv_cache = generate_output.past_key_values
thread = Thread(target=generate_with_cache, kwargs=generation_kwargs)
try:
thread.start()
tool_state = ToolState()
for result in streamer:
# ====== TOOL CALL LOGIC ======
if tool_model_family is not None:
# Start of a tool call: reset state variables, set `inside_tool_call`
if result.strip() == _TOOL_CALL_TOKENS[tool_model_family]["start"]:
tool_state.inside_tool_call = True
continue
# End of tool call: reset `inside_tool_call`, emit a `finish_reason`
if result.strip() == _TOOL_CALL_TOKENS[tool_model_family]["end"]:
tool_state.reset()
yield self.build_chunk("", _request_id, role=None, finish_reason="tool_calls")
continue
# Inside a tool call
if tool_state.inside_tool_call:
tool_state.buffer += result
# First step: extract the tool name (may need several tokens, and we can't emit a delta
# until we have the full name)
if not tool_state.has_tool_name_defined:
tool_name = re.search(r"\"name\": \"(.*?)\"", tool_state.buffer)
if tool_name is None:
continue
else:
tool_name = tool_name.group(1)
tool_state.has_tool_name_defined = True
tool = ChatCompletionStreamOutputDeltaToolCall(
function=ChatCompletionStreamOutputFunction(
name=tool_name,
arguments=None,
),
index=0,
type="function",
id=_request_id + "_tool_call", # Only the first tool call delta has an id
)
# Second step: extract tool arguments. The tool arguments can be seen as a json string
# within the tool json string. We emit a delta for the arguments.
else:
# Empty text: skip
if result == "":
continue
# Until we see the `"arguments": {` in the buffer, we skip
# TODO: other models will likely need more elaborate processing here
if '"arguments": {' not in tool_state.buffer:
continue
# Handle nesting. We want to exclude the last } from the emitted arguments (it's
# closing the outermost nesting level, outside the arguments block)
tool_state.arg_nesting_level += result.count("{")
tool_state.arg_nesting_level -= result.count("}")
if tool_state.arg_nesting_level < 0:
result = "".join(result.split("}")[:-2]) + "}" # e.g. "4}}\n" -> "4}"
tool = ChatCompletionStreamOutputDeltaToolCall(
function=ChatCompletionStreamOutputFunction(
arguments=result,
),
index=0,
type="function",
id=None,
)
yield self.build_chunk(None, _request_id, role=None, tool_calls=[tool])
continue
# ====== END OF TOOL CALL LOGIC ======
# All non-tool related tokens are emitted as assistant messages
yield self.build_chunk(result, _request_id, role="assistant")
yield self.build_chunk(None, _request_id, role=None, finish_reason="stop")
thread.join()
except Exception as e:
logger.error(str(e))
raise
yield f'data: {{"error": "{str(e)}"}}'
finally:
thread.join()
return StreamingResponse(stream_response(generation_streamer, request_id), media_type="text/event-stream")
@staticmethod
def get_quantization_config(model_args: ServeArguments) -> Optional["BitsAndBytesConfig"]:
if model_args.load_in_4bit:
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
# For consistency with model weights, we use the same value as `torch_dtype`
bnb_4bit_compute_dtype=model_args.torch_dtype,
bnb_4bit_quant_type=model_args.bnb_4bit_quant_type,
bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant,
bnb_4bit_quant_storage=model_args.torch_dtype,
)
elif model_args.load_in_8bit:
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
)
else:
quantization_config = None
return quantization_config
def load_model_and_tokenizer(
self, model_id_and_revision: str, args: ServeArguments
) -> tuple[PreTrainedModel, PreTrainedTokenizerFast]:
logger.warning(f"Loading {model_id_and_revision}")
if "@" in model_id_and_revision:
model_id, revision = model_id_and_revision.split("@", 1)
else:
model_id, revision = model_id_and_revision, "main"
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
trust_remote_code=args.trust_remote_code,
)
torch_dtype = args.torch_dtype if args.torch_dtype in ["auto", None] else getattr(torch, args.torch_dtype)
quantization_config = self.get_quantization_config(args)
model_kwargs = {
"revision": revision,
"attn_implementation": args.attn_implementation,
"torch_dtype": torch_dtype,
"device_map": "auto",
"quantization_config": quantization_config,
"trust_remote_code": args.trust_remote_code,
}
model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs)
if model.generation_config.max_new_tokens is not None and model.generation_config.max_new_tokens < 256:
model.generation_config.max_new_tokens = 256
if getattr(model, "hf_device_map", None) is None:
model = model.to(args.device)
self.loaded_model = model_id_and_revision
print("Loaded model", model_id_and_revision)
return model, tokenizer
if __name__ == "__main__":
serve = ServeCommand()
serve.run()

View File

@ -54,7 +54,7 @@ deps = {
"protobuf": "protobuf", "protobuf": "protobuf",
"psutil": "psutil", "psutil": "psutil",
"pyyaml": "pyyaml>=5.1", "pyyaml": "pyyaml>=5.1",
"pydantic": "pydantic", "pydantic": "pydantic>=2",
"pytest": "pytest>=7.2.0", "pytest": "pytest>=7.2.0",
"pytest-asyncio": "pytest-asyncio", "pytest-asyncio": "pytest-asyncio",
"pytest-rerunfailures": "pytest-rerunfailures", "pytest-rerunfailures": "pytest-rerunfailures",

View File

@ -27,6 +27,8 @@ from typing import Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from tokenizers import Tokenizer
from tokenizers.decoders import DecodeStream
from torch.profiler import profile, schedule, tensorboard_trace_handler from torch.profiler import profile, schedule, tensorboard_trace_handler
from tqdm import tqdm from tqdm import tqdm
@ -72,6 +74,7 @@ class GenerationOutput:
error: Optional[str] = None error: Optional[str] = None
status: RequestStatus = RequestStatus.PENDING status: RequestStatus = RequestStatus.PENDING
created_time: float = field(default_factory=time.time) created_time: float = field(default_factory=time.time)
next_token: Optional[int] = field(default_factory=int)
@dataclass @dataclass
@ -96,6 +99,7 @@ class RequestState:
eos_token_id: int = -1 eos_token_id: int = -1
created_time: float = field(default_factory=time.time) created_time: float = field(default_factory=time.time)
error: Optional[str] = None error: Optional[str] = None
next_token: Optional[str] = None
def current_len(self) -> int: def current_len(self) -> int:
"""Get the current length of the sequence (prompt + generated tokens).""" """Get the current length of the sequence (prompt + generated tokens)."""
@ -139,6 +143,7 @@ class RequestState:
generated_tokens=self.static_outputs, generated_tokens=self.static_outputs,
logprobs=[], logprobs=[],
error=self.error, error=self.error,
next_token=self.next_token,
) )
@ -764,6 +769,9 @@ class ContinuousBatchProcessor:
self.setup_static_tensors() self.setup_static_tensors()
self.tokenizer = Tokenizer.from_pretrained(self.config._name_or_path)
self.decode_stream = DecodeStream(skip_special_tokens=True)
@traced(standalone=True) @traced(standalone=True)
def setup_static_tensors(self): def setup_static_tensors(self):
T = self.max_batch_tokens T = self.max_batch_tokens
@ -995,7 +1003,7 @@ class ContinuousBatchProcessor:
def _maybe_send_output(self, state: RequestState, token: int): def _maybe_send_output(self, state: RequestState, token: int):
"""Send output to the queue based on streaming mode and request state.""" """Send output to the queue based on streaming mode and request state."""
if self.streaming: if self.streaming:
state.next_token = token state.next_token = self.decode_stream.step(self.tokenizer, state.static_outputs[-1])
self.output_queue.put(state.to_generation_output()) self.output_queue.put(state.to_generation_output())
elif state.status == RequestStatus.FINISHED: elif state.status == RequestStatus.FINISHED:
self.output_queue.put(state.to_generation_output()) self.output_queue.put(state.to_generation_output())
@ -1102,6 +1110,7 @@ class ContinuousBatchingManager:
self.profile = getattr(generation_config, "profile", False) self.profile = getattr(generation_config, "profile", False)
self.manual_eviction = manual_eviction self.manual_eviction = manual_eviction
self.batch_processor: Optional[ContinuousBatchProcessor] = None self.batch_processor: Optional[ContinuousBatchProcessor] = None
self.decode_stream = DecodeStream(skip_special_tokens=True)
@traced @traced
def start(self): def start(self):

View File

@ -292,6 +292,30 @@ except importlib.metadata.PackageNotFoundError:
_essentia_version = False _essentia_version = False
_pydantic_available = importlib.util.find_spec("pydantic") is not None
try:
_pydantic_version = importlib.metadata.version("pydantic")
logger.debug(f"Successfully imported pydantic version {_pydantic_version}")
except importlib.metadata.PackageNotFoundError:
_pydantic_available = False
_fastapi_available = importlib.util.find_spec("fastapi") is not None
try:
_fastapi_version = importlib.metadata.version("fastapi")
logger.debug(f"Successfully imported pydantic version {_fastapi_version}")
except importlib.metadata.PackageNotFoundError:
_fastapi_available = False
_uvicorn_available = importlib.util.find_spec("uvicorn") is not None
try:
_uvicorn_version = importlib.metadata.version("uvicorn")
logger.debug(f"Successfully imported pydantic version {_uvicorn_version}")
except importlib.metadata.PackageNotFoundError:
_uvicorn_available = False
_pretty_midi_available = importlib.util.find_spec("pretty_midi") is not None _pretty_midi_available = importlib.util.find_spec("pretty_midi") is not None
try: try:
_pretty_midi_version = importlib.metadata.version("pretty_midi") _pretty_midi_version = importlib.metadata.version("pretty_midi")
@ -473,6 +497,18 @@ def is_essentia_available():
return _essentia_available return _essentia_available
def is_pydantic_available():
return _pydantic_available
def is_fastapi_available():
return _fastapi_available
def is_uvicorn_available():
return _uvicorn_available
def is_pretty_midi_available(): def is_pretty_midi_available():
return _pretty_midi_available return _pretty_midi_available
@ -1843,6 +1879,23 @@ VISION_IMPORT_ERROR = """
`pip install pillow`. Please note that you may need to restart your runtime after installation. `pip install pillow`. Please note that you may need to restart your runtime after installation.
""" """
# docstyle-ignore
PYDANTIC_IMPORT_ERROR = """
{0} requires the pydantic library but it was not found in your environment. You can install it with pip:
`pip install pydantic`. Please note that you may need to restart your runtime after installation.
"""
# docstyle-ignore
FASTAPI_IMPORT_ERROR = """
{0} requires the fastapi library but it was not found in your environment. You can install it with pip:
`pip install fastapi`. Please note that you may need to restart your runtime after installation.
"""
# docstyle-ignore
UVICORN_IMPORT_ERROR = """
{0} requires the uvicorn library but it was not found in your environment. You can install it with pip:
`pip install uvicorn`. Please note that you may need to restart your runtime after installation.
"""
# docstyle-ignore # docstyle-ignore
PYTESSERACT_IMPORT_ERROR = """ PYTESSERACT_IMPORT_ERROR = """
@ -1966,6 +2019,9 @@ BACKENDS_MAPPING = OrderedDict(
("yt_dlp", (is_yt_dlp_available, YT_DLP_IMPORT_ERROR)), ("yt_dlp", (is_yt_dlp_available, YT_DLP_IMPORT_ERROR)),
("rich", (is_rich_available, RICH_IMPORT_ERROR)), ("rich", (is_rich_available, RICH_IMPORT_ERROR)),
("keras_nlp", (is_keras_nlp_available, KERAS_NLP_IMPORT_ERROR)), ("keras_nlp", (is_keras_nlp_available, KERAS_NLP_IMPORT_ERROR)),
("pydantic", (is_pydantic_available, PYDANTIC_IMPORT_ERROR)),
("fastapi", (is_fastapi_available, FASTAPI_IMPORT_ERROR)),
("uvicorn", (is_uvicorn_available, UVICORN_IMPORT_ERROR)),
] ]
) )

View File

@ -0,0 +1,65 @@
import os
import tempfile
import unittest
from unittest.mock import patch
import transformers.commands.transformers_cli as cli
from transformers.commands.chat import ChatArguments, ChatCommand
from transformers.testing_utils import CaptureStd
class ChatCLITest(unittest.TestCase):
def test_help(self):
with patch("sys.argv", ["transformers", "chat", "--help"]), CaptureStd() as cs:
with self.assertRaises(SystemExit):
cli.main()
self.assertIn("chat interface", cs.out.lower())
@patch.object(ChatCommand, "run")
def test_cli_dispatch(self, run_mock):
args = ["transformers", "chat", "hf-internal-testing/tiny-random-gpt2"]
with patch("sys.argv", args):
cli.main()
run_mock.assert_called_once()
def test_parsed_args(self):
with (
patch.object(ChatCommand, "__init__", return_value=None) as init_mock,
patch.object(ChatCommand, "run") as run_mock,
patch(
"sys.argv",
[
"transformers",
"chat",
"test-model",
"max_new_tokens=64",
],
),
):
cli.main()
init_mock.assert_called_once()
run_mock.assert_called_once()
parsed_args = init_mock.call_args[0][0]
self.assertEqual(parsed_args.model_name_or_path_or_address, "test-model")
self.assertEqual(parsed_args.generate_flags, ["max_new_tokens=64"])
class ChatUtilitiesTest(unittest.TestCase):
def test_save_and_clear_chat(self):
tmp_path = tempfile.mkdtemp()
args = ChatArguments(save_folder=str(tmp_path))
args.model_name_or_path_or_address = "test-model"
chat_history = [{"role": "user", "content": "hi"}]
filename = ChatCommand.save_chat(chat_history, args)
self.assertTrue(os.path.isfile(filename))
cleared = ChatCommand.clear_chat_history()
self.assertEqual(cleared, [])
def test_parse_generate_flags(self):
dummy = ChatCommand.__new__(ChatCommand)
parsed = ChatCommand.parse_generate_flags(dummy, ["temperature=0.5", "max_new_tokens=10"])
self.assertEqual(parsed["temperature"], 0.5)
self.assertEqual(parsed["max_new_tokens"], 10)

View File

@ -0,0 +1,34 @@
import unittest
from unittest.mock import patch
import transformers.commands.transformers_cli as cli
from transformers.commands.serving import ServeCommand
from transformers.testing_utils import CaptureStd
class ServeCLITest(unittest.TestCase):
def test_help(self):
with patch("sys.argv", ["transformers", "serve", "--help"]), CaptureStd() as cs:
with self.assertRaises(SystemExit):
cli.main()
self.assertIn("serve", cs.out.lower())
def test_parsed_args(self):
with (
patch.object(ServeCommand, "__init__", return_value=None) as init_mock,
patch.object(ServeCommand, "run") as run_mock,
patch("sys.argv", ["transformers", "serve", "--host", "0.0.0.0", "--port", "9000"]),
):
cli.main()
init_mock.assert_called_once()
run_mock.assert_called_once()
parsed_args = init_mock.call_args[0][0]
self.assertEqual(parsed_args.host, "0.0.0.0")
self.assertEqual(parsed_args.port, 9000)
def test_build_chunk(self):
dummy = ServeCommand.__new__(ServeCommand)
dummy.args = type("Args", (), {})()
chunk = ServeCommand.build_chunk(dummy, "hello", "req0", finish_reason="stop")
self.assertIn("chat.completion.chunk", chunk)
self.assertIn("data:", chunk)