mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-12 17:20:03 +06:00

* generate from config mvp * fix failing tests * max_time test * Load default gen config at model load time; Update docs * further documentation; add tests * adapt rag to the new structure * handle models not instantiated with from_pretained (like in tests) * better default generation config * add can_generate fn * handle legacy use case of ad hoc model config changes * initialize gen config from config in individual methods, if gen config is none * fix _get_decoder_start_token_id when called outside GenerationMixin * correct model config load order (set attr > model config > decoder config) * update rag to match latest changes * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * load gen config from model config in model.from_pretrained * fix can_generate fn * handle generate calls without a previous from_pretrained (e.g. tests) * add legacy behavior (and a warning) * lower logger severity Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
117 lines
4.9 KiB
Plaintext
117 lines
4.9 KiB
Plaintext
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
|
the License. You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
|
specific language governing permissions and limitations under the License.
|
|
-->
|
|
|
|
# Generation
|
|
|
|
Each framework has a generate method for auto-regressive text generation implemented in their respective `GenerationMixin` class:
|
|
|
|
- PyTorch [`~generation.GenerationMixin.generate`] is implemented in [`~generation.GenerationMixin`].
|
|
- TensorFlow [`~generation.TFGenerationMixin.generate`] is implemented in [`~generation.TFGenerationMixin`].
|
|
- Flax/JAX [`~generation.FlaxGenerationMixin.generate`] is implemented in [`~generation.FlaxGenerationMixin`].
|
|
|
|
Regardless of your framework of choice, you can parameterize the generate method with a [`~generation.GenerationConfig`]
|
|
class instance. Please refer to this class for the complete list of generation parameters, which control the behavior
|
|
of the generation method.
|
|
|
|
All models have a default generation configuration that will be used if you don't provide one. If you have a loaded
|
|
model instance `model`, you can inspect the default generation configuration with `model.generation_config`. If you'd
|
|
like to set a new default generation configuration, you can create a new [`~generation.GenerationConfig`] instance and
|
|
store it with `save_pretrained`, making sure to leave its `config_file_name` argument empty.
|
|
|
|
```python
|
|
from transformers import AutoModelForCausalLM, GenerationConfig
|
|
|
|
model = AutoModelForCausalLM.from_pretrained("my_account/my_model")
|
|
|
|
# Inspect the default generation configuration
|
|
print(model.generation_config)
|
|
|
|
# Set a new default generation configuration
|
|
generation_config = GenerationConfig(
|
|
max_new_tokens=50, do_sample=True, top_k=50, eos_token_id=model.config.eos_token_id
|
|
)
|
|
generation_config.save_pretrained("my_account/my_model", push_to_hub=True)
|
|
```
|
|
|
|
<Tip>
|
|
|
|
If you inspect a serialized [`~generation.GenerationConfig`] file or print a class instance, you will notice that
|
|
default values are omitted. Some attributes, like `max_length`, have a conservative default value, to avoid running
|
|
into resource limitations. Make sure you double-check the defaults in the documentation.
|
|
|
|
</Tip>
|
|
|
|
You can also store several generation parametrizations in a single directory, making use of the `config_file_name`
|
|
argument in `save_pretrained`. You can latter instantiate them with `from_pretrained`. This is useful if you want to
|
|
store several generation configurations for a single model (e.g. one for creative text generation with sampling, and
|
|
other for summarization with beam search).
|
|
|
|
```python
|
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("t5-small")
|
|
model = AutoModelForSeq2SeqLM.from_pretrained("t5-small")
|
|
|
|
translation_generation_config = GenerationConfig(
|
|
num_beams=4,
|
|
early_stopping=True,
|
|
decoder_start_token_id=0,
|
|
eos_token_id=model.config.eos_token_id,
|
|
pad_token=model.config.pad_token_id,
|
|
)
|
|
# If you were working on a model for which your had the right Hub permissions, you could store a named generation
|
|
# config as follows
|
|
translation_generation_config.save_pretrained("t5-small", "translation_generation_config.json", push_to_hub=True)
|
|
|
|
# You could then use the named generation config file to parameterize generation
|
|
generation_config = GenerationConfig.from_pretrained("t5-small", "translation_generation_config.json")
|
|
inputs = tokenizer("translate English to French: Configuration files are easy to use!", return_tensors="pt")
|
|
outputs = model.generate(**inputs, generation_config=generation_config)
|
|
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
|
|
# ['Les fichiers de configuration sont faciles à utiliser !']
|
|
```
|
|
|
|
Finally, you can specify ad hoc modifications to the used generation configuration by passing the attribute you
|
|
wish to override directly to the generate method (e.g. `model.generate(inputs, max_new_tokens=512)`). Each
|
|
framework's `generate` method docstring (available below) has a few illustrative examples on the different strategies
|
|
to parameterize it.
|
|
|
|
|
|
## GenerationConfig
|
|
|
|
[[autodoc]] generation.GenerationConfig
|
|
- from_pretrained
|
|
- from_model_config
|
|
- save_pretrained
|
|
|
|
## GenerationMixin
|
|
|
|
[[autodoc]] generation.GenerationMixin
|
|
- generate
|
|
- greedy_search
|
|
- sample
|
|
- beam_search
|
|
- beam_sample
|
|
- contrastive_search
|
|
- group_beam_search
|
|
- constrained_beam_search
|
|
|
|
## TFGenerationMixin
|
|
|
|
[[autodoc]] generation.TFGenerationMixin
|
|
- generate
|
|
|
|
## FlaxGenerationMixin
|
|
|
|
[[autodoc]] generation.FlaxGenerationMixin
|
|
- generate
|