Create local Transformers Engine (#33218)

* Create local Transformers Engine
This commit is contained in:
Aymeric Roucher 2024-08-30 18:22:27 +02:00 committed by GitHub
parent b017a9eb11
commit c79bfc71b8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 85 additions and 27 deletions

View File

@ -126,12 +126,13 @@ Additionally, `llm_engine` can also take a `grammar` argument. In the case where
You will also need a `tools` argument which accepts a list of `Tools` - it can be an empty list. You can also add the default toolbox on top of your `tools` list by defining the optional argument `add_base_tools=True`.
Now you can create an agent, like [`CodeAgent`], and run it. For convenience, we also provide the [`HfEngine`] class that uses `huggingface_hub.InferenceClient` under the hood.
Now you can create an agent, like [`CodeAgent`], and run it. You can also create a [`TransformersEngine`] with a pre-initialized pipeline to run inference on your local machine using `transformers`.
For convenience, since agentic behaviours generally require stronger models such as `Llama-3.1-70B-Instruct` that are harder to run locally for now, we also provide the [`HfApiEngine`] class that initializes a `huggingface_hub.InferenceClient` under the hood.
```python
from transformers import CodeAgent, HfEngine
from transformers import CodeAgent, HfApiEngine
llm_engine = HfEngine(model="meta-llama/Meta-Llama-3-70B-Instruct")
llm_engine = HfApiEngine(model="meta-llama/Meta-Llama-3-70B-Instruct")
agent = CodeAgent(tools=[], llm_engine=llm_engine, add_base_tools=True)
agent.run(
@ -141,7 +142,7 @@ agent.run(
```
This will be handy in case of emergency baguette need!
You can even leave the argument `llm_engine` undefined, and an [`HfEngine`] will be created by default.
You can even leave the argument `llm_engine` undefined, and an [`HfApiEngine`] will be created by default.
```python
from transformers import CodeAgent
@ -521,14 +522,14 @@ import gradio as gr
from transformers import (
load_tool,
ReactCodeAgent,
HfEngine,
HfApiEngine,
stream_to_gradio,
)
# Import tool from Hub
image_generation_tool = load_tool("m-ric/text-to-image")
llm_engine = HfEngine("meta-llama/Meta-Llama-3-70B-Instruct")
llm_engine = HfApiEngine("meta-llama/Meta-Llama-3-70B-Instruct")
# Initialize the agent with the image generation tool
agent = ReactCodeAgent(tools=[image_generation_tool], llm_engine=llm_engine)

View File

@ -87,12 +87,33 @@ These engines have the following specification:
1. Follow the [messages format](../chat_templating.md) for its input (`List[Dict[str, str]]`) and return a string.
2. Stop generating outputs *before* the sequences passed in the argument `stop_sequences`
### HfEngine
### TransformersEngine
For convenience, we have added a `HfEngine` that implements the points above and uses an inference endpoint for the execution of the LLM.
For convenience, we have added a `TransformersEngine` that implements the points above, taking a pre-initialized `Pipeline` as input.
```python
>>> from transformers import HfEngine
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, TransformersEngine
>>> model_name = "HuggingFaceTB/SmolLM-135M-Instruct"
>>> tokenizer = AutoTokenizer.from_pretrained(model_name)
>>> model = AutoModelForCausalLM.from_pretrained(model_name)
>>> pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
>>> engine = TransformersEngine(pipe)
>>> engine([{"role": "user", "content": "Ok!"}], stop_sequences=["great"])
"What a "
```
[[autodoc]] TransformersEngine
### HfApiEngine
The `HfApiEngine` is an engine that wraps an [HF Inference API](https://huggingface.co/docs/api-inference/index) client for the execution of the LLM.
```python
>>> from transformers import HfApiEngine
>>> messages = [
... {"role": "user", "content": "Hello, how are you?"},
@ -100,12 +121,12 @@ For convenience, we have added a `HfEngine` that implements the points above and
... {"role": "user", "content": "No need to help, take it easy."},
... ]
>>> HfEngine()(messages, stop_sequences=["conversation"])
>>> HfApiEngine()(messages, stop_sequences=["conversation"])
"That's very kind of you to say! It's always nice to have a relaxed "
```
[[autodoc]] HfEngine
[[autodoc]] HfApiEngine
## Agent Types

View File

@ -83,12 +83,12 @@ API나 기반 모델이 자주 업데이트되므로, 에이전트가 제공하
1. 입력(`List[Dict[str, str]]`)에 대한 [메시지 형식](../chat_templating.md)을 따르고 문자열을 반환해야 합니다.
2. 인수 `stop_sequences`에 시퀀스가 전달되기 *전에* 출력을 생성하는 것을 중지해야 합니다.
### HfEngine [[hfengine]]
### HfApiEngine [[HfApiEngine]]
편의를 위해, 위의 사항을 구현하고 대규모 언어 모델 실행을 위해 추론 엔드포인트를 사용하는 `HfEngine`을 추가했습니다.
편의를 위해, 위의 사항을 구현하고 대규모 언어 모델 실행을 위해 추론 엔드포인트를 사용하는 `HfApiEngine`을 추가했습니다.
```python
>>> from transformers import HfEngine
>>> from transformers import HfApiEngine
>>> messages = [
... {"role": "user", "content": "Hello, how are you?"},
@ -96,12 +96,12 @@ API나 기반 모델이 자주 업데이트되므로, 에이전트가 제공하
... {"role": "user", "content": "No need to help, take it easy."},
... ]
>>> HfEngine()(messages, stop_sequences=["conversation"])
>>> HfApiEngine()(messages, stop_sequences=["conversation"])
"That's very kind of you to say! It's always nice to have a relaxed "
```
[[autodoc]] HfEngine
[[autodoc]] HfApiEngine
## 에이전트 유형 [[agent-types]]

View File

@ -57,7 +57,7 @@ _import_structure = {
"agents": [
"Agent",
"CodeAgent",
"HfEngine",
"HfApiEngine",
"PipelineTool",
"ReactAgent",
"ReactCodeAgent",
@ -65,6 +65,7 @@ _import_structure = {
"Tool",
"Toolbox",
"ToolCollection",
"TransformersEngine",
"launch_gradio_demo",
"load_tool",
"stream_to_gradio",
@ -4806,7 +4807,7 @@ if TYPE_CHECKING:
from .agents import (
Agent,
CodeAgent,
HfEngine,
HfApiEngine,
PipelineTool,
ReactAgent,
ReactCodeAgent,
@ -4814,6 +4815,7 @@ if TYPE_CHECKING:
Tool,
Toolbox,
ToolCollection,
TransformersEngine,
launch_gradio_demo,
load_tool,
stream_to_gradio,

View File

@ -25,7 +25,7 @@ from ..utils import (
_import_structure = {
"agents": ["Agent", "CodeAgent", "ReactAgent", "ReactCodeAgent", "ReactJsonAgent", "Toolbox"],
"llm_engine": ["HfEngine"],
"llm_engine": ["HfApiEngine", "TransformersEngine"],
"monitoring": ["stream_to_gradio"],
"tools": ["PipelineTool", "Tool", "ToolCollection", "launch_gradio_demo", "load_tool"],
}
@ -45,7 +45,7 @@ else:
if TYPE_CHECKING:
from .agents import Agent, CodeAgent, ReactAgent, ReactCodeAgent, ReactJsonAgent, Toolbox
from .llm_engine import HfEngine
from .llm_engine import HfApiEngine, TransformersEngine
from .monitoring import stream_to_gradio
from .tools import PipelineTool, Tool, ToolCollection, launch_gradio_demo, load_tool

View File

@ -24,7 +24,7 @@ from ..utils import logging as transformers_logging
from ..utils.import_utils import is_pygments_available
from .agent_types import AgentAudio, AgentImage, AgentText
from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool, setup_default_tools
from .llm_engine import HfEngine, MessageRole
from .llm_engine import HfApiEngine, MessageRole
from .prompts import (
DEFAULT_CODE_SYSTEM_PROMPT,
DEFAULT_REACT_CODE_SYSTEM_PROMPT,
@ -327,7 +327,7 @@ class Agent:
def __init__(
self,
tools: Union[List[Tool], Toolbox],
llm_engine: Callable = HfEngine(),
llm_engine: Callable = HfApiEngine(),
system_prompt=DEFAULT_REACT_CODE_SYSTEM_PROMPT,
tool_description_template=None,
additional_args={},
@ -532,7 +532,7 @@ class CodeAgent(Agent):
def __init__(
self,
tools: List[Tool],
llm_engine: Callable = HfEngine(),
llm_engine: Callable = HfApiEngine(),
system_prompt: str = DEFAULT_CODE_SYSTEM_PROMPT,
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
grammar: Dict[str, str] = None,
@ -655,7 +655,7 @@ class ReactAgent(Agent):
def __init__(
self,
tools: List[Tool],
llm_engine: Callable = HfEngine(),
llm_engine: Callable = HfApiEngine(),
system_prompt: str = DEFAULT_REACT_CODE_SYSTEM_PROMPT,
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
grammar: Dict[str, str] = None,
@ -886,7 +886,7 @@ class ReactJsonAgent(ReactAgent):
def __init__(
self,
tools: List[Tool],
llm_engine: Callable = HfEngine(),
llm_engine: Callable = HfApiEngine(),
system_prompt: str = DEFAULT_REACT_JSON_SYSTEM_PROMPT,
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
grammar: Dict[str, str] = None,
@ -992,7 +992,7 @@ class ReactCodeAgent(ReactAgent):
def __init__(
self,
tools: List[Tool],
llm_engine: Callable = HfEngine(),
llm_engine: Callable = HfApiEngine(),
system_prompt: str = DEFAULT_REACT_CODE_SYSTEM_PROMPT,
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
grammar: Dict[str, str] = None,

View File

@ -20,6 +20,8 @@ from typing import Dict, List, Optional
from huggingface_hub import InferenceClient
from ..pipelines.base import Pipeline
class MessageRole(str, Enum):
USER = "user"
@ -65,7 +67,9 @@ llama_role_conversions = {
}
class HfEngine:
class HfApiEngine:
"""This engine leverages Hugging Face's Inference API service, either serverless or with a dedicated endpoint."""
def __init__(self, model: str = "meta-llama/Meta-Llama-3.1-8B-Instruct"):
self.model = model
self.client = InferenceClient(self.model, timeout=120)
@ -93,6 +97,36 @@ class HfEngine:
return response
class TransformersEngine:
"""This engine uses a pre-initialized local text-generation pipeline."""
def __init__(self, pipeline: Pipeline):
self.pipeline = pipeline
def __call__(
self, messages: List[Dict[str, str]], stop_sequences: Optional[List[str]] = None, grammar: Optional[str] = None
) -> str:
# Get clean message list
messages = get_clean_message_list(messages, role_conversions=llama_role_conversions)
# Get LLM output
output = self.pipeline(
messages,
stop_strings=stop_sequences,
max_length=1500,
tokenizer=self.pipeline.tokenizer,
)
response = output[0]["generated_text"][-1]["content"]
# Remove stop sequences from LLM output
if stop_sequences is not None:
for stop_seq in stop_sequences:
if response[-len(stop_seq) :] == stop_seq:
response = response[: -len(stop_seq)]
return response
DEFAULT_JSONAGENT_REGEX_GRAMMAR = {
"type": "regex",
"value": 'Thought: .+?\\nAction:\\n\\{\\n\\s{4}"action":\\s"[^"\\n]+",\\n\\s{4}"action_input":\\s"[^"\\n]+"\\n\\}\\n<end_action>',