Agents planning (#31702)

* Allow planning for agents
This commit is contained in:
Aymeric Roucher 2024-07-22 10:49:57 +02:00 committed by GitHub
parent 0fdea8607d
commit b381880597
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 770 additions and 273 deletions

View File

@ -25,7 +25,19 @@ from ..utils.import_utils import is_pygments_available
from .agent_types import AgentAudio, AgentImage, AgentText
from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool, setup_default_tools
from .llm_engine import HfEngine, MessageRole
from .prompts import DEFAULT_CODE_SYSTEM_PROMPT, DEFAULT_REACT_CODE_SYSTEM_PROMPT, DEFAULT_REACT_JSON_SYSTEM_PROMPT
from .prompts import (
DEFAULT_CODE_SYSTEM_PROMPT,
DEFAULT_REACT_CODE_SYSTEM_PROMPT,
DEFAULT_REACT_JSON_SYSTEM_PROMPT,
PLAN_UPDATE_FINAL_PLAN_REDACTION,
SYSTEM_PROMPT_FACTS,
SYSTEM_PROMPT_FACTS_UPDATE,
SYSTEM_PROMPT_PLAN,
SYSTEM_PROMPT_PLAN_UPDATE,
USER_PROMPT_FACTS_UPDATE,
USER_PROMPT_PLAN,
USER_PROMPT_PLAN_UPDATE,
)
from .python_interpreter import LIST_SAFE_MODULES, evaluate_python_code
from .tools import (
DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
@ -99,12 +111,19 @@ def parse_json_blob(json_blob: str) -> Dict[str, str]:
def parse_code_blob(code_blob: str) -> str:
try:
pattern = r"```(?:py|python)?\n(.*?)```"
pattern = r"```(?:py|python)?\n(.*?)\n```"
match = re.search(pattern, code_blob, re.DOTALL)
return match.group(1).strip()
except Exception as e:
raise ValueError(
f"The code blob you used is invalid: due to the following error: {e}. This means that the regex pattern {pattern} was not respected. Make sure to correct its formatting. Code blob was: {code_blob}"
f"""
The code blob you used is invalid: due to the following error: {e}
This means that the regex pattern {pattern} was not respected: make sure to include code with the correct pattern, for instance:
Thoughts: Your thoughts
Code:
```py
# Your python code here
```<end_action>"""
)
@ -113,6 +132,8 @@ def parse_json_tool_call(json_blob: str) -> Tuple[str, Dict[str, str]]:
tool_call = parse_json_blob(json_blob)
if "action" in tool_call and "action_input" in tool_call:
return tool_call["action"], tool_call["action_input"]
elif "action" in tool_call:
return tool_call["action"], None
else:
raise ValueError(
f"Missing keys: {[key for key in ['action', 'action_input'] if key not in tool_call]} in blob {tool_call}"
@ -208,7 +229,7 @@ class Toolbox:
The tool to add to the toolbox.
"""
if tool.name in self._tools:
raise KeyError(f"Error: tool {tool.name} already exists in the toolbox.")
raise KeyError(f"Error: tool '{tool.name}' already exists in the toolbox.")
self._tools[tool.name] = tool
def remove_tool(self, tool_name: str):
@ -359,12 +380,8 @@ class Agent:
"""Get the toolbox currently available to the agent"""
return self._toolbox
def initialize_for_run(self, task: str, **kwargs):
def initialize_for_run(self):
self.token_count = 0
self.task = task
if len(kwargs) > 0:
self.task += f"\nYou have been provided with these initial arguments: {str(kwargs)}."
self.state = kwargs.copy()
self.system_prompt = format_prompt_with_tools(
self._toolbox,
self.system_prompt_template,
@ -380,7 +397,7 @@ class Agent:
self.logger.debug("System prompt is as follows:")
self.logger.debug(self.system_prompt)
def write_inner_memory_from_logs(self) -> List[Dict[str, str]]:
def write_inner_memory_from_logs(self, summary_mode: Optional[bool] = False) -> List[Dict[str, str]]:
"""
Reads past llm_outputs, actions, and observations or errors from the logs into a series of messages
that can be used as input to the LLM.
@ -390,43 +407,51 @@ class Agent:
"role": MessageRole.USER,
"content": "Task: " + self.logs[0]["task"],
}
memory = [prompt_message, task_message]
if summary_mode:
memory = [task_message]
else:
memory = [prompt_message, task_message]
for i, step_log in enumerate(self.logs[1:]):
if "llm_output" in step_log:
thought_message = {"role": MessageRole.ASSISTANT, "content": step_log["llm_output"] + "\n"}
if "llm_output" in step_log and not summary_mode:
thought_message = {"role": MessageRole.ASSISTANT, "content": step_log["llm_output"].strip()}
memory.append(thought_message)
if "facts" in step_log:
thought_message = {
"role": MessageRole.ASSISTANT,
"content": "[FACTS LIST]:\n" + step_log["facts"].strip(),
}
memory.append(thought_message)
if "error" in step_log:
message_content = (
"Error: "
+ str(step_log["error"])
+ "\nNow let's retry: take care not to repeat previous errors! If you have retried several times, try a completely different approach.\n"
)
elif "observation" in step_log:
message_content = f"Observation: {step_log['observation']}"
tool_response_message = {"role": MessageRole.TOOL_RESPONSE, "content": message_content}
memory.append(tool_response_message)
if "plan" in step_log and not summary_mode:
thought_message = {"role": MessageRole.ASSISTANT, "content": "[PLAN]:\n" + step_log["plan"].strip()}
memory.append(thought_message)
if "tool_call" in step_log and summary_mode:
tool_call_message = {
"role": MessageRole.ASSISTANT,
"content": f"[STEP {i} TOOL CALL]: " + str(step_log["tool_call"]).strip(),
}
memory.append(tool_call_message)
if "task" in step_log:
tool_call_message = {
"role": MessageRole.USER,
"content": "New task:\n" + step_log["task"],
}
memory.append(tool_call_message)
if "error" in step_log or "observation" in step_log:
if "error" in step_log:
message_content = (
f"[OUTPUT OF STEP {i}] Error: "
+ str(step_log["error"])
+ "\nNow let's retry: take care not to repeat previous errors! If you have retried several times, try a completely different approach.\n"
)
elif "observation" in step_log:
message_content = f"[OUTPUT OF STEP {i}] Observation:\n{step_log['observation']}"
tool_response_message = {"role": MessageRole.TOOL_RESPONSE, "content": message_content}
memory.append(tool_response_message)
if len(memory) % 3 == 0:
reminder_content = (
"Reminder: you are working towards solving the following task: " + self.logs[0]["task"]
)
reminder_content += "\nHere is a summary of your past tool calls and their results:"
for j in range(i + 1):
reminder_content += "\nStep " + str(j + 1)
if "tool_call" in self.logs[j]:
reminder_content += "\nTool call:" + str(self.logs[j]["tool_call"])
if self.memory_verbose:
if "observation" in self.logs[j]:
reminder_content += "\nObservation:" + str(self.logs[j]["observation"])
if "error" in self.logs[j]:
reminder_content += "\nError:" + str(self.logs[j]["error"])
memory.append(
{
"role": MessageRole.USER,
"content": reminder_content,
}
)
return memory
def get_succinct_logs(self):
@ -459,7 +484,7 @@ class Agent:
This method replaces arguments with the actual values from the state if they refer to state variables.
Args:
tool_name (`str`): Name of the Tool to execute (shoulde be one from self.toolbox).
tool_name (`str`): Name of the Tool to execute (should be one from self.toolbox).
arguments (Dict[str, str]): Arguments passed to the Tool.
"""
if tool_name not in self.toolbox.tools:
@ -559,7 +584,11 @@ class CodeAgent(Agent):
agent.run("What is the result of 2 power 3.7384?")
```
"""
self.initialize_for_run(task, **kwargs)
self.task = task
if len(kwargs) > 0:
self.task += f"\nYou have been provided with these initial arguments: {str(kwargs)}."
self.state = kwargs.copy()
self.initialize_for_run()
# Run LLM
prompt_message = {"role": MessageRole.SYSTEM, "content": self.system_prompt}
@ -598,7 +627,8 @@ class CodeAgent(Agent):
available_tools = {**BASE_PYTHON_TOOLS.copy(), **self.toolbox.tools}
output = self.python_evaluator(
code_action,
available_tools,
static_tools=available_tools,
custom_tools={},
state=self.state,
authorized_imports=self.authorized_imports,
)
@ -623,6 +653,7 @@ class ReactAgent(Agent):
llm_engine: Callable = HfEngine(),
system_prompt: str = DEFAULT_REACT_CODE_SYSTEM_PROMPT,
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
planning_interval: Optional[int] = None,
**kwargs,
):
super().__init__(
@ -632,6 +663,7 @@ class ReactAgent(Agent):
tool_description_template=tool_description_template,
**kwargs,
)
self.planning_interval = planning_interval
def provide_final_answer(self, task) -> str:
"""
@ -655,11 +687,13 @@ class ReactAgent(Agent):
except Exception as e:
return f"Error in generating final llm output: {e}."
def run(self, task: str, stream: bool = False, **kwargs):
def run(self, task: str, stream: bool = False, reset: bool = True, **kwargs):
"""
Runs the agent for the given task.
Args:
task (`str`): The task to perform
Example:
```py
from transformers.agents import ReactCodeAgent
@ -667,14 +701,23 @@ class ReactAgent(Agent):
agent.run("What is the result of 2 power 3.7384?")
```
"""
if stream:
return self.stream_run(task, **kwargs)
self.task = task
if len(kwargs) > 0:
self.task += f"\nYou have been provided with these initial arguments: {str(kwargs)}."
self.state = kwargs.copy()
if reset:
self.initialize_for_run()
else:
return self.direct_run(task, **kwargs)
def stream_run(self, task: str, **kwargs):
self.initialize_for_run(task, **kwargs)
self.logs.append({"task": task})
if stream:
return self.stream_run(task)
else:
return self.direct_run(task)
def stream_run(self, task: str):
"""
Runs the agent in streaming mode, yielding steps as they are executed: should be launched only in the `run` method.
"""
final_answer = None
iteration = 0
while final_answer is None and iteration < self.max_iterations:
@ -700,13 +743,16 @@ class ReactAgent(Agent):
yield final_answer
def direct_run(self, task: str, **kwargs):
self.initialize_for_run(task, **kwargs)
def direct_run(self, task: str):
"""
Runs the agent in direct mode, returning outputs only at the end: should be launched only in the `run` method.
"""
final_answer = None
iteration = 0
while final_answer is None and iteration < self.max_iterations:
try:
if self.planning_interval is not None and iteration % self.planning_interval == 0:
self.planning_step(task, is_first_step=(iteration == 0), iteration=iteration)
step_logs = self.step()
if "final_answer" in step_logs:
final_answer = step_logs["final_answer"]
@ -726,6 +772,96 @@ class ReactAgent(Agent):
return final_answer
def planning_step(self, task, is_first_step: bool = False, iteration: int = None):
"""
Used periodically by the agent to plan the next steps to reach the objective.
Args:
task (`str`): The task to perform
is_first_step (`bool`): If this step is not the first one, the plan should be an update over a previous plan.
iteration (`int`): The number of the current step, used as an indication for the LLM.
"""
if is_first_step:
message_prompt_facts = {"role": MessageRole.SYSTEM, "content": SYSTEM_PROMPT_FACTS}
message_prompt_task = {
"role": MessageRole.USER,
"content": f"""Here is the task:
```
{task}
```
Now begin!""",
}
answer_facts = self.llm_engine([message_prompt_facts, message_prompt_task])
message_system_prompt_plan = {"role": MessageRole.SYSTEM, "content": SYSTEM_PROMPT_PLAN}
message_user_prompt_plan = {
"role": MessageRole.USER,
"content": USER_PROMPT_PLAN.format(
task=task,
tool_descriptions=self._toolbox.show_tool_descriptions(self.tool_description_template),
answer_facts=answer_facts,
),
}
answer_plan = self.llm_engine(
[message_system_prompt_plan, message_user_prompt_plan], stop_sequences=["<end_plan>"]
)
final_plan_redaction = f"""Here is the plan of action that I will follow to solve the task:
```
{answer_plan}
```"""
final_facts_redaction = f"""Here are the facts that I know so far:
```
{answer_facts}
```""".strip()
self.logs.append({"plan": final_plan_redaction, "facts": final_facts_redaction})
self.logger.debug("===== Initial plan: =====")
self.logger.debug(final_plan_redaction)
else: # update plan
agent_memory = self.write_inner_memory_from_logs(
summary_mode=False
) # This will not log the plan but will log facts
# Redact updated facts
facts_update_system_prompt = {
"role": MessageRole.SYSTEM,
"content": SYSTEM_PROMPT_FACTS_UPDATE,
}
facts_update_message = {
"role": MessageRole.USER,
"content": USER_PROMPT_FACTS_UPDATE,
}
facts_update = self.llm_engine([facts_update_system_prompt] + agent_memory + [facts_update_message])
# Redact updated plan
plan_update_message = {
"role": MessageRole.SYSTEM,
"content": SYSTEM_PROMPT_PLAN_UPDATE.format(task=task),
}
plan_update_message_user = {
"role": MessageRole.USER,
"content": USER_PROMPT_PLAN_UPDATE.format(
task=task,
tool_descriptions=self._toolbox.show_tool_descriptions(self.tool_description_template),
facts_update=facts_update,
remaining_steps=(self.max_iterations - iteration),
),
}
plan_update = self.llm_engine(
[plan_update_message] + agent_memory + [plan_update_message_user], stop_sequences=["<end_plan>"]
)
# Log final facts and plan
final_plan_redaction = PLAN_UPDATE_FINAL_PLAN_REDACTION.format(task=task, plan_update=plan_update)
final_facts_redaction = f"""Here is the updated list of the facts that I know:
```
{facts_update}
```"""
self.logs.append({"plan": final_plan_redaction, "facts": final_facts_redaction})
self.logger.debug("===== Updated plan: =====")
self.logger.debug(final_plan_redaction)
class ReactJsonAgent(ReactAgent):
"""
@ -740,6 +876,7 @@ class ReactJsonAgent(ReactAgent):
llm_engine: Callable = HfEngine(),
system_prompt: str = DEFAULT_REACT_JSON_SYSTEM_PROMPT,
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
planning_interval: Optional[int] = None,
**kwargs,
):
super().__init__(
@ -747,6 +884,7 @@ class ReactJsonAgent(ReactAgent):
llm_engine=llm_engine,
system_prompt=system_prompt,
tool_description_template=tool_description_template,
planning_interval=planning_interval,
**kwargs,
)
@ -792,11 +930,16 @@ class ReactJsonAgent(ReactAgent):
self.logger.warning(f"Calling tool: '{tool_name}' with arguments: {arguments}")
if tool_name == "final_answer":
if isinstance(arguments, dict):
answer = arguments["answer"]
if "answer" in arguments:
answer = arguments["answer"]
if (
isinstance(answer, str) and answer in self.state.keys()
): # if the answer is a state variable, return the value
answer = self.state[answer]
else:
answer = arguments
else:
answer = arguments
if answer in self.state: # if the answer is a state variable, return the value
answer = self.state[answer]
current_step_logs["final_answer"] = answer
return current_step_logs
else:
@ -835,6 +978,7 @@ class ReactCodeAgent(ReactAgent):
system_prompt: str = DEFAULT_REACT_CODE_SYSTEM_PROMPT,
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
additional_authorized_imports: Optional[List[str]] = None,
planning_interval: Optional[int] = None,
**kwargs,
):
super().__init__(
@ -842,6 +986,7 @@ class ReactCodeAgent(ReactAgent):
llm_engine=llm_engine,
system_prompt=system_prompt,
tool_description_template=tool_description_template,
planning_interval=planning_interval,
**kwargs,
)
@ -856,10 +1001,7 @@ class ReactCodeAgent(ReactAgent):
self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else []
self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports))
self.system_prompt = self.system_prompt.replace("<<authorized_imports>>", str(self.authorized_imports))
self.available_tools = {
**BASE_PYTHON_TOOLS.copy(),
**self.toolbox.tools,
} # This list can be augmented by the code agent creating some new functions
self.custom_tools = {}
def step(self):
"""
@ -911,7 +1053,11 @@ class ReactCodeAgent(ReactAgent):
try:
result = self.python_evaluator(
code_action,
tools=self.available_tools,
static_tools={
**BASE_PYTHON_TOOLS.copy(),
**self.toolbox.tools,
},
custom_tools=self.custom_tools,
state=self.state,
authorized_imports=self.authorized_imports,
)
@ -920,7 +1066,7 @@ class ReactCodeAgent(ReactAgent):
self.logger.log(32, information)
current_step_logs["observation"] = information
except Exception as e:
error_msg = f"Failed while trying to execute the code below:\n{CustomFormatter.reset + code_action + CustomFormatter.reset}\nThis failed due to the following error:\n{str(e)}"
error_msg = f"Code execution failed due to the following error:\n{str(e)}"
if "'dict' object has no attribute 'read'" in str(e):
error_msg += "\nYou get this error because you passed a dict as input for one of the arguments instead of a string."
raise AgentExecutionError(error_msg)

View File

@ -173,7 +173,7 @@ class PythonInterpreterTool(Tool):
def forward(self, code):
output = str(
evaluate_python_code(code, tools=self.available_tools, authorized_imports=self.authorized_imports)
evaluate_python_code(code, static_tools=self.available_tools, authorized_imports=self.authorized_imports)
)
return output

View File

@ -365,7 +365,118 @@ Here are the rules you should always follow to solve your task:
6. Don't name any new variable with the same name as a tool: for instance don't name a variable 'final_answer'.
7. Never create any notional variables in our code, as having these in your logs might derail you from the true variables.
8. You can use imports in your code, but only from the following list of modules: <<authorized_imports>>
9. Don't give up! You're in charge of solving the task, not providing directions to solve it.
9. The state persists between code executions: so if in one step you've created variables or imported modules, these will all persist.
10. Don't give up! You're in charge of solving the task, not providing directions to solve it.
Now Begin! If you solve the task correctly, you will receive a reward of $1,000,000.
"""
SYSTEM_PROMPT_FACTS = """Below I will present you a task.
You will now build a comprehensive preparatory survey of which facts we have at our disposal and which ones we still need.
To do so, you will have to read the task and identify things that must be discovered in order to successfully complete it.
Don't make any assumptions. For each item, provide a thorough reasoning. Here is how you will structure this survey:
---
### 1. Facts given in the task
List here the specific facts given in the task that could help you (there might be nothing here).
### 2. Facts to look up
List here any facts that we may need to look up.
Also list where to find each of these, for instance a website, a file... - maybe the task contains some sources that you should re-use here.
### 3. Facts to derive
List here anything that we want to derive from the above by logical reasoning, for instance computation or simulation.
Keep in mind that "facts" will typically be specific names, dates, values, etc. Your answer should use the below headings:
### 1. Facts given in the task
### 2. Facts to look up
### 3. Facts to derive
Do not add anything else."""
SYSTEM_PROMPT_PLAN = """You are a world expert at making efficient plans to solve any task using a set of carefully crafted tools.
Now for the given task, develop a step-by-step high-level plan taking into account the above inputs and list of facts.
This plan should involve individual tasks based on the avilable tools, that if executed correctly will yield the correct answer.
Do not skip steps, do not add any superfluous steps. Only write the high-level plan, DO NOT DETAIL INDIVIDUAL TOOL CALLS.
After writing the final step of the plan, write the '\n<end_plan>' tag and stop there."""
USER_PROMPT_PLAN = """
Here is your task:
Task:
```
{task}
```
Your plan can leverage any of these tools:
{tool_descriptions}
List of facts that you know:
```
{answer_facts}
```
Now begin! Write your plan below."""
SYSTEM_PROMPT_FACTS_UPDATE = """
You are a world expert at gathering known and unknown facts based on a conversation.
Below you will find a task, and ahistory of attempts made to solve the task. You will have to produce a list of these:
### 1. Facts given in the task
### 2. Facts that we have learned
### 3. Facts still to look up
### 4. Facts still to derive
Find the task and history below."""
USER_PROMPT_FACTS_UPDATE = """Earlier we've built a list of facts.
But since in your previous steps you may have learned useful new facts or invalidated some false ones.
Please update your list of facts based on the previous history, and provide these headings:
### 1. Facts given in the task
### 2. Facts that we have learned
### 3. Facts still to look up
### 4. Facts still to derive
Now write your new list of facts below."""
SYSTEM_PROMPT_PLAN_UPDATE = """You are a world expert at making efficient plans to solve any task using a set of carefully crafted tools.
You have been given a task:
```
{task}
```
Find below the record of what has been tried so far to solve it. Then you will be asked to make an updated plan to solve the task.
If the previous tries so far have met some success, you can make an updated plan based on these actions.
If you are stalled, you can make a completely new plan starting from scratch.
"""
USER_PROMPT_PLAN_UPDATE = """You're still working towards solving this task:
```
{task}
```
You have access to these tools:
{tool_descriptions}
Here is the up to date list of facts that you know:
```
{facts_update}
```
Now for the given task, develop a step-by-step high-level plan taking into account the above inputs and list of facts.
This plan should involve individual tasks based on the avilable tools, that if executed correctly will yield the correct answer.
Beware that you have {remaining_steps} steps remaining.
Do not skip steps, do not add any superfluous steps. Only write the high-level plan, DO NOT DETAIL INDIVIDUAL TOOL CALLS.
After writing the final step of the plan, write the '\n<end_plan>' tag and stop there.
Now write your new plan below."""
PLAN_UPDATE_FINAL_PLAN_REDACTION = """I still need to solve the task I was given:
```
{task}
```
Here is my new/updated plan of action to solve the task:
```
{plan_update}
```"""

View File

@ -18,8 +18,17 @@ import ast
import builtins
import difflib
from collections.abc import Mapping
from importlib import import_module
from typing import Any, Callable, Dict, List, Optional
import numpy as np
from ..utils import is_pandas_available
if is_pandas_available():
import pandas as pd
class InterpreterError(ValueError):
"""
@ -50,7 +59,8 @@ LIST_SAFE_MODULES = [
"unicodedata",
]
PRINT_OUTPUTS = ""
PRINT_OUTPUTS, MAX_LEN_OUTPUT = "", 50000
OPERATIONS_COUNT, MAX_OPERATIONS = 0, 10000000
class BreakException(Exception):
@ -75,8 +85,8 @@ def get_iterable(obj):
raise InterpreterError("Object is not iterable")
def evaluate_unaryop(expression, state, tools):
operand = evaluate_ast(expression.operand, state, tools)
def evaluate_unaryop(expression, state, static_tools, custom_tools):
operand = evaluate_ast(expression.operand, state, static_tools, custom_tools)
if isinstance(expression.op, ast.USub):
return -operand
elif isinstance(expression.op, ast.UAdd):
@ -89,25 +99,25 @@ def evaluate_unaryop(expression, state, tools):
raise InterpreterError(f"Unary operation {expression.op.__class__.__name__} is not supported.")
def evaluate_lambda(lambda_expression, state, tools):
def evaluate_lambda(lambda_expression, state, static_tools, custom_tools):
args = [arg.arg for arg in lambda_expression.args.args]
def lambda_func(*values):
new_state = state.copy()
for arg, value in zip(args, values):
new_state[arg] = value
return evaluate_ast(lambda_expression.body, new_state, tools)
return evaluate_ast(lambda_expression.body, new_state, static_tools, custom_tools)
return lambda_func
def evaluate_while(while_loop, state, tools):
def evaluate_while(while_loop, state, static_tools, custom_tools):
max_iterations = 1000
iterations = 0
while evaluate_ast(while_loop.test, state, tools):
while evaluate_ast(while_loop.test, state, static_tools, custom_tools):
for node in while_loop.body:
try:
evaluate_ast(node, state, tools)
evaluate_ast(node, state, static_tools, custom_tools)
except BreakException:
return None
except ContinueException:
@ -118,11 +128,11 @@ def evaluate_while(while_loop, state, tools):
return None
def create_function(func_def, state, tools):
def create_function(func_def, state, static_tools, custom_tools):
def new_func(*args, **kwargs):
func_state = state.copy()
arg_names = [arg.arg for arg in func_def.args.args]
default_values = [evaluate_ast(d, state, tools) for d in func_def.args.defaults]
default_values = [evaluate_ast(d, state, static_tools, custom_tools) for d in func_def.args.defaults]
# Apply default values
defaults = dict(zip(arg_names[-len(default_values) :], default_values))
@ -158,7 +168,7 @@ def create_function(func_def, state, tools):
result = None
try:
for stmt in func_def.body:
result = evaluate_ast(stmt, func_state, tools)
result = evaluate_ast(stmt, func_state, static_tools, custom_tools)
except ReturnException as e:
result = e.value
return result
@ -173,25 +183,25 @@ def create_class(class_name, class_bases, class_body):
return type(class_name, tuple(class_bases), class_dict)
def evaluate_function_def(func_def, state, tools):
tools[func_def.name] = create_function(func_def, state, tools)
return tools[func_def.name]
def evaluate_function_def(func_def, state, static_tools, custom_tools):
custom_tools[func_def.name] = create_function(func_def, state, static_tools, custom_tools)
return custom_tools[func_def.name]
def evaluate_class_def(class_def, state, tools):
def evaluate_class_def(class_def, state, static_tools, custom_tools):
class_name = class_def.name
bases = [evaluate_ast(base, state, tools) for base in class_def.bases]
bases = [evaluate_ast(base, state, static_tools, custom_tools) for base in class_def.bases]
class_dict = {}
for stmt in class_def.body:
if isinstance(stmt, ast.FunctionDef):
class_dict[stmt.name] = evaluate_function_def(stmt, state, tools)
class_dict[stmt.name] = evaluate_function_def(stmt, state, static_tools, custom_tools)
elif isinstance(stmt, ast.Assign):
for target in stmt.targets:
if isinstance(target, ast.Name):
class_dict[target.id] = evaluate_ast(stmt.value, state, tools)
class_dict[target.id] = evaluate_ast(stmt.value, state, static_tools, custom_tools)
elif isinstance(target, ast.Attribute):
class_dict[target.attr] = evaluate_ast(stmt.value, state, tools)
class_dict[target.attr] = evaluate_ast(stmt.value, state, static_tools, custom_tools)
else:
raise InterpreterError(f"Unsupported statement in class body: {stmt.__class__.__name__}")
@ -200,17 +210,17 @@ def evaluate_class_def(class_def, state, tools):
return new_class
def evaluate_augassign(expression: ast.AugAssign, state: Dict[str, Any], tools: Dict[str, Callable]):
def evaluate_augassign(expression, state, static_tools, custom_tools):
# Helper function to get current value and set new value based on the target type
def get_current_value(target):
if isinstance(target, ast.Name):
return state.get(target.id, 0)
elif isinstance(target, ast.Subscript):
obj = evaluate_ast(target.value, state, tools)
key = evaluate_ast(target.slice, state, tools)
obj = evaluate_ast(target.value, state, static_tools, custom_tools)
key = evaluate_ast(target.slice, state, static_tools, custom_tools)
return obj[key]
elif isinstance(target, ast.Attribute):
obj = evaluate_ast(target.value, state, tools)
obj = evaluate_ast(target.value, state, static_tools, custom_tools)
return getattr(obj, target.attr)
elif isinstance(target, ast.Tuple):
return tuple(get_current_value(elt) for elt in target.elts)
@ -220,7 +230,7 @@ def evaluate_augassign(expression: ast.AugAssign, state: Dict[str, Any], tools:
raise InterpreterError("AugAssign not supported for {type(target)} targets.")
current_value = get_current_value(expression.target)
value_to_add = evaluate_ast(expression.value, state, tools)
value_to_add = evaluate_ast(expression.value, state, static_tools, custom_tools)
# Determine the operation and apply it
if isinstance(expression.op, ast.Add):
@ -256,28 +266,28 @@ def evaluate_augassign(expression: ast.AugAssign, state: Dict[str, Any], tools:
raise InterpreterError(f"Operation {type(expression.op).__name__} is not supported.")
# Update the state
set_value(expression.target, updated_value, state, tools)
set_value(expression.target, updated_value, state, static_tools, custom_tools)
return updated_value
def evaluate_boolop(node, state, tools):
def evaluate_boolop(node, state, static_tools, custom_tools):
if isinstance(node.op, ast.And):
for value in node.values:
if not evaluate_ast(value, state, tools):
if not evaluate_ast(value, state, static_tools, custom_tools):
return False
return True
elif isinstance(node.op, ast.Or):
for value in node.values:
if evaluate_ast(value, state, tools):
if evaluate_ast(value, state, static_tools, custom_tools):
return True
return False
def evaluate_binop(binop, state, tools):
def evaluate_binop(binop, state, static_tools, custom_tools):
# Recursively evaluate the left and right operands
left_val = evaluate_ast(binop.left, state, tools)
right_val = evaluate_ast(binop.right, state, tools)
left_val = evaluate_ast(binop.left, state, static_tools, custom_tools)
right_val = evaluate_ast(binop.right, state, static_tools, custom_tools)
# Determine the operation based on the type of the operator in the BinOp
if isinstance(binop.op, ast.Add):
@ -308,66 +318,92 @@ def evaluate_binop(binop, state, tools):
raise NotImplementedError(f"Binary operation {type(binop.op).__name__} is not implemented.")
def evaluate_assign(assign, state, tools):
result = evaluate_ast(assign.value, state, tools)
def evaluate_assign(assign, state, static_tools, custom_tools):
result = evaluate_ast(assign.value, state, static_tools, custom_tools)
if len(assign.targets) == 1:
target = assign.targets[0]
set_value(target, result, state, tools)
set_value(target, result, state, static_tools, custom_tools)
else:
if len(assign.targets) != len(result):
raise InterpreterError(f"Assign failed: expected {len(result)} values but got {len(assign.targets)}.")
for tgt, val in zip(assign.targets, result):
set_value(tgt, val, state, tools)
expanded_values = []
for tgt in assign.targets:
if isinstance(tgt, ast.Starred):
expanded_values.extend(result)
else:
expanded_values.append(result)
for tgt, val in zip(assign.targets, expanded_values):
set_value(tgt, val, state, static_tools, custom_tools)
return result
def set_value(target, value, state, tools):
def set_value(target, value, state, static_tools, custom_tools):
if isinstance(target, ast.Name):
if target.id in tools:
if target.id in static_tools:
raise InterpreterError(f"Cannot assign to name '{target.id}': doing this would erase the existing tool!")
state[target.id] = value
elif isinstance(target, ast.Tuple):
if not isinstance(value, tuple):
raise InterpreterError("Cannot unpack non-tuple value")
if hasattr(value, "__iter__") and not isinstance(value, (str, bytes)):
value = tuple(value)
else:
raise InterpreterError("Cannot unpack non-tuple value")
if len(target.elts) != len(value):
raise InterpreterError("Cannot unpack tuple of wrong size")
for i, elem in enumerate(target.elts):
set_value(elem, value[i], state, tools)
set_value(elem, value[i], state, static_tools, custom_tools)
elif isinstance(target, ast.Subscript):
obj = evaluate_ast(target.value, state, tools)
key = evaluate_ast(target.slice, state, tools)
obj = evaluate_ast(target.value, state, static_tools, custom_tools)
key = evaluate_ast(target.slice, state, static_tools, custom_tools)
obj[key] = value
elif isinstance(target, ast.Attribute):
obj = evaluate_ast(target.value, state, tools)
obj = evaluate_ast(target.value, state, static_tools, custom_tools)
setattr(obj, target.attr, value)
def evaluate_call(call, state, tools):
def evaluate_call(call, state, static_tools, custom_tools):
if not (isinstance(call.func, ast.Attribute) or isinstance(call.func, ast.Name)):
raise InterpreterError(
f"It is not permitted to evaluate other functions than the provided tools (tried to execute {call.func})."
)
raise InterpreterError(f"This is not a correct function: {call.func}).")
if isinstance(call.func, ast.Attribute):
obj = evaluate_ast(call.func.value, state, tools)
obj = evaluate_ast(call.func.value, state, static_tools, custom_tools)
func_name = call.func.attr
if not hasattr(obj, func_name):
raise InterpreterError(f"Object {obj} has no attribute {func_name}")
func = getattr(obj, func_name)
elif isinstance(call.func, ast.Name):
func_name = call.func.id
if func_name in state:
func = state[func_name]
elif func_name in tools:
func = tools[func_name]
elif func_name in static_tools:
func = static_tools[func_name]
elif func_name in custom_tools:
func = custom_tools[func_name]
elif func_name in ERRORS:
func = ERRORS[func_name]
else:
raise InterpreterError(
f"It is not permitted to evaluate other functions than the provided tools or imported functions (tried to execute {call.func.id})."
f"It is not permitted to evaluate other functions than the provided tools or functions defined in previous code (tried to execute {call.func.id})."
)
args = [evaluate_ast(arg, state, tools) for arg in call.args]
kwargs = {keyword.arg: evaluate_ast(keyword.value, state, tools) for keyword in call.keywords}
args = []
for arg in call.args:
if isinstance(arg, ast.Starred):
args.extend(evaluate_ast(arg.value, state, static_tools, custom_tools))
else:
args.append(evaluate_ast(arg, state, static_tools, custom_tools))
args = []
for arg in call.args:
if isinstance(arg, ast.Starred):
unpacked = evaluate_ast(arg.value, state, static_tools, custom_tools)
if not hasattr(unpacked, "__iter__") or isinstance(unpacked, (str, bytes)):
raise InterpreterError(f"Cannot unpack non-iterable value {unpacked}")
args.extend(unpacked)
else:
args.append(evaluate_ast(arg, state, static_tools, custom_tools))
kwargs = {keyword.arg: evaluate_ast(keyword.value, state, static_tools, custom_tools) for keyword in call.keywords}
if isinstance(func, type) and len(func.__module__.split(".")) > 1: # Check for user-defined classes
# Instantiate the class using its constructor
@ -397,24 +433,31 @@ def evaluate_call(call, state, tools):
output = " ".join(map(str, args))
global PRINT_OUTPUTS
PRINT_OUTPUTS += output + "\n"
# cap the number of lines
return output
else: # Assume it's a callable object
output = func(*args, **kwargs)
return output
def evaluate_subscript(subscript, state, tools):
index = evaluate_ast(subscript.slice, state, tools)
value = evaluate_ast(subscript.value, state, tools)
if isinstance(index, slice):
def evaluate_subscript(subscript, state, static_tools, custom_tools):
index = evaluate_ast(subscript.slice, state, static_tools, custom_tools)
value = evaluate_ast(subscript.value, state, static_tools, custom_tools)
if isinstance(value, pd.core.indexing._LocIndexer):
parent_object = value.obj
return parent_object.loc[index]
if isinstance(value, (pd.DataFrame, pd.Series, np.ndarray)):
return value[index]
elif isinstance(value, pd.core.groupby.generic.DataFrameGroupBy):
return value[index]
elif isinstance(index, slice):
return value[index]
elif isinstance(value, (list, tuple)):
# Ensure the index is within bounds
if not (-len(value) <= index < len(value)):
raise InterpreterError(f"Index {index} out of bounds for list of length {len(value)}")
return value[int(index)]
elif isinstance(value, str):
# Ensure the index is within bounds
if not (-len(value) <= index < len(value)):
raise InterpreterError(f"Index {index} out of bounds for string of length {len(value)}")
return value[index]
@ -427,11 +470,11 @@ def evaluate_subscript(subscript, state, tools):
raise InterpreterError(f"Could not index {value} with '{index}'.")
def evaluate_name(name, state, tools):
def evaluate_name(name, state, static_tools, custom_tools):
if name.id in state:
return state[name.id]
elif name.id in tools:
return tools[name.id]
elif name.id in static_tools:
return static_tools[name.id]
elif name.id in ERRORS:
return ERRORS[name.id]
close_matches = difflib.get_close_matches(name.id, list(state.keys()))
@ -440,9 +483,9 @@ def evaluate_name(name, state, tools):
raise InterpreterError(f"The variable `{name.id}` is not defined.")
def evaluate_condition(condition, state, tools):
left = evaluate_ast(condition.left, state, tools)
comparators = [evaluate_ast(c, state, tools) for c in condition.comparators]
def evaluate_condition(condition, state, static_tools, custom_tools):
left = evaluate_ast(condition.left, state, static_tools, custom_tools)
comparators = [evaluate_ast(c, state, static_tools, custom_tools) for c in condition.comparators]
ops = [type(op) for op in condition.ops]
result = True
@ -450,63 +493,61 @@ def evaluate_condition(condition, state, tools):
for op, comparator in zip(ops, comparators):
if op == ast.Eq:
result = result and (current_left == comparator)
current_result = current_left == comparator
elif op == ast.NotEq:
result = result and (current_left != comparator)
current_result = current_left != comparator
elif op == ast.Lt:
result = result and (current_left < comparator)
current_result = current_left < comparator
elif op == ast.LtE:
result = result and (current_left <= comparator)
current_result = current_left <= comparator
elif op == ast.Gt:
result = result and (current_left > comparator)
current_result = current_left > comparator
elif op == ast.GtE:
result = result and (current_left >= comparator)
current_result = current_left >= comparator
elif op == ast.Is:
result = result and (current_left is comparator)
current_result = current_left is comparator
elif op == ast.IsNot:
result = result and (current_left is not comparator)
current_result = current_left is not comparator
elif op == ast.In:
result = result and (current_left in comparator)
current_result = current_left in comparator
elif op == ast.NotIn:
result = result and (current_left not in comparator)
current_result = current_left not in comparator
else:
raise InterpreterError(f"Operator not supported: {op}")
result = result & current_result
current_left = comparator
if not result:
if isinstance(result, bool) and not result:
break
return result
return result if isinstance(result, (bool, pd.Series)) else result.all()
def evaluate_if(if_statement, state, tools):
def evaluate_if(if_statement, state, static_tools, custom_tools):
result = None
test_result = evaluate_ast(if_statement.test, state, tools)
test_result = evaluate_ast(if_statement.test, state, static_tools, custom_tools)
if test_result:
for line in if_statement.body:
line_result = evaluate_ast(line, state, tools)
line_result = evaluate_ast(line, state, static_tools, custom_tools)
if line_result is not None:
result = line_result
else:
for line in if_statement.orelse:
line_result = evaluate_ast(line, state, tools)
line_result = evaluate_ast(line, state, static_tools, custom_tools)
if line_result is not None:
result = line_result
return result
def evaluate_for(for_loop, state, tools):
def evaluate_for(for_loop, state, static_tools, custom_tools):
result = None
iterator = evaluate_ast(for_loop.iter, state, tools)
iterator = evaluate_ast(for_loop.iter, state, static_tools, custom_tools)
for counter in iterator:
if isinstance(for_loop.target, ast.Tuple):
for i, elem in enumerate(for_loop.target.elts):
state[elem.id] = counter[i]
else:
state[for_loop.target.id] = counter
set_value(for_loop.target, counter, state, static_tools, custom_tools)
for node in for_loop.body:
try:
line_result = evaluate_ast(node, state, tools)
line_result = evaluate_ast(node, state, static_tools, custom_tools)
if line_result is not None:
result = line_result
except BreakException:
@ -519,55 +560,60 @@ def evaluate_for(for_loop, state, tools):
return result
def evaluate_listcomp(listcomp, state, tools):
result = []
for generator in listcomp.generators:
iter_value = evaluate_ast(generator.iter, state, tools)
def evaluate_listcomp(listcomp, state, static_tools, custom_tools):
def inner_evaluate(generators, index, current_state):
if index >= len(generators):
return [evaluate_ast(listcomp.elt, current_state, static_tools, custom_tools)]
generator = generators[index]
iter_value = evaluate_ast(generator.iter, current_state, static_tools, custom_tools)
result = []
for value in iter_value:
new_state = state.copy()
new_state = current_state.copy()
if isinstance(generator.target, ast.Tuple):
for idx, elem in enumerate(generator.target.elts):
new_state[elem.id] = value[idx]
else:
new_state[generator.target.id] = value
if all(evaluate_ast(if_clause, new_state, tools) for if_clause in generator.ifs):
result.append(evaluate_ast(listcomp.elt, new_state, tools))
return result
if all(evaluate_ast(if_clause, new_state, static_tools, custom_tools) for if_clause in generator.ifs):
result.extend(inner_evaluate(generators, index + 1, new_state))
return result
return inner_evaluate(listcomp.generators, 0, state)
def evaluate_try(try_node, state, tools):
def evaluate_try(try_node, state, static_tools, custom_tools):
try:
for stmt in try_node.body:
evaluate_ast(stmt, state, tools)
evaluate_ast(stmt, state, static_tools, custom_tools)
except Exception as e:
matched = False
for handler in try_node.handlers:
if handler.type is None or isinstance(e, evaluate_ast(handler.type, state, tools)):
if handler.type is None or isinstance(e, evaluate_ast(handler.type, state, static_tools, custom_tools)):
matched = True
if handler.name:
state[handler.name] = e
for stmt in handler.body:
evaluate_ast(stmt, state, tools)
evaluate_ast(stmt, state, static_tools, custom_tools)
break
if not matched:
raise e
else:
if try_node.orelse:
for stmt in try_node.orelse:
evaluate_ast(stmt, state, tools)
evaluate_ast(stmt, state, static_tools, custom_tools)
finally:
if try_node.finalbody:
for stmt in try_node.finalbody:
evaluate_ast(stmt, state, tools)
evaluate_ast(stmt, state, static_tools, custom_tools)
def evaluate_raise(raise_node, state, tools):
def evaluate_raise(raise_node, state, static_tools, custom_tools):
if raise_node.exc is not None:
exc = evaluate_ast(raise_node.exc, state, tools)
exc = evaluate_ast(raise_node.exc, state, static_tools, custom_tools)
else:
exc = None
if raise_node.cause is not None:
cause = evaluate_ast(raise_node.cause, state, tools)
cause = evaluate_ast(raise_node.cause, state, static_tools, custom_tools)
else:
cause = None
if exc is not None:
@ -579,11 +625,11 @@ def evaluate_raise(raise_node, state, tools):
raise InterpreterError("Re-raise is not supported without an active exception")
def evaluate_assert(assert_node, state, tools):
test_result = evaluate_ast(assert_node.test, state, tools)
def evaluate_assert(assert_node, state, static_tools, custom_tools):
test_result = evaluate_ast(assert_node.test, state, static_tools, custom_tools)
if not test_result:
if assert_node.msg:
msg = evaluate_ast(assert_node.msg, state, tools)
msg = evaluate_ast(assert_node.msg, state, static_tools, custom_tools)
raise AssertionError(msg)
else:
# Include the failing condition in the assertion message
@ -591,10 +637,10 @@ def evaluate_assert(assert_node, state, tools):
raise AssertionError(f"Assertion failed: {test_code}")
def evaluate_with(with_node, state, tools):
def evaluate_with(with_node, state, static_tools, custom_tools):
contexts = []
for item in with_node.items:
context_expr = evaluate_ast(item.context_expr, state, tools)
context_expr = evaluate_ast(item.context_expr, state, static_tools, custom_tools)
if item.optional_vars:
state[item.optional_vars.id] = context_expr.__enter__()
contexts.append(state[item.optional_vars.id])
@ -604,7 +650,7 @@ def evaluate_with(with_node, state, tools):
try:
for stmt in with_node.body:
evaluate_ast(stmt, state, tools)
evaluate_ast(stmt, state, static_tools, custom_tools)
except Exception as e:
for context in reversed(contexts):
context.__exit__(type(e), e, e.__traceback__)
@ -614,10 +660,51 @@ def evaluate_with(with_node, state, tools):
context.__exit__(None, None, None)
def import_modules(expression, state, authorized_imports):
def check_module_authorized(module_name):
module_path = module_name.split(".")
module_subpaths = [".".join(module_path[:i]) for i in range(1, len(module_path) + 1)]
return any(subpath in authorized_imports for subpath in module_subpaths)
if isinstance(expression, ast.Import):
for alias in expression.names:
if check_module_authorized(alias.name):
module = import_module(alias.name)
state[alias.asname or alias.name] = module
else:
raise InterpreterError(
f"Import of {alias.name} is not allowed. Authorized imports are: {str(authorized_imports)}"
)
return None
elif isinstance(expression, ast.ImportFrom):
if check_module_authorized(expression.module):
module = __import__(expression.module, fromlist=[alias.name for alias in expression.names])
for alias in expression.names:
state[alias.asname or alias.name] = getattr(module, alias.name)
else:
raise InterpreterError(f"Import from {expression.module} is not allowed.")
return None
def evaluate_dictcomp(dictcomp, state, static_tools, custom_tools):
result = {}
for gen in dictcomp.generators:
iter_value = evaluate_ast(gen.iter, state, static_tools, custom_tools)
for value in iter_value:
new_state = state.copy()
set_value(gen.target, value, new_state, static_tools, custom_tools)
if all(evaluate_ast(if_clause, new_state, static_tools, custom_tools) for if_clause in gen.ifs):
key = evaluate_ast(dictcomp.key, new_state, static_tools, custom_tools)
val = evaluate_ast(dictcomp.value, new_state, static_tools, custom_tools)
result[key] = val
return result
def evaluate_ast(
expression: ast.AST,
state: Dict[str, Any],
tools: Dict[str, Callable],
static_tools: Dict[str, Callable],
custom_tools: Dict[str, Callable],
authorized_imports: List[str] = LIST_SAFE_MODULES,
):
"""
@ -632,146 +719,128 @@ def evaluate_ast(
state (`Dict[str, Any]`):
A dictionary mapping variable names to values. The `state` is updated if need be when the evaluation
encounters assignements.
tools (`Dict[str, Callable]`):
The functions that may be called during the evaluation. Any call to another function will fail with an
`InterpreterError`.
static_tools (`Dict[str, Callable]`):
Functions that may be called during the evaluation. Trying to change one of these static_tools will raise an error.
custom_tools (`Dict[str, Callable]`):
Functions that may be called during the evaluation. These static_tools can be overwritten.
authorized_imports (`List[str]`):
The list of modules that can be imported by the code. By default, only a few safe modules are allowed.
Add more at your own risk!
"""
global OPERATIONS_COUNT
if OPERATIONS_COUNT >= MAX_OPERATIONS:
raise InterpreterError(
f"Reached the max number of operations of {MAX_OPERATIONS}. Maybe there is an infinite loop somewhere in the code, or you're just asking too many calculations."
)
OPERATIONS_COUNT += 1
if isinstance(expression, ast.Assign):
# Assignement -> we evaluate the assignment which should update the state
# We return the variable assigned as it may be used to determine the final result.
return evaluate_assign(expression, state, tools)
return evaluate_assign(expression, state, static_tools, custom_tools)
elif isinstance(expression, ast.AugAssign):
return evaluate_augassign(expression, state, tools)
return evaluate_augassign(expression, state, static_tools, custom_tools)
elif isinstance(expression, ast.Call):
# Function call -> we return the value of the function call
return evaluate_call(expression, state, tools)
return evaluate_call(expression, state, static_tools, custom_tools)
elif isinstance(expression, ast.Constant):
# Constant -> just return the value
return expression.value
elif isinstance(expression, ast.Tuple):
return tuple(evaluate_ast(elt, state, tools) for elt in expression.elts)
elif isinstance(expression, ast.ListComp):
return evaluate_listcomp(expression, state, tools)
return tuple(evaluate_ast(elt, state, static_tools, custom_tools) for elt in expression.elts)
elif isinstance(expression, (ast.ListComp, ast.GeneratorExp)):
return evaluate_listcomp(expression, state, static_tools, custom_tools)
elif isinstance(expression, ast.UnaryOp):
return evaluate_unaryop(expression, state, tools)
return evaluate_unaryop(expression, state, static_tools, custom_tools)
elif isinstance(expression, ast.Starred):
return evaluate_ast(expression.value, state, static_tools, custom_tools)
elif isinstance(expression, ast.BoolOp):
# Boolean operation -> evaluate the operation
return evaluate_boolop(expression, state, tools)
return evaluate_boolop(expression, state, static_tools, custom_tools)
elif isinstance(expression, ast.Break):
raise BreakException()
elif isinstance(expression, ast.Continue):
raise ContinueException()
elif isinstance(expression, ast.BinOp):
# Binary operation -> execute operation
return evaluate_binop(expression, state, tools)
return evaluate_binop(expression, state, static_tools, custom_tools)
elif isinstance(expression, ast.Compare):
# Comparison -> evaluate the comparison
return evaluate_condition(expression, state, tools)
return evaluate_condition(expression, state, static_tools, custom_tools)
elif isinstance(expression, ast.Lambda):
return evaluate_lambda(expression, state, tools)
return evaluate_lambda(expression, state, static_tools, custom_tools)
elif isinstance(expression, ast.FunctionDef):
return evaluate_function_def(expression, state, tools)
return evaluate_function_def(expression, state, static_tools, custom_tools)
elif isinstance(expression, ast.Dict):
# Dict -> evaluate all keys and values
keys = [evaluate_ast(k, state, tools) for k in expression.keys]
values = [evaluate_ast(v, state, tools) for v in expression.values]
keys = [evaluate_ast(k, state, static_tools, custom_tools) for k in expression.keys]
values = [evaluate_ast(v, state, static_tools, custom_tools) for v in expression.values]
return dict(zip(keys, values))
elif isinstance(expression, ast.Expr):
# Expression -> evaluate the content
return evaluate_ast(expression.value, state, tools)
return evaluate_ast(expression.value, state, static_tools, custom_tools)
elif isinstance(expression, ast.For):
# For loop -> execute the loop
return evaluate_for(expression, state, tools)
return evaluate_for(expression, state, static_tools, custom_tools)
elif isinstance(expression, ast.FormattedValue):
# Formatted value (part of f-string) -> evaluate the content and return
return evaluate_ast(expression.value, state, tools)
return evaluate_ast(expression.value, state, static_tools, custom_tools)
elif isinstance(expression, ast.If):
# If -> execute the right branch
return evaluate_if(expression, state, tools)
return evaluate_if(expression, state, static_tools, custom_tools)
elif hasattr(ast, "Index") and isinstance(expression, ast.Index):
return evaluate_ast(expression.value, state, tools)
return evaluate_ast(expression.value, state, static_tools, custom_tools)
elif isinstance(expression, ast.JoinedStr):
return "".join([str(evaluate_ast(v, state, tools)) for v in expression.values])
return "".join([str(evaluate_ast(v, state, static_tools, custom_tools)) for v in expression.values])
elif isinstance(expression, ast.List):
# List -> evaluate all elements
return [evaluate_ast(elt, state, tools) for elt in expression.elts]
return [evaluate_ast(elt, state, static_tools, custom_tools) for elt in expression.elts]
elif isinstance(expression, ast.Name):
# Name -> pick up the value in the state
return evaluate_name(expression, state, tools)
return evaluate_name(expression, state, static_tools, custom_tools)
elif isinstance(expression, ast.Subscript):
# Subscript -> return the value of the indexing
return evaluate_subscript(expression, state, tools)
return evaluate_subscript(expression, state, static_tools, custom_tools)
elif isinstance(expression, ast.IfExp):
test_val = evaluate_ast(expression.test, state, tools)
test_val = evaluate_ast(expression.test, state, static_tools, custom_tools)
if test_val:
return evaluate_ast(expression.body, state, tools)
return evaluate_ast(expression.body, state, static_tools, custom_tools)
else:
return evaluate_ast(expression.orelse, state, tools)
return evaluate_ast(expression.orelse, state, static_tools, custom_tools)
elif isinstance(expression, ast.Attribute):
obj = evaluate_ast(expression.value, state, tools)
return getattr(obj, expression.attr)
value = evaluate_ast(expression.value, state, static_tools, custom_tools)
return getattr(value, expression.attr)
elif isinstance(expression, ast.Slice):
return slice(
evaluate_ast(expression.lower, state, tools) if expression.lower is not None else None,
evaluate_ast(expression.upper, state, tools) if expression.upper is not None else None,
evaluate_ast(expression.step, state, tools) if expression.step is not None else None,
evaluate_ast(expression.lower, state, static_tools, custom_tools)
if expression.lower is not None
else None,
evaluate_ast(expression.upper, state, static_tools, custom_tools)
if expression.upper is not None
else None,
evaluate_ast(expression.step, state, static_tools, custom_tools) if expression.step is not None else None,
)
elif isinstance(expression, ast.ListComp) or isinstance(expression, ast.GeneratorExp):
result = []
vars = {}
for generator in expression.generators:
var_name = generator.target.id
iter_value = evaluate_ast(generator.iter, state, tools)
for value in iter_value:
vars[var_name] = value
if all(evaluate_ast(if_clause, {**state, **vars}, tools) for if_clause in generator.ifs):
elem = evaluate_ast(expression.elt, {**state, **vars}, tools)
result.append(elem)
return result
elif isinstance(expression, ast.DictComp):
result = {}
for gen in expression.generators:
for container in get_iterable(evaluate_ast(gen.iter, state, tools)):
state[gen.target.id] = container
key = evaluate_ast(expression.key, state, tools)
value = evaluate_ast(expression.value, state, tools)
result[key] = value
return result
elif isinstance(expression, ast.Import):
for alias in expression.names:
if alias.name in authorized_imports:
module = __import__(alias.name)
state[alias.asname or alias.name] = module
else:
raise InterpreterError(f"Import of {alias.name} is not allowed.")
return None
return evaluate_dictcomp(expression, state, static_tools, custom_tools)
elif isinstance(expression, ast.While):
return evaluate_while(expression, state, tools)
elif isinstance(expression, ast.ImportFrom):
if expression.module in authorized_imports:
module = __import__(expression.module)
for alias in expression.names:
state[alias.asname or alias.name] = getattr(module, alias.name)
else:
raise InterpreterError(f"Import from {expression.module} is not allowed.")
return None
return evaluate_while(expression, state, static_tools, custom_tools)
elif isinstance(expression, (ast.Import, ast.ImportFrom)):
return import_modules(expression, state, authorized_imports)
elif isinstance(expression, ast.ClassDef):
return evaluate_class_def(expression, state, tools)
return evaluate_class_def(expression, state, static_tools, custom_tools)
elif isinstance(expression, ast.Try):
return evaluate_try(expression, state, tools)
return evaluate_try(expression, state, static_tools, custom_tools)
elif isinstance(expression, ast.Raise):
return evaluate_raise(expression, state, tools)
return evaluate_raise(expression, state, static_tools, custom_tools)
elif isinstance(expression, ast.Assert):
return evaluate_assert(expression, state, tools)
return evaluate_assert(expression, state, static_tools, custom_tools)
elif isinstance(expression, ast.With):
return evaluate_with(expression, state, tools)
return evaluate_with(expression, state, static_tools, custom_tools)
elif isinstance(expression, ast.Set):
return {evaluate_ast(elt, state, tools) for elt in expression.elts}
return {evaluate_ast(elt, state, static_tools, custom_tools) for elt in expression.elts}
elif isinstance(expression, ast.Return):
raise ReturnException(evaluate_ast(expression.value, state, tools) if expression.value else None)
raise ReturnException(
evaluate_ast(expression.value, state, static_tools, custom_tools) if expression.value else None
)
else:
# For now we refuse anything else. Let's add things as we need them.
raise InterpreterError(f"{expression.__class__.__name__} is not supported.")
@ -779,7 +848,8 @@ def evaluate_ast(
def evaluate_python_code(
code: str,
tools: Optional[Dict[str, Callable]] = None,
static_tools: Optional[Dict[str, Callable]] = None,
custom_tools: Optional[Dict[str, Callable]] = None,
state: Optional[Dict[str, Any]] = None,
authorized_imports: List[str] = LIST_SAFE_MODULES,
):
@ -792,9 +862,12 @@ def evaluate_python_code(
Args:
code (`str`):
The code to evaluate.
tools (`Dict[str, Callable]`):
The functions that may be called during the evaluation. Any call to another function will fail with an
`InterpreterError`.
static_tools (`Dict[str, Callable]`):
The functions that may be called during the evaluation.
These tools cannot be overwritten in the code: any assignment to their name will raise an error.
custom_tools (`Dict[str, Callable]`):
The functions that may be called during the evaluation.
These tools can be overwritten in the code: any assignment to their name will overwrite them.
state (`Dict[str, Any]`):
A dictionary mapping variable names to values. The `state` should contain the initial inputs but will be
updated by this function to contain all variables as they are evaluated.
@ -806,20 +879,34 @@ def evaluate_python_code(
raise SyntaxError(f"The code generated by the agent is not valid.\n{e}")
if state is None:
state = {}
if tools is None:
tools = {}
if static_tools is None:
static_tools = {}
if custom_tools is None:
custom_tools = {}
result = None
global PRINT_OUTPUTS
PRINT_OUTPUTS = ""
global OPERATIONS_COUNT
OPERATIONS_COUNT = 0
for node in expression.body:
try:
result = evaluate_ast(node, state, tools, authorized_imports)
result = evaluate_ast(node, state, static_tools, custom_tools, authorized_imports)
except InterpreterError as e:
msg = f"Evaluation stopped at line '{ast.get_source_segment(code, node)}' because of the following error:\n{e}"
msg = ""
if len(PRINT_OUTPUTS) > 0:
msg += f"Executing code yielded these outputs:\n{PRINT_OUTPUTS}\n====\n"
if len(PRINT_OUTPUTS) < MAX_LEN_OUTPUT:
msg += f"Print outputs:\n{PRINT_OUTPUTS}\n====\n"
else:
msg += f"Print outputs:\n{PRINT_OUTPUTS[:MAX_LEN_OUTPUT]}\n_Print outputs were over {MAX_LEN_OUTPUT} characters, so they have been truncated._\n====\n"
msg += f"EXECUTION FAILED:\nEvaluation stopped at line '{ast.get_source_segment(code, node)}' because of the following error:\n{e}"
raise InterpreterError(msg)
finally:
state["print_outputs"] = PRINT_OUTPUTS
if len(PRINT_OUTPUTS) < MAX_LEN_OUTPUT:
state["print_outputs"] = PRINT_OUTPUTS
else:
state["print_outputs"] = (
PRINT_OUTPUTS[:MAX_LEN_OUTPUT]
+ f"\n_Print outputs were over {MAX_LEN_OUTPUT} characters, so they have been truncated._"
)
return result

View File

@ -223,7 +223,7 @@ Action:
# check that add_base_tools will not interfere with existing tools
with pytest.raises(KeyError) as e:
agent = ReactJsonAgent(tools=toolset_3, llm_engine=fake_react_json_llm, add_base_tools=True)
assert "python_interpreter already exists in the toolbox" in str(e)
assert "already exists in the toolbox" in str(e)
# check that python_interpreter base tool does not get added to code agents
agent = ReactCodeAgent(tools=[], llm_engine=fake_react_code_llm, add_base_tools=True)

View File

@ -15,6 +15,7 @@
import unittest
import numpy as np
import pytest
from transformers import load_tool
@ -241,8 +242,41 @@ for block in text_block:
code = """
digits, i = [1, 2, 3], 1
digits[i], digits[i + 1] = digits[i + 1], digits[i]"""
evaluate_python_code(code, {"range": range, "print": print, "int": int}, {})
code = """
def calculate_isbn_10_check_digit(number):
total = sum((10 - i) * int(digit) for i, digit in enumerate(number))
remainder = total % 11
check_digit = 11 - remainder
if check_digit == 10:
return 'X'
elif check_digit == 11:
return '0'
else:
return str(check_digit)
# Given 9-digit numbers
numbers = [
"478225952",
"643485613",
"739394228",
"291726859",
"875262394",
"542617795",
"031810713",
"957007669",
"871467426"
]
# Calculate check digits for each number
check_digits = [calculate_isbn_10_check_digit(number) for number in numbers]
print(check_digits)
"""
state = {}
evaluate_python_code(code, {"range": range, "print": print, "int": int}, state)
evaluate_python_code(
code, {"range": range, "print": print, "sum": sum, "enumerate": enumerate, "int": int, "str": str}, state
)
def test_listcomp(self):
code = "x = [i for i in range(3)]"
@ -273,6 +307,17 @@ digits[i], digits[i + 1] = digits[i + 1], digits[i]"""
result = evaluate_python_code(code, {"range": range}, state={})
assert result == {0: 0, 1: 1, 2: 4}
code = "{num: name for num, name in {101: 'a', 102: 'b'}.items() if name not in ['a']}"
result = evaluate_python_code(code, {"print": print}, state={}, authorized_imports=["pandas"])
assert result == {102: "b"}
code = """
shifts = {'A': ('6:45', '8:00'), 'B': ('10:00', '11:45')}
shift_minutes = {worker: ('a', 'b') for worker, (start, end) in shifts.items()}
"""
result = evaluate_python_code(code, {}, state={})
assert result == {"A": ("a", "b"), "B": ("a", "b")}
def test_tuple_assignment(self):
code = "a, b = 0, 1\nb"
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
@ -341,7 +386,7 @@ if char.isalpha():
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result == "lose"
code = "import time\ntime.sleep(0.1)"
code = "import time, re\ntime.sleep(0.1)"
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result is None
@ -369,6 +414,23 @@ if char.isalpha():
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result == "LATIN CAPITAL LETTER A"
# Test submodules are handled properly, thus not raising error
code = "import numpy.random as rd\nrng = rd.default_rng(12345)\nrng.random()"
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"])
code = "from numpy.random import default_rng as d_rng\nrng = d_rng(12345)\nrng.random()"
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"])
def test_additional_imports(self):
code = "import numpy as np"
evaluate_python_code(code, authorized_imports=["numpy"], state={})
code = "import numpy.random as rd"
evaluate_python_code(code, authorized_imports=["numpy.random"], state={})
evaluate_python_code(code, authorized_imports=["numpy"], state={})
with pytest.raises(InterpreterError):
evaluate_python_code(code, authorized_imports=["random"], state={})
def test_multiple_comparators(self):
code = "0 <= -1 < 4 and 0 <= -5 < 4"
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
@ -400,7 +462,7 @@ def function():
print("2")
function()"""
state = {}
evaluate_python_code(code, {"print": print}, state)
evaluate_python_code(code, {"print": print}, state=state)
assert state["print_outputs"] == "1\n2\n"
def test_tuple_target_in_iterator(self):
@ -612,7 +674,7 @@ assert lock.locked == False
"""
state = {}
tools = {}
evaluate_python_code(code, tools, state)
evaluate_python_code(code, tools, state=state)
def test_default_arg_in_function(self):
code = """
@ -672,3 +734,94 @@ returns_none(1)
state = {}
result = evaluate_python_code(code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state)
assert result is None
def test_nested_for_loop(self):
code = """
all_res = []
for i in range(10):
subres = []
for j in range(i):
subres.append(j)
all_res.append(subres)
out = [i for sublist in all_res for i in sublist]
out[:10]
"""
state = {}
result = evaluate_python_code(code, {"print": print, "range": range}, state=state)
assert result == [0, 0, 1, 0, 1, 2, 0, 1, 2, 3]
def test_pandas(self):
code = """
import pandas as pd
df = pd.DataFrame.from_dict({'SetCount': ['5', '4', '5'], 'Quantity': [1, 0, -1]})
df['SetCount'] = pd.to_numeric(df['SetCount'], errors='coerce')
parts_with_5_set_count = df[df['SetCount'] == 5.0]
parts_with_5_set_count[['Quantity', 'SetCount']].values[1]
"""
state = {}
result = evaluate_python_code(code, {}, state=state, authorized_imports=["pandas"])
assert np.array_equal(result, [-1, 5])
code = """
import pandas as pd
df = pd.DataFrame.from_dict({"AtomicNumber": [111, 104, 105], "ok": [0, 1, 2]})
print("HH0")
# Filter the DataFrame to get only the rows with outdated atomic numbers
filtered_df = df.loc[df['AtomicNumber'].isin([104])]
"""
result = evaluate_python_code(code, {"print": print}, state={}, authorized_imports=["pandas"])
assert np.array_equal(result.values[0], [104, 1])
code = """import pandas as pd
data = pd.DataFrame.from_dict([
{"Pclass": 1, "Survived": 1},
{"Pclass": 2, "Survived": 0},
{"Pclass": 2, "Survived": 1}
])
survival_rate_by_class = data.groupby('Pclass')['Survived'].mean()
"""
result = evaluate_python_code(code, {}, state={}, authorized_imports=["pandas"])
assert result.values[1] == 0.5
def test_starred(self):
code = """
from math import radians, sin, cos, sqrt, atan2
def haversine(lat1, lon1, lat2, lon2):
R = 6371000 # Radius of the Earth in meters
lat1, lon1, lat2, lon2 = map(radians, [lat1, lon1, lat2, lon2])
dlat = lat2 - lat1
dlon = lon2 - lon1
a = sin(dlat / 2) ** 2 + cos(lat1) * cos(lat2) * sin(dlon / 2) ** 2
c = 2 * atan2(sqrt(a), sqrt(1 - a))
distance = R * c
return distance
coords_geneva = (46.1978, 6.1342)
coords_barcelona = (41.3869, 2.1660)
distance_geneva_barcelona = haversine(*coords_geneva, *coords_barcelona)
"""
result = evaluate_python_code(code, {"print": print, "map": map}, state={}, authorized_imports=["math"])
assert round(result, 1) == 622395.4
def test_for(self):
code = """
shifts = {
"Worker A": ("6:45 pm", "8:00 pm"),
"Worker B": ("10:00 am", "11:45 am")
}
shift_intervals = {}
for worker, (start, end) in shifts.items():
shift_intervals[worker] = end
shift_intervals
"""
result = evaluate_python_code(code, {"print": print, "map": map}, state={})
assert result == {"Worker A": "8:00 pm", "Worker B": "11:45 am"}