mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Create local Transformers Engine (#33218)
* Create local Transformers Engine
This commit is contained in:
parent
b017a9eb11
commit
c79bfc71b8
@ -126,12 +126,13 @@ Additionally, `llm_engine` can also take a `grammar` argument. In the case where
|
||||
|
||||
You will also need a `tools` argument which accepts a list of `Tools` - it can be an empty list. You can also add the default toolbox on top of your `tools` list by defining the optional argument `add_base_tools=True`.
|
||||
|
||||
Now you can create an agent, like [`CodeAgent`], and run it. For convenience, we also provide the [`HfEngine`] class that uses `huggingface_hub.InferenceClient` under the hood.
|
||||
Now you can create an agent, like [`CodeAgent`], and run it. You can also create a [`TransformersEngine`] with a pre-initialized pipeline to run inference on your local machine using `transformers`.
|
||||
For convenience, since agentic behaviours generally require stronger models such as `Llama-3.1-70B-Instruct` that are harder to run locally for now, we also provide the [`HfApiEngine`] class that initializes a `huggingface_hub.InferenceClient` under the hood.
|
||||
|
||||
```python
|
||||
from transformers import CodeAgent, HfEngine
|
||||
from transformers import CodeAgent, HfApiEngine
|
||||
|
||||
llm_engine = HfEngine(model="meta-llama/Meta-Llama-3-70B-Instruct")
|
||||
llm_engine = HfApiEngine(model="meta-llama/Meta-Llama-3-70B-Instruct")
|
||||
agent = CodeAgent(tools=[], llm_engine=llm_engine, add_base_tools=True)
|
||||
|
||||
agent.run(
|
||||
@ -141,7 +142,7 @@ agent.run(
|
||||
```
|
||||
|
||||
This will be handy in case of emergency baguette need!
|
||||
You can even leave the argument `llm_engine` undefined, and an [`HfEngine`] will be created by default.
|
||||
You can even leave the argument `llm_engine` undefined, and an [`HfApiEngine`] will be created by default.
|
||||
|
||||
```python
|
||||
from transformers import CodeAgent
|
||||
@ -521,14 +522,14 @@ import gradio as gr
|
||||
from transformers import (
|
||||
load_tool,
|
||||
ReactCodeAgent,
|
||||
HfEngine,
|
||||
HfApiEngine,
|
||||
stream_to_gradio,
|
||||
)
|
||||
|
||||
# Import tool from Hub
|
||||
image_generation_tool = load_tool("m-ric/text-to-image")
|
||||
|
||||
llm_engine = HfEngine("meta-llama/Meta-Llama-3-70B-Instruct")
|
||||
llm_engine = HfApiEngine("meta-llama/Meta-Llama-3-70B-Instruct")
|
||||
|
||||
# Initialize the agent with the image generation tool
|
||||
agent = ReactCodeAgent(tools=[image_generation_tool], llm_engine=llm_engine)
|
||||
|
@ -87,12 +87,33 @@ These engines have the following specification:
|
||||
1. Follow the [messages format](../chat_templating.md) for its input (`List[Dict[str, str]]`) and return a string.
|
||||
2. Stop generating outputs *before* the sequences passed in the argument `stop_sequences`
|
||||
|
||||
### HfEngine
|
||||
### TransformersEngine
|
||||
|
||||
For convenience, we have added a `HfEngine` that implements the points above and uses an inference endpoint for the execution of the LLM.
|
||||
For convenience, we have added a `TransformersEngine` that implements the points above, taking a pre-initialized `Pipeline` as input.
|
||||
|
||||
```python
|
||||
>>> from transformers import HfEngine
|
||||
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, TransformersEngine
|
||||
|
||||
>>> model_name = "HuggingFaceTB/SmolLM-135M-Instruct"
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
>>> model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||
|
||||
>>> pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
|
||||
|
||||
>>> engine = TransformersEngine(pipe)
|
||||
>>> engine([{"role": "user", "content": "Ok!"}], stop_sequences=["great"])
|
||||
|
||||
"What a "
|
||||
```
|
||||
|
||||
[[autodoc]] TransformersEngine
|
||||
|
||||
### HfApiEngine
|
||||
|
||||
The `HfApiEngine` is an engine that wraps an [HF Inference API](https://huggingface.co/docs/api-inference/index) client for the execution of the LLM.
|
||||
|
||||
```python
|
||||
>>> from transformers import HfApiEngine
|
||||
|
||||
>>> messages = [
|
||||
... {"role": "user", "content": "Hello, how are you?"},
|
||||
@ -100,12 +121,12 @@ For convenience, we have added a `HfEngine` that implements the points above and
|
||||
... {"role": "user", "content": "No need to help, take it easy."},
|
||||
... ]
|
||||
|
||||
>>> HfEngine()(messages, stop_sequences=["conversation"])
|
||||
>>> HfApiEngine()(messages, stop_sequences=["conversation"])
|
||||
|
||||
"That's very kind of you to say! It's always nice to have a relaxed "
|
||||
```
|
||||
|
||||
[[autodoc]] HfEngine
|
||||
[[autodoc]] HfApiEngine
|
||||
|
||||
|
||||
## Agent Types
|
||||
|
@ -83,12 +83,12 @@ API나 기반 모델이 자주 업데이트되므로, 에이전트가 제공하
|
||||
1. 입력(`List[Dict[str, str]]`)에 대한 [메시지 형식](../chat_templating.md)을 따르고 문자열을 반환해야 합니다.
|
||||
2. 인수 `stop_sequences`에 시퀀스가 전달되기 *전에* 출력을 생성하는 것을 중지해야 합니다.
|
||||
|
||||
### HfEngine [[hfengine]]
|
||||
### HfApiEngine [[HfApiEngine]]
|
||||
|
||||
편의를 위해, 위의 사항을 구현하고 대규모 언어 모델 실행을 위해 추론 엔드포인트를 사용하는 `HfEngine`을 추가했습니다.
|
||||
편의를 위해, 위의 사항을 구현하고 대규모 언어 모델 실행을 위해 추론 엔드포인트를 사용하는 `HfApiEngine`을 추가했습니다.
|
||||
|
||||
```python
|
||||
>>> from transformers import HfEngine
|
||||
>>> from transformers import HfApiEngine
|
||||
|
||||
>>> messages = [
|
||||
... {"role": "user", "content": "Hello, how are you?"},
|
||||
@ -96,12 +96,12 @@ API나 기반 모델이 자주 업데이트되므로, 에이전트가 제공하
|
||||
... {"role": "user", "content": "No need to help, take it easy."},
|
||||
... ]
|
||||
|
||||
>>> HfEngine()(messages, stop_sequences=["conversation"])
|
||||
>>> HfApiEngine()(messages, stop_sequences=["conversation"])
|
||||
|
||||
"That's very kind of you to say! It's always nice to have a relaxed "
|
||||
```
|
||||
|
||||
[[autodoc]] HfEngine
|
||||
[[autodoc]] HfApiEngine
|
||||
|
||||
|
||||
## 에이전트 유형 [[agent-types]]
|
||||
|
@ -57,7 +57,7 @@ _import_structure = {
|
||||
"agents": [
|
||||
"Agent",
|
||||
"CodeAgent",
|
||||
"HfEngine",
|
||||
"HfApiEngine",
|
||||
"PipelineTool",
|
||||
"ReactAgent",
|
||||
"ReactCodeAgent",
|
||||
@ -65,6 +65,7 @@ _import_structure = {
|
||||
"Tool",
|
||||
"Toolbox",
|
||||
"ToolCollection",
|
||||
"TransformersEngine",
|
||||
"launch_gradio_demo",
|
||||
"load_tool",
|
||||
"stream_to_gradio",
|
||||
@ -4806,7 +4807,7 @@ if TYPE_CHECKING:
|
||||
from .agents import (
|
||||
Agent,
|
||||
CodeAgent,
|
||||
HfEngine,
|
||||
HfApiEngine,
|
||||
PipelineTool,
|
||||
ReactAgent,
|
||||
ReactCodeAgent,
|
||||
@ -4814,6 +4815,7 @@ if TYPE_CHECKING:
|
||||
Tool,
|
||||
Toolbox,
|
||||
ToolCollection,
|
||||
TransformersEngine,
|
||||
launch_gradio_demo,
|
||||
load_tool,
|
||||
stream_to_gradio,
|
||||
|
@ -25,7 +25,7 @@ from ..utils import (
|
||||
|
||||
_import_structure = {
|
||||
"agents": ["Agent", "CodeAgent", "ReactAgent", "ReactCodeAgent", "ReactJsonAgent", "Toolbox"],
|
||||
"llm_engine": ["HfEngine"],
|
||||
"llm_engine": ["HfApiEngine", "TransformersEngine"],
|
||||
"monitoring": ["stream_to_gradio"],
|
||||
"tools": ["PipelineTool", "Tool", "ToolCollection", "launch_gradio_demo", "load_tool"],
|
||||
}
|
||||
@ -45,7 +45,7 @@ else:
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .agents import Agent, CodeAgent, ReactAgent, ReactCodeAgent, ReactJsonAgent, Toolbox
|
||||
from .llm_engine import HfEngine
|
||||
from .llm_engine import HfApiEngine, TransformersEngine
|
||||
from .monitoring import stream_to_gradio
|
||||
from .tools import PipelineTool, Tool, ToolCollection, launch_gradio_demo, load_tool
|
||||
|
||||
|
@ -24,7 +24,7 @@ from ..utils import logging as transformers_logging
|
||||
from ..utils.import_utils import is_pygments_available
|
||||
from .agent_types import AgentAudio, AgentImage, AgentText
|
||||
from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool, setup_default_tools
|
||||
from .llm_engine import HfEngine, MessageRole
|
||||
from .llm_engine import HfApiEngine, MessageRole
|
||||
from .prompts import (
|
||||
DEFAULT_CODE_SYSTEM_PROMPT,
|
||||
DEFAULT_REACT_CODE_SYSTEM_PROMPT,
|
||||
@ -327,7 +327,7 @@ class Agent:
|
||||
def __init__(
|
||||
self,
|
||||
tools: Union[List[Tool], Toolbox],
|
||||
llm_engine: Callable = HfEngine(),
|
||||
llm_engine: Callable = HfApiEngine(),
|
||||
system_prompt=DEFAULT_REACT_CODE_SYSTEM_PROMPT,
|
||||
tool_description_template=None,
|
||||
additional_args={},
|
||||
@ -532,7 +532,7 @@ class CodeAgent(Agent):
|
||||
def __init__(
|
||||
self,
|
||||
tools: List[Tool],
|
||||
llm_engine: Callable = HfEngine(),
|
||||
llm_engine: Callable = HfApiEngine(),
|
||||
system_prompt: str = DEFAULT_CODE_SYSTEM_PROMPT,
|
||||
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
|
||||
grammar: Dict[str, str] = None,
|
||||
@ -655,7 +655,7 @@ class ReactAgent(Agent):
|
||||
def __init__(
|
||||
self,
|
||||
tools: List[Tool],
|
||||
llm_engine: Callable = HfEngine(),
|
||||
llm_engine: Callable = HfApiEngine(),
|
||||
system_prompt: str = DEFAULT_REACT_CODE_SYSTEM_PROMPT,
|
||||
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
|
||||
grammar: Dict[str, str] = None,
|
||||
@ -886,7 +886,7 @@ class ReactJsonAgent(ReactAgent):
|
||||
def __init__(
|
||||
self,
|
||||
tools: List[Tool],
|
||||
llm_engine: Callable = HfEngine(),
|
||||
llm_engine: Callable = HfApiEngine(),
|
||||
system_prompt: str = DEFAULT_REACT_JSON_SYSTEM_PROMPT,
|
||||
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
|
||||
grammar: Dict[str, str] = None,
|
||||
@ -992,7 +992,7 @@ class ReactCodeAgent(ReactAgent):
|
||||
def __init__(
|
||||
self,
|
||||
tools: List[Tool],
|
||||
llm_engine: Callable = HfEngine(),
|
||||
llm_engine: Callable = HfApiEngine(),
|
||||
system_prompt: str = DEFAULT_REACT_CODE_SYSTEM_PROMPT,
|
||||
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
|
||||
grammar: Dict[str, str] = None,
|
||||
|
@ -20,6 +20,8 @@ from typing import Dict, List, Optional
|
||||
|
||||
from huggingface_hub import InferenceClient
|
||||
|
||||
from ..pipelines.base import Pipeline
|
||||
|
||||
|
||||
class MessageRole(str, Enum):
|
||||
USER = "user"
|
||||
@ -65,7 +67,9 @@ llama_role_conversions = {
|
||||
}
|
||||
|
||||
|
||||
class HfEngine:
|
||||
class HfApiEngine:
|
||||
"""This engine leverages Hugging Face's Inference API service, either serverless or with a dedicated endpoint."""
|
||||
|
||||
def __init__(self, model: str = "meta-llama/Meta-Llama-3.1-8B-Instruct"):
|
||||
self.model = model
|
||||
self.client = InferenceClient(self.model, timeout=120)
|
||||
@ -93,6 +97,36 @@ class HfEngine:
|
||||
return response
|
||||
|
||||
|
||||
class TransformersEngine:
|
||||
"""This engine uses a pre-initialized local text-generation pipeline."""
|
||||
|
||||
def __init__(self, pipeline: Pipeline):
|
||||
self.pipeline = pipeline
|
||||
|
||||
def __call__(
|
||||
self, messages: List[Dict[str, str]], stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None
|
||||
) -> str:
|
||||
# Get clean message list
|
||||
messages = get_clean_message_list(messages, role_conversions=llama_role_conversions)
|
||||
|
||||
# Get LLM output
|
||||
output = self.pipeline(
|
||||
messages,
|
||||
stop_strings=stop_sequences,
|
||||
max_length=1500,
|
||||
tokenizer=self.pipeline.tokenizer,
|
||||
)
|
||||
|
||||
response = output[0]["generated_text"][-1]["content"]
|
||||
|
||||
# Remove stop sequences from LLM output
|
||||
if stop_sequences is not None:
|
||||
for stop_seq in stop_sequences:
|
||||
if response[-len(stop_seq) :] == stop_seq:
|
||||
response = response[: -len(stop_seq)]
|
||||
return response
|
||||
|
||||
|
||||
DEFAULT_JSONAGENT_REGEX_GRAMMAR = {
|
||||
"type": "regex",
|
||||
"value": 'Thought: .+?\\nAction:\\n\\{\\n\\s{4}"action":\\s"[^"\\n]+",\\n\\s{4}"action_input":\\s"[^"\\n]+"\\n\\}\\n<end_action>',
|
||||
|
Loading…
Reference in New Issue
Block a user