mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
parent
0fdea8607d
commit
b381880597
@ -25,7 +25,19 @@ from ..utils.import_utils import is_pygments_available
|
||||
from .agent_types import AgentAudio, AgentImage, AgentText
|
||||
from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool, setup_default_tools
|
||||
from .llm_engine import HfEngine, MessageRole
|
||||
from .prompts import DEFAULT_CODE_SYSTEM_PROMPT, DEFAULT_REACT_CODE_SYSTEM_PROMPT, DEFAULT_REACT_JSON_SYSTEM_PROMPT
|
||||
from .prompts import (
|
||||
DEFAULT_CODE_SYSTEM_PROMPT,
|
||||
DEFAULT_REACT_CODE_SYSTEM_PROMPT,
|
||||
DEFAULT_REACT_JSON_SYSTEM_PROMPT,
|
||||
PLAN_UPDATE_FINAL_PLAN_REDACTION,
|
||||
SYSTEM_PROMPT_FACTS,
|
||||
SYSTEM_PROMPT_FACTS_UPDATE,
|
||||
SYSTEM_PROMPT_PLAN,
|
||||
SYSTEM_PROMPT_PLAN_UPDATE,
|
||||
USER_PROMPT_FACTS_UPDATE,
|
||||
USER_PROMPT_PLAN,
|
||||
USER_PROMPT_PLAN_UPDATE,
|
||||
)
|
||||
from .python_interpreter import LIST_SAFE_MODULES, evaluate_python_code
|
||||
from .tools import (
|
||||
DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
|
||||
@ -99,12 +111,19 @@ def parse_json_blob(json_blob: str) -> Dict[str, str]:
|
||||
|
||||
def parse_code_blob(code_blob: str) -> str:
|
||||
try:
|
||||
pattern = r"```(?:py|python)?\n(.*?)```"
|
||||
pattern = r"```(?:py|python)?\n(.*?)\n```"
|
||||
match = re.search(pattern, code_blob, re.DOTALL)
|
||||
return match.group(1).strip()
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"The code blob you used is invalid: due to the following error: {e}. This means that the regex pattern {pattern} was not respected. Make sure to correct its formatting. Code blob was: {code_blob}"
|
||||
f"""
|
||||
The code blob you used is invalid: due to the following error: {e}
|
||||
This means that the regex pattern {pattern} was not respected: make sure to include code with the correct pattern, for instance:
|
||||
Thoughts: Your thoughts
|
||||
Code:
|
||||
```py
|
||||
# Your python code here
|
||||
```<end_action>"""
|
||||
)
|
||||
|
||||
|
||||
@ -113,6 +132,8 @@ def parse_json_tool_call(json_blob: str) -> Tuple[str, Dict[str, str]]:
|
||||
tool_call = parse_json_blob(json_blob)
|
||||
if "action" in tool_call and "action_input" in tool_call:
|
||||
return tool_call["action"], tool_call["action_input"]
|
||||
elif "action" in tool_call:
|
||||
return tool_call["action"], None
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Missing keys: {[key for key in ['action', 'action_input'] if key not in tool_call]} in blob {tool_call}"
|
||||
@ -208,7 +229,7 @@ class Toolbox:
|
||||
The tool to add to the toolbox.
|
||||
"""
|
||||
if tool.name in self._tools:
|
||||
raise KeyError(f"Error: tool {tool.name} already exists in the toolbox.")
|
||||
raise KeyError(f"Error: tool '{tool.name}' already exists in the toolbox.")
|
||||
self._tools[tool.name] = tool
|
||||
|
||||
def remove_tool(self, tool_name: str):
|
||||
@ -359,12 +380,8 @@ class Agent:
|
||||
"""Get the toolbox currently available to the agent"""
|
||||
return self._toolbox
|
||||
|
||||
def initialize_for_run(self, task: str, **kwargs):
|
||||
def initialize_for_run(self):
|
||||
self.token_count = 0
|
||||
self.task = task
|
||||
if len(kwargs) > 0:
|
||||
self.task += f"\nYou have been provided with these initial arguments: {str(kwargs)}."
|
||||
self.state = kwargs.copy()
|
||||
self.system_prompt = format_prompt_with_tools(
|
||||
self._toolbox,
|
||||
self.system_prompt_template,
|
||||
@ -380,7 +397,7 @@ class Agent:
|
||||
self.logger.debug("System prompt is as follows:")
|
||||
self.logger.debug(self.system_prompt)
|
||||
|
||||
def write_inner_memory_from_logs(self) -> List[Dict[str, str]]:
|
||||
def write_inner_memory_from_logs(self, summary_mode: Optional[bool] = False) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Reads past llm_outputs, actions, and observations or errors from the logs into a series of messages
|
||||
that can be used as input to the LLM.
|
||||
@ -390,43 +407,51 @@ class Agent:
|
||||
"role": MessageRole.USER,
|
||||
"content": "Task: " + self.logs[0]["task"],
|
||||
}
|
||||
memory = [prompt_message, task_message]
|
||||
if summary_mode:
|
||||
memory = [task_message]
|
||||
else:
|
||||
memory = [prompt_message, task_message]
|
||||
for i, step_log in enumerate(self.logs[1:]):
|
||||
if "llm_output" in step_log:
|
||||
thought_message = {"role": MessageRole.ASSISTANT, "content": step_log["llm_output"] + "\n"}
|
||||
if "llm_output" in step_log and not summary_mode:
|
||||
thought_message = {"role": MessageRole.ASSISTANT, "content": step_log["llm_output"].strip()}
|
||||
memory.append(thought_message)
|
||||
if "facts" in step_log:
|
||||
thought_message = {
|
||||
"role": MessageRole.ASSISTANT,
|
||||
"content": "[FACTS LIST]:\n" + step_log["facts"].strip(),
|
||||
}
|
||||
memory.append(thought_message)
|
||||
|
||||
if "error" in step_log:
|
||||
message_content = (
|
||||
"Error: "
|
||||
+ str(step_log["error"])
|
||||
+ "\nNow let's retry: take care not to repeat previous errors! If you have retried several times, try a completely different approach.\n"
|
||||
)
|
||||
elif "observation" in step_log:
|
||||
message_content = f"Observation: {step_log['observation']}"
|
||||
tool_response_message = {"role": MessageRole.TOOL_RESPONSE, "content": message_content}
|
||||
memory.append(tool_response_message)
|
||||
if "plan" in step_log and not summary_mode:
|
||||
thought_message = {"role": MessageRole.ASSISTANT, "content": "[PLAN]:\n" + step_log["plan"].strip()}
|
||||
memory.append(thought_message)
|
||||
|
||||
if "tool_call" in step_log and summary_mode:
|
||||
tool_call_message = {
|
||||
"role": MessageRole.ASSISTANT,
|
||||
"content": f"[STEP {i} TOOL CALL]: " + str(step_log["tool_call"]).strip(),
|
||||
}
|
||||
memory.append(tool_call_message)
|
||||
|
||||
if "task" in step_log:
|
||||
tool_call_message = {
|
||||
"role": MessageRole.USER,
|
||||
"content": "New task:\n" + step_log["task"],
|
||||
}
|
||||
memory.append(tool_call_message)
|
||||
|
||||
if "error" in step_log or "observation" in step_log:
|
||||
if "error" in step_log:
|
||||
message_content = (
|
||||
f"[OUTPUT OF STEP {i}] Error: "
|
||||
+ str(step_log["error"])
|
||||
+ "\nNow let's retry: take care not to repeat previous errors! If you have retried several times, try a completely different approach.\n"
|
||||
)
|
||||
elif "observation" in step_log:
|
||||
message_content = f"[OUTPUT OF STEP {i}] Observation:\n{step_log['observation']}"
|
||||
tool_response_message = {"role": MessageRole.TOOL_RESPONSE, "content": message_content}
|
||||
memory.append(tool_response_message)
|
||||
|
||||
if len(memory) % 3 == 0:
|
||||
reminder_content = (
|
||||
"Reminder: you are working towards solving the following task: " + self.logs[0]["task"]
|
||||
)
|
||||
reminder_content += "\nHere is a summary of your past tool calls and their results:"
|
||||
for j in range(i + 1):
|
||||
reminder_content += "\nStep " + str(j + 1)
|
||||
if "tool_call" in self.logs[j]:
|
||||
reminder_content += "\nTool call:" + str(self.logs[j]["tool_call"])
|
||||
if self.memory_verbose:
|
||||
if "observation" in self.logs[j]:
|
||||
reminder_content += "\nObservation:" + str(self.logs[j]["observation"])
|
||||
if "error" in self.logs[j]:
|
||||
reminder_content += "\nError:" + str(self.logs[j]["error"])
|
||||
memory.append(
|
||||
{
|
||||
"role": MessageRole.USER,
|
||||
"content": reminder_content,
|
||||
}
|
||||
)
|
||||
return memory
|
||||
|
||||
def get_succinct_logs(self):
|
||||
@ -459,7 +484,7 @@ class Agent:
|
||||
This method replaces arguments with the actual values from the state if they refer to state variables.
|
||||
|
||||
Args:
|
||||
tool_name (`str`): Name of the Tool to execute (shoulde be one from self.toolbox).
|
||||
tool_name (`str`): Name of the Tool to execute (should be one from self.toolbox).
|
||||
arguments (Dict[str, str]): Arguments passed to the Tool.
|
||||
"""
|
||||
if tool_name not in self.toolbox.tools:
|
||||
@ -559,7 +584,11 @@ class CodeAgent(Agent):
|
||||
agent.run("What is the result of 2 power 3.7384?")
|
||||
```
|
||||
"""
|
||||
self.initialize_for_run(task, **kwargs)
|
||||
self.task = task
|
||||
if len(kwargs) > 0:
|
||||
self.task += f"\nYou have been provided with these initial arguments: {str(kwargs)}."
|
||||
self.state = kwargs.copy()
|
||||
self.initialize_for_run()
|
||||
|
||||
# Run LLM
|
||||
prompt_message = {"role": MessageRole.SYSTEM, "content": self.system_prompt}
|
||||
@ -598,7 +627,8 @@ class CodeAgent(Agent):
|
||||
available_tools = {**BASE_PYTHON_TOOLS.copy(), **self.toolbox.tools}
|
||||
output = self.python_evaluator(
|
||||
code_action,
|
||||
available_tools,
|
||||
static_tools=available_tools,
|
||||
custom_tools={},
|
||||
state=self.state,
|
||||
authorized_imports=self.authorized_imports,
|
||||
)
|
||||
@ -623,6 +653,7 @@ class ReactAgent(Agent):
|
||||
llm_engine: Callable = HfEngine(),
|
||||
system_prompt: str = DEFAULT_REACT_CODE_SYSTEM_PROMPT,
|
||||
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
|
||||
planning_interval: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
@ -632,6 +663,7 @@ class ReactAgent(Agent):
|
||||
tool_description_template=tool_description_template,
|
||||
**kwargs,
|
||||
)
|
||||
self.planning_interval = planning_interval
|
||||
|
||||
def provide_final_answer(self, task) -> str:
|
||||
"""
|
||||
@ -655,11 +687,13 @@ class ReactAgent(Agent):
|
||||
except Exception as e:
|
||||
return f"Error in generating final llm output: {e}."
|
||||
|
||||
def run(self, task: str, stream: bool = False, **kwargs):
|
||||
def run(self, task: str, stream: bool = False, reset: bool = True, **kwargs):
|
||||
"""
|
||||
Runs the agent for the given task.
|
||||
|
||||
Args:
|
||||
task (`str`): The task to perform
|
||||
|
||||
Example:
|
||||
```py
|
||||
from transformers.agents import ReactCodeAgent
|
||||
@ -667,14 +701,23 @@ class ReactAgent(Agent):
|
||||
agent.run("What is the result of 2 power 3.7384?")
|
||||
```
|
||||
"""
|
||||
if stream:
|
||||
return self.stream_run(task, **kwargs)
|
||||
self.task = task
|
||||
if len(kwargs) > 0:
|
||||
self.task += f"\nYou have been provided with these initial arguments: {str(kwargs)}."
|
||||
self.state = kwargs.copy()
|
||||
if reset:
|
||||
self.initialize_for_run()
|
||||
else:
|
||||
return self.direct_run(task, **kwargs)
|
||||
|
||||
def stream_run(self, task: str, **kwargs):
|
||||
self.initialize_for_run(task, **kwargs)
|
||||
self.logs.append({"task": task})
|
||||
if stream:
|
||||
return self.stream_run(task)
|
||||
else:
|
||||
return self.direct_run(task)
|
||||
|
||||
def stream_run(self, task: str):
|
||||
"""
|
||||
Runs the agent in streaming mode, yielding steps as they are executed: should be launched only in the `run` method.
|
||||
"""
|
||||
final_answer = None
|
||||
iteration = 0
|
||||
while final_answer is None and iteration < self.max_iterations:
|
||||
@ -700,13 +743,16 @@ class ReactAgent(Agent):
|
||||
|
||||
yield final_answer
|
||||
|
||||
def direct_run(self, task: str, **kwargs):
|
||||
self.initialize_for_run(task, **kwargs)
|
||||
|
||||
def direct_run(self, task: str):
|
||||
"""
|
||||
Runs the agent in direct mode, returning outputs only at the end: should be launched only in the `run` method.
|
||||
"""
|
||||
final_answer = None
|
||||
iteration = 0
|
||||
while final_answer is None and iteration < self.max_iterations:
|
||||
try:
|
||||
if self.planning_interval is not None and iteration % self.planning_interval == 0:
|
||||
self.planning_step(task, is_first_step=(iteration == 0), iteration=iteration)
|
||||
step_logs = self.step()
|
||||
if "final_answer" in step_logs:
|
||||
final_answer = step_logs["final_answer"]
|
||||
@ -726,6 +772,96 @@ class ReactAgent(Agent):
|
||||
|
||||
return final_answer
|
||||
|
||||
def planning_step(self, task, is_first_step: bool = False, iteration: int = None):
|
||||
"""
|
||||
Used periodically by the agent to plan the next steps to reach the objective.
|
||||
|
||||
Args:
|
||||
task (`str`): The task to perform
|
||||
is_first_step (`bool`): If this step is not the first one, the plan should be an update over a previous plan.
|
||||
iteration (`int`): The number of the current step, used as an indication for the LLM.
|
||||
"""
|
||||
if is_first_step:
|
||||
message_prompt_facts = {"role": MessageRole.SYSTEM, "content": SYSTEM_PROMPT_FACTS}
|
||||
message_prompt_task = {
|
||||
"role": MessageRole.USER,
|
||||
"content": f"""Here is the task:
|
||||
```
|
||||
{task}
|
||||
```
|
||||
Now begin!""",
|
||||
}
|
||||
|
||||
answer_facts = self.llm_engine([message_prompt_facts, message_prompt_task])
|
||||
|
||||
message_system_prompt_plan = {"role": MessageRole.SYSTEM, "content": SYSTEM_PROMPT_PLAN}
|
||||
message_user_prompt_plan = {
|
||||
"role": MessageRole.USER,
|
||||
"content": USER_PROMPT_PLAN.format(
|
||||
task=task,
|
||||
tool_descriptions=self._toolbox.show_tool_descriptions(self.tool_description_template),
|
||||
answer_facts=answer_facts,
|
||||
),
|
||||
}
|
||||
answer_plan = self.llm_engine(
|
||||
[message_system_prompt_plan, message_user_prompt_plan], stop_sequences=["<end_plan>"]
|
||||
)
|
||||
|
||||
final_plan_redaction = f"""Here is the plan of action that I will follow to solve the task:
|
||||
```
|
||||
{answer_plan}
|
||||
```"""
|
||||
final_facts_redaction = f"""Here are the facts that I know so far:
|
||||
```
|
||||
{answer_facts}
|
||||
```""".strip()
|
||||
self.logs.append({"plan": final_plan_redaction, "facts": final_facts_redaction})
|
||||
self.logger.debug("===== Initial plan: =====")
|
||||
self.logger.debug(final_plan_redaction)
|
||||
else: # update plan
|
||||
agent_memory = self.write_inner_memory_from_logs(
|
||||
summary_mode=False
|
||||
) # This will not log the plan but will log facts
|
||||
|
||||
# Redact updated facts
|
||||
facts_update_system_prompt = {
|
||||
"role": MessageRole.SYSTEM,
|
||||
"content": SYSTEM_PROMPT_FACTS_UPDATE,
|
||||
}
|
||||
facts_update_message = {
|
||||
"role": MessageRole.USER,
|
||||
"content": USER_PROMPT_FACTS_UPDATE,
|
||||
}
|
||||
facts_update = self.llm_engine([facts_update_system_prompt] + agent_memory + [facts_update_message])
|
||||
|
||||
# Redact updated plan
|
||||
plan_update_message = {
|
||||
"role": MessageRole.SYSTEM,
|
||||
"content": SYSTEM_PROMPT_PLAN_UPDATE.format(task=task),
|
||||
}
|
||||
plan_update_message_user = {
|
||||
"role": MessageRole.USER,
|
||||
"content": USER_PROMPT_PLAN_UPDATE.format(
|
||||
task=task,
|
||||
tool_descriptions=self._toolbox.show_tool_descriptions(self.tool_description_template),
|
||||
facts_update=facts_update,
|
||||
remaining_steps=(self.max_iterations - iteration),
|
||||
),
|
||||
}
|
||||
plan_update = self.llm_engine(
|
||||
[plan_update_message] + agent_memory + [plan_update_message_user], stop_sequences=["<end_plan>"]
|
||||
)
|
||||
|
||||
# Log final facts and plan
|
||||
final_plan_redaction = PLAN_UPDATE_FINAL_PLAN_REDACTION.format(task=task, plan_update=plan_update)
|
||||
final_facts_redaction = f"""Here is the updated list of the facts that I know:
|
||||
```
|
||||
{facts_update}
|
||||
```"""
|
||||
self.logs.append({"plan": final_plan_redaction, "facts": final_facts_redaction})
|
||||
self.logger.debug("===== Updated plan: =====")
|
||||
self.logger.debug(final_plan_redaction)
|
||||
|
||||
|
||||
class ReactJsonAgent(ReactAgent):
|
||||
"""
|
||||
@ -740,6 +876,7 @@ class ReactJsonAgent(ReactAgent):
|
||||
llm_engine: Callable = HfEngine(),
|
||||
system_prompt: str = DEFAULT_REACT_JSON_SYSTEM_PROMPT,
|
||||
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
|
||||
planning_interval: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
@ -747,6 +884,7 @@ class ReactJsonAgent(ReactAgent):
|
||||
llm_engine=llm_engine,
|
||||
system_prompt=system_prompt,
|
||||
tool_description_template=tool_description_template,
|
||||
planning_interval=planning_interval,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -792,11 +930,16 @@ class ReactJsonAgent(ReactAgent):
|
||||
self.logger.warning(f"Calling tool: '{tool_name}' with arguments: {arguments}")
|
||||
if tool_name == "final_answer":
|
||||
if isinstance(arguments, dict):
|
||||
answer = arguments["answer"]
|
||||
if "answer" in arguments:
|
||||
answer = arguments["answer"]
|
||||
if (
|
||||
isinstance(answer, str) and answer in self.state.keys()
|
||||
): # if the answer is a state variable, return the value
|
||||
answer = self.state[answer]
|
||||
else:
|
||||
answer = arguments
|
||||
else:
|
||||
answer = arguments
|
||||
if answer in self.state: # if the answer is a state variable, return the value
|
||||
answer = self.state[answer]
|
||||
current_step_logs["final_answer"] = answer
|
||||
return current_step_logs
|
||||
else:
|
||||
@ -835,6 +978,7 @@ class ReactCodeAgent(ReactAgent):
|
||||
system_prompt: str = DEFAULT_REACT_CODE_SYSTEM_PROMPT,
|
||||
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
|
||||
additional_authorized_imports: Optional[List[str]] = None,
|
||||
planning_interval: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
@ -842,6 +986,7 @@ class ReactCodeAgent(ReactAgent):
|
||||
llm_engine=llm_engine,
|
||||
system_prompt=system_prompt,
|
||||
tool_description_template=tool_description_template,
|
||||
planning_interval=planning_interval,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -856,10 +1001,7 @@ class ReactCodeAgent(ReactAgent):
|
||||
self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else []
|
||||
self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports))
|
||||
self.system_prompt = self.system_prompt.replace("<<authorized_imports>>", str(self.authorized_imports))
|
||||
self.available_tools = {
|
||||
**BASE_PYTHON_TOOLS.copy(),
|
||||
**self.toolbox.tools,
|
||||
} # This list can be augmented by the code agent creating some new functions
|
||||
self.custom_tools = {}
|
||||
|
||||
def step(self):
|
||||
"""
|
||||
@ -911,7 +1053,11 @@ class ReactCodeAgent(ReactAgent):
|
||||
try:
|
||||
result = self.python_evaluator(
|
||||
code_action,
|
||||
tools=self.available_tools,
|
||||
static_tools={
|
||||
**BASE_PYTHON_TOOLS.copy(),
|
||||
**self.toolbox.tools,
|
||||
},
|
||||
custom_tools=self.custom_tools,
|
||||
state=self.state,
|
||||
authorized_imports=self.authorized_imports,
|
||||
)
|
||||
@ -920,7 +1066,7 @@ class ReactCodeAgent(ReactAgent):
|
||||
self.logger.log(32, information)
|
||||
current_step_logs["observation"] = information
|
||||
except Exception as e:
|
||||
error_msg = f"Failed while trying to execute the code below:\n{CustomFormatter.reset + code_action + CustomFormatter.reset}\nThis failed due to the following error:\n{str(e)}"
|
||||
error_msg = f"Code execution failed due to the following error:\n{str(e)}"
|
||||
if "'dict' object has no attribute 'read'" in str(e):
|
||||
error_msg += "\nYou get this error because you passed a dict as input for one of the arguments instead of a string."
|
||||
raise AgentExecutionError(error_msg)
|
||||
|
@ -173,7 +173,7 @@ class PythonInterpreterTool(Tool):
|
||||
|
||||
def forward(self, code):
|
||||
output = str(
|
||||
evaluate_python_code(code, tools=self.available_tools, authorized_imports=self.authorized_imports)
|
||||
evaluate_python_code(code, static_tools=self.available_tools, authorized_imports=self.authorized_imports)
|
||||
)
|
||||
return output
|
||||
|
||||
|
@ -365,7 +365,118 @@ Here are the rules you should always follow to solve your task:
|
||||
6. Don't name any new variable with the same name as a tool: for instance don't name a variable 'final_answer'.
|
||||
7. Never create any notional variables in our code, as having these in your logs might derail you from the true variables.
|
||||
8. You can use imports in your code, but only from the following list of modules: <<authorized_imports>>
|
||||
9. Don't give up! You're in charge of solving the task, not providing directions to solve it.
|
||||
9. The state persists between code executions: so if in one step you've created variables or imported modules, these will all persist.
|
||||
10. Don't give up! You're in charge of solving the task, not providing directions to solve it.
|
||||
|
||||
Now Begin! If you solve the task correctly, you will receive a reward of $1,000,000.
|
||||
"""
|
||||
|
||||
SYSTEM_PROMPT_FACTS = """Below I will present you a task.
|
||||
|
||||
You will now build a comprehensive preparatory survey of which facts we have at our disposal and which ones we still need.
|
||||
To do so, you will have to read the task and identify things that must be discovered in order to successfully complete it.
|
||||
Don't make any assumptions. For each item, provide a thorough reasoning. Here is how you will structure this survey:
|
||||
|
||||
---
|
||||
### 1. Facts given in the task
|
||||
List here the specific facts given in the task that could help you (there might be nothing here).
|
||||
|
||||
### 2. Facts to look up
|
||||
List here any facts that we may need to look up.
|
||||
Also list where to find each of these, for instance a website, a file... - maybe the task contains some sources that you should re-use here.
|
||||
|
||||
### 3. Facts to derive
|
||||
List here anything that we want to derive from the above by logical reasoning, for instance computation or simulation.
|
||||
|
||||
Keep in mind that "facts" will typically be specific names, dates, values, etc. Your answer should use the below headings:
|
||||
### 1. Facts given in the task
|
||||
### 2. Facts to look up
|
||||
### 3. Facts to derive
|
||||
Do not add anything else."""
|
||||
|
||||
SYSTEM_PROMPT_PLAN = """You are a world expert at making efficient plans to solve any task using a set of carefully crafted tools.
|
||||
|
||||
Now for the given task, develop a step-by-step high-level plan taking into account the above inputs and list of facts.
|
||||
This plan should involve individual tasks based on the avilable tools, that if executed correctly will yield the correct answer.
|
||||
Do not skip steps, do not add any superfluous steps. Only write the high-level plan, DO NOT DETAIL INDIVIDUAL TOOL CALLS.
|
||||
After writing the final step of the plan, write the '\n<end_plan>' tag and stop there."""
|
||||
|
||||
USER_PROMPT_PLAN = """
|
||||
Here is your task:
|
||||
|
||||
Task:
|
||||
```
|
||||
{task}
|
||||
```
|
||||
|
||||
Your plan can leverage any of these tools:
|
||||
{tool_descriptions}
|
||||
|
||||
List of facts that you know:
|
||||
```
|
||||
{answer_facts}
|
||||
```
|
||||
|
||||
Now begin! Write your plan below."""
|
||||
|
||||
SYSTEM_PROMPT_FACTS_UPDATE = """
|
||||
You are a world expert at gathering known and unknown facts based on a conversation.
|
||||
Below you will find a task, and ahistory of attempts made to solve the task. You will have to produce a list of these:
|
||||
### 1. Facts given in the task
|
||||
### 2. Facts that we have learned
|
||||
### 3. Facts still to look up
|
||||
### 4. Facts still to derive
|
||||
Find the task and history below."""
|
||||
|
||||
USER_PROMPT_FACTS_UPDATE = """Earlier we've built a list of facts.
|
||||
But since in your previous steps you may have learned useful new facts or invalidated some false ones.
|
||||
Please update your list of facts based on the previous history, and provide these headings:
|
||||
### 1. Facts given in the task
|
||||
### 2. Facts that we have learned
|
||||
### 3. Facts still to look up
|
||||
### 4. Facts still to derive
|
||||
|
||||
Now write your new list of facts below."""
|
||||
|
||||
SYSTEM_PROMPT_PLAN_UPDATE = """You are a world expert at making efficient plans to solve any task using a set of carefully crafted tools.
|
||||
|
||||
You have been given a task:
|
||||
```
|
||||
{task}
|
||||
```
|
||||
|
||||
Find below the record of what has been tried so far to solve it. Then you will be asked to make an updated plan to solve the task.
|
||||
If the previous tries so far have met some success, you can make an updated plan based on these actions.
|
||||
If you are stalled, you can make a completely new plan starting from scratch.
|
||||
"""
|
||||
|
||||
USER_PROMPT_PLAN_UPDATE = """You're still working towards solving this task:
|
||||
```
|
||||
{task}
|
||||
```
|
||||
|
||||
You have access to these tools:
|
||||
{tool_descriptions}
|
||||
|
||||
Here is the up to date list of facts that you know:
|
||||
```
|
||||
{facts_update}
|
||||
```
|
||||
|
||||
Now for the given task, develop a step-by-step high-level plan taking into account the above inputs and list of facts.
|
||||
This plan should involve individual tasks based on the avilable tools, that if executed correctly will yield the correct answer.
|
||||
Beware that you have {remaining_steps} steps remaining.
|
||||
Do not skip steps, do not add any superfluous steps. Only write the high-level plan, DO NOT DETAIL INDIVIDUAL TOOL CALLS.
|
||||
After writing the final step of the plan, write the '\n<end_plan>' tag and stop there.
|
||||
|
||||
Now write your new plan below."""
|
||||
|
||||
PLAN_UPDATE_FINAL_PLAN_REDACTION = """I still need to solve the task I was given:
|
||||
```
|
||||
{task}
|
||||
```
|
||||
|
||||
Here is my new/updated plan of action to solve the task:
|
||||
```
|
||||
{plan_update}
|
||||
```"""
|
||||
|
@ -18,8 +18,17 @@ import ast
|
||||
import builtins
|
||||
import difflib
|
||||
from collections.abc import Mapping
|
||||
from importlib import import_module
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..utils import is_pandas_available
|
||||
|
||||
|
||||
if is_pandas_available():
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class InterpreterError(ValueError):
|
||||
"""
|
||||
@ -50,7 +59,8 @@ LIST_SAFE_MODULES = [
|
||||
"unicodedata",
|
||||
]
|
||||
|
||||
PRINT_OUTPUTS = ""
|
||||
PRINT_OUTPUTS, MAX_LEN_OUTPUT = "", 50000
|
||||
OPERATIONS_COUNT, MAX_OPERATIONS = 0, 10000000
|
||||
|
||||
|
||||
class BreakException(Exception):
|
||||
@ -75,8 +85,8 @@ def get_iterable(obj):
|
||||
raise InterpreterError("Object is not iterable")
|
||||
|
||||
|
||||
def evaluate_unaryop(expression, state, tools):
|
||||
operand = evaluate_ast(expression.operand, state, tools)
|
||||
def evaluate_unaryop(expression, state, static_tools, custom_tools):
|
||||
operand = evaluate_ast(expression.operand, state, static_tools, custom_tools)
|
||||
if isinstance(expression.op, ast.USub):
|
||||
return -operand
|
||||
elif isinstance(expression.op, ast.UAdd):
|
||||
@ -89,25 +99,25 @@ def evaluate_unaryop(expression, state, tools):
|
||||
raise InterpreterError(f"Unary operation {expression.op.__class__.__name__} is not supported.")
|
||||
|
||||
|
||||
def evaluate_lambda(lambda_expression, state, tools):
|
||||
def evaluate_lambda(lambda_expression, state, static_tools, custom_tools):
|
||||
args = [arg.arg for arg in lambda_expression.args.args]
|
||||
|
||||
def lambda_func(*values):
|
||||
new_state = state.copy()
|
||||
for arg, value in zip(args, values):
|
||||
new_state[arg] = value
|
||||
return evaluate_ast(lambda_expression.body, new_state, tools)
|
||||
return evaluate_ast(lambda_expression.body, new_state, static_tools, custom_tools)
|
||||
|
||||
return lambda_func
|
||||
|
||||
|
||||
def evaluate_while(while_loop, state, tools):
|
||||
def evaluate_while(while_loop, state, static_tools, custom_tools):
|
||||
max_iterations = 1000
|
||||
iterations = 0
|
||||
while evaluate_ast(while_loop.test, state, tools):
|
||||
while evaluate_ast(while_loop.test, state, static_tools, custom_tools):
|
||||
for node in while_loop.body:
|
||||
try:
|
||||
evaluate_ast(node, state, tools)
|
||||
evaluate_ast(node, state, static_tools, custom_tools)
|
||||
except BreakException:
|
||||
return None
|
||||
except ContinueException:
|
||||
@ -118,11 +128,11 @@ def evaluate_while(while_loop, state, tools):
|
||||
return None
|
||||
|
||||
|
||||
def create_function(func_def, state, tools):
|
||||
def create_function(func_def, state, static_tools, custom_tools):
|
||||
def new_func(*args, **kwargs):
|
||||
func_state = state.copy()
|
||||
arg_names = [arg.arg for arg in func_def.args.args]
|
||||
default_values = [evaluate_ast(d, state, tools) for d in func_def.args.defaults]
|
||||
default_values = [evaluate_ast(d, state, static_tools, custom_tools) for d in func_def.args.defaults]
|
||||
|
||||
# Apply default values
|
||||
defaults = dict(zip(arg_names[-len(default_values) :], default_values))
|
||||
@ -158,7 +168,7 @@ def create_function(func_def, state, tools):
|
||||
result = None
|
||||
try:
|
||||
for stmt in func_def.body:
|
||||
result = evaluate_ast(stmt, func_state, tools)
|
||||
result = evaluate_ast(stmt, func_state, static_tools, custom_tools)
|
||||
except ReturnException as e:
|
||||
result = e.value
|
||||
return result
|
||||
@ -173,25 +183,25 @@ def create_class(class_name, class_bases, class_body):
|
||||
return type(class_name, tuple(class_bases), class_dict)
|
||||
|
||||
|
||||
def evaluate_function_def(func_def, state, tools):
|
||||
tools[func_def.name] = create_function(func_def, state, tools)
|
||||
return tools[func_def.name]
|
||||
def evaluate_function_def(func_def, state, static_tools, custom_tools):
|
||||
custom_tools[func_def.name] = create_function(func_def, state, static_tools, custom_tools)
|
||||
return custom_tools[func_def.name]
|
||||
|
||||
|
||||
def evaluate_class_def(class_def, state, tools):
|
||||
def evaluate_class_def(class_def, state, static_tools, custom_tools):
|
||||
class_name = class_def.name
|
||||
bases = [evaluate_ast(base, state, tools) for base in class_def.bases]
|
||||
bases = [evaluate_ast(base, state, static_tools, custom_tools) for base in class_def.bases]
|
||||
class_dict = {}
|
||||
|
||||
for stmt in class_def.body:
|
||||
if isinstance(stmt, ast.FunctionDef):
|
||||
class_dict[stmt.name] = evaluate_function_def(stmt, state, tools)
|
||||
class_dict[stmt.name] = evaluate_function_def(stmt, state, static_tools, custom_tools)
|
||||
elif isinstance(stmt, ast.Assign):
|
||||
for target in stmt.targets:
|
||||
if isinstance(target, ast.Name):
|
||||
class_dict[target.id] = evaluate_ast(stmt.value, state, tools)
|
||||
class_dict[target.id] = evaluate_ast(stmt.value, state, static_tools, custom_tools)
|
||||
elif isinstance(target, ast.Attribute):
|
||||
class_dict[target.attr] = evaluate_ast(stmt.value, state, tools)
|
||||
class_dict[target.attr] = evaluate_ast(stmt.value, state, static_tools, custom_tools)
|
||||
else:
|
||||
raise InterpreterError(f"Unsupported statement in class body: {stmt.__class__.__name__}")
|
||||
|
||||
@ -200,17 +210,17 @@ def evaluate_class_def(class_def, state, tools):
|
||||
return new_class
|
||||
|
||||
|
||||
def evaluate_augassign(expression: ast.AugAssign, state: Dict[str, Any], tools: Dict[str, Callable]):
|
||||
def evaluate_augassign(expression, state, static_tools, custom_tools):
|
||||
# Helper function to get current value and set new value based on the target type
|
||||
def get_current_value(target):
|
||||
if isinstance(target, ast.Name):
|
||||
return state.get(target.id, 0)
|
||||
elif isinstance(target, ast.Subscript):
|
||||
obj = evaluate_ast(target.value, state, tools)
|
||||
key = evaluate_ast(target.slice, state, tools)
|
||||
obj = evaluate_ast(target.value, state, static_tools, custom_tools)
|
||||
key = evaluate_ast(target.slice, state, static_tools, custom_tools)
|
||||
return obj[key]
|
||||
elif isinstance(target, ast.Attribute):
|
||||
obj = evaluate_ast(target.value, state, tools)
|
||||
obj = evaluate_ast(target.value, state, static_tools, custom_tools)
|
||||
return getattr(obj, target.attr)
|
||||
elif isinstance(target, ast.Tuple):
|
||||
return tuple(get_current_value(elt) for elt in target.elts)
|
||||
@ -220,7 +230,7 @@ def evaluate_augassign(expression: ast.AugAssign, state: Dict[str, Any], tools:
|
||||
raise InterpreterError("AugAssign not supported for {type(target)} targets.")
|
||||
|
||||
current_value = get_current_value(expression.target)
|
||||
value_to_add = evaluate_ast(expression.value, state, tools)
|
||||
value_to_add = evaluate_ast(expression.value, state, static_tools, custom_tools)
|
||||
|
||||
# Determine the operation and apply it
|
||||
if isinstance(expression.op, ast.Add):
|
||||
@ -256,28 +266,28 @@ def evaluate_augassign(expression: ast.AugAssign, state: Dict[str, Any], tools:
|
||||
raise InterpreterError(f"Operation {type(expression.op).__name__} is not supported.")
|
||||
|
||||
# Update the state
|
||||
set_value(expression.target, updated_value, state, tools)
|
||||
set_value(expression.target, updated_value, state, static_tools, custom_tools)
|
||||
|
||||
return updated_value
|
||||
|
||||
|
||||
def evaluate_boolop(node, state, tools):
|
||||
def evaluate_boolop(node, state, static_tools, custom_tools):
|
||||
if isinstance(node.op, ast.And):
|
||||
for value in node.values:
|
||||
if not evaluate_ast(value, state, tools):
|
||||
if not evaluate_ast(value, state, static_tools, custom_tools):
|
||||
return False
|
||||
return True
|
||||
elif isinstance(node.op, ast.Or):
|
||||
for value in node.values:
|
||||
if evaluate_ast(value, state, tools):
|
||||
if evaluate_ast(value, state, static_tools, custom_tools):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def evaluate_binop(binop, state, tools):
|
||||
def evaluate_binop(binop, state, static_tools, custom_tools):
|
||||
# Recursively evaluate the left and right operands
|
||||
left_val = evaluate_ast(binop.left, state, tools)
|
||||
right_val = evaluate_ast(binop.right, state, tools)
|
||||
left_val = evaluate_ast(binop.left, state, static_tools, custom_tools)
|
||||
right_val = evaluate_ast(binop.right, state, static_tools, custom_tools)
|
||||
|
||||
# Determine the operation based on the type of the operator in the BinOp
|
||||
if isinstance(binop.op, ast.Add):
|
||||
@ -308,66 +318,92 @@ def evaluate_binop(binop, state, tools):
|
||||
raise NotImplementedError(f"Binary operation {type(binop.op).__name__} is not implemented.")
|
||||
|
||||
|
||||
def evaluate_assign(assign, state, tools):
|
||||
result = evaluate_ast(assign.value, state, tools)
|
||||
def evaluate_assign(assign, state, static_tools, custom_tools):
|
||||
result = evaluate_ast(assign.value, state, static_tools, custom_tools)
|
||||
if len(assign.targets) == 1:
|
||||
target = assign.targets[0]
|
||||
set_value(target, result, state, tools)
|
||||
set_value(target, result, state, static_tools, custom_tools)
|
||||
else:
|
||||
if len(assign.targets) != len(result):
|
||||
raise InterpreterError(f"Assign failed: expected {len(result)} values but got {len(assign.targets)}.")
|
||||
for tgt, val in zip(assign.targets, result):
|
||||
set_value(tgt, val, state, tools)
|
||||
expanded_values = []
|
||||
for tgt in assign.targets:
|
||||
if isinstance(tgt, ast.Starred):
|
||||
expanded_values.extend(result)
|
||||
else:
|
||||
expanded_values.append(result)
|
||||
for tgt, val in zip(assign.targets, expanded_values):
|
||||
set_value(tgt, val, state, static_tools, custom_tools)
|
||||
return result
|
||||
|
||||
|
||||
def set_value(target, value, state, tools):
|
||||
def set_value(target, value, state, static_tools, custom_tools):
|
||||
if isinstance(target, ast.Name):
|
||||
if target.id in tools:
|
||||
if target.id in static_tools:
|
||||
raise InterpreterError(f"Cannot assign to name '{target.id}': doing this would erase the existing tool!")
|
||||
state[target.id] = value
|
||||
elif isinstance(target, ast.Tuple):
|
||||
if not isinstance(value, tuple):
|
||||
raise InterpreterError("Cannot unpack non-tuple value")
|
||||
if hasattr(value, "__iter__") and not isinstance(value, (str, bytes)):
|
||||
value = tuple(value)
|
||||
else:
|
||||
raise InterpreterError("Cannot unpack non-tuple value")
|
||||
if len(target.elts) != len(value):
|
||||
raise InterpreterError("Cannot unpack tuple of wrong size")
|
||||
for i, elem in enumerate(target.elts):
|
||||
set_value(elem, value[i], state, tools)
|
||||
set_value(elem, value[i], state, static_tools, custom_tools)
|
||||
elif isinstance(target, ast.Subscript):
|
||||
obj = evaluate_ast(target.value, state, tools)
|
||||
key = evaluate_ast(target.slice, state, tools)
|
||||
obj = evaluate_ast(target.value, state, static_tools, custom_tools)
|
||||
key = evaluate_ast(target.slice, state, static_tools, custom_tools)
|
||||
obj[key] = value
|
||||
elif isinstance(target, ast.Attribute):
|
||||
obj = evaluate_ast(target.value, state, tools)
|
||||
obj = evaluate_ast(target.value, state, static_tools, custom_tools)
|
||||
setattr(obj, target.attr, value)
|
||||
|
||||
|
||||
def evaluate_call(call, state, tools):
|
||||
def evaluate_call(call, state, static_tools, custom_tools):
|
||||
if not (isinstance(call.func, ast.Attribute) or isinstance(call.func, ast.Name)):
|
||||
raise InterpreterError(
|
||||
f"It is not permitted to evaluate other functions than the provided tools (tried to execute {call.func})."
|
||||
)
|
||||
raise InterpreterError(f"This is not a correct function: {call.func}).")
|
||||
if isinstance(call.func, ast.Attribute):
|
||||
obj = evaluate_ast(call.func.value, state, tools)
|
||||
obj = evaluate_ast(call.func.value, state, static_tools, custom_tools)
|
||||
func_name = call.func.attr
|
||||
if not hasattr(obj, func_name):
|
||||
raise InterpreterError(f"Object {obj} has no attribute {func_name}")
|
||||
func = getattr(obj, func_name)
|
||||
|
||||
elif isinstance(call.func, ast.Name):
|
||||
func_name = call.func.id
|
||||
if func_name in state:
|
||||
func = state[func_name]
|
||||
elif func_name in tools:
|
||||
func = tools[func_name]
|
||||
elif func_name in static_tools:
|
||||
func = static_tools[func_name]
|
||||
elif func_name in custom_tools:
|
||||
func = custom_tools[func_name]
|
||||
elif func_name in ERRORS:
|
||||
func = ERRORS[func_name]
|
||||
else:
|
||||
raise InterpreterError(
|
||||
f"It is not permitted to evaluate other functions than the provided tools or imported functions (tried to execute {call.func.id})."
|
||||
f"It is not permitted to evaluate other functions than the provided tools or functions defined in previous code (tried to execute {call.func.id})."
|
||||
)
|
||||
|
||||
args = [evaluate_ast(arg, state, tools) for arg in call.args]
|
||||
kwargs = {keyword.arg: evaluate_ast(keyword.value, state, tools) for keyword in call.keywords}
|
||||
args = []
|
||||
for arg in call.args:
|
||||
if isinstance(arg, ast.Starred):
|
||||
args.extend(evaluate_ast(arg.value, state, static_tools, custom_tools))
|
||||
else:
|
||||
args.append(evaluate_ast(arg, state, static_tools, custom_tools))
|
||||
|
||||
args = []
|
||||
for arg in call.args:
|
||||
if isinstance(arg, ast.Starred):
|
||||
unpacked = evaluate_ast(arg.value, state, static_tools, custom_tools)
|
||||
if not hasattr(unpacked, "__iter__") or isinstance(unpacked, (str, bytes)):
|
||||
raise InterpreterError(f"Cannot unpack non-iterable value {unpacked}")
|
||||
args.extend(unpacked)
|
||||
else:
|
||||
args.append(evaluate_ast(arg, state, static_tools, custom_tools))
|
||||
|
||||
kwargs = {keyword.arg: evaluate_ast(keyword.value, state, static_tools, custom_tools) for keyword in call.keywords}
|
||||
|
||||
if isinstance(func, type) and len(func.__module__.split(".")) > 1: # Check for user-defined classes
|
||||
# Instantiate the class using its constructor
|
||||
@ -397,24 +433,31 @@ def evaluate_call(call, state, tools):
|
||||
output = " ".join(map(str, args))
|
||||
global PRINT_OUTPUTS
|
||||
PRINT_OUTPUTS += output + "\n"
|
||||
# cap the number of lines
|
||||
return output
|
||||
else: # Assume it's a callable object
|
||||
output = func(*args, **kwargs)
|
||||
return output
|
||||
|
||||
|
||||
def evaluate_subscript(subscript, state, tools):
|
||||
index = evaluate_ast(subscript.slice, state, tools)
|
||||
value = evaluate_ast(subscript.value, state, tools)
|
||||
if isinstance(index, slice):
|
||||
def evaluate_subscript(subscript, state, static_tools, custom_tools):
|
||||
index = evaluate_ast(subscript.slice, state, static_tools, custom_tools)
|
||||
value = evaluate_ast(subscript.value, state, static_tools, custom_tools)
|
||||
|
||||
if isinstance(value, pd.core.indexing._LocIndexer):
|
||||
parent_object = value.obj
|
||||
return parent_object.loc[index]
|
||||
if isinstance(value, (pd.DataFrame, pd.Series, np.ndarray)):
|
||||
return value[index]
|
||||
elif isinstance(value, pd.core.groupby.generic.DataFrameGroupBy):
|
||||
return value[index]
|
||||
elif isinstance(index, slice):
|
||||
return value[index]
|
||||
elif isinstance(value, (list, tuple)):
|
||||
# Ensure the index is within bounds
|
||||
if not (-len(value) <= index < len(value)):
|
||||
raise InterpreterError(f"Index {index} out of bounds for list of length {len(value)}")
|
||||
return value[int(index)]
|
||||
elif isinstance(value, str):
|
||||
# Ensure the index is within bounds
|
||||
if not (-len(value) <= index < len(value)):
|
||||
raise InterpreterError(f"Index {index} out of bounds for string of length {len(value)}")
|
||||
return value[index]
|
||||
@ -427,11 +470,11 @@ def evaluate_subscript(subscript, state, tools):
|
||||
raise InterpreterError(f"Could not index {value} with '{index}'.")
|
||||
|
||||
|
||||
def evaluate_name(name, state, tools):
|
||||
def evaluate_name(name, state, static_tools, custom_tools):
|
||||
if name.id in state:
|
||||
return state[name.id]
|
||||
elif name.id in tools:
|
||||
return tools[name.id]
|
||||
elif name.id in static_tools:
|
||||
return static_tools[name.id]
|
||||
elif name.id in ERRORS:
|
||||
return ERRORS[name.id]
|
||||
close_matches = difflib.get_close_matches(name.id, list(state.keys()))
|
||||
@ -440,9 +483,9 @@ def evaluate_name(name, state, tools):
|
||||
raise InterpreterError(f"The variable `{name.id}` is not defined.")
|
||||
|
||||
|
||||
def evaluate_condition(condition, state, tools):
|
||||
left = evaluate_ast(condition.left, state, tools)
|
||||
comparators = [evaluate_ast(c, state, tools) for c in condition.comparators]
|
||||
def evaluate_condition(condition, state, static_tools, custom_tools):
|
||||
left = evaluate_ast(condition.left, state, static_tools, custom_tools)
|
||||
comparators = [evaluate_ast(c, state, static_tools, custom_tools) for c in condition.comparators]
|
||||
ops = [type(op) for op in condition.ops]
|
||||
|
||||
result = True
|
||||
@ -450,63 +493,61 @@ def evaluate_condition(condition, state, tools):
|
||||
|
||||
for op, comparator in zip(ops, comparators):
|
||||
if op == ast.Eq:
|
||||
result = result and (current_left == comparator)
|
||||
current_result = current_left == comparator
|
||||
elif op == ast.NotEq:
|
||||
result = result and (current_left != comparator)
|
||||
current_result = current_left != comparator
|
||||
elif op == ast.Lt:
|
||||
result = result and (current_left < comparator)
|
||||
current_result = current_left < comparator
|
||||
elif op == ast.LtE:
|
||||
result = result and (current_left <= comparator)
|
||||
current_result = current_left <= comparator
|
||||
elif op == ast.Gt:
|
||||
result = result and (current_left > comparator)
|
||||
current_result = current_left > comparator
|
||||
elif op == ast.GtE:
|
||||
result = result and (current_left >= comparator)
|
||||
current_result = current_left >= comparator
|
||||
elif op == ast.Is:
|
||||
result = result and (current_left is comparator)
|
||||
current_result = current_left is comparator
|
||||
elif op == ast.IsNot:
|
||||
result = result and (current_left is not comparator)
|
||||
current_result = current_left is not comparator
|
||||
elif op == ast.In:
|
||||
result = result and (current_left in comparator)
|
||||
current_result = current_left in comparator
|
||||
elif op == ast.NotIn:
|
||||
result = result and (current_left not in comparator)
|
||||
current_result = current_left not in comparator
|
||||
else:
|
||||
raise InterpreterError(f"Operator not supported: {op}")
|
||||
|
||||
result = result & current_result
|
||||
current_left = comparator
|
||||
if not result:
|
||||
|
||||
if isinstance(result, bool) and not result:
|
||||
break
|
||||
|
||||
return result
|
||||
return result if isinstance(result, (bool, pd.Series)) else result.all()
|
||||
|
||||
|
||||
def evaluate_if(if_statement, state, tools):
|
||||
def evaluate_if(if_statement, state, static_tools, custom_tools):
|
||||
result = None
|
||||
test_result = evaluate_ast(if_statement.test, state, tools)
|
||||
test_result = evaluate_ast(if_statement.test, state, static_tools, custom_tools)
|
||||
if test_result:
|
||||
for line in if_statement.body:
|
||||
line_result = evaluate_ast(line, state, tools)
|
||||
line_result = evaluate_ast(line, state, static_tools, custom_tools)
|
||||
if line_result is not None:
|
||||
result = line_result
|
||||
else:
|
||||
for line in if_statement.orelse:
|
||||
line_result = evaluate_ast(line, state, tools)
|
||||
line_result = evaluate_ast(line, state, static_tools, custom_tools)
|
||||
if line_result is not None:
|
||||
result = line_result
|
||||
return result
|
||||
|
||||
|
||||
def evaluate_for(for_loop, state, tools):
|
||||
def evaluate_for(for_loop, state, static_tools, custom_tools):
|
||||
result = None
|
||||
iterator = evaluate_ast(for_loop.iter, state, tools)
|
||||
iterator = evaluate_ast(for_loop.iter, state, static_tools, custom_tools)
|
||||
for counter in iterator:
|
||||
if isinstance(for_loop.target, ast.Tuple):
|
||||
for i, elem in enumerate(for_loop.target.elts):
|
||||
state[elem.id] = counter[i]
|
||||
else:
|
||||
state[for_loop.target.id] = counter
|
||||
set_value(for_loop.target, counter, state, static_tools, custom_tools)
|
||||
for node in for_loop.body:
|
||||
try:
|
||||
line_result = evaluate_ast(node, state, tools)
|
||||
line_result = evaluate_ast(node, state, static_tools, custom_tools)
|
||||
if line_result is not None:
|
||||
result = line_result
|
||||
except BreakException:
|
||||
@ -519,55 +560,60 @@ def evaluate_for(for_loop, state, tools):
|
||||
return result
|
||||
|
||||
|
||||
def evaluate_listcomp(listcomp, state, tools):
|
||||
result = []
|
||||
for generator in listcomp.generators:
|
||||
iter_value = evaluate_ast(generator.iter, state, tools)
|
||||
def evaluate_listcomp(listcomp, state, static_tools, custom_tools):
|
||||
def inner_evaluate(generators, index, current_state):
|
||||
if index >= len(generators):
|
||||
return [evaluate_ast(listcomp.elt, current_state, static_tools, custom_tools)]
|
||||
generator = generators[index]
|
||||
iter_value = evaluate_ast(generator.iter, current_state, static_tools, custom_tools)
|
||||
result = []
|
||||
for value in iter_value:
|
||||
new_state = state.copy()
|
||||
new_state = current_state.copy()
|
||||
if isinstance(generator.target, ast.Tuple):
|
||||
for idx, elem in enumerate(generator.target.elts):
|
||||
new_state[elem.id] = value[idx]
|
||||
else:
|
||||
new_state[generator.target.id] = value
|
||||
if all(evaluate_ast(if_clause, new_state, tools) for if_clause in generator.ifs):
|
||||
result.append(evaluate_ast(listcomp.elt, new_state, tools))
|
||||
return result
|
||||
if all(evaluate_ast(if_clause, new_state, static_tools, custom_tools) for if_clause in generator.ifs):
|
||||
result.extend(inner_evaluate(generators, index + 1, new_state))
|
||||
return result
|
||||
|
||||
return inner_evaluate(listcomp.generators, 0, state)
|
||||
|
||||
|
||||
def evaluate_try(try_node, state, tools):
|
||||
def evaluate_try(try_node, state, static_tools, custom_tools):
|
||||
try:
|
||||
for stmt in try_node.body:
|
||||
evaluate_ast(stmt, state, tools)
|
||||
evaluate_ast(stmt, state, static_tools, custom_tools)
|
||||
except Exception as e:
|
||||
matched = False
|
||||
for handler in try_node.handlers:
|
||||
if handler.type is None or isinstance(e, evaluate_ast(handler.type, state, tools)):
|
||||
if handler.type is None or isinstance(e, evaluate_ast(handler.type, state, static_tools, custom_tools)):
|
||||
matched = True
|
||||
if handler.name:
|
||||
state[handler.name] = e
|
||||
for stmt in handler.body:
|
||||
evaluate_ast(stmt, state, tools)
|
||||
evaluate_ast(stmt, state, static_tools, custom_tools)
|
||||
break
|
||||
if not matched:
|
||||
raise e
|
||||
else:
|
||||
if try_node.orelse:
|
||||
for stmt in try_node.orelse:
|
||||
evaluate_ast(stmt, state, tools)
|
||||
evaluate_ast(stmt, state, static_tools, custom_tools)
|
||||
finally:
|
||||
if try_node.finalbody:
|
||||
for stmt in try_node.finalbody:
|
||||
evaluate_ast(stmt, state, tools)
|
||||
evaluate_ast(stmt, state, static_tools, custom_tools)
|
||||
|
||||
|
||||
def evaluate_raise(raise_node, state, tools):
|
||||
def evaluate_raise(raise_node, state, static_tools, custom_tools):
|
||||
if raise_node.exc is not None:
|
||||
exc = evaluate_ast(raise_node.exc, state, tools)
|
||||
exc = evaluate_ast(raise_node.exc, state, static_tools, custom_tools)
|
||||
else:
|
||||
exc = None
|
||||
if raise_node.cause is not None:
|
||||
cause = evaluate_ast(raise_node.cause, state, tools)
|
||||
cause = evaluate_ast(raise_node.cause, state, static_tools, custom_tools)
|
||||
else:
|
||||
cause = None
|
||||
if exc is not None:
|
||||
@ -579,11 +625,11 @@ def evaluate_raise(raise_node, state, tools):
|
||||
raise InterpreterError("Re-raise is not supported without an active exception")
|
||||
|
||||
|
||||
def evaluate_assert(assert_node, state, tools):
|
||||
test_result = evaluate_ast(assert_node.test, state, tools)
|
||||
def evaluate_assert(assert_node, state, static_tools, custom_tools):
|
||||
test_result = evaluate_ast(assert_node.test, state, static_tools, custom_tools)
|
||||
if not test_result:
|
||||
if assert_node.msg:
|
||||
msg = evaluate_ast(assert_node.msg, state, tools)
|
||||
msg = evaluate_ast(assert_node.msg, state, static_tools, custom_tools)
|
||||
raise AssertionError(msg)
|
||||
else:
|
||||
# Include the failing condition in the assertion message
|
||||
@ -591,10 +637,10 @@ def evaluate_assert(assert_node, state, tools):
|
||||
raise AssertionError(f"Assertion failed: {test_code}")
|
||||
|
||||
|
||||
def evaluate_with(with_node, state, tools):
|
||||
def evaluate_with(with_node, state, static_tools, custom_tools):
|
||||
contexts = []
|
||||
for item in with_node.items:
|
||||
context_expr = evaluate_ast(item.context_expr, state, tools)
|
||||
context_expr = evaluate_ast(item.context_expr, state, static_tools, custom_tools)
|
||||
if item.optional_vars:
|
||||
state[item.optional_vars.id] = context_expr.__enter__()
|
||||
contexts.append(state[item.optional_vars.id])
|
||||
@ -604,7 +650,7 @@ def evaluate_with(with_node, state, tools):
|
||||
|
||||
try:
|
||||
for stmt in with_node.body:
|
||||
evaluate_ast(stmt, state, tools)
|
||||
evaluate_ast(stmt, state, static_tools, custom_tools)
|
||||
except Exception as e:
|
||||
for context in reversed(contexts):
|
||||
context.__exit__(type(e), e, e.__traceback__)
|
||||
@ -614,10 +660,51 @@ def evaluate_with(with_node, state, tools):
|
||||
context.__exit__(None, None, None)
|
||||
|
||||
|
||||
def import_modules(expression, state, authorized_imports):
|
||||
def check_module_authorized(module_name):
|
||||
module_path = module_name.split(".")
|
||||
module_subpaths = [".".join(module_path[:i]) for i in range(1, len(module_path) + 1)]
|
||||
return any(subpath in authorized_imports for subpath in module_subpaths)
|
||||
|
||||
if isinstance(expression, ast.Import):
|
||||
for alias in expression.names:
|
||||
if check_module_authorized(alias.name):
|
||||
module = import_module(alias.name)
|
||||
state[alias.asname or alias.name] = module
|
||||
else:
|
||||
raise InterpreterError(
|
||||
f"Import of {alias.name} is not allowed. Authorized imports are: {str(authorized_imports)}"
|
||||
)
|
||||
return None
|
||||
elif isinstance(expression, ast.ImportFrom):
|
||||
if check_module_authorized(expression.module):
|
||||
module = __import__(expression.module, fromlist=[alias.name for alias in expression.names])
|
||||
for alias in expression.names:
|
||||
state[alias.asname or alias.name] = getattr(module, alias.name)
|
||||
else:
|
||||
raise InterpreterError(f"Import from {expression.module} is not allowed.")
|
||||
return None
|
||||
|
||||
|
||||
def evaluate_dictcomp(dictcomp, state, static_tools, custom_tools):
|
||||
result = {}
|
||||
for gen in dictcomp.generators:
|
||||
iter_value = evaluate_ast(gen.iter, state, static_tools, custom_tools)
|
||||
for value in iter_value:
|
||||
new_state = state.copy()
|
||||
set_value(gen.target, value, new_state, static_tools, custom_tools)
|
||||
if all(evaluate_ast(if_clause, new_state, static_tools, custom_tools) for if_clause in gen.ifs):
|
||||
key = evaluate_ast(dictcomp.key, new_state, static_tools, custom_tools)
|
||||
val = evaluate_ast(dictcomp.value, new_state, static_tools, custom_tools)
|
||||
result[key] = val
|
||||
return result
|
||||
|
||||
|
||||
def evaluate_ast(
|
||||
expression: ast.AST,
|
||||
state: Dict[str, Any],
|
||||
tools: Dict[str, Callable],
|
||||
static_tools: Dict[str, Callable],
|
||||
custom_tools: Dict[str, Callable],
|
||||
authorized_imports: List[str] = LIST_SAFE_MODULES,
|
||||
):
|
||||
"""
|
||||
@ -632,146 +719,128 @@ def evaluate_ast(
|
||||
state (`Dict[str, Any]`):
|
||||
A dictionary mapping variable names to values. The `state` is updated if need be when the evaluation
|
||||
encounters assignements.
|
||||
tools (`Dict[str, Callable]`):
|
||||
The functions that may be called during the evaluation. Any call to another function will fail with an
|
||||
`InterpreterError`.
|
||||
static_tools (`Dict[str, Callable]`):
|
||||
Functions that may be called during the evaluation. Trying to change one of these static_tools will raise an error.
|
||||
custom_tools (`Dict[str, Callable]`):
|
||||
Functions that may be called during the evaluation. These static_tools can be overwritten.
|
||||
authorized_imports (`List[str]`):
|
||||
The list of modules that can be imported by the code. By default, only a few safe modules are allowed.
|
||||
Add more at your own risk!
|
||||
"""
|
||||
global OPERATIONS_COUNT
|
||||
if OPERATIONS_COUNT >= MAX_OPERATIONS:
|
||||
raise InterpreterError(
|
||||
f"Reached the max number of operations of {MAX_OPERATIONS}. Maybe there is an infinite loop somewhere in the code, or you're just asking too many calculations."
|
||||
)
|
||||
OPERATIONS_COUNT += 1
|
||||
if isinstance(expression, ast.Assign):
|
||||
# Assignement -> we evaluate the assignment which should update the state
|
||||
# We return the variable assigned as it may be used to determine the final result.
|
||||
return evaluate_assign(expression, state, tools)
|
||||
return evaluate_assign(expression, state, static_tools, custom_tools)
|
||||
elif isinstance(expression, ast.AugAssign):
|
||||
return evaluate_augassign(expression, state, tools)
|
||||
return evaluate_augassign(expression, state, static_tools, custom_tools)
|
||||
elif isinstance(expression, ast.Call):
|
||||
# Function call -> we return the value of the function call
|
||||
return evaluate_call(expression, state, tools)
|
||||
return evaluate_call(expression, state, static_tools, custom_tools)
|
||||
elif isinstance(expression, ast.Constant):
|
||||
# Constant -> just return the value
|
||||
return expression.value
|
||||
elif isinstance(expression, ast.Tuple):
|
||||
return tuple(evaluate_ast(elt, state, tools) for elt in expression.elts)
|
||||
elif isinstance(expression, ast.ListComp):
|
||||
return evaluate_listcomp(expression, state, tools)
|
||||
return tuple(evaluate_ast(elt, state, static_tools, custom_tools) for elt in expression.elts)
|
||||
elif isinstance(expression, (ast.ListComp, ast.GeneratorExp)):
|
||||
return evaluate_listcomp(expression, state, static_tools, custom_tools)
|
||||
elif isinstance(expression, ast.UnaryOp):
|
||||
return evaluate_unaryop(expression, state, tools)
|
||||
return evaluate_unaryop(expression, state, static_tools, custom_tools)
|
||||
elif isinstance(expression, ast.Starred):
|
||||
return evaluate_ast(expression.value, state, static_tools, custom_tools)
|
||||
elif isinstance(expression, ast.BoolOp):
|
||||
# Boolean operation -> evaluate the operation
|
||||
return evaluate_boolop(expression, state, tools)
|
||||
return evaluate_boolop(expression, state, static_tools, custom_tools)
|
||||
elif isinstance(expression, ast.Break):
|
||||
raise BreakException()
|
||||
elif isinstance(expression, ast.Continue):
|
||||
raise ContinueException()
|
||||
elif isinstance(expression, ast.BinOp):
|
||||
# Binary operation -> execute operation
|
||||
return evaluate_binop(expression, state, tools)
|
||||
return evaluate_binop(expression, state, static_tools, custom_tools)
|
||||
elif isinstance(expression, ast.Compare):
|
||||
# Comparison -> evaluate the comparison
|
||||
return evaluate_condition(expression, state, tools)
|
||||
return evaluate_condition(expression, state, static_tools, custom_tools)
|
||||
elif isinstance(expression, ast.Lambda):
|
||||
return evaluate_lambda(expression, state, tools)
|
||||
return evaluate_lambda(expression, state, static_tools, custom_tools)
|
||||
elif isinstance(expression, ast.FunctionDef):
|
||||
return evaluate_function_def(expression, state, tools)
|
||||
return evaluate_function_def(expression, state, static_tools, custom_tools)
|
||||
elif isinstance(expression, ast.Dict):
|
||||
# Dict -> evaluate all keys and values
|
||||
keys = [evaluate_ast(k, state, tools) for k in expression.keys]
|
||||
values = [evaluate_ast(v, state, tools) for v in expression.values]
|
||||
keys = [evaluate_ast(k, state, static_tools, custom_tools) for k in expression.keys]
|
||||
values = [evaluate_ast(v, state, static_tools, custom_tools) for v in expression.values]
|
||||
return dict(zip(keys, values))
|
||||
elif isinstance(expression, ast.Expr):
|
||||
# Expression -> evaluate the content
|
||||
return evaluate_ast(expression.value, state, tools)
|
||||
return evaluate_ast(expression.value, state, static_tools, custom_tools)
|
||||
elif isinstance(expression, ast.For):
|
||||
# For loop -> execute the loop
|
||||
return evaluate_for(expression, state, tools)
|
||||
return evaluate_for(expression, state, static_tools, custom_tools)
|
||||
elif isinstance(expression, ast.FormattedValue):
|
||||
# Formatted value (part of f-string) -> evaluate the content and return
|
||||
return evaluate_ast(expression.value, state, tools)
|
||||
return evaluate_ast(expression.value, state, static_tools, custom_tools)
|
||||
elif isinstance(expression, ast.If):
|
||||
# If -> execute the right branch
|
||||
return evaluate_if(expression, state, tools)
|
||||
return evaluate_if(expression, state, static_tools, custom_tools)
|
||||
elif hasattr(ast, "Index") and isinstance(expression, ast.Index):
|
||||
return evaluate_ast(expression.value, state, tools)
|
||||
return evaluate_ast(expression.value, state, static_tools, custom_tools)
|
||||
elif isinstance(expression, ast.JoinedStr):
|
||||
return "".join([str(evaluate_ast(v, state, tools)) for v in expression.values])
|
||||
return "".join([str(evaluate_ast(v, state, static_tools, custom_tools)) for v in expression.values])
|
||||
elif isinstance(expression, ast.List):
|
||||
# List -> evaluate all elements
|
||||
return [evaluate_ast(elt, state, tools) for elt in expression.elts]
|
||||
return [evaluate_ast(elt, state, static_tools, custom_tools) for elt in expression.elts]
|
||||
elif isinstance(expression, ast.Name):
|
||||
# Name -> pick up the value in the state
|
||||
return evaluate_name(expression, state, tools)
|
||||
return evaluate_name(expression, state, static_tools, custom_tools)
|
||||
elif isinstance(expression, ast.Subscript):
|
||||
# Subscript -> return the value of the indexing
|
||||
return evaluate_subscript(expression, state, tools)
|
||||
return evaluate_subscript(expression, state, static_tools, custom_tools)
|
||||
elif isinstance(expression, ast.IfExp):
|
||||
test_val = evaluate_ast(expression.test, state, tools)
|
||||
test_val = evaluate_ast(expression.test, state, static_tools, custom_tools)
|
||||
if test_val:
|
||||
return evaluate_ast(expression.body, state, tools)
|
||||
return evaluate_ast(expression.body, state, static_tools, custom_tools)
|
||||
else:
|
||||
return evaluate_ast(expression.orelse, state, tools)
|
||||
return evaluate_ast(expression.orelse, state, static_tools, custom_tools)
|
||||
elif isinstance(expression, ast.Attribute):
|
||||
obj = evaluate_ast(expression.value, state, tools)
|
||||
return getattr(obj, expression.attr)
|
||||
value = evaluate_ast(expression.value, state, static_tools, custom_tools)
|
||||
return getattr(value, expression.attr)
|
||||
elif isinstance(expression, ast.Slice):
|
||||
return slice(
|
||||
evaluate_ast(expression.lower, state, tools) if expression.lower is not None else None,
|
||||
evaluate_ast(expression.upper, state, tools) if expression.upper is not None else None,
|
||||
evaluate_ast(expression.step, state, tools) if expression.step is not None else None,
|
||||
evaluate_ast(expression.lower, state, static_tools, custom_tools)
|
||||
if expression.lower is not None
|
||||
else None,
|
||||
evaluate_ast(expression.upper, state, static_tools, custom_tools)
|
||||
if expression.upper is not None
|
||||
else None,
|
||||
evaluate_ast(expression.step, state, static_tools, custom_tools) if expression.step is not None else None,
|
||||
)
|
||||
elif isinstance(expression, ast.ListComp) or isinstance(expression, ast.GeneratorExp):
|
||||
result = []
|
||||
vars = {}
|
||||
for generator in expression.generators:
|
||||
var_name = generator.target.id
|
||||
iter_value = evaluate_ast(generator.iter, state, tools)
|
||||
for value in iter_value:
|
||||
vars[var_name] = value
|
||||
if all(evaluate_ast(if_clause, {**state, **vars}, tools) for if_clause in generator.ifs):
|
||||
elem = evaluate_ast(expression.elt, {**state, **vars}, tools)
|
||||
result.append(elem)
|
||||
return result
|
||||
elif isinstance(expression, ast.DictComp):
|
||||
result = {}
|
||||
for gen in expression.generators:
|
||||
for container in get_iterable(evaluate_ast(gen.iter, state, tools)):
|
||||
state[gen.target.id] = container
|
||||
key = evaluate_ast(expression.key, state, tools)
|
||||
value = evaluate_ast(expression.value, state, tools)
|
||||
result[key] = value
|
||||
return result
|
||||
elif isinstance(expression, ast.Import):
|
||||
for alias in expression.names:
|
||||
if alias.name in authorized_imports:
|
||||
module = __import__(alias.name)
|
||||
state[alias.asname or alias.name] = module
|
||||
else:
|
||||
raise InterpreterError(f"Import of {alias.name} is not allowed.")
|
||||
return None
|
||||
return evaluate_dictcomp(expression, state, static_tools, custom_tools)
|
||||
elif isinstance(expression, ast.While):
|
||||
return evaluate_while(expression, state, tools)
|
||||
elif isinstance(expression, ast.ImportFrom):
|
||||
if expression.module in authorized_imports:
|
||||
module = __import__(expression.module)
|
||||
for alias in expression.names:
|
||||
state[alias.asname or alias.name] = getattr(module, alias.name)
|
||||
else:
|
||||
raise InterpreterError(f"Import from {expression.module} is not allowed.")
|
||||
return None
|
||||
return evaluate_while(expression, state, static_tools, custom_tools)
|
||||
elif isinstance(expression, (ast.Import, ast.ImportFrom)):
|
||||
return import_modules(expression, state, authorized_imports)
|
||||
elif isinstance(expression, ast.ClassDef):
|
||||
return evaluate_class_def(expression, state, tools)
|
||||
return evaluate_class_def(expression, state, static_tools, custom_tools)
|
||||
elif isinstance(expression, ast.Try):
|
||||
return evaluate_try(expression, state, tools)
|
||||
return evaluate_try(expression, state, static_tools, custom_tools)
|
||||
elif isinstance(expression, ast.Raise):
|
||||
return evaluate_raise(expression, state, tools)
|
||||
return evaluate_raise(expression, state, static_tools, custom_tools)
|
||||
elif isinstance(expression, ast.Assert):
|
||||
return evaluate_assert(expression, state, tools)
|
||||
return evaluate_assert(expression, state, static_tools, custom_tools)
|
||||
elif isinstance(expression, ast.With):
|
||||
return evaluate_with(expression, state, tools)
|
||||
return evaluate_with(expression, state, static_tools, custom_tools)
|
||||
elif isinstance(expression, ast.Set):
|
||||
return {evaluate_ast(elt, state, tools) for elt in expression.elts}
|
||||
return {evaluate_ast(elt, state, static_tools, custom_tools) for elt in expression.elts}
|
||||
elif isinstance(expression, ast.Return):
|
||||
raise ReturnException(evaluate_ast(expression.value, state, tools) if expression.value else None)
|
||||
raise ReturnException(
|
||||
evaluate_ast(expression.value, state, static_tools, custom_tools) if expression.value else None
|
||||
)
|
||||
else:
|
||||
# For now we refuse anything else. Let's add things as we need them.
|
||||
raise InterpreterError(f"{expression.__class__.__name__} is not supported.")
|
||||
@ -779,7 +848,8 @@ def evaluate_ast(
|
||||
|
||||
def evaluate_python_code(
|
||||
code: str,
|
||||
tools: Optional[Dict[str, Callable]] = None,
|
||||
static_tools: Optional[Dict[str, Callable]] = None,
|
||||
custom_tools: Optional[Dict[str, Callable]] = None,
|
||||
state: Optional[Dict[str, Any]] = None,
|
||||
authorized_imports: List[str] = LIST_SAFE_MODULES,
|
||||
):
|
||||
@ -792,9 +862,12 @@ def evaluate_python_code(
|
||||
Args:
|
||||
code (`str`):
|
||||
The code to evaluate.
|
||||
tools (`Dict[str, Callable]`):
|
||||
The functions that may be called during the evaluation. Any call to another function will fail with an
|
||||
`InterpreterError`.
|
||||
static_tools (`Dict[str, Callable]`):
|
||||
The functions that may be called during the evaluation.
|
||||
These tools cannot be overwritten in the code: any assignment to their name will raise an error.
|
||||
custom_tools (`Dict[str, Callable]`):
|
||||
The functions that may be called during the evaluation.
|
||||
These tools can be overwritten in the code: any assignment to their name will overwrite them.
|
||||
state (`Dict[str, Any]`):
|
||||
A dictionary mapping variable names to values. The `state` should contain the initial inputs but will be
|
||||
updated by this function to contain all variables as they are evaluated.
|
||||
@ -806,20 +879,34 @@ def evaluate_python_code(
|
||||
raise SyntaxError(f"The code generated by the agent is not valid.\n{e}")
|
||||
if state is None:
|
||||
state = {}
|
||||
if tools is None:
|
||||
tools = {}
|
||||
if static_tools is None:
|
||||
static_tools = {}
|
||||
if custom_tools is None:
|
||||
custom_tools = {}
|
||||
result = None
|
||||
global PRINT_OUTPUTS
|
||||
PRINT_OUTPUTS = ""
|
||||
global OPERATIONS_COUNT
|
||||
OPERATIONS_COUNT = 0
|
||||
for node in expression.body:
|
||||
try:
|
||||
result = evaluate_ast(node, state, tools, authorized_imports)
|
||||
result = evaluate_ast(node, state, static_tools, custom_tools, authorized_imports)
|
||||
except InterpreterError as e:
|
||||
msg = f"Evaluation stopped at line '{ast.get_source_segment(code, node)}' because of the following error:\n{e}"
|
||||
msg = ""
|
||||
if len(PRINT_OUTPUTS) > 0:
|
||||
msg += f"Executing code yielded these outputs:\n{PRINT_OUTPUTS}\n====\n"
|
||||
if len(PRINT_OUTPUTS) < MAX_LEN_OUTPUT:
|
||||
msg += f"Print outputs:\n{PRINT_OUTPUTS}\n====\n"
|
||||
else:
|
||||
msg += f"Print outputs:\n{PRINT_OUTPUTS[:MAX_LEN_OUTPUT]}\n_Print outputs were over {MAX_LEN_OUTPUT} characters, so they have been truncated._\n====\n"
|
||||
msg += f"EXECUTION FAILED:\nEvaluation stopped at line '{ast.get_source_segment(code, node)}' because of the following error:\n{e}"
|
||||
raise InterpreterError(msg)
|
||||
finally:
|
||||
state["print_outputs"] = PRINT_OUTPUTS
|
||||
if len(PRINT_OUTPUTS) < MAX_LEN_OUTPUT:
|
||||
state["print_outputs"] = PRINT_OUTPUTS
|
||||
else:
|
||||
state["print_outputs"] = (
|
||||
PRINT_OUTPUTS[:MAX_LEN_OUTPUT]
|
||||
+ f"\n_Print outputs were over {MAX_LEN_OUTPUT} characters, so they have been truncated._"
|
||||
)
|
||||
|
||||
return result
|
||||
|
@ -223,7 +223,7 @@ Action:
|
||||
# check that add_base_tools will not interfere with existing tools
|
||||
with pytest.raises(KeyError) as e:
|
||||
agent = ReactJsonAgent(tools=toolset_3, llm_engine=fake_react_json_llm, add_base_tools=True)
|
||||
assert "python_interpreter already exists in the toolbox" in str(e)
|
||||
assert "already exists in the toolbox" in str(e)
|
||||
|
||||
# check that python_interpreter base tool does not get added to code agents
|
||||
agent = ReactCodeAgent(tools=[], llm_engine=fake_react_code_llm, add_base_tools=True)
|
||||
|
@ -15,6 +15,7 @@
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from transformers import load_tool
|
||||
@ -241,8 +242,41 @@ for block in text_block:
|
||||
code = """
|
||||
digits, i = [1, 2, 3], 1
|
||||
digits[i], digits[i + 1] = digits[i + 1], digits[i]"""
|
||||
evaluate_python_code(code, {"range": range, "print": print, "int": int}, {})
|
||||
|
||||
code = """
|
||||
def calculate_isbn_10_check_digit(number):
|
||||
total = sum((10 - i) * int(digit) for i, digit in enumerate(number))
|
||||
remainder = total % 11
|
||||
check_digit = 11 - remainder
|
||||
if check_digit == 10:
|
||||
return 'X'
|
||||
elif check_digit == 11:
|
||||
return '0'
|
||||
else:
|
||||
return str(check_digit)
|
||||
|
||||
# Given 9-digit numbers
|
||||
numbers = [
|
||||
"478225952",
|
||||
"643485613",
|
||||
"739394228",
|
||||
"291726859",
|
||||
"875262394",
|
||||
"542617795",
|
||||
"031810713",
|
||||
"957007669",
|
||||
"871467426"
|
||||
]
|
||||
|
||||
# Calculate check digits for each number
|
||||
check_digits = [calculate_isbn_10_check_digit(number) for number in numbers]
|
||||
print(check_digits)
|
||||
"""
|
||||
state = {}
|
||||
evaluate_python_code(code, {"range": range, "print": print, "int": int}, state)
|
||||
evaluate_python_code(
|
||||
code, {"range": range, "print": print, "sum": sum, "enumerate": enumerate, "int": int, "str": str}, state
|
||||
)
|
||||
|
||||
def test_listcomp(self):
|
||||
code = "x = [i for i in range(3)]"
|
||||
@ -273,6 +307,17 @@ digits[i], digits[i + 1] = digits[i + 1], digits[i]"""
|
||||
result = evaluate_python_code(code, {"range": range}, state={})
|
||||
assert result == {0: 0, 1: 1, 2: 4}
|
||||
|
||||
code = "{num: name for num, name in {101: 'a', 102: 'b'}.items() if name not in ['a']}"
|
||||
result = evaluate_python_code(code, {"print": print}, state={}, authorized_imports=["pandas"])
|
||||
assert result == {102: "b"}
|
||||
|
||||
code = """
|
||||
shifts = {'A': ('6:45', '8:00'), 'B': ('10:00', '11:45')}
|
||||
shift_minutes = {worker: ('a', 'b') for worker, (start, end) in shifts.items()}
|
||||
"""
|
||||
result = evaluate_python_code(code, {}, state={})
|
||||
assert result == {"A": ("a", "b"), "B": ("a", "b")}
|
||||
|
||||
def test_tuple_assignment(self):
|
||||
code = "a, b = 0, 1\nb"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
@ -341,7 +386,7 @@ if char.isalpha():
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result == "lose"
|
||||
|
||||
code = "import time\ntime.sleep(0.1)"
|
||||
code = "import time, re\ntime.sleep(0.1)"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result is None
|
||||
|
||||
@ -369,6 +414,23 @@ if char.isalpha():
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result == "LATIN CAPITAL LETTER A"
|
||||
|
||||
# Test submodules are handled properly, thus not raising error
|
||||
code = "import numpy.random as rd\nrng = rd.default_rng(12345)\nrng.random()"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"])
|
||||
|
||||
code = "from numpy.random import default_rng as d_rng\nrng = d_rng(12345)\nrng.random()"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"])
|
||||
|
||||
def test_additional_imports(self):
|
||||
code = "import numpy as np"
|
||||
evaluate_python_code(code, authorized_imports=["numpy"], state={})
|
||||
|
||||
code = "import numpy.random as rd"
|
||||
evaluate_python_code(code, authorized_imports=["numpy.random"], state={})
|
||||
evaluate_python_code(code, authorized_imports=["numpy"], state={})
|
||||
with pytest.raises(InterpreterError):
|
||||
evaluate_python_code(code, authorized_imports=["random"], state={})
|
||||
|
||||
def test_multiple_comparators(self):
|
||||
code = "0 <= -1 < 4 and 0 <= -5 < 4"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
@ -400,7 +462,7 @@ def function():
|
||||
print("2")
|
||||
function()"""
|
||||
state = {}
|
||||
evaluate_python_code(code, {"print": print}, state)
|
||||
evaluate_python_code(code, {"print": print}, state=state)
|
||||
assert state["print_outputs"] == "1\n2\n"
|
||||
|
||||
def test_tuple_target_in_iterator(self):
|
||||
@ -612,7 +674,7 @@ assert lock.locked == False
|
||||
"""
|
||||
state = {}
|
||||
tools = {}
|
||||
evaluate_python_code(code, tools, state)
|
||||
evaluate_python_code(code, tools, state=state)
|
||||
|
||||
def test_default_arg_in_function(self):
|
||||
code = """
|
||||
@ -672,3 +734,94 @@ returns_none(1)
|
||||
state = {}
|
||||
result = evaluate_python_code(code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state)
|
||||
assert result is None
|
||||
|
||||
def test_nested_for_loop(self):
|
||||
code = """
|
||||
all_res = []
|
||||
for i in range(10):
|
||||
subres = []
|
||||
for j in range(i):
|
||||
subres.append(j)
|
||||
all_res.append(subres)
|
||||
|
||||
out = [i for sublist in all_res for i in sublist]
|
||||
out[:10]
|
||||
"""
|
||||
state = {}
|
||||
result = evaluate_python_code(code, {"print": print, "range": range}, state=state)
|
||||
assert result == [0, 0, 1, 0, 1, 2, 0, 1, 2, 3]
|
||||
|
||||
def test_pandas(self):
|
||||
code = """
|
||||
import pandas as pd
|
||||
|
||||
df = pd.DataFrame.from_dict({'SetCount': ['5', '4', '5'], 'Quantity': [1, 0, -1]})
|
||||
|
||||
df['SetCount'] = pd.to_numeric(df['SetCount'], errors='coerce')
|
||||
|
||||
parts_with_5_set_count = df[df['SetCount'] == 5.0]
|
||||
parts_with_5_set_count[['Quantity', 'SetCount']].values[1]
|
||||
"""
|
||||
state = {}
|
||||
result = evaluate_python_code(code, {}, state=state, authorized_imports=["pandas"])
|
||||
assert np.array_equal(result, [-1, 5])
|
||||
|
||||
code = """
|
||||
import pandas as pd
|
||||
|
||||
df = pd.DataFrame.from_dict({"AtomicNumber": [111, 104, 105], "ok": [0, 1, 2]})
|
||||
print("HH0")
|
||||
|
||||
# Filter the DataFrame to get only the rows with outdated atomic numbers
|
||||
filtered_df = df.loc[df['AtomicNumber'].isin([104])]
|
||||
"""
|
||||
result = evaluate_python_code(code, {"print": print}, state={}, authorized_imports=["pandas"])
|
||||
assert np.array_equal(result.values[0], [104, 1])
|
||||
|
||||
code = """import pandas as pd
|
||||
data = pd.DataFrame.from_dict([
|
||||
{"Pclass": 1, "Survived": 1},
|
||||
{"Pclass": 2, "Survived": 0},
|
||||
{"Pclass": 2, "Survived": 1}
|
||||
])
|
||||
survival_rate_by_class = data.groupby('Pclass')['Survived'].mean()
|
||||
"""
|
||||
result = evaluate_python_code(code, {}, state={}, authorized_imports=["pandas"])
|
||||
assert result.values[1] == 0.5
|
||||
|
||||
def test_starred(self):
|
||||
code = """
|
||||
from math import radians, sin, cos, sqrt, atan2
|
||||
|
||||
def haversine(lat1, lon1, lat2, lon2):
|
||||
R = 6371000 # Radius of the Earth in meters
|
||||
lat1, lon1, lat2, lon2 = map(radians, [lat1, lon1, lat2, lon2])
|
||||
dlat = lat2 - lat1
|
||||
dlon = lon2 - lon1
|
||||
a = sin(dlat / 2) ** 2 + cos(lat1) * cos(lat2) * sin(dlon / 2) ** 2
|
||||
c = 2 * atan2(sqrt(a), sqrt(1 - a))
|
||||
distance = R * c
|
||||
return distance
|
||||
|
||||
coords_geneva = (46.1978, 6.1342)
|
||||
coords_barcelona = (41.3869, 2.1660)
|
||||
|
||||
distance_geneva_barcelona = haversine(*coords_geneva, *coords_barcelona)
|
||||
"""
|
||||
result = evaluate_python_code(code, {"print": print, "map": map}, state={}, authorized_imports=["math"])
|
||||
assert round(result, 1) == 622395.4
|
||||
|
||||
def test_for(self):
|
||||
code = """
|
||||
shifts = {
|
||||
"Worker A": ("6:45 pm", "8:00 pm"),
|
||||
"Worker B": ("10:00 am", "11:45 am")
|
||||
}
|
||||
|
||||
shift_intervals = {}
|
||||
for worker, (start, end) in shifts.items():
|
||||
shift_intervals[worker] = end
|
||||
shift_intervals
|
||||
"""
|
||||
result = evaluate_python_code(code, {"print": print, "map": map}, state={})
|
||||
assert result == {"Worker A": "8:00 pm", "Worker B": "11:45 am"}
|
||||
|
Loading…
Reference in New Issue
Block a user