mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Add support to declare imports for code agent (#31355)
* Support import declaration in Code Agent
This commit is contained in:
parent
35a6d9d648
commit
a2ede66674
@ -17,7 +17,7 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import Any, Callable, Dict, List, Tuple, Union
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from .. import is_torch_available
|
from .. import is_torch_available
|
||||||
from ..utils import logging as transformers_logging
|
from ..utils import logging as transformers_logging
|
||||||
@ -256,15 +256,6 @@ class Toolbox:
|
|||||||
return toolbox_description
|
return toolbox_description
|
||||||
|
|
||||||
|
|
||||||
def format_prompt_with_tools(toolbox: Toolbox, prompt_template: str, tool_description_template: str) -> str:
|
|
||||||
tool_descriptions = toolbox.show_tool_descriptions(tool_description_template)
|
|
||||||
prompt = prompt_template.replace("<<tool_descriptions>>", tool_descriptions)
|
|
||||||
if "<<tool_names>>" in prompt:
|
|
||||||
tool_names = [f"'{tool_name}'" for tool_name in toolbox.tools.keys()]
|
|
||||||
prompt = prompt.replace("<<tool_names>>", ", ".join(tool_names))
|
|
||||||
return prompt
|
|
||||||
|
|
||||||
|
|
||||||
class AgentError(Exception):
|
class AgentError(Exception):
|
||||||
"""Base class for other agent-related exceptions"""
|
"""Base class for other agent-related exceptions"""
|
||||||
|
|
||||||
@ -297,6 +288,21 @@ class AgentGenerationError(AgentError):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def format_prompt_with_tools(toolbox: Toolbox, prompt_template: str, tool_description_template: str) -> str:
|
||||||
|
tool_descriptions = toolbox.show_tool_descriptions(tool_description_template)
|
||||||
|
prompt = prompt_template.replace("<<tool_descriptions>>", tool_descriptions)
|
||||||
|
if "<<tool_names>>" in prompt:
|
||||||
|
tool_names = [f"'{tool_name}'" for tool_name in toolbox.tools.keys()]
|
||||||
|
prompt = prompt.replace("<<tool_names>>", ", ".join(tool_names))
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
def format_prompt_with_imports(prompt_template: str, authorized_imports: List[str]) -> str:
|
||||||
|
if "<<authorized_imports>>" not in prompt_template:
|
||||||
|
raise AgentError("Tag '<<authorized_imports>>' should be provided in the prompt.")
|
||||||
|
return prompt_template.replace("<<authorized_imports>>", str(authorized_imports))
|
||||||
|
|
||||||
|
|
||||||
class Agent:
|
class Agent:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -359,8 +365,14 @@ class Agent:
|
|||||||
self.task += f"\nYou have been provided with these initial arguments: {str(kwargs)}."
|
self.task += f"\nYou have been provided with these initial arguments: {str(kwargs)}."
|
||||||
self.state = kwargs.copy()
|
self.state = kwargs.copy()
|
||||||
self.system_prompt = format_prompt_with_tools(
|
self.system_prompt = format_prompt_with_tools(
|
||||||
self._toolbox, self.system_prompt_template, self.tool_description_template
|
self._toolbox,
|
||||||
|
self.system_prompt_template,
|
||||||
|
self.tool_description_template,
|
||||||
)
|
)
|
||||||
|
if hasattr(self, "authorized_imports"):
|
||||||
|
self.system_prompt = format_prompt_with_imports(
|
||||||
|
self.system_prompt, list(set(LIST_SAFE_MODULES) | set(self.authorized_imports))
|
||||||
|
)
|
||||||
self.logs = [{"system_prompt": self.system_prompt, "task": self.task}]
|
self.logs = [{"system_prompt": self.system_prompt, "task": self.task}]
|
||||||
self.logger.warn("======== New task ========")
|
self.logger.warn("======== New task ========")
|
||||||
self.logger.log(33, self.task)
|
self.logger.log(33, self.task)
|
||||||
@ -496,7 +508,7 @@ class CodeAgent(Agent):
|
|||||||
llm_engine: Callable = HfEngine(),
|
llm_engine: Callable = HfEngine(),
|
||||||
system_prompt: str = DEFAULT_CODE_SYSTEM_PROMPT,
|
system_prompt: str = DEFAULT_CODE_SYSTEM_PROMPT,
|
||||||
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
|
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
|
||||||
additional_authorized_imports: List[str] = [],
|
additional_authorized_imports: Optional[List[str]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@ -515,7 +527,9 @@ class CodeAgent(Agent):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.python_evaluator = evaluate_python_code
|
self.python_evaluator = evaluate_python_code
|
||||||
self.additional_authorized_imports = additional_authorized_imports
|
self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else []
|
||||||
|
self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports))
|
||||||
|
self.system_prompt = self.system_prompt.replace("<<authorized_imports>>", str(self.authorized_imports))
|
||||||
|
|
||||||
def parse_code_blob(self, result: str) -> str:
|
def parse_code_blob(self, result: str) -> str:
|
||||||
"""
|
"""
|
||||||
@ -562,7 +576,13 @@ class CodeAgent(Agent):
|
|||||||
return llm_output
|
return llm_output
|
||||||
|
|
||||||
# Parse
|
# Parse
|
||||||
_, code_action = self.extract_action(llm_output=llm_output, split_token="Code:")
|
try:
|
||||||
|
_, code_action = self.extract_action(llm_output=llm_output, split_token="Code:")
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.debug(
|
||||||
|
f"Error in extracting action, trying to parse the whole output as code. Error trace: {e}"
|
||||||
|
)
|
||||||
|
code_action = llm_output
|
||||||
|
|
||||||
try:
|
try:
|
||||||
code_action = self.parse_code_blob(code_action)
|
code_action = self.parse_code_blob(code_action)
|
||||||
@ -579,7 +599,7 @@ class CodeAgent(Agent):
|
|||||||
code_action,
|
code_action,
|
||||||
available_tools,
|
available_tools,
|
||||||
state=self.state,
|
state=self.state,
|
||||||
authorized_imports=LIST_SAFE_MODULES + self.additional_authorized_imports,
|
authorized_imports=self.authorized_imports,
|
||||||
)
|
)
|
||||||
self.logger.info(self.state["print_outputs"])
|
self.logger.info(self.state["print_outputs"])
|
||||||
return output
|
return output
|
||||||
@ -639,17 +659,12 @@ class ReactAgent(Agent):
|
|||||||
def run(self, task: str, stream: bool = False, **kwargs):
|
def run(self, task: str, stream: bool = False, **kwargs):
|
||||||
"""
|
"""
|
||||||
Runs the agent for the given task.
|
Runs the agent for the given task.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
task (`str`): The task to perform
|
task (`str`): The task to perform
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
```py
|
```py
|
||||||
from transformers.agents import ReactJsonAgent, PythonInterpreterTool
|
from transformers.agents import ReactCodeAgent
|
||||||
|
agent = ReactCodeAgent(tools=[])
|
||||||
python_interpreter = PythonInterpreterTool()
|
|
||||||
agent = ReactJsonAgent(tools=[python_interpreter])
|
|
||||||
agent.run("What is the result of 2 power 3.7384?")
|
agent.run("What is the result of 2 power 3.7384?")
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
@ -820,7 +835,7 @@ class ReactCodeAgent(ReactAgent):
|
|||||||
llm_engine: Callable = HfEngine(),
|
llm_engine: Callable = HfEngine(),
|
||||||
system_prompt: str = DEFAULT_REACT_CODE_SYSTEM_PROMPT,
|
system_prompt: str = DEFAULT_REACT_CODE_SYSTEM_PROMPT,
|
||||||
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
|
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
|
||||||
additional_authorized_imports: List[str] = [],
|
additional_authorized_imports: Optional[List[str]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@ -839,7 +854,9 @@ class ReactCodeAgent(ReactAgent):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.python_evaluator = evaluate_python_code
|
self.python_evaluator = evaluate_python_code
|
||||||
self.additional_authorized_imports = additional_authorized_imports
|
self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else []
|
||||||
|
self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports))
|
||||||
|
self.system_prompt = self.system_prompt.replace("<<authorized_imports>>", str(self.authorized_imports))
|
||||||
|
|
||||||
def step(self):
|
def step(self):
|
||||||
"""
|
"""
|
||||||
@ -871,7 +888,11 @@ class ReactCodeAgent(ReactAgent):
|
|||||||
|
|
||||||
# Parse
|
# Parse
|
||||||
self.logger.debug("===== Extracting action =====")
|
self.logger.debug("===== Extracting action =====")
|
||||||
rationale, raw_code_action = self.extract_action(llm_output=llm_output, split_token="Code:")
|
try:
|
||||||
|
rationale, raw_code_action = self.extract_action(llm_output=llm_output, split_token="Code:")
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.debug(f"Error in extracting action, trying to parse the whole output. Error trace: {e}")
|
||||||
|
rationale, raw_code_action = llm_output, llm_output
|
||||||
|
|
||||||
try:
|
try:
|
||||||
code_action = parse_code_blob(raw_code_action)
|
code_action = parse_code_blob(raw_code_action)
|
||||||
@ -890,7 +911,7 @@ class ReactCodeAgent(ReactAgent):
|
|||||||
code_action,
|
code_action,
|
||||||
available_tools,
|
available_tools,
|
||||||
state=self.state,
|
state=self.state,
|
||||||
authorized_imports=LIST_SAFE_MODULES + self.additional_authorized_imports,
|
authorized_imports=self.authorized_imports,
|
||||||
)
|
)
|
||||||
information = self.state["print_outputs"]
|
information = self.state["print_outputs"]
|
||||||
self.logger.warning("Print outputs:")
|
self.logger.warning("Print outputs:")
|
||||||
|
@ -125,12 +125,13 @@ def setup_default_tools(logger):
|
|||||||
|
|
||||||
for task_name, tool_class_name in TASK_MAPPING.items():
|
for task_name, tool_class_name in TASK_MAPPING.items():
|
||||||
tool_class = getattr(tools_module, tool_class_name)
|
tool_class = getattr(tools_module, tool_class_name)
|
||||||
|
tool_instance = tool_class()
|
||||||
default_tools[tool_class.name] = PreTool(
|
default_tools[tool_class.name] = PreTool(
|
||||||
name=tool_class.name,
|
name=tool_instance.name,
|
||||||
inputs=tool_class.inputs,
|
inputs=tool_instance.inputs,
|
||||||
output_type=tool_class.output_type,
|
output_type=tool_instance.output_type,
|
||||||
task=task_name,
|
task=task_name,
|
||||||
description=tool_class.description,
|
description=tool_instance.description,
|
||||||
repo_id=None,
|
repo_id=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -141,18 +142,25 @@ class PythonInterpreterTool(Tool):
|
|||||||
name = "python_interpreter"
|
name = "python_interpreter"
|
||||||
description = "This is a tool that evaluates python code. It can be used to perform calculations."
|
description = "This is a tool that evaluates python code. It can be used to perform calculations."
|
||||||
|
|
||||||
inputs = {
|
|
||||||
"code": {
|
|
||||||
"type": "text",
|
|
||||||
"description": (
|
|
||||||
"The code snippet to evaluate. All variables used in this snippet must be defined in this same snippet, "
|
|
||||||
f"else you will get an error. This code can only import the following python libraries: {LIST_SAFE_MODULES}."
|
|
||||||
),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
output_type = "text"
|
output_type = "text"
|
||||||
available_tools = BASE_PYTHON_TOOLS.copy()
|
available_tools = BASE_PYTHON_TOOLS.copy()
|
||||||
|
|
||||||
|
def __init__(self, *args, authorized_imports=None, **kwargs):
|
||||||
|
if authorized_imports is None:
|
||||||
|
authorized_imports = list(set(LIST_SAFE_MODULES))
|
||||||
|
else:
|
||||||
|
authorized_imports = list(set(LIST_SAFE_MODULES) | set(authorized_imports))
|
||||||
|
self.inputs = {
|
||||||
|
"code": {
|
||||||
|
"type": "text",
|
||||||
|
"description": (
|
||||||
|
"The code snippet to evaluate. All variables used in this snippet must be defined in this same snippet, "
|
||||||
|
f"else you will get an error. This code can only import the following python libraries: {authorized_imports}."
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
def forward(self, code):
|
def forward(self, code):
|
||||||
output = str(evaluate_python_code(code, tools=self.available_tools))
|
output = str(evaluate_python_code(code, tools=self.available_tools))
|
||||||
return output
|
return output
|
||||||
|
@ -52,6 +52,7 @@ DEFAULT_CODE_SYSTEM_PROMPT = """You will be given a task to solve, your job is t
|
|||||||
To help you, I will give you access to a set of tools that you can use. Each tool is a Python function and has a description explaining the task it performs, the inputs it expects and the outputs it returns.
|
To help you, I will give you access to a set of tools that you can use. Each tool is a Python function and has a description explaining the task it performs, the inputs it expects and the outputs it returns.
|
||||||
You should first explain which tool you will use to perform the task and for what reason, then write the code in Python.
|
You should first explain which tool you will use to perform the task and for what reason, then write the code in Python.
|
||||||
Each instruction in Python should be a simple assignment. You can print intermediate results if it makes sense to do so.
|
Each instruction in Python should be a simple assignment. You can print intermediate results if it makes sense to do so.
|
||||||
|
You can use imports in your code, but only from the following list of modules: <<authorized_imports>>
|
||||||
Be sure to provide a 'Code:' token, else the system will be stuck in a loop.
|
Be sure to provide a 'Code:' token, else the system will be stuck in a loop.
|
||||||
|
|
||||||
Tools:
|
Tools:
|
||||||
@ -263,7 +264,7 @@ Now Begin! If you solve the task correctly, you will receive a reward of $1,000,
|
|||||||
|
|
||||||
|
|
||||||
DEFAULT_REACT_CODE_SYSTEM_PROMPT = """You are an expert assistant who can solve any task using code blobs. You will be given a task to solve as best you can.
|
DEFAULT_REACT_CODE_SYSTEM_PROMPT = """You are an expert assistant who can solve any task using code blobs. You will be given a task to solve as best you can.
|
||||||
To do so, you have been given access to *tools*: these tools are basically Python functions which you can call with code.
|
To do so, you have been given access to a list of tools: these tools are basically Python functions which you can call with code.
|
||||||
To solve the task, you must plan forward to proceed in a series of steps, in a cycle of 'Thought:', 'Code:', and 'Observation:' sequences.
|
To solve the task, you must plan forward to proceed in a series of steps, in a cycle of 'Thought:', 'Code:', and 'Observation:' sequences.
|
||||||
|
|
||||||
At each step, in the 'Thought:' sequence, you should first explain your reasoning towards solving the task and the tools that you want to use.
|
At each step, in the 'Thought:' sequence, you should first explain your reasoning towards solving the task and the tools that you want to use.
|
||||||
@ -356,6 +357,7 @@ Here are the rules you should always follow to solve your task:
|
|||||||
4. Take care to not chain too many sequential tool calls in the same code block, especially when the output format is unpredictable. For instance, a call to search has an unpredictable return format, so do not have another tool call that depends on its output in the same block: rather output results with print() to use them in the next block.
|
4. Take care to not chain too many sequential tool calls in the same code block, especially when the output format is unpredictable. For instance, a call to search has an unpredictable return format, so do not have another tool call that depends on its output in the same block: rather output results with print() to use them in the next block.
|
||||||
5. Call a tool only when needed, and never re-do a tool call that you previously did with the exact same parameters.
|
5. Call a tool only when needed, and never re-do a tool call that you previously did with the exact same parameters.
|
||||||
6. Don't name any new variable with the same name as a tool: for instance don't name a variable 'final_answer'.
|
6. Don't name any new variable with the same name as a tool: for instance don't name a variable 'final_answer'.
|
||||||
|
7. You can use imports in your code, but only from the following list of modules: <<authorized_imports>>
|
||||||
|
|
||||||
Now Begin! If you solve the task correctly, you will receive a reward of $1,000,000.
|
Now Begin! If you solve the task correctly, you will receive a reward of $1,000,000.
|
||||||
"""
|
"""
|
||||||
|
@ -141,15 +141,21 @@ Action:
|
|||||||
def test_init_agent_with_different_toolsets(self):
|
def test_init_agent_with_different_toolsets(self):
|
||||||
toolset_1 = []
|
toolset_1 = []
|
||||||
agent = ReactCodeAgent(tools=toolset_1, llm_engine=fake_react_code_llm)
|
agent = ReactCodeAgent(tools=toolset_1, llm_engine=fake_react_code_llm)
|
||||||
assert len(agent.toolbox.tools) == 1 # contains only final_answer tool
|
assert (
|
||||||
|
len(agent.toolbox.tools) == 1
|
||||||
|
) # when no tools are provided, only the final_answer tool is added by default
|
||||||
|
|
||||||
toolset_2 = [PythonInterpreterTool(), PythonInterpreterTool()]
|
toolset_2 = [PythonInterpreterTool(), PythonInterpreterTool()]
|
||||||
agent = ReactCodeAgent(tools=toolset_2, llm_engine=fake_react_code_llm)
|
agent = ReactCodeAgent(tools=toolset_2, llm_engine=fake_react_code_llm)
|
||||||
assert len(agent.toolbox.tools) == 2 # added final_answer tool
|
assert (
|
||||||
|
len(agent.toolbox.tools) == 2
|
||||||
|
) # deduplication of tools, so only one python_interpreter tool is added in addition to final_answer
|
||||||
|
|
||||||
toolset_3 = Toolbox(toolset_2)
|
toolset_3 = Toolbox(toolset_2)
|
||||||
agent = ReactCodeAgent(tools=toolset_3, llm_engine=fake_react_code_llm)
|
agent = ReactCodeAgent(tools=toolset_3, llm_engine=fake_react_code_llm)
|
||||||
assert len(agent.toolbox.tools) == 2 # added final_answer tool
|
assert (
|
||||||
|
len(agent.toolbox.tools) == 2
|
||||||
|
) # same as previous one, where toolset_3 is an instantiation of previous one
|
||||||
|
|
||||||
# check that add_base_tools will not interfere with existing tools
|
# check that add_base_tools will not interfere with existing tools
|
||||||
with pytest.raises(KeyError) as e:
|
with pytest.raises(KeyError) as e:
|
||||||
|
@ -32,9 +32,19 @@ def add_two(x):
|
|||||||
|
|
||||||
class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin):
|
class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.tool = load_tool("python_interpreter")
|
self.tool = load_tool("python_interpreter", authorized_imports=["sqlite3"])
|
||||||
self.tool.setup()
|
self.tool.setup()
|
||||||
|
|
||||||
|
def test_exact_match_input_spec(self):
|
||||||
|
inputs_spec = self.tool.inputs
|
||||||
|
expected_description = (
|
||||||
|
"The code snippet to evaluate. All variables used in this snippet must be defined in this same snippet, "
|
||||||
|
"else you will get an error. This code can only import the following python libraries: "
|
||||||
|
"['math', 'statistics', 'time', 'itertools', 'stat', 'unicodedata', 'sqlite3', 'queue', 'collections', "
|
||||||
|
"'random', 're']."
|
||||||
|
)
|
||||||
|
self.assertEqual(inputs_spec["code"]["description"], expected_description)
|
||||||
|
|
||||||
def test_exact_match_arg(self):
|
def test_exact_match_arg(self):
|
||||||
result = self.tool("(2 / 2) * 4")
|
result = self.tool("(2 / 2) * 4")
|
||||||
self.assertEqual(result, "4.0")
|
self.assertEqual(result, "4.0")
|
||||||
|
Loading…
Reference in New Issue
Block a user