[generate] Run custom generation code from the Hub (#36405)

* mvp

* remove trust_remote_code

* generate_from_hub

* handle requirements; docs

* english

* doc PR suggestions

* Apply suggestions from code review

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* changed remote code path to generate/generate.py

* model repo has custom generate -> override base generate

* check for proper inheritance

* some doc updates (missing: tag-related docs)

* update docs to model repo

* nit

* nit

* nits

* Update src/transformers/dynamic_module_utils.py

* Apply suggestions from code review

* Update docs/source/en/generation_strategies.md

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* trust remote code is required

* use new import utils for requirements version parsing

* use  org examples

* add tests

* Apply suggestions from code review

Co-authored-by: Manuel de Prada Corral <6536835+manueldeprada@users.noreply.github.com>

* ascii file structure; tag instructions on readme.md

---------

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Co-authored-by: Manuel de Prada Corral <6536835+manueldeprada@users.noreply.github.com>
This commit is contained in:
Joao Gante 2025-05-15 10:35:54 +01:00 committed by GitHub
parent 955e61b0da
commit 0e0e5c1044
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 522 additions and 97 deletions

View File

@ -20,11 +20,15 @@ A decoding strategy informs how a model should select the next generated token.
This guide will help you understand the different decoding strategies available in Transformers and how and when to use them.
## Greedy search
## Basic decoding methods
Greedy search is the default decoding strategy. It selects the next most likely token at each step. Unless specified in [`GenerationConfig`], this strategy generates a maximum of 20 tokens.
These are well established decoding methods, and should be your starting point for text generation tasks.
Greedy search works well for tasks with relatively short outputs. However, it breaks down when generating longer sequences because it begins to repeat itself.
### Greedy search
Greedy search is the default decoding strategy. It selects the next most likely token at each step. Unless specified in [`GenerationConfig`], this strategy generates a maximum of 20 new tokens.
Greedy search works well for tasks with relatively short outputs where creativity is not a priority. However, it breaks down when generating longer sequences because it begins to repeat itself.
```py
import torch
@ -40,11 +44,11 @@ tokenizer.batch_decode(outputs, skip_special_tokens=True)
'Hugging Face is an open-source company that provides a suite of tools and services for building, deploying, and maintaining natural language processing'
```
## Contrastive search
### Sampling
[Contrastive search](https://huggingface.co/papers/2202.06417) is a decoding strategy that aims to reduce repetition even while generating longer sequences. This strategy compares how similar a generated token is against previous tokens, and if they're more similar, a penalty is applied.
Sampling, or multinomial sampling, randomly selects a token based on the probability distribution over the entire model's vocabulary (as opposed to the most likely token, as in greedy search). This means every token with a non-zero probability has a chance to be selected. Sampling strategies reduce repetition and can generate more creative and diverse outputs.
Enable contrastive search with the `penalty_alpha` and `top_k` parameters. The `penalty_alpha` manages the penalty applied and `top_k` is the number of most likely tokens to return.
Enable multinomial sampling with `do_sample=True` and `num_beams=1`.
```py
import torch
@ -55,14 +59,14 @@ inputs = tokenizer("Hugging Face is an open-source company", return_tensors="pt"
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype=torch.float16).to("cuda")
# explicitly set to 100 because Llama2 generation length is 4096
outputs = model.generate(**inputs, max_new_tokens=100, penalty_alpha=0.6, top_k=4)
outputs = model.generate(**inputs, max_new_tokens=50, do_sample=True, num_beams=1)
tokenizer.batch_decode(outputs, skip_special_tokens=True)
'Hugging Face is an open-source company that provides a platform for building and deploying AI models.\nHugging Face is an open-source company that provides a platform for building and deploying AI models. The platform allows developers to build and deploy AI models, as well as collaborate with other developers.\nHugging Face was founded in 2019 by Thibault Wittemberg and Clément Delangue. The company is based in Paris, France.\nHugging Face has'
'Hugging Face is an open-source company 🤗\nWe are open-source and believe that open-source is the best way to build technology. Our mission is to make AI accessible to everyone, and we believe that open-source is the best way to achieve that.'
```
## Beam search
### Beam search
Beam search keeps track of several generated sequences (beams) at each time step. After a certain number of steps, it selects the sequence with the highest *overall* probability. Unlike greedy search, this strategy can "look ahead" and pick a sequence with a higher probability overall even if the initial tokens have a lower probability.
Beam search keeps track of several generated sequences (beams) at each time step. After a certain number of steps, it selects the sequence with the highest *overall* probability. Unlike greedy search, this strategy can "look ahead" and pick a sequence with a higher probability overall even if the initial tokens have a lower probability. It is best suited for input-grounded tasks, like describing an image or speech recognition. You can also use `do_sample=True` with beam search to sample at each step, but beam search will still greedily prune out low probability sequences between steps.
> [!TIP]
> Check out the [beam search visualizer](https://huggingface.co/spaces/m-ric/beam_search_visualizer) to see how beam search works.
@ -83,66 +87,11 @@ tokenizer.batch_decode(outputs, skip_special_tokens=True)
"['Hugging Face is an open-source company that develops and maintains the Hugging Face platform, which is a collection of tools and libraries for building and deploying natural language processing (NLP) models. Hugging Face was founded in 2018 by Thomas Wolf']"
```
## Diverse beam search
## Advanced decoding methods
[Diverse beam search](https://hf.co/papers/1610.02424) is a variant of beam search that produces more diverse output candidates to choose from. This strategy measures the dissimilarity of sequences and a penalty is applied if sequences are too similar. To avoid high computation costs, the number of beams is divided into groups.
Advanced decoding methods aim at either tackling specific generation quality issues (e.g. repetition) or at improving the generation throughput in certain situations. These techniques are more complex, and may not work correctly with all models.
Enable diverse beam search with the `num_beams`, `num_beam_groups` and `diversity_penalty` parameters (the `num_beams` parameter should be divisible by `num_beam_groups`).
```py
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
inputs = tokenizer("Hugging Face is an open-source company", return_tensors="pt").to("cuda")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype=torch.float16).to("cuda")
# explicitly set to 100 because Llama2 generation length is 4096
outputs = model.generate(**inputs, max_new_tokens=50, num_beams=6, num_beam_groups=3, diversity_penalty=1.0, do_sample=False)
tokenizer.batch_decode(outputs, skip_special_tokens=True)
'Hugging Face is an open-source company 🤗\nWe are an open-source company. Our mission is to democratize AI and make it accessible to everyone. We believe that AI should be used for the benefit of humanity, not for the benefit of a'
```
## Multinomial sampling
Search methods selects the most likely tokens. Sampling, or multinomial sampling, randomly selects a token based on the probability distribution over the entire models vocabulary. This means every token with a non-zero probability has a chance to be selected. Sampling strategies reduce repetition and can generate more creative and diverse outputs.
Enable multinomial sampling with `do_sample=True` and `num_beams=1`.
```py
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
inputs = tokenizer("Hugging Face is an open-source company", return_tensors="pt").to("cuda")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype=torch.float16).to("cuda")
# explicitly set to 100 because Llama2 generation length is 4096
outputs = model.generate(**inputs, max_new_tokens=50, do_sample=True, num_beams=1)
tokenizer.batch_decode(outputs, skip_special_tokens=True)
'Hugging Face is an open-source company 🤗\nWe are open-source and believe that open-source is the best way to build technology. Our mission is to make AI accessible to everyone, and we believe that open-source is the best way to achieve that.'
```
## Beam search multinomial sampling
This decoding strategy is a combination of beam search and multinomial sampling. It generates multiple beams and uses a sampling strategy for each beam.
Enable beam search multinomial sampling by setting `num_beams` to a value greater than 1 and `do_sample=True`.
```py
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
inputs = tokenizer("Hugging Face is an open-source company", return_tensors="pt").to("cuda")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype=torch.float16).to("cuda")
# explicitly set to 100 because Llama2 generation length is 4096
outputs = model.generate(**inputs, max_new_tokens=50, do_sample=True, num_beams=4)
'Hugging Face is an open-source company 100% dedicated to making AI more accessible. We believe that AI should be available to everyone, and were working hard to make that a reality.\nWere a team of passionate engineers, designers,'
```
## Speculative decoding
### Speculative decoding
[Speculative](https://hf.co/papers/2211.17192) or assistive decoding isn't a search or sampling strategy. Instead, speculative decoding adds a second smaller model to generate candidate tokens. The main model verifies the candidate tokens in a single `forward` pass, which speeds up the decoding process overall. This method is especially useful for LLMs where it can be more costly and slower to generate tokens. Refer to the [speculative decoding](./llm_optims#speculative-decoding) guide to learn more.
@ -203,7 +152,7 @@ tokenizer.batch_decode(outputs, skip_special_tokens=True)
</hfoption>
</hfoptions>
### Prompt lookup decoding
#### Prompt lookup decoding
[Prompt lookup decoding](./llm_optims#prompt-lookup-decoding) is a variant of speculative decoding that uses overlapping n-grams as the candidate tokens. It works well for input-grounded tasks such as summarization. Refer to the [prompt lookup decoding](./llm_optims#prompt-lookup-decoding) guide to learn more.
@ -245,7 +194,7 @@ outputs = model.generate(**inputs, assistant_early_exit=4, do_sample=False, max_
tokenizer.batch_decode(outputs, skip_special_tokens=True)
```
### Universal assisted decoding
#### Universal assisted decoding
Universal assisted decoding (UAD) enables the main and assistant models to use different tokenizers. The main models input tokens are re-encoded into assistant model tokens. Candidate tokens are generated in the assistant encoding which are re-encoded into the main model candidate tokens. The candidate tokens are verified as explained in [speculative decoding](#speculative-decoding).
@ -269,7 +218,27 @@ tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a']
```
## DoLa
### Contrastive search
[Contrastive search](https://huggingface.co/papers/2202.06417) is a decoding strategy that aims to reduce repetition even while generating longer sequences. This strategy compares how similar a generated token is against previous tokens, and if they're more similar, a penalty is applied.
Enable contrastive search with the `penalty_alpha` and `top_k` parameters. The `penalty_alpha` manages the penalty applied and `top_k` is the number of most likely tokens to return.
```py
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
inputs = tokenizer("Hugging Face is an open-source company", return_tensors="pt").to("cuda")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype=torch.float16).to("cuda")
# explicitly set to 100 because Llama2 generation length is 4096
outputs = model.generate(**inputs, max_new_tokens=100, penalty_alpha=0.6, top_k=4)
tokenizer.batch_decode(outputs, skip_special_tokens=True)
'Hugging Face is an open-source company that provides a platform for building and deploying AI models.\nHugging Face is an open-source company that provides a platform for building and deploying AI models. The platform allows developers to build and deploy AI models, as well as collaborate with other developers.\nHugging Face was founded in 2019 by Thibault Wittemberg and Clément Delangue. The company is based in Paris, France.\nHugging Face has'
```
### DoLa
[Decoding by Contrasting Layers (DoLa)](https://hf.co/papers/2309.03883) is a contrastive decoding strategy for improving factuality and reducing hallucination. This strategy works by contrasting the logit differences between the final and early layers. As a result, factual knowledge localized to particular layers are amplified. DoLa is not recommended for smaller models like GPT-2.
@ -325,6 +294,210 @@ tokenizer.batch_decode(outputs[:, inputs.input_ids.shape[-1]:], skip_special_tok
</hfoption>
</hfoptions>
### Diverse beam search
[Diverse beam search](https://hf.co/papers/1610.02424) is a variant of beam search that produces more diverse output candidates to choose from. This strategy measures the dissimilarity of sequences and a penalty is applied if sequences are too similar. To avoid high computation costs, the number of beams is divided into groups.
Enable diverse beam search with the `num_beams`, `num_beam_groups` and `diversity_penalty` parameters (the `num_beams` parameter should be divisible by `num_beam_groups`).
```py
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
inputs = tokenizer("Hugging Face is an open-source company", return_tensors="pt").to("cuda")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype=torch.float16).to("cuda")
# explicitly set to 100 because Llama2 generation length is 4096
outputs = model.generate(**inputs, max_new_tokens=50, num_beams=6, num_beam_groups=3, diversity_penalty=1.0, do_sample=False)
tokenizer.batch_decode(outputs, skip_special_tokens=True)
'Hugging Face is an open-source company 🤗\nWe are an open-source company. Our mission is to democratize AI and make it accessible to everyone. We believe that AI should be used for the benefit of humanity, not for the benefit of a'
```
## Custom decoding methods
Custom decoding methods enable specialized generation behavior such as the following:
- have the model continue thinking if it is uncertain;
- roll back generation if the model gets stuck;
- handle special tokens with custom logic;
- enhanced input preparation for advanced models;
We enable custom decoding methods through model repositories, assuming a specific model tag and file structure (see subsection below). This feature is an extension of [custom modeling code](./models.md#custom-models) and, like such, requires setting `trust_remote_code=True`.
If a model repository holds a custom decoding method, the easiest way to try it out is to load the model and generate with it:
<!-- TODO before merging: 1) better repo name (use a `generate-community` org?) 2) prettify the repo -->
```py
from transformers import AutoModelForCausalLM, AutoTokenizer
# `transformers-community/custom_generate_example` holds a copy of `Qwen/Qwen2.5-0.5B-Instruct`, but
# with custom generation code -> calling `generate` uses the custom decoding method!
tokenizer = AutoTokenizer.from_pretrained("transformers-community/custom_generate_example")
model = AutoModelForCausalLM.from_pretrained(
"transformers-community/custom_generate_example", device_map="auto", trust_remote_code=True
)
inputs = tokenizer(["The quick brown"], return_tensors="pt").to(model.device)
# The custom decoding method is a minimal greedy decoding implementation. It also prints a custom message at run time.
gen_out = model.generate(**inputs)
# you should now see its custom message, "✨ using a custom generation method ✨"
print(tokenizer.batch_decode(gen_out, skip_special_tokens=True))
'The quick brown fox jumps over a lazy dog, and the dog is a type of animal. Is'
```
Model repositories with custom decoding methods have a special property: their decoding method can be loaded from **any** model through [`~GenerationMixin.generate`]'s `custom_generate` argument. This means anyone can create and share their custom generation method to potentially work with any Transformers model, without requiring users to install additional Python packages.
```py
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct", device_map="auto")
inputs = tokenizer(["The quick brown"], return_tensors="pt").to(model.device)
# `custom_generate` replaces the original `generate` by the custom decoding method defined in
# `transformers-community/custom_generate_example`
gen_out = model.generate(**inputs, custom_generate="transformers-community/custom_generate_example", trust_remote_code=True)
print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0])
'The quick brown fox jumps over a lazy dog, and the dog is a type of animal. Is'
```
You should read the `README.md` file of the repository containing the custom generation strategy to see what the new arguments and output type differences are, if they exist. Otherwise, you can assume it works like the base [`~GenerationMixin.generate`] method.
> [!TIP]
> You can find all custom decoding methods by [searching for their custom tag.](https://huggingface.co/models?other=custom_generate), `custom_generate`
Consider the Hub repository [transformers-community/custom_generate_example](https://huggingface.co/transformers-community/custom_generate_example) as an example. The `README.md` states that it has an additional input argument, `left_padding`, which adds a number of padding tokens before the prompt.
```py
gen_out = model.generate(
**inputs, custom_generate="transformers-community/custom_generate_example", trust_remote_code=True, left_padding=5
)
print(tokenizer.batch_decode(gen_out)[0])
'<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>The quick brown fox jumps over the lazy dog.\n\nThe sentence "The quick'
```
If the custom method has pinned Python requirements that your environment doesn't meet, you'll get an exception about missing requirements. For instance, [transformers-community/custom_generate_bad_requirements](https://huggingface.co/transformers-community/custom_generate_bad_requirements) has an impossible set of requirements defined in its `custom_generate/requirements.txt` file, and you'll see the error message below if you try to run it.
```
ImportError: Missing requirements in your local environment for `transformers-community/custom_generate_bad_requirements`:
foo (installed: None)
bar==0.0.0 (installed: None)
torch>=99.0 (installed: 2.6.0)
```
Updating your Python requirements accordingly will remove this error message.
### Creating a custom decoding method
To create a new decoding method, you need to create a new [**Model**](https://huggingface.co/new) repository and push a few files into it.
1. The model you've designed your decoding method with.
2. `custom_generate/generate.py`, which contains all the logic for your custom decoding method.
3. `custom_generate/requirements.txt`, used to optionally add new Python requirements and/or lock specific versions to correctly use your method.
4. `README.md`, where you should add the `custom_generate` tag and document any new arguments or output type differences of your custom method here.
After you've added all required files, your repository should look like this
```
your_repo/
├── README.md # include the 'custom_generate' tag
├── config.json
├── ...
└── custom_generate/
├── generate.py
└── requirements.txt
```
#### Adding the base model
The starting point for your custom decoding method is a model repository just like any other. The model to add to this repository should be the model you've designed your method with, and it is meant to be part of a working self-contained model-generate pair. When the model in this repository is loaded, your custom decoding method will override `generate`. Don't worry -- your decoding method can still be loaded with any other Transformers model, as explained in the section above.
If you simply want to copy an existing model, you can do
```py
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("source/model_repo")
model = AutoModelForCausalLM.from_pretrained("source/model_repo")
tokenizer.save_pretrained("your/decoding_method", push_to_hub=True)
model.save_pretrained("your/decoding_method", push_to_hub=True)
```
#### generate.py
This is the core of your decoding method. It *must* contain a method named `generate`, and this method *must* contain a `model` argument as its first argument. `model` is the model instance, which means you have access to all attributes and methods in the model, including the ones defined in [`GenerationMixin`] (like the base `generate` method).
> [!WARNING]
> `generate.py` must be placed in a folder named `custom_generate`, and not at the root level of the repository. The file paths for this feature are hardcoded.
Under the hood, when the base [`~GenerationMixin.generate`] method is called with a `custom_generate` argument, it first checks its Python requirements (if any), then locates the custom `generate` method in `generate.py`, and finally calls the custom `generate`. All received arguments and `model` are forwarded to your custom `generate` method.
This means your `generate` can have a mix of original and custom arguments (as well as a different output type) as shown below.
```py
import torch
def generate(model, input_ids, generation_config=None, left_padding=None, **kwargs):
generation_config = generation_config or model.generation_config # default to the model generation config
cur_length = input_ids.shape[1]
max_length = generation_config.max_length or cur_length + generation_config.max_new_tokens
# Example of custom argument: add `left_padding` (integer) pad tokens before the prompt
if left_padding is not None:
if not isinstance(left_padding, int) or left_padding < 0:
raise ValueError(f"left_padding must be an integer larger than 0, but is {left_padding}")
pad_token = kwargs.pop("pad_token", None) or generation_config.pad_token_id or model.config.pad_token_id
if pad_token is None:
raise ValueError("pad_token is not defined")
batch_size = input_ids.shape[0]
pad_tensor = torch.full(size=(batch_size, left_padding), fill_value=pad_token).to(input_ids.device)
input_ids = torch.cat((pad_tensor, input_ids), dim=1)
cur_length = input_ids.shape[1]
# Simple greedy decoding loop
while cur_length < max_length:
logits = model(input_ids).logits
next_token_logits = logits[:, -1, :]
next_tokens = torch.argmax(next_token_logits, dim=-1)
input_ids = torch.cat((input_ids, next_tokens[:, None]), dim=-1)
cur_length += 1
return input_ids
```
Follow the recommended practices below to ensure your custom decoding method works as expected.
- Feel free to reuse the logic for validation and input preparation in the original [`~GenerationMixin.generate`].
- Pin the `transformers` version in the requirements if you use any private method/attribute in `model`.
- You can add other files in the `custom_generate` folder, and use relative imports.
- Consider adding model validation, input validation, or even a separate test file to help users sanity-check your code in their environment.
#### requirements.txt
You can optionally specify additional Python requirements in a `requirements.txt` file inside the `custom_generate` folder. These are checked at runtime and an exception will be thrown if they're missing, nudging users to update their environment accordingly.
#### README.md
The root level `README.md` in the model repository usually describes the model therein. However, since the focus of the repository is the custom decoding method, we highly recommend to shift its focus towards describing the custom decoding method. In addition to a description of the method, we recommend documenting any input and/or output differences to the original [`~GenerationMixin.generate`]. This way, users can focus on what's new, and rely on Transformers docs for generic implementation details.
For discoverability, we highly recommend you to add the `custom_generate` tag to your repository. To do so, the top of your `README.md` file should look like the example below. After you push the file, you should see the tag in your repository!
```
---
library_name: transformers
tags:
- custom_generate
---
(your markdown content here)
```
Recommended practices:
- Document input and output differences in [`~GenerationMixin.generate`].
- Add self-contained examples to enable quick experimentation.
- Describe soft-requirements such as if the method only works well with a certain family of models.
## Resources
Read the [How to generate text: using different decoding methods for language generation with Transformers](https://huggingface.co/blog/how-to-generate) blog post for an explanation of how common decoding strategies work.

View File

@ -17,6 +17,7 @@ import ast
import filecmp
import hashlib
import importlib
import importlib.metadata
import importlib.util
import os
import re
@ -30,6 +31,7 @@ from types import ModuleType
from typing import Any, Optional, Union
from huggingface_hub import try_to_load_from_cache
from packaging import version
from .utils import (
HF_MODULES_CACHE,
@ -39,6 +41,7 @@ from .utils import (
is_offline_mode,
logging,
)
from .utils.import_utils import VersionComparison, split_package_version
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@ -383,7 +386,7 @@ def get_cached_module_file(
new_files.append(module_file)
except OSError:
logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
logger.info(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
raise
# Check we have all the requirements in our environment
@ -417,7 +420,8 @@ def get_cached_module_file(
# benefit of versioning.
submodule_path = submodule_path / commit_hash
full_submodule = full_submodule + os.path.sep + commit_hash
create_dynamic_module(full_submodule)
full_submodule_module_file_path = os.path.join(full_submodule, module_file)
create_dynamic_module(Path(full_submodule_module_file_path).parent)
if not (submodule_path / module_file).exists():
shutil.copy(resolved_module_file, submodule_path / module_file)
@ -663,7 +667,33 @@ def _raise_timeout_error(signum, frame):
TIME_OUT_REMOTE_CODE = 15
def resolve_trust_remote_code(trust_remote_code, model_name, has_local_code, has_remote_code):
def resolve_trust_remote_code(trust_remote_code, model_name, has_local_code, has_remote_code, error_message=None):
"""
Resolves the `trust_remote_code` argument. If there is remote code to be loaded, the user must opt-in to loading
it.
Args:
trust_remote_code (`bool` or `None`):
User-defined `trust_remote_code` value.
model_name (`str`):
The name of the model repository in huggingface.co.
has_local_code (`bool`):
Whether the model has local code.
has_remote_code (`bool`):
Whether the model has remote code.
error_message (`str`, *optional*):
Custom error message to display if there is remote code to load and the user didn't opt-in. If unset, the error
message will be regarding loading a model with custom code.
Returns:
The resolved `trust_remote_code` value.
"""
# Originally, `trust_remote_code` was used to load models with custom code.
error_message = (
error_message
or f"The repository `{model_name}` contains custom code which must be executed to correctly load the model."
)
if trust_remote_code is None:
if has_local_code:
trust_remote_code = False
@ -674,8 +704,7 @@ def resolve_trust_remote_code(trust_remote_code, model_name, has_local_code, has
signal.alarm(TIME_OUT_REMOTE_CODE)
while trust_remote_code is None:
answer = input(
f"The repository for {model_name} contains custom code which must be executed to correctly "
f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
f"{error_message} You can inspect the repository content at https://hf.co/{model_name}.\n"
f"You can avoid this prompt in future by passing the argument `trust_remote_code=True`.\n\n"
f"Do you wish to run the custom code? [y/N] "
)
@ -687,8 +716,7 @@ def resolve_trust_remote_code(trust_remote_code, model_name, has_local_code, has
except Exception:
# OS which does not support signal.SIGALRM
raise ValueError(
f"The repository for {model_name} contains custom code which must be executed to correctly "
f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
f"{error_message} You can inspect the repository content at https://hf.co/{model_name}.\n"
f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
)
finally:
@ -701,9 +729,64 @@ def resolve_trust_remote_code(trust_remote_code, model_name, has_local_code, has
if has_remote_code and not has_local_code and not trust_remote_code:
raise ValueError(
f"Loading {model_name} requires you to execute the configuration file in that"
" repo on your local machine. Make sure you have read the code there to avoid malicious use, then"
" set the option `trust_remote_code=True` to remove this error."
f"{error_message} You can inspect the repository content at https://hf.co/{model_name}.\n"
f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
)
return trust_remote_code
def check_python_requirements(path_or_repo_id, requirements_file="requirements.txt", **kwargs):
"""
Tries to locate `requirements_file` in a local folder or repo, and confirms that the environment has all the
python dependencies installed.
Args:
path_or_repo_id (`str` or `os.PathLike`):
This can be either:
- a string, the *model id* of a model repo on huggingface.co.
- a path to a *directory* potentially containing the file.
kwargs (`Dict[str, Any]`, *optional*):
Additional arguments to pass to `cached_file`.
"""
failed = [] # error messages regarding requirements
try:
requirements = cached_file(path_or_repo_id=path_or_repo_id, filename=requirements_file, **kwargs)
with open(requirements, "r") as f:
requirements = f.readlines()
for requirement in requirements:
requirement = requirement.strip()
if not requirement or requirement.startswith("#"): # skip empty lines and comments
continue
try:
# e.g. "torch>2.6.0" -> "torch", ">", "2.6.0"
package_name, delimiter, version_number = split_package_version(requirement)
except ValueError: # e.g. "torch", as opposed to "torch>2.6.0"
package_name = requirement
delimiter, version_number = None, None
try:
local_package_version = importlib.metadata.version(package_name)
except importlib.metadata.PackageNotFoundError:
failed.append(f"{requirement} (installed: None)")
continue
if delimiter is not None and version_number is not None:
is_satisfied = VersionComparison.from_string(delimiter)(
version.parse(local_package_version), version.parse(version_number)
)
else:
is_satisfied = True
if not is_satisfied:
failed.append(f"{requirement} (installed: {local_package_version})")
except OSError: # no requirements.txt
pass
if failed:
raise ImportError(
f"Missing requirements in your local environment for `{path_or_repo_id}`:\n" + "\n".join(failed)
)

View File

@ -23,12 +23,11 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Un
import numpy as np
import torch
import torch.distributed as dist
from huggingface_hub import file_exists
from packaging import version
from torch import nn
from torch.nn import functional as F
from transformers.generation.candidate_generator import AssistantVocabTranslatorCache
from ..cache_utils import (
Cache,
DynamicCache,
@ -39,6 +38,12 @@ from ..cache_utils import (
QuantizedCacheConfig,
)
from ..configuration_utils import PretrainedConfig
from ..dynamic_module_utils import (
check_python_requirements,
get_cached_module_file,
get_class_in_module,
resolve_trust_remote_code,
)
from ..integrations.deepspeed import is_deepspeed_zero3_enabled
from ..integrations.fsdp import is_fsdp_managed_module
from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
@ -55,6 +60,7 @@ from ..utils import (
from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint
from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
from .candidate_generator import (
AssistantVocabTranslatorCache,
AssistedCandidateGenerator,
AssistedCandidateGeneratorDifferentTokenizers,
CandidateGenerator,
@ -376,6 +382,73 @@ class GenerationMixin:
To learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
"""
def load_custom_generate(
self,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
trust_remote_code: Optional[bool] = None,
**kwargs,
) -> Callable:
"""
Loads and returns a custom generate function, given a model repo.
Args:
pretrained_model_name_or_path (`str` or `os.PathLike`):
Can be either:
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a *directory* containing model weights saved using
[`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
trust_remote_code (`bool`, *optional*):
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
should only be set to `True` for repositories you trust and in which you have read the code, as it will
execute code present on the Hub on your local machine.
**kwargs:
Additional keyword arguments for remote code loading.
Raises:
OSError: If `pretrained_model_name_or_path` does not contain a `custom_generate` subdirectory.
Returns:
A callable that can be used to generate text.
"""
# Does `pretrained_model_name_or_path` have a `custom_generate` subdirectory? If not -> OSError
is_local_code = os.path.exists(pretrained_model_name_or_path)
has_custom_generate_folder = True
if is_local_code:
if not os.path.exists(os.path.join(pretrained_model_name_or_path, "custom_generate/generate.py")):
has_custom_generate_folder = False
else:
if not file_exists(pretrained_model_name_or_path, "custom_generate/generate.py"):
has_custom_generate_folder = False
if not has_custom_generate_folder:
raise OSError(
f"`{pretrained_model_name_or_path}` does not contain a `custom_generate` subdirectory with a "
"`generate.py` file, can't load the custom generate function."
)
# Handle opt-in `trust_remote_code` and related exceptions
error_message = (
f"The repository `{pretrained_model_name_or_path}` contains custom generation code that will override "
"the default `generate` method."
)
resolve_trust_remote_code(
trust_remote_code,
pretrained_model_name_or_path,
has_local_code=is_local_code,
has_remote_code=not is_local_code,
error_message=error_message,
)
# Load the custom generate function
check_python_requirements(
pretrained_model_name_or_path, requirements_file="custom_generate/requirements.txt", **kwargs
)
module = get_cached_module_file(
pretrained_model_name_or_path, module_file="custom_generate/generate.py", **kwargs
)
custom_generate_function = get_class_in_module("generate", module)
return custom_generate_function
def _cache_dependant_input_preparation(
self,
input_ids: torch.LongTensor,
@ -2158,6 +2231,7 @@ class GenerationMixin:
negative_prompt_ids: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
use_model_defaults: Optional[bool] = None,
custom_generate: Optional[str] = None,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
r"""
@ -2227,6 +2301,11 @@ class GenerationMixin:
generation configuration (`model.generation_config`), as opposed to the global defaults
(`GenerationConfig()`). If unset, models saved starting from `v4.50` will consider this flag to be
`True`.
custom_generate (`str`, *optional*):
A string containing the name of a huggingface.co repository. If provided, the custom `generate`
function defined in that reposity's `custom_generate/generate.py` file will be executed instead of the
standard `generate` method. Note that the logic is for generation is entirely defined in that
repository, and the return type may be different from the standard `generate` method.
kwargs (`Dict[str, Any]`, *optional*):
Ad hoc parametrization of `generation_config` and/or additional model-specific kwargs that will be
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
@ -2248,6 +2327,20 @@ class GenerationMixin:
- [`~generation.GenerateEncoderDecoderOutput`],
- [`~generation.GenerateBeamEncoderDecoderOutput`]
"""
# 0. If requested, load an arbitrary generation recipe from the Hub and run it instead
if custom_generate is not None:
trust_remote_code = kwargs.pop("trust_remote_code", None)
# Get all `generate` arguments in a single variable. Custom functions are responsible for handling them:
# they receive the same inputs as `generate`, only with `model` instead of `self`. They can access to
# methods from `GenerationMixin` through `model`.
global_keys_to_exclude = {"self", "kwargs"}
generate_arguments = {key: value for key, value in locals().items() if key not in global_keys_to_exclude}
generate_arguments.update(kwargs)
custom_generate_function = self.load_custom_generate(
custom_generate, trust_remote_code=trust_remote_code, **kwargs
)
return custom_generate_function(model=self, **generate_arguments)
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria

View File

@ -4104,6 +4104,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
gguf_file = kwargs.pop("gguf_file", None)
tp_plan = kwargs.pop("tp_plan", None)
tp_size = kwargs.pop("tp_size", None)
trust_remote_code = kwargs.pop("trust_remote_code", None)
# Load models with hardcoded key mapping on class for VLMs only, to keep BC and standardize model
if any(allowed_name in cls.__name__.lower() for allowed_name in VLMS):
@ -4113,7 +4114,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
# Not used anymore -- remove them from the kwargs
_ = kwargs.pop("resume_download", None)
_ = kwargs.pop("trust_remote_code", None)
_ = kwargs.pop("mirror", None)
_ = kwargs.pop("_fast_init", True)
_ = kwargs.pop("low_cpu_mem_usage", None)
@ -4591,30 +4591,44 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
# Set model in evaluation mode to deactivate DropOut modules by default
model.eval()
# If it is a model with generation capabilities, attempt to load the generation config
# If it is a model with generation capabilities, attempt to load generation files (generation config,
# custom generate function)
if model.can_generate() and generation_config is not None:
logger.info("The user-defined `generation_config` will be used to override the default generation config.")
model.generation_config = model.generation_config.from_dict(generation_config.to_dict())
elif model.can_generate() and pretrained_model_name_or_path is not None:
repo_loading_kwargs = {
"cache_dir": cache_dir,
"force_download": force_download,
"proxies": proxies,
"local_files_only": local_files_only,
"token": token,
"revision": revision,
"subfolder": subfolder,
**kwargs,
}
# Load generation config
try:
model.generation_config = GenerationConfig.from_pretrained(
pretrained_model_name_or_path,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
_from_auto=from_auto_class,
_from_pipeline=from_pipeline,
**kwargs,
**repo_loading_kwargs,
)
except OSError:
logger.info(
"Generation config file not found, using a generation config created from the model config."
)
pass
# Load custom generate function if `pretrained_model_name_or_path` defines it (and override `generate`)
if hasattr(model, "load_custom_generate"):
try:
custom_generate = model.load_custom_generate(
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **repo_loading_kwargs
)
model.generate = functools.partial(custom_generate, model=model)
except OSError: # there is no custom generate function
pass
# Dispatch model with hooks on all devices if necessary (not needed with a tp_plan, so we skip it as it slightly
# harm performances)

View File

@ -452,7 +452,7 @@ class _BaseAutoModelClass:
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
config = kwargs.pop("config", None)
trust_remote_code = kwargs.pop("trust_remote_code", None)
trust_remote_code = kwargs.get("trust_remote_code", None)
kwargs["_from_auto"] = True
hub_kwargs_names = [
"cache_dir",
@ -531,7 +531,6 @@ class _BaseAutoModelClass:
config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path,
return_unused_kwargs=True,
trust_remote_code=trust_remote_code,
code_revision=code_revision,
_commit_hash=commit_hash,
**hub_kwargs,
@ -549,6 +548,7 @@ class _BaseAutoModelClass:
trust_remote_code = resolve_trust_remote_code(
trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code
)
kwargs["trust_remote_code"] = trust_remote_code
# Set the adapter kwargs
kwargs["adapter_kwargs"] = adapter_kwargs
@ -730,13 +730,13 @@ def add_generation_mixin_to_remote_model(model_class):
# 3. Prior to v4.45, we could detect whether a model was `generate`-compatible if it had its own `generate` and/or
# `prepare_inputs_for_generation` method.
has_custom_generate = hasattr(model_class, "generate") and "GenerationMixin" not in str(
has_custom_generate_in_class = hasattr(model_class, "generate") and "GenerationMixin" not in str(
getattr(model_class, "generate")
)
has_custom_prepare_inputs = hasattr(model_class, "prepare_inputs_for_generation") and "GenerationMixin" not in str(
getattr(model_class, "prepare_inputs_for_generation")
)
if has_custom_generate or has_custom_prepare_inputs:
if has_custom_generate_in_class or has_custom_prepare_inputs:
model_class_with_generation_mixin = type(
model_class.__name__, (model_class, GenerationMixin), {**model_class.__dict__}
)

View File

@ -4954,6 +4954,68 @@ class GenerationIntegrationTests(unittest.TestCase):
_ = model_cpu.generate(input_ids, **generate_kwargs)
self.assertFalse(hasattr(model_cpu, "_compiled_call"))
def test_custom_generate_from_argument_in_generate(self):
"""Tests that the `custom_generate` argument is used when passed to `generate`"""
model = AutoModelForCausalLM.from_pretrained(
"hf-internal-testing/tiny-random-MistralForCausalLM", device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
model_inputs = tokenizer("Hello, world!", return_tensors="pt").to(model.device)
# Note: `transformers-community/custom_generate_example` has a custom decoding method with a `left_padding`
# argument (int), which prepends as many pad tokens.
gen_out = model.generate(
**model_inputs,
left_padding=5,
max_new_tokens=5,
custom_generate="transformers-community/custom_generate_example",
trust_remote_code=True,
)
text_output = tokenizer.decode(gen_out[0])
self.assertTrue(text_output.startswith("<unk><unk><unk><unk><unk>")) # <unk> is the pad token
def test_custom_generate_from_model_repo_with_custom_generate_code(self):
"""
Tests that models from model repos containing custom generation code override `generate` with the custom code
"""
model = AutoModelForCausalLM.from_pretrained(
"transformers-community/custom_generate_example", device_map="auto", trust_remote_code=True
)
generate_signature = inspect.signature(model.generate)
# `left_padding` is a custom argument, doesn't exist in the base `generate` method
self.assertTrue(generate_signature.parameters.get("left_padding"))
def test_custom_generate_bad_requirements(self):
"""Tests that we check the `requirements.txt` file from custom generation repos"""
model = AutoModelForCausalLM.from_pretrained(
"hf-internal-testing/tiny-random-MistralForCausalLM", device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
model_inputs = tokenizer("Hello, world!", return_tensors="pt").to(model.device)
with self.assertRaises(ImportError):
# Note: `transformers-community/custom_generate_bad_requirements` has a `requirements.txt` with
# impossible requirements
model.generate(
**model_inputs,
custom_generate="transformers-community/custom_generate_bad_requirements",
trust_remote_code=True,
)
def test_custom_generate_requires_trust_remote_code(self):
"""Tests that `trust_remote_code` is required when using `custom_generate`"""
# Case 1: A model from a repo containing custom generation code must be loaded with `trust_remote_code`
with self.assertRaises(ValueError):
AutoModelForCausalLM.from_pretrained("transformers-community/custom_generate_example", device_map="auto")
# Case 2: Using the `custom_generate` argument in `generate` requires `trust_remote_code` if the code is not
# local
model = AutoModelForCausalLM.from_pretrained(
"hf-internal-testing/tiny-random-MistralForCausalLM", device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
model_inputs = tokenizer("Hello, world!", return_tensors="pt").to(model.device)
with self.assertRaises(ValueError):
model.generate(**model_inputs, custom_generate="transformers-community/custom_generate_example")
@require_torch
class TokenHealingTestCase(unittest.TestCase):