mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
Code agent: allow function persistence between steps (#31769)
* Code agent: allow function persistence between steps
This commit is contained in:
parent
eef0507f3d
commit
1556025271
@ -188,7 +188,7 @@ class AgentAudio(AgentType, str):
|
||||
self.samplerate = samplerate
|
||||
if isinstance(value, (str, pathlib.Path)):
|
||||
self._path = value
|
||||
elif isinstance(value, torch.Tensor):
|
||||
elif is_torch_available() and isinstance(value, torch.Tensor):
|
||||
self._tensor = value
|
||||
elif isinstance(value, tuple):
|
||||
self.samplerate = value[0]
|
||||
@ -232,7 +232,10 @@ class AgentAudio(AgentType, str):
|
||||
|
||||
|
||||
AGENT_TYPE_MAPPING = {"text": AgentText, "image": AgentImage, "audio": AgentAudio}
|
||||
INSTANCE_TYPE_MAPPING = {str: AgentText, float: AgentText, int: AgentText, Tensor: AgentAudio, ImageType: AgentImage}
|
||||
INSTANCE_TYPE_MAPPING = {str: AgentText, ImageType: AgentImage}
|
||||
|
||||
if is_torch_available():
|
||||
INSTANCE_TYPE_MAPPING[Tensor] = AgentAudio
|
||||
|
||||
|
||||
def handle_agent_inputs(*args, **kwargs):
|
||||
@ -251,4 +254,4 @@ def handle_agent_outputs(output, output_type=None):
|
||||
for _k, _v in INSTANCE_TYPE_MAPPING.items():
|
||||
if isinstance(output, _k):
|
||||
return _v(output)
|
||||
return AgentType(output)
|
||||
return output
|
||||
|
@ -856,6 +856,10 @@ class ReactCodeAgent(ReactAgent):
|
||||
self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else []
|
||||
self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports))
|
||||
self.system_prompt = self.system_prompt.replace("<<authorized_imports>>", str(self.authorized_imports))
|
||||
self.available_tools = {
|
||||
**BASE_PYTHON_TOOLS.copy(),
|
||||
**self.toolbox.tools,
|
||||
} # This list can be augmented by the code agent creating some new functions
|
||||
|
||||
def step(self):
|
||||
"""
|
||||
@ -905,10 +909,9 @@ class ReactCodeAgent(ReactAgent):
|
||||
# Execute
|
||||
self.log_code_action(code_action)
|
||||
try:
|
||||
available_tools = {**BASE_PYTHON_TOOLS.copy(), **self.toolbox.tools}
|
||||
result = self.python_evaluator(
|
||||
code_action,
|
||||
available_tools,
|
||||
tools=self.available_tools,
|
||||
state=self.state,
|
||||
authorized_imports=self.authorized_imports,
|
||||
)
|
||||
|
@ -778,7 +778,10 @@ def evaluate_ast(
|
||||
|
||||
|
||||
def evaluate_python_code(
|
||||
code: str, tools: Optional[Dict[str, Callable]] = {}, state=None, authorized_imports: List[str] = LIST_SAFE_MODULES
|
||||
code: str,
|
||||
tools: Optional[Dict[str, Callable]] = None,
|
||||
state: Optional[Dict[str, Any]] = None,
|
||||
authorized_imports: List[str] = LIST_SAFE_MODULES,
|
||||
):
|
||||
"""
|
||||
Evaluate a python expression using the content of the variables stored in a state and only evaluating a given set
|
||||
@ -803,6 +806,8 @@ def evaluate_python_code(
|
||||
raise SyntaxError(f"The code generated by the agent is not valid.\n{e}")
|
||||
if state is None:
|
||||
state = {}
|
||||
if tools is None:
|
||||
tools = {}
|
||||
result = None
|
||||
global PRINT_OUTPUTS
|
||||
PRINT_OUTPUTS = ""
|
||||
|
@ -94,12 +94,48 @@ final_answer("got an error")
|
||||
"""
|
||||
|
||||
|
||||
def fake_react_code_functiondef(messages, stop_sequences=None) -> str:
|
||||
prompt = str(messages)
|
||||
if "special_marker" not in prompt:
|
||||
return """
|
||||
Thought: Let's define the function. special_marker
|
||||
Code:
|
||||
```py
|
||||
import numpy as np
|
||||
|
||||
def moving_average(x, w):
|
||||
return np.convolve(x, np.ones(w), 'valid') / w
|
||||
```<end_code>
|
||||
"""
|
||||
else: # We're at step 2
|
||||
return """
|
||||
Thought: I can now answer the initial question
|
||||
Code:
|
||||
```py
|
||||
x, w = [0, 1, 2, 3, 4, 5], 2
|
||||
res = moving_average(x, w)
|
||||
final_answer(res)
|
||||
```<end_code>
|
||||
"""
|
||||
|
||||
|
||||
def fake_code_llm_oneshot(messages, stop_sequences=None) -> str:
|
||||
return """
|
||||
Thought: I should multiply 2 by 3.6452. special_marker
|
||||
Code:
|
||||
```py
|
||||
result = python_interpreter(code="2*3.6452")
|
||||
final_answer(result)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def fake_code_llm_no_return(messages, stop_sequences=None) -> str:
|
||||
return """
|
||||
Thought: I should multiply 2 by 3.6452. special_marker
|
||||
Code:
|
||||
```py
|
||||
result = python_interpreter(code="2*3.6452")
|
||||
print(result)
|
||||
```
|
||||
"""
|
||||
@ -135,8 +171,8 @@ Action:
|
||||
def test_fake_react_code_agent(self):
|
||||
agent = ReactCodeAgent(tools=[PythonInterpreterTool()], llm_engine=fake_react_code_llm)
|
||||
output = agent.run("What is 2 multiplied by 3.6452?")
|
||||
assert isinstance(output, AgentText)
|
||||
assert output == "7.2904"
|
||||
assert isinstance(output, float)
|
||||
assert output == 7.2904
|
||||
assert agent.logs[0]["task"] == "What is 2 multiplied by 3.6452?"
|
||||
assert float(agent.logs[1]["observation"].strip()) - 12.511648 < 1e-6
|
||||
assert agent.logs[2]["tool_call"] == {
|
||||
@ -157,7 +193,7 @@ Action:
|
||||
def test_react_fails_max_iterations(self):
|
||||
agent = ReactCodeAgent(
|
||||
tools=[PythonInterpreterTool()],
|
||||
llm_engine=fake_code_llm_oneshot, # use this callable because it never ends
|
||||
llm_engine=fake_code_llm_no_return, # use this callable because it never ends
|
||||
max_iterations=5,
|
||||
)
|
||||
agent.run("What is 2 multiplied by 3.6452?")
|
||||
@ -192,3 +228,10 @@ Action:
|
||||
# check that python_interpreter base tool does not get added to code agents
|
||||
agent = ReactCodeAgent(tools=[], llm_engine=fake_react_code_llm, add_base_tools=True)
|
||||
assert len(agent.toolbox.tools) == 6 # added final_answer tool + 5 base tools (excluding interpreter)
|
||||
|
||||
def test_function_persistence_across_steps(self):
|
||||
agent = ReactCodeAgent(
|
||||
tools=[], llm_engine=fake_react_code_functiondef, max_iterations=2, additional_authorized_imports=["numpy"]
|
||||
)
|
||||
res = agent.run("ok")
|
||||
assert res[0] == 0.5
|
||||
|
@ -660,7 +660,6 @@ add_one(1, 1)
|
||||
"""
|
||||
state = {}
|
||||
result = evaluate_python_code(code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state)
|
||||
print(state)
|
||||
assert result == 2
|
||||
|
||||
# test returning None
|
||||
@ -672,5 +671,4 @@ returns_none(1)
|
||||
"""
|
||||
state = {}
|
||||
result = evaluate_python_code(code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state)
|
||||
print(state)
|
||||
assert result is None
|
||||
|
Loading…
Reference in New Issue
Block a user