Add stream messages from agent run for gradio chatbot (#32142)

* Add stream_to_gradio method for running agent in gradio demo
This commit is contained in:
Aymeric Roucher 2024-07-29 20:12:44 +02:00 committed by GitHub
parent 811a9caa21
commit a24a9a66f4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 134 additions and 0 deletions

View File

@ -509,3 +509,54 @@ agent = ReactCodeAgent(tools=[search_tool])
agent.run("How many more blocks (also denoted as layers) in BERT base encoder than the encoder from the architecture proposed in Attention is All You Need?") agent.run("How many more blocks (also denoted as layers) in BERT base encoder than the encoder from the architecture proposed in Attention is All You Need?")
``` ```
## Gradio interface
You can leverage `gradio.Chatbot`to display your agent's thoughts using `stream_to_gradio`, here is an example:
```py
import gradio as gr
from transformers import (
load_tool,
ReactCodeAgent,
HfEngine,
stream_to_gradio,
)
# Import tool from Hub
image_generation_tool = load_tool("m-ric/text-to-image")
llm_engine = HfEngine("meta-llama/Meta-Llama-3-70B-Instruct")
# Initialize the agent with the image generation tool
agent = ReactCodeAgent(tools=[image_generation_tool], llm_engine=llm_engine)
def interact_with_agent(task):
messages = []
messages.append(gr.ChatMessage(role="user", content=task))
yield messages
for msg in stream_to_gradio(agent, task):
messages.append(msg)
yield messages + [
gr.ChatMessage(role="assistant", content="⏳ Task not finished yet!")
]
yield messages
with gr.Blocks() as demo:
text_input = gr.Textbox(lines=1, label="Chat Message", value="Make me a picture of the Statue of Liberty.")
submit = gr.Button("Run illustrator agent!")
chatbot = gr.Chatbot(
label="Agent",
type="messages",
avatar_images=(
None,
"https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png",
),
)
submit.click(interact_with_agent, [text_input], [chatbot])
if __name__ == "__main__":
demo.launch()
```

View File

@ -72,6 +72,10 @@ We provide two types of agents, based on the main [`Agent`] class:
[[autodoc]] launch_gradio_demo [[autodoc]] launch_gradio_demo
### stream_to_gradio
[[autodoc]] stream_to_gradio
### ToolCollection ### ToolCollection
[[autodoc]] ToolCollection [[autodoc]] ToolCollection

View File

@ -67,6 +67,7 @@ _import_structure = {
"ToolCollection", "ToolCollection",
"launch_gradio_demo", "launch_gradio_demo",
"load_tool", "load_tool",
"stream_to_gradio",
], ],
"audio_utils": [], "audio_utils": [],
"benchmark": [], "benchmark": [],
@ -4733,6 +4734,7 @@ if TYPE_CHECKING:
ToolCollection, ToolCollection,
launch_gradio_demo, launch_gradio_demo,
load_tool, load_tool,
stream_to_gradio,
) )
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig

View File

@ -26,6 +26,7 @@ from ..utils import (
_import_structure = { _import_structure = {
"agents": ["Agent", "CodeAgent", "ReactAgent", "ReactCodeAgent", "ReactJsonAgent", "Toolbox"], "agents": ["Agent", "CodeAgent", "ReactAgent", "ReactCodeAgent", "ReactJsonAgent", "Toolbox"],
"llm_engine": ["HfEngine"], "llm_engine": ["HfEngine"],
"monitoring": ["stream_to_gradio"],
"tools": ["PipelineTool", "Tool", "ToolCollection", "launch_gradio_demo", "load_tool"], "tools": ["PipelineTool", "Tool", "ToolCollection", "launch_gradio_demo", "load_tool"],
} }
@ -45,6 +46,7 @@ else:
if TYPE_CHECKING: if TYPE_CHECKING:
from .agents import Agent, CodeAgent, ReactAgent, ReactCodeAgent, ReactJsonAgent, Toolbox from .agents import Agent, CodeAgent, ReactAgent, ReactCodeAgent, ReactJsonAgent, Toolbox
from .llm_engine import HfEngine from .llm_engine import HfEngine
from .monitoring import stream_to_gradio
from .tools import PipelineTool, Tool, ToolCollection, launch_gradio_demo, load_tool from .tools import PipelineTool, Tool, ToolCollection, launch_gradio_demo, load_tool
try: try:

View File

@ -0,0 +1,75 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .agent_types import AgentAudio, AgentImage, AgentText
from .agents import ReactAgent
def pull_message(step_log: dict):
try:
from gradio import ChatMessage
except ImportError:
raise ImportError("Gradio should be installed in order to launch a gradio demo.")
if step_log.get("rationale"):
yield ChatMessage(role="assistant", content=step_log["rationale"])
if step_log.get("tool_call"):
used_code = step_log["tool_call"]["tool_name"] == "code interpreter"
content = step_log["tool_call"]["tool_arguments"]
if used_code:
content = f"```py\n{content}\n```"
yield ChatMessage(
role="assistant",
metadata={"title": f"🛠️ Used tool {step_log['tool_call']['tool_name']}"},
content=content,
)
if step_log.get("observation"):
yield ChatMessage(role="assistant", content=f"```\n{step_log['observation']}\n```")
if step_log.get("error"):
yield ChatMessage(
role="assistant",
content=str(step_log["error"]),
metadata={"title": "💥 Error"},
)
def stream_to_gradio(agent: ReactAgent, task: str, **kwargs):
"""Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages."""
try:
from gradio import ChatMessage
except ImportError:
raise ImportError("Gradio should be installed in order to launch a gradio demo.")
for step_log in agent.run(task, stream=True, **kwargs):
if isinstance(step_log, dict):
for message in pull_message(step_log):
yield message
if isinstance(step_log, AgentText):
yield ChatMessage(role="assistant", content=f"**Final answer:**\n```\n{step_log.to_string()}\n```")
elif isinstance(step_log, AgentImage):
yield ChatMessage(
role="assistant",
content={"path": step_log.to_string(), "mime_type": "image/png"},
)
elif isinstance(step_log, AgentAudio):
yield ChatMessage(
role="assistant",
content={"path": step_log.to_string(), "mime_type": "audio/wav"},
)
else:
yield ChatMessage(role="assistant", content=step_log)