Add support to declare imports for code agent (#31355)

* Support import declaration in Code Agent
This commit is contained in:
Jason (Siyu) Zhu 2024-06-12 00:32:28 -07:00 committed by GitHub
parent 35a6d9d648
commit a2ede66674
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 91 additions and 44 deletions

View File

@ -17,7 +17,7 @@
import json
import logging
import re
from typing import Any, Callable, Dict, List, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from .. import is_torch_available
from ..utils import logging as transformers_logging
@ -256,15 +256,6 @@ class Toolbox:
return toolbox_description
def format_prompt_with_tools(toolbox: Toolbox, prompt_template: str, tool_description_template: str) -> str:
tool_descriptions = toolbox.show_tool_descriptions(tool_description_template)
prompt = prompt_template.replace("<<tool_descriptions>>", tool_descriptions)
if "<<tool_names>>" in prompt:
tool_names = [f"'{tool_name}'" for tool_name in toolbox.tools.keys()]
prompt = prompt.replace("<<tool_names>>", ", ".join(tool_names))
return prompt
class AgentError(Exception):
"""Base class for other agent-related exceptions"""
@ -297,6 +288,21 @@ class AgentGenerationError(AgentError):
pass
def format_prompt_with_tools(toolbox: Toolbox, prompt_template: str, tool_description_template: str) -> str:
tool_descriptions = toolbox.show_tool_descriptions(tool_description_template)
prompt = prompt_template.replace("<<tool_descriptions>>", tool_descriptions)
if "<<tool_names>>" in prompt:
tool_names = [f"'{tool_name}'" for tool_name in toolbox.tools.keys()]
prompt = prompt.replace("<<tool_names>>", ", ".join(tool_names))
return prompt
def format_prompt_with_imports(prompt_template: str, authorized_imports: List[str]) -> str:
if "<<authorized_imports>>" not in prompt_template:
raise AgentError("Tag '<<authorized_imports>>' should be provided in the prompt.")
return prompt_template.replace("<<authorized_imports>>", str(authorized_imports))
class Agent:
def __init__(
self,
@ -359,8 +365,14 @@ class Agent:
self.task += f"\nYou have been provided with these initial arguments: {str(kwargs)}."
self.state = kwargs.copy()
self.system_prompt = format_prompt_with_tools(
self._toolbox, self.system_prompt_template, self.tool_description_template
self._toolbox,
self.system_prompt_template,
self.tool_description_template,
)
if hasattr(self, "authorized_imports"):
self.system_prompt = format_prompt_with_imports(
self.system_prompt, list(set(LIST_SAFE_MODULES) | set(self.authorized_imports))
)
self.logs = [{"system_prompt": self.system_prompt, "task": self.task}]
self.logger.warn("======== New task ========")
self.logger.log(33, self.task)
@ -496,7 +508,7 @@ class CodeAgent(Agent):
llm_engine: Callable = HfEngine(),
system_prompt: str = DEFAULT_CODE_SYSTEM_PROMPT,
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
additional_authorized_imports: List[str] = [],
additional_authorized_imports: Optional[List[str]] = None,
**kwargs,
):
super().__init__(
@ -515,7 +527,9 @@ class CodeAgent(Agent):
)
self.python_evaluator = evaluate_python_code
self.additional_authorized_imports = additional_authorized_imports
self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else []
self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports))
self.system_prompt = self.system_prompt.replace("<<authorized_imports>>", str(self.authorized_imports))
def parse_code_blob(self, result: str) -> str:
"""
@ -562,7 +576,13 @@ class CodeAgent(Agent):
return llm_output
# Parse
_, code_action = self.extract_action(llm_output=llm_output, split_token="Code:")
try:
_, code_action = self.extract_action(llm_output=llm_output, split_token="Code:")
except Exception as e:
self.logger.debug(
f"Error in extracting action, trying to parse the whole output as code. Error trace: {e}"
)
code_action = llm_output
try:
code_action = self.parse_code_blob(code_action)
@ -579,7 +599,7 @@ class CodeAgent(Agent):
code_action,
available_tools,
state=self.state,
authorized_imports=LIST_SAFE_MODULES + self.additional_authorized_imports,
authorized_imports=self.authorized_imports,
)
self.logger.info(self.state["print_outputs"])
return output
@ -639,17 +659,12 @@ class ReactAgent(Agent):
def run(self, task: str, stream: bool = False, **kwargs):
"""
Runs the agent for the given task.
Args:
task (`str`): The task to perform
Example:
```py
from transformers.agents import ReactJsonAgent, PythonInterpreterTool
python_interpreter = PythonInterpreterTool()
agent = ReactJsonAgent(tools=[python_interpreter])
from transformers.agents import ReactCodeAgent
agent = ReactCodeAgent(tools=[])
agent.run("What is the result of 2 power 3.7384?")
```
"""
@ -820,7 +835,7 @@ class ReactCodeAgent(ReactAgent):
llm_engine: Callable = HfEngine(),
system_prompt: str = DEFAULT_REACT_CODE_SYSTEM_PROMPT,
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
additional_authorized_imports: List[str] = [],
additional_authorized_imports: Optional[List[str]] = None,
**kwargs,
):
super().__init__(
@ -839,7 +854,9 @@ class ReactCodeAgent(ReactAgent):
)
self.python_evaluator = evaluate_python_code
self.additional_authorized_imports = additional_authorized_imports
self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else []
self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports))
self.system_prompt = self.system_prompt.replace("<<authorized_imports>>", str(self.authorized_imports))
def step(self):
"""
@ -871,7 +888,11 @@ class ReactCodeAgent(ReactAgent):
# Parse
self.logger.debug("===== Extracting action =====")
rationale, raw_code_action = self.extract_action(llm_output=llm_output, split_token="Code:")
try:
rationale, raw_code_action = self.extract_action(llm_output=llm_output, split_token="Code:")
except Exception as e:
self.logger.debug(f"Error in extracting action, trying to parse the whole output. Error trace: {e}")
rationale, raw_code_action = llm_output, llm_output
try:
code_action = parse_code_blob(raw_code_action)
@ -890,7 +911,7 @@ class ReactCodeAgent(ReactAgent):
code_action,
available_tools,
state=self.state,
authorized_imports=LIST_SAFE_MODULES + self.additional_authorized_imports,
authorized_imports=self.authorized_imports,
)
information = self.state["print_outputs"]
self.logger.warning("Print outputs:")

View File

@ -125,12 +125,13 @@ def setup_default_tools(logger):
for task_name, tool_class_name in TASK_MAPPING.items():
tool_class = getattr(tools_module, tool_class_name)
tool_instance = tool_class()
default_tools[tool_class.name] = PreTool(
name=tool_class.name,
inputs=tool_class.inputs,
output_type=tool_class.output_type,
name=tool_instance.name,
inputs=tool_instance.inputs,
output_type=tool_instance.output_type,
task=task_name,
description=tool_class.description,
description=tool_instance.description,
repo_id=None,
)
@ -141,18 +142,25 @@ class PythonInterpreterTool(Tool):
name = "python_interpreter"
description = "This is a tool that evaluates python code. It can be used to perform calculations."
inputs = {
"code": {
"type": "text",
"description": (
"The code snippet to evaluate. All variables used in this snippet must be defined in this same snippet, "
f"else you will get an error. This code can only import the following python libraries: {LIST_SAFE_MODULES}."
),
}
}
output_type = "text"
available_tools = BASE_PYTHON_TOOLS.copy()
def __init__(self, *args, authorized_imports=None, **kwargs):
if authorized_imports is None:
authorized_imports = list(set(LIST_SAFE_MODULES))
else:
authorized_imports = list(set(LIST_SAFE_MODULES) | set(authorized_imports))
self.inputs = {
"code": {
"type": "text",
"description": (
"The code snippet to evaluate. All variables used in this snippet must be defined in this same snippet, "
f"else you will get an error. This code can only import the following python libraries: {authorized_imports}."
),
}
}
super().__init__(*args, **kwargs)
def forward(self, code):
output = str(evaluate_python_code(code, tools=self.available_tools))
return output

View File

@ -52,6 +52,7 @@ DEFAULT_CODE_SYSTEM_PROMPT = """You will be given a task to solve, your job is t
To help you, I will give you access to a set of tools that you can use. Each tool is a Python function and has a description explaining the task it performs, the inputs it expects and the outputs it returns.
You should first explain which tool you will use to perform the task and for what reason, then write the code in Python.
Each instruction in Python should be a simple assignment. You can print intermediate results if it makes sense to do so.
You can use imports in your code, but only from the following list of modules: <<authorized_imports>>
Be sure to provide a 'Code:' token, else the system will be stuck in a loop.
Tools:
@ -263,7 +264,7 @@ Now Begin! If you solve the task correctly, you will receive a reward of $1,000,
DEFAULT_REACT_CODE_SYSTEM_PROMPT = """You are an expert assistant who can solve any task using code blobs. You will be given a task to solve as best you can.
To do so, you have been given access to *tools*: these tools are basically Python functions which you can call with code.
To do so, you have been given access to a list of tools: these tools are basically Python functions which you can call with code.
To solve the task, you must plan forward to proceed in a series of steps, in a cycle of 'Thought:', 'Code:', and 'Observation:' sequences.
At each step, in the 'Thought:' sequence, you should first explain your reasoning towards solving the task and the tools that you want to use.
@ -356,6 +357,7 @@ Here are the rules you should always follow to solve your task:
4. Take care to not chain too many sequential tool calls in the same code block, especially when the output format is unpredictable. For instance, a call to search has an unpredictable return format, so do not have another tool call that depends on its output in the same block: rather output results with print() to use them in the next block.
5. Call a tool only when needed, and never re-do a tool call that you previously did with the exact same parameters.
6. Don't name any new variable with the same name as a tool: for instance don't name a variable 'final_answer'.
7. You can use imports in your code, but only from the following list of modules: <<authorized_imports>>
Now Begin! If you solve the task correctly, you will receive a reward of $1,000,000.
"""

View File

@ -141,15 +141,21 @@ Action:
def test_init_agent_with_different_toolsets(self):
toolset_1 = []
agent = ReactCodeAgent(tools=toolset_1, llm_engine=fake_react_code_llm)
assert len(agent.toolbox.tools) == 1 # contains only final_answer tool
assert (
len(agent.toolbox.tools) == 1
) # when no tools are provided, only the final_answer tool is added by default
toolset_2 = [PythonInterpreterTool(), PythonInterpreterTool()]
agent = ReactCodeAgent(tools=toolset_2, llm_engine=fake_react_code_llm)
assert len(agent.toolbox.tools) == 2 # added final_answer tool
assert (
len(agent.toolbox.tools) == 2
) # deduplication of tools, so only one python_interpreter tool is added in addition to final_answer
toolset_3 = Toolbox(toolset_2)
agent = ReactCodeAgent(tools=toolset_3, llm_engine=fake_react_code_llm)
assert len(agent.toolbox.tools) == 2 # added final_answer tool
assert (
len(agent.toolbox.tools) == 2
) # same as previous one, where toolset_3 is an instantiation of previous one
# check that add_base_tools will not interfere with existing tools
with pytest.raises(KeyError) as e:

View File

@ -32,9 +32,19 @@ def add_two(x):
class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin):
def setUp(self):
self.tool = load_tool("python_interpreter")
self.tool = load_tool("python_interpreter", authorized_imports=["sqlite3"])
self.tool.setup()
def test_exact_match_input_spec(self):
inputs_spec = self.tool.inputs
expected_description = (
"The code snippet to evaluate. All variables used in this snippet must be defined in this same snippet, "
"else you will get an error. This code can only import the following python libraries: "
"['math', 'statistics', 'time', 'itertools', 'stat', 'unicodedata', 'sqlite3', 'queue', 'collections', "
"'random', 're']."
)
self.assertEqual(inputs_spec["code"]["description"], expected_description)
def test_exact_match_arg(self):
result = self.tool("(2 / 2) * 4")
self.assertEqual(result, "4.0")