From a2ede66674d919fe507cc27db191f3314fd29948 Mon Sep 17 00:00:00 2001 From: "Jason (Siyu) Zhu" Date: Wed, 12 Jun 2024 00:32:28 -0700 Subject: [PATCH] Add support to declare imports for code agent (#31355) * Support import declaration in Code Agent --- src/transformers/agents/agents.py | 73 +++++++++++++++--------- src/transformers/agents/default_tools.py | 34 ++++++----- src/transformers/agents/prompts.py | 4 +- tests/agents/test_agents.py | 12 +++- tests/agents/test_python_interpreter.py | 12 +++- 5 files changed, 91 insertions(+), 44 deletions(-) diff --git a/src/transformers/agents/agents.py b/src/transformers/agents/agents.py index ad0b9fecc3e..63a2c3889ba 100644 --- a/src/transformers/agents/agents.py +++ b/src/transformers/agents/agents.py @@ -17,7 +17,7 @@ import json import logging import re -from typing import Any, Callable, Dict, List, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union from .. import is_torch_available from ..utils import logging as transformers_logging @@ -256,15 +256,6 @@ class Toolbox: return toolbox_description -def format_prompt_with_tools(toolbox: Toolbox, prompt_template: str, tool_description_template: str) -> str: - tool_descriptions = toolbox.show_tool_descriptions(tool_description_template) - prompt = prompt_template.replace("<>", tool_descriptions) - if "<>" in prompt: - tool_names = [f"'{tool_name}'" for tool_name in toolbox.tools.keys()] - prompt = prompt.replace("<>", ", ".join(tool_names)) - return prompt - - class AgentError(Exception): """Base class for other agent-related exceptions""" @@ -297,6 +288,21 @@ class AgentGenerationError(AgentError): pass +def format_prompt_with_tools(toolbox: Toolbox, prompt_template: str, tool_description_template: str) -> str: + tool_descriptions = toolbox.show_tool_descriptions(tool_description_template) + prompt = prompt_template.replace("<>", tool_descriptions) + if "<>" in prompt: + tool_names = [f"'{tool_name}'" for tool_name in toolbox.tools.keys()] + prompt = prompt.replace("<>", ", ".join(tool_names)) + return prompt + + +def format_prompt_with_imports(prompt_template: str, authorized_imports: List[str]) -> str: + if "<>" not in prompt_template: + raise AgentError("Tag '<>' should be provided in the prompt.") + return prompt_template.replace("<>", str(authorized_imports)) + + class Agent: def __init__( self, @@ -359,8 +365,14 @@ class Agent: self.task += f"\nYou have been provided with these initial arguments: {str(kwargs)}." self.state = kwargs.copy() self.system_prompt = format_prompt_with_tools( - self._toolbox, self.system_prompt_template, self.tool_description_template + self._toolbox, + self.system_prompt_template, + self.tool_description_template, ) + if hasattr(self, "authorized_imports"): + self.system_prompt = format_prompt_with_imports( + self.system_prompt, list(set(LIST_SAFE_MODULES) | set(self.authorized_imports)) + ) self.logs = [{"system_prompt": self.system_prompt, "task": self.task}] self.logger.warn("======== New task ========") self.logger.log(33, self.task) @@ -496,7 +508,7 @@ class CodeAgent(Agent): llm_engine: Callable = HfEngine(), system_prompt: str = DEFAULT_CODE_SYSTEM_PROMPT, tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE, - additional_authorized_imports: List[str] = [], + additional_authorized_imports: Optional[List[str]] = None, **kwargs, ): super().__init__( @@ -515,7 +527,9 @@ class CodeAgent(Agent): ) self.python_evaluator = evaluate_python_code - self.additional_authorized_imports = additional_authorized_imports + self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else [] + self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports)) + self.system_prompt = self.system_prompt.replace("<>", str(self.authorized_imports)) def parse_code_blob(self, result: str) -> str: """ @@ -562,7 +576,13 @@ class CodeAgent(Agent): return llm_output # Parse - _, code_action = self.extract_action(llm_output=llm_output, split_token="Code:") + try: + _, code_action = self.extract_action(llm_output=llm_output, split_token="Code:") + except Exception as e: + self.logger.debug( + f"Error in extracting action, trying to parse the whole output as code. Error trace: {e}" + ) + code_action = llm_output try: code_action = self.parse_code_blob(code_action) @@ -579,7 +599,7 @@ class CodeAgent(Agent): code_action, available_tools, state=self.state, - authorized_imports=LIST_SAFE_MODULES + self.additional_authorized_imports, + authorized_imports=self.authorized_imports, ) self.logger.info(self.state["print_outputs"]) return output @@ -639,17 +659,12 @@ class ReactAgent(Agent): def run(self, task: str, stream: bool = False, **kwargs): """ Runs the agent for the given task. - Args: task (`str`): The task to perform - Example: - ```py - from transformers.agents import ReactJsonAgent, PythonInterpreterTool - - python_interpreter = PythonInterpreterTool() - agent = ReactJsonAgent(tools=[python_interpreter]) + from transformers.agents import ReactCodeAgent + agent = ReactCodeAgent(tools=[]) agent.run("What is the result of 2 power 3.7384?") ``` """ @@ -820,7 +835,7 @@ class ReactCodeAgent(ReactAgent): llm_engine: Callable = HfEngine(), system_prompt: str = DEFAULT_REACT_CODE_SYSTEM_PROMPT, tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE, - additional_authorized_imports: List[str] = [], + additional_authorized_imports: Optional[List[str]] = None, **kwargs, ): super().__init__( @@ -839,7 +854,9 @@ class ReactCodeAgent(ReactAgent): ) self.python_evaluator = evaluate_python_code - self.additional_authorized_imports = additional_authorized_imports + self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else [] + self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports)) + self.system_prompt = self.system_prompt.replace("<>", str(self.authorized_imports)) def step(self): """ @@ -871,7 +888,11 @@ class ReactCodeAgent(ReactAgent): # Parse self.logger.debug("===== Extracting action =====") - rationale, raw_code_action = self.extract_action(llm_output=llm_output, split_token="Code:") + try: + rationale, raw_code_action = self.extract_action(llm_output=llm_output, split_token="Code:") + except Exception as e: + self.logger.debug(f"Error in extracting action, trying to parse the whole output. Error trace: {e}") + rationale, raw_code_action = llm_output, llm_output try: code_action = parse_code_blob(raw_code_action) @@ -890,7 +911,7 @@ class ReactCodeAgent(ReactAgent): code_action, available_tools, state=self.state, - authorized_imports=LIST_SAFE_MODULES + self.additional_authorized_imports, + authorized_imports=self.authorized_imports, ) information = self.state["print_outputs"] self.logger.warning("Print outputs:") diff --git a/src/transformers/agents/default_tools.py b/src/transformers/agents/default_tools.py index 7187422dc06..9adf55289d0 100644 --- a/src/transformers/agents/default_tools.py +++ b/src/transformers/agents/default_tools.py @@ -125,12 +125,13 @@ def setup_default_tools(logger): for task_name, tool_class_name in TASK_MAPPING.items(): tool_class = getattr(tools_module, tool_class_name) + tool_instance = tool_class() default_tools[tool_class.name] = PreTool( - name=tool_class.name, - inputs=tool_class.inputs, - output_type=tool_class.output_type, + name=tool_instance.name, + inputs=tool_instance.inputs, + output_type=tool_instance.output_type, task=task_name, - description=tool_class.description, + description=tool_instance.description, repo_id=None, ) @@ -141,18 +142,25 @@ class PythonInterpreterTool(Tool): name = "python_interpreter" description = "This is a tool that evaluates python code. It can be used to perform calculations." - inputs = { - "code": { - "type": "text", - "description": ( - "The code snippet to evaluate. All variables used in this snippet must be defined in this same snippet, " - f"else you will get an error. This code can only import the following python libraries: {LIST_SAFE_MODULES}." - ), - } - } output_type = "text" available_tools = BASE_PYTHON_TOOLS.copy() + def __init__(self, *args, authorized_imports=None, **kwargs): + if authorized_imports is None: + authorized_imports = list(set(LIST_SAFE_MODULES)) + else: + authorized_imports = list(set(LIST_SAFE_MODULES) | set(authorized_imports)) + self.inputs = { + "code": { + "type": "text", + "description": ( + "The code snippet to evaluate. All variables used in this snippet must be defined in this same snippet, " + f"else you will get an error. This code can only import the following python libraries: {authorized_imports}." + ), + } + } + super().__init__(*args, **kwargs) + def forward(self, code): output = str(evaluate_python_code(code, tools=self.available_tools)) return output diff --git a/src/transformers/agents/prompts.py b/src/transformers/agents/prompts.py index 4e5ff997081..661df9bd24e 100644 --- a/src/transformers/agents/prompts.py +++ b/src/transformers/agents/prompts.py @@ -52,6 +52,7 @@ DEFAULT_CODE_SYSTEM_PROMPT = """You will be given a task to solve, your job is t To help you, I will give you access to a set of tools that you can use. Each tool is a Python function and has a description explaining the task it performs, the inputs it expects and the outputs it returns. You should first explain which tool you will use to perform the task and for what reason, then write the code in Python. Each instruction in Python should be a simple assignment. You can print intermediate results if it makes sense to do so. +You can use imports in your code, but only from the following list of modules: <> Be sure to provide a 'Code:' token, else the system will be stuck in a loop. Tools: @@ -263,7 +264,7 @@ Now Begin! If you solve the task correctly, you will receive a reward of $1,000, DEFAULT_REACT_CODE_SYSTEM_PROMPT = """You are an expert assistant who can solve any task using code blobs. You will be given a task to solve as best you can. -To do so, you have been given access to *tools*: these tools are basically Python functions which you can call with code. +To do so, you have been given access to a list of tools: these tools are basically Python functions which you can call with code. To solve the task, you must plan forward to proceed in a series of steps, in a cycle of 'Thought:', 'Code:', and 'Observation:' sequences. At each step, in the 'Thought:' sequence, you should first explain your reasoning towards solving the task and the tools that you want to use. @@ -356,6 +357,7 @@ Here are the rules you should always follow to solve your task: 4. Take care to not chain too many sequential tool calls in the same code block, especially when the output format is unpredictable. For instance, a call to search has an unpredictable return format, so do not have another tool call that depends on its output in the same block: rather output results with print() to use them in the next block. 5. Call a tool only when needed, and never re-do a tool call that you previously did with the exact same parameters. 6. Don't name any new variable with the same name as a tool: for instance don't name a variable 'final_answer'. +7. You can use imports in your code, but only from the following list of modules: <> Now Begin! If you solve the task correctly, you will receive a reward of $1,000,000. """ diff --git a/tests/agents/test_agents.py b/tests/agents/test_agents.py index 8dc535e63c7..79e55bf6523 100644 --- a/tests/agents/test_agents.py +++ b/tests/agents/test_agents.py @@ -141,15 +141,21 @@ Action: def test_init_agent_with_different_toolsets(self): toolset_1 = [] agent = ReactCodeAgent(tools=toolset_1, llm_engine=fake_react_code_llm) - assert len(agent.toolbox.tools) == 1 # contains only final_answer tool + assert ( + len(agent.toolbox.tools) == 1 + ) # when no tools are provided, only the final_answer tool is added by default toolset_2 = [PythonInterpreterTool(), PythonInterpreterTool()] agent = ReactCodeAgent(tools=toolset_2, llm_engine=fake_react_code_llm) - assert len(agent.toolbox.tools) == 2 # added final_answer tool + assert ( + len(agent.toolbox.tools) == 2 + ) # deduplication of tools, so only one python_interpreter tool is added in addition to final_answer toolset_3 = Toolbox(toolset_2) agent = ReactCodeAgent(tools=toolset_3, llm_engine=fake_react_code_llm) - assert len(agent.toolbox.tools) == 2 # added final_answer tool + assert ( + len(agent.toolbox.tools) == 2 + ) # same as previous one, where toolset_3 is an instantiation of previous one # check that add_base_tools will not interfere with existing tools with pytest.raises(KeyError) as e: diff --git a/tests/agents/test_python_interpreter.py b/tests/agents/test_python_interpreter.py index dbe6c90a9ea..51775e31e76 100644 --- a/tests/agents/test_python_interpreter.py +++ b/tests/agents/test_python_interpreter.py @@ -32,9 +32,19 @@ def add_two(x): class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin): def setUp(self): - self.tool = load_tool("python_interpreter") + self.tool = load_tool("python_interpreter", authorized_imports=["sqlite3"]) self.tool.setup() + def test_exact_match_input_spec(self): + inputs_spec = self.tool.inputs + expected_description = ( + "The code snippet to evaluate. All variables used in this snippet must be defined in this same snippet, " + "else you will get an error. This code can only import the following python libraries: " + "['math', 'statistics', 'time', 'itertools', 'stat', 'unicodedata', 'sqlite3', 'queue', 'collections', " + "'random', 're']." + ) + self.assertEqual(inputs_spec["code"]["description"], expected_description) + def test_exact_match_arg(self): result = self.tool("(2 / 2) * 4") self.assertEqual(result, "4.0")