mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 17:52:35 +06:00
Add stream messages from agent run for gradio chatbot (#32142)
* Add stream_to_gradio method for running agent in gradio demo
This commit is contained in:
parent
811a9caa21
commit
a24a9a66f4
@ -509,3 +509,54 @@ agent = ReactCodeAgent(tools=[search_tool])
|
||||
|
||||
agent.run("How many more blocks (also denoted as layers) in BERT base encoder than the encoder from the architecture proposed in Attention is All You Need?")
|
||||
```
|
||||
|
||||
## Gradio interface
|
||||
|
||||
You can leverage `gradio.Chatbot`to display your agent's thoughts using `stream_to_gradio`, here is an example:
|
||||
|
||||
```py
|
||||
import gradio as gr
|
||||
from transformers import (
|
||||
load_tool,
|
||||
ReactCodeAgent,
|
||||
HfEngine,
|
||||
stream_to_gradio,
|
||||
)
|
||||
|
||||
# Import tool from Hub
|
||||
image_generation_tool = load_tool("m-ric/text-to-image")
|
||||
|
||||
llm_engine = HfEngine("meta-llama/Meta-Llama-3-70B-Instruct")
|
||||
|
||||
# Initialize the agent with the image generation tool
|
||||
agent = ReactCodeAgent(tools=[image_generation_tool], llm_engine=llm_engine)
|
||||
|
||||
|
||||
def interact_with_agent(task):
|
||||
messages = []
|
||||
messages.append(gr.ChatMessage(role="user", content=task))
|
||||
yield messages
|
||||
for msg in stream_to_gradio(agent, task):
|
||||
messages.append(msg)
|
||||
yield messages + [
|
||||
gr.ChatMessage(role="assistant", content="⏳ Task not finished yet!")
|
||||
]
|
||||
yield messages
|
||||
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
text_input = gr.Textbox(lines=1, label="Chat Message", value="Make me a picture of the Statue of Liberty.")
|
||||
submit = gr.Button("Run illustrator agent!")
|
||||
chatbot = gr.Chatbot(
|
||||
label="Agent",
|
||||
type="messages",
|
||||
avatar_images=(
|
||||
None,
|
||||
"https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png",
|
||||
),
|
||||
)
|
||||
submit.click(interact_with_agent, [text_input], [chatbot])
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch()
|
||||
```
|
@ -72,6 +72,10 @@ We provide two types of agents, based on the main [`Agent`] class:
|
||||
|
||||
[[autodoc]] launch_gradio_demo
|
||||
|
||||
### stream_to_gradio
|
||||
|
||||
[[autodoc]] stream_to_gradio
|
||||
|
||||
### ToolCollection
|
||||
|
||||
[[autodoc]] ToolCollection
|
||||
|
@ -67,6 +67,7 @@ _import_structure = {
|
||||
"ToolCollection",
|
||||
"launch_gradio_demo",
|
||||
"load_tool",
|
||||
"stream_to_gradio",
|
||||
],
|
||||
"audio_utils": [],
|
||||
"benchmark": [],
|
||||
@ -4733,6 +4734,7 @@ if TYPE_CHECKING:
|
||||
ToolCollection,
|
||||
launch_gradio_demo,
|
||||
load_tool,
|
||||
stream_to_gradio,
|
||||
)
|
||||
from .configuration_utils import PretrainedConfig
|
||||
|
||||
|
@ -26,6 +26,7 @@ from ..utils import (
|
||||
_import_structure = {
|
||||
"agents": ["Agent", "CodeAgent", "ReactAgent", "ReactCodeAgent", "ReactJsonAgent", "Toolbox"],
|
||||
"llm_engine": ["HfEngine"],
|
||||
"monitoring": ["stream_to_gradio"],
|
||||
"tools": ["PipelineTool", "Tool", "ToolCollection", "launch_gradio_demo", "load_tool"],
|
||||
}
|
||||
|
||||
@ -45,6 +46,7 @@ else:
|
||||
if TYPE_CHECKING:
|
||||
from .agents import Agent, CodeAgent, ReactAgent, ReactCodeAgent, ReactJsonAgent, Toolbox
|
||||
from .llm_engine import HfEngine
|
||||
from .monitoring import stream_to_gradio
|
||||
from .tools import PipelineTool, Tool, ToolCollection, launch_gradio_demo, load_tool
|
||||
|
||||
try:
|
||||
|
75
src/transformers/agents/monitoring.py
Normal file
75
src/transformers/agents/monitoring.py
Normal file
@ -0,0 +1,75 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from .agent_types import AgentAudio, AgentImage, AgentText
|
||||
from .agents import ReactAgent
|
||||
|
||||
|
||||
def pull_message(step_log: dict):
|
||||
try:
|
||||
from gradio import ChatMessage
|
||||
except ImportError:
|
||||
raise ImportError("Gradio should be installed in order to launch a gradio demo.")
|
||||
|
||||
if step_log.get("rationale"):
|
||||
yield ChatMessage(role="assistant", content=step_log["rationale"])
|
||||
if step_log.get("tool_call"):
|
||||
used_code = step_log["tool_call"]["tool_name"] == "code interpreter"
|
||||
content = step_log["tool_call"]["tool_arguments"]
|
||||
if used_code:
|
||||
content = f"```py\n{content}\n```"
|
||||
yield ChatMessage(
|
||||
role="assistant",
|
||||
metadata={"title": f"🛠️ Used tool {step_log['tool_call']['tool_name']}"},
|
||||
content=content,
|
||||
)
|
||||
if step_log.get("observation"):
|
||||
yield ChatMessage(role="assistant", content=f"```\n{step_log['observation']}\n```")
|
||||
if step_log.get("error"):
|
||||
yield ChatMessage(
|
||||
role="assistant",
|
||||
content=str(step_log["error"]),
|
||||
metadata={"title": "💥 Error"},
|
||||
)
|
||||
|
||||
|
||||
def stream_to_gradio(agent: ReactAgent, task: str, **kwargs):
|
||||
"""Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages."""
|
||||
|
||||
try:
|
||||
from gradio import ChatMessage
|
||||
except ImportError:
|
||||
raise ImportError("Gradio should be installed in order to launch a gradio demo.")
|
||||
|
||||
for step_log in agent.run(task, stream=True, **kwargs):
|
||||
if isinstance(step_log, dict):
|
||||
for message in pull_message(step_log):
|
||||
yield message
|
||||
|
||||
if isinstance(step_log, AgentText):
|
||||
yield ChatMessage(role="assistant", content=f"**Final answer:**\n```\n{step_log.to_string()}\n```")
|
||||
elif isinstance(step_log, AgentImage):
|
||||
yield ChatMessage(
|
||||
role="assistant",
|
||||
content={"path": step_log.to_string(), "mime_type": "image/png"},
|
||||
)
|
||||
elif isinstance(step_log, AgentAudio):
|
||||
yield ChatMessage(
|
||||
role="assistant",
|
||||
content={"path": step_log.to_string(), "mime_type": "audio/wav"},
|
||||
)
|
||||
else:
|
||||
yield ChatMessage(role="assistant", content=step_log)
|
Loading…
Reference in New Issue
Block a user