mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Refactoring of the text generate API docs (#21112)
* initial commit, refactoring the text generation api reference * removed repetitive code examples * Refactoring the text generation docs to reduce repetition * make style
This commit is contained in:
parent
d386fd646a
commit
0248810300
@ -12,7 +12,7 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
# Generation
|
||||
|
||||
Each framework has a generate method for auto-regressive text generation implemented in their respective `GenerationMixin` class:
|
||||
Each framework has a generate method for text generation implemented in their respective `GenerationMixin` class:
|
||||
|
||||
- PyTorch [`~generation.GenerationMixin.generate`] is implemented in [`~generation.GenerationMixin`].
|
||||
- TensorFlow [`~generation.TFGenerationMixin.generate`] is implemented in [`~generation.TFGenerationMixin`].
|
||||
@ -22,69 +22,9 @@ Regardless of your framework of choice, you can parameterize the generate method
|
||||
class instance. Please refer to this class for the complete list of generation parameters, which control the behavior
|
||||
of the generation method.
|
||||
|
||||
All models have a default generation configuration that will be used if you don't provide one. If you have a loaded
|
||||
model instance `model`, you can inspect the default generation configuration with `model.generation_config`. If you'd
|
||||
like to set a new default generation configuration, you can create a new [`~generation.GenerationConfig`] instance and
|
||||
store it with `save_pretrained`, making sure to leave its `config_file_name` argument empty.
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, GenerationConfig
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("my_account/my_model")
|
||||
|
||||
# Inspect the default generation configuration
|
||||
print(model.generation_config)
|
||||
|
||||
# Set a new default generation configuration
|
||||
generation_config = GenerationConfig(
|
||||
max_new_tokens=50, do_sample=True, top_k=50, eos_token_id=model.config.eos_token_id
|
||||
)
|
||||
generation_config.save_pretrained("my_account/my_model", push_to_hub=True)
|
||||
```
|
||||
|
||||
<Tip>
|
||||
|
||||
If you inspect a serialized [`~generation.GenerationConfig`] file or print a class instance, you will notice that
|
||||
default values are omitted. Some attributes, like `max_length`, have a conservative default value, to avoid running
|
||||
into resource limitations. Make sure you double-check the defaults in the documentation.
|
||||
|
||||
</Tip>
|
||||
|
||||
You can also store several generation parametrizations in a single directory, making use of the `config_file_name`
|
||||
argument in `save_pretrained`. You can latter instantiate them with `from_pretrained`. This is useful if you want to
|
||||
store several generation configurations for a single model (e.g. one for creative text generation with sampling, and
|
||||
other for summarization with beam search).
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("t5-small")
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained("t5-small")
|
||||
|
||||
translation_generation_config = GenerationConfig(
|
||||
num_beams=4,
|
||||
early_stopping=True,
|
||||
decoder_start_token_id=0,
|
||||
eos_token_id=model.config.eos_token_id,
|
||||
pad_token=model.config.pad_token_id,
|
||||
)
|
||||
# If you were working on a model for which your had the right Hub permissions, you could store a named generation
|
||||
# config as follows
|
||||
translation_generation_config.save_pretrained("t5-small", "translation_generation_config.json", push_to_hub=True)
|
||||
|
||||
# You could then use the named generation config file to parameterize generation
|
||||
generation_config = GenerationConfig.from_pretrained("t5-small", "translation_generation_config.json")
|
||||
inputs = tokenizer("translate English to French: Configuration files are easy to use!", return_tensors="pt")
|
||||
outputs = model.generate(**inputs, generation_config=generation_config)
|
||||
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
|
||||
# ['Les fichiers de configuration sont faciles à utiliser !']
|
||||
```
|
||||
|
||||
Finally, you can specify ad hoc modifications to the used generation configuration by passing the attribute you
|
||||
wish to override directly to the generate method (e.g. `model.generate(inputs, max_new_tokens=512)`). Each
|
||||
framework's `generate` method docstring (available below) has a few illustrative examples on the different strategies
|
||||
to parameterize it.
|
||||
|
||||
To learn how to inspect a model's generation configuration, what are the defaults, how to change the parameters ad hoc,
|
||||
and how to create and save a customized generation configuration, refer to the
|
||||
[text generation strategies guide](./generation_strategies).
|
||||
|
||||
## GenerationConfig
|
||||
|
||||
|
@ -41,29 +41,22 @@ class GenerationConfig(PushToHubMixin):
|
||||
for text-decoder, text-to-text, speech-to-text, and vision-to-text models:
|
||||
|
||||
- *greedy decoding* by calling [`~generation.GenerationMixin.greedy_search`] if `num_beams=1` and
|
||||
`do_sample=False`.
|
||||
`do_sample=False`
|
||||
- *contrastive search* by calling [`~generation.GenerationMixin.contrastive_search`] if `penalty_alpha>0.`
|
||||
and `top_k>1`
|
||||
- *multinomial sampling* by calling [`~generation.GenerationMixin.sample`] if `num_beams=1` and
|
||||
`do_sample=True`.
|
||||
`do_sample=True`
|
||||
- *beam-search decoding* by calling [`~generation.GenerationMixin.beam_search`] if `num_beams>1` and
|
||||
`do_sample=False`.
|
||||
`do_sample=False`
|
||||
- *beam-search multinomial sampling* by calling [`~generation.GenerationMixin.beam_sample`] if
|
||||
`num_beams>1` and `do_sample=True`.
|
||||
`num_beams>1` and `do_sample=True`
|
||||
- *diverse beam-search decoding* by calling [`~generation.GenerationMixin.group_beam_search`], if
|
||||
`num_beams>1` and `num_beam_groups>1`.
|
||||
`num_beams>1` and `num_beam_groups>1`
|
||||
- *constrained beam-search decoding* by calling [`~generation.GenerationMixin.constrained_beam_search`], if
|
||||
`constraints!=None` or `force_words_ids!=None`.
|
||||
`constraints!=None` or `force_words_ids!=None`
|
||||
|
||||
<Tip>
|
||||
|
||||
A generation configuration file can be loaded and saved to disk. Loading and using a generation configuration file
|
||||
does **not** change a model configuration or weights. It only affects the model's behavior at generation time.
|
||||
|
||||
</Tip>
|
||||
|
||||
Most of these parameters are explained in more detail in [this blog
|
||||
post](https://huggingface.co/blog/how-to-generate).
|
||||
You do not need to call any of the above methods directly. Pass custom parameter values to 'generate'. To learn
|
||||
more about decoding strategies refer to the [text generation strategies guide](./generation_strategies).
|
||||
|
||||
Arg:
|
||||
> Parameters that control the length of the output
|
||||
|
@ -131,11 +131,14 @@ class FlaxGenerationMixin:
|
||||
|
||||
The class exposes [`~generation.FlaxGenerationMixin.generate`], which can be used for:
|
||||
- *greedy decoding* by calling [`~generation.FlaxGenerationMixin._greedy_search`] if `num_beams=1` and
|
||||
`do_sample=False`.
|
||||
`do_sample=False`
|
||||
- *multinomial sampling* by calling [`~generation.FlaxGenerationMixin._sample`] if `num_beams=1` and
|
||||
`do_sample=True`.
|
||||
`do_sample=True`
|
||||
- *beam-search decoding* by calling [`~generation.FlaxGenerationMixin._beam_search`] if `num_beams>1` and
|
||||
`do_sample=False`.
|
||||
`do_sample=False`
|
||||
|
||||
You do not need to call any of the above methods directly. Pass custom parameter values to 'generate' instead. To
|
||||
learn more about decoding strategies refer to the [text generation strategies guide](./generation_strategies).
|
||||
"""
|
||||
|
||||
def prepare_inputs_for_generation(self, *args, **kwargs):
|
||||
@ -225,26 +228,7 @@ class FlaxGenerationMixin:
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Generates sequences of token ids for models with a language modeling head. The method supports the following
|
||||
generation methods for text-decoder, text-to-text, speech-to-text, and vision-to-text models:
|
||||
|
||||
- *greedy decoding* by calling [`~generation.FlaxGenerationMixin._greedy_search`] if `num_beams=1` and
|
||||
`do_sample=False`.
|
||||
- *multinomial sampling* by calling [`~generation.FlaxGenerationMixin._sample`] if `num_beams=1` and
|
||||
`do_sample=True`.
|
||||
- *beam-search decoding* by calling [`~generation.FlaxGenerationMixin._beam_search`] if `num_beams>1` and
|
||||
`do_sample=False`.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
|
||||
model's default generation configuration. You can override any `generation_config` by passing the corresponding
|
||||
parameters to generate, e.g. `.generate(inputs, num_beams=4, do_sample=True)`.
|
||||
|
||||
For a complete overview of generate, check the [following
|
||||
guide](https://huggingface.co/docs/transformers/en/main_classes/text_generation).
|
||||
|
||||
</Tip>
|
||||
Generates sequences of token ids for models with a language modeling head.
|
||||
|
||||
Parameters:
|
||||
input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
|
||||
@ -269,72 +253,7 @@ class FlaxGenerationMixin:
|
||||
Return:
|
||||
[`~utils.ModelOutput`].
|
||||
|
||||
Examples:
|
||||
|
||||
Greedy decoding, using the default generation configuration and ad hoc modifications:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, FlaxAutoModelForCausalLM
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
>>> model = FlaxAutoModelForCausalLM.from_pretrained("gpt2")
|
||||
|
||||
>>> prompt = "Today I believe we can finally"
|
||||
>>> input_ids = tokenizer(prompt, return_tensors="np").input_ids
|
||||
|
||||
>>> # Generate up to 30 tokens
|
||||
>>> outputs = model.generate(input_ids, do_sample=False, max_length=30)
|
||||
>>> tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)
|
||||
['Today I believe we can finally get to the point where we can make a difference in the lives of the people of the United States of America.\n']
|
||||
```
|
||||
|
||||
Multinomial sampling, modifying an existing generation configuration:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, FlaxAutoModelForCausalLM, GenerationConfig
|
||||
>>> import numpy as np
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
>>> model = FlaxAutoModelForCausalLM.from_pretrained("gpt2")
|
||||
|
||||
>>> prompt = "Today I believe we can finally"
|
||||
>>> input_ids = tokenizer(prompt, return_tensors="np").input_ids
|
||||
|
||||
>>> # Sample up to 30 tokens
|
||||
>>> generation_config = GenerationConfig.from_pretrained("gpt2")
|
||||
>>> generation_config.max_length = 30
|
||||
>>> generation_config.do_sample = True
|
||||
>>> outputs = model.generate(
|
||||
... input_ids, generation_config=generation_config, prng_key=np.asarray([0, 0], dtype=np.uint32)
|
||||
... )
|
||||
>>> tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)
|
||||
['Today I believe we can finally get a change in that system. The way I saw it was this: a few years ago, this company would not']
|
||||
```
|
||||
|
||||
Beam-search decoding, using a freshly initialized generation configuration:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, FlaxAutoModelForSeq2SeqLM, GenerationConfig
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de")
|
||||
>>> model = FlaxAutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-de")
|
||||
|
||||
>>> sentence = "Paris is one of the densest populated areas in Europe."
|
||||
>>> input_ids = tokenizer(sentence, return_tensors="np").input_ids
|
||||
|
||||
>>> generation_config = GenerationConfig(
|
||||
... max_length=64,
|
||||
... num_beams=5,
|
||||
... bos_token_id=0,
|
||||
... eos_token_id=0,
|
||||
... decoder_start_token_id=58100,
|
||||
... pad_token_id=58100,
|
||||
... bad_words_ids=[[58100]],
|
||||
... )
|
||||
>>> outputs = model.generate(input_ids, generation_config=generation_config)
|
||||
>>> tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)
|
||||
['Paris ist eines der dichtesten besiedelten Gebiete Europas.']
|
||||
```"""
|
||||
"""
|
||||
# Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
|
||||
self._validate_model_class()
|
||||
|
||||
|
@ -432,12 +432,15 @@ class TFGenerationMixin:
|
||||
|
||||
The class exposes [`~generation.TFGenerationMixin.generate`], which can be used for:
|
||||
- *greedy decoding* by calling [`~generation.TFGenerationMixin.greedy_search`] if `num_beams=1` and
|
||||
`do_sample=False`.
|
||||
`do_sample=False`
|
||||
- *contrastive search* by calling [`~generation.TFGenerationMixin.contrastive_search`] if `penalty_alpha>0` and
|
||||
`top_k>1`
|
||||
- *multinomial sampling* by calling [`~generation.TFGenerationMixin.sample`] if `num_beams=1` and
|
||||
`do_sample=True`.
|
||||
- *beam-search decoding* by calling [`~generation.TFGenerationMixin.beam_search`] if `num_beams>1`.
|
||||
`do_sample=True`
|
||||
- *beam-search decoding* by calling [`~generation.TFGenerationMixin.beam_search`] if `num_beams>1`
|
||||
|
||||
You do not need to call any of the above methods directly. Pass custom parameter values to 'generate' instead. To
|
||||
learn more about decoding strategies refer to the [text generation strategies guide](./generation_strategies).
|
||||
"""
|
||||
|
||||
_seed_generator = None
|
||||
@ -541,8 +544,8 @@ class TFGenerationMixin:
|
||||
model's default generation configuration. You can override any `generation_config` by passing the corresponding
|
||||
parameters to generate, e.g. `.generate(inputs, num_beams=4, do_sample=True)`.
|
||||
|
||||
For a complete overview of generate, check the [following
|
||||
guide](https://huggingface.co/docs/transformers/en/main_classes/text_generation).
|
||||
For an overview of generation strategies and code examples, check out the [following
|
||||
guide](./generation_strategies).
|
||||
|
||||
</Tip>
|
||||
|
||||
@ -585,69 +588,7 @@ class TFGenerationMixin:
|
||||
- [`~generation.TFBeamSearchEncoderDecoderOutput`],
|
||||
- [`~generation.TFBeamSampleEncoderDecoderOutput`]
|
||||
|
||||
Examples:
|
||||
|
||||
Greedy decoding, using the default generation configuration and ad hoc modifications:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, TFAutoModelForCausalLM
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
>>> model = TFAutoModelForCausalLM.from_pretrained("gpt2")
|
||||
|
||||
>>> prompt = "Today I believe we can finally"
|
||||
>>> input_ids = tokenizer(prompt, return_tensors="tf").input_ids
|
||||
|
||||
>>> # Generate up to 30 tokens
|
||||
>>> outputs = model.generate(input_ids, do_sample=False, max_length=30)
|
||||
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
['Today I believe we can finally get to the point where we can make a difference in the lives of the people of the United States of America.\n']
|
||||
```
|
||||
|
||||
Multinomial sampling, modifying an existing generation configuration:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, TFAutoModelForCausalLM, GenerationConfig
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
>>> model = TFAutoModelForCausalLM.from_pretrained("gpt2")
|
||||
|
||||
>>> prompt = "Today I believe we can finally"
|
||||
>>> input_ids = tokenizer(prompt, return_tensors="tf").input_ids
|
||||
|
||||
>>> # Sample up to 30 tokens
|
||||
>>> generation_config = GenerationConfig.from_pretrained("gpt2")
|
||||
>>> generation_config.max_length = 30
|
||||
>>> generation_config.do_sample = True
|
||||
>>> outputs = model.generate(input_ids, generation_config=generation_config, seed=[0, 0])
|
||||
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
["Today I believe we can finally start taking a bold stand against climate change and climate change mitigation efforts such as President Obama's climate ban and President Trump's"]
|
||||
```
|
||||
|
||||
Beam-search decoding, using a freshly initialized generation configuration:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, TFAutoModelForSeq2SeqLM, GenerationConfig
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de")
|
||||
>>> model = TFAutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-de")
|
||||
|
||||
>>> sentence = "Paris is one of the densest populated areas in Europe."
|
||||
>>> input_ids = tokenizer(sentence, return_tensors="tf").input_ids
|
||||
|
||||
>>> generation_config = GenerationConfig(
|
||||
... max_length=64,
|
||||
... num_beams=5,
|
||||
... bos_token_id=0,
|
||||
... eos_token_id=0,
|
||||
... decoder_start_token_id=58100,
|
||||
... pad_token_id=58100,
|
||||
... bad_words_ids=[[58100]],
|
||||
... )
|
||||
>>> outputs = model.generate(input_ids, generation_config=generation_config)
|
||||
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
['Paris ist eines der dichtesten besiedelten Gebiete Europas.']
|
||||
```"""
|
||||
"""
|
||||
|
||||
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
|
||||
self._validate_model_class()
|
||||
|
@ -466,19 +466,22 @@ class GenerationMixin:
|
||||
|
||||
The class exposes [`~generation.GenerationMixin.generate`], which can be used for:
|
||||
- *greedy decoding* by calling [`~generation.GenerationMixin.greedy_search`] if `num_beams=1` and
|
||||
`do_sample=False`.
|
||||
`do_sample=False`
|
||||
- *contrastive search* by calling [`~generation.GenerationMixin.contrastive_search`] if `penalty_alpha>0` and
|
||||
`top_k>1`
|
||||
- *multinomial sampling* by calling [`~generation.GenerationMixin.sample`] if `num_beams=1` and
|
||||
`do_sample=True`.
|
||||
`do_sample=True`
|
||||
- *beam-search decoding* by calling [`~generation.GenerationMixin.beam_search`] if `num_beams>1` and
|
||||
`do_sample=False`.
|
||||
`do_sample=False`
|
||||
- *beam-search multinomial sampling* by calling [`~generation.GenerationMixin.beam_sample`] if `num_beams>1`
|
||||
and `do_sample=True`.
|
||||
and `do_sample=True`
|
||||
- *diverse beam-search decoding* by calling [`~generation.GenerationMixin.group_beam_search`], if `num_beams>1`
|
||||
and `num_beam_groups>1`.
|
||||
and `num_beam_groups>1`
|
||||
- *constrained beam-search decoding* by calling [`~generation.GenerationMixin.constrained_beam_search`], if
|
||||
`constraints!=None` or `force_words_ids!=None`.
|
||||
`constraints!=None` or `force_words_ids!=None`
|
||||
|
||||
You do not need to call any of the above methods directly. Pass custom parameter values to 'generate' instead. To
|
||||
learn more about decoding strategies refer to the [text generation strategies guide](./generation_strategies).
|
||||
"""
|
||||
|
||||
def prepare_inputs_for_generation(self, *args, **kwargs):
|
||||
@ -1018,10 +1021,10 @@ class GenerationMixin:
|
||||
|
||||
Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
|
||||
model's default generation configuration. You can override any `generation_config` by passing the corresponding
|
||||
parameters to generate, e.g. `.generate(inputs, num_beams=4, do_sample=True)`.
|
||||
parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`.
|
||||
|
||||
For a complete overview of generate, check the [following
|
||||
guide](https://huggingface.co/docs/transformers/en/main_classes/text_generation).
|
||||
For an overview of generation strategies and code examples, check out the [following
|
||||
guide](./generation_strategies).
|
||||
|
||||
</Tip>
|
||||
|
||||
@ -1079,72 +1082,7 @@ class GenerationMixin:
|
||||
- [`~generation.SampleEncoderDecoderOutput`],
|
||||
- [`~generation.BeamSearchEncoderDecoderOutput`],
|
||||
- [`~generation.BeamSampleEncoderDecoderOutput`]
|
||||
|
||||
Examples:
|
||||
|
||||
Greedy decoding, using the default generation configuration and ad hoc modifications:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
|
||||
|
||||
>>> prompt = "Today I believe we can finally"
|
||||
>>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids
|
||||
|
||||
>>> # Generate up to 30 tokens
|
||||
>>> outputs = model.generate(input_ids, do_sample=False, max_length=30)
|
||||
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
['Today I believe we can finally get to the point where we can make a difference in the lives of the people of the United States of America.\n']
|
||||
```
|
||||
|
||||
Multinomial sampling, modifying an existing generation configuration:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
|
||||
>>> import torch
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
|
||||
|
||||
>>> prompt = "Today I believe we can finally"
|
||||
>>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids
|
||||
|
||||
>>> # Sample up to 30 tokens
|
||||
>>> torch.manual_seed(0) # doctest: +IGNORE_RESULT
|
||||
>>> generation_config = GenerationConfig.from_pretrained("gpt2")
|
||||
>>> generation_config.max_length = 30
|
||||
>>> generation_config.do_sample = True
|
||||
>>> outputs = model.generate(input_ids, generation_config=generation_config)
|
||||
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
['Today I believe we can finally get rid of discrimination," said Rep. Mark Pocan (D-Wis.).\n\n"Just look at the']
|
||||
```
|
||||
|
||||
Beam-search decoding, using a freshly initialized generation configuration:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, GenerationConfig
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de")
|
||||
>>> model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-de")
|
||||
|
||||
>>> sentence = "Paris is one of the densest populated areas in Europe."
|
||||
>>> input_ids = tokenizer(sentence, return_tensors="pt").input_ids
|
||||
|
||||
>>> generation_config = GenerationConfig(
|
||||
... max_length=64,
|
||||
... num_beams=5,
|
||||
... bos_token_id=0,
|
||||
... eos_token_id=0,
|
||||
... decoder_start_token_id=58100,
|
||||
... pad_token_id=58100,
|
||||
... bad_words_ids=[[58100]],
|
||||
... )
|
||||
>>> outputs = model.generate(input_ids, generation_config=generation_config)
|
||||
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
['Paris ist eines der dichtesten besiedelten Gebiete Europas.']
|
||||
```"""
|
||||
"""
|
||||
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
|
||||
self._validate_model_class()
|
||||
|
||||
@ -1650,6 +1588,14 @@ class GenerationMixin:
|
||||
Generates sequences of token ids for models with a language modeling head using **contrastive search** and can
|
||||
be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
In most cases, you do not need to call [`~generation.GenerationMixin.contrastive_search`] directly. Use
|
||||
generate() instead. For an overview of generation strategies and code examples, check the [following
|
||||
guide](./generation_strategies).
|
||||
|
||||
</Tip>
|
||||
|
||||
Parameters:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
The sequence used as a prompt for the generation.
|
||||
@ -1998,6 +1944,15 @@ class GenerationMixin:
|
||||
Generates sequences of token ids for models with a language modeling head using **greedy decoding** and can be
|
||||
used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
In most cases, you do not need to call [`~generation.GenerationMixin.greedy_search`] directly. Use generate()
|
||||
instead. For an overview of generation strategies and code examples, check the [following
|
||||
guide](./generation_strategies).
|
||||
|
||||
</Tip>
|
||||
|
||||
|
||||
Parameters:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
The sequence used as a prompt for the generation.
|
||||
@ -2233,6 +2188,14 @@ class GenerationMixin:
|
||||
Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and
|
||||
can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
In most cases, you do not need to call [`~generation.GenerationMixin.sample`] directly. Use generate() instead.
|
||||
For an overview of generation strategies and code examples, check the [following
|
||||
guide](./generation_strategies).
|
||||
|
||||
</Tip>
|
||||
|
||||
Parameters:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
The sequence used as a prompt for the generation.
|
||||
@ -2491,6 +2454,14 @@ class GenerationMixin:
|
||||
Generates sequences of token ids for models with a language modeling head using **beam search decoding** and
|
||||
can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
In most cases, you do not need to call [`~generation.GenerationMixin.beam_search`] directly. Use generate()
|
||||
instead. For an overview of generation strategies and code examples, check the [following
|
||||
guide](./generation_strategies).
|
||||
|
||||
</Tip>
|
||||
|
||||
Parameters:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
The sequence used as a prompt for the generation.
|
||||
@ -2807,6 +2778,14 @@ class GenerationMixin:
|
||||
Generates sequences of token ids for models with a language modeling head using **beam search multinomial
|
||||
sampling** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
In most cases, you do not need to call [`~generation.GenerationMixin.beam_sample`] directly. Use generate()
|
||||
instead. For an overview of generation strategies and code examples, check the [following
|
||||
guide](./generation_strategies).
|
||||
|
||||
</Tip>
|
||||
|
||||
Parameters:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
The sequence used as a prompt for the generation.
|
||||
@ -3128,6 +3107,14 @@ class GenerationMixin:
|
||||
Generates sequences of token ids for models with a language modeling head using **diverse beam search
|
||||
decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
In most cases, you do not need to call [`~generation.GenerationMixin.group_beam_search`] directly. Use
|
||||
generate() instead. For an overview of generation strategies and code examples, check the [following
|
||||
guide](./generation_strategies).
|
||||
|
||||
</Tip>
|
||||
|
||||
Parameters:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
The sequence used as a prompt for the generation.
|
||||
@ -3498,6 +3485,14 @@ class GenerationMixin:
|
||||
Generates sequences of token ids for models with a language modeling head using **constrained beam search
|
||||
decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
In most cases, you do not need to call [`~generation.GenerationMixin.constrained_beam_search`] directly. Use
|
||||
generate() instead. For an overview of generation strategies and code examples, check the [following
|
||||
guide](./generation_strategies).
|
||||
|
||||
</Tip>
|
||||
|
||||
Parameters:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
The sequence used as a prompt for the generation.
|
||||
|
Loading…
Reference in New Issue
Block a user