Add support to declare imports for code agent (#31355)

* Support import declaration in Code Agent
This commit is contained in:
Jason (Siyu) Zhu 2024-06-12 00:32:28 -07:00 committed by GitHub
parent 35a6d9d648
commit a2ede66674
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 91 additions and 44 deletions

View File

@ -17,7 +17,7 @@
import json import json
import logging import logging
import re import re
from typing import Any, Callable, Dict, List, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from .. import is_torch_available from .. import is_torch_available
from ..utils import logging as transformers_logging from ..utils import logging as transformers_logging
@ -256,15 +256,6 @@ class Toolbox:
return toolbox_description return toolbox_description
def format_prompt_with_tools(toolbox: Toolbox, prompt_template: str, tool_description_template: str) -> str:
tool_descriptions = toolbox.show_tool_descriptions(tool_description_template)
prompt = prompt_template.replace("<<tool_descriptions>>", tool_descriptions)
if "<<tool_names>>" in prompt:
tool_names = [f"'{tool_name}'" for tool_name in toolbox.tools.keys()]
prompt = prompt.replace("<<tool_names>>", ", ".join(tool_names))
return prompt
class AgentError(Exception): class AgentError(Exception):
"""Base class for other agent-related exceptions""" """Base class for other agent-related exceptions"""
@ -297,6 +288,21 @@ class AgentGenerationError(AgentError):
pass pass
def format_prompt_with_tools(toolbox: Toolbox, prompt_template: str, tool_description_template: str) -> str:
tool_descriptions = toolbox.show_tool_descriptions(tool_description_template)
prompt = prompt_template.replace("<<tool_descriptions>>", tool_descriptions)
if "<<tool_names>>" in prompt:
tool_names = [f"'{tool_name}'" for tool_name in toolbox.tools.keys()]
prompt = prompt.replace("<<tool_names>>", ", ".join(tool_names))
return prompt
def format_prompt_with_imports(prompt_template: str, authorized_imports: List[str]) -> str:
if "<<authorized_imports>>" not in prompt_template:
raise AgentError("Tag '<<authorized_imports>>' should be provided in the prompt.")
return prompt_template.replace("<<authorized_imports>>", str(authorized_imports))
class Agent: class Agent:
def __init__( def __init__(
self, self,
@ -359,8 +365,14 @@ class Agent:
self.task += f"\nYou have been provided with these initial arguments: {str(kwargs)}." self.task += f"\nYou have been provided with these initial arguments: {str(kwargs)}."
self.state = kwargs.copy() self.state = kwargs.copy()
self.system_prompt = format_prompt_with_tools( self.system_prompt = format_prompt_with_tools(
self._toolbox, self.system_prompt_template, self.tool_description_template self._toolbox,
self.system_prompt_template,
self.tool_description_template,
) )
if hasattr(self, "authorized_imports"):
self.system_prompt = format_prompt_with_imports(
self.system_prompt, list(set(LIST_SAFE_MODULES) | set(self.authorized_imports))
)
self.logs = [{"system_prompt": self.system_prompt, "task": self.task}] self.logs = [{"system_prompt": self.system_prompt, "task": self.task}]
self.logger.warn("======== New task ========") self.logger.warn("======== New task ========")
self.logger.log(33, self.task) self.logger.log(33, self.task)
@ -496,7 +508,7 @@ class CodeAgent(Agent):
llm_engine: Callable = HfEngine(), llm_engine: Callable = HfEngine(),
system_prompt: str = DEFAULT_CODE_SYSTEM_PROMPT, system_prompt: str = DEFAULT_CODE_SYSTEM_PROMPT,
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE, tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
additional_authorized_imports: List[str] = [], additional_authorized_imports: Optional[List[str]] = None,
**kwargs, **kwargs,
): ):
super().__init__( super().__init__(
@ -515,7 +527,9 @@ class CodeAgent(Agent):
) )
self.python_evaluator = evaluate_python_code self.python_evaluator = evaluate_python_code
self.additional_authorized_imports = additional_authorized_imports self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else []
self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports))
self.system_prompt = self.system_prompt.replace("<<authorized_imports>>", str(self.authorized_imports))
def parse_code_blob(self, result: str) -> str: def parse_code_blob(self, result: str) -> str:
""" """
@ -562,7 +576,13 @@ class CodeAgent(Agent):
return llm_output return llm_output
# Parse # Parse
_, code_action = self.extract_action(llm_output=llm_output, split_token="Code:") try:
_, code_action = self.extract_action(llm_output=llm_output, split_token="Code:")
except Exception as e:
self.logger.debug(
f"Error in extracting action, trying to parse the whole output as code. Error trace: {e}"
)
code_action = llm_output
try: try:
code_action = self.parse_code_blob(code_action) code_action = self.parse_code_blob(code_action)
@ -579,7 +599,7 @@ class CodeAgent(Agent):
code_action, code_action,
available_tools, available_tools,
state=self.state, state=self.state,
authorized_imports=LIST_SAFE_MODULES + self.additional_authorized_imports, authorized_imports=self.authorized_imports,
) )
self.logger.info(self.state["print_outputs"]) self.logger.info(self.state["print_outputs"])
return output return output
@ -639,17 +659,12 @@ class ReactAgent(Agent):
def run(self, task: str, stream: bool = False, **kwargs): def run(self, task: str, stream: bool = False, **kwargs):
""" """
Runs the agent for the given task. Runs the agent for the given task.
Args: Args:
task (`str`): The task to perform task (`str`): The task to perform
Example: Example:
```py ```py
from transformers.agents import ReactJsonAgent, PythonInterpreterTool from transformers.agents import ReactCodeAgent
agent = ReactCodeAgent(tools=[])
python_interpreter = PythonInterpreterTool()
agent = ReactJsonAgent(tools=[python_interpreter])
agent.run("What is the result of 2 power 3.7384?") agent.run("What is the result of 2 power 3.7384?")
``` ```
""" """
@ -820,7 +835,7 @@ class ReactCodeAgent(ReactAgent):
llm_engine: Callable = HfEngine(), llm_engine: Callable = HfEngine(),
system_prompt: str = DEFAULT_REACT_CODE_SYSTEM_PROMPT, system_prompt: str = DEFAULT_REACT_CODE_SYSTEM_PROMPT,
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE, tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
additional_authorized_imports: List[str] = [], additional_authorized_imports: Optional[List[str]] = None,
**kwargs, **kwargs,
): ):
super().__init__( super().__init__(
@ -839,7 +854,9 @@ class ReactCodeAgent(ReactAgent):
) )
self.python_evaluator = evaluate_python_code self.python_evaluator = evaluate_python_code
self.additional_authorized_imports = additional_authorized_imports self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else []
self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports))
self.system_prompt = self.system_prompt.replace("<<authorized_imports>>", str(self.authorized_imports))
def step(self): def step(self):
""" """
@ -871,7 +888,11 @@ class ReactCodeAgent(ReactAgent):
# Parse # Parse
self.logger.debug("===== Extracting action =====") self.logger.debug("===== Extracting action =====")
rationale, raw_code_action = self.extract_action(llm_output=llm_output, split_token="Code:") try:
rationale, raw_code_action = self.extract_action(llm_output=llm_output, split_token="Code:")
except Exception as e:
self.logger.debug(f"Error in extracting action, trying to parse the whole output. Error trace: {e}")
rationale, raw_code_action = llm_output, llm_output
try: try:
code_action = parse_code_blob(raw_code_action) code_action = parse_code_blob(raw_code_action)
@ -890,7 +911,7 @@ class ReactCodeAgent(ReactAgent):
code_action, code_action,
available_tools, available_tools,
state=self.state, state=self.state,
authorized_imports=LIST_SAFE_MODULES + self.additional_authorized_imports, authorized_imports=self.authorized_imports,
) )
information = self.state["print_outputs"] information = self.state["print_outputs"]
self.logger.warning("Print outputs:") self.logger.warning("Print outputs:")

View File

@ -125,12 +125,13 @@ def setup_default_tools(logger):
for task_name, tool_class_name in TASK_MAPPING.items(): for task_name, tool_class_name in TASK_MAPPING.items():
tool_class = getattr(tools_module, tool_class_name) tool_class = getattr(tools_module, tool_class_name)
tool_instance = tool_class()
default_tools[tool_class.name] = PreTool( default_tools[tool_class.name] = PreTool(
name=tool_class.name, name=tool_instance.name,
inputs=tool_class.inputs, inputs=tool_instance.inputs,
output_type=tool_class.output_type, output_type=tool_instance.output_type,
task=task_name, task=task_name,
description=tool_class.description, description=tool_instance.description,
repo_id=None, repo_id=None,
) )
@ -141,18 +142,25 @@ class PythonInterpreterTool(Tool):
name = "python_interpreter" name = "python_interpreter"
description = "This is a tool that evaluates python code. It can be used to perform calculations." description = "This is a tool that evaluates python code. It can be used to perform calculations."
inputs = {
"code": {
"type": "text",
"description": (
"The code snippet to evaluate. All variables used in this snippet must be defined in this same snippet, "
f"else you will get an error. This code can only import the following python libraries: {LIST_SAFE_MODULES}."
),
}
}
output_type = "text" output_type = "text"
available_tools = BASE_PYTHON_TOOLS.copy() available_tools = BASE_PYTHON_TOOLS.copy()
def __init__(self, *args, authorized_imports=None, **kwargs):
if authorized_imports is None:
authorized_imports = list(set(LIST_SAFE_MODULES))
else:
authorized_imports = list(set(LIST_SAFE_MODULES) | set(authorized_imports))
self.inputs = {
"code": {
"type": "text",
"description": (
"The code snippet to evaluate. All variables used in this snippet must be defined in this same snippet, "
f"else you will get an error. This code can only import the following python libraries: {authorized_imports}."
),
}
}
super().__init__(*args, **kwargs)
def forward(self, code): def forward(self, code):
output = str(evaluate_python_code(code, tools=self.available_tools)) output = str(evaluate_python_code(code, tools=self.available_tools))
return output return output

View File

@ -52,6 +52,7 @@ DEFAULT_CODE_SYSTEM_PROMPT = """You will be given a task to solve, your job is t
To help you, I will give you access to a set of tools that you can use. Each tool is a Python function and has a description explaining the task it performs, the inputs it expects and the outputs it returns. To help you, I will give you access to a set of tools that you can use. Each tool is a Python function and has a description explaining the task it performs, the inputs it expects and the outputs it returns.
You should first explain which tool you will use to perform the task and for what reason, then write the code in Python. You should first explain which tool you will use to perform the task and for what reason, then write the code in Python.
Each instruction in Python should be a simple assignment. You can print intermediate results if it makes sense to do so. Each instruction in Python should be a simple assignment. You can print intermediate results if it makes sense to do so.
You can use imports in your code, but only from the following list of modules: <<authorized_imports>>
Be sure to provide a 'Code:' token, else the system will be stuck in a loop. Be sure to provide a 'Code:' token, else the system will be stuck in a loop.
Tools: Tools:
@ -263,7 +264,7 @@ Now Begin! If you solve the task correctly, you will receive a reward of $1,000,
DEFAULT_REACT_CODE_SYSTEM_PROMPT = """You are an expert assistant who can solve any task using code blobs. You will be given a task to solve as best you can. DEFAULT_REACT_CODE_SYSTEM_PROMPT = """You are an expert assistant who can solve any task using code blobs. You will be given a task to solve as best you can.
To do so, you have been given access to *tools*: these tools are basically Python functions which you can call with code. To do so, you have been given access to a list of tools: these tools are basically Python functions which you can call with code.
To solve the task, you must plan forward to proceed in a series of steps, in a cycle of 'Thought:', 'Code:', and 'Observation:' sequences. To solve the task, you must plan forward to proceed in a series of steps, in a cycle of 'Thought:', 'Code:', and 'Observation:' sequences.
At each step, in the 'Thought:' sequence, you should first explain your reasoning towards solving the task and the tools that you want to use. At each step, in the 'Thought:' sequence, you should first explain your reasoning towards solving the task and the tools that you want to use.
@ -356,6 +357,7 @@ Here are the rules you should always follow to solve your task:
4. Take care to not chain too many sequential tool calls in the same code block, especially when the output format is unpredictable. For instance, a call to search has an unpredictable return format, so do not have another tool call that depends on its output in the same block: rather output results with print() to use them in the next block. 4. Take care to not chain too many sequential tool calls in the same code block, especially when the output format is unpredictable. For instance, a call to search has an unpredictable return format, so do not have another tool call that depends on its output in the same block: rather output results with print() to use them in the next block.
5. Call a tool only when needed, and never re-do a tool call that you previously did with the exact same parameters. 5. Call a tool only when needed, and never re-do a tool call that you previously did with the exact same parameters.
6. Don't name any new variable with the same name as a tool: for instance don't name a variable 'final_answer'. 6. Don't name any new variable with the same name as a tool: for instance don't name a variable 'final_answer'.
7. You can use imports in your code, but only from the following list of modules: <<authorized_imports>>
Now Begin! If you solve the task correctly, you will receive a reward of $1,000,000. Now Begin! If you solve the task correctly, you will receive a reward of $1,000,000.
""" """

View File

@ -141,15 +141,21 @@ Action:
def test_init_agent_with_different_toolsets(self): def test_init_agent_with_different_toolsets(self):
toolset_1 = [] toolset_1 = []
agent = ReactCodeAgent(tools=toolset_1, llm_engine=fake_react_code_llm) agent = ReactCodeAgent(tools=toolset_1, llm_engine=fake_react_code_llm)
assert len(agent.toolbox.tools) == 1 # contains only final_answer tool assert (
len(agent.toolbox.tools) == 1
) # when no tools are provided, only the final_answer tool is added by default
toolset_2 = [PythonInterpreterTool(), PythonInterpreterTool()] toolset_2 = [PythonInterpreterTool(), PythonInterpreterTool()]
agent = ReactCodeAgent(tools=toolset_2, llm_engine=fake_react_code_llm) agent = ReactCodeAgent(tools=toolset_2, llm_engine=fake_react_code_llm)
assert len(agent.toolbox.tools) == 2 # added final_answer tool assert (
len(agent.toolbox.tools) == 2
) # deduplication of tools, so only one python_interpreter tool is added in addition to final_answer
toolset_3 = Toolbox(toolset_2) toolset_3 = Toolbox(toolset_2)
agent = ReactCodeAgent(tools=toolset_3, llm_engine=fake_react_code_llm) agent = ReactCodeAgent(tools=toolset_3, llm_engine=fake_react_code_llm)
assert len(agent.toolbox.tools) == 2 # added final_answer tool assert (
len(agent.toolbox.tools) == 2
) # same as previous one, where toolset_3 is an instantiation of previous one
# check that add_base_tools will not interfere with existing tools # check that add_base_tools will not interfere with existing tools
with pytest.raises(KeyError) as e: with pytest.raises(KeyError) as e:

View File

@ -32,9 +32,19 @@ def add_two(x):
class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin): class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin):
def setUp(self): def setUp(self):
self.tool = load_tool("python_interpreter") self.tool = load_tool("python_interpreter", authorized_imports=["sqlite3"])
self.tool.setup() self.tool.setup()
def test_exact_match_input_spec(self):
inputs_spec = self.tool.inputs
expected_description = (
"The code snippet to evaluate. All variables used in this snippet must be defined in this same snippet, "
"else you will get an error. This code can only import the following python libraries: "
"['math', 'statistics', 'time', 'itertools', 'stat', 'unicodedata', 'sqlite3', 'queue', 'collections', "
"'random', 're']."
)
self.assertEqual(inputs_spec["code"]["description"], expected_description)
def test_exact_match_arg(self): def test_exact_match_arg(self):
result = self.tool("(2 / 2) * 4") result = self.tool("(2 / 2) * 4")
self.assertEqual(result, "4.0") self.assertEqual(result, "4.0")