mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
Add stream messages from agent run for gradio chatbot (#32142)
* Add stream_to_gradio method for running agent in gradio demo
This commit is contained in:
parent
811a9caa21
commit
a24a9a66f4
@ -509,3 +509,54 @@ agent = ReactCodeAgent(tools=[search_tool])
|
|||||||
|
|
||||||
agent.run("How many more blocks (also denoted as layers) in BERT base encoder than the encoder from the architecture proposed in Attention is All You Need?")
|
agent.run("How many more blocks (also denoted as layers) in BERT base encoder than the encoder from the architecture proposed in Attention is All You Need?")
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Gradio interface
|
||||||
|
|
||||||
|
You can leverage `gradio.Chatbot`to display your agent's thoughts using `stream_to_gradio`, here is an example:
|
||||||
|
|
||||||
|
```py
|
||||||
|
import gradio as gr
|
||||||
|
from transformers import (
|
||||||
|
load_tool,
|
||||||
|
ReactCodeAgent,
|
||||||
|
HfEngine,
|
||||||
|
stream_to_gradio,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Import tool from Hub
|
||||||
|
image_generation_tool = load_tool("m-ric/text-to-image")
|
||||||
|
|
||||||
|
llm_engine = HfEngine("meta-llama/Meta-Llama-3-70B-Instruct")
|
||||||
|
|
||||||
|
# Initialize the agent with the image generation tool
|
||||||
|
agent = ReactCodeAgent(tools=[image_generation_tool], llm_engine=llm_engine)
|
||||||
|
|
||||||
|
|
||||||
|
def interact_with_agent(task):
|
||||||
|
messages = []
|
||||||
|
messages.append(gr.ChatMessage(role="user", content=task))
|
||||||
|
yield messages
|
||||||
|
for msg in stream_to_gradio(agent, task):
|
||||||
|
messages.append(msg)
|
||||||
|
yield messages + [
|
||||||
|
gr.ChatMessage(role="assistant", content="⏳ Task not finished yet!")
|
||||||
|
]
|
||||||
|
yield messages
|
||||||
|
|
||||||
|
|
||||||
|
with gr.Blocks() as demo:
|
||||||
|
text_input = gr.Textbox(lines=1, label="Chat Message", value="Make me a picture of the Statue of Liberty.")
|
||||||
|
submit = gr.Button("Run illustrator agent!")
|
||||||
|
chatbot = gr.Chatbot(
|
||||||
|
label="Agent",
|
||||||
|
type="messages",
|
||||||
|
avatar_images=(
|
||||||
|
None,
|
||||||
|
"https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
submit.click(interact_with_agent, [text_input], [chatbot])
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
demo.launch()
|
||||||
|
```
|
@ -72,6 +72,10 @@ We provide two types of agents, based on the main [`Agent`] class:
|
|||||||
|
|
||||||
[[autodoc]] launch_gradio_demo
|
[[autodoc]] launch_gradio_demo
|
||||||
|
|
||||||
|
### stream_to_gradio
|
||||||
|
|
||||||
|
[[autodoc]] stream_to_gradio
|
||||||
|
|
||||||
### ToolCollection
|
### ToolCollection
|
||||||
|
|
||||||
[[autodoc]] ToolCollection
|
[[autodoc]] ToolCollection
|
||||||
|
@ -67,6 +67,7 @@ _import_structure = {
|
|||||||
"ToolCollection",
|
"ToolCollection",
|
||||||
"launch_gradio_demo",
|
"launch_gradio_demo",
|
||||||
"load_tool",
|
"load_tool",
|
||||||
|
"stream_to_gradio",
|
||||||
],
|
],
|
||||||
"audio_utils": [],
|
"audio_utils": [],
|
||||||
"benchmark": [],
|
"benchmark": [],
|
||||||
@ -4733,6 +4734,7 @@ if TYPE_CHECKING:
|
|||||||
ToolCollection,
|
ToolCollection,
|
||||||
launch_gradio_demo,
|
launch_gradio_demo,
|
||||||
load_tool,
|
load_tool,
|
||||||
|
stream_to_gradio,
|
||||||
)
|
)
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
|
@ -26,6 +26,7 @@ from ..utils import (
|
|||||||
_import_structure = {
|
_import_structure = {
|
||||||
"agents": ["Agent", "CodeAgent", "ReactAgent", "ReactCodeAgent", "ReactJsonAgent", "Toolbox"],
|
"agents": ["Agent", "CodeAgent", "ReactAgent", "ReactCodeAgent", "ReactJsonAgent", "Toolbox"],
|
||||||
"llm_engine": ["HfEngine"],
|
"llm_engine": ["HfEngine"],
|
||||||
|
"monitoring": ["stream_to_gradio"],
|
||||||
"tools": ["PipelineTool", "Tool", "ToolCollection", "launch_gradio_demo", "load_tool"],
|
"tools": ["PipelineTool", "Tool", "ToolCollection", "launch_gradio_demo", "load_tool"],
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -45,6 +46,7 @@ else:
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .agents import Agent, CodeAgent, ReactAgent, ReactCodeAgent, ReactJsonAgent, Toolbox
|
from .agents import Agent, CodeAgent, ReactAgent, ReactCodeAgent, ReactJsonAgent, Toolbox
|
||||||
from .llm_engine import HfEngine
|
from .llm_engine import HfEngine
|
||||||
|
from .monitoring import stream_to_gradio
|
||||||
from .tools import PipelineTool, Tool, ToolCollection, launch_gradio_demo, load_tool
|
from .tools import PipelineTool, Tool, ToolCollection, launch_gradio_demo, load_tool
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
75
src/transformers/agents/monitoring.py
Normal file
75
src/transformers/agents/monitoring.py
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# coding=utf-8
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from .agent_types import AgentAudio, AgentImage, AgentText
|
||||||
|
from .agents import ReactAgent
|
||||||
|
|
||||||
|
|
||||||
|
def pull_message(step_log: dict):
|
||||||
|
try:
|
||||||
|
from gradio import ChatMessage
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Gradio should be installed in order to launch a gradio demo.")
|
||||||
|
|
||||||
|
if step_log.get("rationale"):
|
||||||
|
yield ChatMessage(role="assistant", content=step_log["rationale"])
|
||||||
|
if step_log.get("tool_call"):
|
||||||
|
used_code = step_log["tool_call"]["tool_name"] == "code interpreter"
|
||||||
|
content = step_log["tool_call"]["tool_arguments"]
|
||||||
|
if used_code:
|
||||||
|
content = f"```py\n{content}\n```"
|
||||||
|
yield ChatMessage(
|
||||||
|
role="assistant",
|
||||||
|
metadata={"title": f"🛠️ Used tool {step_log['tool_call']['tool_name']}"},
|
||||||
|
content=content,
|
||||||
|
)
|
||||||
|
if step_log.get("observation"):
|
||||||
|
yield ChatMessage(role="assistant", content=f"```\n{step_log['observation']}\n```")
|
||||||
|
if step_log.get("error"):
|
||||||
|
yield ChatMessage(
|
||||||
|
role="assistant",
|
||||||
|
content=str(step_log["error"]),
|
||||||
|
metadata={"title": "💥 Error"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def stream_to_gradio(agent: ReactAgent, task: str, **kwargs):
|
||||||
|
"""Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages."""
|
||||||
|
|
||||||
|
try:
|
||||||
|
from gradio import ChatMessage
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Gradio should be installed in order to launch a gradio demo.")
|
||||||
|
|
||||||
|
for step_log in agent.run(task, stream=True, **kwargs):
|
||||||
|
if isinstance(step_log, dict):
|
||||||
|
for message in pull_message(step_log):
|
||||||
|
yield message
|
||||||
|
|
||||||
|
if isinstance(step_log, AgentText):
|
||||||
|
yield ChatMessage(role="assistant", content=f"**Final answer:**\n```\n{step_log.to_string()}\n```")
|
||||||
|
elif isinstance(step_log, AgentImage):
|
||||||
|
yield ChatMessage(
|
||||||
|
role="assistant",
|
||||||
|
content={"path": step_log.to_string(), "mime_type": "image/png"},
|
||||||
|
)
|
||||||
|
elif isinstance(step_log, AgentAudio):
|
||||||
|
yield ChatMessage(
|
||||||
|
role="assistant",
|
||||||
|
content={"path": step_log.to_string(), "mime_type": "audio/wav"},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
yield ChatMessage(role="assistant", content=step_log)
|
Loading…
Reference in New Issue
Block a user