mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
Reboot Agents (#30387)
* Create CodeAgent and ReactAgent * Fix formatting errors * Update documentation for agents * Add custom errors, improve logging * Support variable usage in ReactAgent * add messages * Add message passing format * Create React Code Agent * Update * Refactoring * Fix errors * Improve python interpreter * Only non-tensor inputs should be sent to device * Calculator tool slight refactor * Improve docstrings * Refactor * Fix tests * Fix more tests * Fix even more tests * Fix tests by replacing output and input types * Fix operand type issue * two small fixes * EM TTS * Fix agent running type errors * Change text to speech tests to allow changed outputs * Update doc with new agent types * Improve code interpreter * If max iterations reached, provide a real answer instead of an error * Add edge case in interpreter * Add safe imports to the interpreter * Interpreter tweaks: tuples and listcomp * Make style * Make quality * Add dictcomp to interpreter * Rename ReactJSONAgent to ReactJsonAgent * Misc changes * ToolCollection * Rename agent's logger to self.logger * Add while loops to interpreter * Update doc with new tools. still need to mention collections * Add collections to the doc * Small fixes on logs and interpretor * Fix toolbox return type * Docs + fixup * Skip doctests * Correct prompts with improved examples and formatting * Update prompt * Remove outdated docs * Change agent to accept Toolbox object for tools * Remove calculator tool * Propagate removal of calculator in doc * Fix 2 failing workflows * Simplify additional argument passing * AgentType audio * Minor changes: function name, types * Remove calculator tests * Fix test * Fix torch requirement * Fix final answer tests * Style fixes * Fix tests * Update docstrings with calculator removal * Small type hint fixes * Update tests/agents/test_translation.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update tests/agents/test_python_interpreter.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/agents/default_tools.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/agents/tools.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update tests/agents/test_agents.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/bert/configuration_bert.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/agents/tools.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/agents/speech_to_text.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update tests/agents/test_speech_to_text.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update tests/agents/test_tools_common.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * pygments * Answer comments * Cleaning up * Simplifying init for all agents * Improving prompts and making code nicer * Style fixes * Add multiple comparator test in interpreter * Style fixes * Improve BERT example in documentation * Add examples to doc * Fix python interpreter quality * Logging improvements * Change test flag to agents * Quality fix * Add example for HfEngine * Improve conversation example for HfEngine * typo fix * Verify doc * Update docs/source/en/agents.md Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/agents/agents.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/agents/prompts.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/agents/python_interpreter.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update docs/source/en/agents.md Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Fix style issues * local s2t tool --------- Co-authored-by: Cyril Kondratenko <kkn1993@gmail.com> Co-authored-by: Lysandre <lysandre@huggingface.co> Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
parent
3733391c53
commit
0ba15cedbc
@ -71,7 +71,7 @@ NOT_DEVICE_TESTS = {
|
||||
"ModelTester::test_pipeline_",
|
||||
"/repo_utils/",
|
||||
"/utils/",
|
||||
"/tools/",
|
||||
"/agents/",
|
||||
}
|
||||
|
||||
# allow having multiple repository checkouts and not needing to remember to rerun
|
||||
@ -94,7 +94,7 @@ def pytest_configure(config):
|
||||
config.addinivalue_line("markers", "is_pipeline_test: mark test to run only when pipelines are tested")
|
||||
config.addinivalue_line("markers", "is_staging_test: mark test to run only in the staging environment")
|
||||
config.addinivalue_line("markers", "accelerate_tests: mark test that require accelerate")
|
||||
config.addinivalue_line("markers", "tool_tests: mark the tool tests that are run on their specific schedule")
|
||||
config.addinivalue_line("markers", "agent_tests: mark the agent tests that are run on their specific schedule")
|
||||
config.addinivalue_line("markers", "not_device_test: mark the tests always running on cpu")
|
||||
|
||||
|
||||
|
@ -23,7 +23,7 @@
|
||||
title: Load and train adapters with 🤗 PEFT
|
||||
- local: model_sharing
|
||||
title: Share your model
|
||||
- local: transformers_agents
|
||||
- local: agents
|
||||
title: Agents
|
||||
- local: llm_tutorial
|
||||
title: Generation with LLMs
|
||||
@ -133,8 +133,6 @@
|
||||
title: Notebooks with examples
|
||||
- local: community
|
||||
title: Community resources
|
||||
- local: custom_tools
|
||||
title: Custom Tools and Prompts
|
||||
- local: troubleshooting
|
||||
title: Troubleshoot
|
||||
- local: hf_quantizer
|
||||
|
490
docs/source/en/agents.md
Normal file
490
docs/source/en/agents.md
Normal file
@ -0,0 +1,490 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
# Agents and tools
|
||||
|
||||
[[open-in-colab]]
|
||||
|
||||
### What is an agent?
|
||||
|
||||
Large Language Models (LLMs) trained to perform [causal language modeling](./tasks/language_modeling.) can tackle a wide range of tasks, but they often struggle with basic tasks like logic, calculation, and search. When prompted in domains in which they do not perform well, they often fail to generate the answer we expect them to.
|
||||
|
||||
One approach to overcome this weakness is to create an *agent*.
|
||||
|
||||
An agent is a system that uses an LLM as its engine, and it has access to functions called *tools*.
|
||||
|
||||
These *tools* are functions for performing a task, and they contain all necessary description for the agent to properly use them.
|
||||
|
||||
The agent can be programmed to:
|
||||
- devise a series of actions/tools and run them all at once like the `CodeAgent` for example
|
||||
- plan and execute actions/tools one by one and wait for the outcome of each action before launching the next one like the `ReactJsonAgent` for example
|
||||
|
||||
### Types of agents
|
||||
|
||||
#### Code agent
|
||||
|
||||
This agent has a planning step, then generates python code to execute all its actions at once. It natively handles different input and output types for its tools, thus it is the recommended choice for multimodal tasks.
|
||||
|
||||
#### React agents
|
||||
|
||||
This is the go-to agent to solve reasoning tasks, since the ReAct framework ([Yao et al., 2022](https://huggingface.co/papers/2210.03629)) makes it really efficient to think on the basis of its previous observations.
|
||||
|
||||
We implement two versions of ReactJsonAgent:
|
||||
- [`~ReactJsonAgent`] generates tool calls as a JSON in its output.
|
||||
- [`~ReactCodeAgent`] is a new type of ReactJsonAgent that generates its tool calls as blobs of code, which works really well for LLMs that have strong coding performance.
|
||||
|
||||
> [!TIP]
|
||||
> Read [Open-source LLMs as LangChain Agents](https://huggingface.co/blog/open-source-llms-as-agents) blog post to learn more the ReAct agent.
|
||||
|
||||

|
||||
|
||||
For example, here is how a ReAct agent would work its way through the following question.
|
||||
|
||||
```py3
|
||||
>>> agent.run(
|
||||
... "How many more blocks (also denoted as layers) in BERT base encoder than the encoder from the architecture proposed in Attention is All You Need?",
|
||||
... )
|
||||
=====New task=====
|
||||
How many more blocks (also denoted as layers) in BERT base encoder than the encoder from the architecture proposed in Attention is All You Need?
|
||||
====Agent is executing the code below:
|
||||
bert_blocks = search(query="number of blocks in BERT base encoder")
|
||||
print("BERT blocks:", bert_blocks)
|
||||
====
|
||||
Print outputs:
|
||||
BERT blocks: twelve encoder blocks
|
||||
|
||||
====Agent is executing the code below:
|
||||
attention_layer = search(query="number of layers in Attention is All You Need")
|
||||
print("Attention layers:", attention_layer)
|
||||
====
|
||||
Print outputs:
|
||||
Attention layers: Encoder: The encoder is composed of a stack of N = 6 identical layers. Each layer has two sub-layers. The first is a multi-head self-attention mechanism, and the second is a simple, position- 2 Page 3 Figure 1: The Transformer - model architecture.
|
||||
|
||||
====Agent is executing the code below:
|
||||
bert_blocks = 12
|
||||
attention_layers = 6
|
||||
diff = bert_blocks - attention_layers
|
||||
print("Difference in blocks:", diff)
|
||||
final_answer(diff)
|
||||
====
|
||||
|
||||
Print outputs:
|
||||
Difference in blocks: 6
|
||||
|
||||
Final answer: 6
|
||||
```
|
||||
|
||||
### How can I build an agent?
|
||||
|
||||
To initialize an agent, you need these arguments:
|
||||
|
||||
- an LLM to power your agent - the agent is not exactly the LLM, it’s more like the agent is a program that uses an LLM as its engine.
|
||||
- a system prompt: what the LLM engine will be prompted with to generate its output
|
||||
- a toolbox from which the agent pick tools to execute
|
||||
- a parser to extract from the LLM output which tools are to call and with which arguments
|
||||
|
||||
Upon initialization of the agent system, the tool attributes are used to generate a tool description, then baked into the agent’s `system_prompt` to let it know which tools it can use and why.
|
||||
|
||||
To start with, please install the `agents` extras in order to install all default dependencies.
|
||||
|
||||
```bash
|
||||
pip install transformers[agents]
|
||||
```
|
||||
|
||||
Build your LLM engine by defining a `llm_engine` method which accepts a list of [messages](./chat_templating.) and returns text. This callable also needs to accept a `stop` argument that indicates when to stop generating.
|
||||
|
||||
```python
|
||||
from huggingface_hub import login, InferenceClient
|
||||
|
||||
login("<YOUR_HUGGINGFACEHUB_API_TOKEN>")
|
||||
|
||||
client = InferenceClient(model="meta-llama/Meta-Llama-3-70B-Instruct")
|
||||
|
||||
def llm_engine(messages, stop_sequences=["Task"]) -> str:
|
||||
response = client.chat_completion(messages, stop=stop_sequences, max_tokens=1000)
|
||||
answer = response.choices[0].message.content
|
||||
return answer
|
||||
```
|
||||
|
||||
You could use any `llm_engine` method as long as:
|
||||
1. it follows the [messages format](./chat_templating.md) for its input (`List[Dict[str, str]]`) and returns a `str`
|
||||
2. it stops generating outputs at the sequences passed in the argument `stop`
|
||||
|
||||
You also need a `tools` argument which accepts a list of `Tools`. You can provide an empty list for `tools`, but use the default toolbox with the optional argument `add_base_tools=True`.
|
||||
|
||||
Now you can create an agent, like `CodeAgent`, and run it. For convenience, we also provide the `HfEngine` class that uses `huggingface_hub.InferenceClient` under the hood.
|
||||
|
||||
```python
|
||||
from transformers import CodeAgent, HfEngine
|
||||
|
||||
llm_engine = HfEngine(model="meta-llama/Meta-Llama-3-70B-Instruct")
|
||||
agent = CodeAgent(tools=[], llm_engine=llm_engine, add_base_tools=True)
|
||||
|
||||
agent.run(
|
||||
"Could you translate this sentence from French, say it out loud and return the audio.",
|
||||
sentence="Où est la boulangerie la plus proche?",
|
||||
)
|
||||
```
|
||||
|
||||
This will be handy in case of emergency baguette need!
|
||||
You can even leave the argument `llm_engine` undefined, and an [~HfEngine] will be created by default.
|
||||
|
||||
```python
|
||||
from transformers import CodeAgent
|
||||
|
||||
agent = CodeAgent(tools=[], add_base_tools=True)
|
||||
|
||||
agent.run(
|
||||
"Could you translate this sentence from French, say it out loud and give me the audio.",
|
||||
sentence="Où est la boulangerie la plus proche?",
|
||||
)
|
||||
```
|
||||
|
||||
Note that we used an additional `sentence` argument: you can pass text as additional arguments to the model.
|
||||
|
||||
You can also use this to indicate the path to local or remote files for the model to use:
|
||||
|
||||
```py
|
||||
from transformers import ReactCodeAgent
|
||||
|
||||
agent = ReactCodeAgent(tools=[], llm_engine=llm_engine, add_base_tools=True)
|
||||
|
||||
agent.run("Why does Mike not know many people in New York?", audio="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/recording.mp3")
|
||||
```
|
||||
|
||||
|
||||
The prompt and output parser were automatically defined, but you can easily inspect them by calling the `system_prompt_template` on your agent.
|
||||
|
||||
```python
|
||||
print(agent.system_prompt_template)
|
||||
```
|
||||
|
||||
It's important to explain as clearly as possible the task you want to perform.
|
||||
Every [`~Agent.run`] operation is independent, and since an agent is powered by an LLM, minor variations in your prompt might yield completely different results.
|
||||
You can also run an agent consecutively for different tasks: each time the attributes `agent.task` and `agent.logs` will be re-initialized.
|
||||
|
||||
|
||||
#### Code execution
|
||||
|
||||
A Python interpreter executes the code on a set of inputs passed along with your tools.
|
||||
This should be safe because the only functions that can be called are the tools you provided (especially if it's only tools by Hugging Face) and the print function, so you're already limited in what can be executed.
|
||||
|
||||
The Python interpreter also doesn't allow any attribute lookup or imports (which shouldn't be needed for passing inputs/outputs to a small set of functions) so all the most obvious attacks shouldn't be an issue.
|
||||
|
||||
The execution will stop at any code trying to perform an illegal operation or if there is a regular Python error with the code generated by the agent.
|
||||
|
||||
### The system prompt
|
||||
|
||||
An agent, or rather the LLM that drives the agent, generates an output based on the system prompt. The system prompt can be customized and tailored to the intended task. For example, check the system prompt for the `ReactCodeAgent` (below version is slightly simplified).
|
||||
|
||||
```text
|
||||
You will be given a task to solve as best you can.
|
||||
You have access to the following tools:
|
||||
<<tool_descriptions>>
|
||||
|
||||
To solve the task, you must plan forward to proceed in a series of steps, in a cycle of 'Thought:', 'Code:', and 'Observation:' sequences.
|
||||
|
||||
At each step, in the 'Thought:' sequence, you should first explain your reasoning towards solving the task, then the tools that you want to use.
|
||||
Then in the 'Code:' sequence, you shold write the code in simple Python. The code sequence must end with '/End code' sequence.
|
||||
During each intermediate step, you can use 'print()' to save whatever important information you will then need.
|
||||
These print outputs will then be available in the 'Observation:' field, for using this information as input for the next step.
|
||||
|
||||
In the end you have to return a final answer using the `final_answer` tool.
|
||||
|
||||
Here are a few examples using notional tools:
|
||||
---
|
||||
{examples}
|
||||
|
||||
Above example were using notional tools that might not exist for you. You only have acces to those tools:
|
||||
<<tool_names>>
|
||||
You also can perform computations in the python code you generate.
|
||||
|
||||
Always provide a 'Thought:' and a 'Code:\n```py' sequence ending with '```<end_code>' sequence. You MUST provide at least the 'Code:' sequence to move forward.
|
||||
|
||||
Remember to not perform too many operations in a single code block! You should split the task into intermediate code blocks.
|
||||
Print results at the end of each step to save the intermediate results. Then use final_answer() to return the final result.
|
||||
|
||||
Remember to make sure that variables you use are all defined.
|
||||
|
||||
Now Begin!
|
||||
```
|
||||
|
||||
The system prompt includes:
|
||||
- An *introduction* that explains how the agent should behave and what tools are.
|
||||
- A description of all the tools that is defined by a `<<tool_descriptions>>` token that is dynamically replaced at runtime with the tools defined/chosen by the user.
|
||||
- The tool description comes from the tool attributes, `name`, `description`, `inputs` and `output_type`, and a simple `jinja2` template that you can refine.
|
||||
- The expected output format.
|
||||
|
||||
You could improve the system prompt, for example, by adding an explanation of the output format.
|
||||
|
||||
For maximum flexibility, you can overwrite the whole system prompt template by passing your custom prompt as an argument to the `system_prompt` parameter.
|
||||
|
||||
```python
|
||||
from transformers import ReactJsonAgent
|
||||
from transformers.agents import PythonInterpreterTool
|
||||
|
||||
agent = ReactJsonAgent(tools=[PythonInterpreterTool()], system_prompt="{your_custom_prompt}")
|
||||
```
|
||||
|
||||
> [!WARNING]
|
||||
> Please make sure to define the `<<tool_descriptions>>` string somewhere in the `template` so the agent is aware
|
||||
of the available tools.
|
||||
|
||||
## Tools
|
||||
|
||||
A tool is an atomic function to be used by an agent.
|
||||
|
||||
You can for instance check the [~PythonInterpreterTool]: it has a name, a description, input descriptions, an output type, and a `__call__` method to perform the action.
|
||||
|
||||
When the agent is initialized, the tool attributes are used to generate a tool description which is baked into the agent's system prompt. This lets the agent know which tools it can use and why.
|
||||
|
||||
### Default toolbox
|
||||
|
||||
Transformers comes with a default toolbox for empowering agents, that you can add to your agent upon initialization with argument `add_base_tools = True`:
|
||||
|
||||
- **Document question answering**: given a document (such as a PDF) in image format, answer a question on this document ([Donut](./model_doc/donut))
|
||||
- **Image question answering**: given an image, answer a question on this image ([VILT](./model_doc/vilt))
|
||||
- **Speech to text**: given an audio recording of a person talking, transcribe the speech into text ([Whisper](./model_doc/whisper))
|
||||
- **Text to speech**: convert text to speech ([SpeechT5](./model_doc/speecht5))
|
||||
- **Translation**: translates a given sentence from source language to target language.
|
||||
- **Python code interpreter**: runs your the LLM generated Python code in a secure environment. This tool will only be added to [~ReactJsonAgent] if you use `add_base_tools=True`, since code-based tools can already execute Python code
|
||||
|
||||
|
||||
You can manually use a tool by calling the [`load_tool`] function and a task to perform.
|
||||
|
||||
|
||||
```python
|
||||
from transformers import load_tool
|
||||
|
||||
tool = load_tool("text-to-speech")
|
||||
audio = tool("This is a text to speech tool")
|
||||
```
|
||||
|
||||
|
||||
### Create a new tool
|
||||
|
||||
You can create your own tool for use cases not covered by the default tools from Hugging Face.
|
||||
For example, let's create a tool that returns the most downloaded model for a given task from the Hub.
|
||||
|
||||
You'll start with the code below.
|
||||
|
||||
```python
|
||||
from huggingface_hub import list_models
|
||||
|
||||
task = "text-classification"
|
||||
|
||||
model = next(iter(list_models(filter=task, sort="downloads", direction=-1)))
|
||||
print(model.id)
|
||||
```
|
||||
|
||||
This code can be converted into a class that inherits from the [`Tool`] superclass.
|
||||
|
||||
|
||||
The custom tool needs:
|
||||
- An attribute `name`, which corresponds to the name of the tool itself. The name usually describes what the tool does. Since the code returns the model with the most downloads for a task, let's name is `model_download_counter`.
|
||||
- An attribute `description` is used to populate the agent's system prompt.
|
||||
- An `inputs` attribute, which is a dictionary with keys `"type"` and `"description"`. It contains information that helps the Python interpreter make educated choices about the input.
|
||||
- An `output_type` attribute, which specifies the output type.
|
||||
- A `forward` method which contains the inference code to be executed.
|
||||
|
||||
|
||||
```python
|
||||
from transformers import Tool
|
||||
from huggingface_hub import list_models
|
||||
|
||||
class HFModelDownloadsTool(Tool):
|
||||
name = "model_download_counter"
|
||||
description = (
|
||||
"This is a tool that returns the most downloaded model of a given task on the Hugging Face Hub. "
|
||||
"It returns the name of the checkpoint."
|
||||
)
|
||||
|
||||
inputs = {
|
||||
"task": {
|
||||
"type": "text",
|
||||
"description": "the task category (such as text-classification, depth-estimation, etc)",
|
||||
}
|
||||
}
|
||||
output_type = "text"
|
||||
|
||||
def forward(self, task: str):
|
||||
model = next(iter(list_models(filter=task, sort="downloads", direction=-1)))
|
||||
return model.id
|
||||
```
|
||||
|
||||
Now that the custom `HfModelDownloadsTool` class is ready, you can save it to a file named `model_downloads.py` and import it for use.
|
||||
|
||||
|
||||
```python
|
||||
from model_downloads import HFModelDownloadsTool
|
||||
|
||||
tool = HFModelDownloadsTool()
|
||||
```
|
||||
|
||||
You can also share your custom tool to the Hub by calling [`~Tool.push_to_hub`] on the tool. Make sure you've created a repository for it on the Hub and are using a token with read access.
|
||||
|
||||
```python
|
||||
tool.push_to_hub("{your_username}/hf-model-downloads")
|
||||
```
|
||||
|
||||
Load the tool with the [`~Tool.load_tool`] function and pass it to the `tools` parameter in your agent.
|
||||
|
||||
```python
|
||||
from transformers import load_tool, CodeAgent
|
||||
|
||||
model_download_tool = load_tool("m-ric/hf-model-downloads")
|
||||
agent = CodeAgent(tools=[model_download_tool], llm_engine=llm_engine)
|
||||
agent.run(
|
||||
"Can you give me the name of the model that has the most downloads in the 'text-to-video' task on the Hugging Face Hub?"
|
||||
)
|
||||
```
|
||||
|
||||
You get the following:
|
||||
```text
|
||||
======== New task ========
|
||||
Can you give me the name of the model that has the most downloads in the 'text-to-video' task on the Hugging Face Hub?
|
||||
==== Agent is executing the code below:
|
||||
most_downloaded_model = model_download_counter(task="text-to-video")
|
||||
print(f"The most downloaded model for the 'text-to-video' task is {most_downloaded_model}.")
|
||||
====
|
||||
```
|
||||
|
||||
And the output:
|
||||
`"The most downloaded model for the 'text-to-video' task is ByteDance/AnimateDiff-Lightning."`
|
||||
|
||||
|
||||
### Manage agent toolbox
|
||||
|
||||
If you have already initialized an agent, it is inconvenient to reinitialize it from scratch with a tool you want to use. With Transformers, you can manage an agent's toolbox by adding or replacing a tool.
|
||||
|
||||
Let's add the `model_download_tool` to an existing agent initialized with only the default toolbox.
|
||||
|
||||
```python
|
||||
from transformers import CodeAgent
|
||||
|
||||
agent = CodeAgent(tools=[], llm_engine=llm_engine, add_base_tools=True)
|
||||
agent.toolbox.add_tool(model_download_tool)
|
||||
```
|
||||
Now we can leverage both the new tool and the previous text-to-speech tool:
|
||||
|
||||
```python
|
||||
agent.run(
|
||||
"Can you read out loud the name of the model that has the most downloads in the 'text-to-video' task on the Hugging Face Hub and return the audio?"
|
||||
)
|
||||
```
|
||||
|
||||
|
||||
| **Audio** |
|
||||
|------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| <audio controls><source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/damo.wav" type="audio/wav"/> |
|
||||
|
||||
|
||||
> [!WARNING]
|
||||
> Beware when adding tools to an agent that already works well because it can bias selection towards your tool or select another tool other than the one already defined.
|
||||
|
||||
|
||||
Use the `agent.toolbox.update_tool()` method to replace an existing tool in the agent's toolbox.
|
||||
This is useful if your new tool is a one-to-one replacement of the existing tool because the agent already knows how to perform that specific task.
|
||||
Just make sure the new tool follows the same API as the replaced tool or adapt the system prompt template to ensure all examples using the replaced tool are updated.
|
||||
|
||||
|
||||
### Use a collection of tools
|
||||
|
||||
You can leverage tool collections by using the ToolCollection object, with the slug of the collection you want to use.
|
||||
Then pass them as a list to initialize you agent, and start using them!
|
||||
|
||||
```py
|
||||
from transformers import ToolCollection, ReactCodeAgent
|
||||
|
||||
image_tool_collection = ToolCollection(collection_slug="huggingface-tools/diffusion-tools-6630bb19a942c2306a2cdb6f")
|
||||
agent = ReactCodeAgent(tools=[*image_tool_collection.tools], add_base_tools=True)
|
||||
|
||||
agent.run("Please draw me a picture of rivers and lakes.")
|
||||
```
|
||||
|
||||
To speed up the start, tools are loaded only if called by the agent.
|
||||
|
||||
This gets you this image:
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rivers_and_lakes.png">
|
||||
|
||||
|
||||
### Use gradio-tools
|
||||
|
||||
[gradio-tools](https://github.com/freddyaboulton/gradio-tools) is a powerful library that allows using Hugging
|
||||
Face Spaces as tools. It supports many existing Spaces as well as custom Spaces.
|
||||
|
||||
Transformers supports `gradio_tools` with the [`Tool.from_gradio`] method. For example, let's use the [`StableDiffusionPromptGeneratorTool`](https://github.com/freddyaboulton/gradio-tools/blob/main/gradio_tools/tools/prompt_generator.py) from `gradio-tools` toolkit for improving prompts to generate better images.
|
||||
|
||||
Import and instantiate the tool, then pass it to the `Tool.from_gradio` method:
|
||||
|
||||
```python
|
||||
from gradio_tools import StableDiffusionPromptGeneratorTool
|
||||
from transformers import Tool, load_tool, CodeAgent
|
||||
|
||||
gradio_prompt_generator_tool = StableDiffusionPromptGeneratorTool()
|
||||
prompt_generator_tool = Tool.from_gradio(gradio_prompt_generator_tool)
|
||||
```
|
||||
|
||||
Now you can use it just like any other tool. For example, let's improve the prompt `a rabbit wearing a space suit`.
|
||||
|
||||
```python
|
||||
image_generation_tool = load_tool('huggingface-tools/text-to-image')
|
||||
agent = CodeAgent(tools=[prompt_generator_tool, image_generation_tool], llm_engine=llm_engine)
|
||||
|
||||
agent.run(
|
||||
"Improve this prompt, then generate an image of it.", prompt='A rabbit wearing a space suit'
|
||||
)
|
||||
```
|
||||
|
||||
The model adequately leverages the tool:
|
||||
```text
|
||||
======== New task ========
|
||||
Improve this prompt, then generate an image of it.
|
||||
You have been provided with these initial arguments: {'prompt': 'A rabbit wearing a space suit'}.
|
||||
==== Agent is executing the code below:
|
||||
improved_prompt = StableDiffusionPromptGenerator(query=prompt)
|
||||
while improved_prompt == "QUEUE_FULL":
|
||||
improved_prompt = StableDiffusionPromptGenerator(query=prompt)
|
||||
print(f"The improved prompt is {improved_prompt}.")
|
||||
image = image_generator(prompt=improved_prompt)
|
||||
====
|
||||
```
|
||||
|
||||
Before finally generating the image:
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png">
|
||||
|
||||
|
||||
> [!WARNING]
|
||||
> gradio-tools require *textual* inputs and outputs even when working with different modalities like image and audio objects. Image and audio inputs and outputs are currently incompatible.
|
||||
|
||||
### Use LangChain tools
|
||||
|
||||
We love Langchain and think it has a very compelling suite of tools.
|
||||
To import a tool from LangChain, use the `from_langchain()` method.
|
||||
|
||||
Here is how you can use it to recreate the intro's search result using a LangChain web search tool.
|
||||
|
||||
```python
|
||||
from langchain.agents import load_tools
|
||||
from transformers import Tool, ReactCodeAgent
|
||||
|
||||
search_tool = Tool.from_langchain(load_tools(["serpapi"])[0])
|
||||
|
||||
agent = ReactCodeAgent(tools=[search_tool])
|
||||
|
||||
agent.run("How many more blocks (also denoted as layers) in BERT base encoder than the encoder from the architecture proposed in Attention is All You Need?")
|
||||
```
|
@ -1,798 +0,0 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Custom Tools and Prompts
|
||||
|
||||
<Tip>
|
||||
|
||||
If you are not aware of what tools and agents are in the context of transformers, we recommend you read the
|
||||
[Transformers Agents](transformers_agents) page first.
|
||||
|
||||
</Tip>
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Transformers Agents is an experimental API that is subject to change at any time. Results returned by the agents
|
||||
can vary as the APIs or underlying models are prone to change.
|
||||
|
||||
</Tip>
|
||||
|
||||
Creating and using custom tools and prompts is paramount to empowering the agent and having it perform new tasks.
|
||||
In this guide we'll take a look at:
|
||||
|
||||
- How to customize the prompt
|
||||
- How to use custom tools
|
||||
- How to create custom tools
|
||||
|
||||
## Customizing the prompt
|
||||
|
||||
As explained in [Transformers Agents](transformers_agents) agents can run in [`~Agent.run`] and [`~Agent.chat`] mode.
|
||||
Both the `run` and `chat` modes underlie the same logic. The language model powering the agent is conditioned on a long
|
||||
prompt and completes the prompt by generating the next tokens until the stop token is reached.
|
||||
The only difference between the two modes is that during the `chat` mode the prompt is extended with
|
||||
previous user inputs and model generations. This allows the agent to have access to past interactions,
|
||||
seemingly giving the agent some kind of memory.
|
||||
|
||||
### Structure of the prompt
|
||||
|
||||
Let's take a closer look at how the prompt is structured to understand how it can be best customized.
|
||||
The prompt is structured broadly into four parts.
|
||||
|
||||
1. Introduction: how the agent should behave, explanation of the concept of tools.
|
||||
2. Description of all the tools. This is defined by a `<<all_tools>>` token that is dynamically replaced at runtime with the tools defined/chosen by the user.
|
||||
3. A set of examples of tasks and their solution
|
||||
4. Current example, and request for solution.
|
||||
|
||||
To better understand each part, let's look at a shortened version of how the `run` prompt can look like:
|
||||
|
||||
````text
|
||||
I will ask you to perform a task, your job is to come up with a series of simple commands in Python that will perform the task.
|
||||
[...]
|
||||
You can print intermediate results if it makes sense to do so.
|
||||
|
||||
Tools:
|
||||
- document_qa: This is a tool that answers a question about a document (pdf). It takes an input named `document` which should be the document containing the information, as well as a `question` that is the question about the document. It returns a text that contains the answer to the question.
|
||||
- image_captioner: This is a tool that generates a description of an image. It takes an input named `image` which should be the image to the caption and returns a text that contains the description in English.
|
||||
[...]
|
||||
|
||||
Task: "Answer the question in the variable `question` about the image stored in the variable `image`. The question is in French."
|
||||
|
||||
I will use the following tools: `translator` to translate the question into English and then `image_qa` to answer the question on the input image.
|
||||
|
||||
Answer:
|
||||
```py
|
||||
translated_question = translator(question=question, src_lang="French", tgt_lang="English")
|
||||
print(f"The translated question is {translated_question}.")
|
||||
answer = image_qa(image=image, question=translated_question)
|
||||
print(f"The answer is {answer}")
|
||||
```
|
||||
|
||||
Task: "Identify the oldest person in the `document` and create an image showcasing the result as a banner."
|
||||
|
||||
I will use the following tools: `document_qa` to find the oldest person in the document, then `image_generator` to generate an image according to the answer.
|
||||
|
||||
Answer:
|
||||
```py
|
||||
answer = document_qa(document, question="What is the oldest person?")
|
||||
print(f"The answer is {answer}.")
|
||||
image = image_generator("A banner showing " + answer)
|
||||
```
|
||||
|
||||
[...]
|
||||
|
||||
Task: "Draw me a picture of rivers and lakes"
|
||||
|
||||
I will use the following
|
||||
````
|
||||
|
||||
The introduction (the text before *"Tools:"*) explains precisely how the model shall behave and what it should do.
|
||||
This part most likely does not need to be customized as the agent shall always behave the same way.
|
||||
|
||||
The second part (the bullet points below *"Tools"*) is dynamically added upon calling `run` or `chat`. There are
|
||||
exactly as many bullet points as there are tools in `agent.toolbox` and each bullet point consists of the name
|
||||
and description of the tool:
|
||||
|
||||
```text
|
||||
- <tool.name>: <tool.description>
|
||||
```
|
||||
|
||||
Let's verify this quickly by loading the document_qa tool and printing out the name and description.
|
||||
|
||||
```py
|
||||
from transformers import load_tool
|
||||
|
||||
document_qa = load_tool("document-question-answering")
|
||||
print(f"- {document_qa.name}: {document_qa.description}")
|
||||
```
|
||||
|
||||
which gives:
|
||||
```text
|
||||
- document_qa: This is a tool that answers a question about a document (pdf). It takes an input named `document` which should be the document containing the information, as well as a `question` that is the question about the document. It returns a text that contains the answer to the question.
|
||||
```
|
||||
|
||||
We can see that the tool name is short and precise. The description includes two parts, the first explaining
|
||||
what the tool does and the second states what input arguments and return values are expected.
|
||||
|
||||
A good tool name and tool description are very important for the agent to correctly use it. Note that the only
|
||||
information the agent has about the tool is its name and description, so one should make sure that both
|
||||
are precisely written and match the style of the existing tools in the toolbox. In particular make sure the description
|
||||
mentions all the arguments expected by name in code-style, along with the expected type and a description of what they
|
||||
are.
|
||||
|
||||
<Tip>
|
||||
|
||||
Check the naming and description of the curated Transformers tools to better understand what name and
|
||||
description a tool is expected to have. You can see all tools with the [`Agent.toolbox`] property.
|
||||
|
||||
</Tip>
|
||||
|
||||
The third part includes a set of curated examples that show the agent exactly what code it should produce
|
||||
for what kind of user request. The large language models empowering the agent are extremely good at
|
||||
recognizing patterns in a prompt and repeating the pattern with new data. Therefore, it is very important
|
||||
that the examples are written in a way that maximizes the likelihood of the agent to generating correct,
|
||||
executable code in practice.
|
||||
|
||||
Let's have a look at one example:
|
||||
|
||||
````text
|
||||
Task: "Identify the oldest person in the `document` and create an image showcasing the result as a banner."
|
||||
|
||||
I will use the following tools: `document_qa` to find the oldest person in the document, then `image_generator` to generate an image according to the answer.
|
||||
|
||||
Answer:
|
||||
```py
|
||||
answer = document_qa(document, question="What is the oldest person?")
|
||||
print(f"The answer is {answer}.")
|
||||
image = image_generator("A banner showing " + answer)
|
||||
```
|
||||
|
||||
````
|
||||
|
||||
The pattern the model is prompted to repeat has three parts: The task statement, the agent's explanation of
|
||||
what it intends to do, and finally the generated code. Every example that is part of the prompt has this exact
|
||||
pattern, thus making sure that the agent will reproduce exactly the same pattern when generating new tokens.
|
||||
|
||||
The prompt examples are curated by the Transformers team and rigorously evaluated on a set of
|
||||
[problem statements](https://github.com/huggingface/transformers/blob/main/src/transformers/tools/evaluate_agent.py)
|
||||
to ensure that the agent's prompt is as good as possible to solve real use cases of the agent.
|
||||
|
||||
The final part of the prompt corresponds to:
|
||||
```text
|
||||
Task: "Draw me a picture of rivers and lakes"
|
||||
|
||||
I will use the following
|
||||
```
|
||||
|
||||
is a final and unfinished example that the agent is tasked to complete. The unfinished example
|
||||
is dynamically created based on the actual user input. For the above example, the user ran:
|
||||
|
||||
```py
|
||||
agent.run("Draw me a picture of rivers and lakes")
|
||||
```
|
||||
|
||||
The user input - *a.k.a* the task: *"Draw me a picture of rivers and lakes"* is cast into the
|
||||
prompt template: "Task: <task> \n\n I will use the following". This sentence makes up the final lines of the
|
||||
prompt the agent is conditioned on, therefore strongly influencing the agent to finish the example
|
||||
exactly in the same way it was previously done in the examples.
|
||||
|
||||
Without going into too much detail, the chat template has the same prompt structure with the
|
||||
examples having a slightly different style, *e.g.*:
|
||||
|
||||
````text
|
||||
[...]
|
||||
|
||||
=====
|
||||
|
||||
Human: Answer the question in the variable `question` about the image stored in the variable `image`.
|
||||
|
||||
Assistant: I will use the tool `image_qa` to answer the question on the input image.
|
||||
|
||||
```py
|
||||
answer = image_qa(text=question, image=image)
|
||||
print(f"The answer is {answer}")
|
||||
```
|
||||
|
||||
Human: I tried this code, it worked but didn't give me a good result. The question is in French
|
||||
|
||||
Assistant: In this case, the question needs to be translated first. I will use the tool `translator` to do this.
|
||||
|
||||
```py
|
||||
translated_question = translator(question=question, src_lang="French", tgt_lang="English")
|
||||
print(f"The translated question is {translated_question}.")
|
||||
answer = image_qa(text=translated_question, image=image)
|
||||
print(f"The answer is {answer}")
|
||||
```
|
||||
|
||||
=====
|
||||
|
||||
[...]
|
||||
````
|
||||
|
||||
Contrary, to the examples of the `run` prompt, each `chat` prompt example has one or more exchanges between the
|
||||
*Human* and the *Assistant*. Every exchange is structured similarly to the example of the `run` prompt.
|
||||
The user's input is appended to behind *Human:* and the agent is prompted to first generate what needs to be done
|
||||
before generating code. An exchange can be based on previous exchanges, therefore allowing the user to refer
|
||||
to past exchanges as is done *e.g.* above by the user's input of "I tried **this** code" refers to the
|
||||
previously generated code of the agent.
|
||||
|
||||
Upon running `.chat`, the user's input or *task* is cast into an unfinished example of the form:
|
||||
```text
|
||||
Human: <user-input>\n\nAssistant:
|
||||
```
|
||||
which the agent completes. Contrary to the `run` command, the `chat` command then appends the completed example
|
||||
to the prompt, thus giving the agent more context for the next `chat` turn.
|
||||
|
||||
Great now that we know how the prompt is structured, let's see how we can customize it!
|
||||
|
||||
### Writing good user inputs
|
||||
|
||||
While large language models are getting better and better at understanding users' intentions, it helps
|
||||
enormously to be as precise as possible to help the agent pick the correct task. What does it mean to be
|
||||
as precise as possible?
|
||||
|
||||
The agent sees a list of tool names and their description in its prompt. The more tools are added the
|
||||
more difficult it becomes for the agent to choose the correct tool and it's even more difficult to choose
|
||||
the correct sequences of tools to run. Let's look at a common failure case, here we will only return
|
||||
the code to analyze it.
|
||||
|
||||
```py
|
||||
from transformers import HfAgent
|
||||
|
||||
agent = HfAgent("https://api-inference.huggingface.co/models/bigcode/starcoder")
|
||||
|
||||
agent.run("Show me a tree", return_code=True)
|
||||
```
|
||||
|
||||
gives:
|
||||
|
||||
```text
|
||||
==Explanation from the agent==
|
||||
I will use the following tool: `image_segmenter` to create a segmentation mask for the image.
|
||||
|
||||
|
||||
==Code generated by the agent==
|
||||
mask = image_segmenter(image, prompt="tree")
|
||||
```
|
||||
|
||||
which is probably not what we wanted. Instead, it is more likely that we want an image of a tree to be generated.
|
||||
To steer the agent more towards using a specific tool it can therefore be very helpful to use important keywords that
|
||||
are present in the tool's name and description. Let's have a look.
|
||||
```py
|
||||
agent.toolbox["image_generator"].description
|
||||
```
|
||||
|
||||
```text
|
||||
'This is a tool that creates an image according to a prompt, which is a text description. It takes an input named `prompt` which contains the image description and outputs an image.
|
||||
```
|
||||
|
||||
The name and description make use of the keywords "image", "prompt", "create" and "generate". Using these words will most likely work better here. Let's refine our prompt a bit.
|
||||
|
||||
```py
|
||||
agent.run("Create an image of a tree", return_code=True)
|
||||
```
|
||||
|
||||
gives:
|
||||
```text
|
||||
==Explanation from the agent==
|
||||
I will use the following tool `image_generator` to generate an image of a tree.
|
||||
|
||||
|
||||
==Code generated by the agent==
|
||||
image = image_generator(prompt="tree")
|
||||
```
|
||||
|
||||
Much better! That looks more like what we want. In short, when you notice that the agent struggles to
|
||||
correctly map your task to the correct tools, try looking up the most pertinent keywords of the tool's name
|
||||
and description and try refining your task request with it.
|
||||
|
||||
### Customizing the tool descriptions
|
||||
|
||||
As we've seen before the agent has access to each of the tools' names and descriptions. The base tools
|
||||
should have very precise names and descriptions, however, you might find that it could help to change
|
||||
the description or name of a tool for your specific use case. This might become especially important
|
||||
when you've added multiple tools that are very similar or if you want to use your agent only for a certain
|
||||
domain, *e.g.* image generation and transformations.
|
||||
|
||||
A common problem is that the agent confuses image generation with image transformation/modification when
|
||||
used a lot for image generation tasks, *e.g.*
|
||||
```py
|
||||
agent.run("Make an image of a house and a car", return_code=True)
|
||||
```
|
||||
returns
|
||||
```text
|
||||
==Explanation from the agent==
|
||||
I will use the following tools `image_generator` to generate an image of a house and `image_transformer` to transform the image of a car into the image of a house.
|
||||
|
||||
==Code generated by the agent==
|
||||
house_image = image_generator(prompt="A house")
|
||||
car_image = image_generator(prompt="A car")
|
||||
house_car_image = image_transformer(image=car_image, prompt="A house")
|
||||
```
|
||||
|
||||
which is probably not exactly what we want here. It seems like the agent has a difficult time
|
||||
to understand the difference between `image_generator` and `image_transformer` and often uses the two together.
|
||||
|
||||
We can help the agent here by changing the tool name and description of `image_transformer`. Let's instead call it `modifier`
|
||||
to disassociate it a bit from "image" and "prompt":
|
||||
```py
|
||||
agent.toolbox["modifier"] = agent.toolbox.pop("image_transformer")
|
||||
agent.toolbox["modifier"].description = agent.toolbox["modifier"].description.replace(
|
||||
"transforms an image according to a prompt", "modifies an image"
|
||||
)
|
||||
```
|
||||
|
||||
Now "modify" is a strong cue to use the new image processor which should help with the above prompt. Let's run it again.
|
||||
|
||||
```py
|
||||
agent.run("Make an image of a house and a car", return_code=True)
|
||||
```
|
||||
|
||||
Now we're getting:
|
||||
```text
|
||||
==Explanation from the agent==
|
||||
I will use the following tools: `image_generator` to generate an image of a house, then `image_generator` to generate an image of a car.
|
||||
|
||||
|
||||
==Code generated by the agent==
|
||||
house_image = image_generator(prompt="A house")
|
||||
car_image = image_generator(prompt="A car")
|
||||
```
|
||||
|
||||
which is definitely closer to what we had in mind! However, we want to have both the house and car in the same image. Steering the task more toward single image generation should help:
|
||||
|
||||
```py
|
||||
agent.run("Create image: 'A house and car'", return_code=True)
|
||||
```
|
||||
|
||||
```text
|
||||
==Explanation from the agent==
|
||||
I will use the following tool: `image_generator` to generate an image.
|
||||
|
||||
|
||||
==Code generated by the agent==
|
||||
image = image_generator(prompt="A house and car")
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Agents are still brittle for many use cases, especially when it comes to
|
||||
slightly more complex use cases like generating an image of multiple objects.
|
||||
Both the agent itself and the underlying prompt will be further improved in the coming
|
||||
months making sure that agents become more robust to a variety of user inputs.
|
||||
|
||||
</Tip>
|
||||
|
||||
### Customizing the whole prompt
|
||||
|
||||
To give the user maximum flexibility, the whole prompt template as explained in [above](#structure-of-the-prompt)
|
||||
can be overwritten by the user. In this case make sure that your custom prompt includes an introduction section,
|
||||
a tool section, an example section, and an unfinished example section. If you want to overwrite the `run` prompt template,
|
||||
you can do as follows:
|
||||
|
||||
```py
|
||||
template = """ [...] """
|
||||
|
||||
agent = HfAgent(your_endpoint, run_prompt_template=template)
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Please make sure to have the `<<all_tools>>` string and the `<<prompt>>` defined somewhere in the `template` so that the agent can be aware
|
||||
of the tools, it has available to it as well as correctly insert the user's prompt.
|
||||
|
||||
</Tip>
|
||||
|
||||
Similarly, one can overwrite the `chat` prompt template. Note that the `chat` mode always uses the following format for the exchanges:
|
||||
```text
|
||||
Human: <<task>>
|
||||
|
||||
Assistant:
|
||||
```
|
||||
|
||||
Therefore it is important that the examples of the custom `chat` prompt template also make use of this format.
|
||||
You can overwrite the `chat` template at instantiation as follows.
|
||||
|
||||
```python
|
||||
template = """ [...] """
|
||||
|
||||
agent = HfAgent(url_endpoint=your_endpoint, chat_prompt_template=template)
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Please make sure to have the `<<all_tools>>` string defined somewhere in the `template` so that the agent can be aware
|
||||
of the tools, it has available to it.
|
||||
|
||||
</Tip>
|
||||
|
||||
In both cases, you can pass a repo ID instead of the prompt template if you would like to use a template hosted by someone in the community. The default prompts live in [this repo](https://huggingface.co/datasets/huggingface-tools/default-prompts) as an example.
|
||||
|
||||
To upload your custom prompt on a repo on the Hub and share it with the community just make sure:
|
||||
- to use a dataset repository
|
||||
- to put the prompt template for the `run` command in a file named `run_prompt_template.txt`
|
||||
- to put the prompt template for the `chat` command in a file named `chat_prompt_template.txt`
|
||||
|
||||
## Using custom tools
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Using custom tools in your local runtime means that you'll download code to run on your machine.
|
||||
|
||||
ALWAYS inspect the tool you're downloading before loading it within your runtime, as you would do when
|
||||
installing a package using pip/npm/apt.
|
||||
|
||||
</Tip>
|
||||
|
||||
In this section, we'll be leveraging two existing custom tools that are specific to image generation:
|
||||
|
||||
- We replace [huggingface-tools/image-transformation](https://huggingface.co/spaces/huggingface-tools/image-transformation),
|
||||
with [diffusers/controlnet-canny-tool](https://huggingface.co/spaces/diffusers/controlnet-canny-tool)
|
||||
to allow for more image modifications.
|
||||
- We add a new tool for image upscaling to the default toolbox:
|
||||
[diffusers/latent-upscaler-tool](https://huggingface.co/spaces/diffusers/latent-upscaler-tool) replace the existing image-transformation tool.
|
||||
|
||||
We'll start by loading the custom tools with the convenient [`load_tool`] function:
|
||||
|
||||
```py
|
||||
from transformers import load_tool
|
||||
|
||||
controlnet_transformer = load_tool("diffusers/controlnet-canny-tool")
|
||||
upscaler = load_tool("diffusers/latent-upscaler-tool")
|
||||
```
|
||||
|
||||
Upon adding custom tools to an agent, the tools' descriptions and names are automatically
|
||||
included in the agents' prompts. Thus, it is imperative that custom tools have
|
||||
a well-written description and name in order for the agent to understand how to use them.
|
||||
Let's take a look at the description and name of `controlnet_transformer`:
|
||||
|
||||
```py
|
||||
print(f"Description: '{controlnet_transformer.description}'")
|
||||
print(f"Name: '{controlnet_transformer.name}'")
|
||||
```
|
||||
|
||||
gives
|
||||
```text
|
||||
Description: 'This is a tool that transforms an image with ControlNet according to a prompt.
|
||||
It takes two inputs: `image`, which should be the image to transform, and `prompt`, which should be the prompt to use to change it. It returns the modified image.'
|
||||
Name: 'image_transformer'
|
||||
```
|
||||
|
||||
The name and description are accurate and fit the style of the [curated set of tools](./transformers_agents#a-curated-set-of-tools).
|
||||
Next, let's instantiate an agent with `controlnet_transformer` and `upscaler`:
|
||||
|
||||
```py
|
||||
tools = [controlnet_transformer, upscaler]
|
||||
agent = HfAgent("https://api-inference.huggingface.co/models/bigcode/starcoder", additional_tools=tools)
|
||||
```
|
||||
|
||||
This command should give you the following info:
|
||||
|
||||
```text
|
||||
image_transformer has been replaced by <transformers_modules.diffusers.controlnet-canny-tool.bd76182c7777eba9612fc03c0
|
||||
8718a60c0aa6312.image_transformation.ControlNetTransformationTool object at 0x7f1d3bfa3a00> as provided in `additional_tools`
|
||||
```
|
||||
|
||||
The set of curated tools already has an `image_transformer` tool which is hereby replaced with our custom tool.
|
||||
|
||||
<Tip>
|
||||
|
||||
Overwriting existing tools can be beneficial if we want to use a custom tool exactly for the same task as an existing tool
|
||||
because the agent is well-versed in using the specific task. Beware that the custom tool should follow the exact same API
|
||||
as the overwritten tool in this case, or you should adapt the prompt template to make sure all examples using that
|
||||
tool are updated.
|
||||
|
||||
</Tip>
|
||||
|
||||
The upscaler tool was given the name `image_upscaler` which is not yet present in the default toolbox and is therefore simply added to the list of tools.
|
||||
You can always have a look at the toolbox that is currently available to the agent via the `agent.toolbox` attribute:
|
||||
|
||||
```py
|
||||
print("\n".join([f"- {a}" for a in agent.toolbox.keys()]))
|
||||
```
|
||||
|
||||
```text
|
||||
- document_qa
|
||||
- image_captioner
|
||||
- image_qa
|
||||
- image_segmenter
|
||||
- transcriber
|
||||
- summarizer
|
||||
- text_classifier
|
||||
- text_qa
|
||||
- text_reader
|
||||
- translator
|
||||
- image_transformer
|
||||
- text_downloader
|
||||
- image_generator
|
||||
- video_generator
|
||||
- image_upscaler
|
||||
```
|
||||
|
||||
Note how `image_upscaler` is now part of the agents' toolbox.
|
||||
|
||||
Let's now try out the new tools! We will re-use the image we generated in [Transformers Agents Quickstart](./transformers_agents#single-execution-run).
|
||||
|
||||
```py
|
||||
from diffusers.utils import load_image
|
||||
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rivers_and_lakes.png"
|
||||
)
|
||||
```
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rivers_and_lakes.png" width=200>
|
||||
|
||||
Let's transform the image into a beautiful winter landscape:
|
||||
|
||||
```py
|
||||
image = agent.run("Transform the image: 'A frozen lake and snowy forest'", image=image)
|
||||
```
|
||||
|
||||
```text
|
||||
==Explanation from the agent==
|
||||
I will use the following tool: `image_transformer` to transform the image.
|
||||
|
||||
|
||||
==Code generated by the agent==
|
||||
image = image_transformer(image, prompt="A frozen lake and snowy forest")
|
||||
```
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rivers_and_lakes_winter.png" width=200>
|
||||
|
||||
The new image processing tool is based on ControlNet which can make very strong modifications to the image.
|
||||
By default the image processing tool returns an image of size 512x512 pixels. Let's see if we can upscale it.
|
||||
|
||||
```py
|
||||
image = agent.run("Upscale the image", image)
|
||||
```
|
||||
|
||||
```text
|
||||
==Explanation from the agent==
|
||||
I will use the following tool: `image_upscaler` to upscale the image.
|
||||
|
||||
|
||||
==Code generated by the agent==
|
||||
upscaled_image = image_upscaler(image)
|
||||
```
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rivers_and_lakes_winter_upscale.png" width=400>
|
||||
|
||||
The agent automatically mapped our prompt "Upscale the image" to the just added upscaler tool purely based on the description and name of the upscaler tool
|
||||
and was able to correctly run it.
|
||||
|
||||
Next, let's have a look at how you can create a new custom tool.
|
||||
|
||||
### Adding new tools
|
||||
|
||||
In this section, we show how to create a new tool that can be added to the agent.
|
||||
|
||||
#### Creating a new tool
|
||||
|
||||
We'll first start by creating a tool. We'll add the not-so-useful yet fun task of fetching the model on the Hugging Face
|
||||
Hub with the most downloads for a given task.
|
||||
|
||||
We can do that with the following code:
|
||||
|
||||
```python
|
||||
from huggingface_hub import list_models
|
||||
|
||||
task = "text-classification"
|
||||
|
||||
model = next(iter(list_models(filter=task, sort="downloads", direction=-1)))
|
||||
print(model.id)
|
||||
```
|
||||
|
||||
For the task `text-classification`, this returns `'facebook/bart-large-mnli'`, for `translation` it returns `'google-t5/t5-base`.
|
||||
|
||||
How do we convert this to a tool that the agent can leverage? All tools depend on the superclass `Tool` that holds the
|
||||
main attributes necessary. We'll create a class that inherits from it:
|
||||
|
||||
```python
|
||||
from transformers import Tool
|
||||
|
||||
|
||||
class HFModelDownloadsTool(Tool):
|
||||
pass
|
||||
```
|
||||
|
||||
This class has a few needs:
|
||||
- An attribute `name`, which corresponds to the name of the tool itself. To be in tune with other tools which have a
|
||||
performative name, we'll name it `model_download_counter`.
|
||||
- An attribute `description`, which will be used to populate the prompt of the agent.
|
||||
- `inputs` and `outputs` attributes. Defining this will help the python interpreter make educated choices about types,
|
||||
and will allow for a gradio-demo to be spawned when we push our tool to the Hub. They're both a list of expected
|
||||
values, which can be `text`, `image`, or `audio`.
|
||||
- A `__call__` method which contains the inference code. This is the code we've played with above!
|
||||
|
||||
Here's what our class looks like now:
|
||||
|
||||
```python
|
||||
from transformers import Tool
|
||||
from huggingface_hub import list_models
|
||||
|
||||
|
||||
class HFModelDownloadsTool(Tool):
|
||||
name = "model_download_counter"
|
||||
description = (
|
||||
"This is a tool that returns the most downloaded model of a given task on the Hugging Face Hub. "
|
||||
"It takes the name of the category (such as text-classification, depth-estimation, etc), and "
|
||||
"returns the name of the checkpoint."
|
||||
)
|
||||
|
||||
inputs = ["text"]
|
||||
outputs = ["text"]
|
||||
|
||||
def __call__(self, task: str):
|
||||
model = next(iter(list_models(filter=task, sort="downloads", direction=-1)))
|
||||
return model.id
|
||||
```
|
||||
|
||||
We now have our tool handy. Save it in a file and import it from your main script. Let's name this file
|
||||
`model_downloads.py`, so the resulting import code looks like this:
|
||||
|
||||
```python
|
||||
from model_downloads import HFModelDownloadsTool
|
||||
|
||||
tool = HFModelDownloadsTool()
|
||||
```
|
||||
|
||||
In order to let others benefit from it and for simpler initialization, we recommend pushing it to the Hub under your
|
||||
namespace. To do so, just call `push_to_hub` on the `tool` variable:
|
||||
|
||||
```python
|
||||
tool.push_to_hub("hf-model-downloads")
|
||||
```
|
||||
|
||||
You now have your code on the Hub! Let's take a look at the final step, which is to have the agent use it.
|
||||
|
||||
#### Having the agent use the tool
|
||||
|
||||
We now have our tool that lives on the Hub which can be instantiated as such (change the user name for your tool):
|
||||
|
||||
```python
|
||||
from transformers import load_tool
|
||||
|
||||
tool = load_tool("lysandre/hf-model-downloads")
|
||||
```
|
||||
|
||||
In order to use it in the agent, simply pass it in the `additional_tools` parameter of the agent initialization method:
|
||||
|
||||
```python
|
||||
from transformers import HfAgent
|
||||
|
||||
agent = HfAgent("https://api-inference.huggingface.co/models/bigcode/starcoder", additional_tools=[tool])
|
||||
|
||||
agent.run(
|
||||
"Can you read out loud the name of the model that has the most downloads in the 'text-to-video' task on the Hugging Face Hub?"
|
||||
)
|
||||
```
|
||||
which outputs the following:
|
||||
```text
|
||||
==Code generated by the agent==
|
||||
model = model_download_counter(task="text-to-video")
|
||||
print(f"The model with the most downloads is {model}.")
|
||||
audio_model = text_reader(model)
|
||||
|
||||
|
||||
==Result==
|
||||
The model with the most downloads is damo-vilab/text-to-video-ms-1.7b.
|
||||
```
|
||||
|
||||
and generates the following audio.
|
||||
|
||||
| **Audio** |
|
||||
|------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| <audio controls><source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/damo.wav" type="audio/wav"/> |
|
||||
|
||||
|
||||
<Tip>
|
||||
|
||||
Depending on the LLM, some are quite brittle and require very exact prompts in order to work well. Having a well-defined
|
||||
name and description of the tool is paramount to having it be leveraged by the agent.
|
||||
|
||||
</Tip>
|
||||
|
||||
### Replacing existing tools
|
||||
|
||||
Replacing existing tools can be done simply by assigning a new item to the agent's toolbox. Here's how one would do so:
|
||||
|
||||
```python
|
||||
from transformers import HfAgent, load_tool
|
||||
|
||||
agent = HfAgent("https://api-inference.huggingface.co/models/bigcode/starcoder")
|
||||
agent.toolbox["image-transformation"] = load_tool("diffusers/controlnet-canny-tool")
|
||||
```
|
||||
|
||||
<Tip>
|
||||
|
||||
Beware when replacing tools with others! This will also adjust the agent's prompt. This can be good if you have a better
|
||||
prompt suited for the task, but it can also result in your tool being selected way more than others or for other
|
||||
tools to be selected instead of the one you have defined.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Leveraging gradio-tools
|
||||
|
||||
[gradio-tools](https://github.com/freddyaboulton/gradio-tools) is a powerful library that allows using Hugging
|
||||
Face Spaces as tools. It supports many existing Spaces as well as custom Spaces to be designed with it.
|
||||
|
||||
We offer support for `gradio_tools` by using the `Tool.from_gradio` method. For example, we want to take
|
||||
advantage of the `StableDiffusionPromptGeneratorTool` tool offered in the `gradio-tools` toolkit so as to
|
||||
improve our prompts and generate better images.
|
||||
|
||||
We first import the tool from `gradio_tools` and instantiate it:
|
||||
|
||||
```python
|
||||
from gradio_tools import StableDiffusionPromptGeneratorTool
|
||||
|
||||
gradio_tool = StableDiffusionPromptGeneratorTool()
|
||||
```
|
||||
|
||||
We pass that instance to the `Tool.from_gradio` method:
|
||||
|
||||
```python
|
||||
from transformers import Tool
|
||||
|
||||
tool = Tool.from_gradio(gradio_tool)
|
||||
```
|
||||
|
||||
Now we can manage it exactly as we would a usual custom tool. We leverage it to improve our prompt
|
||||
` a rabbit wearing a space suit`:
|
||||
|
||||
```python
|
||||
from transformers import HfAgent
|
||||
|
||||
agent = HfAgent("https://api-inference.huggingface.co/models/bigcode/starcoder", additional_tools=[tool])
|
||||
|
||||
agent.run("Generate an image of the `prompt` after improving it.", prompt="A rabbit wearing a space suit")
|
||||
```
|
||||
|
||||
The model adequately leverages the tool:
|
||||
```text
|
||||
==Explanation from the agent==
|
||||
I will use the following tools: `StableDiffusionPromptGenerator` to improve the prompt, then `image_generator` to generate an image according to the improved prompt.
|
||||
|
||||
|
||||
==Code generated by the agent==
|
||||
improved_prompt = StableDiffusionPromptGenerator(prompt)
|
||||
print(f"The improved prompt is {improved_prompt}.")
|
||||
image = image_generator(improved_prompt)
|
||||
```
|
||||
|
||||
Before finally generating the image:
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png">
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
gradio-tools requires *textual* inputs and outputs, even when working with different modalities. This implementation
|
||||
works with image and audio objects. The two are currently incompatible, but will rapidly become compatible as we
|
||||
work to improve the support.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Future compatibility with Langchain
|
||||
|
||||
We love Langchain and think it has a very compelling suite of tools. In order to handle these tools,
|
||||
Langchain requires *textual* inputs and outputs, even when working with different modalities.
|
||||
This is often the serialized version (i.e., saved to disk) of the objects.
|
||||
|
||||
This difference means that multi-modality isn't handled between transformers-agents and langchain.
|
||||
We aim for this limitation to be resolved in future versions, and welcome any help from avid langchain
|
||||
users to help us achieve this compatibility.
|
||||
|
||||
We would love to have better support. If you would like to help, please
|
||||
[open an issue](https://github.com/huggingface/transformers/issues/new) and share what you have in mind.
|
@ -28,30 +28,27 @@ contains the API docs for the underlying classes.
|
||||
|
||||
## Agents
|
||||
|
||||
We provide three types of agents: [`HfAgent`] uses inference endpoints for opensource models, [`LocalAgent`] uses a model of your choice locally and [`OpenAiAgent`] uses OpenAI closed models.
|
||||
|
||||
### HfAgent
|
||||
|
||||
[[autodoc]] HfAgent
|
||||
|
||||
### LocalAgent
|
||||
|
||||
[[autodoc]] LocalAgent
|
||||
|
||||
### OpenAiAgent
|
||||
|
||||
[[autodoc]] OpenAiAgent
|
||||
|
||||
### AzureOpenAiAgent
|
||||
|
||||
[[autodoc]] AzureOpenAiAgent
|
||||
We provide two types of agents, based on the main [`Agent`] class:
|
||||
- [`CodeAgent`] acts in one shot, generating code to solve the task, then executes it at once.
|
||||
- [`ReactAgent`] acts step by step, each step consisting of one thought, then one tool call and execution. It has two classes:
|
||||
- [`ReactJsonAgent`] writes its tool calls in JSON.
|
||||
- [`ReactCodeAgent`] writes its tool calls in Python code.
|
||||
|
||||
### Agent
|
||||
|
||||
[[autodoc]] Agent
|
||||
- chat
|
||||
- run
|
||||
- prepare_for_new_chat
|
||||
|
||||
### CodeAgent
|
||||
|
||||
[[autodoc]] CodeAgent
|
||||
|
||||
### React agents
|
||||
|
||||
[[autodoc]] ReactAgent
|
||||
|
||||
[[autodoc]] ReactJsonAgent
|
||||
|
||||
[[autodoc]] ReactCodeAgent
|
||||
|
||||
## Tools
|
||||
|
||||
@ -63,18 +60,50 @@ We provide three types of agents: [`HfAgent`] uses inference endpoints for opens
|
||||
|
||||
[[autodoc]] Tool
|
||||
|
||||
### Toolbox
|
||||
|
||||
[[autodoc]] Toolbox
|
||||
|
||||
### PipelineTool
|
||||
|
||||
[[autodoc]] PipelineTool
|
||||
|
||||
### RemoteTool
|
||||
|
||||
[[autodoc]] RemoteTool
|
||||
|
||||
### launch_gradio_demo
|
||||
|
||||
[[autodoc]] launch_gradio_demo
|
||||
|
||||
### ToolCollection
|
||||
|
||||
[[autodoc]] ToolCollection
|
||||
|
||||
## Engines
|
||||
|
||||
You're free to create and use your own engines to be usable by the Agents framework.
|
||||
These engines have the following specification:
|
||||
1. Follow the [messages format](../chat_templating.md) for its input (`List[Dict[str, str]]`) and return a string.
|
||||
2. Stop generating outputs *before* the sequences passed in the argument `stop_sequences`
|
||||
|
||||
### HfEngine
|
||||
|
||||
For convenience, we have added a `HfEngine` that implements the points above and uses an inference endpoint for the execution of the LLM.
|
||||
|
||||
```python
|
||||
>>> from transformers import HfEngine
|
||||
|
||||
>>> messages = [
|
||||
... {"role": "user", "content": "Hello, how are you?"},
|
||||
... {"role": "assistant", "content": "I'm doing great. How can I help you today?"},
|
||||
... {"role": "user", "content": "No need to help, take it easy."},
|
||||
... ]
|
||||
|
||||
>>> HfEngine()(messages, stop_sequences=["conversation"])
|
||||
|
||||
"That's very kind of you to say! It's always nice to have a relaxed "
|
||||
```
|
||||
|
||||
[[autodoc]] HfEngine
|
||||
|
||||
|
||||
## Agent Types
|
||||
|
||||
Agents can handle any type of object in-between tools; tools, being completely multimodal, can accept and return
|
||||
@ -94,12 +123,12 @@ These types have three specific purposes:
|
||||
|
||||
### AgentText
|
||||
|
||||
[[autodoc]] transformers.tools.agent_types.AgentText
|
||||
[[autodoc]] transformers.agents.agent_types.AgentText
|
||||
|
||||
### AgentImage
|
||||
|
||||
[[autodoc]] transformers.tools.agent_types.AgentImage
|
||||
[[autodoc]] transformers.agents.agent_types.AgentImage
|
||||
|
||||
### AgentAudio
|
||||
|
||||
[[autodoc]] transformers.tools.agent_types.AgentAudio
|
||||
[[autodoc]] transformers.agents.agent_types.AgentAudio
|
||||
|
@ -1,323 +0,0 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Transformers Agents
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Transformers Agents is an experimental API which is subject to change at any time. Results returned by the agents
|
||||
can vary as the APIs or underlying models are prone to change.
|
||||
|
||||
</Tip>
|
||||
|
||||
Transformers version v4.29.0, building on the concept of *tools* and *agents*. You can play with in
|
||||
[this colab](https://colab.research.google.com/drive/1c7MHD-T1forUPGcC_jlwsIptOzpG3hSj).
|
||||
|
||||
In short, it provides a natural language API on top of transformers: we define a set of curated tools and design an
|
||||
agent to interpret natural language and to use these tools. It is extensible by design; we curated some relevant tools,
|
||||
but we'll show you how the system can be extended easily to use any tool developed by the community.
|
||||
|
||||
Let's start with a few examples of what can be achieved with this new API. It is particularly powerful when it comes
|
||||
to multimodal tasks, so let's take it for a spin to generate images and read text out loud.
|
||||
|
||||
```py
|
||||
agent.run("Caption the following image", image=image)
|
||||
```
|
||||
|
||||
| **Input** | **Output** |
|
||||
|-----------------------------------------------------------------------------------------------------------------------------|-----------------------------------|
|
||||
| <img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/beaver.png" width=200> | A beaver is swimming in the water |
|
||||
|
||||
---
|
||||
|
||||
```py
|
||||
agent.run("Read the following text out loud", text=text)
|
||||
```
|
||||
| **Input** | **Output** |
|
||||
|-------------------------------------------------------------------------------------------------------------------------|----------------------------------------------|
|
||||
| A beaver is swimming in the water | <audio controls><source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tts_example.wav" type="audio/wav"> your browser does not support the audio element. </audio>
|
||||
|
||||
---
|
||||
|
||||
```py
|
||||
agent.run(
|
||||
"In the following `document`, where will the TRRF Scientific Advisory Council Meeting take place?",
|
||||
document=document,
|
||||
)
|
||||
```
|
||||
| **Input** | **Output** |
|
||||
|-----------------------------------------------------------------------------------------------------------------------------|----------------|
|
||||
| <img src="https://datasets-server.huggingface.co/assets/hf-internal-testing/example-documents/--/hf-internal-testing--example-documents/test/0/image/image.jpg" width=200> | ballroom foyer |
|
||||
|
||||
## Quickstart
|
||||
|
||||
Before being able to use `agent.run`, you will need to instantiate an agent, which is a large language model (LLM).
|
||||
We provide support for openAI models as well as opensource alternatives from BigCode and OpenAssistant. The openAI
|
||||
models perform better (but require you to have an openAI API key, so cannot be used for free); Hugging Face is
|
||||
providing free access to endpoints for BigCode and OpenAssistant models.
|
||||
|
||||
To start with, please install the `agents` extras in order to install all default dependencies.
|
||||
```bash
|
||||
pip install transformers[agents]
|
||||
```
|
||||
|
||||
To use openAI models, you instantiate an [`OpenAiAgent`] after installing the `openai` dependency:
|
||||
|
||||
```bash
|
||||
pip install openai
|
||||
```
|
||||
|
||||
|
||||
```py
|
||||
from transformers import OpenAiAgent
|
||||
|
||||
agent = OpenAiAgent(model="text-davinci-003", api_key="<your_api_key>")
|
||||
```
|
||||
|
||||
To use BigCode or OpenAssistant, start by logging in to have access to the Inference API:
|
||||
|
||||
```py
|
||||
from huggingface_hub import login
|
||||
|
||||
login("<YOUR_TOKEN>")
|
||||
```
|
||||
|
||||
Then, instantiate the agent
|
||||
|
||||
```py
|
||||
from transformers import HfAgent
|
||||
|
||||
# Starcoder
|
||||
agent = HfAgent("https://api-inference.huggingface.co/models/bigcode/starcoder")
|
||||
# StarcoderBase
|
||||
# agent = HfAgent("https://api-inference.huggingface.co/models/bigcode/starcoderbase")
|
||||
# OpenAssistant
|
||||
# agent = HfAgent(url_endpoint="https://api-inference.huggingface.co/models/OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5")
|
||||
```
|
||||
|
||||
This is using the inference API that Hugging Face provides for free at the moment. If you have your own inference
|
||||
endpoint for this model (or another one) you can replace the URL above with your URL endpoint.
|
||||
|
||||
<Tip>
|
||||
|
||||
StarCoder and OpenAssistant are free to use and perform admirably well on simple tasks. However, the checkpoints
|
||||
don't hold up when handling more complex prompts. If you're facing such an issue, we recommend trying out the OpenAI
|
||||
model which, while sadly not open-source, performs better at this given time.
|
||||
|
||||
</Tip>
|
||||
|
||||
You're now good to go! Let's dive into the two APIs that you now have at your disposal.
|
||||
|
||||
### Single execution (run)
|
||||
|
||||
The single execution method is when using the [`~Agent.run`] method of the agent:
|
||||
|
||||
```py
|
||||
agent.run("Draw me a picture of rivers and lakes.")
|
||||
```
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rivers_and_lakes.png" width=200>
|
||||
|
||||
It automatically selects the tool (or tools) appropriate for the task you want to perform and runs them appropriately. It
|
||||
can perform one or several tasks in the same instruction (though the more complex your instruction, the more likely
|
||||
the agent is to fail).
|
||||
|
||||
```py
|
||||
agent.run("Draw me a picture of the sea then transform the picture to add an island")
|
||||
```
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/sea_and_island.png" width=200>
|
||||
|
||||
<br/>
|
||||
|
||||
|
||||
Every [`~Agent.run`] operation is independent, so you can run it several times in a row with different tasks.
|
||||
|
||||
Note that your `agent` is just a large-language model, so small variations in your prompt might yield completely
|
||||
different results. It's important to explain as clearly as possible the task you want to perform. We go more in-depth
|
||||
on how to write good prompts [here](custom_tools#writing-good-user-inputs).
|
||||
|
||||
If you'd like to keep a state across executions or to pass non-text objects to the agent, you can do so by specifying
|
||||
variables that you would like the agent to use. For example, you could generate the first image of rivers and lakes,
|
||||
and ask the model to update that picture to add an island by doing the following:
|
||||
|
||||
```python
|
||||
picture = agent.run("Generate a picture of rivers and lakes.")
|
||||
updated_picture = agent.run("Transform the image in `picture` to add an island to it.", picture=picture)
|
||||
```
|
||||
|
||||
<Tip>
|
||||
|
||||
This can be helpful when the model is unable to understand your request and mixes tools. An example would be:
|
||||
|
||||
```py
|
||||
agent.run("Draw me the picture of a capybara swimming in the sea")
|
||||
```
|
||||
|
||||
Here, the model could interpret in two ways:
|
||||
- Have the `text-to-image` generate a capybara swimming in the sea
|
||||
- Or, have the `text-to-image` generate capybara, then use the `image-transformation` tool to have it swim in the sea
|
||||
|
||||
In case you would like to force the first scenario, you could do so by passing it the prompt as an argument:
|
||||
|
||||
```py
|
||||
agent.run("Draw me a picture of the `prompt`", prompt="a capybara swimming in the sea")
|
||||
```
|
||||
|
||||
</Tip>
|
||||
|
||||
|
||||
### Chat-based execution (chat)
|
||||
|
||||
The agent also has a chat-based approach, using the [`~Agent.chat`] method:
|
||||
|
||||
```py
|
||||
agent.chat("Generate a picture of rivers and lakes")
|
||||
```
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rivers_and_lakes.png" width=200>
|
||||
|
||||
```py
|
||||
agent.chat("Transform the picture so that there is a rock in there")
|
||||
```
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rivers_and_lakes_and_beaver.png" width=200>
|
||||
|
||||
<br/>
|
||||
|
||||
This is an interesting approach when you want to keep the state across instructions. It's better for experimentation,
|
||||
but will tend to be much better at single instructions rather than complex instructions (which the [`~Agent.run`]
|
||||
method is better at handling).
|
||||
|
||||
This method can also take arguments if you would like to pass non-text types or specific prompts.
|
||||
|
||||
### ⚠️ Remote execution
|
||||
|
||||
For demonstration purposes and so that it could be used with all setups, we had created remote executors for several
|
||||
of the default tools the agent has access for the release. These are created using
|
||||
[inference endpoints](https://huggingface.co/inference-endpoints).
|
||||
|
||||
We have turned these off for now, but in order to see how to set up remote executors tools yourself,
|
||||
we recommend reading the [custom tool guide](./custom_tools).
|
||||
|
||||
### What's happening here? What are tools, and what are agents?
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/diagram.png">
|
||||
|
||||
#### Agents
|
||||
|
||||
The "agent" here is a large language model, and we're prompting it so that it has access to a specific set of tools.
|
||||
|
||||
LLMs are pretty good at generating small samples of code, so this API takes advantage of that by prompting the
|
||||
LLM gives a small sample of code performing a task with a set of tools. This prompt is then completed by the
|
||||
task you give your agent and the description of the tools you give it. This way it gets access to the doc of the
|
||||
tools you are using, especially their expected inputs and outputs, and can generate the relevant code.
|
||||
|
||||
#### Tools
|
||||
|
||||
Tools are very simple: they're a single function, with a name, and a description. We then use these tools' descriptions
|
||||
to prompt the agent. Through the prompt, we show the agent how it would leverage tools to perform what was
|
||||
requested in the query.
|
||||
|
||||
This is using brand-new tools and not pipelines, because the agent writes better code with very atomic tools.
|
||||
Pipelines are more refactored and often combine several tasks in one. Tools are meant to be focused on
|
||||
one very simple task only.
|
||||
|
||||
#### Code-execution?!
|
||||
|
||||
This code is then executed with our small Python interpreter on the set of inputs passed along with your tools.
|
||||
We hear you screaming "Arbitrary code execution!" in the back, but let us explain why that is not the case.
|
||||
|
||||
The only functions that can be called are the tools you provided and the print function, so you're already
|
||||
limited in what can be executed. You should be safe if it's limited to Hugging Face tools.
|
||||
|
||||
Then, we don't allow any attribute lookup or imports (which shouldn't be needed anyway for passing along
|
||||
inputs/outputs to a small set of functions) so all the most obvious attacks (and you'd need to prompt the LLM
|
||||
to output them anyway) shouldn't be an issue. If you want to be on the super safe side, you can execute the
|
||||
run() method with the additional argument return_code=True, in which case the agent will just return the code
|
||||
to execute and you can decide whether to do it or not.
|
||||
|
||||
The execution will stop at any line trying to perform an illegal operation or if there is a regular Python error
|
||||
with the code generated by the agent.
|
||||
|
||||
### A curated set of tools
|
||||
|
||||
We identify a set of tools that can empower such agents. Here is an updated list of the tools we have integrated
|
||||
in `transformers`:
|
||||
|
||||
- **Document question answering**: given a document (such as a PDF) in image format, answer a question on this document ([Donut](./model_doc/donut))
|
||||
- **Text question answering**: given a long text and a question, answer the question in the text ([Flan-T5](./model_doc/flan-t5))
|
||||
- **Unconditional image captioning**: Caption the image! ([BLIP](./model_doc/blip))
|
||||
- **Image question answering**: given an image, answer a question on this image ([VILT](./model_doc/vilt))
|
||||
- **Image segmentation**: given an image and a prompt, output the segmentation mask of that prompt ([CLIPSeg](./model_doc/clipseg))
|
||||
- **Speech to text**: given an audio recording of a person talking, transcribe the speech into text ([Whisper](./model_doc/whisper))
|
||||
- **Text to speech**: convert text to speech ([SpeechT5](./model_doc/speecht5))
|
||||
- **Zero-shot text classification**: given a text and a list of labels, identify to which label the text corresponds the most ([BART](./model_doc/bart))
|
||||
- **Text summarization**: summarize a long text in one or a few sentences ([BART](./model_doc/bart))
|
||||
- **Translation**: translate the text into a given language ([NLLB](./model_doc/nllb))
|
||||
|
||||
These tools have an integration in transformers, and can be used manually as well, for example:
|
||||
|
||||
```py
|
||||
from transformers import load_tool
|
||||
|
||||
tool = load_tool("text-to-speech")
|
||||
audio = tool("This is a text to speech tool")
|
||||
```
|
||||
|
||||
### Custom tools
|
||||
|
||||
While we identify a curated set of tools, we strongly believe that the main value provided by this implementation is
|
||||
the ability to quickly create and share custom tools.
|
||||
|
||||
By pushing the code of a tool to a Hugging Face Space or a model repository, you're then able to leverage the tool
|
||||
directly with the agent. We've added a few
|
||||
**transformers-agnostic** tools to the [`huggingface-tools` organization](https://huggingface.co/huggingface-tools):
|
||||
|
||||
- **Text downloader**: to download a text from a web URL
|
||||
- **Text to image**: generate an image according to a prompt, leveraging stable diffusion
|
||||
- **Image transformation**: modify an image given an initial image and a prompt, leveraging instruct pix2pix stable diffusion
|
||||
- **Text to video**: generate a small video according to a prompt, leveraging damo-vilab
|
||||
|
||||
The text-to-image tool we have been using since the beginning is a remote tool that lives in
|
||||
[*huggingface-tools/text-to-image*](https://huggingface.co/spaces/huggingface-tools/text-to-image)! We will
|
||||
continue releasing such tools on this and other organizations, to further supercharge this implementation.
|
||||
|
||||
The agents have by default access to tools that reside on [`huggingface-tools`](https://huggingface.co/huggingface-tools).
|
||||
We explain how to you can write and share your tools as well as leverage any custom tool that resides on the Hub in [following guide](custom_tools).
|
||||
|
||||
### Code generation
|
||||
|
||||
So far we have shown how to use the agents to perform actions for you. However, the agent is only generating code
|
||||
that we then execute using a very restricted Python interpreter. In case you would like to use the code generated in
|
||||
a different setting, the agent can be prompted to return the code, along with tool definition and accurate imports.
|
||||
|
||||
For example, the following instruction
|
||||
```python
|
||||
agent.run("Draw me a picture of rivers and lakes", return_code=True)
|
||||
```
|
||||
|
||||
returns the following code
|
||||
|
||||
```python
|
||||
from transformers import load_tool
|
||||
|
||||
image_generator = load_tool("huggingface-tools/text-to-image")
|
||||
|
||||
image = image_generator(prompt="rivers and lakes")
|
||||
```
|
||||
|
||||
that you can then modify and execute yourself.
|
@ -16,749 +16,11 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
# Custom Tools and Prompts
|
||||
|
||||
<Tip>
|
||||
|
||||
トランスフォーマーのコンテキストでツールとエージェントが何であるかを知らない場合、
|
||||
まず[Transformers Agents](transformers_agents)ページをお読みいただくことをお勧めします。
|
||||
|
||||
</Tip>
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Transformers Agentsは実験的なAPIであり、いつでも変更される可能性があります。
|
||||
エージェントによって返される結果は、APIや基礎となるモデルが変更される可能性があるため、変化することがあります。
|
||||
The Agents framework has significantly changed in version v4.41.0.
|
||||
This document has been removed as it was referencing an older API.
|
||||
|
||||
We eagerly welcome new contributions for the updated API.
|
||||
|
||||
</Tip>
|
||||
|
||||
カスタムツールとプロンプトを作成し、使用することは、エージェントを強化し、新しいタスクを実行させるために非常に重要です。
|
||||
このガイドでは、以下の内容を説明します:
|
||||
|
||||
- プロンプトのカスタマイズ方法
|
||||
- カスタムツールの使用方法
|
||||
- カスタムツールの作成方法
|
||||
|
||||
## Customizing the prompt
|
||||
|
||||
[Transformers Agents](transformers_agents)で説明されているように、エージェントは[`~Agent.run`]および[`~Agent.chat`]モードで実行できます。
|
||||
`run`モードと`chat`モードの両方は同じロジックに基づいています。
|
||||
エージェントを駆動する言語モデルは、長いプロンプトに基づいて条件付けられ、
|
||||
次のトークンを生成して停止トークンに達するまでプロンプトを完了します。
|
||||
両者の唯一の違いは、`chat`モードの間にプロンプトが前のユーザーの入力とモデルの生成と共に拡張されることです。
|
||||
これにより、エージェントは過去の対話にアクセスでき、エージェントにあたかもメモリがあるかのように見えます。
|
||||
|
||||
### Structure of the prompt
|
||||
|
||||
プロンプトがどのように構築され、どのように最適化できるかを理解するために、プロンプトは大まかに4つの部分に分かれています。
|
||||
|
||||
1. イントロダクション:エージェントの振る舞い、ツールの概念の説明。
|
||||
2. すべてのツールの説明。これはユーザーによって定義/選択されたツールでランタイム時に動的に置換される`<<all_tools>>`トークンによって定義されます。
|
||||
3. タスクとその解決策の一連の例。
|
||||
4. 現在の例と解決策の要求。
|
||||
|
||||
各部分をよりよく理解するために、`run`プロンプトがどのように見えるかの簡略版を見てみましょう:
|
||||
|
||||
````text
|
||||
タスクを実行するために、Pythonのシンプルなコマンドのシリーズを考えてくることがあるでしょう。
|
||||
[...]
|
||||
意味がある場合は、中間結果を表示することができます。
|
||||
|
||||
ツール:
|
||||
- document_qa:これはドキュメント(pdf)に関する質問に答えるツールです。情報を含むドキュメントである `document` と、ドキュメントに関する質問である `question` を受け取り、質問に対する回答を含むテキストを返します。
|
||||
- image_captioner:これは画像の説明を生成するツールです。キャプションにする画像である `image` と、説明を含む英語のテキストを返すテキストを受け取ります。
|
||||
[...]
|
||||
|
||||
タスク: "変数 `question` に関する質問に答えるための画像について回答してください。質問はフランス語です。"
|
||||
|
||||
次のツールを使用します:質問を英語に翻訳するための `translator`、そして入力画像に関する質問に答えるための `image_qa`。
|
||||
|
||||
回答:
|
||||
```py
|
||||
translated_question = translator(question=question, src_lang="French", tgt_lang="English")
|
||||
print(f"The translated question is {translated_question}.")
|
||||
answer = image_qa(image=image, question=translated_question)
|
||||
print(f"The answer is {answer}")
|
||||
```
|
||||
|
||||
タスク:「`document`内で最年長の人物を特定し、その結果をバナーとして表示する。」
|
||||
|
||||
以下のツールを使用します:`document_qa`を使用してドキュメント内で最年長の人物を見つけ、その回答に従って`image_generator`を使用して画像を生成します。
|
||||
|
||||
回答:
|
||||
```py
|
||||
answer = document_qa(document, question="What is the oldest person?")
|
||||
print(f"The answer is {answer}.")
|
||||
image = image_generator("A banner showing " + answer)
|
||||
```
|
||||
|
||||
[...]
|
||||
タスク: "川と湖の絵を描いてください"
|
||||
|
||||
以下のものを使用します
|
||||
````
|
||||
|
||||
導入部分("Tools:"の前のテキスト)は、モデルの振る舞いと実行すべきタスクを正確に説明しています。
|
||||
この部分はおそらくエージェントが常に同じ方法で振る舞う必要があるため、カスタマイズする必要はありません。
|
||||
|
||||
2番目の部分("Tools"の下の箇条書き)は、`run`または`chat`を呼び出すたびに動的に追加されます。
|
||||
`agent.toolbox`内のツールの数と同じ数の箇条書きがあり、それぞれの箇条書きにはツールの名前と説明が含まれています。
|
||||
|
||||
```text
|
||||
- <tool.name>: <tool.description>
|
||||
```
|
||||
|
||||
もうすぐ確認しましょう。 `document_qa` ツールを読み込んで名前と説明を出力します。
|
||||
|
||||
```py
|
||||
from transformers import load_tool
|
||||
|
||||
document_qa = load_tool("document-question-answering")
|
||||
print(f"- {document_qa.name}: {document_qa.description}")
|
||||
```
|
||||
|
||||
which gives:
|
||||
```text
|
||||
- document_qa: This is a tool that answers a question about a document (pdf). It takes an input named `document` which should be the document containing the information, as well as a `question` that is the question about the document. It returns a text that contains the answer to the question.
|
||||
```
|
||||
|
||||
ツール説明:
|
||||
このツールは、2つのパートから成り立っています。最初のパートでは、ツールが何を行うかを説明し、2番目のパートでは入力引数と戻り値がどのように期待されるかを述べています。
|
||||
|
||||
良いツール名とツールの説明は、エージェントが正しく使用するために非常に重要です。エージェントがツールについて持っている唯一の情報は、その名前と説明です。したがって、ツール名と説明の両方が正確に記述され、ツールボックス内の既存のツールのスタイルに合致することを確認する必要があります。特に、説明にはコードスタイルで名前で期待されるすべての引数が言及され、期待される型とそれらが何であるかの説明も含めるべきです。
|
||||
|
||||
<Tip>
|
||||
|
||||
キュレートされたTransformersツールの命名と説明を確認して、ツールがどのような名前と説明を持つべきかを理解するのに役立ちます。
|
||||
すべてのツールは[`Agent.toolbox`]プロパティで確認できます。
|
||||
|
||||
</Tip>
|
||||
|
||||
|
||||
カスタマイズされた例:
|
||||
ツールの使い方をエージェントに正確に示す一連の例が含まれています。これらの例は、エージェントが実際に正確で実行可能なコードを生成する可能性を最大化するように書かれているため、非常に重要です。大規模な言語モデルは、プロンプト内のパターンを認識し、新しいデータを使用してそのパターンを繰り返すことに非常に優れています。したがって、実践で正しい実行可能なコードを生成するエージェントの可能性を最大化するように、これらの例は書かれている必要があります。
|
||||
|
||||
以下は、一つの例です:
|
||||
|
||||
````text
|
||||
Task: "Identify the oldest person in the `document` and create an image showcasing the result as a banner."
|
||||
|
||||
I will use the following tools: `document_qa` to find the oldest person in the document, then `image_generator` to generate an image according to the answer.
|
||||
|
||||
Answer:
|
||||
```py
|
||||
answer = document_qa(document, question="What is the oldest person?")
|
||||
print(f"The answer is {answer}.")
|
||||
image = image_generator("A banner showing " + answer)
|
||||
```
|
||||
|
||||
````
|
||||
|
||||
パターン:モデルが繰り返しを行うように指示されるパターンには、3つの部分があります。
|
||||
タスクの声明、エージェントの意図した動作の説明、そして最後に生成されるコードです。
|
||||
プロンプトの一部であるすべての例には、この正確なパターンがあり、エージェントが新しいトークンを生成する際にも
|
||||
同じパターンを再現することを確認しています。
|
||||
|
||||
プロンプトの例はTransformersチームによって厳選され、一連の問題ステートメントで厳密に評価されます。
|
||||
これにより、エージェントのプロンプトがエージェントの実際の使用ケースを解決するためにできるだけ優れたものになります。
|
||||
|
||||
プロンプトの最後の部分に対応しています:
|
||||
|
||||
[こちら](https://github.com/huggingface/transformers/blob/main/src/transformers/tools/evaluate_agent.py)の問題ステートメントで厳密に評価される、エージェントのプロンプトができるだけ優れたものになるように
|
||||
慎重に選定されたプロンプト例を提供しています。
|
||||
|
||||
```text
|
||||
Task: "Draw me a picture of rivers and lakes"
|
||||
|
||||
I will use the following
|
||||
```
|
||||
|
||||
|
||||
これがエージェントに完成させるための最終的で未完成の例です。未完成の例は、実際のユーザー入力に基づいて動的に作成されます。上記の例では、ユーザーが次のように実行しました:
|
||||
|
||||
```py
|
||||
agent.run("Draw me a picture of rivers and lakes")
|
||||
```
|
||||
|
||||
ユーザーの入力 - つまり、タスク:"川と湖の絵を描いてください"は、以下のようなプロンプトテンプレートに変換されます:"タスク:<task> \n\n 次に私は以下を使用します"。
|
||||
この文は、エージェントが条件付けられたプロンプトの最終行を構成し、したがってエージェントに対して前の例とまったく同じ方法で例を終了するよう強く影響します。
|
||||
|
||||
詳細には立ち入りませんが、チャットテンプレートは同じプロンプト構造を持ち、例はわずかに異なるスタイルを持っています。例:
|
||||
|
||||
````text
|
||||
[...]
|
||||
|
||||
=====
|
||||
|
||||
Human: Answer the question in the variable `question` about the image stored in the variable `image`.
|
||||
|
||||
Assistant: I will use the tool `image_qa` to answer the question on the input image.
|
||||
|
||||
```py
|
||||
answer = image_qa(text=question, image=image)
|
||||
print(f"The answer is {answer}")
|
||||
```
|
||||
|
||||
Human: I tried this code, it worked but didn't give me a good result. The question is in French
|
||||
|
||||
Assistant: In this case, the question needs to be translated first. I will use the tool `translator` to do this.
|
||||
|
||||
```py
|
||||
translated_question = translator(question=question, src_lang="French", tgt_lang="English")
|
||||
print(f"The translated question is {translated_question}.")
|
||||
answer = image_qa(text=translated_question, image=image)
|
||||
print(f"The answer is {answer}")
|
||||
```
|
||||
|
||||
=====
|
||||
|
||||
[...]
|
||||
````
|
||||
|
||||
*Human:* `run`プロンプトの例とは対照的に、各`chat`プロンプトの例には*Human*と*Assistant*の間で1つ以上のやりとりがあります。各やりとりは、`run`プロンプトの例と同様の構造になっています。ユーザーの入力は*Human:*の後ろに追加され、エージェントにはコードを生成する前に何を行う必要があるかを最初に生成するように指示されます。やりとりは以前のやりとりに基づいて行われることがあり、ユーザーが「I tried **this** code」と入力したように、以前に生成されたエージェントのコードを参照できます。
|
||||
|
||||
*Assistant:* `.chat`を実行すると、ユーザーの入力または*タスク*が未完了の形式に変換されます:
|
||||
|
||||
```text
|
||||
Human: <user-input>\n\nAssistant:
|
||||
```
|
||||
|
||||
以下のエージェントが完了するコマンドについて説明します。 `run` コマンドとは対照的に、`chat` コマンドは完了した例をプロンプトに追加します。そのため、次の `chat` ターンのためにエージェントにより多くの文脈を提供します。
|
||||
|
||||
さて、プロンプトの構造がわかったところで、どのようにカスタマイズできるかを見てみましょう!
|
||||
|
||||
### Writing good user inputs
|
||||
|
||||
大規模な言語モデルはユーザーの意図を理解する能力がますます向上していますが、エージェントが正しいタスクを選択するのを助けるために、できるだけ正確に記述することが非常に役立ちます。できるだけ正確であるとは何を意味するのでしょうか?
|
||||
|
||||
エージェントは、プロンプトでツール名とその説明のリストを見ています。ツールが追加されるほど、エージェントが正しいツールを選択するのが難しくなり、正しいツールの連続を選択するのはさらに難しくなります。共通の失敗例を見てみましょう。ここではコードのみを返すことにします。
|
||||
|
||||
|
||||
```py
|
||||
from transformers import HfAgent
|
||||
|
||||
agent = HfAgent("https://api-inference.huggingface.co/models/bigcode/starcoder")
|
||||
|
||||
agent.run("Show me a tree", return_code=True)
|
||||
```
|
||||
|
||||
gives:
|
||||
|
||||
```text
|
||||
==Explanation from the agent==
|
||||
I will use the following tool: `image_segmenter` to create a segmentation mask for the image.
|
||||
|
||||
|
||||
==Code generated by the agent==
|
||||
mask = image_segmenter(image, prompt="tree")
|
||||
```
|
||||
|
||||
これはおそらく私たちが望んでいたものではないでしょう。代わりに、木の画像が生成されることがより可能性が高いです。
|
||||
特定のツールを使用するようエージェントを誘導するために、ツールの名前や説明に含まれている重要なキーワードを使用することは非常に役立ちます。さて、詳しく見てみましょう。
|
||||
|
||||
```py
|
||||
agent.toolbox["image_generator"].description
|
||||
```
|
||||
|
||||
```text
|
||||
'This is a tool that creates an image according to a prompt, which is a text description. It takes an input named `prompt` which contains the image description and outputs an image.
|
||||
```
|
||||
|
||||
名前と説明文には、キーワード「画像」、「プロンプト」、「作成」、および「生成」が使用されています。これらの言葉を使用することで、ここでの動作がより効果的になる可能性が高いです。プロンプトを少し詳細に調整しましょう。
|
||||
|
||||
```py
|
||||
agent.run("Create an image of a tree", return_code=True)
|
||||
```
|
||||
|
||||
gives:
|
||||
```text
|
||||
==Explanation from the agent==
|
||||
I will use the following tool `image_generator` to generate an image of a tree.
|
||||
|
||||
|
||||
==Code generated by the agent==
|
||||
image = image_generator(prompt="tree")
|
||||
```
|
||||
|
||||
簡単に言うと、エージェントがタスクを正確に適切なツールにマッピングできない場合は、ツールの名前や説明の最も関連性のあるキーワードを調べて、タスクリクエストをそれに合わせて洗練させてみてください。
|
||||
|
||||
|
||||
### Customizing the tool descriptions
|
||||
|
||||
以前にも見たように、エージェントは各ツールの名前と説明にアクセスできます。ベースのツールは非常に正確な名前と説明を持っているはずですが、特定のユースケースに合わせてツールの説明や名前を変更することが役立つかもしれません。これは、非常に類似した複数のツールを追加した場合や、特定のドメイン(たとえば、画像生成や変換など)でエージェントを使用する場合に特に重要になるかもしれません。
|
||||
|
||||
よくある問題は、エージェントが画像生成タスクに頻繁に使用される場合、画像生成と画像変換/修正を混同することです。
|
||||
|
||||
例:
|
||||
|
||||
```py
|
||||
agent.run("Make an image of a house and a car", return_code=True)
|
||||
```
|
||||
|
||||
returns
|
||||
```text
|
||||
==Explanation from the agent==
|
||||
I will use the following tools `image_generator` to generate an image of a house and `image_transformer` to transform the image of a car into the image of a house.
|
||||
|
||||
==Code generated by the agent==
|
||||
house_image = image_generator(prompt="A house")
|
||||
car_image = image_generator(prompt="A car")
|
||||
house_car_image = image_transformer(image=car_image, prompt="A house")
|
||||
```
|
||||
|
||||
これはおそらく私たちがここで望んでいる正確なものではないようです。エージェントは「image_generator」と「image_transformer」の違いを理解するのが難しいようで、しばしば両方を一緒に使用します。
|
||||
|
||||
ここでエージェントをサポートするために、"image_transformer"のツール名と説明を変更して、少し"image"や"prompt"から切り離してみましょう。代わりにそれを「modifier」と呼びましょう:
|
||||
|
||||
```py
|
||||
agent.toolbox["modifier"] = agent.toolbox.pop("image_transformer")
|
||||
agent.toolbox["modifier"].description = agent.toolbox["modifier"].description.replace(
|
||||
"transforms an image according to a prompt", "modifies an image"
|
||||
)
|
||||
```
|
||||
|
||||
「変更」は、上記のプロンプトに新しい画像プロセッサを使用する強力な手がかりです。それでは、もう一度実行してみましょう。
|
||||
|
||||
|
||||
```py
|
||||
agent.run("Make an image of a house and a car", return_code=True)
|
||||
```
|
||||
|
||||
Now we're getting:
|
||||
```text
|
||||
==Explanation from the agent==
|
||||
I will use the following tools: `image_generator` to generate an image of a house, then `image_generator` to generate an image of a car.
|
||||
|
||||
|
||||
==Code generated by the agent==
|
||||
house_image = image_generator(prompt="A house")
|
||||
car_image = image_generator(prompt="A car")
|
||||
```
|
||||
|
||||
これは、私たちが考えていたものに確実に近づいています!ただし、家と車を同じ画像に含めたいと考えています。タスクを単一の画像生成に向けることで、より適切な方向に進めるはずです:
|
||||
|
||||
```py
|
||||
agent.run("Create image: 'A house and car'", return_code=True)
|
||||
```
|
||||
|
||||
```text
|
||||
==Explanation from the agent==
|
||||
I will use the following tool: `image_generator` to generate an image.
|
||||
|
||||
|
||||
==Code generated by the agent==
|
||||
image = image_generator(prompt="A house and car")
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
エージェントは、特に複数のオブジェクトの画像を生成するなど、やや複雑なユースケースに関しては、まだ多くのユースケースに対して脆弱です。
|
||||
エージェント自体とその基礎となるプロンプトは、今後数ヶ月でさらに改善され、さまざまなユーザーの入力に対してエージェントがより頑健になるようになります。
|
||||
|
||||
</Tip>
|
||||
|
||||
### Customizing the whole project
|
||||
|
||||
ユーザーに最大限の柔軟性を提供するために、[上記](#structure-of-the-prompt)で説明されたプロンプトテンプレート全体をユーザーが上書きできます。この場合、カスタムプロンプトには導入セクション、ツールセクション、例セクション、未完了の例セクションが含まれていることを確認してください。`run` プロンプトテンプレートを上書きしたい場合、以下のように行うことができます:
|
||||
|
||||
|
||||
```py
|
||||
template = """ [...] """
|
||||
|
||||
agent = HfAgent(your_endpoint, run_prompt_template=template)
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
`<<all_tools>>` 文字列と `<<prompt>>` は、エージェントが使用できるツールを認識し、ユーザーのプロンプトを正しく挿入できるように、`template` のどこかに定義されていることを確認してください。
|
||||
|
||||
</Tip>
|
||||
|
||||
同様に、`chat` プロンプトテンプレートを上書きすることもできます。なお、`chat` モードでは常に以下の形式で交換が行われます:
|
||||
|
||||
上記のテキストの上に日本語の翻訳を提供してください。Markdownコードとして書いてください。
|
||||
|
||||
|
||||
```text
|
||||
Human: <<task>>
|
||||
|
||||
Assistant:
|
||||
```
|
||||
|
||||
したがって、カスタム`chat`プロンプトテンプレートの例もこのフォーマットを使用することが重要です。以下のように、インスタンス化時に`chat`テンプレートを上書きできます。
|
||||
|
||||
```python
|
||||
template = """ [...] """
|
||||
|
||||
agent = HfAgent(url_endpoint=your_endpoint, chat_prompt_template=template)
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
`<<all_tools>>` という文字列が `template` 内で定義されていることを確認してください。これにより、エージェントは使用可能なツールを把握できます。
|
||||
|
||||
</Tip>
|
||||
|
||||
両方の場合、プロンプトテンプレートの代わりに、コミュニティの誰かがホストしたテンプレートを使用したい場合は、リポジトリIDを渡すことができます。デフォルトのプロンプトは、[このリポジトリ](https://huggingface.co/datasets/huggingface-tools/default-prompts) にありますので、参考になります。
|
||||
|
||||
カスタムプロンプトをHubのリポジトリにアップロードしてコミュニティと共有する場合は、次のことを確認してください:
|
||||
- データセットリポジトリを使用すること
|
||||
- `run` コマンド用のプロンプトテンプレートを `run_prompt_template.txt` という名前のファイルに配置すること
|
||||
- `chat` コマンド用のプロンプトテンプレートを `chat_prompt_template.txt` という名前のファイルに配置すること
|
||||
|
||||
## Using custom tools
|
||||
|
||||
このセクションでは、画像生成に特化した2つの既存のカスタムツールを利用します:
|
||||
|
||||
- [huggingface-tools/image-transformation](https://huggingface.co/spaces/huggingface-tools/image-transformation) をより多くの画像変更を可能にするために [diffusers/controlnet-canny-tool](https://huggingface.co/spaces/diffusers/controlnet-canny-tool) に置き換えます。
|
||||
- 画像のアップスケーリング用の新しいツールをデフォルトのツールボックスに追加します:[diffusers/latent-upscaler-tool](https://huggingface.co/spaces/diffusers/latent-upscaler-tool) は既存の画像変換ツールを置き換えます。
|
||||
|
||||
便利な [`load_tool`] 関数を使用してカスタムツールをロードします:
|
||||
|
||||
```py
|
||||
from transformers import load_tool
|
||||
|
||||
controlnet_transformer = load_tool("diffusers/controlnet-canny-tool")
|
||||
upscaler = load_tool("diffusers/latent-upscaler-tool")
|
||||
```
|
||||
|
||||
エージェントにカスタムツールを追加すると、ツールの説明と名前がエージェントのプロンプトに自動的に含まれます。したがって、エージェントがカスタムツールの使用方法を理解できるように、カスタムツールには適切に記述された説明と名前が必要です。
|
||||
|
||||
`controlnet_transformer`の説明と名前を見てみましょう。
|
||||
|
||||
最初に、便利な[`load_tool`]関数を使用してカスタムツールをロードします。
|
||||
|
||||
```py
|
||||
print(f"Description: '{controlnet_transformer.description}'")
|
||||
print(f"Name: '{controlnet_transformer.name}'")
|
||||
```
|
||||
|
||||
gives
|
||||
```text
|
||||
Description: 'This is a tool that transforms an image with ControlNet according to a prompt.
|
||||
It takes two inputs: `image`, which should be the image to transform, and `prompt`, which should be the prompt to use to change it. It returns the modified image.'
|
||||
Name: 'image_transformer'
|
||||
```
|
||||
|
||||
名前と説明は正確であり、[厳選されたツール](./transformers_agents#a-curated-set-of-tools)のスタイルに合っています。
|
||||
|
||||
次に、`controlnet_transformer`と`upscaler`を使ってエージェントをインスタンス化します。
|
||||
|
||||
```py
|
||||
tools = [controlnet_transformer, upscaler]
|
||||
agent = HfAgent("https://api-inference.huggingface.co/models/bigcode/starcoder", additional_tools=tools)
|
||||
|
||||
```
|
||||
|
||||
以下のコマンドは、以下の情報を提供します:
|
||||
|
||||
|
||||
```text
|
||||
image_transformer has been replaced by <transformers_modules.diffusers.controlnet-canny-tool.bd76182c7777eba9612fc03c0
|
||||
8718a60c0aa6312.image_transformation.ControlNetTransformationTool object at 0x7f1d3bfa3a00> as provided in `additional_tools`
|
||||
```
|
||||
|
||||
一連の厳選されたツールにはすでに `image_transformer` ツールがあり、これをカスタムツールで置き換えます。
|
||||
|
||||
<Tip>
|
||||
|
||||
既存のツールを上書きすることは、特定のタスクに既存のツールをまったく同じ目的で使用したい場合に有益であることがあります。
|
||||
なぜなら、エージェントはその特定のタスクの使用方法に精通しているからです。この場合、カスタムツールは既存のツールとまったく同じAPIに従うか、そのツールを使用するすべての例が更新されるようにプロンプトテンプレートを適応させる必要があります。
|
||||
|
||||
</Tip>
|
||||
|
||||
アップスケーラーツールには `image_upscaler` という名前が付けられ、これはデフォルトのツールボックスにはまだ存在しないため、単にツールのリストに追加されます。
|
||||
エージェントが現在使用可能なツールボックスを確認するには、`agent.toolbox` 属性を使用できます。
|
||||
|
||||
```py
|
||||
print("\n".join([f"- {a}" for a in agent.toolbox.keys()]))
|
||||
```
|
||||
|
||||
```text
|
||||
- document_qa
|
||||
- image_captioner
|
||||
- image_qa
|
||||
- image_segmenter
|
||||
- transcriber
|
||||
- summarizer
|
||||
- text_classifier
|
||||
- text_qa
|
||||
- text_reader
|
||||
- translator
|
||||
- image_transformer
|
||||
- text_downloader
|
||||
- image_generator
|
||||
- video_generator
|
||||
- image_upscaler
|
||||
```
|
||||
|
||||
注意: `image_upscaler` がエージェントのツールボックスの一部となったことに注目してください。
|
||||
|
||||
それでは、新しいツールを試してみましょう で生成した画像を再利用します。
|
||||
|
||||
|
||||
```py
|
||||
from diffusers.utils import load_image
|
||||
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rivers_and_lakes.png"
|
||||
)
|
||||
```
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rivers_and_lakes.png" width=200>
|
||||
|
||||
美しい冬の風景にこの画像を変身させましょう:
|
||||
|
||||
```py
|
||||
image = agent.run("Transform the image: 'A frozen lake and snowy forest'", image=image)
|
||||
```
|
||||
|
||||
```text
|
||||
==Explanation from the agent==
|
||||
I will use the following tool: `image_transformer` to transform the image.
|
||||
|
||||
|
||||
==Code generated by the agent==
|
||||
image = image_transformer(image, prompt="A frozen lake and snowy forest")
|
||||
```
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rivers_and_lakes_winter.png" width=200>
|
||||
|
||||
新しい画像処理ツールは、非常に強力な画像の変更を行うことができるControlNetに基づいています。
|
||||
デフォルトでは、画像処理ツールはサイズが512x512ピクセルの画像を返します。それを拡大できるか見てみましょう。
|
||||
|
||||
|
||||
```py
|
||||
image = agent.run("Upscale the image", image)
|
||||
```
|
||||
|
||||
```text
|
||||
==Explanation from the agent==
|
||||
I will use the following tool: `image_upscaler` to upscale the image.
|
||||
|
||||
|
||||
==Code generated by the agent==
|
||||
upscaled_image = image_upscaler(image)
|
||||
```
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rivers_and_lakes_winter_upscale.png" width=400>
|
||||
|
||||
|
||||
エージェントは、プロンプト「画像の拡大」を、その説明とツールの名前だけを基に、新たに追加されたアップスケーリングツールに自動的にマッピングし、正しく実行できました。
|
||||
|
||||
次に、新しいカスタムツールを作成する方法を見てみましょう。
|
||||
|
||||
### Adding new tools
|
||||
|
||||
このセクションでは、エージェントに追加できる新しいツールの作成方法を示します。
|
||||
|
||||
#### Creating a new tool
|
||||
|
||||
まず、ツールの作成から始めましょう。次のコードで、特定のタスクに関してHugging Face Hubで最もダウンロードされたモデルを取得する、あまり役立たないけれども楽しいタスクを追加します。
|
||||
|
||||
以下のコードでそれを行うことができます:
|
||||
|
||||
|
||||
```python
|
||||
from huggingface_hub import list_models
|
||||
|
||||
task = "text-classification"
|
||||
|
||||
model = next(iter(list_models(filter=task, sort="downloads", direction=-1)))
|
||||
print(model.id)
|
||||
```
|
||||
|
||||
タスク `text-classification` の場合、これは `'facebook/bart-large-mnli'` を返します。`translation` の場合、`'google-t5/t5-base'` を返します。
|
||||
|
||||
これをエージェントが利用できるツールに変換する方法は何でしょうか?すべてのツールは、主要な属性を保持するスーパークラス `Tool` に依存しています。私たちは、それを継承したクラスを作成します:
|
||||
|
||||
|
||||
```python
|
||||
from transformers import Tool
|
||||
|
||||
|
||||
class HFModelDownloadsTool(Tool):
|
||||
pass
|
||||
```
|
||||
|
||||
このクラスにはいくつかの必要な要素があります:
|
||||
- `name` 属性:これはツール自体の名前に対応し、他のツールと調和するために `model_download_counter` と名付けます。
|
||||
- `description` 属性:これはエージェントのプロンプトを埋めるために使用されます。
|
||||
- `inputs` と `outputs` 属性:これらを定義することで、Python インタープリターが型に関する賢明な選択を行うのに役立ち、ツールをHubにプッシュする際にgradio-demoを生成できるようになります。これらは、予想される値のリストであり、`text`、`image`、または`audio`になることがあります。
|
||||
- `__call__` メソッド:これには推論コードが含まれています。これは上記で試したコードです!
|
||||
|
||||
こちらが現在のクラスの外観です:
|
||||
|
||||
|
||||
```python
|
||||
from transformers import Tool
|
||||
from huggingface_hub import list_models
|
||||
|
||||
|
||||
class HFModelDownloadsTool(Tool):
|
||||
name = "model_download_counter"
|
||||
description = (
|
||||
"This is a tool that returns the most downloaded model of a given task on the Hugging Face Hub. "
|
||||
"It takes the name of the category (such as text-classification, depth-estimation, etc), and "
|
||||
"returns the name of the checkpoint."
|
||||
)
|
||||
|
||||
inputs = ["text"]
|
||||
outputs = ["text"]
|
||||
|
||||
def __call__(self, task: str):
|
||||
model = next(iter(list_models(filter=task, sort="downloads", direction=-1)))
|
||||
return model.id
|
||||
```
|
||||
|
||||
さて、今度はツールが使えるようになりました。このツールをファイルに保存し、メインスクリプトからインポートしましょう。このファイルを `model_downloads.py` という名前にし、結果のインポートコードは次のようになります:
|
||||
|
||||
以下は、現在のクラスの外観です:
|
||||
|
||||
|
||||
```python
|
||||
from model_downloads import HFModelDownloadsTool
|
||||
|
||||
tool = HFModelDownloadsTool()
|
||||
```
|
||||
|
||||
他の人々に利益をもたらし、より簡単な初期化のために、それをHubにあなたの名前空間でプッシュすることをお勧めします。これを行うには、`tool` 変数で `push_to_hub` を呼び出すだけです:
|
||||
|
||||
```python
|
||||
tool.push_to_hub("hf-model-downloads")
|
||||
```
|
||||
|
||||
エージェントがツールを使用する方法について、最終ステップを見てみましょう。
|
||||
|
||||
#### Having the agent use the tool
|
||||
|
||||
Hubにあるツールがあります。これは次のようにインスタンス化できます(ユーザー名をツールに合わせて変更してください):
|
||||
|
||||
```python
|
||||
from transformers import load_tool
|
||||
|
||||
tool = load_tool("lysandre/hf-model-downloads")
|
||||
```
|
||||
|
||||
エージェントで使用するためには、エージェントの初期化メソッドの `additional_tools` パラメータにそれを渡すだけです:
|
||||
|
||||
|
||||
```python
|
||||
from transformers import HfAgent
|
||||
|
||||
agent = HfAgent("https://api-inference.huggingface.co/models/bigcode/starcoder", additional_tools=[tool])
|
||||
|
||||
agent.run(
|
||||
"Can you read out loud the name of the model that has the most downloads in the 'text-to-video' task on the Hugging Face Hub?"
|
||||
)
|
||||
```
|
||||
which outputs the following:
|
||||
```text
|
||||
==Code generated by the agent==
|
||||
model = model_download_counter(task="text-to-video")
|
||||
print(f"The model with the most downloads is {model}.")
|
||||
audio_model = text_reader(model)
|
||||
|
||||
|
||||
==Result==
|
||||
The model with the most downloads is damo-vilab/text-to-video-ms-1.7b.
|
||||
```
|
||||
|
||||
以下のテキストは、次のオーディオを生成します。
|
||||
|
||||
|
||||
|
||||
**Audio** |
|
||||
|------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| <audio controls><source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/damo.wav" type="audio/wav"/> |
|
||||
|
||||
|
||||
<Tip>
|
||||
|
||||
特定のLLMに依存することがあり、うまく機能させるためには非常に正確なプロンプトが必要なものもあります。ツールの名前と説明を明確に定義することは、エージェントによって活用されるために非常に重要です。
|
||||
|
||||
</Tip>
|
||||
|
||||
### Replacing existing tools
|
||||
|
||||
既存のツールを置き換えるには、新しいアイテムをエージェントのツールボックスに割り当てるだけで行うことができます。以下はその方法です:
|
||||
|
||||
```python
|
||||
from transformers import HfAgent, load_tool
|
||||
|
||||
agent = HfAgent("https://api-inference.huggingface.co/models/bigcode/starcoder")
|
||||
agent.toolbox["image-transformation"] = load_tool("diffusers/controlnet-canny-tool")
|
||||
```
|
||||
|
||||
<Tip>
|
||||
|
||||
他のツールでツールを置き換える際には注意が必要です!これにより、エージェントのプロンプトも調整されます。これは、タスクに適したより良いプロンプトを持っている場合には良いことですが、他のツールが選択される確率が高くなり、定義したツールの代わりに他のツールが選択されることもあるかもしれません。
|
||||
|
||||
</Tip>
|
||||
|
||||
## Leveraging gradio-tools
|
||||
|
||||
[gradio-tools](https://github.com/freddyaboulton/gradio-tools)は、Hugging Face Spacesをツールとして使用することを可能にする強力なライブラリです。既存の多くのSpacesおよびカスタムSpacesを設計することもサポートしています。
|
||||
|
||||
我々は、`gradio_tools`を使用して`StableDiffusionPromptGeneratorTool`ツールを活用したいと考えています。このツールは`gradio-tools`ツールキットで提供されており、プロンプトを改善し、より良い画像を生成するために使用します。
|
||||
|
||||
まず、`gradio_tools`からツールをインポートし、それをインスタンス化します:
|
||||
|
||||
```python
|
||||
from gradio_tools import StableDiffusionPromptGeneratorTool
|
||||
|
||||
gradio_tool = StableDiffusionPromptGeneratorTool()
|
||||
```
|
||||
|
||||
そのインスタンスを `Tool.from_gradio` メソッドに渡します:
|
||||
|
||||
|
||||
```python
|
||||
from transformers import Tool
|
||||
|
||||
tool = Tool.from_gradio(gradio_tool)
|
||||
```
|
||||
|
||||
これからは、通常のカスタムツールと同じようにそれを管理できます。私たちはプロンプトを改善するためにそれを活用します。
|
||||
` a rabbit wearing a space suit`:
|
||||
|
||||
```python
|
||||
from transformers import HfAgent
|
||||
|
||||
agent = HfAgent("https://api-inference.huggingface.co/models/bigcode/starcoder", additional_tools=[tool])
|
||||
|
||||
agent.run("Generate an image of the `prompt` after improving it.", prompt="A rabbit wearing a space suit")
|
||||
```
|
||||
|
||||
The model adequately leverages the tool:
|
||||
```text
|
||||
==Explanation from the agent==
|
||||
I will use the following tools: `StableDiffusionPromptGenerator` to improve the prompt, then `image_generator` to generate an image according to the improved prompt.
|
||||
|
||||
|
||||
==Code generated by the agent==
|
||||
improved_prompt = StableDiffusionPromptGenerator(prompt)
|
||||
print(f"The improved prompt is {improved_prompt}.")
|
||||
image = image_generator(improved_prompt)
|
||||
```
|
||||
|
||||
最終的に画像を生成する前に:
|
||||
|
||||

|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png">
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
gradio-toolsは、さまざまなモダリティを使用する場合でも、*テキスト*の入力と出力が必要です。この実装は画像と音声オブジェクトと連携します。現時点では、これら2つは互換性がありませんが、サポートを向上させるために取り組んでおり、迅速に互換性が向上するでしょう。
|
||||
|
||||
</Tip>
|
||||
|
||||
## Future compatibility with Langchain
|
||||
|
||||
私たちはLangchainを愛しており、非常に魅力的なツールのスイートを持っていると考えています。これらのツールを扱うために、Langchainはさまざまなモダリティで作業する場合でも、*テキスト*の入出力が必要です。これは、オブジェクトのシリアル化バージョン(つまり、ディスクに保存されたバージョン)であることが多いです。
|
||||
|
||||
この違いにより、transformers-agentsとlangchain間ではマルチモダリティが処理されていません。
|
||||
この制限は将来のバージョンで解決されることを目指しており、熱心なlangchainユーザーからの任意の支援を歓迎します。
|
||||
|
||||
私たちはより良いサポートを提供したいと考えています。お手伝いいただける場合は、ぜひ[問題を開いて](https://github.com/huggingface/transformers/issues/new)、お考えのことを共有してください。
|
||||
|
||||
|
||||
|
@ -18,88 +18,9 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Transformers Agents は実験的な API であり、いつでも変更される可能性があります。エージェントから返される結果
|
||||
API または基礎となるモデルは変更される傾向があるため、変更される可能性があります。
|
||||
The Agents framework has significantly changed in version v4.41.0.
|
||||
This document has been removed as it was referencing an older API.
|
||||
|
||||
</Tip>
|
||||
We eagerly welcome new contributions for the updated API.
|
||||
|
||||
エージェントとツールの詳細については、[入門ガイド](../transformers_agents) を必ずお読みください。このページ
|
||||
基礎となるクラスの API ドキュメントが含まれています。
|
||||
|
||||
## エージェント
|
||||
|
||||
私たちは 3 種類のエージェントを提供します。[`HfAgent`] はオープンソース モデルの推論エンドポイントを使用し、[`LocalAgent`] は選択したモデルをローカルで使用し、[`OpenAiAgent`] は OpenAI クローズド モデルを使用します。
|
||||
|
||||
### HfAgent
|
||||
|
||||
[[autodoc]] HfAgent
|
||||
|
||||
### LocalAgent
|
||||
|
||||
[[autodoc]] LocalAgent
|
||||
|
||||
### OpenAiAgent
|
||||
|
||||
[[autodoc]] OpenAiAgent
|
||||
|
||||
### AzureOpenAiAgent
|
||||
|
||||
[[autodoc]] AzureOpenAiAgent
|
||||
|
||||
### Agent
|
||||
|
||||
[[autodoc]] Agent
|
||||
- chat
|
||||
- run
|
||||
- prepare_for_new_chat
|
||||
|
||||
## Tools
|
||||
|
||||
### load_tool
|
||||
|
||||
[[autodoc]] load_tool
|
||||
|
||||
### Tool
|
||||
|
||||
[[autodoc]] Tool
|
||||
|
||||
### PipelineTool
|
||||
|
||||
[[autodoc]] PipelineTool
|
||||
|
||||
### RemoteTool
|
||||
|
||||
[[autodoc]] RemoteTool
|
||||
|
||||
### launch_gradio_demo
|
||||
|
||||
[[autodoc]] launch_gradio_demo
|
||||
|
||||
## エージェントの種類
|
||||
|
||||
エージェントはツール間であらゆる種類のオブジェクトを処理できます。ツールは完全にマルチモーダルであるため、受け取りと返品が可能です
|
||||
テキスト、画像、オーディオ、ビデオなどのタイプ。ツール間の互換性を高めるためだけでなく、
|
||||
これらの戻り値を ipython (jupyter、colab、ipython ノートブックなど) で正しくレンダリングするには、ラッパー クラスを実装します。
|
||||
このタイプの周り。
|
||||
|
||||
ラップされたオブジェクトは最初と同じように動作し続けるはずです。テキストオブジェクトは依然として文字列または画像として動作する必要があります
|
||||
オブジェクトは依然として `PIL.Image` として動作するはずです。
|
||||
|
||||
これらのタイプには、次の 3 つの特定の目的があります。
|
||||
|
||||
- 型に対して `to_raw` を呼び出すと、基になるオブジェクトが返されるはずです
|
||||
- 型に対して `to_string` を呼び出すと、オブジェクトを文字列として返す必要があります。`AgentText` の場合は文字列になる可能性があります。
|
||||
ただし、他のインスタンスのオブジェクトのシリアル化されたバージョンのパスになります。
|
||||
- ipython カーネルで表示すると、オブジェクトが正しく表示されるはずです
|
||||
|
||||
### AgentText
|
||||
|
||||
[[autodoc]] transformers.tools.agent_types.AgentText
|
||||
|
||||
### AgentImage
|
||||
|
||||
[[autodoc]] transformers.tools.agent_types.AgentImage
|
||||
|
||||
### AgentAudio
|
||||
|
||||
[[autodoc]] transformers.tools.agent_types.AgentAudio
|
||||
</Tip>
|
@ -12,737 +12,11 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
# 사용자 정의 도구와 프롬프트[[custom-tools-and-prompts]]
|
||||
|
||||
<Tip>
|
||||
|
||||
Transformers와 관련하여 어떤 도구와 에이전트가 있는지 잘 모르신다면 [Transformers Agents](transformers_agents) 페이지를 먼저 읽어보시기 바랍니다.
|
||||
|
||||
</Tip>
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Transformers Agents는 실험 중인 API로 언제든지 변경될 수 있습니다.
|
||||
API 또는 기반 모델이 변경되기 쉽기 때문에 에이전트가 반환하는 결과도 달라질 수 있습니다.
|
||||
The Agents framework has significantly changed in version v4.41.0.
|
||||
This document has been removed as it was referencing an older API.
|
||||
|
||||
We eagerly welcome new contributions for the updated API.
|
||||
|
||||
</Tip>
|
||||
|
||||
에이전트에게 권한을 부여하고 새로운 작업을 수행하게 하려면 사용자 정의 도구와 프롬프트를 만들고 사용하는 것이 무엇보다 중요합니다.
|
||||
이 가이드에서는 다음과 같은 내용을 살펴보겠습니다:
|
||||
|
||||
- 프롬프트를 사용자 정의하는 방법
|
||||
- 사용자 정의 도구를 사용하는 방법
|
||||
- 사용자 정의 도구를 만드는 방법
|
||||
|
||||
## 프롬프트를 사용자 정의하기[[customizing-the-prompt]]
|
||||
|
||||
[Transformers Agents](transformers_agents)에서 설명한 것처럼 에이전트는 [`~Agent.run`] 및 [`~Agent.chat`] 모드에서 실행할 수 있습니다.
|
||||
`run`(실행) 모드와 `chat`(채팅) 모드 모두 동일한 로직을 기반으로 합니다.
|
||||
에이전트를 구동하는 언어 모델은 긴 프롬프트에 따라 조건이 지정되고, 중지 토큰에 도달할 때까지 다음 토큰을 생성하여 프롬프트를 완수합니다.
|
||||
`chat` 모드에서는 프롬프트가 이전 사용자 입력 및 모델 생성으로 연장된다는 점이 두 모드의 유일한 차이점입니다.
|
||||
이를 통해 에이전트가 과거 상호작용에 접근할 수 있게 되므로 에이전트에게 일종의 메모리를 제공하는 셈입니다.
|
||||
|
||||
### 프롬프트의 구조[[structure-of-the-prompt]]
|
||||
|
||||
어떻게 프롬프트 사용자 정의를 잘 할 수 있는지 이해하기 위해 프롬프트의 구조를 자세히 살펴봅시다.
|
||||
프롬프트는 크게 네 부분으로 구성되어 있습니다.
|
||||
|
||||
- 1. 도입: 에이전트가 어떻게 행동해야 하는지, 도구의 개념에 대한 설명.
|
||||
- 2. 모든 도구에 대한 설명. 이는 런타임에 사용자가 정의/선택한 도구로 동적으로 대체되는 `<<all_tools>>` 토큰으로 정의됩니다.
|
||||
- 3. 작업 예제 및 해당 솔루션 세트.
|
||||
- 4. 현재 예제 및 해결 요청.
|
||||
|
||||
각 부분을 더 잘 이해할 수 있도록 짧은 버전을 통해 `run` 프롬프트가 어떻게 보이는지 살펴보겠습니다:
|
||||
|
||||
````text
|
||||
I will ask you to perform a task, your job is to come up with a series of simple commands in Python that will perform the task.
|
||||
[...]
|
||||
You can print intermediate results if it makes sense to do so.
|
||||
|
||||
Tools:
|
||||
- document_qa: This is a tool that answers a question about a document (pdf). It takes an input named `document` which should be the document containing the information, as well as a `question` that is the question about the document. It returns a text that contains the answer to the question.
|
||||
- image_captioner: This is a tool that generates a description of an image. It takes an input named `image` which should be the image to the caption and returns a text that contains the description in English.
|
||||
[...]
|
||||
|
||||
Task: "Answer the question in the variable `question` about the image stored in the variable `image`. The question is in French."
|
||||
|
||||
I will use the following tools: `translator` to translate the question into English and then `image_qa` to answer the question on the input image.
|
||||
|
||||
Answer:
|
||||
```py
|
||||
translated_question = translator(question=question, src_lang="French", tgt_lang="English")
|
||||
print(f"The translated question is {translated_question}.")
|
||||
answer = image_qa(image=image, question=translated_question)
|
||||
print(f"The answer is {answer}")
|
||||
```
|
||||
|
||||
Task: "Identify the oldest person in the `document` and create an image showcasing the result as a banner."
|
||||
|
||||
I will use the following tools: `document_qa` to find the oldest person in the document, then `image_generator` to generate an image according to the answer.
|
||||
|
||||
Answer:
|
||||
```py
|
||||
answer = document_qa(document, question="What is the oldest person?")
|
||||
print(f"The answer is {answer}.")
|
||||
image = image_generator("A banner showing " + answer)
|
||||
```
|
||||
|
||||
[...]
|
||||
|
||||
Task: "Draw me a picture of rivers and lakes"
|
||||
|
||||
I will use the following
|
||||
````
|
||||
|
||||
도입(*"도구:"* 앞의 텍스트)에서는 모델이 어떻게 작동하고 무엇을 해야 하는지 정확하게 설명합니다.
|
||||
에이전트는 항상 같은 방식으로 작동해야 하므로 이 부분은 사용자 정의할 필요가 없을 가능성이 높습니다.
|
||||
|
||||
두 번째 부분(*"도구"* 아래의 글머리 기호)은 `run` 또는 `chat`을 호출할 때 동적으로 추가됩니다.
|
||||
정확히 `agent.toolbox`에 있는 도구 수만큼 글머리 기호가 있고, 각 글머리 기호는 도구의 이름과 설명으로 구성됩니다:
|
||||
|
||||
```text
|
||||
- <tool.name>: <tool.description>
|
||||
```
|
||||
|
||||
문서 질의응답 도구를 가져오고 이름과 설명을 출력해서 빠르게 확인해 보겠습니다.
|
||||
|
||||
```py
|
||||
from transformers import load_tool
|
||||
|
||||
document_qa = load_tool("document-question-answering")
|
||||
print(f"- {document_qa.name}: {document_qa.description}")
|
||||
```
|
||||
|
||||
그러면 다음 결과가 출력됩니다:
|
||||
```text
|
||||
- document_qa: This is a tool that answers a question about a document (pdf). It takes an input named `document` which should be the document containing the information, as well as a `question` that is the question about the document. It returns a text that contains the answer to the question.
|
||||
```
|
||||
|
||||
여기서 도구 이름이 짧고 정확하다는 것을 알 수 있습니다.
|
||||
설명은 두 부분으로 구성되어 있는데, 첫 번째 부분에서는 도구의 기능을 설명하고 두 번째 부분에서는 예상되는 입력 인수와 반환 값을 명시합니다.
|
||||
|
||||
에이전트가 도구를 올바르게 사용하려면 좋은 도구 이름과 도구 설명이 매우 중요합니다.
|
||||
에이전트가 도구에 대해 알 수 있는 유일한 정보는 이름과 설명뿐이므로, 이 두 가지를 정확하게 작성하고 도구 상자에 있는 기존 도구의 스타일과 일치하는지 확인해야 합니다.
|
||||
특히 이름에 따라 예상되는 모든 인수가 설명에 코드 스타일로 언급되어 있는지, 예상되는 유형과 그 유형이 무엇인지에 대한 설명이 포함되어 있는지 확인하세요.
|
||||
|
||||
<Tip>
|
||||
|
||||
도구에 어떤 이름과 설명이 있어야 하는지 이해하려면 엄선된 Transformers 도구의 이름과 설명을 확인하세요.
|
||||
[`Agent.toolbox`] 속성을 가진 모든 도구를 볼 수 있습니다.
|
||||
|
||||
</Tip>
|
||||
|
||||
세 번째 부분에는 에이전트가 어떤 종류의 사용자 요청에 대해 어떤 코드를 생성해야 하는지 정확하게 보여주는 엄선된 예제 세트가 포함되어 있습니다.
|
||||
에이전트를 지원하는 대규모 언어 모델은 프롬프트에서 패턴을 인식하고 새로운 데이터로 패턴을 반복하는 데 매우 능숙합니다.
|
||||
따라서 에이전트가 실제로 올바른 실행 가능한 코드를 생성할 가능성을 극대화하는 방식으로 예제를 작성하는 것이 매우 중요합니다.
|
||||
|
||||
한 가지 예를 살펴보겠습니다:
|
||||
|
||||
````text
|
||||
Task: "Identify the oldest person in the `document` and create an image showcasing the result as a banner."
|
||||
|
||||
I will use the following tools: `document_qa` to find the oldest person in the document, then `image_generator` to generate an image according to the answer.
|
||||
|
||||
Answer:
|
||||
```py
|
||||
answer = document_qa(document, question="What is the oldest person?")
|
||||
print(f"The answer is {answer}.")
|
||||
image = image_generator("A banner showing " + answer)
|
||||
```
|
||||
|
||||
````
|
||||
작업 설명, 에이전트가 수행하려는 작업에 대한 설명, 마지막으로 생성된 코드, 이 세 부분으로 구성된 프롬프트는 모델에 반복하여 제공됩니다.
|
||||
프롬프트의 일부인 모든 예제는 이러한 정확한 패턴으로 되어 있으므로, 에이전트가 새 토큰을 생성할 때 정확히 동일한 패턴을 재현할 수 있습니다.
|
||||
|
||||
프롬프트 예제는 Transformers 팀이 선별하고 일련의 [problem statements](https://github.com/huggingface/transformers/blob/main/src/transformers/tools/evaluate_agent.py)에 따라 엄격하게 평가하여
|
||||
에이전트의 프롬프트가 에이전트의 실제 사용 사례를 최대한 잘 해결할 수 있도록 보장합니다.
|
||||
|
||||
프롬프트의 마지막 부분은 다음에 해당합니다:
|
||||
```text
|
||||
Task: "Draw me a picture of rivers and lakes"
|
||||
|
||||
I will use the following
|
||||
```
|
||||
|
||||
이는 에이전트가 완료해야 할 최종적인 미완성 예제입니다. 미완성 예제는 실제 사용자 입력에 따라 동적으로 만들어집니다.
|
||||
위 예시의 경우 사용자가 다음과 같이 실행했습니다:
|
||||
|
||||
```py
|
||||
agent.run("Draw me a picture of rivers and lakes")
|
||||
```
|
||||
|
||||
사용자 입력 - *즉* Task: *"Draw me a picture of rivers and lakes"*가 프롬프트 템플릿에 맞춰 "Task: <task> \n\n I will use the following"로 캐스팅됩니다.
|
||||
이 문장은 에이전트에게 조건이 적용되는 프롬프트의 마지막 줄을 구성하므로 에이전트가 이전 예제에서 수행한 것과 정확히 동일한 방식으로 예제를 완료하도록 강력하게 영향을 미칩니다.
|
||||
|
||||
너무 자세히 설명하지 않더라도 채팅 템플릿의 프롬프트 구조는 동일하지만 예제의 스타일이 약간 다릅니다. *예를 들면*:
|
||||
|
||||
````text
|
||||
[...]
|
||||
|
||||
=====
|
||||
|
||||
Human: Answer the question in the variable `question` about the image stored in the variable `image`.
|
||||
|
||||
Assistant: I will use the tool `image_qa` to answer the question on the input image.
|
||||
|
||||
```py
|
||||
answer = image_qa(text=question, image=image)
|
||||
print(f"The answer is {answer}")
|
||||
```
|
||||
|
||||
Human: I tried this code, it worked but didn't give me a good result. The question is in French
|
||||
|
||||
Assistant: In this case, the question needs to be translated first. I will use the tool `translator` to do this.
|
||||
|
||||
```py
|
||||
translated_question = translator(question=question, src_lang="French", tgt_lang="English")
|
||||
print(f"The translated question is {translated_question}.")
|
||||
answer = image_qa(text=translated_question, image=image)
|
||||
print(f"The answer is {answer}")
|
||||
```
|
||||
|
||||
=====
|
||||
|
||||
[...]
|
||||
````
|
||||
|
||||
`run` 프롬프트의 예와는 반대로, 각 `chat` 프롬프트의 예에는 *Human(사람)*과 *Assistant(어시스턴트)* 간에 하나 이상의 교환이 있습니다. 모든 교환은 `run` 프롬프트의 예와 유사한 구조로 되어 있습니다.
|
||||
사용자의 입력이 *Human:* 뒤에 추가되며, 에이전트에게 코드를 생성하기 전에 수행해야 할 작업을 먼저 생성하라는 메시지가 표시됩니다.
|
||||
교환은 이전 교환을 기반으로 할 수 있으므로 위와 같이 사용자가 "**이** 코드를 시도했습니다"라고 입력하면 이전에 생성된 에이전트의 코드를 참조하여 과거 교환을 참조할 수 있습니다.
|
||||
|
||||
`.chat`을 실행하면 사용자의 입력 또는 *작업*이 미완성된 양식의 예시로 캐스팅됩니다:
|
||||
```text
|
||||
Human: <user-input>\n\nAssistant:
|
||||
```
|
||||
그러면 에이전트가 이를 완성합니다. `run` 명령과 달리 `chat` 명령은 완료된 예제를 프롬프트에 추가하여 에이전트에게 다음 `chat` 차례에 대한 더 많은 문맥을 제공합니다.
|
||||
|
||||
이제 프롬프트가 어떻게 구성되어 있는지 알았으니 어떻게 사용자 정의할 수 있는지 살펴봅시다!
|
||||
|
||||
### 좋은 사용자 입력 작성하기[[writing-good-user-inputs]]
|
||||
|
||||
대규모 언어 모델이 사용자의 의도를 이해하는 능력이 점점 더 향상되고 있지만, 에이전트가 올바른 작업을 선택할 수 있도록 최대한 정확성을 유지하는 것은 큰 도움이 됩니다.
|
||||
최대한 정확하다는 것은 무엇을 의미할까요?
|
||||
|
||||
에이전트는 프롬프트에서 도구 이름 목록과 해당 설명을 볼 수 있습니다.
|
||||
더 많은 도구가 추가될수록 에이전트가 올바른 도구를 선택하기가 더 어려워지고 실행할 도구의 올바른 순서를 선택하는 것은 더욱 어려워집니다.
|
||||
일반적인 실패 사례를 살펴보겠습니다. 여기서는 분석할 코드만 반환하겠습니다.
|
||||
|
||||
```py
|
||||
from transformers import HfAgent
|
||||
|
||||
agent = HfAgent("https://api-inference.huggingface.co/models/bigcode/starcoder")
|
||||
|
||||
agent.run("Show me a tree", return_code=True)
|
||||
```
|
||||
|
||||
그러면 다음 결과가 출력됩니다:
|
||||
|
||||
```text
|
||||
==Explanation from the agent==
|
||||
I will use the following tool: `image_segmenter` to create a segmentation mask for the image.
|
||||
|
||||
|
||||
==Code generated by the agent==
|
||||
mask = image_segmenter(image, prompt="tree")
|
||||
```
|
||||
|
||||
우리가 원했던 결과가 아닐 수도 있습니다. 대신 나무 이미지가 생성되기를 원할 가능성이 더 높습니다.
|
||||
따라서 에이전트가 특정 도구를 사용하도록 유도하려면 도구의 이름과 설명에 있는 중요한 키워드를 사용하는 것이 매우 유용할 수 있습니다. 한번 살펴보겠습니다.
|
||||
```py
|
||||
agent.toolbox["image_generator"].description
|
||||
```
|
||||
|
||||
```text
|
||||
'This is a tool that creates an image according to a prompt, which is a text description. It takes an input named `prompt` which contains the image description and outputs an image.
|
||||
```
|
||||
|
||||
이름과 설명은 "image", "prompt", "create" 및 "generate" 키워드를 사용합니다. 이 단어들을 사용하면 더 잘 작동할 가능성이 높습니다. 프롬프트를 조금 더 구체화해 보겠습니다.
|
||||
|
||||
```py
|
||||
agent.run("Create an image of a tree", return_code=True)
|
||||
```
|
||||
|
||||
이 코드는 다음 프롬프트를 만들어냅니다:
|
||||
```text
|
||||
==Explanation from the agent==
|
||||
I will use the following tool `image_generator` to generate an image of a tree.
|
||||
|
||||
|
||||
==Code generated by the agent==
|
||||
image = image_generator(prompt="tree")
|
||||
```
|
||||
|
||||
훨씬 낫네요! 저희가 원했던 것과 비슷해 보입니다.
|
||||
즉, 에이전트가 작업을 올바른 도구에 올바르게 매핑하는 데 어려움을 겪고 있다면 도구 이름과 설명에서 가장 관련성이 높은 키워드를 찾아보고 이를 통해 작업 요청을 구체화해 보세요.
|
||||
|
||||
### 도구 설명 사용자 정의하기[[customizing-the-tool-descriptions]]
|
||||
|
||||
앞서 살펴본 것처럼 에이전트는 각 도구의 이름과 설명에 액세스할 수 있습니다.
|
||||
기본 도구에는 매우 정확한 이름과 설명이 있어야 하지만 특정 사용 사례에 맞게 도구의 설명이나 이름을 변경하는 것이 도움이 될 수도 있습니다.
|
||||
이는 매우 유사한 여러 도구를 추가했거나 특정 도메인(*예*: 이미지 생성 및 변환)에만 에이전트를 사용하려는 경우에 특히 중요해질 수 있습니다.
|
||||
|
||||
일반적인 문제는 이미지 생성 작업에 많이 사용되는 경우 에이전트가 이미지 생성과 이미지 변환/수정을 혼동하는 것입니다. *예를 들어,*
|
||||
```py
|
||||
agent.run("Make an image of a house and a car", return_code=True)
|
||||
```
|
||||
그러면 다음 결과가 출력됩니다:
|
||||
```text
|
||||
==Explanation from the agent==
|
||||
I will use the following tools `image_generator` to generate an image of a house and `image_transformer` to transform the image of a car into the image of a house.
|
||||
|
||||
==Code generated by the agent==
|
||||
house_image = image_generator(prompt="A house")
|
||||
car_image = image_generator(prompt="A car")
|
||||
house_car_image = image_transformer(image=car_image, prompt="A house")
|
||||
```
|
||||
|
||||
결과물이 우리가 여기서 원하는 것과 정확히 일치하지 않을 수 있습니다. 에이전트가 `image_generator`와 `image_transformer`의 차이점을 이해하기 어려워서 두 가지를 함께 사용하는 경우가 많은 것 같습니다.
|
||||
|
||||
여기서 `image_transformer`의 도구 이름과 설명을 변경하여 에이전트가 도울 수 있습니다.
|
||||
"image" 및 "prompt"와 약간 분리하기 위해 `modifier`라고 대신 부르겠습니다:
|
||||
```py
|
||||
agent.toolbox["modifier"] = agent.toolbox.pop("image_transformer")
|
||||
agent.toolbox["modifier"].description = agent.toolbox["modifier"].description.replace(
|
||||
"transforms an image according to a prompt", "modifies an image"
|
||||
)
|
||||
```
|
||||
|
||||
이제 "modify"은 새 이미지 프로세서를 사용하라는 강력한 신호이므로 위의 프롬프트에 도움이 될 것입니다. 다시 실행해 봅시다.
|
||||
|
||||
```py
|
||||
agent.run("Make an image of a house and a car", return_code=True)
|
||||
```
|
||||
|
||||
여기서 다음과 같은 결과를 얻게 됩니다:
|
||||
```text
|
||||
==Explanation from the agent==
|
||||
I will use the following tools: `image_generator` to generate an image of a house, then `image_generator` to generate an image of a car.
|
||||
|
||||
|
||||
==Code generated by the agent==
|
||||
house_image = image_generator(prompt="A house")
|
||||
car_image = image_generator(prompt="A car")
|
||||
```
|
||||
|
||||
우리가 염두에 두었던 것과 확실히 더 가까워졌습니다! 하지만 집과 자동차가 모두 같은 이미지에 포함되면 좋겠습니다. 작업을 단일 이미지 생성에 더 집중하면 도움이 될 것입니다:
|
||||
|
||||
```py
|
||||
agent.run("Create image: 'A house and car'", return_code=True)
|
||||
```
|
||||
|
||||
```text
|
||||
==Explanation from the agent==
|
||||
I will use the following tool: `image_generator` to generate an image.
|
||||
|
||||
|
||||
==Code generated by the agent==
|
||||
image = image_generator(prompt="A house and car")
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
에이전트는 여전히 특히 여러 개체의 이미지를 생성하는 것과 같이 약간 더 복잡한 사용 사례에서 취약한 경우가 많습니다.
|
||||
앞으로 몇 달 안에 에이전트 자체와 기본 프롬프트가 더욱 개선되어 에이전트가 다양한 사용자 입력에 더욱 강력하게 대응할 수 있도록 할 예정입니다.
|
||||
|
||||
</Tip>
|
||||
|
||||
### 전체 프롬프트 사용자 정의하기[[customizing-the-whole-prompt]]
|
||||
|
||||
사용자에게 최대한의 유연성을 제공하기 위해 [위](#structure-of-the-prompt)에 설명된 전체 프롬프트 템플릿을 사용자가 덮어쓸 수 있습니다.
|
||||
이 경우 사용자 정의 프롬프트에 소개 섹션, 도구 섹션, 예제 섹션 및 미완성 예제 섹션이 포함되어 있는지 확인하세요.
|
||||
`run` 프롬프트 템플릿을 덮어쓰려면 다음과 같이 하면 됩니다:
|
||||
|
||||
```py
|
||||
template = """ [...] """
|
||||
|
||||
agent = HfAgent(your_endpoint, run_prompt_template=template)
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
에이전트가 사용 가능한 도구를 인식하고 사용자의 프롬프트를 올바르게 삽입할 수 있도록 `<<all_tools>>` 문자열과 `<<prompt>>`를 `template` 어딘가에 정의해야 합니다.
|
||||
|
||||
</Tip>
|
||||
|
||||
마찬가지로 `chat` 프롬프트 템플릿을 덮어쓸 수 있습니다. `chat` 모드에서는 항상 다음과 같은 교환 형식을 사용한다는 점에 유의하세요:
|
||||
|
||||
```text
|
||||
Human: <<task>>
|
||||
|
||||
Assistant:
|
||||
```
|
||||
|
||||
따라서 사용자 정의 `chat` 프롬프트 템플릿의 예제에서도 이 형식을 사용하는 것이 중요합니다.
|
||||
다음과 같이 인스턴스화 할 때 `chat` 템플릿을 덮어쓸 수 있습니다.
|
||||
|
||||
```python
|
||||
template = """ [...] """
|
||||
|
||||
agent = HfAgent(url_endpoint=your_endpoint, chat_prompt_template=template)
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
에이전트가 사용 가능한 도구를 인식할 수 있도록 `<<all_tools>>` 문자열을 `template` 어딘가에 정의해야 합니다.
|
||||
|
||||
</Tip>
|
||||
|
||||
두 경우 모두 커뮤니티의 누군가가 호스팅하는 템플릿을 사용하려는 경우 프롬프트 템플릿 대신 저장소 ID를 전달할 수 있습니다.
|
||||
기본 프롬프트는 [이 저장소](https://huggingface.co/datasets/huggingface-tools/default-prompts)를 예로 들 수 있습니다.
|
||||
|
||||
Hub의 저장소에 사용자 정의 프롬프트를 업로드하여 커뮤니티와 공유하려면 다음을 확인하세요:
|
||||
- 데이터 세트 저장소를 사용하세요.
|
||||
- `run` 명령에 대한 프롬프트 템플릿을 `run_prompt_template.txt`라는 파일에 넣으세요.
|
||||
- `chat` 명령에 대한 프롬프트 템플릿을 `chat_prompt_template.txt`라는 파일에 넣으세요.
|
||||
|
||||
## 사용자 정의 도구 사용하기[[using-custom-tools]]
|
||||
|
||||
이 섹션에서는 이미지 생성에 특화된 두 가지 기존 사용자 정의 도구를 활용하겠습니다:
|
||||
|
||||
- 더 많은 이미지 수정을 허용하기 위해 [huggingface-tools/image-transformation](https://huggingface.co/spaces/huggingface-tools/image-transformation)을
|
||||
[diffusers/controlnet-canny-tool](https://huggingface.co/spaces/diffusers/controlnet-canny-tool)로 대체합니다.
|
||||
- 기본 도구 상자에 이미지 업스케일링을 위한 새로운 도구가 추가되었습니다:
|
||||
[diffusers/latent-upscaler-tool](https://huggingface.co/spaces/diffusers/latent-upscaler-tool)가 기존 이미지 변환 도구를 대체합니다.
|
||||
|
||||
편리한 [`load_tool`] 함수를 사용하여 사용자 정의 도구를 가져오는 것으로 시작하겠습니다:
|
||||
|
||||
```py
|
||||
from transformers import load_tool
|
||||
|
||||
controlnet_transformer = load_tool("diffusers/controlnet-canny-tool")
|
||||
upscaler = load_tool("diffusers/latent-upscaler-tool")
|
||||
```
|
||||
|
||||
에이전트에게 사용자 정의 도구를 추가하면 도구의 설명과 이름이 에이전트의 프롬프트에 자동으로 포함됩니다.
|
||||
따라서 에이전트가 사용 방법을 이해할 수 있도록 사용자 정의 도구의 설명과 이름을 잘 작성해야 합니다.
|
||||
`controlnet_transformer`의 설명과 이름을 살펴보겠습니다:
|
||||
|
||||
```py
|
||||
print(f"Description: '{controlnet_transformer.description}'")
|
||||
print(f"Name: '{controlnet_transformer.name}'")
|
||||
```
|
||||
|
||||
그러면 다음 결과가 출력됩니다:
|
||||
```text
|
||||
Description: 'This is a tool that transforms an image with ControlNet according to a prompt.
|
||||
It takes two inputs: `image`, which should be the image to transform, and `prompt`, which should be the prompt to use to change it. It returns the modified image.'
|
||||
Name: 'image_transformer'
|
||||
```
|
||||
|
||||
이름과 설명이 정확하고 [큐레이팅 된 도구 세트(curated set of tools)](./transformers_agents#a-curated-set-of-tools)의 스타일에 맞습니다.
|
||||
다음으로, `controlnet_transformer`와 `upscaler`로 에이전트를 인스턴스화해 봅시다:
|
||||
```py
|
||||
tools = [controlnet_transformer, upscaler]
|
||||
agent = HfAgent("https://api-inference.huggingface.co/models/bigcode/starcoder", additional_tools=tools)
|
||||
```
|
||||
|
||||
이 명령을 실행하면 다음 정보가 표시됩니다:
|
||||
|
||||
```text
|
||||
image_transformer has been replaced by <transformers_modules.diffusers.controlnet-canny-tool.bd76182c7777eba9612fc03c0
|
||||
8718a60c0aa6312.image_transformation.ControlNetTransformationTool object at 0x7f1d3bfa3a00> as provided in `additional_tools`
|
||||
```
|
||||
|
||||
큐레이팅된 도구 세트에는 이미 'image_transformer' 도구가 있으며, 이 도구는 사용자 정의 도구로 대체됩니다.
|
||||
|
||||
<Tip>
|
||||
|
||||
기존 도구와 똑같은 작업에 사용자 정의 도구를 사용하려는 경우 기존 도구를 덮어쓰는 것이 유용할 수 있습니다.
|
||||
에이전트가 해당 작업에 능숙하기 때문입니다.
|
||||
이 경우 사용자 정의 도구가 덮어쓴 도구와 정확히 동일한 API를 따라야 하며, 그렇지 않으면 해당 도구를 사용하는 모든 예제가 업데이트되도록 프롬프트 템플릿을 조정해야 한다는 점에 유의하세요.
|
||||
|
||||
</Tip>
|
||||
|
||||
업스케일러 도구에 지정된 'image_upscaler'라는 이름 아직 기본 도구 상자에는 존재하지 않기 때문에, 도구 목록에 해당 이름이 간단히 추가되었습니다.
|
||||
에이전트가 현재 사용할 수 있는 도구 상자는 언제든지 `agent.toolbox` 속성을 통해 확인할 수 있습니다:
|
||||
|
||||
```py
|
||||
print("\n".join([f"- {a}" for a in agent.toolbox.keys()]))
|
||||
```
|
||||
|
||||
```text
|
||||
- document_qa
|
||||
- image_captioner
|
||||
- image_qa
|
||||
- image_segmenter
|
||||
- transcriber
|
||||
- summarizer
|
||||
- text_classifier
|
||||
- text_qa
|
||||
- text_reader
|
||||
- translator
|
||||
- image_transformer
|
||||
- text_downloader
|
||||
- image_generator
|
||||
- video_generator
|
||||
- image_upscaler
|
||||
```
|
||||
|
||||
에이전트의 도구 상자에 `image_upscaler`가 추가된 점을 주목하세요.
|
||||
|
||||
이제 새로운 도구를 사용해봅시다! [Transformers Agents Quickstart](./transformers_agents#single-execution-run)에서 생성한 이미지를 다시 사용하겠습니다.
|
||||
|
||||
```py
|
||||
from diffusers.utils import load_image
|
||||
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rivers_and_lakes.png"
|
||||
)
|
||||
```
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rivers_and_lakes.png" width=200>
|
||||
|
||||
이미지를 아름다운 겨울 풍경으로 바꿔 봅시다:
|
||||
|
||||
```py
|
||||
image = agent.run("Transform the image: 'A frozen lake and snowy forest'", image=image)
|
||||
```
|
||||
|
||||
```text
|
||||
==Explanation from the agent==
|
||||
I will use the following tool: `image_transformer` to transform the image.
|
||||
|
||||
|
||||
==Code generated by the agent==
|
||||
image = image_transformer(image, prompt="A frozen lake and snowy forest")
|
||||
```
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rivers_and_lakes_winter.png" width=200>
|
||||
|
||||
새로운 이미지 처리 도구는 이미지를 매우 강력하게 수정할 수 있는 ControlNet을 기반으로 합니다.
|
||||
기본적으로 이미지 처리 도구는 512x512 픽셀 크기의 이미지를 반환합니다. 이를 업스케일링할 수 있는지 살펴봅시다.
|
||||
|
||||
```py
|
||||
image = agent.run("Upscale the image", image)
|
||||
```
|
||||
|
||||
```text
|
||||
==Explanation from the agent==
|
||||
I will use the following tool: `image_upscaler` to upscale the image.
|
||||
|
||||
|
||||
==Code generated by the agent==
|
||||
upscaled_image = image_upscaler(image)
|
||||
```
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rivers_and_lakes_winter_upscale.png" width=400>
|
||||
|
||||
에이전트는 업스케일러 도구의 설명과 이름만 보고 방금 추가한 업스케일러 도구에 "이미지 업스케일링"이라는 프롬프트를 자동으로 매핑하여 올바르게 실행했습니다.
|
||||
|
||||
다음으로 새 사용자 정의 도구를 만드는 방법을 살펴보겠습니다.
|
||||
|
||||
### 새 도구 추가하기[[adding-new-tools]]
|
||||
|
||||
이 섹션에서는 에이전트에게 추가할 수 있는 새 도구를 만드는 방법을 보여 드립니다.
|
||||
|
||||
#### 새 도구 만들기[[creating-a-new-tool]]
|
||||
|
||||
먼저 도구를 만드는 것부터 시작하겠습니다.
|
||||
특정 작업에 대해 가장 많은 다운로드를 받은 Hugging Face Hub의 모델을 가져오는, 그다지 유용하지는 않지만 재미있는 작업을 추가하겠습니다.
|
||||
|
||||
다음 코드를 사용하면 됩니다:
|
||||
|
||||
```python
|
||||
from huggingface_hub import list_models
|
||||
|
||||
task = "text-classification"
|
||||
|
||||
model = next(iter(list_models(filter=task, sort="downloads", direction=-1)))
|
||||
print(model.id)
|
||||
```
|
||||
`text-classification`(텍스트 분류) 작업의 경우 `'facebook/bart-large-mnli'`를 반환하고, `translation`(번역) 작업의 경우 `'google-t5/t5-base'`를 반환합니다.
|
||||
|
||||
이를 에이전트가 활용할 수 있는 도구로 변환하려면 어떻게 해야 할까요?
|
||||
모든 도구는 필요한 주요 속성을 보유하는 슈퍼클래스 `Tool`에 의존합니다. 이를 상속하는 클래스를 만들어 보겠습니다:
|
||||
|
||||
```python
|
||||
from transformers import Tool
|
||||
|
||||
|
||||
class HFModelDownloadsTool(Tool):
|
||||
pass
|
||||
```
|
||||
|
||||
이 클래스에는 몇 가지 요구사항이 있습니다:
|
||||
- 도구 자체의 이름에 해당하는 `name` 속성. 수행명이 있는 다른 도구와 호환되도록 `model_download_counter`로 이름을 지정하겠습니다.
|
||||
- 에이전트의 프롬프트를 채우는 데 사용되는 속성 `description`.
|
||||
- `inputs` 및 `outputs` 속성. 이를 정의하면 Python 인터프리터가 유형에 대한 정보에 입각한 선택을 하는 데 도움이 되며,
|
||||
도구를 허브에 푸시할 때 gradio 데모를 생성할 수 있습니다.
|
||||
두 속성 모두 값은 '텍스트', '이미지' 또는 '오디오'가 될 수 있는 예상 값의 리스트입니다.
|
||||
- 추론 코드가 포함된 `__call__` 메소드. 이것이 우리가 위에서 다루었던 코드입니다!
|
||||
|
||||
이제 클래스의 모습은 다음과 같습니다:
|
||||
|
||||
```python
|
||||
from transformers import Tool
|
||||
from huggingface_hub import list_models
|
||||
|
||||
|
||||
class HFModelDownloadsTool(Tool):
|
||||
name = "model_download_counter"
|
||||
description = (
|
||||
"This is a tool that returns the most downloaded model of a given task on the Hugging Face Hub. "
|
||||
"It takes the name of the category (such as text-classification, depth-estimation, etc), and "
|
||||
"returns the name of the checkpoint."
|
||||
)
|
||||
|
||||
inputs = ["text"]
|
||||
outputs = ["text"]
|
||||
|
||||
def __call__(self, task: str):
|
||||
model = next(iter(list_models(filter=task, sort="downloads", direction=-1)))
|
||||
return model.id
|
||||
```
|
||||
|
||||
이제 도구를 손쉽게 사용할 수 있게 되었습니다.
|
||||
도구를 파일에 저장하고 메인 스크립트에서 가져옵니다. 이 파일의 이름을 `model_downloads.py`로 지정하면 결과적으로 가져오기 코드는 다음과 같습니다:
|
||||
|
||||
```python
|
||||
from model_downloads import HFModelDownloadsTool
|
||||
|
||||
tool = HFModelDownloadsTool()
|
||||
```
|
||||
|
||||
다른 사람들이 이 기능을 활용할 수 있도록 하고 초기화를 더 간단하게 하려면 네임스페이스 아래의 Hub로 푸시하는 것이 좋습니다.
|
||||
그렇게 하려면 `tool` 변수에서 `push_to_hub`를 호출하면 됩니다:
|
||||
|
||||
```python
|
||||
tool.push_to_hub("hf-model-downloads")
|
||||
```
|
||||
|
||||
이제 허브에 코드가 생겼습니다! 마지막 단계인 에이전트가 코드를 사용하도록 하는 단계를 살펴보겠습니다.
|
||||
|
||||
#### 에이전트가 도구를 사용하게 하기[[Having-the-agent-use-the-tool]]
|
||||
|
||||
이제 이런 식으로 허브에 존재하는 도구를 인스턴스화할 수 있습니다(도구의 사용자 이름은 변경하세요):
|
||||
We now have our tool that lives on the Hub which can be instantiated as such (change the user name for your tool):
|
||||
|
||||
```python
|
||||
from transformers import load_tool
|
||||
|
||||
tool = load_tool("lysandre/hf-model-downloads")
|
||||
```
|
||||
|
||||
이 도구를 에이전트에서 사용하려면 에이전트 초기화 메소드의 `additional_tools` 매개변수에 전달하기만 하면 됩니다:
|
||||
|
||||
```python
|
||||
from transformers import HfAgent
|
||||
|
||||
agent = HfAgent("https://api-inference.huggingface.co/models/bigcode/starcoder", additional_tools=[tool])
|
||||
|
||||
agent.run(
|
||||
"Can you read out loud the name of the model that has the most downloads in the 'text-to-video' task on the Hugging Face Hub?"
|
||||
)
|
||||
```
|
||||
그러면 다음과 같은 결과가 출력됩니다:
|
||||
```text
|
||||
==Code generated by the agent==
|
||||
model = model_download_counter(task="text-to-video")
|
||||
print(f"The model with the most downloads is {model}.")
|
||||
audio_model = text_reader(model)
|
||||
|
||||
|
||||
==Result==
|
||||
The model with the most downloads is damo-vilab/text-to-video-ms-1.7b.
|
||||
```
|
||||
|
||||
and generates the following audio.
|
||||
|
||||
| **Audio** |
|
||||
|------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| <audio controls><source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/damo.wav" type="audio/wav"/> |
|
||||
|
||||
|
||||
<Tip>
|
||||
|
||||
LLM에 따라 일부는 매우 취약하기 때문에 제대로 작동하려면 매우 정확한 프롬프트가 필요합니다.
|
||||
에이전트가 도구를 잘 활용하기 위해서는 도구의 이름과 설명을 잘 정의하는 것이 무엇보다 중요합니다.
|
||||
|
||||
</Tip>
|
||||
|
||||
### 기존 도구 대체하기[[replacing-existing-tools]]
|
||||
|
||||
에이전트의 도구 상자에 새 항목을 배정하기만 하면 기존 도구를 대체할 수 있습니다. 방법은 다음과 같습니다:
|
||||
|
||||
```python
|
||||
from transformers import HfAgent, load_tool
|
||||
|
||||
agent = HfAgent("https://api-inference.huggingface.co/models/bigcode/starcoder")
|
||||
agent.toolbox["image-transformation"] = load_tool("diffusers/controlnet-canny-tool")
|
||||
```
|
||||
|
||||
<Tip>
|
||||
|
||||
다른 도구로 교체할 때는 주의하세요! 이 작업으로 에이전트의 프롬프트도 조정됩니다.
|
||||
작업에 더 적합한 프롬프트가 있으면 좋을 수 있지만,
|
||||
다른 도구보다 더 많이 선택되거나 정의한 도구 대신 다른 도구가 선택될 수도 있습니다.
|
||||
|
||||
</Tip>
|
||||
|
||||
## gradio-tools 사용하기[[leveraging-gradio-tools]]
|
||||
|
||||
[gradio-tools](https://github.com/freddyaboulton/gradio-tools)는 Hugging Face Spaces를 도구로 사용할 수 있는 강력한 라이브러리입니다.
|
||||
기존의 많은 Spaces뿐만 아니라 사용자 정의 Spaces를 사용하여 디자인할 수 있도록 지원합니다.
|
||||
|
||||
우리는 `Tool.from_gradio` 메소드를 사용하여 `gradio_tools`에 대한 지원을 제공합니다.
|
||||
예를 들어, 프롬프트를 개선하고 더 나은 이미지를 생성하기 위해 `gradio-tools` 툴킷에서 제공되는 `StableDiffusionPromptGeneratorTool` 도구를 활용하고자 합니다.
|
||||
|
||||
먼저 `gradio_tools`에서 도구를 가져와서 인스턴스화합니다:
|
||||
|
||||
```python
|
||||
from gradio_tools import StableDiffusionPromptGeneratorTool
|
||||
|
||||
gradio_tool = StableDiffusionPromptGeneratorTool()
|
||||
```
|
||||
|
||||
해당 인스턴스를 `Tool.from_gradio` 메소드에 전달합니다:
|
||||
|
||||
```python
|
||||
from transformers import Tool
|
||||
|
||||
tool = Tool.from_gradio(gradio_tool)
|
||||
```
|
||||
|
||||
이제 일반적인 사용자 정의 도구와 똑같이 관리할 수 있습니다.
|
||||
이를 활용하여 `a rabbit wearing a space suit'(우주복을 입은 토끼)라는 프롬프트를 개선했습니다:
|
||||
|
||||
```python
|
||||
from transformers import HfAgent
|
||||
|
||||
agent = HfAgent("https://api-inference.huggingface.co/models/bigcode/starcoder", additional_tools=[tool])
|
||||
|
||||
agent.run("Generate an image of the `prompt` after improving it.", prompt="A rabbit wearing a space suit")
|
||||
```
|
||||
|
||||
모델이 도구를 적절히 활용합니다:
|
||||
```text
|
||||
==Explanation from the agent==
|
||||
I will use the following tools: `StableDiffusionPromptGenerator` to improve the prompt, then `image_generator` to generate an image according to the improved prompt.
|
||||
|
||||
|
||||
==Code generated by the agent==
|
||||
improved_prompt = StableDiffusionPromptGenerator(prompt)
|
||||
print(f"The improved prompt is {improved_prompt}.")
|
||||
image = image_generator(improved_prompt)
|
||||
```
|
||||
|
||||
마지막으로 이미지를 생성하기 전에:
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png">
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
gradio-tools는 다른 모달리티로 작업할 때에도 *텍스트* 입력 및 출력을 필요로 합니다.
|
||||
이 구현은 이미지 및 오디오 객체에서 작동합니다.
|
||||
현재는 이 두 가지가 호환되지 않지만 지원 개선을 위해 노력하면서 빠르게 호환될 것입니다.
|
||||
|
||||
</Tip>
|
||||
|
||||
## 향후 Langchain과의 호환성[[future-compatibility-with-langchain]]
|
||||
|
||||
저희는 Langchain을 좋아하며 매우 매력적인 도구 모음을 가지고 있다고 생각합니다.
|
||||
이러한 도구를 처리하기 위해 Langchain은 다른 모달리티와 작업할 때에도 *텍스트* 입력과 출력을 필요로 합니다.
|
||||
이는 종종 객체의 직렬화된(즉, 디스크에 저장된) 버전입니다.
|
||||
|
||||
이 차이로 인해 transformers-agents와 Langchain 간에는 멀티 모달리티가 처리되지 않습니다.
|
||||
향후 버전에서 이 제한이 해결되기를 바라며, 이 호환성을 달성할 수 있도록 열렬한 Langchain 사용자의 도움을 환영합니다.
|
||||
|
||||
저희는 더 나은 지원을 제공하고자 합니다. 도움을 주고 싶으시다면, [이슈를 열어](https://github.com/huggingface/transformers/issues/new) 의견을 공유해 주세요.
|
||||
|
@ -18,84 +18,9 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Transformers Agents是一个实验性的API,它随时可能发生变化。由于API或底层模型容易发生变化,因此由agents返回的结果可能会有所不同。
|
||||
The Agents framework has significantly changed in version v4.41.0.
|
||||
This document has been removed as it was referencing an older API.
|
||||
|
||||
We eagerly welcome new contributions for the updated API.
|
||||
|
||||
</Tip>
|
||||
|
||||
要了解更多关于agents和工具的信息,请确保阅读[介绍指南](../transformers_agents)。此页面包含底层类的API文档。
|
||||
|
||||
|
||||
## Agents
|
||||
|
||||
我们提供三种类型的agents:[`HfAgent`]使用开源模型的推理端点,[`LocalAgent`]使用您在本地选择的模型,[`OpenAiAgent`]使用OpenAI封闭模型。
|
||||
|
||||
|
||||
### HfAgent
|
||||
|
||||
[[autodoc]] HfAgent
|
||||
|
||||
### LocalAgent
|
||||
|
||||
[[autodoc]] LocalAgent
|
||||
|
||||
### OpenAiAgent
|
||||
|
||||
[[autodoc]] OpenAiAgent
|
||||
|
||||
### AzureOpenAiAgent
|
||||
|
||||
[[autodoc]] AzureOpenAiAgent
|
||||
|
||||
### Agent
|
||||
|
||||
[[autodoc]] Agent
|
||||
- chat
|
||||
- run
|
||||
- prepare_for_new_chat
|
||||
|
||||
## 工具
|
||||
|
||||
### load_tool
|
||||
|
||||
[[autodoc]] load_tool
|
||||
|
||||
### Tool
|
||||
|
||||
[[autodoc]] Tool
|
||||
|
||||
### PipelineTool
|
||||
|
||||
[[autodoc]] PipelineTool
|
||||
|
||||
### RemoteTool
|
||||
|
||||
[[autodoc]] RemoteTool
|
||||
|
||||
### launch_gradio_demo
|
||||
|
||||
[[autodoc]] launch_gradio_demo
|
||||
|
||||
## Agent类型
|
||||
|
||||
Agents可以处理工具之间任何类型的对象;工具是多模态的,可以接受和返回文本、图像、音频、视频等类型。为了增加工具之间的兼容性,以及正确地在ipython(jupyter、colab、ipython notebooks等)中呈现这些返回值,我们实现了这些类型的包装类。
|
||||
|
||||
被包装的对象应该继续按照最初的行为方式运作;文本对象应该仍然像字符串一样运作,图像对象应该仍然像`PIL.Image`一样运作。
|
||||
|
||||
这些类型有三个特定目的:
|
||||
|
||||
- 对类型调用 `to_raw` 应该返回底层对象
|
||||
- 对类型调用 `to_string` 应该将对象作为字符串返回:在`AgentText`的情况下可能是字符串,但在其他情况下可能是对象序列化版本的路径
|
||||
- 在ipython内核中显示它应该正确显示对象
|
||||
|
||||
### AgentText
|
||||
|
||||
[[autodoc]] transformers.tools.agent_types.AgentText
|
||||
|
||||
### AgentImage
|
||||
|
||||
[[autodoc]] transformers.tools.agent_types.AgentImage
|
||||
|
||||
### AgentAudio
|
||||
|
||||
[[autodoc]] transformers.tools.agent_types.AgentAudio
|
||||
|
@ -54,6 +54,20 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
# Base objects, independent of any specific backend
|
||||
_import_structure = {
|
||||
"agents": [
|
||||
"Agent",
|
||||
"CodeAgent",
|
||||
"HfEngine",
|
||||
"PipelineTool",
|
||||
"ReactAgent",
|
||||
"ReactCodeAgent",
|
||||
"ReactJsonAgent",
|
||||
"Tool",
|
||||
"Toolbox",
|
||||
"ToolCollection",
|
||||
"launch_gradio_demo",
|
||||
"load_tool",
|
||||
],
|
||||
"audio_utils": [],
|
||||
"benchmark": [],
|
||||
"commands": [],
|
||||
@ -129,8 +143,8 @@ _import_structure = {
|
||||
"load_tf2_model_in_pytorch_model",
|
||||
"load_tf2_weights_in_pytorch_model",
|
||||
],
|
||||
"models": [],
|
||||
# Models
|
||||
"models": [],
|
||||
"models.albert": ["ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "AlbertConfig"],
|
||||
"models.align": [
|
||||
"ALIGN_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
||||
@ -1050,18 +1064,6 @@ _import_structure = {
|
||||
"SpecialTokensMixin",
|
||||
"TokenSpan",
|
||||
],
|
||||
"tools": [
|
||||
"Agent",
|
||||
"AzureOpenAiAgent",
|
||||
"HfAgent",
|
||||
"LocalAgent",
|
||||
"OpenAiAgent",
|
||||
"PipelineTool",
|
||||
"RemoteTool",
|
||||
"Tool",
|
||||
"launch_gradio_demo",
|
||||
"load_tool",
|
||||
],
|
||||
"trainer_callback": [
|
||||
"DefaultFlowCallback",
|
||||
"EarlyStoppingCallback",
|
||||
@ -5039,6 +5041,21 @@ else:
|
||||
# Direct imports for type-checking
|
||||
if TYPE_CHECKING:
|
||||
# Configuration
|
||||
# Agents
|
||||
from .agents import (
|
||||
Agent,
|
||||
CodeAgent,
|
||||
HfEngine,
|
||||
PipelineTool,
|
||||
ReactAgent,
|
||||
ReactCodeAgent,
|
||||
ReactJsonAgent,
|
||||
Tool,
|
||||
Toolbox,
|
||||
ToolCollection,
|
||||
launch_gradio_demo,
|
||||
load_tool,
|
||||
)
|
||||
from .configuration_utils import PretrainedConfig
|
||||
|
||||
# Data
|
||||
@ -6010,20 +6027,6 @@ if TYPE_CHECKING:
|
||||
TokenSpan,
|
||||
)
|
||||
|
||||
# Tools
|
||||
from .tools import (
|
||||
Agent,
|
||||
AzureOpenAiAgent,
|
||||
HfAgent,
|
||||
LocalAgent,
|
||||
OpenAiAgent,
|
||||
PipelineTool,
|
||||
RemoteTool,
|
||||
Tool,
|
||||
launch_gradio_demo,
|
||||
load_tool,
|
||||
)
|
||||
|
||||
# Trainer
|
||||
from .trainer_callback import (
|
||||
DefaultFlowCallback,
|
||||
|
@ -24,8 +24,9 @@ from ..utils import (
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"agents": ["Agent", "AzureOpenAiAgent", "HfAgent", "LocalAgent", "OpenAiAgent"],
|
||||
"base": ["PipelineTool", "RemoteTool", "Tool", "launch_gradio_demo", "load_tool"],
|
||||
"agents": ["Agent", "CodeAgent", "ReactAgent", "ReactCodeAgent", "ReactJsonAgent", "Toolbox"],
|
||||
"llm_engine": ["HfEngine"],
|
||||
"tools": ["PipelineTool", "Tool", "ToolCollection", "launch_gradio_demo", "load_tool"],
|
||||
}
|
||||
|
||||
try:
|
||||
@ -34,20 +35,17 @@ try:
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["default_tools"] = ["FinalAnswerTool", "PythonInterpreterTool"]
|
||||
_import_structure["document_question_answering"] = ["DocumentQuestionAnsweringTool"]
|
||||
_import_structure["image_captioning"] = ["ImageCaptioningTool"]
|
||||
_import_structure["image_question_answering"] = ["ImageQuestionAnsweringTool"]
|
||||
_import_structure["image_segmentation"] = ["ImageSegmentationTool"]
|
||||
_import_structure["speech_to_text"] = ["SpeechToTextTool"]
|
||||
_import_structure["text_classification"] = ["TextClassificationTool"]
|
||||
_import_structure["text_question_answering"] = ["TextQuestionAnsweringTool"]
|
||||
_import_structure["text_summarization"] = ["TextSummarizationTool"]
|
||||
_import_structure["text_to_speech"] = ["TextToSpeechTool"]
|
||||
_import_structure["translation"] = ["TranslationTool"]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .agents import Agent, AzureOpenAiAgent, HfAgent, LocalAgent, OpenAiAgent
|
||||
from .base import PipelineTool, RemoteTool, Tool, launch_gradio_demo, load_tool
|
||||
from .agents import Agent, CodeAgent, ReactAgent, ReactCodeAgent, ReactJsonAgent, Toolbox
|
||||
from .llm_engine import HfEngine
|
||||
from .tools import PipelineTool, Tool, ToolCollection, launch_gradio_demo, load_tool
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
@ -55,14 +53,10 @@ if TYPE_CHECKING:
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .default_tools import FinalAnswerTool, PythonInterpreterTool
|
||||
from .document_question_answering import DocumentQuestionAnsweringTool
|
||||
from .image_captioning import ImageCaptioningTool
|
||||
from .image_question_answering import ImageQuestionAnsweringTool
|
||||
from .image_segmentation import ImageSegmentationTool
|
||||
from .speech_to_text import SpeechToTextTool
|
||||
from .text_classification import TextClassificationTool
|
||||
from .text_question_answering import TextQuestionAnsweringTool
|
||||
from .text_summarization import TextSummarizationTool
|
||||
from .text_to_speech import TextToSpeechTool
|
||||
from .translation import TranslationTool
|
||||
else:
|
@ -1,5 +1,5 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -25,7 +25,6 @@ from ..utils import is_soundfile_availble, is_torch_available, is_vision_availab
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
if is_vision_available():
|
||||
import PIL.Image
|
||||
from PIL import Image
|
||||
from PIL.Image import Image as ImageType
|
||||
else:
|
||||
@ -33,6 +32,9 @@ else:
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from torch import Tensor
|
||||
else:
|
||||
Tensor = object
|
||||
|
||||
if is_soundfile_availble():
|
||||
import soundfile as sf
|
||||
@ -77,7 +79,7 @@ class AgentText(AgentType, str):
|
||||
return self._value
|
||||
|
||||
def to_string(self):
|
||||
return self._value
|
||||
return str(self._value)
|
||||
|
||||
|
||||
class AgentImage(AgentType, ImageType):
|
||||
@ -211,10 +213,7 @@ class AgentAudio(AgentType):
|
||||
|
||||
|
||||
AGENT_TYPE_MAPPING = {"text": AgentText, "image": AgentImage, "audio": AgentAudio}
|
||||
INSTANCE_TYPE_MAPPING = {str: AgentText}
|
||||
|
||||
if is_vision_available():
|
||||
INSTANCE_TYPE_MAPPING[PIL.Image] = AgentImage
|
||||
INSTANCE_TYPE_MAPPING = {str: AgentText, float: AgentText, int: AgentText, Tensor: AgentAudio, ImageType: AgentImage}
|
||||
|
||||
|
||||
def handle_agent_inputs(*args, **kwargs):
|
||||
@ -223,55 +222,14 @@ def handle_agent_inputs(*args, **kwargs):
|
||||
return args, kwargs
|
||||
|
||||
|
||||
def handle_agent_outputs(outputs, output_types=None):
|
||||
if isinstance(outputs, dict):
|
||||
decoded_outputs = {}
|
||||
for i, (k, v) in enumerate(outputs.items()):
|
||||
if output_types is not None:
|
||||
# If the class has defined outputs, we can map directly according to the class definition
|
||||
if output_types[i] in AGENT_TYPE_MAPPING:
|
||||
decoded_outputs[k] = AGENT_TYPE_MAPPING[output_types[i]](v)
|
||||
else:
|
||||
decoded_outputs[k] = AgentType(v)
|
||||
|
||||
else:
|
||||
# If the class does not have defined output, then we map according to the type
|
||||
for _k, _v in INSTANCE_TYPE_MAPPING.items():
|
||||
if isinstance(v, _k):
|
||||
decoded_outputs[k] = _v(v)
|
||||
if k not in decoded_outputs:
|
||||
decoded_outputs[k] = AgentType[v]
|
||||
|
||||
elif isinstance(outputs, (list, tuple)):
|
||||
decoded_outputs = type(outputs)()
|
||||
for i, v in enumerate(outputs):
|
||||
if output_types is not None:
|
||||
# If the class has defined outputs, we can map directly according to the class definition
|
||||
if output_types[i] in AGENT_TYPE_MAPPING:
|
||||
decoded_outputs.append(AGENT_TYPE_MAPPING[output_types[i]](v))
|
||||
else:
|
||||
decoded_outputs.append(AgentType(v))
|
||||
else:
|
||||
# If the class does not have defined output, then we map according to the type
|
||||
found = False
|
||||
for _k, _v in INSTANCE_TYPE_MAPPING.items():
|
||||
if isinstance(v, _k):
|
||||
decoded_outputs.append(_v(v))
|
||||
found = True
|
||||
|
||||
if not found:
|
||||
decoded_outputs.append(AgentType(v))
|
||||
|
||||
def handle_agent_outputs(output, output_type=None):
|
||||
if output_type in AGENT_TYPE_MAPPING:
|
||||
# If the class has defined outputs, we can map directly according to the class definition
|
||||
decoded_outputs = AGENT_TYPE_MAPPING[output_type](output)
|
||||
return decoded_outputs
|
||||
else:
|
||||
if output_types[0] in AGENT_TYPE_MAPPING:
|
||||
# If the class has defined outputs, we can map directly according to the class definition
|
||||
decoded_outputs = AGENT_TYPE_MAPPING[output_types[0]](outputs)
|
||||
|
||||
else:
|
||||
# If the class does not have defined output, then we map according to the type
|
||||
for _k, _v in INSTANCE_TYPE_MAPPING.items():
|
||||
if isinstance(outputs, _k):
|
||||
return _v(outputs)
|
||||
return AgentType(outputs)
|
||||
|
||||
return decoded_outputs
|
||||
# If the class does not have defined output, then we map according to the type
|
||||
for _k, _v in INSTANCE_TYPE_MAPPING.items():
|
||||
if isinstance(output, _k):
|
||||
return _v(output)
|
||||
return AgentType(output)
|
838
src/transformers/agents/agents.py
Normal file
838
src/transformers/agents/agents.py
Normal file
@ -0,0 +1,838 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Callable, Dict, List, Tuple, Union
|
||||
|
||||
from .. import is_torch_available
|
||||
from ..utils import logging as transformers_logging
|
||||
from ..utils.import_utils import is_pygments_available
|
||||
from .agent_types import AgentAudio, AgentImage, AgentText
|
||||
from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool, setup_default_tools
|
||||
from .llm_engine import HfEngine, MessageRole
|
||||
from .prompts import DEFAULT_CODE_SYSTEM_PROMPT, DEFAULT_REACT_CODE_SYSTEM_PROMPT, DEFAULT_REACT_JSON_SYSTEM_PROMPT
|
||||
from .python_interpreter import evaluate_python_code
|
||||
from .tools import (
|
||||
DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
|
||||
Tool,
|
||||
get_tool_description_with_args,
|
||||
load_tool,
|
||||
)
|
||||
|
||||
|
||||
if is_pygments_available():
|
||||
from pygments import highlight
|
||||
from pygments.formatters import Terminal256Formatter
|
||||
from pygments.lexers import PythonLexer
|
||||
|
||||
|
||||
class CustomFormatter(logging.Formatter):
|
||||
grey = "\x1b[38;20m"
|
||||
bold_yellow = "\x1b[33;1m"
|
||||
red = "\x1b[31;20m"
|
||||
green = "\x1b[32;20m"
|
||||
bold_red = "\x1b[31;1m"
|
||||
bold_white = "\x1b[37;1m"
|
||||
reset = "\x1b[0m"
|
||||
format = "%(message)s"
|
||||
|
||||
FORMATS = {
|
||||
logging.DEBUG: grey + format + reset,
|
||||
logging.INFO: format,
|
||||
logging.WARNING: bold_yellow + format + reset,
|
||||
31: reset + format + reset,
|
||||
32: green + format + reset,
|
||||
33: bold_white + format + reset,
|
||||
logging.ERROR: red + format + reset,
|
||||
logging.CRITICAL: bold_red + format + reset,
|
||||
}
|
||||
|
||||
def format(self, record):
|
||||
log_fmt = self.FORMATS.get(record.levelno)
|
||||
formatter = logging.Formatter(log_fmt)
|
||||
return formatter.format(record)
|
||||
|
||||
|
||||
logger = transformers_logging.get_logger(__name__)
|
||||
logger.propagate = False
|
||||
ch = logging.StreamHandler()
|
||||
ch.setFormatter(CustomFormatter())
|
||||
logger.addHandler(ch)
|
||||
|
||||
|
||||
def parse_json_blob(json_blob: str) -> Dict[str, str]:
|
||||
try:
|
||||
first_accolade_index = json_blob.find("{")
|
||||
last_accolade_index = [a.start() for a in list(re.finditer("}", json_blob))][-1]
|
||||
json_blob = json_blob[first_accolade_index : last_accolade_index + 1].replace('\\"', "'")
|
||||
json_data = json.loads(json_blob, strict=False)
|
||||
return json_data
|
||||
except json.JSONDecodeError as e:
|
||||
place = e.pos
|
||||
raise ValueError(
|
||||
f"The JSON blob you used is invalid: due to the following error: {e}. JSON blob was: {json_blob}, decoding failed at '{json_blob[place-4:place+5]}'."
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error in parsing the JSON blob: {e}")
|
||||
|
||||
|
||||
def parse_code_blob(code_blob: str) -> str:
|
||||
try:
|
||||
pattern = r"```(?:py|python)?\n(.*?)```"
|
||||
match = re.search(pattern, code_blob, re.DOTALL)
|
||||
return match.group(1).strip()
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"The code blob you used is invalid: due to the following error: {e}. This means that the regex pattern {pattern} was not respected. Make sure to correct its formatting. Code blob was: {code_blob}"
|
||||
)
|
||||
|
||||
|
||||
def parse_json_tool_call(json_blob: str) -> Tuple[str, Dict[str, str]]:
|
||||
json_blob = json_blob.replace("```json", "").replace("```", "")
|
||||
tool_call = parse_json_blob(json_blob)
|
||||
if "action" in tool_call and "action_input" in tool_call:
|
||||
return tool_call["action"], tool_call["action_input"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Missing keys: {[key for key in ['action', 'action_input'] if key not in tool_call]} in blob {tool_call}"
|
||||
)
|
||||
|
||||
|
||||
def parse_text_tool_call(text: str) -> Tuple[str, Union[str, Dict[str, str]]]:
|
||||
"""
|
||||
Expects a text in the format: 'Action:', 'Action input:', 'Observation:'. 'Action input:' contains a json string with input arguments.
|
||||
"""
|
||||
try:
|
||||
if "Observation:" in text:
|
||||
text = text.split("Observation:")[0]
|
||||
if "Action:" in text:
|
||||
text = text.split("Action:")[1]
|
||||
tool_name, tool_input = text.split("Action input:")
|
||||
if "{" in tool_input:
|
||||
tool_input = parse_json_blob(tool_input)
|
||||
else:
|
||||
tool_input = tool_input.strip().replace('"', "")
|
||||
return tool_name.strip().replace('"', "").replace("\\", ""), tool_input
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Error in parsing the text tool call: {e}. Be sure to provide the correct format. DO NOT repeat your previous incorrect tool call."
|
||||
)
|
||||
|
||||
|
||||
def to_text(input: Union[List[Dict[str, str]], Dict[str, str], str]) -> str:
|
||||
if isinstance(input, list):
|
||||
return "\n".join([m["content"] for m in input])
|
||||
elif isinstance(input, dict):
|
||||
return input["content"]
|
||||
else:
|
||||
return input
|
||||
|
||||
|
||||
HUGGINGFACE_DEFAULT_TOOLS = {}
|
||||
_tools_are_initialized = False
|
||||
|
||||
|
||||
class Toolbox:
|
||||
"""
|
||||
The toolbox contains all tools that the agent can perform operations with, as well as a few methods to
|
||||
manage them.
|
||||
|
||||
Args:
|
||||
tools (`List[Tool]`):
|
||||
The list of tools to instantiate the toolbox with
|
||||
add_base_tools (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
||||
Whether to add the tools available within `transformers` to the toolbox.
|
||||
"""
|
||||
|
||||
def __init__(self, tools: List[Tool], add_base_tools: bool = False):
|
||||
self._tools = {tool.name: tool for tool in tools}
|
||||
if add_base_tools:
|
||||
self.add_base_tools()
|
||||
self._load_tools_if_needed()
|
||||
|
||||
def add_base_tools(self, add_python_interpreter: bool = False):
|
||||
global _tools_are_initialized
|
||||
global HUGGINGFACE_DEFAULT_TOOLS
|
||||
if not _tools_are_initialized:
|
||||
HUGGINGFACE_DEFAULT_TOOLS = setup_default_tools(logger)
|
||||
_tools_are_initialized = True
|
||||
for tool in HUGGINGFACE_DEFAULT_TOOLS.values():
|
||||
if tool.name != "python_interpreter" or add_python_interpreter:
|
||||
self.add_tool(tool)
|
||||
self._load_tools_if_needed()
|
||||
|
||||
@property
|
||||
def tools(self) -> Dict[str, Tool]:
|
||||
"""Get all tools currently in the toolbox"""
|
||||
return self._tools
|
||||
|
||||
def show_tool_descriptions(self, tool_description_template: str = None) -> str:
|
||||
"""
|
||||
Returns the description of all tools in the toolbox
|
||||
|
||||
Args:
|
||||
tool_description_template (`str`, *optional*):
|
||||
The template to use to describe the tools. If not provided, the default template will be used.
|
||||
"""
|
||||
return "\n".join(
|
||||
[get_tool_description_with_args(tool, tool_description_template) for tool in self._tools.values()]
|
||||
)
|
||||
|
||||
def add_tool(self, tool: Tool):
|
||||
"""
|
||||
Adds a tool to the toolbox
|
||||
|
||||
Args:
|
||||
tool (`Tool`):
|
||||
The tool to add to the toolbox.
|
||||
"""
|
||||
if tool.name in self._tools:
|
||||
raise KeyError(f"Error: tool {tool.name} already exists in the toolbox.")
|
||||
self._tools[tool.name] = tool
|
||||
|
||||
def remove_tool(self, tool_name: str):
|
||||
"""
|
||||
Removes a tool from the toolbox
|
||||
|
||||
Args:
|
||||
tool_name (`str`):
|
||||
The tool to remove from the toolbox.
|
||||
"""
|
||||
if tool_name not in self._tools:
|
||||
raise KeyError(
|
||||
f"Error: tool {tool_name} not found in toolbox for removal, should be instead one of {list(self._tools.keys())}."
|
||||
)
|
||||
del self._tools[tool_name]
|
||||
|
||||
def update_tool(self, tool: Tool):
|
||||
"""
|
||||
Updates a tool in the toolbox according to its name.
|
||||
|
||||
Args:
|
||||
tool (`Tool`):
|
||||
The tool to update to the toolbox.
|
||||
"""
|
||||
if tool.name not in self._tools:
|
||||
raise KeyError(
|
||||
f"Error: tool {tool.name} not found in toolbox for update, should be instead one of {list(self._tools.keys())}."
|
||||
)
|
||||
self._tools[tool.name] = tool
|
||||
|
||||
def clear_toolbox(self):
|
||||
"""Clears the toolbox"""
|
||||
self._tools = {}
|
||||
|
||||
def _load_tools_if_needed(self):
|
||||
for name, tool in self._tools.items():
|
||||
if not isinstance(tool, Tool):
|
||||
task_or_repo_id = tool.task if tool.repo_id is None else tool.repo_id
|
||||
self._tools[name] = load_tool(task_or_repo_id)
|
||||
|
||||
def __repr__(self):
|
||||
toolbox_description = "Toolbox contents:\n"
|
||||
for tool in self._tools.values():
|
||||
toolbox_description += f"\t{tool.name}: {tool.description}\n"
|
||||
return toolbox_description
|
||||
|
||||
|
||||
def format_prompt_with_tools(toolbox: Toolbox, prompt_template: str, tool_description_template: str) -> str:
|
||||
tool_descriptions = toolbox.show_tool_descriptions(tool_description_template)
|
||||
prompt = prompt_template.replace("<<tool_descriptions>>", tool_descriptions)
|
||||
if "<<tool_names>>" in prompt:
|
||||
tool_names = [f"'{tool_name}'" for tool_name in toolbox.tools.keys()]
|
||||
prompt = prompt.replace("<<tool_names>>", ", ".join(tool_names))
|
||||
return prompt
|
||||
|
||||
|
||||
class AgentError(Exception):
|
||||
"""Base class for other agent-related exceptions"""
|
||||
|
||||
def __init__(self, message):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
|
||||
|
||||
class AgentParsingError(AgentError):
|
||||
"""Exception raised for errors in parsing in the agent"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AgentExecutionError(AgentError):
|
||||
"""Exception raised for errors in execution in the agent"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AgentMaxIterationsError(AgentError):
|
||||
"""Exception raised for errors in execution in the agent"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AgentGenerationError(AgentError):
|
||||
"""Exception raised for errors in generation in the agent"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class Agent:
|
||||
def __init__(
|
||||
self,
|
||||
tools: Union[List[Tool], Toolbox],
|
||||
llm_engine: Callable = HfEngine(),
|
||||
system_prompt=DEFAULT_REACT_JSON_SYSTEM_PROMPT,
|
||||
tool_description_template=None,
|
||||
additional_args={},
|
||||
max_iterations: int = 6,
|
||||
tool_parser=parse_json_tool_call,
|
||||
add_base_tools: bool = False,
|
||||
verbose: int = 0,
|
||||
memory_verbose: bool = False,
|
||||
):
|
||||
self.agent_name = self.__class__.__name__
|
||||
self.llm_engine = llm_engine
|
||||
self.system_prompt_template = system_prompt
|
||||
self.tool_description_template = (
|
||||
tool_description_template if tool_description_template else DEFAULT_TOOL_DESCRIPTION_TEMPLATE
|
||||
)
|
||||
self.additional_args = additional_args
|
||||
self.max_iterations = max_iterations
|
||||
self.logger = logger
|
||||
self.tool_parser = tool_parser
|
||||
|
||||
if isinstance(tools, Toolbox):
|
||||
self._toolbox = tools
|
||||
if add_base_tools:
|
||||
if not is_torch_available():
|
||||
raise ImportError("Using the base tools requires torch to be installed.")
|
||||
|
||||
self._toolbox.add_base_tools(add_python_interpreter=(self.__class__ == ReactJsonAgent))
|
||||
else:
|
||||
self._toolbox = Toolbox(tools, add_base_tools=add_base_tools)
|
||||
|
||||
self.system_prompt = format_prompt_with_tools(
|
||||
self._toolbox, self.system_prompt_template, self.tool_description_template
|
||||
)
|
||||
self.prompt = None
|
||||
self.logs = []
|
||||
self.task = None
|
||||
self.memory_verbose = memory_verbose
|
||||
|
||||
if verbose == 0:
|
||||
logger.setLevel(logging.WARNING)
|
||||
elif verbose == 1:
|
||||
logger.setLevel(logging.INFO)
|
||||
elif verbose == 2:
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
@property
|
||||
def toolbox(self) -> Toolbox:
|
||||
"""Get the toolbox currently available to the agent"""
|
||||
return self._toolbox
|
||||
|
||||
def initialize_for_run(self, task: str, **kwargs):
|
||||
self.task = task
|
||||
if len(kwargs) > 0:
|
||||
self.task += f"\nYou have been provided with these initial arguments: {str(kwargs)}."
|
||||
self.state = kwargs.copy()
|
||||
self.system_prompt = format_prompt_with_tools(
|
||||
self._toolbox, self.system_prompt_template, self.tool_description_template
|
||||
)
|
||||
self.logs = [{"system_prompt": self.system_prompt, "task": self.task}]
|
||||
self.logger.warn("======== New task ========")
|
||||
self.logger.log(33, self.task)
|
||||
self.logger.debug("System prompt is as follows:")
|
||||
self.logger.debug(self.system_prompt)
|
||||
|
||||
def write_inner_memory_from_logs(self) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Reads past llm_outputs, actions, and observations or errors from the logs into a series of messages
|
||||
that can be used as input to the LLM.
|
||||
"""
|
||||
prompt_message = {"role": MessageRole.SYSTEM, "content": self.logs[0]["system_prompt"]}
|
||||
task_message = {
|
||||
"role": MessageRole.USER,
|
||||
"content": "Task: " + self.logs[0]["task"],
|
||||
}
|
||||
memory = [prompt_message, task_message]
|
||||
for i, step_log in enumerate(self.logs[1:]):
|
||||
if "llm_output" in step_log:
|
||||
thought_message = {"role": MessageRole.ASSISTANT, "content": step_log["llm_output"] + "\n"}
|
||||
memory.append(thought_message)
|
||||
|
||||
if "error" in step_log:
|
||||
message_content = (
|
||||
"Error: "
|
||||
+ str(step_log["error"])
|
||||
+ "\nNow let's retry: take care not to repeat previous errors! Try to adopt different approaches.\n"
|
||||
)
|
||||
elif "observation" in step_log:
|
||||
message_content = f"Observation: {step_log['observation']}"
|
||||
tool_response_message = {"role": MessageRole.TOOL_RESPONSE, "content": message_content}
|
||||
memory.append(tool_response_message)
|
||||
|
||||
if len(memory) % 3 == 0:
|
||||
reminder_content = (
|
||||
"Reminder: you are working towards solving the following task: " + self.logs[0]["task"]
|
||||
)
|
||||
reminder_content += "\nHere is a summary of your past tool calls and their results:"
|
||||
for j in range(i + 1):
|
||||
reminder_content += "\nStep " + str(j + 1)
|
||||
if "tool_call" in self.logs[j]:
|
||||
reminder_content += "\nTool call:" + str(self.logs[j]["tool_call"])
|
||||
if self.memory_verbose:
|
||||
if "observation" in self.logs[j]:
|
||||
reminder_content += "\nObservation:" + str(self.logs[j]["observation"])
|
||||
if "error" in self.logs[j]:
|
||||
reminder_content += "\nError:" + str(self.logs[j]["error"])
|
||||
memory.append(
|
||||
{
|
||||
"role": MessageRole.USER,
|
||||
"content": reminder_content,
|
||||
}
|
||||
)
|
||||
return memory
|
||||
|
||||
def extract_action(self, llm_output: str, split_token: str) -> str:
|
||||
"""
|
||||
Parse action from the LLM output
|
||||
|
||||
Args:
|
||||
llm_output (`str`): Output of the LLM
|
||||
split_token (`str`): Separator for the action. Should match the example in the system prompt.
|
||||
"""
|
||||
try:
|
||||
split = llm_output.split(split_token)
|
||||
rationale, action = (
|
||||
split[-2],
|
||||
split[-1],
|
||||
) # NOTE: using indexes starting from the end solves for when you have more than one split_token in the output
|
||||
except Exception as e:
|
||||
self.logger.error(e, exc_info=1)
|
||||
raise AgentParsingError(
|
||||
f"Error: No '{split_token}' token provided in your output.\nYour output:\n{llm_output}\n. Be sure to include an action, prefaced with '{split_token}'!"
|
||||
)
|
||||
return rationale, action
|
||||
|
||||
def execute_tool_call(self, tool_name: str, arguments: Dict[str, str]) -> Any:
|
||||
"""
|
||||
Execute tool with the provided input and returns the result.
|
||||
This method replaces arguments with the actual values from the state if they refer to state variables.
|
||||
|
||||
Args:
|
||||
tool_name (`str`): Name of the Tool to execute (shoulde be one from self.toolbox).
|
||||
arguments (Dict[str, str]): Arguments passed to the Tool.
|
||||
"""
|
||||
if tool_name not in self.toolbox.tools:
|
||||
error_msg = f"Error: unknown tool {tool_name}, should be instead one of {list(self.toolbox.tools.keys())}."
|
||||
self.logger.error(error_msg, exc_info=1)
|
||||
raise AgentExecutionError(error_msg)
|
||||
|
||||
try:
|
||||
if isinstance(arguments, str):
|
||||
observation = self.toolbox.tools[tool_name](arguments)
|
||||
else:
|
||||
for key, value in arguments.items():
|
||||
# if the value is the name of a state variable like "image.png", replace it with the actual value
|
||||
if isinstance(value, str) and value in self.state:
|
||||
arguments[key] = self.state[value]
|
||||
observation = self.toolbox.tools[tool_name](**arguments)
|
||||
return observation
|
||||
except Exception as e:
|
||||
raise AgentExecutionError(
|
||||
f"Error in tool call execution: {e}\nYou should only use this tool with a correct input.\n"
|
||||
f"As a reminder, this tool's description is the following:\n{get_tool_description_with_args(self.toolbox.tools[tool_name])}"
|
||||
)
|
||||
|
||||
def log_code_action(self, code_action: str) -> None:
|
||||
self.logger.warning("==== Agent is executing the code below:")
|
||||
if is_pygments_available():
|
||||
self.logger.log(
|
||||
31, highlight(code_action, PythonLexer(ensurenl=False), Terminal256Formatter(style="nord"))
|
||||
)
|
||||
else:
|
||||
self.logger.log(31, code_action)
|
||||
self.logger.warning("====")
|
||||
|
||||
def run(self, **kwargs):
|
||||
"""To be implemented in the child class"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CodeAgent(Agent):
|
||||
"""
|
||||
A class for an agent that solves the given task using a single block of code. It plans all its actions, then executes all in one shot.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tools: List[Tool],
|
||||
llm_engine: Callable = HfEngine(),
|
||||
system_prompt: str = DEFAULT_CODE_SYSTEM_PROMPT,
|
||||
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
tools=tools,
|
||||
llm_engine=llm_engine,
|
||||
system_prompt=system_prompt,
|
||||
tool_description_template=tool_description_template,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not is_pygments_available():
|
||||
transformers_logging.warning_once(
|
||||
logger,
|
||||
"pygments isn't installed. Installing pygments will enable color syntax highlighting in the "
|
||||
"CodeAgent.",
|
||||
)
|
||||
|
||||
self.python_evaluator = evaluate_python_code
|
||||
|
||||
def parse_code_blob(self, result: str) -> str:
|
||||
"""
|
||||
Override this method if you want to change the way the code is
|
||||
cleaned in the `run` method.
|
||||
"""
|
||||
return parse_code_blob(result)
|
||||
|
||||
def run(self, task: str, return_generated_code: bool = False, **kwargs):
|
||||
"""
|
||||
Runs the agent for the given task.
|
||||
|
||||
Args:
|
||||
task (`str`): The task to perform
|
||||
return_generated_code (`bool`, *optional*, defaults to `False`): Whether to return the generated code instead of running it
|
||||
kwargs (additional keyword arguments, *optional*):
|
||||
Any keyword argument to send to the agent when evaluating the code.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
from transformers.agents import CodeAgent, PythonInterpreterTool
|
||||
|
||||
python_interpreter = PythonInterpreterTool()
|
||||
agent = CodeAgent(tools=[python_interpreter])
|
||||
agent.run("What is the result of 2 power 3.7384?")
|
||||
```
|
||||
"""
|
||||
self.initialize_for_run(task, **kwargs)
|
||||
|
||||
# Run LLM
|
||||
prompt_message = {"role": MessageRole.SYSTEM, "content": self.system_prompt}
|
||||
task_message = {
|
||||
"role": MessageRole.USER,
|
||||
"content": "Task: " + self.task,
|
||||
}
|
||||
|
||||
self.prompt = [prompt_message, task_message]
|
||||
self.logger.info("====Executing with this prompt====")
|
||||
self.logger.info(self.prompt)
|
||||
llm_output = self.llm_engine(self.prompt, stop_sequences=["<end_code>"])
|
||||
|
||||
if return_generated_code:
|
||||
return llm_output
|
||||
|
||||
# Parse
|
||||
_, code_action = self.extract_action(llm_output=llm_output, split_token="Code:")
|
||||
|
||||
try:
|
||||
code_action = self.parse_code_blob(code_action)
|
||||
except Exception as e:
|
||||
error_msg = f"Error in code parsing: {e}. Be sure to provide correct code"
|
||||
self.logger.error(error_msg, exc_info=1)
|
||||
return error_msg
|
||||
|
||||
# Execute
|
||||
self.log_code_action(code_action)
|
||||
try:
|
||||
available_tools = {**BASE_PYTHON_TOOLS.copy(), **self.toolbox.tools}
|
||||
output = self.python_evaluator(code_action, available_tools, state=self.state)
|
||||
self.logger.info(self.state["print_outputs"])
|
||||
return output
|
||||
except Exception as e:
|
||||
error_msg = f"Error in execution: {e}. Be sure to provide correct code."
|
||||
self.logger.error(error_msg, exc_info=1)
|
||||
return error_msg
|
||||
|
||||
|
||||
class ReactAgent(Agent):
|
||||
"""
|
||||
This agent that solves the given task step by step, using the ReAct framework:
|
||||
While the objective is not reached, the agent will perform a cycle of thinking and acting.
|
||||
The action will be parsed from the LLM output: it consists in calls to tools from the toolbox, with arguments chosen by the LLM engine.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tools: List[Tool],
|
||||
llm_engine: Callable = HfEngine(),
|
||||
system_prompt: str = DEFAULT_REACT_CODE_SYSTEM_PROMPT,
|
||||
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
tools=tools,
|
||||
llm_engine=llm_engine,
|
||||
system_prompt=system_prompt,
|
||||
tool_description_template=tool_description_template,
|
||||
**kwargs,
|
||||
)
|
||||
if "final_answer" not in self._toolbox.tools:
|
||||
self._toolbox.add_tool(FinalAnswerTool())
|
||||
|
||||
def run(self, task: str, **kwargs):
|
||||
"""
|
||||
Runs the agent for the given task.
|
||||
|
||||
Args:
|
||||
task (`str`): The task to perform
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
from transformers.agents import ReactJsonAgent, PythonInterpreterTool
|
||||
|
||||
python_interpreter = PythonInterpreterTool()
|
||||
agent = ReactJsonAgent(tools=[python_interpreter])
|
||||
agent.run("What is the result of 2 power 3.7384?")
|
||||
```
|
||||
"""
|
||||
self.initialize_for_run(task, **kwargs)
|
||||
|
||||
final_answer = None
|
||||
iteration = 0
|
||||
while final_answer is None and iteration < self.max_iterations:
|
||||
try:
|
||||
final_answer = self.step()
|
||||
except AgentError as e:
|
||||
self.logger.error(e, exc_info=1)
|
||||
self.logs[-1]["error"] = e
|
||||
finally:
|
||||
iteration += 1
|
||||
|
||||
if final_answer is None and iteration == self.max_iterations:
|
||||
error_message = "Reached max iterations."
|
||||
self.logs.append({"error": AgentMaxIterationsError(error_message)})
|
||||
self.logger.error(error_message, exc_info=1)
|
||||
|
||||
self.prompt = [
|
||||
{
|
||||
"role": MessageRole.SYSTEM,
|
||||
"content": "An agent tried to answer a user query but it failed to do so. You are tasked with providing an answer instead. Here is the agent's memory:",
|
||||
}
|
||||
]
|
||||
self.prompt += self.write_inner_memory_from_logs()[1:]
|
||||
self.prompt += [
|
||||
{
|
||||
"role": MessageRole.USER,
|
||||
"content": f"Based on the above, please provide an answer to the following user request:\n{task}",
|
||||
}
|
||||
]
|
||||
try:
|
||||
final_answer = self.llm_engine(self.prompt, stop_sequences=["Observation:"])
|
||||
except Exception as e:
|
||||
final_answer = f"Error in generating final llm output: {e}."
|
||||
|
||||
return final_answer
|
||||
|
||||
|
||||
class ReactJsonAgent(ReactAgent):
|
||||
"""
|
||||
This agent that solves the given task step by step, using the ReAct framework:
|
||||
While the objective is not reached, the agent will perform a cycle of thinking and acting.
|
||||
The tool calls will be formulated by the LLM in JSON format, then parsed and executed.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tools: List[Tool],
|
||||
llm_engine: Callable = HfEngine(),
|
||||
system_prompt: str = DEFAULT_REACT_JSON_SYSTEM_PROMPT,
|
||||
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
tools=tools,
|
||||
llm_engine=llm_engine,
|
||||
system_prompt=system_prompt,
|
||||
tool_description_template=tool_description_template,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def step(self):
|
||||
"""
|
||||
Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
|
||||
The errors are raised here, they are caught and logged in the run() method.
|
||||
"""
|
||||
agent_memory = self.write_inner_memory_from_logs()
|
||||
|
||||
self.logs[-1]["agent_memory"] = agent_memory.copy()
|
||||
self.prompt = agent_memory
|
||||
self.logger.debug("===== New step =====")
|
||||
|
||||
# Add new step in logs
|
||||
self.logs.append({})
|
||||
self.logger.info("===== Calling LLM with this last message: =====")
|
||||
self.logger.info(self.prompt[-1])
|
||||
|
||||
try:
|
||||
llm_output = self.llm_engine(self.prompt, stop_sequences=["Observation:"])
|
||||
except Exception as e:
|
||||
raise AgentGenerationError(f"Error in generating llm output: {e}.")
|
||||
self.logger.debug("===== Output message of the LLM: =====")
|
||||
self.logger.debug(llm_output)
|
||||
self.logs[-1]["llm_output"] = llm_output
|
||||
|
||||
# Parse
|
||||
self.logger.debug("===== Extracting action =====")
|
||||
rationale, action = self.extract_action(llm_output=llm_output, split_token="Action:")
|
||||
|
||||
try:
|
||||
tool_name, arguments = self.tool_parser(action)
|
||||
except Exception as e:
|
||||
raise AgentParsingError(f"Could not parse the given action: {e}.")
|
||||
|
||||
self.logs[-1]["rationale"] = rationale
|
||||
self.logs[-1]["tool_call"] = {"tool_name": tool_name, "tool_arguments": arguments}
|
||||
|
||||
# Execute
|
||||
self.logger.warning(f"Calling tool: '{tool_name}' with arguments: {arguments}")
|
||||
if tool_name == "final_answer":
|
||||
if isinstance(arguments, dict):
|
||||
answer = arguments["answer"]
|
||||
else:
|
||||
answer = arguments
|
||||
if answer in self.state: # if the answer is a state variable, return the value
|
||||
answer = self.state[answer]
|
||||
return answer
|
||||
else:
|
||||
observation = self.execute_tool_call(tool_name, arguments)
|
||||
observation_type = type(observation)
|
||||
if observation_type == AgentText:
|
||||
updated_information = str(observation).strip()
|
||||
else:
|
||||
# TODO: observation naming could allow for different names of same type
|
||||
if observation_type == AgentImage:
|
||||
observation_name = "image.png"
|
||||
elif observation_type == AgentAudio:
|
||||
observation_name = "audio.mp3"
|
||||
else:
|
||||
observation_name = "object.object"
|
||||
|
||||
self.state[observation_name] = observation
|
||||
updated_information = f"Stored '{observation_name}' in memory."
|
||||
|
||||
self.logger.info(updated_information)
|
||||
self.logs[-1]["observation"] = updated_information
|
||||
return None
|
||||
|
||||
|
||||
class ReactCodeAgent(ReactAgent):
|
||||
"""
|
||||
This agent that solves the given task step by step, using the ReAct framework:
|
||||
While the objective is not reached, the agent will perform a cycle of thinking and acting.
|
||||
The tool calls will be formulated by the LLM in code format, then parsed and executed.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tools: List[Tool],
|
||||
llm_engine: Callable = HfEngine(),
|
||||
system_prompt: str = DEFAULT_REACT_CODE_SYSTEM_PROMPT,
|
||||
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
tools=tools,
|
||||
llm_engine=llm_engine,
|
||||
system_prompt=system_prompt,
|
||||
tool_description_template=tool_description_template,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not is_pygments_available():
|
||||
transformers_logging.warning_once(
|
||||
logger,
|
||||
"pygments isn't installed. Installing pygments will enable color syntax highlighting in the "
|
||||
"ReactCodeAgent.",
|
||||
)
|
||||
|
||||
self.python_evaluator = evaluate_python_code
|
||||
|
||||
def step(self):
|
||||
"""
|
||||
Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
|
||||
The errors are raised here, they are caught and logged in the run() method.
|
||||
"""
|
||||
agent_memory = self.write_inner_memory_from_logs()
|
||||
self.logs[-1]["agent_memory"] = agent_memory.copy()
|
||||
|
||||
self.prompt = agent_memory.copy()
|
||||
|
||||
self.logger.debug("===== New step =====")
|
||||
|
||||
# Add new step in logs
|
||||
self.logs.append({})
|
||||
|
||||
self.logger.info("===== Calling LLM with these last messages: =====")
|
||||
self.logger.info(self.prompt[-2:])
|
||||
|
||||
try:
|
||||
llm_output = self.llm_engine(self.prompt, stop_sequences=["<end_code>", "Observation:"])
|
||||
except Exception as e:
|
||||
raise AgentGenerationError(f"Error in generating llm output: {e}.")
|
||||
|
||||
self.logger.debug("===== Output message of the LLM: =====")
|
||||
self.logger.debug(llm_output)
|
||||
self.logs[-1]["llm_output"] = llm_output
|
||||
|
||||
# Parse
|
||||
self.logger.debug("===== Extracting action =====")
|
||||
rationale, raw_code_action = self.extract_action(llm_output=llm_output, split_token="Code:")
|
||||
|
||||
try:
|
||||
code_action = parse_code_blob(raw_code_action)
|
||||
except Exception as e:
|
||||
error_msg = f"Error in code parsing: {e}. Make sure to provide correct code"
|
||||
raise AgentParsingError(error_msg)
|
||||
|
||||
self.logs[-1]["rationale"] = rationale
|
||||
self.logs[-1]["tool_call"] = {"tool_name": "code interpreter", "tool_arguments": code_action}
|
||||
|
||||
# Execute
|
||||
self.log_code_action(code_action)
|
||||
try:
|
||||
available_tools = {**BASE_PYTHON_TOOLS.copy(), **self.toolbox.tools}
|
||||
result = self.python_evaluator(code_action, available_tools, state=self.state)
|
||||
information = self.state["print_outputs"]
|
||||
self.logger.warning("Print outputs:")
|
||||
self.logger.log(32, information)
|
||||
self.logs[-1]["observation"] = information
|
||||
except Exception as e:
|
||||
error_msg = f"Failed while trying to execute the code below:\n{CustomFormatter.reset + code_action + CustomFormatter.reset}\nThis failed due to the following error:\n{str(e)}"
|
||||
if "'dict' object has no attribute 'read'" in str(e):
|
||||
error_msg += "\nYou get this error because you passed a dict as input for one of the arguments instead of a string."
|
||||
raise AgentExecutionError(error_msg)
|
||||
for line in code_action.split("\n"):
|
||||
if line[: len("final_answer")] == "final_answer":
|
||||
self.logger.warning(">>> Final answer:")
|
||||
self.logger.log(32, result)
|
||||
return result
|
||||
return None
|
168
src/transformers/agents/default_tools.py
Normal file
168
src/transformers/agents/default_tools.py
Normal file
@ -0,0 +1,168 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib.util
|
||||
import json
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from math import sqrt
|
||||
from typing import Dict
|
||||
|
||||
from huggingface_hub import hf_hub_download, list_spaces
|
||||
|
||||
from ..utils import is_offline_mode
|
||||
from .python_interpreter import LIST_SAFE_MODULES, evaluate_python_code
|
||||
from .tools import TASK_MAPPING, TOOL_CONFIG_FILE, Tool
|
||||
|
||||
|
||||
def custom_print(*args):
|
||||
return " ".join(map(str, args))
|
||||
|
||||
|
||||
BASE_PYTHON_TOOLS = {
|
||||
"print": custom_print,
|
||||
"range": range,
|
||||
"float": float,
|
||||
"int": int,
|
||||
"bool": bool,
|
||||
"str": str,
|
||||
"round": round,
|
||||
"ceil": math.ceil,
|
||||
"floor": math.floor,
|
||||
"log": math.log,
|
||||
"exp": math.exp,
|
||||
"sin": math.sin,
|
||||
"cos": math.cos,
|
||||
"tan": math.tan,
|
||||
"asin": math.asin,
|
||||
"acos": math.acos,
|
||||
"atan": math.atan,
|
||||
"atan2": math.atan2,
|
||||
"degrees": math.degrees,
|
||||
"radians": math.radians,
|
||||
"pow": math.pow,
|
||||
"sqrt": sqrt,
|
||||
"len": len,
|
||||
"sum": sum,
|
||||
"max": max,
|
||||
"min": min,
|
||||
"abs": abs,
|
||||
"list": list,
|
||||
"dict": dict,
|
||||
"tuple": tuple,
|
||||
"set": set,
|
||||
"enumerate": enumerate,
|
||||
"zip": zip,
|
||||
"reversed": reversed,
|
||||
"sorted": sorted,
|
||||
"all": all,
|
||||
"any": any,
|
||||
"map": map,
|
||||
"filter": filter,
|
||||
"ord": ord,
|
||||
"chr": chr,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class PreTool:
|
||||
name: str
|
||||
inputs: Dict[str, str]
|
||||
output_type: type
|
||||
task: str
|
||||
description: str
|
||||
repo_id: str
|
||||
|
||||
|
||||
HUGGINGFACE_DEFAULT_TOOLS_FROM_HUB = [
|
||||
"image-transformation",
|
||||
"text-to-image",
|
||||
]
|
||||
|
||||
|
||||
def get_remote_tools(logger, organization="huggingface-tools"):
|
||||
if is_offline_mode():
|
||||
logger.info("You are in offline mode, so remote tools are not available.")
|
||||
return {}
|
||||
|
||||
spaces = list_spaces(author=organization)
|
||||
tools = {}
|
||||
for space_info in spaces:
|
||||
repo_id = space_info.id
|
||||
resolved_config_file = hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="space")
|
||||
with open(resolved_config_file, encoding="utf-8") as reader:
|
||||
config = json.load(reader)
|
||||
task = repo_id.split("/")[-1]
|
||||
tools[config["name"]] = PreTool(
|
||||
task=task,
|
||||
description=config["description"],
|
||||
repo_id=repo_id,
|
||||
name=task,
|
||||
inputs=config["inputs"],
|
||||
output_type=config["output_type"],
|
||||
)
|
||||
|
||||
return tools
|
||||
|
||||
|
||||
def setup_default_tools(logger):
|
||||
default_tools = {}
|
||||
main_module = importlib.import_module("transformers")
|
||||
tools_module = main_module.agents
|
||||
|
||||
for task_name, tool_class_name in TASK_MAPPING.items():
|
||||
tool_class = getattr(tools_module, tool_class_name)
|
||||
default_tools[tool_class.name] = PreTool(
|
||||
name=tool_class.name,
|
||||
inputs=tool_class.inputs,
|
||||
output_type=tool_class.output_type,
|
||||
task=task_name,
|
||||
description=tool_class.description,
|
||||
repo_id=None,
|
||||
)
|
||||
|
||||
return default_tools
|
||||
|
||||
|
||||
class PythonInterpreterTool(Tool):
|
||||
name = "python_interpreter"
|
||||
description = "This is a tool that evaluates python code. It can be used to perform calculations."
|
||||
|
||||
inputs = {
|
||||
"code": {
|
||||
"type": "text",
|
||||
"description": (
|
||||
"The code snippet to evaluate. All variables used in this snippet must be defined in this same snippet, "
|
||||
f"else you will get an error. This code can only import the following python libraries: {LIST_SAFE_MODULES}."
|
||||
),
|
||||
}
|
||||
}
|
||||
output_type = "text"
|
||||
available_tools = BASE_PYTHON_TOOLS.copy()
|
||||
|
||||
def forward(self, code):
|
||||
output = str(evaluate_python_code(code, tools=self.available_tools))
|
||||
return output
|
||||
|
||||
|
||||
class FinalAnswerTool(Tool):
|
||||
name = "final_answer"
|
||||
description = "Provides a final answer to the given problem"
|
||||
inputs = {"answer": {"type": "text", "description": "The final answer to the problem"}}
|
||||
output_type = "any"
|
||||
|
||||
def forward(self, answer):
|
||||
return answer
|
@ -16,10 +16,13 @@
|
||||
# limitations under the License.
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..models.auto import AutoProcessor
|
||||
from ..models.vision_encoder_decoder import VisionEncoderDecoderModel
|
||||
from ..utils import is_vision_available
|
||||
from .base import PipelineTool
|
||||
from .tools import PipelineTool
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
@ -28,17 +31,19 @@ if is_vision_available():
|
||||
|
||||
class DocumentQuestionAnsweringTool(PipelineTool):
|
||||
default_checkpoint = "naver-clova-ix/donut-base-finetuned-docvqa"
|
||||
description = (
|
||||
"This is a tool that answers a question about an document (pdf). It takes an input named `document` which "
|
||||
"should be the document containing the information, as well as a `question` that is the question about the "
|
||||
"document. It returns a text that contains the answer to the question."
|
||||
)
|
||||
description = "This is a tool that answers a question about an document (pdf). It returns a text that contains the answer to the question."
|
||||
name = "document_qa"
|
||||
pre_processor_class = AutoProcessor
|
||||
model_class = VisionEncoderDecoderModel
|
||||
|
||||
inputs = ["image", "text"]
|
||||
outputs = ["text"]
|
||||
inputs = {
|
||||
"document": {
|
||||
"type": "image",
|
||||
"description": "The image containing the information. Can be a PIL Image or a string path to the image.",
|
||||
},
|
||||
"question": {"type": "text", "description": "The question in English"},
|
||||
}
|
||||
output_type = "text"
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if not is_vision_available():
|
||||
@ -52,6 +57,10 @@ class DocumentQuestionAnsweringTool(PipelineTool):
|
||||
decoder_input_ids = self.pre_processor.tokenizer(
|
||||
prompt, add_special_tokens=False, return_tensors="pt"
|
||||
).input_ids
|
||||
if isinstance(document, str):
|
||||
img = Image.open(document).convert("RGB")
|
||||
img_array = np.array(img).transpose(2, 0, 1)
|
||||
document = torch.tensor(img_array)
|
||||
pixel_values = self.pre_processor(document, return_tensors="pt").pixel_values
|
||||
|
||||
return {"decoder_input_ids": decoder_input_ids, "pixel_values": pixel_values}
|
@ -14,7 +14,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from .agents import BASE_PYTHON_TOOLS, clean_code_for_chat
|
||||
from .agents import BASE_PYTHON_TOOLS
|
||||
from .python_interpreter import InterpretorError, evaluate
|
||||
|
||||
|
||||
@ -221,182 +221,6 @@ EVALUATION_TASKS = [
|
||||
]
|
||||
|
||||
|
||||
EVALUATION_CHATS = [
|
||||
[
|
||||
Problem(
|
||||
task=[
|
||||
"Translate the following `text` from Spanish to English.",
|
||||
"Translate the following `text` from Spanish to English.",
|
||||
],
|
||||
inputs=["text"],
|
||||
answer="translated_text=translator(text, src_lang='Spanish', tgt_lang='English')",
|
||||
),
|
||||
Problem(
|
||||
task=[
|
||||
"Is it positive or negative?",
|
||||
"Tell me if its positive or negative.",
|
||||
],
|
||||
inputs=[],
|
||||
answer="text_classifier(translated_text, labels=['positive', 'negative'])",
|
||||
),
|
||||
],
|
||||
[
|
||||
Problem(
|
||||
task=[
|
||||
"What does this `image` contain?",
|
||||
"Describe the following `image`.",
|
||||
"Find what is in the picture stored in `image`",
|
||||
],
|
||||
inputs=["image"],
|
||||
answer=[
|
||||
"description=image_captioner(image)",
|
||||
"description=image_qa(image, question='What is in the image?')",
|
||||
],
|
||||
),
|
||||
Problem(
|
||||
task=["Now, read the description out loud.", "Great! Can you read it out loud?", "Read it out loud."],
|
||||
inputs=[],
|
||||
answer=["audio=text_reader(description)", "audio=text_reader(description)"],
|
||||
),
|
||||
],
|
||||
[
|
||||
Problem(
|
||||
task=[
|
||||
"Generate an image from the text given in `text_input`.",
|
||||
"Use the following `text_input` to generate an image",
|
||||
],
|
||||
inputs=["text_input"],
|
||||
answer="image = image_generator(text_input)",
|
||||
),
|
||||
Problem(
|
||||
task=[
|
||||
"Transform it according to the text in `prompt`.",
|
||||
"Transform it by using the text in `prompt`.",
|
||||
],
|
||||
inputs=["prompt"],
|
||||
answer="image_transformer(image, prompt)",
|
||||
),
|
||||
],
|
||||
[
|
||||
Problem(
|
||||
task=[
|
||||
"Download the content of `url` and summarize it.",
|
||||
"Summarize the content of the web page at `url`.",
|
||||
],
|
||||
inputs=["url"],
|
||||
answer="summary = summarizer(text_downloader(url))",
|
||||
),
|
||||
Problem(
|
||||
task=[
|
||||
"Generate an image from its content.",
|
||||
"Use the previous result to generate an image.",
|
||||
],
|
||||
inputs=[],
|
||||
answer="image_generator(summary)",
|
||||
),
|
||||
],
|
||||
[
|
||||
Problem(
|
||||
task=[
|
||||
"Translate this Spanish `text` in English.",
|
||||
"Translate the `text` from Spanish to English.",
|
||||
],
|
||||
inputs=["text"],
|
||||
answer="translated_text = translator(text, src_lang='Spanish', tgt_lang='English')",
|
||||
),
|
||||
Problem(
|
||||
task=[
|
||||
"Transform the following `image` using the translated `text`.",
|
||||
"Use the previous result to transform the following `image`.",
|
||||
],
|
||||
inputs=["image"],
|
||||
answer="image_transformer(image, translated_text)",
|
||||
),
|
||||
],
|
||||
[
|
||||
Problem(
|
||||
task=["Download the content of `url`.", "Get me the text on the weg page `url`."],
|
||||
inputs=["url"],
|
||||
answer="text = text_downloader(url)",
|
||||
),
|
||||
Problem(
|
||||
task=["Summarize this text.", "Summarize this text."],
|
||||
inputs=[],
|
||||
answer="summary = summarizer(text)",
|
||||
),
|
||||
Problem(
|
||||
task=["Read it out loud to me.", "Read me the previous result."],
|
||||
inputs=[],
|
||||
answer="text_reader(summary)",
|
||||
),
|
||||
],
|
||||
[
|
||||
Problem(
|
||||
task=[
|
||||
"Generate an image from the text given in `text_input`.",
|
||||
],
|
||||
inputs=["text_input"],
|
||||
answer="image_generator(text_input)",
|
||||
),
|
||||
],
|
||||
[
|
||||
Problem(
|
||||
task=[
|
||||
"Replace the beaver in the `image` by the `prompt`.",
|
||||
"Transform the `image` so that it contains the `prompt`.",
|
||||
"Use `prompt` to transform this `image`.",
|
||||
],
|
||||
inputs=["image", "prompt"],
|
||||
answer="image_transformer(image, prompt)",
|
||||
),
|
||||
],
|
||||
[
|
||||
Problem(
|
||||
task=["Provide me the summary of the `text`.", "Summarize `text`."],
|
||||
inputs=["text"],
|
||||
answer="summary = summarizer(text)",
|
||||
),
|
||||
Problem(
|
||||
task=["Read this summary to me.", "Read it out loud."],
|
||||
inputs=[],
|
||||
answer="audio = text_reader(summarizer(text))",
|
||||
),
|
||||
Problem(
|
||||
task=["Transcribing the previous result back in text.", "Transcribe the audio."],
|
||||
inputs=[],
|
||||
answer="text = transcriber(audio)",
|
||||
),
|
||||
Problem(
|
||||
task=["Translating the last result in French.", "Translate this in French."],
|
||||
inputs=[],
|
||||
answer="translator(text, src_lang='English', tgt_lang='French')",
|
||||
),
|
||||
],
|
||||
[
|
||||
Problem(
|
||||
task=["Generate a video of the `prompt`", "Animate a `prompt`", "Make me a short video using `prompt`."],
|
||||
inputs={"prompt": "A lobster swimming"},
|
||||
answer="video_generator('A lobster swimming')",
|
||||
),
|
||||
],
|
||||
[
|
||||
Problem(
|
||||
task=[
|
||||
"Download the content of `url` and summarize it.",
|
||||
"Summarize the content of the web page at `url`.",
|
||||
],
|
||||
inputs=["url"],
|
||||
answer="summary = summarizer(text_downloader(url))",
|
||||
),
|
||||
Problem(
|
||||
task=["generate a video from it.", "Create an animation from the last result."],
|
||||
inputs=[],
|
||||
answer="video_generator(summary)",
|
||||
),
|
||||
],
|
||||
]
|
||||
|
||||
|
||||
def get_theoretical_tools(agent_answer, theoretical_answer, code_answer):
|
||||
if not isinstance(theoretical_answer, list):
|
||||
return {name for name in TEST_TOOLS if name in code_answer}
|
||||
@ -459,19 +283,19 @@ def score_code(agent_answer, theoretical_answer, verbose: bool = False):
|
||||
return 0.3
|
||||
|
||||
|
||||
def evaluate_one_result(explanation, code, agent_answer, theoretical_answer, answer, verbose=False):
|
||||
tools_in_explanation = {name for name in TEST_TOOLS if f"`{name}`" in explanation}
|
||||
def evaluate_one_result(code, agent_answer, theoretical_answer, answer, verbose=False):
|
||||
tools_in_code = {name for name in TEST_TOOLS if f"`{name}`" in code}
|
||||
theoretical_tools = get_theoretical_tools(agent_answer, theoretical_answer, answer)
|
||||
if tools_in_explanation == theoretical_tools:
|
||||
if tools_in_code == theoretical_tools:
|
||||
tool_selection_score = 1.0
|
||||
tool_selection_errors = None
|
||||
else:
|
||||
missing_tools = len(theoretical_tools - tools_in_explanation)
|
||||
unexpected_tools = len(tools_in_explanation - theoretical_tools)
|
||||
missing_tools = len(theoretical_tools - tools_in_code)
|
||||
unexpected_tools = len(tools_in_code - theoretical_tools)
|
||||
tool_selection_score = max(0, 1.0 - 0.25 * missing_tools - 0.25 * unexpected_tools)
|
||||
|
||||
tool_selection_errors = {
|
||||
"selected_tools": tools_in_explanation,
|
||||
"selected_tools": tools_in_code,
|
||||
"theoretical_tools": theoretical_tools,
|
||||
}
|
||||
|
||||
@ -485,7 +309,7 @@ def evaluate_one_result(explanation, code, agent_answer, theoretical_answer, ans
|
||||
tool_used_score = max(0, 1.0 - 0.25 * missing_tools - 0.25 * unexpected_tools)
|
||||
|
||||
tool_used_errors = {
|
||||
"selected_tools": tools_in_explanation,
|
||||
"selected_tools": tools_in_code,
|
||||
"theoretical_tools": theoretical_tools,
|
||||
}
|
||||
|
||||
@ -547,14 +371,13 @@ def evaluate_agent(agent, batch_size=8, verbose=False, return_errors=False):
|
||||
end_idx = min(start_idx + batch_size, len(eval_tasks))
|
||||
batch_tasks = eval_tasks[start_idx:end_idx]
|
||||
|
||||
prompts = [agent.format_prompt(task) for task in batch_tasks]
|
||||
results = agent.generate_many(prompts, stop=["Task:"])
|
||||
results = [agent.run(task, return_generated_code=True) for task in batch_tasks]
|
||||
|
||||
for idx, result in enumerate(results):
|
||||
problem = EVALUATION_TASKS[eval_idx[start_idx + idx]]
|
||||
if verbose:
|
||||
print(f"====Task {start_idx + idx}====\n{batch_tasks[idx]}\n")
|
||||
explanation, code = agent.clean_code_for_run(result)
|
||||
code = agent.extract_action(result, split_token="Answer:")
|
||||
|
||||
# Evaluate agent answer and code answer
|
||||
agent_answer = evaluate_code(code, problem.inputs, verbose=verbose)
|
||||
@ -564,7 +387,7 @@ def evaluate_agent(agent, batch_size=8, verbose=False, return_errors=False):
|
||||
theoretical_answer = evaluate_code(problem.answer, problem.inputs)
|
||||
|
||||
scores, errors = evaluate_one_result(
|
||||
explanation, code, agent_answer, theoretical_answer, problem.answer, verbose=verbose
|
||||
code, agent_answer, theoretical_answer, problem.answer, verbose=verbose
|
||||
)
|
||||
|
||||
tool_selection_score += scores[0]
|
||||
@ -589,104 +412,3 @@ def evaluate_agent(agent, batch_size=8, verbose=False, return_errors=False):
|
||||
return scores, tool_selection_errors, tool_used_errors, code_errors
|
||||
else:
|
||||
return scores
|
||||
|
||||
|
||||
def evaluate_chat_agent(agent, verbose=False, return_errors=False):
|
||||
"""
|
||||
Evaluates a new agent on all `EVALUATION_CHATS`.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
agent = NewOpenAiAgent(model="text-davinci-003", api_key=your_api_key)
|
||||
bads = new_evaluate_agent(agent)
|
||||
for bad in bads:
|
||||
print(bad)
|
||||
```
|
||||
"""
|
||||
# Sanity check
|
||||
agent_tools = set(agent.toolbox.keys())
|
||||
if agent_tools != set(TEST_TOOLS):
|
||||
missing_tools = set(TEST_TOOLS) - agent_tools
|
||||
unexpected_tools = agent_tools - set(TEST_TOOLS)
|
||||
raise ValueError(
|
||||
f"Fix the test tools in the evaluate_agent module. Tools mising: {missing_tools}. Extra tools: {unexpected_tools}."
|
||||
)
|
||||
|
||||
tool_selection_score = 0
|
||||
tool_used_score = 0
|
||||
code_score = 0
|
||||
total_steps = 0
|
||||
|
||||
if return_errors:
|
||||
tool_selection_errors = {}
|
||||
tool_used_errors = {}
|
||||
code_errors = {}
|
||||
|
||||
for chat_problem in EVALUATION_CHATS:
|
||||
if isinstance(chat_problem[0].task, str):
|
||||
resolved_problems = [chat_problem]
|
||||
else:
|
||||
resolved_problems = [
|
||||
[Problem(task=pb.task[i], inputs=pb.inputs, answer=pb.answer) for pb in chat_problem]
|
||||
for i in range(len(chat_problem[0].task))
|
||||
]
|
||||
for problem in resolved_problems:
|
||||
agent.prepare_for_new_chat()
|
||||
agent_state = {}
|
||||
theoretical_state = (
|
||||
[{} for _ in range(len(problem[0].answer))] if isinstance(problem[0].answer, list) else {}
|
||||
)
|
||||
|
||||
for step, step_problem in enumerate(problem):
|
||||
if verbose:
|
||||
print(step_problem.task)
|
||||
total_steps += 1
|
||||
prompt = agent.format_prompt(step_problem.task, chat_mode=True)
|
||||
result = agent.generate_one(prompt, stop=["Human:", "====="])
|
||||
agent.chat_history = prompt + result + "\n"
|
||||
|
||||
explanation, code = clean_code_for_chat(result)
|
||||
|
||||
if verbose:
|
||||
print(f"==Explanation from the agent==\n{explanation}")
|
||||
print(f"\n==Code generated by the agent==\n{code}")
|
||||
|
||||
# Evaluate agent answer and code answer
|
||||
agent_answer = evaluate_code(code, step_problem.inputs, state=agent_state, verbose=verbose)
|
||||
|
||||
answer = step_problem.answer
|
||||
if isinstance(answer, list):
|
||||
theoretical_answer = [
|
||||
evaluate_code(a, step_problem.inputs, state=state)
|
||||
for a, state in zip(answer, theoretical_state)
|
||||
]
|
||||
else:
|
||||
theoretical_answer = evaluate_code(answer, step_problem.inputs, state=theoretical_state)
|
||||
|
||||
scores, errors = evaluate_one_result(
|
||||
explanation, code, agent_answer, theoretical_answer, answer, verbose=verbose
|
||||
)
|
||||
|
||||
tool_selection_score += scores[0]
|
||||
tool_used_score += scores[1]
|
||||
code_score += scores[2]
|
||||
|
||||
if return_errors:
|
||||
if errors[0] is not None:
|
||||
tool_selection_errors[step_problem.task] = errors[0]
|
||||
if errors[1] is not None:
|
||||
tool_used_errors[step_problem.task] = errors[1]
|
||||
if errors[2] is not None:
|
||||
code_errors[step_problem.task] = errors[2]
|
||||
|
||||
scores = {
|
||||
"tool selection score": 100 * (tool_selection_score / total_steps),
|
||||
"tool used score": 100 * (tool_used_score / total_steps),
|
||||
"code score": 100 * (code_score / total_steps),
|
||||
}
|
||||
|
||||
if return_errors:
|
||||
return scores, tool_selection_errors, tool_used_errors, code_errors
|
||||
else:
|
||||
return scores
|
@ -14,32 +14,33 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from ..models.auto import AutoModelForVisualQuestionAnswering, AutoProcessor
|
||||
from ..utils import requires_backends
|
||||
from .base import PipelineTool
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from PIL import Image
|
||||
from .tools import PipelineTool
|
||||
|
||||
|
||||
class ImageQuestionAnsweringTool(PipelineTool):
|
||||
default_checkpoint = "dandelin/vilt-b32-finetuned-vqa"
|
||||
description = (
|
||||
"This is a tool that answers a question about an image. It takes an input named `image` which should be the "
|
||||
"image containing the information, as well as a `question` which should be the question in English. It "
|
||||
"This is a tool that answers a question about an image. It "
|
||||
"returns a text that is the answer to the question."
|
||||
)
|
||||
name = "image_qa"
|
||||
pre_processor_class = AutoProcessor
|
||||
model_class = AutoModelForVisualQuestionAnswering
|
||||
|
||||
inputs = ["image", "text"]
|
||||
outputs = ["text"]
|
||||
inputs = {
|
||||
"image": {
|
||||
"type": "image",
|
||||
"description": "The image containing the information. Can be a PIL Image or a string path to the image.",
|
||||
},
|
||||
"question": {"type": "text", "description": "The question in English"},
|
||||
}
|
||||
output_type = "text"
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["vision"])
|
92
src/transformers/agents/llm_engine.py
Normal file
92
src/transformers/agents/llm_engine.py
Normal file
@ -0,0 +1,92 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
from typing import Dict, List
|
||||
|
||||
from huggingface_hub import InferenceClient
|
||||
|
||||
|
||||
class MessageRole(str, Enum):
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
SYSTEM = "system"
|
||||
TOOL_CALL = "tool-call"
|
||||
TOOL_RESPONSE = "tool-response"
|
||||
|
||||
@classmethod
|
||||
def roles(cls):
|
||||
return [r.value for r in cls]
|
||||
|
||||
|
||||
def get_clean_message_list(message_list: List[Dict[str, str]], role_conversions: Dict[str, str] = {}):
|
||||
"""
|
||||
Subsequent messages with the same role will be concatenated to a single message.
|
||||
|
||||
Args:
|
||||
message_list (`List[Dict[str, str]]`): List of chat messages.
|
||||
"""
|
||||
final_message_list = []
|
||||
message_list = deepcopy(message_list) # Avoid modifying the original list
|
||||
for message in message_list:
|
||||
if not set(message.keys()) == {"role", "content"}:
|
||||
raise ValueError("Message should contain only 'role' and 'content' keys!")
|
||||
|
||||
role = message["role"]
|
||||
if role not in MessageRole.roles():
|
||||
raise ValueError(f"Incorrect role {role}, only {MessageRole.roles()} are supported for now.")
|
||||
|
||||
if role in role_conversions:
|
||||
message["role"] = role_conversions[role]
|
||||
|
||||
if len(final_message_list) > 0 and message["role"] == final_message_list[-1]["role"]:
|
||||
final_message_list[-1]["content"] += "\n===\n" + message["content"]
|
||||
else:
|
||||
final_message_list.append(message)
|
||||
return final_message_list
|
||||
|
||||
|
||||
llama_role_conversions = {
|
||||
MessageRole.SYSTEM: MessageRole.USER,
|
||||
MessageRole.TOOL_RESPONSE: MessageRole.USER,
|
||||
}
|
||||
|
||||
|
||||
class HfEngine:
|
||||
def __init__(self, model: str = "meta-llama/Meta-Llama-3-8B-Instruct"):
|
||||
self.model = model
|
||||
self.client = InferenceClient(model=self.model, timeout=120)
|
||||
|
||||
def __call__(self, messages: List[Dict[str, str]], stop_sequences=[]) -> str:
|
||||
if "Meta-Llama-3" in self.model:
|
||||
if "<|eot_id|>" not in stop_sequences:
|
||||
stop_sequences.append("<|eot_id|>")
|
||||
if "!!!!!" not in stop_sequences:
|
||||
stop_sequences.append("!!!!!")
|
||||
|
||||
# Get clean message list
|
||||
messages = get_clean_message_list(messages, role_conversions=llama_role_conversions)
|
||||
|
||||
# Get answer
|
||||
response = self.client.chat_completion(messages, stop=stop_sequences, max_tokens=1500)
|
||||
response = response.choices[0].message.content
|
||||
|
||||
# Remove stop sequences from the answer
|
||||
for stop_seq in stop_sequences:
|
||||
if response[-len(stop_seq) :] == stop_seq:
|
||||
response = response[: -len(stop_seq)]
|
||||
return response
|
364
src/transformers/agents/prompts.py
Normal file
364
src/transformers/agents/prompts.py
Normal file
@ -0,0 +1,364 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import re
|
||||
|
||||
from ..utils import cached_file
|
||||
|
||||
|
||||
# docstyle-ignore
|
||||
CHAT_MESSAGE_PROMPT = """
|
||||
Human: <<task>>
|
||||
|
||||
Assistant: """
|
||||
|
||||
|
||||
DEFAULT_PROMPTS_REPO = "huggingface-tools/default-prompts"
|
||||
PROMPT_FILES = {"chat": "chat_prompt_template.txt", "run": "run_prompt_template.txt"}
|
||||
|
||||
|
||||
def download_prompt(prompt_or_repo_id, agent_name, mode="run"):
|
||||
"""
|
||||
Downloads and caches the prompt from a repo and returns it contents (if necessary).
|
||||
"""
|
||||
if prompt_or_repo_id is None:
|
||||
prompt_or_repo_id = DEFAULT_PROMPTS_REPO
|
||||
|
||||
# prompt is considered a repo ID when it does not contain any kind of space
|
||||
if re.search("\\s", prompt_or_repo_id) is not None:
|
||||
return prompt_or_repo_id
|
||||
|
||||
prompt_file = cached_file(
|
||||
prompt_or_repo_id, PROMPT_FILES[mode], repo_type="dataset", user_agent={"agent": agent_name}
|
||||
)
|
||||
with open(prompt_file, "r", encoding="utf-8") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
DEFAULT_CODE_SYSTEM_PROMPT = """You will be given a task to solve, your job is to come up with a series of simple commands in Python that will perform the task.
|
||||
To help you, I will give you access to a set of tools that you can use. Each tool is a Python function and has a description explaining the task it performs, the inputs it expects and the outputs it returns.
|
||||
You should first explain which tool you will use to perform the task and for what reason, then write the code in Python.
|
||||
Each instruction in Python should be a simple assignment. You can print intermediate results if it makes sense to do so.
|
||||
Be sure to provide a 'Code:' token, else the system will be stuck in a loop.
|
||||
|
||||
Tools:
|
||||
<<tool_descriptions>>
|
||||
|
||||
Examples:
|
||||
---
|
||||
Task: "Answer the question in the variable `question` about the image stored in the variable `image`. The question is in French."
|
||||
|
||||
I will use the following tools: `translator` to translate the question into English and then `image_qa` to answer the question on the input image.
|
||||
Code:
|
||||
```py
|
||||
translated_question = translator(question=question, src_lang="French", tgt_lang="English")
|
||||
print(f"The translated question is {translated_question}.")
|
||||
answer = image_qa(image=image, question=translated_question)
|
||||
print(f"The answer is {answer}")
|
||||
```<end_code>
|
||||
|
||||
---
|
||||
Task: "Identify the oldest person in the `document` and create an image showcasing the result."
|
||||
|
||||
I will use the following tools: `document_qa` to find the oldest person in the document, then `image_generator` to generate an image according to the answer.
|
||||
Code:
|
||||
```py
|
||||
answer = document_qa(document, question="What is the oldest person?")
|
||||
print(f"The answer is {answer}.")
|
||||
image = image_generator(answer)
|
||||
```<end_code>
|
||||
|
||||
---
|
||||
Task: "Generate an image using the text given in the variable `caption`."
|
||||
|
||||
I will use the following tool: `image_generator` to generate an image.
|
||||
Code:
|
||||
```py
|
||||
image = image_generator(prompt=caption)
|
||||
```<end_code>
|
||||
|
||||
---
|
||||
Task: "Summarize the text given in the variable `text` and read it out loud."
|
||||
|
||||
I will use the following tools: `summarizer` to create a summary of the input text, then `text_reader` to read it out loud.
|
||||
Code:
|
||||
```py
|
||||
summarized_text = summarizer(text)
|
||||
print(f"Summary: {summarized_text}")
|
||||
audio_summary = text_reader(summarized_text)
|
||||
```<end_code>
|
||||
|
||||
---
|
||||
Task: "Answer the question in the variable `question` about the text in the variable `text`. Use the answer to generate an image."
|
||||
|
||||
I will use the following tools: `text_qa` to create the answer, then `image_generator` to generate an image according to the answer.
|
||||
Code:
|
||||
```py
|
||||
answer = text_qa(text=text, question=question)
|
||||
print(f"The answer is {answer}.")
|
||||
image = image_generator(answer)
|
||||
```<end_code>
|
||||
|
||||
---
|
||||
Task: "Caption the following `image`."
|
||||
|
||||
I will use the following tool: `image_captioner` to generate a caption for the image.
|
||||
Code:
|
||||
```py
|
||||
caption = image_captioner(image)
|
||||
```<end_code>
|
||||
|
||||
---
|
||||
Above example were using tools that might not exist for you. You only have acces to those Tools:
|
||||
<<tool_names>>
|
||||
|
||||
Remember to make sure that variables you use are all defined.
|
||||
Be sure to provide a 'Code:\n```' sequence before the code and '```<end_code>' after, else you will get an error.
|
||||
DO NOT pass the arguments as a dict as in 'answer = ask_search_agent({'query': "What is the place where James Bond lives?"})', but use the arguments directly as in 'answer = ask_search_agent(query="What is the place where James Bond lives?")'.
|
||||
|
||||
Now Begin!
|
||||
"""
|
||||
|
||||
|
||||
DEFAULT_REACT_JSON_SYSTEM_PROMPT = """You will be given a task to solve as best you can. You have access to the following tools:
|
||||
<<tool_descriptions>>
|
||||
|
||||
The way you use the tools is by specifying a json blob.
|
||||
Specifically, this json should have a `action` key (name of the tool to use) and a `action_input` key (input to the tool).
|
||||
|
||||
The $ACTION_JSON_BLOB should only contain a SINGLE action, do NOT return a list of multiple actions. It should be formatted in json. Do not try to escape special characters. Here is the template of a valid $ACTION_JSON_BLOB:
|
||||
Action:
|
||||
{
|
||||
"action": $TOOL_NAME,
|
||||
"action_input": $INPUT
|
||||
}
|
||||
|
||||
Make sure to have the $INPUT as a dictionnary in the right format for the tool you are using, and do not put variable names as input if you can find the right values.
|
||||
|
||||
You will be given:
|
||||
|
||||
Task: the task you are given.
|
||||
|
||||
You should ALWAYS use the following format:
|
||||
|
||||
Thought: you should always think about one action to take. Then use the action as follows:
|
||||
Action:
|
||||
$ACTION_JSON_BLOB
|
||||
Observation: the result of the action
|
||||
... (this Thought/Action/Observation can repeat N times, you should take several steps when needed. The $ACTION_JSON_BLOB must only use a SINGLE action at a time.)
|
||||
|
||||
You can use the result of the previous action as input for the next action.
|
||||
The observation will always be a string: it can represent a file, like "image_1.jpg".
|
||||
Then you can use it as input for the next action. You can do it for instance as follows:
|
||||
|
||||
Observation: "image_1.jpg"
|
||||
|
||||
Thought: I need to transform the image that I received in the previous observation to make it green.
|
||||
Action:
|
||||
{
|
||||
"action": "image_transformer",
|
||||
"action_input": {"image": "image_1.jpg"}
|
||||
}
|
||||
|
||||
To provide the final answer to the task, use an action blob with "action": "final_answer" tool. It is the only way to complete the task, else you will be stuck on a loop. So your final output should look like this:
|
||||
Action:
|
||||
{
|
||||
"action": "final_answer",
|
||||
"action_input": {"answer": "insert your final answer here"}
|
||||
}
|
||||
|
||||
|
||||
Here are a few examples using notional tools:
|
||||
---
|
||||
Task: "Generate an image of the oldest person in this document."
|
||||
|
||||
Thought: I will proceed step by step and use the following tools: `document_qa` to find the oldest person in the document, then `image_generator` to generate an image according to the answer.
|
||||
Action:
|
||||
{
|
||||
"action": "document_qa",
|
||||
"action_input": {"document": "document.pdf", "question": "Who is the oldest person mentioned?"}
|
||||
}
|
||||
Observation: "The oldest person in the document is John Doe, a 55 year old lumberjack living in Newfoundland."
|
||||
|
||||
|
||||
Thought: I will now generate an image showcasing the oldest person.
|
||||
Action:
|
||||
{
|
||||
"action": "image_generator",
|
||||
"action_input": {"text": ""A portrait of John Doe, a 55-year-old man living in Canada.""}
|
||||
}
|
||||
Observation: "image.png"
|
||||
|
||||
Thought: I will now return the generated image.
|
||||
Action:
|
||||
{
|
||||
"action": "final_answer",
|
||||
"action_input": "image.png"
|
||||
}
|
||||
|
||||
---
|
||||
Task: "What is the result of the following operation: 5 + 3 + 1294.678?"
|
||||
|
||||
Thought: I will use python code evaluator to compute the result of the operation and then return the final answer using the `final_answer` tool
|
||||
Action:
|
||||
{
|
||||
"action": "python_interpreter",
|
||||
"action_input": {"code": "5 + 3 + 1294.678"}
|
||||
}
|
||||
Observation: 1302.678
|
||||
|
||||
Thought: Now that I know the result, I will now return it.
|
||||
Action:
|
||||
{
|
||||
"action": "final_answer",
|
||||
"action_input": "1302.678"
|
||||
}
|
||||
|
||||
---
|
||||
Task: "Which city has the highest population , Guangzhou or Shanghai?"
|
||||
|
||||
Thought: I need to get the populations for both cities and compare them: I will use the tool `search` to get the population of both cities.
|
||||
Action:
|
||||
{
|
||||
"action": "search",
|
||||
"action_input": "Population Guangzhou"
|
||||
}
|
||||
Observation: ['Guangzhou has a population of 15 million inhabitants as of 2021.']
|
||||
|
||||
|
||||
Thought: Now let's get the population of Shanghai using the tool 'search'.
|
||||
Action:
|
||||
{
|
||||
"action": "search",
|
||||
"action_input": "Population Shanghai"
|
||||
}
|
||||
Observation: '26 million (2019)'
|
||||
|
||||
Thought: Now I know that Shanghai has a larger population. Let's return the result.
|
||||
Action:
|
||||
{
|
||||
"action": "final_answer",
|
||||
"action_input": "Shanghai"
|
||||
}
|
||||
|
||||
|
||||
Above example were using notional tools that might not exist for you. You only have acces to those tools:
|
||||
<<tool_names>>
|
||||
ALWAYS provide a 'Thought:' and an 'Action:' sequence. You MUST provide at least the 'Action:' sequence to move forward.
|
||||
|
||||
Now begin!
|
||||
"""
|
||||
|
||||
|
||||
DEFAULT_REACT_CODE_SYSTEM_PROMPT = """You will be given a task to solve as best you can.
|
||||
You have access to the following tools:
|
||||
<<tool_descriptions>>
|
||||
|
||||
To solve the task, you must plan forward to proceed in a series of steps, in a cycle of 'Thought:', 'Code:', and 'Observation:' sequences.
|
||||
|
||||
At each step, in the 'Thought:' sequence, you should first explain your reasoning towards solving the task, then the tools that you want to use.
|
||||
Then in the 'Code:' sequence, you shold write the code in simple Python. The code sequence must end with '/End code' sequence.
|
||||
During each intermediate step, you can use 'print()' to save whatever important information you will then need.
|
||||
These print outputs will then be available in the 'Observation:' field, for using this information as input for the next step.
|
||||
|
||||
In the end you have to return a final answer using the `final_answer` tool.
|
||||
|
||||
Here are a few examples using notional tools:
|
||||
---
|
||||
Task: "Generate an image of the oldest person in this document."
|
||||
|
||||
Thought: I will proceed step by step and use the following tools: `document_qa` to find the oldest person in the document, then `image_generator` to generate an image according to the answer.
|
||||
Code:
|
||||
```py
|
||||
answer = document_qa(document=document, question="Who is the oldest person mentioned?")
|
||||
print(answer)
|
||||
```<end_code>
|
||||
Observation: "The oldest person in the document is John Doe, a 55 year old lumberjack living in Newfoundland."
|
||||
|
||||
Thought: I will now generate an image showcasing the oldest person.
|
||||
|
||||
Code:
|
||||
```py
|
||||
image = image_generator("A portrait of John Doe, a 55-year-old man living in Canada.")
|
||||
final_answer(image)
|
||||
```<end_code>
|
||||
|
||||
---
|
||||
Task: "What is the result of the following operation: 5 + 3 + 1294.678?"
|
||||
|
||||
Thought: I will use python code to compute the result of the operation and then return the final answer using the `final_answer` tool
|
||||
|
||||
Code:
|
||||
```py
|
||||
result = 5 + 3 + 1294.678
|
||||
final_answer(result)
|
||||
```<end_code>
|
||||
|
||||
---
|
||||
Task: "Which city has the highest population , Guangzhou or Shanghai?"
|
||||
|
||||
Thought: I need to get the populations for both cities and compare them: I will use the tool `search` to get the population of both cities.
|
||||
Code:
|
||||
```py
|
||||
population_guangzhou = search("Guangzhou population")
|
||||
print("Population Guangzhou:", population_guangzhou)
|
||||
population_shanghai = search("Shanghai population")
|
||||
print("Population Shanghai:", population_shanghai)
|
||||
```<end_code>
|
||||
Observation:
|
||||
Population Guangzhou: ['Guangzhou has a population of 15 million inhabitants as of 2021.']
|
||||
Population Shanghai: '26 million (2019)'
|
||||
|
||||
Thought: Now I know that Shanghai has the highest population.
|
||||
Code:
|
||||
```py
|
||||
final_answer("Shanghai")
|
||||
```<end_code>
|
||||
|
||||
---
|
||||
Task: "What is the current age of the pope, raised to the power 0.36?"
|
||||
|
||||
Thought: I will use the tool `search` to get the age of the pope, then raise it to the power 0.36.
|
||||
Code:
|
||||
```py
|
||||
pope_age = search(query="current pope age")
|
||||
print("Pope age:", pope_age)
|
||||
```<end_code>
|
||||
Observation:
|
||||
Pope age: "The pope Francis is currently 85 years old."
|
||||
|
||||
Thought: I know that the pope is 85 years old. Let's compute the result using python code.
|
||||
Code:
|
||||
```py
|
||||
pope_current_age = 85 ** 0.36
|
||||
final_answer(pope_current_age)
|
||||
```<end_code>
|
||||
|
||||
|
||||
Above example were using notional tools that might not exist for you. You only have acces to those tools:
|
||||
<<tool_names>>
|
||||
You also can perform computations in the python code you generate.
|
||||
|
||||
Always provide a 'Thought:' and a 'Code:\n```py' sequence ending with '```<end_code>' sequence. You MUST provide at least the 'Code:' sequence to move forward.
|
||||
|
||||
Remember to not perform too many operations in a single code block! You should split the task into intermediate code blocks.
|
||||
Print results at the end of each step to save the intermediate results. Then use final_answer() to return the final result.
|
||||
|
||||
Remember to make sure that variables you use are all defined.
|
||||
DO NOT pass the arguments as a dict as in 'answer = ask_search_agent({'query': "What is the place where James Bond lives?"})', but use the arguments directly as in 'answer = ask_search_agent(query="What is the place where James Bond lives?")'.
|
||||
|
||||
Now Begin!
|
||||
"""
|
520
src/transformers/agents/python_interpreter.py
Normal file
520
src/transformers/agents/python_interpreter.py
Normal file
@ -0,0 +1,520 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import ast
|
||||
import difflib
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
|
||||
class InterpretorError(ValueError):
|
||||
"""
|
||||
An error raised when the interpretor cannot evaluate a Python expression, due to syntax error or unsupported
|
||||
operations.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
LIST_SAFE_MODULES = ["random", "math", "time", "queue", "itertools", "re", "stat", "statistics", "unicodedata"]
|
||||
|
||||
|
||||
class BreakException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ContinueException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def get_iterable(obj):
|
||||
if isinstance(obj, list):
|
||||
return obj
|
||||
elif hasattr(obj, "__iter__"):
|
||||
return list(obj)
|
||||
else:
|
||||
raise InterpretorError("Object is not iterable")
|
||||
|
||||
|
||||
def evaluate_unaryop(expression, state, tools):
|
||||
operand = evaluate_ast(expression.operand, state, tools)
|
||||
if isinstance(expression.op, ast.USub):
|
||||
return -operand
|
||||
elif isinstance(expression.op, ast.UAdd):
|
||||
return operand
|
||||
elif isinstance(expression.op, ast.Not):
|
||||
return not operand
|
||||
elif isinstance(expression.op, ast.Invert):
|
||||
return ~operand
|
||||
else:
|
||||
raise InterpretorError(f"Unary operation {expression.op.__class__.__name__} is not supported.")
|
||||
|
||||
|
||||
def evaluate_lambda(lambda_expression, state, tools):
|
||||
args = [arg.arg for arg in lambda_expression.args.args]
|
||||
|
||||
def lambda_func(*values):
|
||||
new_state = state.copy()
|
||||
for arg, value in zip(args, values):
|
||||
new_state[arg] = value
|
||||
return evaluate_ast(lambda_expression.body, new_state, tools)
|
||||
|
||||
return lambda_func
|
||||
|
||||
|
||||
def evaluate_while(while_loop, state, tools):
|
||||
max_iterations = 1000
|
||||
iterations = 0
|
||||
while evaluate_ast(while_loop.test, state, tools):
|
||||
for node in while_loop.body:
|
||||
evaluate_ast(node, state, tools)
|
||||
iterations += 1
|
||||
if iterations > max_iterations:
|
||||
raise InterpretorError(f"Maximum number of {max_iterations} iterations in While loop exceeded")
|
||||
return None
|
||||
|
||||
|
||||
def evaluate_function_def(function_def, state, tools):
|
||||
def create_function(func_def, state, tools):
|
||||
def new_func(*args):
|
||||
new_state = state.copy()
|
||||
for arg, val in zip(func_def.args.args, args):
|
||||
new_state[arg.arg] = val
|
||||
result = None
|
||||
for node in func_def.body:
|
||||
result = evaluate_ast(node, new_state, tools)
|
||||
return result
|
||||
|
||||
return new_func
|
||||
|
||||
tools[function_def.name] = create_function(function_def, state, tools)
|
||||
return None
|
||||
|
||||
|
||||
def evaluate_augassign(expression: ast.AugAssign, state: Dict[str, Any], tools: Dict[str, Callable]):
|
||||
# Extract the target variable name and the operation
|
||||
if isinstance(expression.target, ast.Name):
|
||||
var_name = expression.target.id
|
||||
current_value = state.get(var_name, 0) # Assuming default of 0 if not in state
|
||||
value_to_add = evaluate_ast(expression.value, state, tools)
|
||||
|
||||
# Determine the operation and apply it
|
||||
if isinstance(expression.op, ast.Add):
|
||||
updated_value = current_value + value_to_add
|
||||
elif isinstance(expression.op, ast.Sub):
|
||||
updated_value = current_value - value_to_add
|
||||
elif isinstance(expression.op, ast.Mult):
|
||||
updated_value = current_value * value_to_add
|
||||
elif isinstance(expression.op, ast.Div):
|
||||
updated_value = current_value / value_to_add
|
||||
# Add other operations as needed
|
||||
|
||||
# Update the state
|
||||
state[var_name] = updated_value
|
||||
return updated_value
|
||||
else:
|
||||
raise InterpretorError("AugAssign not supported for non-simple variable targets.")
|
||||
|
||||
|
||||
def evaluate_boolop(boolop, state, tools):
|
||||
values = [evaluate_ast(val, state, tools) for val in boolop.values]
|
||||
op = boolop.op
|
||||
if isinstance(op, ast.And):
|
||||
return all(values)
|
||||
elif isinstance(op, ast.Or):
|
||||
return any(values)
|
||||
|
||||
|
||||
def evaluate_binop(binop, state, tools):
|
||||
# Recursively evaluate the left and right operands
|
||||
left_val = evaluate_ast(binop.left, state, tools)
|
||||
right_val = evaluate_ast(binop.right, state, tools)
|
||||
|
||||
# Determine the operation based on the type of the operator in the BinOp
|
||||
if isinstance(binop.op, ast.Add):
|
||||
return left_val + right_val
|
||||
elif isinstance(binop.op, ast.Sub):
|
||||
return left_val - right_val
|
||||
elif isinstance(binop.op, ast.Mult):
|
||||
return left_val * right_val
|
||||
elif isinstance(binop.op, ast.Div):
|
||||
return left_val / right_val
|
||||
elif isinstance(binop.op, ast.Mod):
|
||||
return left_val % right_val
|
||||
elif isinstance(binop.op, ast.Pow):
|
||||
return left_val**right_val
|
||||
elif isinstance(binop.op, ast.FloorDiv):
|
||||
return left_val // right_val
|
||||
elif isinstance(binop.op, ast.BitAnd):
|
||||
return left_val & right_val
|
||||
elif isinstance(binop.op, ast.BitOr):
|
||||
return left_val | right_val
|
||||
elif isinstance(binop.op, ast.BitXor):
|
||||
return left_val ^ right_val
|
||||
elif isinstance(binop.op, ast.LShift):
|
||||
return left_val << right_val
|
||||
elif isinstance(binop.op, ast.RShift):
|
||||
return left_val >> right_val
|
||||
else:
|
||||
raise NotImplementedError(f"Binary operation {type(binop.op).__name__} is not implemented.")
|
||||
|
||||
|
||||
def evaluate_assign(assign, state, tools):
|
||||
var_names = assign.targets
|
||||
result = evaluate_ast(assign.value, state, tools)
|
||||
if len(var_names) == 1:
|
||||
if isinstance(var_names[0], ast.Tuple):
|
||||
for i, elem in enumerate(var_names[0].elts):
|
||||
state[elem.id] = result[i]
|
||||
else:
|
||||
state[var_names[0].id] = result
|
||||
else:
|
||||
if len(result) != len(var_names):
|
||||
raise InterpretorError(f"Expected {len(var_names)} values but got {len(result)}.")
|
||||
for var_name, r in zip(var_names, result):
|
||||
state[var_name.id] = r
|
||||
return result
|
||||
|
||||
|
||||
def evaluate_call(call, state, tools):
|
||||
if isinstance(call.func, ast.Attribute):
|
||||
obj = evaluate_ast(call.func.value, state, tools)
|
||||
func_name = call.func.attr
|
||||
if not hasattr(obj, func_name):
|
||||
raise InterpretorError(f"Object {obj} has no attribute {func_name}")
|
||||
func = getattr(obj, func_name)
|
||||
args = [evaluate_ast(arg, state, tools) for arg in call.args]
|
||||
kwargs = {keyword.arg: evaluate_ast(keyword.value, state, tools) for keyword in call.keywords}
|
||||
return func(*args, **kwargs)
|
||||
|
||||
elif isinstance(call.func, ast.Name):
|
||||
func_name = call.func.id
|
||||
|
||||
if func_name in state:
|
||||
func = state[func_name]
|
||||
elif func_name in tools:
|
||||
func = tools[func_name]
|
||||
else:
|
||||
raise InterpretorError(
|
||||
f"It is not permitted to evaluate other functions than the provided tools or imported functions (tried to execute {call.func.id})."
|
||||
)
|
||||
# Todo deal with args
|
||||
args = [evaluate_ast(arg, state, tools) for arg in call.args]
|
||||
kwargs = {keyword.arg: evaluate_ast(keyword.value, state, tools) for keyword in call.keywords}
|
||||
output = func(*args, **kwargs)
|
||||
|
||||
# store logs of print statements
|
||||
if func_name == "print":
|
||||
state["print_outputs"] += output + "\n"
|
||||
|
||||
return output
|
||||
else:
|
||||
raise InterpretorError(
|
||||
f"It is not permitted to evaluate other functions than the provided tools (tried to execute {call.func})."
|
||||
)
|
||||
|
||||
|
||||
def evaluate_subscript(subscript, state, tools):
|
||||
index = evaluate_ast(subscript.slice, state, tools)
|
||||
value = evaluate_ast(subscript.value, state, tools)
|
||||
if isinstance(index, slice):
|
||||
return value[index]
|
||||
elif isinstance(value, (list, tuple)):
|
||||
return value[int(index)]
|
||||
elif isinstance(value, str):
|
||||
return value[index]
|
||||
elif index in value:
|
||||
return value[index]
|
||||
elif isinstance(index, str) and isinstance(value, Mapping):
|
||||
close_matches = difflib.get_close_matches(index, list(value.keys()))
|
||||
if len(close_matches) > 0:
|
||||
return value[close_matches[0]]
|
||||
raise InterpretorError(f"Could not index {value} with '{index}'.")
|
||||
|
||||
|
||||
def evaluate_name(name, state, tools):
|
||||
if name.id in state:
|
||||
return state[name.id]
|
||||
close_matches = difflib.get_close_matches(name.id, list(state.keys()))
|
||||
if len(close_matches) > 0:
|
||||
return state[close_matches[0]]
|
||||
raise InterpretorError(f"The variable `{name.id}` is not defined.")
|
||||
|
||||
|
||||
def evaluate_condition(condition, state, tools):
|
||||
left = evaluate_ast(condition.left, state, tools)
|
||||
comparators = [evaluate_ast(c, state, tools) for c in condition.comparators]
|
||||
ops = [type(op) for op in condition.ops]
|
||||
|
||||
result = left
|
||||
for op, comparator in zip(ops, comparators):
|
||||
if op == ast.Eq:
|
||||
result = result == comparator
|
||||
elif op == ast.NotEq:
|
||||
result = result != comparator
|
||||
elif op == ast.Lt:
|
||||
result = result < comparator
|
||||
elif op == ast.LtE:
|
||||
result = result <= comparator
|
||||
elif op == ast.Gt:
|
||||
result = result > comparator
|
||||
elif op == ast.GtE:
|
||||
result = result >= comparator
|
||||
elif op == ast.Is:
|
||||
result = result is comparator
|
||||
elif op == ast.IsNot:
|
||||
result = result is not comparator
|
||||
elif op == ast.In:
|
||||
result = result in comparator
|
||||
elif op == ast.NotIn:
|
||||
result = result not in comparator
|
||||
else:
|
||||
raise InterpretorError(f"Operator not supported: {op}")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def evaluate_if(if_statement, state, tools):
|
||||
result = None
|
||||
test_result = evaluate_ast(if_statement.test, state, tools)
|
||||
if test_result:
|
||||
for line in if_statement.body:
|
||||
line_result = evaluate_ast(line, state, tools)
|
||||
if line_result is not None:
|
||||
result = line_result
|
||||
else:
|
||||
for line in if_statement.orelse:
|
||||
line_result = evaluate_ast(line, state, tools)
|
||||
if line_result is not None:
|
||||
result = line_result
|
||||
return result
|
||||
|
||||
|
||||
def evaluate_for(for_loop, state, tools):
|
||||
result = None
|
||||
iterator = evaluate_ast(for_loop.iter, state, tools)
|
||||
for counter in iterator:
|
||||
state[for_loop.target.id] = counter
|
||||
for node in for_loop.body:
|
||||
try:
|
||||
line_result = evaluate_ast(node, state, tools)
|
||||
if line_result is not None:
|
||||
result = line_result
|
||||
except BreakException:
|
||||
break
|
||||
except ContinueException:
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
break
|
||||
return result
|
||||
|
||||
|
||||
def evaluate_listcomp(listcomp, state, tools):
|
||||
result = []
|
||||
vars = {}
|
||||
for generator in listcomp.generators:
|
||||
var_name = generator.target.id
|
||||
iter_value = evaluate_ast(generator.iter, state, tools)
|
||||
for value in iter_value:
|
||||
vars[var_name] = value
|
||||
if all(evaluate_ast(if_clause, {**state, **vars}, tools) for if_clause in generator.ifs):
|
||||
elem = evaluate_ast(listcomp.elt, {**state, **vars}, tools)
|
||||
result.append(elem)
|
||||
return result
|
||||
|
||||
|
||||
def evaluate_ast(expression: ast.AST, state: Dict[str, Any], tools: Dict[str, Callable]):
|
||||
"""
|
||||
Evaluate an abstract syntax tree using the content of the variables stored in a state and only evaluating a given
|
||||
set of functions.
|
||||
|
||||
This function will recurse trough the nodes of the tree provided.
|
||||
|
||||
Args:
|
||||
expression (`ast.AST`):
|
||||
The code to evaluate, as an abastract syntax tree.
|
||||
state (`Dict[str, Any]`):
|
||||
A dictionary mapping variable names to values. The `state` is updated if need be when the evaluation
|
||||
encounters assignements.
|
||||
tools (`Dict[str, Callable]`):
|
||||
The functions that may be called during the evaluation. Any call to another function will fail with an
|
||||
`InterpretorError`.
|
||||
"""
|
||||
if isinstance(expression, ast.Assign):
|
||||
# Assignement -> we evaluate the assignement which should update the state
|
||||
# We return the variable assigned as it may be used to determine the final result.
|
||||
return evaluate_assign(expression, state, tools)
|
||||
elif isinstance(expression, ast.AugAssign):
|
||||
return evaluate_augassign(expression, state, tools)
|
||||
elif isinstance(expression, ast.Call):
|
||||
# Function call -> we return the value of the function call
|
||||
return evaluate_call(expression, state, tools)
|
||||
elif isinstance(expression, ast.Constant):
|
||||
# Constant -> just return the value
|
||||
return expression.value
|
||||
elif isinstance(expression, ast.Tuple):
|
||||
return tuple(evaluate_ast(elt, state, tools) for elt in expression.elts)
|
||||
elif isinstance(expression, ast.ListComp):
|
||||
return evaluate_listcomp(expression, state, tools)
|
||||
elif isinstance(expression, ast.UnaryOp):
|
||||
return evaluate_unaryop(expression, state, tools)
|
||||
elif isinstance(expression, ast.BoolOp):
|
||||
# Boolean operation -> evaluate the operation
|
||||
return evaluate_boolop(expression, state, tools)
|
||||
elif isinstance(expression, ast.Break):
|
||||
raise BreakException()
|
||||
elif isinstance(expression, ast.Continue):
|
||||
raise ContinueException()
|
||||
elif isinstance(expression, ast.BinOp):
|
||||
# Binary operation -> execute operation
|
||||
return evaluate_binop(expression, state, tools)
|
||||
elif isinstance(expression, ast.Compare):
|
||||
# Comparison -> evaluate the comparison
|
||||
return evaluate_condition(expression, state, tools)
|
||||
elif isinstance(expression, ast.Return):
|
||||
return evaluate_ast(expression.value, state, tools)
|
||||
elif isinstance(expression, ast.Lambda):
|
||||
return evaluate_lambda(expression, state, tools)
|
||||
elif isinstance(expression, ast.FunctionDef):
|
||||
return evaluate_function_def(expression, state, tools)
|
||||
elif isinstance(expression, ast.Dict):
|
||||
# Dict -> evaluate all keys and values
|
||||
keys = [evaluate_ast(k, state, tools) for k in expression.keys]
|
||||
values = [evaluate_ast(v, state, tools) for v in expression.values]
|
||||
return dict(zip(keys, values))
|
||||
elif isinstance(expression, ast.Expr):
|
||||
# Expression -> evaluate the content
|
||||
return evaluate_ast(expression.value, state, tools)
|
||||
elif isinstance(expression, ast.For):
|
||||
# For loop -> execute the loop
|
||||
return evaluate_for(expression, state, tools)
|
||||
elif isinstance(expression, ast.FormattedValue):
|
||||
# Formatted value (part of f-string) -> evaluate the content and return
|
||||
return evaluate_ast(expression.value, state, tools)
|
||||
elif isinstance(expression, ast.If):
|
||||
# If -> execute the right branch
|
||||
return evaluate_if(expression, state, tools)
|
||||
elif hasattr(ast, "Index") and isinstance(expression, ast.Index):
|
||||
return evaluate_ast(expression.value, state, tools)
|
||||
elif isinstance(expression, ast.JoinedStr):
|
||||
return "".join([str(evaluate_ast(v, state, tools)) for v in expression.values])
|
||||
elif isinstance(expression, ast.List):
|
||||
# List -> evaluate all elements
|
||||
return [evaluate_ast(elt, state, tools) for elt in expression.elts]
|
||||
elif isinstance(expression, ast.Name):
|
||||
# Name -> pick up the value in the state
|
||||
return evaluate_name(expression, state, tools)
|
||||
elif isinstance(expression, ast.Subscript):
|
||||
# Subscript -> return the value of the indexing
|
||||
return evaluate_subscript(expression, state, tools)
|
||||
elif isinstance(expression, ast.IfExp):
|
||||
test_val = evaluate_ast(expression.test, state, tools)
|
||||
if test_val:
|
||||
return evaluate_ast(expression.body, state, tools)
|
||||
else:
|
||||
return evaluate_ast(expression.orelse, state, tools)
|
||||
elif isinstance(expression, ast.Attribute):
|
||||
obj = evaluate_ast(expression.value, state, tools)
|
||||
return getattr(obj, expression.attr)
|
||||
elif isinstance(expression, ast.Slice):
|
||||
return slice(
|
||||
evaluate_ast(expression.lower, state, tools) if expression.lower is not None else None,
|
||||
evaluate_ast(expression.upper, state, tools) if expression.upper is not None else None,
|
||||
evaluate_ast(expression.step, state, tools) if expression.step is not None else None,
|
||||
)
|
||||
elif isinstance(expression, ast.ListComp) or isinstance(expression, ast.GeneratorExp):
|
||||
result = []
|
||||
vars = {}
|
||||
for generator in expression.generators:
|
||||
var_name = generator.target.id
|
||||
iter_value = evaluate_ast(generator.iter, state, tools)
|
||||
for value in iter_value:
|
||||
vars[var_name] = value
|
||||
if all(evaluate_ast(if_clause, {**state, **vars}, tools) for if_clause in generator.ifs):
|
||||
elem = evaluate_ast(expression.elt, {**state, **vars}, tools)
|
||||
result.append(elem)
|
||||
return result
|
||||
elif isinstance(expression, ast.DictComp):
|
||||
result = {}
|
||||
for gen in expression.generators:
|
||||
for container in get_iterable(evaluate_ast(gen.iter, state, tools)):
|
||||
state[gen.target.id] = container
|
||||
key = evaluate_ast(expression.key, state, tools)
|
||||
value = evaluate_ast(expression.value, state, tools)
|
||||
result[key] = value
|
||||
return result
|
||||
elif isinstance(expression, ast.Import):
|
||||
for alias in expression.names:
|
||||
if alias.name in LIST_SAFE_MODULES:
|
||||
module = __import__(alias.name)
|
||||
state[alias.asname or alias.name] = module
|
||||
else:
|
||||
raise InterpretorError(f"Import of {alias.name} is not allowed.")
|
||||
return None
|
||||
elif isinstance(expression, ast.While):
|
||||
return evaluate_while(expression, state, tools)
|
||||
elif isinstance(expression, ast.ImportFrom):
|
||||
if expression.module in LIST_SAFE_MODULES:
|
||||
module = __import__(expression.module)
|
||||
for alias in expression.names:
|
||||
state[alias.asname or alias.name] = getattr(module, alias.name)
|
||||
else:
|
||||
raise InterpretorError(f"Import from {expression.module} is not allowed.")
|
||||
return None
|
||||
else:
|
||||
# For now we refuse anything else. Let's add things as we need them.
|
||||
raise InterpretorError(f"{expression.__class__.__name__} is not supported.")
|
||||
|
||||
|
||||
def evaluate_python_code(code: str, tools: Optional[Dict[str, Callable]] = {}, state=None):
|
||||
"""
|
||||
Evaluate a python expression using the content of the variables stored in a state and only evaluating a given set
|
||||
of functions.
|
||||
|
||||
This function will recurse through the nodes of the tree provided.
|
||||
|
||||
Args:
|
||||
code (`str`):
|
||||
The code to evaluate.
|
||||
tools (`Dict[str, Callable]`):
|
||||
The functions that may be called during the evaluation. Any call to another function will fail with an
|
||||
`InterpretorError`.
|
||||
state (`Dict[str, Any]`):
|
||||
A dictionary mapping variable names to values. The `state` should contain the initial inputs but will be
|
||||
updated by this function to contain all variables as they are evaluated.
|
||||
The print outputs will be stored in the state under the key 'print_outputs'.
|
||||
"""
|
||||
try:
|
||||
expression = ast.parse(code)
|
||||
except SyntaxError as e:
|
||||
raise SyntaxError(f"The code generated by the agent is not valid.\n{e}")
|
||||
if state is None:
|
||||
state = {}
|
||||
result = None
|
||||
state["print_outputs"] = ""
|
||||
for idx, node in enumerate(expression.body):
|
||||
try:
|
||||
line_result = evaluate_ast(node, state, tools)
|
||||
except InterpretorError as e:
|
||||
msg = f"You tried to execute the following code:\n{code}\n"
|
||||
msg += f"You got these outputs:\n{state['print_outputs']}\n"
|
||||
msg += f"Evaluation stopped at line '{node}' because of the following error:\n{e}"
|
||||
raise InterpretorError(msg)
|
||||
if line_result is not None:
|
||||
result = line_result
|
||||
|
||||
return result
|
@ -1,7 +1,7 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -14,28 +14,26 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ..models.whisper import WhisperForConditionalGeneration, WhisperProcessor
|
||||
from .base import PipelineTool
|
||||
from .tools import PipelineTool
|
||||
|
||||
|
||||
class SpeechToTextTool(PipelineTool):
|
||||
default_checkpoint = "openai/whisper-base"
|
||||
description = (
|
||||
"This is a tool that transcribes an audio into text. It takes an input named `audio` and returns the "
|
||||
"transcribed text."
|
||||
)
|
||||
default_checkpoint = "distil-whisper/distil-large-v3"
|
||||
description = "This is a tool that transcribes an audio into text. It returns the transcribed text."
|
||||
name = "transcriber"
|
||||
pre_processor_class = WhisperProcessor
|
||||
model_class = WhisperForConditionalGeneration
|
||||
|
||||
inputs = ["audio"]
|
||||
outputs = ["text"]
|
||||
inputs = {"audio": {"type": "audio", "description": "The audio to transcribe"}}
|
||||
output_type = "text"
|
||||
|
||||
def encode(self, audio):
|
||||
return self.pre_processor(audio, return_tensors="pt").input_features
|
||||
return self.pre_processor(audio, return_tensors="pt")
|
||||
|
||||
def forward(self, inputs):
|
||||
return self.model.generate(inputs=inputs)
|
||||
return self.model.generate(inputs["input_features"])
|
||||
|
||||
def decode(self, outputs):
|
||||
return self.pre_processor.batch_decode(outputs, skip_special_tokens=True)[0]
|
@ -1,7 +1,7 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -14,11 +14,12 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from ..models.speecht5 import SpeechT5ForTextToSpeech, SpeechT5HifiGan, SpeechT5Processor
|
||||
from ..utils import is_datasets_available
|
||||
from .base import PipelineTool
|
||||
from .tools import PipelineTool
|
||||
|
||||
|
||||
if is_datasets_available():
|
||||
@ -28,16 +29,15 @@ if is_datasets_available():
|
||||
class TextToSpeechTool(PipelineTool):
|
||||
default_checkpoint = "microsoft/speecht5_tts"
|
||||
description = (
|
||||
"This is a tool that reads an English text out loud. It takes an input named `text` which should contain the "
|
||||
"text to read (in English) and returns a waveform object containing the sound."
|
||||
"This is a tool that reads an English text out loud. It returns a waveform object containing the sound."
|
||||
)
|
||||
name = "text_reader"
|
||||
name = "text_to_speech"
|
||||
pre_processor_class = SpeechT5Processor
|
||||
model_class = SpeechT5ForTextToSpeech
|
||||
post_processor_class = SpeechT5HifiGan
|
||||
|
||||
inputs = ["text"]
|
||||
outputs = ["audio"]
|
||||
inputs = {"text": {"type": "text", "description": "The text to read out loud (in English)"}}
|
||||
output_type = "audio"
|
||||
|
||||
def setup(self):
|
||||
if self.post_processor is None:
|
@ -16,18 +16,22 @@
|
||||
# limitations under the License.
|
||||
import base64
|
||||
import importlib
|
||||
import inspect
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from functools import lru_cache
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from huggingface_hub import create_repo, hf_hub_download, metadata_update, upload_folder
|
||||
from huggingface_hub import create_repo, get_collection, hf_hub_download, metadata_update, upload_folder
|
||||
from huggingface_hub.utils import RepositoryNotFoundError, build_hf_headers, get_session
|
||||
from packaging import version
|
||||
|
||||
from ..dynamic_module_utils import custom_object_save, get_class_from_dynamic_module, get_imports
|
||||
from ..image_utils import is_pil_image
|
||||
from ..dynamic_module_utils import (
|
||||
custom_object_save,
|
||||
get_class_from_dynamic_module,
|
||||
get_imports,
|
||||
)
|
||||
from ..models.auto import AutoProcessor
|
||||
from ..utils import (
|
||||
CONFIG_NAME,
|
||||
@ -42,6 +46,11 @@ from .agent_types import handle_agent_inputs, handle_agent_outputs
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
import PIL.Image
|
||||
import PIL.ImageOps
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
@ -89,30 +98,46 @@ class Tool:
|
||||
returns the text contained in the file'.
|
||||
- **name** (`str`) -- A performative name that will be used for your tool in the prompt to the agent. For instance
|
||||
`"text-classifier"` or `"image_generator"`.
|
||||
- **inputs** (`List[str]`) -- The list of modalities expected for the inputs (in the same order as in the call).
|
||||
Modalitiies should be `"text"`, `"image"` or `"audio"`. This is only used by `launch_gradio_demo` or to make a
|
||||
nice space from your tool.
|
||||
- **outputs** (`List[str]`) -- The list of modalities returned but the tool (in the same order as the return of the
|
||||
call method). Modalitiies should be `"text"`, `"image"` or `"audio"`. This is only used by `launch_gradio_demo`
|
||||
or to make a nice space from your tool.
|
||||
- **inputs** (`Dict[str, Dict[str, Union[str, type]]]`) -- The dict of modalities expected for the inputs.
|
||||
It has one `type`key and a `description`key.
|
||||
This is used by `launch_gradio_demo` or to make a nice space from your tool, and also can be used in the generated
|
||||
description for your tool.
|
||||
- **output_type** (`type`) -- The type of the tool output. This is used by `launch_gradio_demo`
|
||||
or to make a nice space from your tool, and also can be used in the generated description for your tool.
|
||||
|
||||
You can also override the method [`~Tool.setup`] if your tool as an expensive operation to perform before being
|
||||
usable (such as loading a model). [`~Tool.setup`] will be called the first time you use your tool, but not at
|
||||
instantiation.
|
||||
"""
|
||||
|
||||
description: str = "This is a tool that ..."
|
||||
name: str = ""
|
||||
|
||||
inputs: List[str]
|
||||
outputs: List[str]
|
||||
name: str
|
||||
description: str
|
||||
inputs: Dict[str, Dict[str, Union[str, type]]]
|
||||
output_type: type
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.is_initialized = False
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
def validate_attributes(self):
|
||||
required_attributes = {
|
||||
"description": str,
|
||||
"name": str,
|
||||
"inputs": Dict,
|
||||
"output_type": type,
|
||||
}
|
||||
for attr, expected_type in required_attributes.items():
|
||||
attr_value = getattr(self, attr, None)
|
||||
if not isinstance(attr_value, expected_type):
|
||||
raise TypeError(f"Instance attribute {attr} must exist and be of type {expected_type.__name__}")
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return NotImplemented("Write this method in your subclass of `Tool`.")
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
args, kwargs = handle_agent_inputs(*args, **kwargs)
|
||||
outputs = self.forward(*args, **kwargs)
|
||||
return handle_agent_outputs(outputs, self.output_type)
|
||||
|
||||
def setup(self):
|
||||
"""
|
||||
Overwrite this method here for any operation that is expensive and needs to be executed before you start using
|
||||
@ -156,7 +181,13 @@ class Tool:
|
||||
else:
|
||||
tool_config = {}
|
||||
|
||||
tool_config = {"tool_class": full_name, "description": self.description, "name": self.name}
|
||||
tool_config = {
|
||||
"tool_class": full_name,
|
||||
"description": self.description,
|
||||
"name": self.name,
|
||||
"inputs": str(self.inputs),
|
||||
"output_type": str(self.output_type),
|
||||
}
|
||||
with open(config_file, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(tool_config, indent=2, sort_keys=True) + "\n")
|
||||
|
||||
@ -180,7 +211,6 @@ class Tool:
|
||||
repo_id: str,
|
||||
model_repo_id: Optional[str] = None,
|
||||
token: Optional[str] = None,
|
||||
remote: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -203,21 +233,11 @@ class Tool:
|
||||
token (`str`, *optional*):
|
||||
The token to identify you on hf.co. If unset, will use the token generated when running
|
||||
`huggingface-cli login` (stored in `~/.huggingface`).
|
||||
remote (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use your tool by downloading the model or (if it is available) with an inference endpoint.
|
||||
kwargs (additional keyword arguments, *optional*):
|
||||
Additional keyword arguments that will be split in two: all arguments relevant to the Hub (such as
|
||||
`cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the
|
||||
others will be passed along to its init.
|
||||
"""
|
||||
if remote and model_repo_id is None:
|
||||
endpoints = get_default_endpoints()
|
||||
if repo_id not in endpoints:
|
||||
raise ValueError(
|
||||
f"Could not infer a default endpoint for {repo_id}, you need to pass one using the "
|
||||
"`model_repo_id` argument."
|
||||
)
|
||||
model_repo_id = endpoints[repo_id]
|
||||
hub_kwargs_names = [
|
||||
"cache_dir",
|
||||
"force_download",
|
||||
@ -290,8 +310,11 @@ class Tool:
|
||||
)
|
||||
tool_class.description = custom_tool["description"]
|
||||
|
||||
if remote:
|
||||
return RemoteTool(model_repo_id, token=token, tool_class=tool_class)
|
||||
if tool_class.inputs != custom_tool["inputs"]:
|
||||
tool_class.inputs = custom_tool["inputs"]
|
||||
if tool_class.output_type != custom_tool["output_type"]:
|
||||
tool_class.output_type = custom_tool["output_type"]
|
||||
|
||||
return tool_class(model_repo_id, token=token, **kwargs)
|
||||
|
||||
def push_to_hub(
|
||||
@ -305,6 +328,14 @@ class Tool:
|
||||
"""
|
||||
Upload the tool to the Hub.
|
||||
|
||||
For this method to work properly, your tool must have been defined in a separate module (not `__main__`).
|
||||
For instance:
|
||||
```
|
||||
from my_tool_module import MyTool
|
||||
my_tool = MyTool()
|
||||
my_tool.push_to_hub("my-username/my-space")
|
||||
```
|
||||
|
||||
Parameters:
|
||||
repo_id (`str`):
|
||||
The name of the repository you want to push your tool to. It should contain your organization name when
|
||||
@ -320,7 +351,12 @@ class Tool:
|
||||
Whether or not to create a PR with the uploaded files or directly commit.
|
||||
"""
|
||||
repo_url = create_repo(
|
||||
repo_id=repo_id, token=token, private=private, exist_ok=True, repo_type="space", space_sdk="gradio"
|
||||
repo_id=repo_id,
|
||||
token=token,
|
||||
private=private,
|
||||
exist_ok=True,
|
||||
repo_type="space",
|
||||
space_sdk="gradio",
|
||||
)
|
||||
repo_id = repo_url.repo_id
|
||||
metadata_update(repo_id, {"tags": ["tool"]}, repo_type="space")
|
||||
@ -343,102 +379,81 @@ class Tool:
|
||||
"""
|
||||
Creates a [`Tool`] from a gradio tool.
|
||||
"""
|
||||
import inspect
|
||||
|
||||
class GradioToolWrapper(Tool):
|
||||
def __init__(self, _gradio_tool):
|
||||
super().__init__()
|
||||
self.name = _gradio_tool.name
|
||||
self.description = _gradio_tool.description
|
||||
self.output_type = "text"
|
||||
self._gradio_tool = _gradio_tool
|
||||
func_args = list(inspect.signature(_gradio_tool.run).parameters.keys())
|
||||
self.inputs = {key: "" for key in func_args}
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self._gradio_tool.run(*args, **kwargs)
|
||||
|
||||
GradioToolWrapper.__call__ = gradio_tool.run
|
||||
return GradioToolWrapper(gradio_tool)
|
||||
|
||||
|
||||
class RemoteTool(Tool):
|
||||
"""
|
||||
A [`Tool`] that will make requests to an inference endpoint.
|
||||
|
||||
Args:
|
||||
endpoint_url (`str`, *optional*):
|
||||
The url of the endpoint to use.
|
||||
token (`str`, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If unset, will use the token generated when
|
||||
running `huggingface-cli login` (stored in `~/.huggingface`).
|
||||
tool_class (`type`, *optional*):
|
||||
The corresponding `tool_class` if this is a remote version of an existing tool. Will help determine when
|
||||
the output should be converted to another type (like images).
|
||||
"""
|
||||
|
||||
def __init__(self, endpoint_url=None, token=None, tool_class=None):
|
||||
self.endpoint_url = endpoint_url
|
||||
self.client = EndpointClient(endpoint_url, token=token)
|
||||
self.tool_class = tool_class
|
||||
|
||||
def prepare_inputs(self, *args, **kwargs):
|
||||
@staticmethod
|
||||
def from_langchain(langchain_tool):
|
||||
"""
|
||||
Prepare the inputs received for the HTTP client sending data to the endpoint. Positional arguments will be
|
||||
matched with the signature of the `tool_class` if it was provided at instantation. Images will be encoded into
|
||||
bytes.
|
||||
|
||||
You can override this method in your custom class of [`RemoteTool`].
|
||||
Creates a [`Tool`] from a langchain tool.
|
||||
"""
|
||||
inputs = kwargs.copy()
|
||||
if len(args) > 0:
|
||||
if self.tool_class is not None:
|
||||
# Match args with the signature
|
||||
if issubclass(self.tool_class, PipelineTool):
|
||||
call_method = self.tool_class.encode
|
||||
else:
|
||||
call_method = self.tool_class.__call__
|
||||
signature = inspect.signature(call_method).parameters
|
||||
parameters = [
|
||||
k
|
||||
for k, p in signature.items()
|
||||
if p.kind not in [inspect._ParameterKind.VAR_POSITIONAL, inspect._ParameterKind.VAR_KEYWORD]
|
||||
]
|
||||
if parameters[0] == "self":
|
||||
parameters = parameters[1:]
|
||||
if len(args) > len(parameters):
|
||||
raise ValueError(
|
||||
f"{self.tool_class} only accepts {len(parameters)} arguments but {len(args)} were given."
|
||||
)
|
||||
for arg, name in zip(args, parameters):
|
||||
inputs[name] = arg
|
||||
elif len(args) > 1:
|
||||
raise ValueError("A `RemoteTool` can only accept one positional input.")
|
||||
elif len(args) == 1:
|
||||
if is_pil_image(args[0]):
|
||||
return {"inputs": self.client.encode_image(args[0])}
|
||||
return {"inputs": args[0]}
|
||||
|
||||
for key, value in inputs.items():
|
||||
if is_pil_image(value):
|
||||
inputs[key] = self.client.encode_image(value)
|
||||
class LangChainToolWrapper(Tool):
|
||||
def __init__(self, _langchain_tool):
|
||||
super().__init__()
|
||||
self.name = _langchain_tool.name.lower()
|
||||
self.description = _langchain_tool.description
|
||||
self.inputs = parse_langchain_args(_langchain_tool.args)
|
||||
self.output_type = "text"
|
||||
self.langchain_tool = _langchain_tool
|
||||
|
||||
return {"inputs": inputs}
|
||||
def forward(self, *args, **kwargs):
|
||||
tool_input = kwargs.copy()
|
||||
for index, argument in enumerate(args):
|
||||
if index < len(self.inputs):
|
||||
input_key = next(iter(self.inputs))
|
||||
tool_input[input_key] = argument
|
||||
return self.langchain_tool.run(tool_input)
|
||||
|
||||
def extract_outputs(self, outputs):
|
||||
"""
|
||||
You can override this method in your custom class of [`RemoteTool`] to apply some custom post-processing of the
|
||||
outputs of the endpoint.
|
||||
"""
|
||||
return outputs
|
||||
return LangChainToolWrapper(langchain_tool)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
args, kwargs = handle_agent_inputs(*args, **kwargs)
|
||||
|
||||
output_image = self.tool_class is not None and self.tool_class.outputs == ["image"]
|
||||
inputs = self.prepare_inputs(*args, **kwargs)
|
||||
if isinstance(inputs, dict):
|
||||
outputs = self.client(**inputs, output_image=output_image)
|
||||
else:
|
||||
outputs = self.client(inputs, output_image=output_image)
|
||||
if isinstance(outputs, list) and len(outputs) == 1 and isinstance(outputs[0], list):
|
||||
outputs = outputs[0]
|
||||
DEFAULT_TOOL_DESCRIPTION_TEMPLATE = """
|
||||
- {{ tool.name }}: {{ tool.description }}
|
||||
Takes inputs: {{tool.inputs}}
|
||||
"""
|
||||
|
||||
outputs = handle_agent_outputs(outputs, self.tool_class.outputs if self.tool_class is not None else None)
|
||||
|
||||
return self.extract_outputs(outputs)
|
||||
def get_tool_description_with_args(tool: Tool, description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE) -> str:
|
||||
compiled_template = compile_jinja_template(description_template)
|
||||
rendered = compiled_template.render(
|
||||
tool=tool,
|
||||
)
|
||||
return rendered
|
||||
|
||||
|
||||
@lru_cache
|
||||
def compile_jinja_template(template):
|
||||
try:
|
||||
import jinja2
|
||||
from jinja2.exceptions import TemplateError
|
||||
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
||||
except ImportError:
|
||||
raise ImportError("template requires jinja2 to be installed.")
|
||||
|
||||
if version.parse(jinja2.__version__) <= version.parse("3.0.0"):
|
||||
raise ImportError("template requires jinja2>=3.0.0 to be installed. Your version is " f"{jinja2.__version__}.")
|
||||
|
||||
def raise_exception(message):
|
||||
raise TemplateError(message)
|
||||
|
||||
jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True)
|
||||
jinja_env.globals["raise_exception"] = raise_exception
|
||||
return jinja_env.from_string(template)
|
||||
|
||||
|
||||
class PipelineTool(Tool):
|
||||
@ -483,6 +498,10 @@ class PipelineTool(Tool):
|
||||
model_class = None
|
||||
post_processor_class = AutoProcessor
|
||||
default_checkpoint = None
|
||||
description = "This is a pipeline tool"
|
||||
name = "pipeline"
|
||||
inputs = {"prompt": str}
|
||||
output_type = str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -573,18 +592,22 @@ class PipelineTool(Tool):
|
||||
self.setup()
|
||||
|
||||
encoded_inputs = self.encode(*args, **kwargs)
|
||||
encoded_inputs = send_to_device(encoded_inputs, self.device)
|
||||
outputs = self.forward(encoded_inputs)
|
||||
|
||||
tensor_inputs = {k: v for k, v in encoded_inputs.items() if isinstance(v, torch.Tensor)}
|
||||
non_tensor_inputs = {k: v for k, v in encoded_inputs.items() if not isinstance(v, torch.Tensor)}
|
||||
|
||||
encoded_inputs = send_to_device(tensor_inputs, self.device)
|
||||
outputs = self.forward({**encoded_inputs, **non_tensor_inputs})
|
||||
outputs = send_to_device(outputs, "cpu")
|
||||
decoded_outputs = self.decode(outputs)
|
||||
|
||||
return handle_agent_outputs(decoded_outputs, self.outputs)
|
||||
return handle_agent_outputs(decoded_outputs, self.output_type)
|
||||
|
||||
|
||||
def launch_gradio_demo(tool_class: Tool):
|
||||
"""
|
||||
Launches a gradio demo for a tool. The corresponding tool class needs to properly implement the class attributes
|
||||
`inputs` and `outputs`.
|
||||
`inputs` and `output_type`.
|
||||
|
||||
Args:
|
||||
tool_class (`type`): The class of the tool for which to launch the demo.
|
||||
@ -599,10 +622,26 @@ def launch_gradio_demo(tool_class: Tool):
|
||||
def fn(*args, **kwargs):
|
||||
return tool(*args, **kwargs)
|
||||
|
||||
gradio_inputs = []
|
||||
for input_type in [tool_input["type"] for tool_input in tool_class.inputs.values()]:
|
||||
if input_type in [str, int, float]:
|
||||
gradio_inputs += "text"
|
||||
elif is_vision_available() and input_type == PIL.Image.Image:
|
||||
gradio_inputs += "image"
|
||||
else:
|
||||
gradio_inputs += "audio"
|
||||
|
||||
if tool_class.output_type in [str, int, float]:
|
||||
gradio_output = "text"
|
||||
elif is_vision_available() and tool_class.output_type == PIL.Image.Image:
|
||||
gradio_output = "image"
|
||||
else:
|
||||
gradio_output = "audio"
|
||||
|
||||
gr.Interface(
|
||||
fn=fn,
|
||||
inputs=tool_class.inputs,
|
||||
outputs=tool_class.outputs,
|
||||
inputs=gradio_inputs,
|
||||
outputs=gradio_output,
|
||||
title=tool_class.__name__,
|
||||
article=tool.description,
|
||||
).launch()
|
||||
@ -610,31 +649,16 @@ def launch_gradio_demo(tool_class: Tool):
|
||||
|
||||
TASK_MAPPING = {
|
||||
"document-question-answering": "DocumentQuestionAnsweringTool",
|
||||
"image-captioning": "ImageCaptioningTool",
|
||||
"image-question-answering": "ImageQuestionAnsweringTool",
|
||||
"image-segmentation": "ImageSegmentationTool",
|
||||
"speech-to-text": "SpeechToTextTool",
|
||||
"summarization": "TextSummarizationTool",
|
||||
"text-classification": "TextClassificationTool",
|
||||
"text-question-answering": "TextQuestionAnsweringTool",
|
||||
"text-to-speech": "TextToSpeechTool",
|
||||
"translation": "TranslationTool",
|
||||
"python_interpreter": "PythonInterpreterTool",
|
||||
"final_answer": "FinalAnswerTool",
|
||||
}
|
||||
|
||||
|
||||
def get_default_endpoints():
|
||||
endpoints_file = cached_file("huggingface-tools/default-endpoints", "default_endpoints.json", repo_type="dataset")
|
||||
with open(endpoints_file, "r", encoding="utf-8") as f:
|
||||
endpoints = json.load(f)
|
||||
return endpoints
|
||||
|
||||
|
||||
def supports_remote(task_or_repo_id):
|
||||
endpoints = get_default_endpoints()
|
||||
return task_or_repo_id in endpoints
|
||||
|
||||
|
||||
def load_tool(task_or_repo_id, model_repo_id=None, remote=False, token=None, **kwargs):
|
||||
def load_tool(task_or_repo_id, model_repo_id=None, token=None, **kwargs):
|
||||
"""
|
||||
Main function to quickly load a tool, be it on the Hub or in the Transformers library.
|
||||
|
||||
@ -652,20 +676,13 @@ def load_tool(task_or_repo_id, model_repo_id=None, remote=False, token=None, **k
|
||||
are:
|
||||
|
||||
- `"document-question-answering"`
|
||||
- `"image-captioning"`
|
||||
- `"image-question-answering"`
|
||||
- `"image-segmentation"`
|
||||
- `"speech-to-text"`
|
||||
- `"summarization"`
|
||||
- `"text-classification"`
|
||||
- `"text-question-answering"`
|
||||
- `"text-to-speech"`
|
||||
- `"translation"`
|
||||
|
||||
model_repo_id (`str`, *optional*):
|
||||
Use this argument to use a different model than the default one for the tool you selected.
|
||||
remote (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use your tool by downloading the model or (if it is available) with an inference endpoint.
|
||||
token (`str`, *optional*):
|
||||
The token to identify you on hf.co. If unset, will use the token generated when running `huggingface-cli
|
||||
login` (stored in `~/.huggingface`).
|
||||
@ -677,21 +694,9 @@ def load_tool(task_or_repo_id, model_repo_id=None, remote=False, token=None, **k
|
||||
if task_or_repo_id in TASK_MAPPING:
|
||||
tool_class_name = TASK_MAPPING[task_or_repo_id]
|
||||
main_module = importlib.import_module("transformers")
|
||||
tools_module = main_module.tools
|
||||
tools_module = main_module.agents
|
||||
tool_class = getattr(tools_module, tool_class_name)
|
||||
|
||||
if remote:
|
||||
if model_repo_id is None:
|
||||
endpoints = get_default_endpoints()
|
||||
if task_or_repo_id not in endpoints:
|
||||
raise ValueError(
|
||||
f"Could not infer a default endpoint for {task_or_repo_id}, you need to pass one using the "
|
||||
"`model_repo_id` argument."
|
||||
)
|
||||
model_repo_id = endpoints[task_or_repo_id]
|
||||
return RemoteTool(model_repo_id, token=token, tool_class=tool_class)
|
||||
else:
|
||||
return tool_class(model_repo_id, token=token, **kwargs)
|
||||
return tool_class(model_repo_id, token=token, **kwargs)
|
||||
else:
|
||||
logger.warning_once(
|
||||
f"You're loading a tool from the Hub from {model_repo_id}. Please make sure this is a source that you "
|
||||
@ -699,7 +704,7 @@ def load_tool(task_or_repo_id, model_repo_id=None, remote=False, token=None, **k
|
||||
f"the tools that you load. We recommend specifying a `revision` to ensure you're loading the "
|
||||
f"code that you have checked."
|
||||
)
|
||||
return Tool.from_hub(task_or_repo_id, model_repo_id=model_repo_id, token=token, remote=remote, **kwargs)
|
||||
return Tool.from_hub(task_or_repo_id, model_repo_id=model_repo_id, token=token, **kwargs)
|
||||
|
||||
|
||||
def add_description(description):
|
||||
@ -718,7 +723,10 @@ def add_description(description):
|
||||
## Will move to the Hub
|
||||
class EndpointClient:
|
||||
def __init__(self, endpoint_url: str, token: Optional[str] = None):
|
||||
self.headers = {**build_hf_headers(token=token), "Content-Type": "application/json"}
|
||||
self.headers = {
|
||||
**build_hf_headers(token=token),
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
self.endpoint_url = endpoint_url
|
||||
|
||||
@staticmethod
|
||||
@ -763,3 +771,44 @@ class EndpointClient:
|
||||
return self.decode_image(response.content)
|
||||
else:
|
||||
return response.json()
|
||||
|
||||
|
||||
def parse_langchain_args(args: Dict[str, str]) -> Dict[str, str]:
|
||||
"""Parse the args attribute of a LangChain tool to create a matching inputs dictionary."""
|
||||
inputs = args.copy()
|
||||
for arg_details in inputs.values():
|
||||
if "title" in arg_details:
|
||||
arg_details.pop("title")
|
||||
return inputs
|
||||
|
||||
|
||||
class ToolCollection:
|
||||
"""
|
||||
Tool collections enable loading all Spaces from a collection in order to be added to the agent's toolbox.
|
||||
|
||||
> [!NOTE]
|
||||
> Only Spaces will be fetched, so you can feel free to add models and datasets to your collection if you'd
|
||||
> like for this collection to showcase them.
|
||||
|
||||
Args:
|
||||
collection_slug (str):
|
||||
The collection slug referencing the collection.
|
||||
token (str, *optional*):
|
||||
The authentication token if the collection is private.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
>>> from transformers import ToolCollection, ReactCodeAgent
|
||||
|
||||
>>> image_tool_collection = ToolCollection(collection_slug="huggingface-tools/diffusion-tools-6630bb19a942c2306a2cdb6f")
|
||||
>>> agent = ReactCodeAgent(tools=[*image_tool_collection.tools], add_base_tools=True)
|
||||
|
||||
>>> agent.run("Please draw me a picture of rivers and lakes.")
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, collection_slug: str, token: Optional[str] = None):
|
||||
self._collection = get_collection(collection_slug, token=token)
|
||||
self._hub_repo_ids = {item.item_id for item in self._collection.items if item.item_type == "space"}
|
||||
self.tools = {Tool.from_hub(repo_id) for repo_id in self._hub_repo_ids}
|
@ -15,7 +15,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from ..models.auto import AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
from .base import PipelineTool
|
||||
from .tools import PipelineTool
|
||||
|
||||
|
||||
LANGUAGE_CODES = {
|
||||
@ -231,27 +231,35 @@ class TranslationTool(PipelineTool):
|
||||
Example:
|
||||
|
||||
```py
|
||||
from transformers.tools import TranslationTool
|
||||
from transformers.agents import TranslationTool
|
||||
|
||||
translator = TranslationTool()
|
||||
translator("This is a super nice API!", src_lang="English", tgt_lang="French")
|
||||
```
|
||||
"""
|
||||
|
||||
lang_to_code = LANGUAGE_CODES
|
||||
default_checkpoint = "facebook/nllb-200-distilled-600M"
|
||||
description = (
|
||||
"This is a tool that translates text from a language to another. It takes three inputs: `text`, which should "
|
||||
"be the text to translate, `src_lang`, which should be the language of the text to translate and `tgt_lang`, "
|
||||
"which should be the language for the desired ouput language. Both `src_lang` and `tgt_lang` are written in "
|
||||
"plain English, such as 'Romanian', or 'Albanian'. It returns the text translated in `tgt_lang`."
|
||||
"This is a tool that translates text from a language to another."
|
||||
f"Both `src_lang`and `tgt_lang` should belong to this list of languages: {list(lang_to_code.keys())}."
|
||||
)
|
||||
name = "translator"
|
||||
pre_processor_class = AutoTokenizer
|
||||
model_class = AutoModelForSeq2SeqLM
|
||||
lang_to_code = LANGUAGE_CODES
|
||||
|
||||
inputs = ["text", "text", "text"]
|
||||
outputs = ["text"]
|
||||
inputs = {
|
||||
"text": {"type": "text", "description": "The text to translate"},
|
||||
"src_lang": {
|
||||
"type": "text",
|
||||
"description": "The language of the text to translate. Written in plain English, such as 'Romanian', or 'Albanian'",
|
||||
},
|
||||
"tgt_lang": {
|
||||
"type": "text",
|
||||
"description": "The language for the desired ouput language. Written in plain English, such as 'Romanian', or 'Albanian'",
|
||||
},
|
||||
}
|
||||
output_type = "text"
|
||||
|
||||
def encode(self, text, src_lang, tgt_lang):
|
||||
if src_lang not in self.lang_to_code:
|
@ -201,7 +201,7 @@ _run_custom_tokenizers = parse_flag_from_env("RUN_CUSTOM_TOKENIZERS", default=Fa
|
||||
_run_staging = parse_flag_from_env("HUGGINGFACE_CO_STAGING", default=False)
|
||||
_tf_gpu_memory_limit = parse_int_from_env("TF_GPU_MEMORY_LIMIT", default=None)
|
||||
_run_pipeline_tests = parse_flag_from_env("RUN_PIPELINE_TESTS", default=True)
|
||||
_run_tool_tests = parse_flag_from_env("RUN_TOOL_TESTS", default=False)
|
||||
_run_agent_tests = parse_flag_from_env("RUN_AGENT_TESTS", default=False)
|
||||
_run_third_party_device_tests = parse_flag_from_env("RUN_THIRD_PARTY_DEVICE_TESTS", default=False)
|
||||
|
||||
|
||||
@ -276,19 +276,19 @@ def is_pipeline_test(test_case):
|
||||
return pytest.mark.is_pipeline_test()(test_case)
|
||||
|
||||
|
||||
def is_tool_test(test_case):
|
||||
def is_agent_test(test_case):
|
||||
"""
|
||||
Decorator marking a test as a tool test. If RUN_TOOL_TESTS is set to a falsy value, those tests will be skipped.
|
||||
Decorator marking a test as an agent test. If RUN_TOOL_TESTS is set to a falsy value, those tests will be skipped.
|
||||
"""
|
||||
if not _run_tool_tests:
|
||||
return unittest.skip("test is a tool test")(test_case)
|
||||
if not _run_agent_tests:
|
||||
return unittest.skip("test is an agent test")(test_case)
|
||||
else:
|
||||
try:
|
||||
import pytest # We don't need a hard dependency on pytest in the main library
|
||||
except ImportError:
|
||||
return test_case
|
||||
else:
|
||||
return pytest.mark.is_tool_test()(test_case)
|
||||
return pytest.mark.is_agent_test()(test_case)
|
||||
|
||||
|
||||
def slow(test_case):
|
||||
|
@ -1,778 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib.util
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict
|
||||
|
||||
import requests
|
||||
from huggingface_hub import HfFolder, hf_hub_download, list_spaces
|
||||
|
||||
from ..models.auto import AutoTokenizer
|
||||
from ..utils import is_offline_mode, is_openai_available, is_torch_available, logging
|
||||
from .base import TASK_MAPPING, TOOL_CONFIG_FILE, Tool, load_tool, supports_remote
|
||||
from .prompts import CHAT_MESSAGE_PROMPT, download_prompt
|
||||
from .python_interpreter import evaluate
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
if is_openai_available():
|
||||
import openai
|
||||
|
||||
if is_torch_available():
|
||||
from ..generation import StoppingCriteria, StoppingCriteriaList
|
||||
from ..models.auto import AutoModelForCausalLM
|
||||
else:
|
||||
StoppingCriteria = object
|
||||
|
||||
_tools_are_initialized = False
|
||||
|
||||
|
||||
BASE_PYTHON_TOOLS = {
|
||||
"print": print,
|
||||
"range": range,
|
||||
"float": float,
|
||||
"int": int,
|
||||
"bool": bool,
|
||||
"str": str,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class PreTool:
|
||||
task: str
|
||||
description: str
|
||||
repo_id: str
|
||||
|
||||
|
||||
HUGGINGFACE_DEFAULT_TOOLS = {}
|
||||
|
||||
|
||||
HUGGINGFACE_DEFAULT_TOOLS_FROM_HUB = [
|
||||
"image-transformation",
|
||||
"text-download",
|
||||
"text-to-image",
|
||||
"text-to-video",
|
||||
]
|
||||
|
||||
|
||||
def get_remote_tools(organization="huggingface-tools"):
|
||||
if is_offline_mode():
|
||||
logger.info("You are in offline mode, so remote tools are not available.")
|
||||
return {}
|
||||
|
||||
spaces = list_spaces(author=organization)
|
||||
tools = {}
|
||||
for space_info in spaces:
|
||||
repo_id = space_info.id
|
||||
resolved_config_file = hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="space")
|
||||
with open(resolved_config_file, encoding="utf-8") as reader:
|
||||
config = json.load(reader)
|
||||
|
||||
task = repo_id.split("/")[-1]
|
||||
tools[config["name"]] = PreTool(task=task, description=config["description"], repo_id=repo_id)
|
||||
|
||||
return tools
|
||||
|
||||
|
||||
def _setup_default_tools():
|
||||
global HUGGINGFACE_DEFAULT_TOOLS
|
||||
global _tools_are_initialized
|
||||
|
||||
if _tools_are_initialized:
|
||||
return
|
||||
|
||||
main_module = importlib.import_module("transformers")
|
||||
tools_module = main_module.tools
|
||||
|
||||
remote_tools = get_remote_tools()
|
||||
for task_name, tool_class_name in TASK_MAPPING.items():
|
||||
tool_class = getattr(tools_module, tool_class_name)
|
||||
description = tool_class.description
|
||||
HUGGINGFACE_DEFAULT_TOOLS[tool_class.name] = PreTool(task=task_name, description=description, repo_id=None)
|
||||
|
||||
if not is_offline_mode():
|
||||
for task_name in HUGGINGFACE_DEFAULT_TOOLS_FROM_HUB:
|
||||
found = False
|
||||
for tool_name, tool in remote_tools.items():
|
||||
if tool.task == task_name:
|
||||
HUGGINGFACE_DEFAULT_TOOLS[tool_name] = tool
|
||||
found = True
|
||||
break
|
||||
|
||||
if not found:
|
||||
raise ValueError(f"{task_name} is not implemented on the Hub.")
|
||||
|
||||
_tools_are_initialized = True
|
||||
|
||||
|
||||
def resolve_tools(code, toolbox, remote=False, cached_tools=None):
|
||||
if cached_tools is None:
|
||||
resolved_tools = BASE_PYTHON_TOOLS.copy()
|
||||
else:
|
||||
resolved_tools = cached_tools
|
||||
for name, tool in toolbox.items():
|
||||
if name not in code or name in resolved_tools:
|
||||
continue
|
||||
|
||||
if isinstance(tool, Tool):
|
||||
resolved_tools[name] = tool
|
||||
else:
|
||||
task_or_repo_id = tool.task if tool.repo_id is None else tool.repo_id
|
||||
_remote = remote and supports_remote(task_or_repo_id)
|
||||
resolved_tools[name] = load_tool(task_or_repo_id, remote=_remote)
|
||||
|
||||
return resolved_tools
|
||||
|
||||
|
||||
def get_tool_creation_code(code, toolbox, remote=False):
|
||||
code_lines = ["from transformers import load_tool", ""]
|
||||
for name, tool in toolbox.items():
|
||||
if name not in code or isinstance(tool, Tool):
|
||||
continue
|
||||
|
||||
task_or_repo_id = tool.task if tool.repo_id is None else tool.repo_id
|
||||
line = f'{name} = load_tool("{task_or_repo_id}"'
|
||||
if remote:
|
||||
line += ", remote=True"
|
||||
line += ")"
|
||||
code_lines.append(line)
|
||||
|
||||
return "\n".join(code_lines) + "\n"
|
||||
|
||||
|
||||
def clean_code_for_chat(result):
|
||||
lines = result.split("\n")
|
||||
idx = 0
|
||||
while idx < len(lines) and not lines[idx].lstrip().startswith("```"):
|
||||
idx += 1
|
||||
explanation = "\n".join(lines[:idx]).strip()
|
||||
if idx == len(lines):
|
||||
return explanation, None
|
||||
|
||||
idx += 1
|
||||
start_idx = idx
|
||||
while not lines[idx].lstrip().startswith("```"):
|
||||
idx += 1
|
||||
code = "\n".join(lines[start_idx:idx]).strip()
|
||||
|
||||
return explanation, code
|
||||
|
||||
|
||||
def clean_code_for_run(result):
|
||||
result = f"I will use the following {result}"
|
||||
explanation, code = result.split("Answer:")
|
||||
explanation = explanation.strip()
|
||||
code = code.strip()
|
||||
|
||||
code_lines = code.split("\n")
|
||||
if code_lines[0] in ["```", "```py", "```python"]:
|
||||
code_lines = code_lines[1:]
|
||||
if code_lines[-1] == "```":
|
||||
code_lines = code_lines[:-1]
|
||||
code = "\n".join(code_lines)
|
||||
|
||||
return explanation, code
|
||||
|
||||
|
||||
class Agent:
|
||||
"""
|
||||
Base class for all agents which contains the main API methods.
|
||||
|
||||
Args:
|
||||
chat_prompt_template (`str`, *optional*):
|
||||
Pass along your own prompt if you want to override the default template for the `chat` method. Can be the
|
||||
actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named
|
||||
`chat_prompt_template.txt` in this repo in this case.
|
||||
run_prompt_template (`str`, *optional*):
|
||||
Pass along your own prompt if you want to override the default template for the `run` method. Can be the
|
||||
actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named
|
||||
`run_prompt_template.txt` in this repo in this case.
|
||||
additional_tools ([`Tool`], list of tools or dictionary with tool values, *optional*):
|
||||
Any additional tools to include on top of the default ones. If you pass along a tool with the same name as
|
||||
one of the default tools, that default tool will be overridden.
|
||||
"""
|
||||
|
||||
def __init__(self, chat_prompt_template=None, run_prompt_template=None, additional_tools=None):
|
||||
_setup_default_tools()
|
||||
|
||||
agent_name = self.__class__.__name__
|
||||
self.chat_prompt_template = download_prompt(chat_prompt_template, agent_name, mode="chat")
|
||||
self.run_prompt_template = download_prompt(run_prompt_template, agent_name, mode="run")
|
||||
self._toolbox = HUGGINGFACE_DEFAULT_TOOLS.copy()
|
||||
self.log = print
|
||||
if additional_tools is not None:
|
||||
if isinstance(additional_tools, (list, tuple)):
|
||||
additional_tools = {t.name: t for t in additional_tools}
|
||||
elif not isinstance(additional_tools, dict):
|
||||
additional_tools = {additional_tools.name: additional_tools}
|
||||
|
||||
replacements = {name: tool for name, tool in additional_tools.items() if name in HUGGINGFACE_DEFAULT_TOOLS}
|
||||
self._toolbox.update(additional_tools)
|
||||
if len(replacements) > 1:
|
||||
names = "\n".join([f"- {n}: {t}" for n, t in replacements.items()])
|
||||
logger.warning(
|
||||
f"The following tools have been replaced by the ones provided in `additional_tools`:\n{names}."
|
||||
)
|
||||
elif len(replacements) == 1:
|
||||
name = list(replacements.keys())[0]
|
||||
logger.warning(f"{name} has been replaced by {replacements[name]} as provided in `additional_tools`.")
|
||||
|
||||
self.prepare_for_new_chat()
|
||||
|
||||
@property
|
||||
def toolbox(self) -> Dict[str, Tool]:
|
||||
"""Get all tool currently available to the agent"""
|
||||
return self._toolbox
|
||||
|
||||
def format_prompt(self, task, chat_mode=False):
|
||||
description = "\n".join([f"- {name}: {tool.description}" for name, tool in self.toolbox.items()])
|
||||
if chat_mode:
|
||||
if self.chat_history is None:
|
||||
prompt = self.chat_prompt_template.replace("<<all_tools>>", description)
|
||||
else:
|
||||
prompt = self.chat_history
|
||||
prompt += CHAT_MESSAGE_PROMPT.replace("<<task>>", task)
|
||||
else:
|
||||
prompt = self.run_prompt_template.replace("<<all_tools>>", description)
|
||||
prompt = prompt.replace("<<prompt>>", task)
|
||||
return prompt
|
||||
|
||||
def set_stream(self, streamer):
|
||||
"""
|
||||
Set the function use to stream results (which is `print` by default).
|
||||
|
||||
Args:
|
||||
streamer (`callable`): The function to call when streaming results from the LLM.
|
||||
"""
|
||||
self.log = streamer
|
||||
|
||||
def chat(self, task, *, return_code=False, remote=False, **kwargs):
|
||||
"""
|
||||
Sends a new request to the agent in a chat. Will use the previous ones in its history.
|
||||
|
||||
Args:
|
||||
task (`str`): The task to perform
|
||||
return_code (`bool`, *optional*, defaults to `False`):
|
||||
Whether to just return code and not evaluate it.
|
||||
remote (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to use remote tools (inference endpoints) instead of local ones.
|
||||
kwargs (additional keyword arguments, *optional*):
|
||||
Any keyword argument to send to the agent when evaluating the code.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
from transformers import HfAgent
|
||||
|
||||
agent = HfAgent("https://api-inference.huggingface.co/models/bigcode/starcoder")
|
||||
agent.chat("Draw me a picture of rivers and lakes")
|
||||
|
||||
agent.chat("Transform the picture so that there is a rock in there")
|
||||
```
|
||||
"""
|
||||
prompt = self.format_prompt(task, chat_mode=True)
|
||||
result = self.generate_one(prompt, stop=["Human:", "====="])
|
||||
self.chat_history = prompt + result.strip() + "\n"
|
||||
explanation, code = clean_code_for_chat(result)
|
||||
|
||||
self.log(f"==Explanation from the agent==\n{explanation}")
|
||||
|
||||
if code is not None:
|
||||
self.log(f"\n\n==Code generated by the agent==\n{code}")
|
||||
if not return_code:
|
||||
self.log("\n\n==Result==")
|
||||
self.cached_tools = resolve_tools(code, self.toolbox, remote=remote, cached_tools=self.cached_tools)
|
||||
self.chat_state.update(kwargs)
|
||||
return evaluate(code, self.cached_tools, self.chat_state, chat_mode=True)
|
||||
else:
|
||||
tool_code = get_tool_creation_code(code, self.toolbox, remote=remote)
|
||||
return f"{tool_code}\n{code}"
|
||||
|
||||
def prepare_for_new_chat(self):
|
||||
"""
|
||||
Clears the history of prior calls to [`~Agent.chat`].
|
||||
"""
|
||||
self.chat_history = None
|
||||
self.chat_state = {}
|
||||
self.cached_tools = None
|
||||
|
||||
def clean_code_for_run(self, result):
|
||||
"""
|
||||
Override this method if you want to change the way the code is
|
||||
cleaned for the `run` method.
|
||||
"""
|
||||
return clean_code_for_run(result)
|
||||
|
||||
def run(self, task, *, return_code=False, remote=False, **kwargs):
|
||||
"""
|
||||
Sends a request to the agent.
|
||||
|
||||
Args:
|
||||
task (`str`): The task to perform
|
||||
return_code (`bool`, *optional*, defaults to `False`):
|
||||
Whether to just return code and not evaluate it.
|
||||
remote (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to use remote tools (inference endpoints) instead of local ones.
|
||||
kwargs (additional keyword arguments, *optional*):
|
||||
Any keyword argument to send to the agent when evaluating the code.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
from transformers import HfAgent
|
||||
|
||||
agent = HfAgent("https://api-inference.huggingface.co/models/bigcode/starcoder")
|
||||
agent.run("Draw me a picture of rivers and lakes")
|
||||
```
|
||||
"""
|
||||
prompt = self.format_prompt(task)
|
||||
result = self.generate_one(prompt, stop=["Task:"])
|
||||
explanation, code = self.clean_code_for_run(result)
|
||||
|
||||
self.log(f"==Explanation from the agent==\n{explanation}")
|
||||
|
||||
self.log(f"\n\n==Code generated by the agent==\n{code}")
|
||||
if not return_code:
|
||||
self.log("\n\n==Result==")
|
||||
self.cached_tools = resolve_tools(code, self.toolbox, remote=remote, cached_tools=self.cached_tools)
|
||||
return evaluate(code, self.cached_tools, state=kwargs.copy())
|
||||
else:
|
||||
tool_code = get_tool_creation_code(code, self.toolbox, remote=remote)
|
||||
return f"{tool_code}\n{code}"
|
||||
|
||||
def generate_one(self, prompt, stop):
|
||||
# This is the method to implement in your custom agent.
|
||||
raise NotImplementedError
|
||||
|
||||
def generate_many(self, prompts, stop):
|
||||
# Override if you have a way to do batch generation faster than one by one
|
||||
return [self.generate_one(prompt, stop) for prompt in prompts]
|
||||
|
||||
|
||||
class OpenAiAgent(Agent):
|
||||
"""
|
||||
Agent that uses the openai API to generate code.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
The openAI models are used in generation mode, so even for the `chat()` API, it's better to use models like
|
||||
`"text-davinci-003"` over the chat-GPT variant. Proper support for chat-GPT models will come in a next version.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
model (`str`, *optional*, defaults to `"text-davinci-003"`):
|
||||
The name of the OpenAI model to use.
|
||||
api_key (`str`, *optional*):
|
||||
The API key to use. If unset, will look for the environment variable `"OPENAI_API_KEY"`.
|
||||
chat_prompt_template (`str`, *optional*):
|
||||
Pass along your own prompt if you want to override the default template for the `chat` method. Can be the
|
||||
actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named
|
||||
`chat_prompt_template.txt` in this repo in this case.
|
||||
run_prompt_template (`str`, *optional*):
|
||||
Pass along your own prompt if you want to override the default template for the `run` method. Can be the
|
||||
actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named
|
||||
`run_prompt_template.txt` in this repo in this case.
|
||||
additional_tools ([`Tool`], list of tools or dictionary with tool values, *optional*):
|
||||
Any additional tools to include on top of the default ones. If you pass along a tool with the same name as
|
||||
one of the default tools, that default tool will be overridden.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
from transformers import OpenAiAgent
|
||||
|
||||
agent = OpenAiAgent(model="text-davinci-003", api_key=xxx)
|
||||
agent.run("Is the following `text` (in Spanish) positive or negative?", text="¡Este es un API muy agradable!")
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model="text-davinci-003",
|
||||
api_key=None,
|
||||
chat_prompt_template=None,
|
||||
run_prompt_template=None,
|
||||
additional_tools=None,
|
||||
):
|
||||
if not is_openai_available():
|
||||
raise ImportError("Using `OpenAiAgent` requires `openai`: `pip install openai`.")
|
||||
|
||||
if api_key is None:
|
||||
api_key = os.environ.get("OPENAI_API_KEY", None)
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"You need an openai key to use `OpenAIAgent`. You can get one here: Get one here "
|
||||
"https://openai.com/api/`. If you have one, set it in your env with `os.environ['OPENAI_API_KEY'] = "
|
||||
"xxx."
|
||||
)
|
||||
else:
|
||||
openai.api_key = api_key
|
||||
self.model = model
|
||||
super().__init__(
|
||||
chat_prompt_template=chat_prompt_template,
|
||||
run_prompt_template=run_prompt_template,
|
||||
additional_tools=additional_tools,
|
||||
)
|
||||
|
||||
def generate_many(self, prompts, stop):
|
||||
if "gpt" in self.model:
|
||||
return [self._chat_generate(prompt, stop) for prompt in prompts]
|
||||
else:
|
||||
return self._completion_generate(prompts, stop)
|
||||
|
||||
def generate_one(self, prompt, stop):
|
||||
if "gpt" in self.model:
|
||||
return self._chat_generate(prompt, stop)
|
||||
else:
|
||||
return self._completion_generate([prompt], stop)[0]
|
||||
|
||||
def _chat_generate(self, prompt, stop):
|
||||
result = openai.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
temperature=0,
|
||||
stop=stop,
|
||||
)
|
||||
return result.choices[0].message.content
|
||||
|
||||
def _completion_generate(self, prompts, stop):
|
||||
result = openai.Completion.create(
|
||||
model=self.model,
|
||||
prompt=prompts,
|
||||
temperature=0,
|
||||
stop=stop,
|
||||
max_tokens=200,
|
||||
)
|
||||
return [answer["text"] for answer in result["choices"]]
|
||||
|
||||
|
||||
class AzureOpenAiAgent(Agent):
|
||||
"""
|
||||
Agent that uses Azure OpenAI to generate code. See the [official
|
||||
documentation](https://learn.microsoft.com/en-us/azure/cognitive-services/openai/) to learn how to deploy an openAI
|
||||
model on Azure
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
The openAI models are used in generation mode, so even for the `chat()` API, it's better to use models like
|
||||
`"text-davinci-003"` over the chat-GPT variant. Proper support for chat-GPT models will come in a next version.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
deployment_id (`str`):
|
||||
The name of the deployed Azure openAI model to use.
|
||||
api_key (`str`, *optional*):
|
||||
The API key to use. If unset, will look for the environment variable `"AZURE_OPENAI_API_KEY"`.
|
||||
resource_name (`str`, *optional*):
|
||||
The name of your Azure OpenAI Resource. If unset, will look for the environment variable
|
||||
`"AZURE_OPENAI_RESOURCE_NAME"`.
|
||||
api_version (`str`, *optional*, default to `"2022-12-01"`):
|
||||
The API version to use for this agent.
|
||||
is_chat_mode (`bool`, *optional*):
|
||||
Whether you are using a completion model or a chat model (see note above, chat models won't be as
|
||||
efficient). Will default to `gpt` being in the `deployment_id` or not.
|
||||
chat_prompt_template (`str`, *optional*):
|
||||
Pass along your own prompt if you want to override the default template for the `chat` method. Can be the
|
||||
actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named
|
||||
`chat_prompt_template.txt` in this repo in this case.
|
||||
run_prompt_template (`str`, *optional*):
|
||||
Pass along your own prompt if you want to override the default template for the `run` method. Can be the
|
||||
actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named
|
||||
`run_prompt_template.txt` in this repo in this case.
|
||||
additional_tools ([`Tool`], list of tools or dictionary with tool values, *optional*):
|
||||
Any additional tools to include on top of the default ones. If you pass along a tool with the same name as
|
||||
one of the default tools, that default tool will be overridden.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
from transformers import AzureOpenAiAgent
|
||||
|
||||
agent = AzureAiAgent(deployment_id="Davinci-003", api_key=xxx, resource_name=yyy)
|
||||
agent.run("Is the following `text` (in Spanish) positive or negative?", text="¡Este es un API muy agradable!")
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
deployment_id,
|
||||
api_key=None,
|
||||
resource_name=None,
|
||||
api_version="2022-12-01",
|
||||
is_chat_model=None,
|
||||
chat_prompt_template=None,
|
||||
run_prompt_template=None,
|
||||
additional_tools=None,
|
||||
):
|
||||
if not is_openai_available():
|
||||
raise ImportError("Using `OpenAiAgent` requires `openai`: `pip install openai`.")
|
||||
|
||||
self.deployment_id = deployment_id
|
||||
openai.api_type = "azure"
|
||||
if api_key is None:
|
||||
api_key = os.environ.get("AZURE_OPENAI_API_KEY", None)
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"You need an Azure openAI key to use `AzureOpenAIAgent`. If you have one, set it in your env with "
|
||||
"`os.environ['AZURE_OPENAI_API_KEY'] = xxx."
|
||||
)
|
||||
else:
|
||||
openai.api_key = api_key
|
||||
if resource_name is None:
|
||||
resource_name = os.environ.get("AZURE_OPENAI_RESOURCE_NAME", None)
|
||||
if resource_name is None:
|
||||
raise ValueError(
|
||||
"You need a resource_name to use `AzureOpenAIAgent`. If you have one, set it in your env with "
|
||||
"`os.environ['AZURE_OPENAI_RESOURCE_NAME'] = xxx."
|
||||
)
|
||||
else:
|
||||
openai.api_base = f"https://{resource_name}.openai.azure.com"
|
||||
openai.api_version = api_version
|
||||
|
||||
if is_chat_model is None:
|
||||
is_chat_model = "gpt" in deployment_id.lower()
|
||||
self.is_chat_model = is_chat_model
|
||||
|
||||
super().__init__(
|
||||
chat_prompt_template=chat_prompt_template,
|
||||
run_prompt_template=run_prompt_template,
|
||||
additional_tools=additional_tools,
|
||||
)
|
||||
|
||||
def generate_many(self, prompts, stop):
|
||||
if self.is_chat_model:
|
||||
return [self._chat_generate(prompt, stop) for prompt in prompts]
|
||||
else:
|
||||
return self._completion_generate(prompts, stop)
|
||||
|
||||
def generate_one(self, prompt, stop):
|
||||
if self.is_chat_model:
|
||||
return self._chat_generate(prompt, stop)
|
||||
else:
|
||||
return self._completion_generate([prompt], stop)[0]
|
||||
|
||||
def _chat_generate(self, prompt, stop):
|
||||
result = openai.ChatCompletion.create(
|
||||
engine=self.deployment_id,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
temperature=0,
|
||||
stop=stop,
|
||||
)
|
||||
return result["choices"][0]["message"]["content"]
|
||||
|
||||
def _completion_generate(self, prompts, stop):
|
||||
result = openai.Completion.create(
|
||||
engine=self.deployment_id,
|
||||
prompt=prompts,
|
||||
temperature=0,
|
||||
stop=stop,
|
||||
max_tokens=200,
|
||||
)
|
||||
return [answer["text"] for answer in result["choices"]]
|
||||
|
||||
|
||||
class HfAgent(Agent):
|
||||
"""
|
||||
Agent that uses an inference endpoint to generate code.
|
||||
|
||||
Args:
|
||||
url_endpoint (`str`):
|
||||
The name of the url endpoint to use.
|
||||
token (`str`, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If unset, will use the token generated when
|
||||
running `huggingface-cli login` (stored in `~/.huggingface`).
|
||||
chat_prompt_template (`str`, *optional*):
|
||||
Pass along your own prompt if you want to override the default template for the `chat` method. Can be the
|
||||
actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named
|
||||
`chat_prompt_template.txt` in this repo in this case.
|
||||
run_prompt_template (`str`, *optional*):
|
||||
Pass along your own prompt if you want to override the default template for the `run` method. Can be the
|
||||
actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named
|
||||
`run_prompt_template.txt` in this repo in this case.
|
||||
additional_tools ([`Tool`], list of tools or dictionary with tool values, *optional*):
|
||||
Any additional tools to include on top of the default ones. If you pass along a tool with the same name as
|
||||
one of the default tools, that default tool will be overridden.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
from transformers import HfAgent
|
||||
|
||||
agent = HfAgent("https://api-inference.huggingface.co/models/bigcode/starcoder")
|
||||
agent.run("Is the following `text` (in Spanish) positive or negative?", text="¡Este es un API muy agradable!")
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, url_endpoint, token=None, chat_prompt_template=None, run_prompt_template=None, additional_tools=None
|
||||
):
|
||||
self.url_endpoint = url_endpoint
|
||||
if token is None:
|
||||
self.token = f"Bearer {HfFolder().get_token()}"
|
||||
elif token.startswith("Bearer") or token.startswith("Basic"):
|
||||
self.token = token
|
||||
else:
|
||||
self.token = f"Bearer {token}"
|
||||
super().__init__(
|
||||
chat_prompt_template=chat_prompt_template,
|
||||
run_prompt_template=run_prompt_template,
|
||||
additional_tools=additional_tools,
|
||||
)
|
||||
|
||||
def generate_one(self, prompt, stop):
|
||||
headers = {"Authorization": self.token}
|
||||
inputs = {
|
||||
"inputs": prompt,
|
||||
"parameters": {"max_new_tokens": 200, "return_full_text": False, "stop": stop},
|
||||
}
|
||||
|
||||
response = requests.post(self.url_endpoint, json=inputs, headers=headers)
|
||||
if response.status_code == 429:
|
||||
logger.info("Getting rate-limited, waiting a tiny bit before trying again.")
|
||||
time.sleep(1)
|
||||
return self._generate_one(prompt)
|
||||
elif response.status_code != 200:
|
||||
raise ValueError(f"Error {response.status_code}: {response.json()}")
|
||||
|
||||
result = response.json()[0]["generated_text"]
|
||||
# Inference API returns the stop sequence
|
||||
for stop_seq in stop:
|
||||
if result.endswith(stop_seq):
|
||||
return result[: -len(stop_seq)]
|
||||
return result
|
||||
|
||||
|
||||
class LocalAgent(Agent):
|
||||
"""
|
||||
Agent that uses a local model and tokenizer to generate code.
|
||||
|
||||
Args:
|
||||
model ([`PreTrainedModel`]):
|
||||
The model to use for the agent.
|
||||
tokenizer ([`PreTrainedTokenizer`]):
|
||||
The tokenizer to use for the agent.
|
||||
chat_prompt_template (`str`, *optional*):
|
||||
Pass along your own prompt if you want to override the default template for the `chat` method. Can be the
|
||||
actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named
|
||||
`chat_prompt_template.txt` in this repo in this case.
|
||||
run_prompt_template (`str`, *optional*):
|
||||
Pass along your own prompt if you want to override the default template for the `run` method. Can be the
|
||||
actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named
|
||||
`run_prompt_template.txt` in this repo in this case.
|
||||
additional_tools ([`Tool`], list of tools or dictionary with tool values, *optional*):
|
||||
Any additional tools to include on top of the default ones. If you pass along a tool with the same name as
|
||||
one of the default tools, that default tool will be overridden.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, LocalAgent
|
||||
|
||||
checkpoint = "bigcode/starcoder"
|
||||
model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map="auto", torch_dtype=torch.bfloat16)
|
||||
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
||||
|
||||
agent = LocalAgent(model, tokenizer)
|
||||
agent.run("Draw me a picture of rivers and lakes.")
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, model, tokenizer, chat_prompt_template=None, run_prompt_template=None, additional_tools=None):
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
super().__init__(
|
||||
chat_prompt_template=chat_prompt_template,
|
||||
run_prompt_template=run_prompt_template,
|
||||
additional_tools=additional_tools,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
||||
"""
|
||||
Convenience method to build a `LocalAgent` from a pretrained checkpoint.
|
||||
|
||||
Args:
|
||||
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
||||
The name of a repo on the Hub or a local path to a folder containing both model and tokenizer.
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
Keyword arguments passed along to [`~PreTrainedModel.from_pretrained`].
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import LocalAgent
|
||||
|
||||
agent = LocalAgent.from_pretrained("bigcode/starcoder", device_map="auto", torch_dtype=torch.bfloat16)
|
||||
agent.run("Draw me a picture of rivers and lakes.")
|
||||
```
|
||||
"""
|
||||
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
return cls(model, tokenizer)
|
||||
|
||||
@property
|
||||
def _model_device(self):
|
||||
if hasattr(self.model, "hf_device_map"):
|
||||
return list(self.model.hf_device_map.values())[0]
|
||||
for param in self.model.parameters():
|
||||
return param.device
|
||||
|
||||
def generate_one(self, prompt, stop):
|
||||
encoded_inputs = self.tokenizer(prompt, return_tensors="pt").to(self._model_device)
|
||||
src_len = encoded_inputs["input_ids"].shape[1]
|
||||
stopping_criteria = StoppingCriteriaList([StopSequenceCriteria(stop, self.tokenizer)])
|
||||
outputs = self.model.generate(
|
||||
encoded_inputs["input_ids"], max_new_tokens=200, stopping_criteria=stopping_criteria
|
||||
)
|
||||
|
||||
result = self.tokenizer.decode(outputs[0].tolist()[src_len:])
|
||||
# Inference API returns the stop sequence
|
||||
for stop_seq in stop:
|
||||
if result.endswith(stop_seq):
|
||||
result = result[: -len(stop_seq)]
|
||||
return result
|
||||
|
||||
|
||||
class StopSequenceCriteria(StoppingCriteria):
|
||||
"""
|
||||
This class can be used to stop generation whenever a sequence of tokens is encountered.
|
||||
|
||||
Args:
|
||||
stop_sequences (`str` or `List[str]`):
|
||||
The sequence (or list of sequences) on which to stop execution.
|
||||
tokenizer:
|
||||
The tokenizer used to decode the model outputs.
|
||||
"""
|
||||
|
||||
def __init__(self, stop_sequences, tokenizer):
|
||||
if isinstance(stop_sequences, str):
|
||||
stop_sequences = [stop_sequences]
|
||||
self.stop_sequences = stop_sequences
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def __call__(self, input_ids, scores, **kwargs) -> bool:
|
||||
decoded_output = self.tokenizer.decode(input_ids.tolist()[0])
|
||||
return any(decoded_output.endswith(stop_sequence) for stop_sequence in self.stop_sequences)
|
@ -1,51 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ..models.auto import AutoModelForVision2Seq
|
||||
from ..utils import requires_backends
|
||||
from .base import PipelineTool
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class ImageCaptioningTool(PipelineTool):
|
||||
default_checkpoint = "Salesforce/blip-image-captioning-base"
|
||||
description = (
|
||||
"This is a tool that generates a description of an image. It takes an input named `image` which should be the "
|
||||
"image to caption, and returns a text that contains the description in English."
|
||||
)
|
||||
name = "image_captioner"
|
||||
model_class = AutoModelForVision2Seq
|
||||
|
||||
inputs = ["image"]
|
||||
outputs = ["text"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["vision"])
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def encode(self, image: "Image"):
|
||||
return self.pre_processor(images=image, return_tensors="pt")
|
||||
|
||||
def forward(self, inputs):
|
||||
return self.model.generate(**inputs)
|
||||
|
||||
def decode(self, outputs):
|
||||
return self.pre_processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
|
@ -1,58 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..models.clipseg import CLIPSegForImageSegmentation
|
||||
from ..utils import is_vision_available, requires_backends
|
||||
from .base import PipelineTool
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class ImageSegmentationTool(PipelineTool):
|
||||
description = (
|
||||
"This is a tool that creates a segmentation mask of an image according to a label. It cannot create an image. "
|
||||
"It takes two arguments named `image` which should be the original image, and `label` which should be a text "
|
||||
"describing the elements what should be identified in the segmentation mask. The tool returns the mask."
|
||||
)
|
||||
default_checkpoint = "CIDAS/clipseg-rd64-refined"
|
||||
name = "image_segmenter"
|
||||
model_class = CLIPSegForImageSegmentation
|
||||
|
||||
inputs = ["image", "text"]
|
||||
outputs = ["image"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["vision"])
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def encode(self, image: "Image", label: str):
|
||||
return self.pre_processor(text=[label], images=[image], padding=True, return_tensors="pt")
|
||||
|
||||
def forward(self, inputs):
|
||||
with torch.no_grad():
|
||||
logits = self.model(**inputs).logits
|
||||
return logits
|
||||
|
||||
def decode(self, outputs):
|
||||
array = outputs.cpu().detach().numpy()
|
||||
array[array <= 0] = 0
|
||||
array[array > 0] = 1
|
||||
return Image.fromarray((array * 255).astype(np.uint8))
|
@ -1,48 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import re
|
||||
|
||||
from ..utils import cached_file
|
||||
|
||||
|
||||
# docstyle-ignore
|
||||
CHAT_MESSAGE_PROMPT = """
|
||||
Human: <<task>>
|
||||
|
||||
Assistant: """
|
||||
|
||||
|
||||
DEFAULT_PROMPTS_REPO = "huggingface-tools/default-prompts"
|
||||
PROMPT_FILES = {"chat": "chat_prompt_template.txt", "run": "run_prompt_template.txt"}
|
||||
|
||||
|
||||
def download_prompt(prompt_or_repo_id, agent_name, mode="run"):
|
||||
"""
|
||||
Downloads and caches the prompt from a repo and returns it contents (if necessary)
|
||||
"""
|
||||
if prompt_or_repo_id is None:
|
||||
prompt_or_repo_id = DEFAULT_PROMPTS_REPO
|
||||
|
||||
# prompt is considered a repo ID when it does not contain any kind of space
|
||||
if re.search("\\s", prompt_or_repo_id) is not None:
|
||||
return prompt_or_repo_id
|
||||
|
||||
prompt_file = cached_file(
|
||||
prompt_or_repo_id, PROMPT_FILES[mode], repo_type="dataset", user_agent={"agent": agent_name}
|
||||
)
|
||||
with open(prompt_file, "r", encoding="utf-8") as f:
|
||||
return f.read()
|
@ -1,253 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import ast
|
||||
import difflib
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Callable, Dict
|
||||
|
||||
|
||||
class InterpretorError(ValueError):
|
||||
"""
|
||||
An error raised when the interpretor cannot evaluate a Python expression, due to syntax error or unsupported
|
||||
operations.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def evaluate(code: str, tools: Dict[str, Callable], state=None, chat_mode=False):
|
||||
"""
|
||||
Evaluate a python expression using the content of the variables stored in a state and only evaluating a given set
|
||||
of functions.
|
||||
|
||||
This function will recurse through the nodes of the tree provided.
|
||||
|
||||
Args:
|
||||
code (`str`):
|
||||
The code to evaluate.
|
||||
tools (`Dict[str, Callable]`):
|
||||
The functions that may be called during the evaluation. Any call to another function will fail with an
|
||||
`InterpretorError`.
|
||||
state (`Dict[str, Any]`):
|
||||
A dictionary mapping variable names to values. The `state` should contain the initial inputs but will be
|
||||
updated by this function to contain all variables as they are evaluated.
|
||||
chat_mode (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not the function is called from `Agent.chat`.
|
||||
"""
|
||||
try:
|
||||
expression = ast.parse(code)
|
||||
except SyntaxError as e:
|
||||
print("The code generated by the agent is not valid.\n", e)
|
||||
return
|
||||
if state is None:
|
||||
state = {}
|
||||
result = None
|
||||
for idx, node in enumerate(expression.body):
|
||||
try:
|
||||
line_result = evaluate_ast(node, state, tools)
|
||||
except InterpretorError as e:
|
||||
msg = f"Evaluation of the code stopped at line {idx} before the end because of the following error"
|
||||
if chat_mode:
|
||||
msg += (
|
||||
f". Copy paste the following error message and send it back to the agent:\nI get an error: '{e}'"
|
||||
)
|
||||
else:
|
||||
msg += f":\n{e}"
|
||||
print(msg)
|
||||
break
|
||||
if line_result is not None:
|
||||
result = line_result
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def evaluate_ast(expression: ast.AST, state: Dict[str, Any], tools: Dict[str, Callable]):
|
||||
"""
|
||||
Evaluate an absract syntax tree using the content of the variables stored in a state and only evaluating a given
|
||||
set of functions.
|
||||
|
||||
This function will recurse trough the nodes of the tree provided.
|
||||
|
||||
Args:
|
||||
expression (`ast.AST`):
|
||||
The code to evaluate, as an abastract syntax tree.
|
||||
state (`Dict[str, Any]`):
|
||||
A dictionary mapping variable names to values. The `state` is updated if need be when the evaluation
|
||||
encounters assignements.
|
||||
tools (`Dict[str, Callable]`):
|
||||
The functions that may be called during the evaluation. Any call to another function will fail with an
|
||||
`InterpretorError`.
|
||||
"""
|
||||
if isinstance(expression, ast.Assign):
|
||||
# Assignement -> we evaluate the assignement which should update the state
|
||||
# We return the variable assigned as it may be used to determine the final result.
|
||||
return evaluate_assign(expression, state, tools)
|
||||
elif isinstance(expression, ast.Call):
|
||||
# Function call -> we return the value of the function call
|
||||
return evaluate_call(expression, state, tools)
|
||||
elif isinstance(expression, ast.Constant):
|
||||
# Constant -> just return the value
|
||||
return expression.value
|
||||
elif isinstance(expression, ast.Dict):
|
||||
# Dict -> evaluate all keys and values
|
||||
keys = [evaluate_ast(k, state, tools) for k in expression.keys]
|
||||
values = [evaluate_ast(v, state, tools) for v in expression.values]
|
||||
return dict(zip(keys, values))
|
||||
elif isinstance(expression, ast.Expr):
|
||||
# Expression -> evaluate the content
|
||||
return evaluate_ast(expression.value, state, tools)
|
||||
elif isinstance(expression, ast.For):
|
||||
# For loop -> execute the loop
|
||||
return evaluate_for(expression, state, tools)
|
||||
elif isinstance(expression, ast.FormattedValue):
|
||||
# Formatted value (part of f-string) -> evaluate the content and return
|
||||
return evaluate_ast(expression.value, state, tools)
|
||||
elif isinstance(expression, ast.If):
|
||||
# If -> execute the right branch
|
||||
return evaluate_if(expression, state, tools)
|
||||
elif hasattr(ast, "Index") and isinstance(expression, ast.Index):
|
||||
return evaluate_ast(expression.value, state, tools)
|
||||
elif isinstance(expression, ast.JoinedStr):
|
||||
return "".join([str(evaluate_ast(v, state, tools)) for v in expression.values])
|
||||
elif isinstance(expression, ast.List):
|
||||
# List -> evaluate all elements
|
||||
return [evaluate_ast(elt, state, tools) for elt in expression.elts]
|
||||
elif isinstance(expression, ast.Name):
|
||||
# Name -> pick up the value in the state
|
||||
return evaluate_name(expression, state, tools)
|
||||
elif isinstance(expression, ast.Subscript):
|
||||
# Subscript -> return the value of the indexing
|
||||
return evaluate_subscript(expression, state, tools)
|
||||
else:
|
||||
# For now we refuse anything else. Let's add things as we need them.
|
||||
raise InterpretorError(f"{expression.__class__.__name__} is not supported.")
|
||||
|
||||
|
||||
def evaluate_assign(assign, state, tools):
|
||||
var_names = assign.targets
|
||||
result = evaluate_ast(assign.value, state, tools)
|
||||
|
||||
if len(var_names) == 1:
|
||||
state[var_names[0].id] = result
|
||||
else:
|
||||
if len(result) != len(var_names):
|
||||
raise InterpretorError(f"Expected {len(var_names)} values but got {len(result)}.")
|
||||
for var_name, r in zip(var_names, result):
|
||||
state[var_name.id] = r
|
||||
return result
|
||||
|
||||
|
||||
def evaluate_call(call, state, tools):
|
||||
if not isinstance(call.func, ast.Name):
|
||||
raise InterpretorError(
|
||||
f"It is not permitted to evaluate other functions than the provided tools (tried to execute {call.func} of "
|
||||
f"type {type(call.func)}."
|
||||
)
|
||||
func_name = call.func.id
|
||||
if func_name not in tools:
|
||||
raise InterpretorError(
|
||||
f"It is not permitted to evaluate other functions than the provided tools (tried to execute {call.func.id})."
|
||||
)
|
||||
|
||||
func = tools[func_name]
|
||||
# Todo deal with args
|
||||
args = [evaluate_ast(arg, state, tools) for arg in call.args]
|
||||
kwargs = {keyword.arg: evaluate_ast(keyword.value, state, tools) for keyword in call.keywords}
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
def evaluate_subscript(subscript, state, tools):
|
||||
index = evaluate_ast(subscript.slice, state, tools)
|
||||
value = evaluate_ast(subscript.value, state, tools)
|
||||
if isinstance(value, (list, tuple)):
|
||||
return value[int(index)]
|
||||
if index in value:
|
||||
return value[index]
|
||||
if isinstance(index, str) and isinstance(value, Mapping):
|
||||
close_matches = difflib.get_close_matches(index, list(value.keys()))
|
||||
if len(close_matches) > 0:
|
||||
return value[close_matches[0]]
|
||||
|
||||
raise InterpretorError(f"Could not index {value} with '{index}'.")
|
||||
|
||||
|
||||
def evaluate_name(name, state, tools):
|
||||
if name.id in state:
|
||||
return state[name.id]
|
||||
close_matches = difflib.get_close_matches(name.id, list(state.keys()))
|
||||
if len(close_matches) > 0:
|
||||
return state[close_matches[0]]
|
||||
raise InterpretorError(f"The variable `{name.id}` is not defined.")
|
||||
|
||||
|
||||
def evaluate_condition(condition, state, tools):
|
||||
if len(condition.ops) > 1:
|
||||
raise InterpretorError("Cannot evaluate conditions with multiple operators")
|
||||
|
||||
left = evaluate_ast(condition.left, state, tools)
|
||||
comparator = condition.ops[0]
|
||||
right = evaluate_ast(condition.comparators[0], state, tools)
|
||||
|
||||
if isinstance(comparator, ast.Eq):
|
||||
return left == right
|
||||
elif isinstance(comparator, ast.NotEq):
|
||||
return left != right
|
||||
elif isinstance(comparator, ast.Lt):
|
||||
return left < right
|
||||
elif isinstance(comparator, ast.LtE):
|
||||
return left <= right
|
||||
elif isinstance(comparator, ast.Gt):
|
||||
return left > right
|
||||
elif isinstance(comparator, ast.GtE):
|
||||
return left >= right
|
||||
elif isinstance(comparator, ast.Is):
|
||||
return left is right
|
||||
elif isinstance(comparator, ast.IsNot):
|
||||
return left is not right
|
||||
elif isinstance(comparator, ast.In):
|
||||
return left in right
|
||||
elif isinstance(comparator, ast.NotIn):
|
||||
return left not in right
|
||||
else:
|
||||
raise InterpretorError(f"Operator not supported: {comparator}")
|
||||
|
||||
|
||||
def evaluate_if(if_statement, state, tools):
|
||||
result = None
|
||||
if evaluate_condition(if_statement.test, state, tools):
|
||||
for line in if_statement.body:
|
||||
line_result = evaluate_ast(line, state, tools)
|
||||
if line_result is not None:
|
||||
result = line_result
|
||||
else:
|
||||
for line in if_statement.orelse:
|
||||
line_result = evaluate_ast(line, state, tools)
|
||||
if line_result is not None:
|
||||
result = line_result
|
||||
return result
|
||||
|
||||
|
||||
def evaluate_for(for_loop, state, tools):
|
||||
result = None
|
||||
iterator = evaluate_ast(for_loop.iter, state, tools)
|
||||
for counter in iterator:
|
||||
state[for_loop.target.id] = counter
|
||||
for expression in for_loop.body:
|
||||
line_result = evaluate_ast(expression, state, tools)
|
||||
if line_result is not None:
|
||||
result = line_result
|
||||
return result
|
@ -1,70 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
|
||||
from ..models.auto import AutoModelForSequenceClassification, AutoTokenizer
|
||||
from .base import PipelineTool
|
||||
|
||||
|
||||
class TextClassificationTool(PipelineTool):
|
||||
"""
|
||||
Example:
|
||||
|
||||
```py
|
||||
from transformers.tools import TextClassificationTool
|
||||
|
||||
classifier = TextClassificationTool()
|
||||
classifier("This is a super nice API!", labels=["positive", "negative"])
|
||||
```
|
||||
"""
|
||||
|
||||
default_checkpoint = "facebook/bart-large-mnli"
|
||||
description = (
|
||||
"This is a tool that classifies an English text using provided labels. It takes two inputs: `text`, which "
|
||||
"should be the text to classify, and `labels`, which should be the list of labels to use for classification. "
|
||||
"It returns the most likely label in the list of provided `labels` for the input text."
|
||||
)
|
||||
name = "text_classifier"
|
||||
pre_processor_class = AutoTokenizer
|
||||
model_class = AutoModelForSequenceClassification
|
||||
|
||||
inputs = ["text", ["text"]]
|
||||
outputs = ["text"]
|
||||
|
||||
def setup(self):
|
||||
super().setup()
|
||||
config = self.model.config
|
||||
self.entailment_id = -1
|
||||
for idx, label in config.id2label.items():
|
||||
if label.lower().startswith("entail"):
|
||||
self.entailment_id = int(idx)
|
||||
if self.entailment_id == -1:
|
||||
raise ValueError("Could not determine the entailment ID from the model config, please pass it at init.")
|
||||
|
||||
def encode(self, text, labels):
|
||||
self._labels = labels
|
||||
return self.pre_processor(
|
||||
[text] * len(labels),
|
||||
[f"This example is {label}" for label in labels],
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
)
|
||||
|
||||
def decode(self, outputs):
|
||||
logits = outputs.logits
|
||||
label_id = torch.argmax(logits[:, 2]).item()
|
||||
return self._labels[label_id]
|
@ -1,52 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from ..models.auto import AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
from .base import PipelineTool
|
||||
|
||||
|
||||
QA_PROMPT = """Here is a text containing a lot of information: '''{text}'''.
|
||||
|
||||
Can you answer this question about the text: '{question}'"""
|
||||
|
||||
|
||||
class TextQuestionAnsweringTool(PipelineTool):
|
||||
default_checkpoint = "google/flan-t5-base"
|
||||
description = (
|
||||
"This is a tool that answers questions related to a text. It takes two arguments named `text`, which is the "
|
||||
"text where to find the answer, and `question`, which is the question, and returns the answer to the question."
|
||||
)
|
||||
name = "text_qa"
|
||||
pre_processor_class = AutoTokenizer
|
||||
model_class = AutoModelForSeq2SeqLM
|
||||
|
||||
inputs = ["text", "text"]
|
||||
outputs = ["text"]
|
||||
|
||||
def encode(self, text: str, question: str):
|
||||
prompt = QA_PROMPT.format(text=text, question=question)
|
||||
return self.pre_processor(prompt, return_tensors="pt")
|
||||
|
||||
def forward(self, inputs):
|
||||
output_ids = self.model.generate(**inputs)
|
||||
|
||||
in_b, _ = inputs["input_ids"].shape
|
||||
out_b = output_ids.shape[0]
|
||||
|
||||
return output_ids.reshape(in_b, out_b // in_b, *output_ids.shape[1:])[0][0]
|
||||
|
||||
def decode(self, outputs):
|
||||
return self.pre_processor.decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
@ -1,52 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from ..models.auto import AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
from .base import PipelineTool
|
||||
|
||||
|
||||
class TextSummarizationTool(PipelineTool):
|
||||
"""
|
||||
Example:
|
||||
|
||||
```py
|
||||
from transformers.tools import TextSummarizationTool
|
||||
|
||||
summarizer = TextSummarizationTool()
|
||||
summarizer(long_text)
|
||||
```
|
||||
"""
|
||||
|
||||
default_checkpoint = "philschmid/bart-large-cnn-samsum"
|
||||
description = (
|
||||
"This is a tool that summarizes an English text. It takes an input `text` containing the text to summarize, "
|
||||
"and returns a summary of the text."
|
||||
)
|
||||
name = "summarizer"
|
||||
pre_processor_class = AutoTokenizer
|
||||
model_class = AutoModelForSeq2SeqLM
|
||||
|
||||
inputs = ["text"]
|
||||
outputs = ["text"]
|
||||
|
||||
def encode(self, text):
|
||||
return self.pre_processor(text, return_tensors="pt", truncation=True)
|
||||
|
||||
def forward(self, inputs):
|
||||
return self.model.generate(**inputs)[0]
|
||||
|
||||
def decode(self, outputs):
|
||||
return self.pre_processor.decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
@ -142,6 +142,7 @@ _phonemizer_available = _is_package_available("phonemizer")
|
||||
_psutil_available = _is_package_available("psutil")
|
||||
_py3nvml_available = _is_package_available("py3nvml")
|
||||
_pyctcdecode_available = _is_package_available("pyctcdecode")
|
||||
_pygments_available = _is_package_available("pygments")
|
||||
_pytesseract_available = _is_package_available("pytesseract")
|
||||
_pytest_available = _is_package_available("pytest")
|
||||
_pytorch_quantization_available = _is_package_available("pytorch_quantization")
|
||||
@ -297,6 +298,10 @@ def is_hqq_available():
|
||||
return _hqq_available
|
||||
|
||||
|
||||
def is_pygments_available():
|
||||
return _pygments_available
|
||||
|
||||
|
||||
def get_torch_version():
|
||||
return _torch_version
|
||||
|
||||
@ -1294,6 +1299,11 @@ shi-labs.com/natten . You can also install it with pip (may take longer to build
|
||||
`pip install natten`. Please note that you may need to restart your runtime after installation.
|
||||
"""
|
||||
|
||||
NUMEXPR_IMPORT_ERROR = """
|
||||
{0} requires the numexpr library but it was not found in your environment. You can install it by referring to:
|
||||
https://numexpr.readthedocs.io/en/latest/index.html.
|
||||
"""
|
||||
|
||||
|
||||
# docstyle-ignore
|
||||
NLTK_IMPORT_ERROR = """
|
||||
|
@ -18,8 +18,8 @@ import unittest
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
from transformers.agents.agent_types import AgentAudio, AgentImage, AgentText
|
||||
from transformers.testing_utils import get_tests_dir, require_soundfile, require_torch, require_vision
|
||||
from transformers.tools.agent_types import AgentAudio, AgentImage, AgentText
|
||||
from transformers.utils import is_soundfile_availble, is_torch_available, is_vision_available
|
||||
|
||||
|
161
tests/agents/test_agents.py
Normal file
161
tests/agents/test_agents.py
Normal file
@ -0,0 +1,161 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from transformers.agents.agent_types import AgentText
|
||||
from transformers.agents.agents import AgentMaxIterationsError, CodeAgent, ReactCodeAgent, ReactJsonAgent, Toolbox
|
||||
from transformers.agents.default_tools import PythonInterpreterTool
|
||||
from transformers.testing_utils import require_torch
|
||||
|
||||
|
||||
def get_new_path(suffix="") -> str:
|
||||
directory = tempfile.mkdtemp()
|
||||
return os.path.join(directory, str(uuid.uuid4()) + suffix)
|
||||
|
||||
|
||||
def fake_react_json_llm(messages, stop_sequences=None) -> str:
|
||||
prompt = str(messages)
|
||||
|
||||
if "special_marker" not in prompt:
|
||||
return """
|
||||
Thought: I should multiply 2 by 3.6452. special_marker
|
||||
Action:
|
||||
{
|
||||
"action": "python_interpreter",
|
||||
"action_input": {"code": "2*3.6452"}
|
||||
}
|
||||
"""
|
||||
else: # We're at step 2
|
||||
return """
|
||||
Thought: I can now answer the initial question
|
||||
Action:
|
||||
{
|
||||
"action": "final_answer",
|
||||
"action_input": {"answer": "7.2904"}
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
def fake_react_code_llm(messages, stop_sequences=None) -> str:
|
||||
prompt = str(messages)
|
||||
if "special_marker" not in prompt:
|
||||
return """
|
||||
Thought: I should multiply 2 by 3.6452. special_marker
|
||||
Code:
|
||||
```py
|
||||
result = 2**3.6452
|
||||
print(result)
|
||||
```<end_code>
|
||||
"""
|
||||
else: # We're at step 2
|
||||
return """
|
||||
Thought: I can now answer the initial question
|
||||
Code:
|
||||
```py
|
||||
final_answer(7.2904)
|
||||
```<end_code>
|
||||
"""
|
||||
|
||||
|
||||
def fake_code_llm_oneshot(messages, stop_sequences=None) -> str:
|
||||
return """
|
||||
Thought: I should multiply 2 by 3.6452. special_marker
|
||||
Code:
|
||||
```py
|
||||
result = python_interpreter(code="2*3.6452")
|
||||
print(result)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class AgentTests(unittest.TestCase):
|
||||
def test_fake_code_agent(self):
|
||||
agent = CodeAgent(tools=[PythonInterpreterTool()], llm_engine=fake_code_llm_oneshot)
|
||||
output = agent.run("What is 2 multiplied by 3.6452?")
|
||||
assert isinstance(output, str)
|
||||
assert output == "7.2904"
|
||||
|
||||
def test_fake_react_json_agent(self):
|
||||
agent = ReactJsonAgent(tools=[PythonInterpreterTool()], llm_engine=fake_react_json_llm)
|
||||
output = agent.run("What is 2 multiplied by 3.6452?")
|
||||
assert isinstance(output, str)
|
||||
assert output == "7.2904"
|
||||
assert agent.logs[0]["task"] == "What is 2 multiplied by 3.6452?"
|
||||
assert agent.logs[1]["observation"] == "7.2904"
|
||||
assert agent.logs[1]["rationale"].strip() == "Thought: I should multiply 2 by 3.6452. special_marker"
|
||||
assert (
|
||||
agent.logs[2]["llm_output"]
|
||||
== """
|
||||
Thought: I can now answer the initial question
|
||||
Action:
|
||||
{
|
||||
"action": "final_answer",
|
||||
"action_input": {"answer": "7.2904"}
|
||||
}
|
||||
"""
|
||||
)
|
||||
|
||||
def test_fake_react_code_agent(self):
|
||||
agent = ReactCodeAgent(tools=[PythonInterpreterTool()], llm_engine=fake_react_code_llm)
|
||||
output = agent.run("What is 2 multiplied by 3.6452?")
|
||||
assert isinstance(output, AgentText)
|
||||
assert output == "7.2904"
|
||||
assert agent.logs[0]["task"] == "What is 2 multiplied by 3.6452?"
|
||||
assert float(agent.logs[1]["observation"].strip()) - 12.511648 < 1e-6
|
||||
assert agent.logs[2]["tool_call"] == {
|
||||
"tool_arguments": "final_answer(7.2904)",
|
||||
"tool_name": "code interpreter",
|
||||
}
|
||||
|
||||
def test_setup_agent_with_empty_toolbox(self):
|
||||
ReactJsonAgent(llm_engine=fake_react_json_llm, tools=[])
|
||||
|
||||
def test_react_fails_max_iterations(self):
|
||||
agent = ReactCodeAgent(
|
||||
tools=[PythonInterpreterTool()],
|
||||
llm_engine=fake_code_llm_oneshot, # use this callable because it never ends
|
||||
max_iterations=5,
|
||||
)
|
||||
agent.run("What is 2 multiplied by 3.6452?")
|
||||
assert len(agent.logs) == 7
|
||||
assert type(agent.logs[-1]["error"]) == AgentMaxIterationsError
|
||||
|
||||
@require_torch
|
||||
def test_init_agent_with_different_toolsets(self):
|
||||
toolset_1 = []
|
||||
agent = ReactCodeAgent(tools=toolset_1, llm_engine=fake_react_code_llm)
|
||||
assert len(agent.toolbox.tools) == 1 # contains only final_answer tool
|
||||
|
||||
toolset_2 = [PythonInterpreterTool(), PythonInterpreterTool()]
|
||||
agent = ReactCodeAgent(tools=toolset_2, llm_engine=fake_react_code_llm)
|
||||
assert len(agent.toolbox.tools) == 2 # added final_answer tool
|
||||
|
||||
toolset_3 = Toolbox(toolset_2)
|
||||
agent = ReactCodeAgent(tools=toolset_3, llm_engine=fake_react_code_llm)
|
||||
assert len(agent.toolbox.tools) == 2 # added final_answer tool
|
||||
|
||||
# check that add_base_tools will not interfere with existing tools
|
||||
with pytest.raises(KeyError) as e:
|
||||
agent = ReactJsonAgent(tools=toolset_3, llm_engine=fake_react_json_llm, add_base_tools=True)
|
||||
assert "python_interpreter already exists in the toolbox" in str(e)
|
||||
|
||||
# check that python_interpreter base tool does not get added to code agents
|
||||
agent = ReactCodeAgent(tools=[], llm_engine=fake_react_code_llm, add_base_tools=True)
|
||||
assert len(agent.toolbox.tools) == 6 # added final_answer tool + 5 base tools (excluding interpreter)
|
@ -26,7 +26,6 @@ class DocumentQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("document-question-answering")
|
||||
self.tool.setup()
|
||||
self.remote_tool = load_tool("document-question-answering", remote=True)
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
|
||||
@ -35,22 +34,8 @@ class DocumentQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
result = self.tool(document, "When is the coffee break?")
|
||||
self.assertEqual(result, "11-14 to 11:39 a.m.")
|
||||
|
||||
def test_exact_match_arg_remote(self):
|
||||
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
|
||||
document = dataset[0]["image"]
|
||||
|
||||
result = self.remote_tool(document, "When is the coffee break?")
|
||||
self.assertEqual(result, "11-14 to 11:39 a.m.")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
|
||||
document = dataset[0]["image"]
|
||||
|
||||
self.tool(document=document, question="When is the coffee break?")
|
||||
|
||||
def test_exact_match_kwarg_remote(self):
|
||||
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
|
||||
document = dataset[0]["image"]
|
||||
|
||||
result = self.remote_tool(document=document, question="When is the coffee break?")
|
||||
self.assertEqual(result, "11-14 to 11:39 a.m.")
|
71
tests/agents/test_final_answer.py
Normal file
71
tests/agents/test_final_answer.py
Normal file
@ -0,0 +1,71 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from transformers import is_torch_available, load_tool
|
||||
from transformers.agents.agent_types import AGENT_TYPE_MAPPING
|
||||
from transformers.testing_utils import get_tests_dir, require_torch
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
class FinalAnswerToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.inputs = {"answer": "Final answer"}
|
||||
self.tool = load_tool("final_answer")
|
||||
self.tool.setup()
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
result = self.tool("Final answer")
|
||||
self.assertEqual(result, "Final answer")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
result = self.tool(answer=self.inputs["answer"])
|
||||
self.assertEqual(result, "Final answer")
|
||||
|
||||
def create_inputs(self):
|
||||
inputs_text = {"answer": "Text input"}
|
||||
inputs_image = {
|
||||
"answer": Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png").resize(
|
||||
(512, 512)
|
||||
)
|
||||
}
|
||||
inputs_audio = {"answer": torch.Tensor(np.ones(3000))}
|
||||
return {"text": inputs_text, "image": inputs_image, "audio": inputs_audio}
|
||||
|
||||
@require_torch
|
||||
def test_agent_type_output(self):
|
||||
inputs = self.create_inputs()
|
||||
for input_type, input in inputs.items():
|
||||
output = self.tool(**input)
|
||||
agent_type = AGENT_TYPE_MAPPING[input_type]
|
||||
self.assertTrue(isinstance(output, agent_type))
|
||||
|
||||
@require_torch
|
||||
def test_agent_types_inputs(self):
|
||||
inputs = self.create_inputs()
|
||||
for input_type, input in inputs.items():
|
||||
output = self.tool(**input)
|
||||
agent_type = AGENT_TYPE_MAPPING[input_type]
|
||||
self.assertTrue(isinstance(output, agent_type))
|
@ -30,24 +30,13 @@ class ImageQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("image-question-answering")
|
||||
self.tool.setup()
|
||||
self.remote_tool = load_tool("image-question-answering", remote=True)
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
||||
result = self.tool(image, "How many cats are sleeping on the couch?")
|
||||
self.assertEqual(result, "2")
|
||||
|
||||
def test_exact_match_arg_remote(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
||||
result = self.remote_tool(image, "How many cats are sleeping on the couch?")
|
||||
self.assertEqual(result, "2")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
||||
result = self.tool(image=image, question="How many cats are sleeping on the couch?")
|
||||
self.assertEqual(result, "2")
|
||||
|
||||
def test_exact_match_kwarg_remote(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
||||
result = self.remote_tool(image=image, question="How many cats are sleeping on the couch?")
|
||||
self.assertEqual(result, "2")
|
355
tests/agents/test_python_interpreter.py
Normal file
355
tests/agents/test_python_interpreter.py
Normal file
@ -0,0 +1,355 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from transformers import load_tool
|
||||
from transformers.agents.agent_types import AGENT_TYPE_MAPPING
|
||||
from transformers.agents.default_tools import BASE_PYTHON_TOOLS
|
||||
from transformers.agents.python_interpreter import InterpretorError, evaluate_python_code
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
# Fake function we will use as tool
|
||||
def add_two(x):
|
||||
return x + 2
|
||||
|
||||
|
||||
class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("python_interpreter")
|
||||
self.tool.setup()
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
result = self.tool("(2 / 2) * 4")
|
||||
self.assertEqual(result, "4.0")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
result = self.tool(code="(2 / 2) * 4")
|
||||
self.assertEqual(result, "4.0")
|
||||
|
||||
def test_agent_type_output(self):
|
||||
inputs = ["2 * 2"]
|
||||
output = self.tool(*inputs)
|
||||
output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
|
||||
self.assertTrue(isinstance(output, output_type))
|
||||
|
||||
def test_agent_types_inputs(self):
|
||||
inputs = ["2 * 2"]
|
||||
_inputs = []
|
||||
|
||||
for _input, expected_input in zip(inputs, self.tool.inputs.values()):
|
||||
input_type = expected_input["type"]
|
||||
if isinstance(input_type, list):
|
||||
_inputs.append([AGENT_TYPE_MAPPING[_input_type](_input) for _input_type in input_type])
|
||||
else:
|
||||
_inputs.append(AGENT_TYPE_MAPPING[input_type](_input))
|
||||
|
||||
# Should not raise an error
|
||||
output = self.tool(*inputs)
|
||||
output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
|
||||
self.assertTrue(isinstance(output, output_type))
|
||||
|
||||
|
||||
class PythonInterpreterTester(unittest.TestCase):
|
||||
def test_evaluate_assign(self):
|
||||
code = "x = 3"
|
||||
state = {}
|
||||
result = evaluate_python_code(code, {}, state=state)
|
||||
assert result == 3
|
||||
self.assertDictEqual(state, {"x": 3, "print_outputs": ""})
|
||||
|
||||
code = "x = y"
|
||||
state = {"y": 5}
|
||||
result = evaluate_python_code(code, {}, state=state)
|
||||
# evaluate returns the value of the last assignment.
|
||||
assert result == 5
|
||||
self.assertDictEqual(state, {"x": 5, "y": 5, "print_outputs": ""})
|
||||
|
||||
def test_evaluate_call(self):
|
||||
code = "y = add_two(x)"
|
||||
state = {"x": 3}
|
||||
result = evaluate_python_code(code, {"add_two": add_two}, state=state)
|
||||
assert result == 5
|
||||
self.assertDictEqual(state, {"x": 3, "y": 5, "print_outputs": ""})
|
||||
|
||||
# Should not work without the tool
|
||||
with pytest.raises(InterpretorError) as e:
|
||||
evaluate_python_code(code, {}, state=state)
|
||||
assert "tried to execute add_two" in str(e.value)
|
||||
|
||||
def test_evaluate_constant(self):
|
||||
code = "x = 3"
|
||||
state = {}
|
||||
result = evaluate_python_code(code, {}, state=state)
|
||||
assert result == 3
|
||||
self.assertDictEqual(state, {"x": 3, "print_outputs": ""})
|
||||
|
||||
def test_evaluate_dict(self):
|
||||
code = "test_dict = {'x': x, 'y': add_two(x)}"
|
||||
state = {"x": 3}
|
||||
result = evaluate_python_code(code, {"add_two": add_two}, state=state)
|
||||
self.assertDictEqual(result, {"x": 3, "y": 5})
|
||||
self.assertDictEqual(state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""})
|
||||
|
||||
def test_evaluate_expression(self):
|
||||
code = "x = 3\ny = 5"
|
||||
state = {}
|
||||
result = evaluate_python_code(code, {}, state=state)
|
||||
# evaluate returns the value of the last assignment.
|
||||
assert result == 5
|
||||
self.assertDictEqual(state, {"x": 3, "y": 5, "print_outputs": ""})
|
||||
|
||||
def test_evaluate_f_string(self):
|
||||
code = "text = f'This is x: {x}.'"
|
||||
state = {"x": 3}
|
||||
result = evaluate_python_code(code, {}, state=state)
|
||||
# evaluate returns the value of the last assignment.
|
||||
assert result == "This is x: 3."
|
||||
self.assertDictEqual(state, {"x": 3, "text": "This is x: 3.", "print_outputs": ""})
|
||||
|
||||
def test_evaluate_if(self):
|
||||
code = "if x <= 3:\n y = 2\nelse:\n y = 5"
|
||||
state = {"x": 3}
|
||||
result = evaluate_python_code(code, {}, state=state)
|
||||
# evaluate returns the value of the last assignment.
|
||||
assert result == 2
|
||||
self.assertDictEqual(state, {"x": 3, "y": 2, "print_outputs": ""})
|
||||
|
||||
state = {"x": 8}
|
||||
result = evaluate_python_code(code, {}, state=state)
|
||||
# evaluate returns the value of the last assignment.
|
||||
assert result == 5
|
||||
self.assertDictEqual(state, {"x": 8, "y": 5, "print_outputs": ""})
|
||||
|
||||
def test_evaluate_list(self):
|
||||
code = "test_list = [x, add_two(x)]"
|
||||
state = {"x": 3}
|
||||
result = evaluate_python_code(code, {"add_two": add_two}, state=state)
|
||||
self.assertListEqual(result, [3, 5])
|
||||
self.assertDictEqual(state, {"x": 3, "test_list": [3, 5], "print_outputs": ""})
|
||||
|
||||
def test_evaluate_name(self):
|
||||
code = "y = x"
|
||||
state = {"x": 3}
|
||||
result = evaluate_python_code(code, {}, state=state)
|
||||
assert result == 3
|
||||
self.assertDictEqual(state, {"x": 3, "y": 3, "print_outputs": ""})
|
||||
|
||||
def test_evaluate_subscript(self):
|
||||
code = "test_list = [x, add_two(x)]\ntest_list[1]"
|
||||
state = {"x": 3}
|
||||
result = evaluate_python_code(code, {"add_two": add_two}, state=state)
|
||||
assert result == 5
|
||||
self.assertDictEqual(state, {"x": 3, "test_list": [3, 5], "print_outputs": ""})
|
||||
|
||||
code = "test_dict = {'x': x, 'y': add_two(x)}\ntest_dict['y']"
|
||||
state = {"x": 3}
|
||||
result = evaluate_python_code(code, {"add_two": add_two}, state=state)
|
||||
assert result == 5
|
||||
self.assertDictEqual(state, {"x": 3, "test_dict": {"x": 3, "y": 5}, "print_outputs": ""})
|
||||
|
||||
def test_evaluate_for(self):
|
||||
code = "x = 0\nfor i in range(3):\n x = i"
|
||||
state = {}
|
||||
result = evaluate_python_code(code, {"range": range}, state=state)
|
||||
assert result == 2
|
||||
self.assertDictEqual(state, {"x": 2, "i": 2, "print_outputs": ""})
|
||||
|
||||
def test_evaluate_binop(self):
|
||||
code = "y + x"
|
||||
state = {"x": 3, "y": 6}
|
||||
result = evaluate_python_code(code, {}, state=state)
|
||||
assert result == 9
|
||||
self.assertDictEqual(state, {"x": 3, "y": 6, "print_outputs": ""})
|
||||
|
||||
def test_recursive_function(self):
|
||||
code = """
|
||||
def recur_fibo(n):
|
||||
if n <= 1:
|
||||
return n
|
||||
else:
|
||||
return(recur_fibo(n-1) + recur_fibo(n-2))
|
||||
recur_fibo(6)"""
|
||||
result = evaluate_python_code(code, {}, state={})
|
||||
assert result == 8
|
||||
|
||||
def test_evaluate_string_methods(self):
|
||||
code = "'hello'.replace('h', 'o').split('e')"
|
||||
result = evaluate_python_code(code, {}, state={})
|
||||
assert result == ["o", "llo"]
|
||||
|
||||
def test_evaluate_slicing(self):
|
||||
code = "'hello'[1:3][::-1]"
|
||||
result = evaluate_python_code(code, {}, state={})
|
||||
assert result == "le"
|
||||
|
||||
def test_access_attributes(self):
|
||||
code = "integer = 1\nobj_class = integer.__class__\nobj_class"
|
||||
result = evaluate_python_code(code, {}, state={})
|
||||
assert result == int
|
||||
|
||||
def test_list_comprehension(self):
|
||||
code = "sentence = 'THESEAGULL43'\nmeaningful_sentence = '-'.join([char.lower() for char in sentence if char.isalpha()])"
|
||||
result = evaluate_python_code(code, {}, state={})
|
||||
assert result == "t-h-e-s-e-a-g-u-l-l"
|
||||
|
||||
def test_string_indexing(self):
|
||||
code = """text_block = [
|
||||
"THESE",
|
||||
"AGULL"
|
||||
]
|
||||
sentence = ""
|
||||
for block in text_block:
|
||||
for col in range(len(text_block[0])):
|
||||
sentence += block[col]
|
||||
"""
|
||||
result = evaluate_python_code(code, {"len": len, "range": range}, state={})
|
||||
assert result == "THESEAGULL"
|
||||
|
||||
def test_tuples(self):
|
||||
code = "x = (1, 2, 3)\nx[1]"
|
||||
result = evaluate_python_code(code, {}, state={})
|
||||
assert result == 2
|
||||
|
||||
def test_listcomp(self):
|
||||
code = "x = [i for i in range(3)]"
|
||||
result = evaluate_python_code(code, {"range": range}, state={})
|
||||
assert result == [0, 1, 2]
|
||||
|
||||
def test_break_continue(self):
|
||||
code = "for i in range(10):\n if i == 5:\n break\ni"
|
||||
result = evaluate_python_code(code, {"range": range}, state={})
|
||||
assert result == 5
|
||||
|
||||
code = "for i in range(10):\n if i == 5:\n continue\ni"
|
||||
result = evaluate_python_code(code, {"range": range}, state={})
|
||||
assert result == 9
|
||||
|
||||
def test_call_int(self):
|
||||
code = "import math\nstr(math.ceil(149))"
|
||||
result = evaluate_python_code(code, {"str": lambda x: str(x)}, state={})
|
||||
assert result == "149"
|
||||
|
||||
def test_lambda(self):
|
||||
code = "f = lambda x: x + 2\nf(3)"
|
||||
result = evaluate_python_code(code, {}, state={})
|
||||
assert result == 5
|
||||
|
||||
def test_dictcomp(self):
|
||||
code = "x = {i: i**2 for i in range(3)}"
|
||||
result = evaluate_python_code(code, {"range": range}, state={})
|
||||
assert result == {0: 0, 1: 1, 2: 4}
|
||||
|
||||
def test_tuple_assignment(self):
|
||||
code = "a, b = 0, 1\nb"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result == 1
|
||||
|
||||
def test_while(self):
|
||||
code = "i = 0\nwhile i < 3:\n i += 1\ni"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result == 3
|
||||
|
||||
# test infinite loop
|
||||
code = "i = 0\nwhile i < 3:\n i -= 1\ni"
|
||||
with pytest.raises(InterpretorError) as e:
|
||||
evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert "iterations in While loop exceeded" in str(e)
|
||||
|
||||
def test_generator(self):
|
||||
code = "a = [1, 2, 3, 4, 5]; b = (i**2 for i in a); list(b)"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result == [1, 4, 9, 16, 25]
|
||||
|
||||
def test_boolops(self):
|
||||
code = """if (not (a > b and a > c)) or d > e:
|
||||
best_city = "Brooklyn"
|
||||
else:
|
||||
best_city = "Manhattan"
|
||||
best_city
|
||||
"""
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5})
|
||||
assert result == "Brooklyn"
|
||||
|
||||
code = """if d > e and a < b:
|
||||
best_city = "Brooklyn"
|
||||
elif d < e and a < b:
|
||||
best_city = "Sacramento"
|
||||
else:
|
||||
best_city = "Manhattan"
|
||||
best_city
|
||||
"""
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={"a": 1, "b": 2, "c": 3, "d": 4, "e": 5})
|
||||
assert result == "Sacramento"
|
||||
|
||||
def test_if_conditions(self):
|
||||
code = """char='a'
|
||||
if char.isalpha():
|
||||
print('2')"""
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result == "2"
|
||||
|
||||
def test_imports(self):
|
||||
code = "import math\nmath.sqrt(4)"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result == 2.0
|
||||
|
||||
code = "from random import choice, seed\nseed(12)\nchoice(['win', 'lose', 'draw'])"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result == "lose"
|
||||
|
||||
code = "import time\ntime.sleep(0.1)"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result is None
|
||||
|
||||
code = "from queue import Queue\nq = Queue()\nq.put(1)\nq.get()"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result == 1
|
||||
|
||||
code = "import itertools\nlist(itertools.islice(range(10), 3))"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result == [0, 1, 2]
|
||||
|
||||
code = "import re\nre.search('a', 'abc').group()"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result == "a"
|
||||
|
||||
code = "import stat\nstat.S_ISREG(0o100644)"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result
|
||||
|
||||
code = "import statistics\nstatistics.mean([1, 2, 3, 4, 4])"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result == 2.8
|
||||
|
||||
code = "import unicodedata\nunicodedata.name('A')"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result == "LATIN CAPITAL LETTER A"
|
||||
|
||||
def test_multiple_comparators(self):
|
||||
code = "0x30A0 <= ord('a') <= 0x30FF"
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||
assert result
|
||||
|
||||
def test_print_output(self):
|
||||
code = "print('Hello world!')\nprint('Ok no one cares')"
|
||||
state = {}
|
||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state)
|
||||
assert result == "Ok no one cares"
|
||||
assert state["print_outputs"] == "Hello world!\nOk no one cares\n"
|
@ -15,24 +15,22 @@
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import is_torch_available, load_tool
|
||||
import numpy as np
|
||||
|
||||
from transformers import load_tool
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
class SpeechToTextToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("speech-to-text")
|
||||
self.tool.setup()
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
result = self.tool(torch.ones(3000))
|
||||
self.assertEqual(result, " you")
|
||||
result = self.tool(np.ones(3000))
|
||||
self.assertEqual(result, " Thank you.")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
result = self.tool(audio=torch.ones(3000))
|
||||
self.assertEqual(result, " you")
|
||||
result = self.tool(audio=np.ones(3000))
|
||||
self.assertEqual(result, " Thank you.")
|
@ -38,21 +38,13 @@ class TextToSpeechToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
torch.manual_seed(0)
|
||||
result = self.tool("hey")
|
||||
resulting_tensor = result.to_raw()
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
resulting_tensor[:3],
|
||||
torch.tensor([-0.0005966668832115829, -0.0003657640190795064, -0.00013439502799883485]),
|
||||
)
|
||||
)
|
||||
self.assertTrue(len(resulting_tensor.detach().shape) == 1)
|
||||
self.assertTrue(resulting_tensor.detach().shape[0] > 1000)
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
# SpeechT5 isn't deterministic
|
||||
torch.manual_seed(0)
|
||||
result = self.tool("hey")
|
||||
resulting_tensor = result.to_raw()
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
resulting_tensor[:3],
|
||||
torch.tensor([-0.0005966668832115829, -0.0003657640190795064, -0.00013439502799883485]),
|
||||
)
|
||||
)
|
||||
self.assertTrue(len(resulting_tensor.detach().shape) == 1)
|
||||
self.assertTrue(resulting_tensor.detach().shape[0] > 1000)
|
107
tests/agents/test_tools_common.py
Normal file
107
tests/agents/test_tools_common.py
Normal file
@ -0,0 +1,107 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from pathlib import Path
|
||||
from typing import Dict, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import is_torch_available, is_vision_available
|
||||
from transformers.agents.agent_types import AGENT_TYPE_MAPPING, AgentAudio, AgentImage, AgentText
|
||||
from transformers.testing_utils import get_tests_dir, is_agent_test
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
AUTHORIZED_TYPES = ["text", "audio", "image", "any"]
|
||||
|
||||
|
||||
def create_inputs(tool_inputs: Dict[str, Dict[Union[str, type], str]]):
|
||||
inputs = {}
|
||||
|
||||
for input_name, input_desc in tool_inputs.items():
|
||||
input_type = input_desc["type"]
|
||||
|
||||
if input_type == "text":
|
||||
inputs[input_name] = "Text input"
|
||||
elif input_type == "image":
|
||||
inputs[input_name] = Image.open(
|
||||
Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png"
|
||||
).resize((512, 512))
|
||||
elif input_type == "audio":
|
||||
inputs[input_name] = np.ones(3000)
|
||||
else:
|
||||
raise ValueError(f"Invalid type requested: {input_type}")
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
def output_type(output):
|
||||
if isinstance(output, (str, AgentText)):
|
||||
return "text"
|
||||
elif isinstance(output, (Image.Image, AgentImage)):
|
||||
return "image"
|
||||
elif isinstance(output, (torch.Tensor, AgentAudio)):
|
||||
return "audio"
|
||||
else:
|
||||
raise ValueError(f"Invalid output: {output}")
|
||||
|
||||
|
||||
@is_agent_test
|
||||
class ToolTesterMixin:
|
||||
def test_inputs_output(self):
|
||||
self.assertTrue(hasattr(self.tool, "inputs"))
|
||||
self.assertTrue(hasattr(self.tool, "output_type"))
|
||||
|
||||
inputs = self.tool.inputs
|
||||
self.assertTrue(isinstance(inputs, dict))
|
||||
|
||||
for _, input_spec in inputs.items():
|
||||
self.assertTrue("type" in input_spec)
|
||||
self.assertTrue("description" in input_spec)
|
||||
self.assertTrue(input_spec["type"] in AUTHORIZED_TYPES)
|
||||
self.assertTrue(isinstance(input_spec["description"], str))
|
||||
|
||||
output_type = self.tool.output_type
|
||||
self.assertTrue(output_type in AUTHORIZED_TYPES)
|
||||
|
||||
def test_common_attributes(self):
|
||||
self.assertTrue(hasattr(self.tool, "description"))
|
||||
self.assertTrue(hasattr(self.tool, "name"))
|
||||
self.assertTrue(hasattr(self.tool, "inputs"))
|
||||
self.assertTrue(hasattr(self.tool, "output_type"))
|
||||
|
||||
def test_agent_type_output(self):
|
||||
inputs = create_inputs(self.tool.inputs)
|
||||
output = self.tool(**inputs)
|
||||
agent_type = AGENT_TYPE_MAPPING[self.tool.output_type]
|
||||
self.assertTrue(isinstance(output, agent_type))
|
||||
|
||||
def test_agent_types_inputs(self):
|
||||
inputs = create_inputs(self.tool.inputs)
|
||||
_inputs = []
|
||||
for _input, expected_input in zip(inputs, self.tool.inputs.values()):
|
||||
input_type = expected_input["type"]
|
||||
_inputs.append(AGENT_TYPE_MAPPING[input_type](_input))
|
||||
|
||||
output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
|
||||
|
||||
# Should not raise an error
|
||||
output = self.tool(**inputs)
|
||||
self.assertTrue(isinstance(output, output_type))
|
68
tests/agents/test_translation.py
Normal file
68
tests/agents/test_translation.py
Normal file
@ -0,0 +1,68 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import load_tool
|
||||
from transformers.agents.agent_types import AGENT_TYPE_MAPPING
|
||||
|
||||
from .test_tools_common import ToolTesterMixin, output_type
|
||||
|
||||
|
||||
class TranslationToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("translation")
|
||||
self.tool.setup()
|
||||
self.remote_tool = load_tool("translation", remote=True)
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
result = self.tool("Hey, what's up?", src_lang="English", tgt_lang="French")
|
||||
self.assertEqual(result, "- Hé, comment ça va?")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
result = self.tool(text="Hey, what's up?", src_lang="English", tgt_lang="French")
|
||||
self.assertEqual(result, "- Hé, comment ça va?")
|
||||
|
||||
def test_call(self):
|
||||
inputs = ["Hey, what's up?", "English", "Spanish"]
|
||||
output = self.tool(*inputs)
|
||||
|
||||
self.assertEqual(output_type(output), self.tool.output_type)
|
||||
|
||||
def test_agent_type_output(self):
|
||||
inputs = ["Hey, what's up?", "English", "Spanish"]
|
||||
output = self.tool(*inputs)
|
||||
|
||||
output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
|
||||
self.assertTrue(isinstance(output, output_type))
|
||||
|
||||
def test_agent_types_inputs(self):
|
||||
example_inputs = {
|
||||
"text": "Hey, what's up?",
|
||||
"src_lang": "English",
|
||||
"tgt_lang": "Spanish",
|
||||
}
|
||||
|
||||
_inputs = []
|
||||
for input_name in example_inputs.keys():
|
||||
example_input = example_inputs[input_name]
|
||||
input_description = self.tool.inputs[input_name]
|
||||
input_type = input_description["type"]
|
||||
_inputs.append(AGENT_TYPE_MAPPING[input_type](example_input))
|
||||
|
||||
# Should not raise an error
|
||||
output = self.tool(**example_inputs)
|
||||
output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
|
||||
self.assertTrue(isinstance(output, output_type))
|
@ -1,53 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from transformers import is_vision_available, load_tool
|
||||
from transformers.testing_utils import get_tests_dir
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class ImageCaptioningToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("image-captioning")
|
||||
self.tool.setup()
|
||||
self.remote_tool = load_tool("image-captioning", remote=True)
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
||||
result = self.tool(image)
|
||||
self.assertEqual(result, "two cats sleeping on a couch")
|
||||
|
||||
def test_exact_match_arg_remote(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
||||
result = self.remote_tool(image)
|
||||
self.assertEqual(result, "two cats sleeping on a couch")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
||||
result = self.tool(image=image)
|
||||
self.assertEqual(result, "two cats sleeping on a couch")
|
||||
|
||||
def test_exact_match_kwarg_remote(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
||||
result = self.remote_tool(image=image)
|
||||
self.assertEqual(result, "two cats sleeping on a couch")
|
@ -1,53 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from transformers import is_vision_available, load_tool
|
||||
from transformers.testing_utils import get_tests_dir
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class ImageSegmentationToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("image-segmentation")
|
||||
self.tool.setup()
|
||||
self.remote_tool = load_tool("image-segmentation", remote=True)
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
||||
result = self.tool(image, "cat")
|
||||
self.assertTrue(isinstance(result, Image.Image))
|
||||
|
||||
def test_exact_match_arg_remote(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
||||
result = self.remote_tool(image, "cat")
|
||||
self.assertTrue(isinstance(result, Image.Image))
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
||||
result = self.tool(image=image, label="cat")
|
||||
self.assertTrue(isinstance(result, Image.Image))
|
||||
|
||||
def test_exact_match_kwarg_remote(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
||||
result = self.remote_tool(image=image, label="cat")
|
||||
self.assertTrue(isinstance(result, Image.Image))
|
@ -1,131 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers.testing_utils import CaptureStdout
|
||||
from transformers.tools.python_interpreter import evaluate
|
||||
|
||||
|
||||
# Fake function we will use as tool
|
||||
def add_two(x):
|
||||
return x + 2
|
||||
|
||||
|
||||
class PythonInterpreterTester(unittest.TestCase):
|
||||
def test_evaluate_assign(self):
|
||||
code = "x = 3"
|
||||
state = {}
|
||||
result = evaluate(code, {}, state=state)
|
||||
assert result == 3
|
||||
self.assertDictEqual(state, {"x": 3})
|
||||
|
||||
code = "x = y"
|
||||
state = {"y": 5}
|
||||
result = evaluate(code, {}, state=state)
|
||||
# evaluate returns the value of the last assignment.
|
||||
assert result == 5
|
||||
self.assertDictEqual(state, {"x": 5, "y": 5})
|
||||
|
||||
def test_evaluate_call(self):
|
||||
code = "y = add_two(x)"
|
||||
state = {"x": 3}
|
||||
result = evaluate(code, {"add_two": add_two}, state=state)
|
||||
assert result == 5
|
||||
self.assertDictEqual(state, {"x": 3, "y": 5})
|
||||
|
||||
# Won't work without the tool
|
||||
with CaptureStdout() as out:
|
||||
result = evaluate(code, {}, state=state)
|
||||
assert result is None
|
||||
assert "tried to execute add_two" in out.out
|
||||
|
||||
def test_evaluate_constant(self):
|
||||
code = "x = 3"
|
||||
state = {}
|
||||
result = evaluate(code, {}, state=state)
|
||||
assert result == 3
|
||||
self.assertDictEqual(state, {"x": 3})
|
||||
|
||||
def test_evaluate_dict(self):
|
||||
code = "test_dict = {'x': x, 'y': add_two(x)}"
|
||||
state = {"x": 3}
|
||||
result = evaluate(code, {"add_two": add_two}, state=state)
|
||||
self.assertDictEqual(result, {"x": 3, "y": 5})
|
||||
self.assertDictEqual(state, {"x": 3, "test_dict": {"x": 3, "y": 5}})
|
||||
|
||||
def test_evaluate_expression(self):
|
||||
code = "x = 3\ny = 5"
|
||||
state = {}
|
||||
result = evaluate(code, {}, state=state)
|
||||
# evaluate returns the value of the last assignment.
|
||||
assert result == 5
|
||||
self.assertDictEqual(state, {"x": 3, "y": 5})
|
||||
|
||||
def test_evaluate_f_string(self):
|
||||
code = "text = f'This is x: {x}.'"
|
||||
state = {"x": 3}
|
||||
result = evaluate(code, {}, state=state)
|
||||
# evaluate returns the value of the last assignment.
|
||||
assert result == "This is x: 3."
|
||||
self.assertDictEqual(state, {"x": 3, "text": "This is x: 3."})
|
||||
|
||||
def test_evaluate_if(self):
|
||||
code = "if x <= 3:\n y = 2\nelse:\n y = 5"
|
||||
state = {"x": 3}
|
||||
result = evaluate(code, {}, state=state)
|
||||
# evaluate returns the value of the last assignment.
|
||||
assert result == 2
|
||||
self.assertDictEqual(state, {"x": 3, "y": 2})
|
||||
|
||||
state = {"x": 8}
|
||||
result = evaluate(code, {}, state=state)
|
||||
# evaluate returns the value of the last assignment.
|
||||
assert result == 5
|
||||
self.assertDictEqual(state, {"x": 8, "y": 5})
|
||||
|
||||
def test_evaluate_list(self):
|
||||
code = "test_list = [x, add_two(x)]"
|
||||
state = {"x": 3}
|
||||
result = evaluate(code, {"add_two": add_two}, state=state)
|
||||
self.assertListEqual(result, [3, 5])
|
||||
self.assertDictEqual(state, {"x": 3, "test_list": [3, 5]})
|
||||
|
||||
def test_evaluate_name(self):
|
||||
code = "y = x"
|
||||
state = {"x": 3}
|
||||
result = evaluate(code, {}, state=state)
|
||||
assert result == 3
|
||||
self.assertDictEqual(state, {"x": 3, "y": 3})
|
||||
|
||||
def test_evaluate_subscript(self):
|
||||
code = "test_list = [x, add_two(x)]\ntest_list[1]"
|
||||
state = {"x": 3}
|
||||
result = evaluate(code, {"add_two": add_two}, state=state)
|
||||
assert result == 5
|
||||
self.assertDictEqual(state, {"x": 3, "test_list": [3, 5]})
|
||||
|
||||
code = "test_dict = {'x': x, 'y': add_two(x)}\ntest_dict['y']"
|
||||
state = {"x": 3}
|
||||
result = evaluate(code, {"add_two": add_two}, state=state)
|
||||
assert result == 5
|
||||
self.assertDictEqual(state, {"x": 3, "test_dict": {"x": 3, "y": 5}})
|
||||
|
||||
def test_evaluate_for(self):
|
||||
code = "x = 0\nfor i in range(3):\n x = i"
|
||||
state = {}
|
||||
result = evaluate(code, {"range": range}, state=state)
|
||||
assert result == 2
|
||||
self.assertDictEqual(state, {"x": 2, "i": 2})
|
@ -1,43 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import load_tool
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
class TextClassificationToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("text-classification")
|
||||
self.tool.setup()
|
||||
self.remote_tool = load_tool("text-classification", remote=True)
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
result = self.tool("That's quite cool", ["positive", "negative"])
|
||||
self.assertEqual(result, "positive")
|
||||
|
||||
def test_exact_match_arg_remote(self):
|
||||
result = self.remote_tool("That's quite cool", ["positive", "negative"])
|
||||
self.assertEqual(result, "positive")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
result = self.tool(text="That's quite cool", labels=["positive", "negative"])
|
||||
self.assertEqual(result, "positive")
|
||||
|
||||
def test_exact_match_kwarg_remote(self):
|
||||
result = self.remote_tool(text="That's quite cool", labels=["positive", "negative"])
|
||||
self.assertEqual(result, "positive")
|
@ -1,52 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import load_tool
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
TEXT = """
|
||||
Hugging Face was founded in 2016 by French entrepreneurs Clément Delangue, Julien Chaumond, and Thomas Wolf originally as a company that developed a chatbot app targeted at teenagers.[2] After open-sourcing the model behind the chatbot, the company pivoted to focus on being a platform for machine learning.
|
||||
|
||||
In March 2021, Hugging Face raised $40 million in a Series B funding round.[3]
|
||||
|
||||
On April 28, 2021, the company launched the BigScience Research Workshop in collaboration with several other research groups to release an open large language model.[4] In 2022, the workshop concluded with the announcement of BLOOM, a multilingual large language model with 176 billion parameters.[5]
|
||||
"""
|
||||
|
||||
|
||||
class TextQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("text-question-answering")
|
||||
self.tool.setup()
|
||||
self.remote_tool = load_tool("text-question-answering", remote=True)
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
result = self.tool(TEXT, "What did Hugging Face do in April 2021?")
|
||||
self.assertEqual(result, "launched the BigScience Research Workshop")
|
||||
|
||||
def test_exact_match_arg_remote(self):
|
||||
result = self.remote_tool(TEXT, "What did Hugging Face do in April 2021?")
|
||||
self.assertEqual(result, "launched the BigScience Research Workshop")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
result = self.tool(text=TEXT, question="What did Hugging Face do in April 2021?")
|
||||
self.assertEqual(result, "launched the BigScience Research Workshop")
|
||||
|
||||
def test_exact_match_kwarg_remote(self):
|
||||
result = self.remote_tool(text=TEXT, question="What did Hugging Face do in April 2021?")
|
||||
self.assertEqual(result, "launched the BigScience Research Workshop")
|
@ -1,64 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import load_tool
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
TEXT = """
|
||||
Hugging Face was founded in 2016 by French entrepreneurs Clément Delangue, Julien Chaumond, and Thomas Wolf originally as a company that developed a chatbot app targeted at teenagers.[2] After open-sourcing the model behind the chatbot, the company pivoted to focus on being a platform for machine learning.
|
||||
|
||||
In March 2021, Hugging Face raised $40 million in a Series B funding round.[3]
|
||||
|
||||
On April 28, 2021, the company launched the BigScience Research Workshop in collaboration with several other research groups to release an open large language model.[4] In 2022, the workshop concluded with the announcement of BLOOM, a multilingual large language model with 176 billion parameters.[5]
|
||||
"""
|
||||
|
||||
|
||||
class TextSummarizationToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("summarization")
|
||||
self.tool.setup()
|
||||
self.remote_tool = load_tool("summarization", remote=True)
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
result = self.tool(TEXT)
|
||||
self.assertEqual(
|
||||
result,
|
||||
"Hugging Face was founded in 2016 by French entrepreneurs Clément Delangue, Julien Chaumond, and Thomas Wolf. In March 2021, Hugging Face raised $40 million in a Series B funding round. On April 28, 2021, the company launched the BigScience Research Workshop in collaboration with several other research groups to release an open large language model. In 2022, the workshop concluded with the announcement of BLOOM.",
|
||||
)
|
||||
|
||||
def test_exact_match_arg_remote(self):
|
||||
result = self.remote_tool(TEXT)
|
||||
self.assertEqual(
|
||||
result,
|
||||
"Hugging Face was founded in 2016 by French entrepreneurs Clément Delangue, Julien Chaumond, and Thomas Wolf. In March 2021, Hugging Face raised $40 million in a Series B funding round. On April 28, 2021, the company launched the BigScience Research Workshop in collaboration with several other research groups to release an open large language model. In 2022, the workshop concluded with the announcement of BLOOM.",
|
||||
)
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
result = self.tool(text=TEXT)
|
||||
self.assertEqual(
|
||||
result,
|
||||
"Hugging Face was founded in 2016 by French entrepreneurs Clément Delangue, Julien Chaumond, and Thomas Wolf. In March 2021, Hugging Face raised $40 million in a Series B funding round. On April 28, 2021, the company launched the BigScience Research Workshop in collaboration with several other research groups to release an open large language model. In 2022, the workshop concluded with the announcement of BLOOM.",
|
||||
)
|
||||
|
||||
def test_exact_match_kwarg_remote(self):
|
||||
result = self.remote_tool(text=TEXT)
|
||||
self.assertEqual(
|
||||
result,
|
||||
"Hugging Face was founded in 2016 by French entrepreneurs Clément Delangue, Julien Chaumond, and Thomas Wolf. In March 2021, Hugging Face raised $40 million in a Series B funding round. On April 28, 2021, the company launched the BigScience Research Workshop in collaboration with several other research groups to release an open large language model. In 2022, the workshop concluded with the announcement of BLOOM.",
|
||||
)
|
@ -1,133 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from transformers import is_torch_available, is_vision_available
|
||||
from transformers.testing_utils import get_tests_dir, is_tool_test
|
||||
from transformers.tools.agent_types import AGENT_TYPE_MAPPING, AgentAudio, AgentImage, AgentText
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
authorized_types = ["text", "image", "audio"]
|
||||
|
||||
|
||||
def create_inputs(input_types: List[str]):
|
||||
inputs = []
|
||||
|
||||
for input_type in input_types:
|
||||
if input_type == "text":
|
||||
inputs.append("Text input")
|
||||
elif input_type == "image":
|
||||
inputs.append(
|
||||
Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png").resize((512, 512))
|
||||
)
|
||||
elif input_type == "audio":
|
||||
inputs.append(torch.ones(3000))
|
||||
elif isinstance(input_type, list):
|
||||
inputs.append(create_inputs(input_type))
|
||||
else:
|
||||
raise ValueError(f"Invalid type requested: {input_type}")
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
def output_types(outputs: List):
|
||||
output_types = []
|
||||
|
||||
for output in outputs:
|
||||
if isinstance(output, (str, AgentText)):
|
||||
output_types.append("text")
|
||||
elif isinstance(output, (Image.Image, AgentImage)):
|
||||
output_types.append("image")
|
||||
elif isinstance(output, (torch.Tensor, AgentAudio)):
|
||||
output_types.append("audio")
|
||||
else:
|
||||
raise ValueError(f"Invalid output: {output}")
|
||||
|
||||
return output_types
|
||||
|
||||
|
||||
@is_tool_test
|
||||
class ToolTesterMixin:
|
||||
def test_inputs_outputs(self):
|
||||
self.assertTrue(hasattr(self.tool, "inputs"))
|
||||
self.assertTrue(hasattr(self.tool, "outputs"))
|
||||
|
||||
inputs = self.tool.inputs
|
||||
for _input in inputs:
|
||||
if isinstance(_input, list):
|
||||
for __input in _input:
|
||||
self.assertTrue(__input in authorized_types)
|
||||
else:
|
||||
self.assertTrue(_input in authorized_types)
|
||||
|
||||
outputs = self.tool.outputs
|
||||
for _output in outputs:
|
||||
self.assertTrue(_output in authorized_types)
|
||||
|
||||
def test_call(self):
|
||||
inputs = create_inputs(self.tool.inputs)
|
||||
outputs = self.tool(*inputs)
|
||||
|
||||
# There is a single output
|
||||
if len(self.tool.outputs) == 1:
|
||||
outputs = [outputs]
|
||||
|
||||
self.assertListEqual(output_types(outputs), self.tool.outputs)
|
||||
|
||||
def test_common_attributes(self):
|
||||
self.assertTrue(hasattr(self.tool, "description"))
|
||||
self.assertTrue(hasattr(self.tool, "default_checkpoint"))
|
||||
self.assertTrue(self.tool.description.startswith("This is a tool that"))
|
||||
|
||||
def test_agent_types_outputs(self):
|
||||
inputs = create_inputs(self.tool.inputs)
|
||||
outputs = self.tool(*inputs)
|
||||
|
||||
if not isinstance(outputs, list):
|
||||
outputs = [outputs]
|
||||
|
||||
self.assertEqual(len(outputs), len(self.tool.outputs))
|
||||
|
||||
for output, output_type in zip(outputs, self.tool.outputs):
|
||||
agent_type = AGENT_TYPE_MAPPING[output_type]
|
||||
self.assertTrue(isinstance(output, agent_type))
|
||||
|
||||
def test_agent_types_inputs(self):
|
||||
inputs = create_inputs(self.tool.inputs)
|
||||
|
||||
_inputs = []
|
||||
|
||||
for _input, input_type in zip(inputs, self.tool.inputs):
|
||||
if isinstance(input_type, list):
|
||||
_inputs.append([AGENT_TYPE_MAPPING[_input_type](_input) for _input_type in input_type])
|
||||
else:
|
||||
_inputs.append(AGENT_TYPE_MAPPING[input_type](_input))
|
||||
|
||||
# Should not raise an error
|
||||
outputs = self.tool(*inputs)
|
||||
|
||||
if not isinstance(outputs, list):
|
||||
outputs = [outputs]
|
||||
|
||||
self.assertEqual(len(outputs), len(self.tool.outputs))
|
@ -1,86 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import load_tool
|
||||
from transformers.tools.agent_types import AGENT_TYPE_MAPPING
|
||||
|
||||
from .test_tools_common import ToolTesterMixin, output_types
|
||||
|
||||
|
||||
class TranslationToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("translation")
|
||||
self.tool.setup()
|
||||
self.remote_tool = load_tool("translation", remote=True)
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
result = self.tool("Hey, what's up?", src_lang="English", tgt_lang="French")
|
||||
self.assertEqual(result, "- Hé, comment ça va?")
|
||||
|
||||
def test_exact_match_arg_remote(self):
|
||||
result = self.remote_tool("Hey, what's up?", src_lang="English", tgt_lang="French")
|
||||
self.assertEqual(result, "- Hé, comment ça va?")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
result = self.tool(text="Hey, what's up?", src_lang="English", tgt_lang="French")
|
||||
self.assertEqual(result, "- Hé, comment ça va?")
|
||||
|
||||
def test_exact_match_kwarg_remote(self):
|
||||
result = self.remote_tool(text="Hey, what's up?", src_lang="English", tgt_lang="French")
|
||||
self.assertEqual(result, "- Hé, comment ça va?")
|
||||
|
||||
def test_call(self):
|
||||
inputs = ["Hey, what's up?", "English", "Spanish"]
|
||||
outputs = self.tool(*inputs)
|
||||
|
||||
# There is a single output
|
||||
if len(self.tool.outputs) == 1:
|
||||
outputs = [outputs]
|
||||
|
||||
self.assertListEqual(output_types(outputs), self.tool.outputs)
|
||||
|
||||
def test_agent_types_outputs(self):
|
||||
inputs = ["Hey, what's up?", "English", "Spanish"]
|
||||
outputs = self.tool(*inputs)
|
||||
|
||||
if not isinstance(outputs, list):
|
||||
outputs = [outputs]
|
||||
|
||||
self.assertEqual(len(outputs), len(self.tool.outputs))
|
||||
|
||||
for output, output_type in zip(outputs, self.tool.outputs):
|
||||
agent_type = AGENT_TYPE_MAPPING[output_type]
|
||||
self.assertTrue(isinstance(output, agent_type))
|
||||
|
||||
def test_agent_types_inputs(self):
|
||||
inputs = ["Hey, what's up?", "English", "Spanish"]
|
||||
|
||||
_inputs = []
|
||||
|
||||
for _input, input_type in zip(inputs, self.tool.inputs):
|
||||
if isinstance(input_type, list):
|
||||
_inputs.append([AGENT_TYPE_MAPPING[_input_type](_input) for _input_type in input_type])
|
||||
else:
|
||||
_inputs.append(AGENT_TYPE_MAPPING[input_type](_input))
|
||||
|
||||
# Should not raise an error
|
||||
outputs = self.tool(*inputs)
|
||||
|
||||
if not isinstance(outputs, list):
|
||||
outputs = [outputs]
|
||||
|
||||
self.assertEqual(len(outputs), len(self.tool.outputs))
|
@ -2,6 +2,8 @@ docs/source/en/_config.py
|
||||
docs/source/en/accelerate.md
|
||||
docs/source/en/add_new_model.md
|
||||
docs/source/en/add_new_pipeline.md
|
||||
docs/source/en/agents.md
|
||||
docs/source/en/agents.md
|
||||
docs/source/en/attention.md
|
||||
docs/source/en/benchmarks.md
|
||||
docs/source/en/bertology.md
|
||||
@ -10,7 +12,6 @@ docs/source/en/community.md
|
||||
docs/source/en/contributing.md
|
||||
docs/source/en/create_a_model.md
|
||||
docs/source/en/custom_models.md
|
||||
docs/source/en/custom_tools.md
|
||||
docs/source/en/debugging.md
|
||||
docs/source/en/fast_tokenizers.md
|
||||
docs/source/en/glossary.md
|
||||
@ -324,10 +325,20 @@ docs/source/en/tflite.md
|
||||
docs/source/en/tokenizer_summary.md
|
||||
docs/source/en/torchscript.md
|
||||
docs/source/en/training.md
|
||||
docs/source/en/transformers_agents.md
|
||||
docs/source/en/troubleshooting.md
|
||||
src/transformers/activations.py
|
||||
src/transformers/activations_tf.py
|
||||
src/transformers/agents/agent_types.py
|
||||
src/transformers/agents/agents.py
|
||||
src/transformers/agents/document_question_answering.py
|
||||
src/transformers/agents/evaluate_agent.py
|
||||
src/transformers/agents/image_question_answering.py
|
||||
src/transformers/agents/prompts.py
|
||||
src/transformers/agents/python_interpreter.py
|
||||
src/transformers/agents/speech_to_text.py
|
||||
src/transformers/agents/text_to_speech.py
|
||||
src/transformers/agents/tools.py
|
||||
src/transformers/agents/translation.py
|
||||
src/transformers/audio_utils.py
|
||||
src/transformers/benchmark/benchmark.py
|
||||
src/transformers/benchmark/benchmark_args.py
|
||||
@ -974,22 +985,6 @@ src/transformers/time_series_utils.py
|
||||
src/transformers/tokenization_utils.py
|
||||
src/transformers/tokenization_utils_base.py
|
||||
src/transformers/tokenization_utils_fast.py
|
||||
src/transformers/tools/agent_types.py
|
||||
src/transformers/tools/agents.py
|
||||
src/transformers/tools/base.py
|
||||
src/transformers/tools/document_question_answering.py
|
||||
src/transformers/tools/evaluate_agent.py
|
||||
src/transformers/tools/image_captioning.py
|
||||
src/transformers/tools/image_question_answering.py
|
||||
src/transformers/tools/image_segmentation.py
|
||||
src/transformers/tools/prompts.py
|
||||
src/transformers/tools/python_interpreter.py
|
||||
src/transformers/tools/speech_to_text.py
|
||||
src/transformers/tools/text_classification.py
|
||||
src/transformers/tools/text_question_answering.py
|
||||
src/transformers/tools/text_summarization.py
|
||||
src/transformers/tools/text_to_speech.py
|
||||
src/transformers/tools/translation.py
|
||||
src/transformers/trainer.py
|
||||
src/transformers/trainer_callback.py
|
||||
src/transformers/trainer_pt_utils.py
|
||||
|
Loading…
Reference in New Issue
Block a user