mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
[Docs] Improve PyTorch, Flax generate API (#15988)
* Move generate docs * up * Update docs/source/_toctree.yml * correct * correct some stuff * correct tests * more fixes * finish generate * add to doc stest * finish * finalize * add warning to generate method
This commit is contained in:
parent
0951d31788
commit
6ce11c2c0f
@ -114,6 +114,8 @@
|
||||
title: Logging
|
||||
- local: main_classes/model
|
||||
title: Models
|
||||
- local: main_classes/text_generation
|
||||
title: Text Generation
|
||||
- local: main_classes/onnx
|
||||
title: ONNX
|
||||
- local: main_classes/optimizer_schedules
|
||||
|
@ -86,14 +86,6 @@ Due to Pytorch design, this functionality is only available for floating dtypes.
|
||||
- push_to_hub
|
||||
- all
|
||||
|
||||
## Generation
|
||||
|
||||
[[autodoc]] generation_utils.GenerationMixin
|
||||
|
||||
[[autodoc]] generation_tf_utils.TFGenerationMixin
|
||||
|
||||
[[autodoc]] generation_flax_utils.FlaxGenerationMixin
|
||||
|
||||
## Pushing to the Hub
|
||||
|
||||
[[autodoc]] file_utils.PushToHubMixin
|
||||
|
39
docs/source/main_classes/text_generation.mdx
Normal file
39
docs/source/main_classes/text_generation.mdx
Normal file
@ -0,0 +1,39 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Generation
|
||||
|
||||
The methods for auto-regressive text generation, namely [`~generation_utils.GenerationMixin.generate`] (for the PyTorch models), [`~generation_tf_utils.TFGenerationMixin.generate`] (for the TensorFlow models) and [`~generation_flax_utils.FlaxGenerationMixin.generate`] (for the Flax/JAX models), are implemented in [`~generation_utils.GenerationMixin`], [`~generation_tf_utils.TFGenerationMixin`] and [`~generation_flax_utils.FlaxGenerationMixin`] respectively.
|
||||
|
||||
The `GenerationMixin` classes are inherited by the corresponding base model classes, *e.g.* [`PreTrainedModel`], [`TFPreTrainedModel`], and [`FlaxPreTrainedModel`] respectively, therefore exposing all
|
||||
methods for auto-regressive text generation to every model class.
|
||||
|
||||
## GenerationMixn
|
||||
|
||||
[[autodoc]] generation_utils.GenerationMixin
|
||||
- generate
|
||||
- greedy_search
|
||||
- sample
|
||||
- beam_search
|
||||
- beam_sample
|
||||
- group_beam_search
|
||||
- constrained_beam_search
|
||||
|
||||
## TFGenerationMixn
|
||||
|
||||
[[autodoc]] generation_tf_utils.TFGenerationMixin
|
||||
- generate
|
||||
|
||||
## FlaxGenerationMixn
|
||||
|
||||
[[autodoc]] generation_flax_utils.FlaxGenerationMixin
|
||||
- generate
|
@ -118,7 +118,16 @@ class BeamSearchState:
|
||||
|
||||
class FlaxGenerationMixin:
|
||||
"""
|
||||
A class containing all of the functions supporting generation, to be used as a mixin in [`FlaxPreTrainedModel`].
|
||||
A class containing all functions for auto-regressive text generation, to be used as a mixin in
|
||||
[`FlaxPreTrainedModel`].
|
||||
|
||||
The class exposes [`~generation_flax_utils.FlaxGenerationMixin.generate`], which can be used for:
|
||||
- *greedy decoding* by calling [`~generation_flax_utils.FlaxGenerationMixin._greedy_search`] if
|
||||
`num_beams=1` and `do_sample=False`.
|
||||
- *multinomial sampling* by calling [`~generation_flax_utils.FlaxGenerationMixin._sample`] if `num_beams=1`
|
||||
and `do_sample=True`.
|
||||
- *beam-search decoding* by calling [`~generation_utils.FlaxGenerationMixin._beam_search`] if `num_beams>1`
|
||||
and `do_sample=False`.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@ -176,12 +185,23 @@ class FlaxGenerationMixin:
|
||||
**model_kwargs,
|
||||
):
|
||||
r"""
|
||||
Generates sequences for models with a language modeling head. The method currently supports greedy decoding,
|
||||
and, multinomial sampling.
|
||||
Generates sequences of token ids for models with a language modeling head. The method supports the following
|
||||
generation methods for text-decoder, text-to-text, speech-to-text, and vision-to-text models:
|
||||
|
||||
Apart from `input_ids`, all the arguments below will default to the value of the attribute of the same name
|
||||
inside the [`PretrainedConfig`] of the model. The default values indicated are the default values of those
|
||||
config.
|
||||
- *greedy decoding* by calling [`~generation_flax_utils.FlaxGenerationMixin._greedy_search`] if
|
||||
`num_beams=1` and `do_sample=False`.
|
||||
- *multinomial sampling* by calling [`~generation_flax_utils.FlaxGenerationMixin._sample`] if `num_beams=1`
|
||||
and `do_sample=True`.
|
||||
- *beam-search decoding* by calling [`~generation_utils.FlaxGenerationMixin._beam_search`] if `num_beams>1`
|
||||
and `do_sample=False`.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Apart from `inputs`, all the arguments below will default to the value of the attribute of the same name as
|
||||
defined in the model's config (`config.json`) which in turn defaults to the
|
||||
[`~modeling_utils.PretrainedConfig`] of the model.
|
||||
|
||||
</Tip>
|
||||
|
||||
Most of these parameters are explained in more detail in [this blog
|
||||
post](https://huggingface.co/blog/how-to-generate).
|
||||
@ -236,7 +256,7 @@ class FlaxGenerationMixin:
|
||||
>>> input_ids = tokenizer(input_context, return_tensors="np").input_ids
|
||||
>>> # generate candidates using sampling
|
||||
>>> outputs = model.generate(input_ids=input_ids, max_length=20, top_k=30, do_sample=True)
|
||||
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
|
||||
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
```"""
|
||||
# set init values
|
||||
max_length = max_length if max_length is not None else self.config.max_length
|
||||
|
@ -377,7 +377,21 @@ BeamSampleOutput = Union[BeamSampleEncoderDecoderOutput, BeamSampleDecoderOnlyOu
|
||||
|
||||
class GenerationMixin:
|
||||
"""
|
||||
A class containing all of the functions supporting generation, to be used as a mixin in [`PreTrainedModel`].
|
||||
A class containing all functions for auto-regressive text generation, to be used as a mixin in [`PreTrainedModel`].
|
||||
|
||||
The class exposes [`~generation_utils.GenerationMixin.generate`], which can be used for:
|
||||
- *greedy decoding* by calling [`~generation_utils.GenerationMixin.greedy_search`] if `num_beams=1` and
|
||||
`do_sample=False`.
|
||||
- *multinomial sampling* by calling [`~generation_utils.GenerationMixin.sample`] if `num_beams=1` and
|
||||
`do_sample=True`.
|
||||
- *beam-search decoding* by calling [`~generation_utils.GenerationMixin.beam_search`] if `num_beams>1` and
|
||||
`do_sample=False`.
|
||||
- *beam-search multinomial sampling* by calling [`~generation_utils.GenerationMixin.beam_sample`] if
|
||||
`num_beams>1` and `do_sample=True`.
|
||||
- *diverse beam-search decoding* by calling [`~generation_utils.GenerationMixin.group_beam_search`], if
|
||||
`num_beams>1` and `num_beam_groups>1`.
|
||||
- *constrained beam-search decoding* by calling [`~generation_utils.GenerationMixin.constrained_beam_search`],
|
||||
if `constraints!=None` or `force_words_ids!=None`.
|
||||
"""
|
||||
|
||||
def _prepare_model_inputs(
|
||||
@ -847,18 +861,37 @@ class GenerationMixin:
|
||||
**model_kwargs,
|
||||
) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]:
|
||||
r"""
|
||||
Generates sequences for models with a language modeling head. The method currently supports greedy decoding,
|
||||
multinomial sampling, beam-search decoding, and beam-search multinomial sampling.
|
||||
|
||||
Apart from `inputs`, all the arguments below will default to the value of the attribute of the same name inside
|
||||
the [`PretrainedConfig`] of the model. The default values indicated are the default values of those config.
|
||||
Generates sequences of token ids for models with a language modeling head. The method supports the following
|
||||
generation methods for text-decoder, text-to-text, speech-to-text, and vision-to-text models:
|
||||
|
||||
- *greedy decoding* by calling [`~generation_utils.GenerationMixin.greedy_search`] if `num_beams=1` and
|
||||
`do_sample=False`.
|
||||
- *multinomial sampling* by calling [`~generation_utils.GenerationMixin.sample`] if `num_beams=1` and
|
||||
`do_sample=True`.
|
||||
- *beam-search decoding* by calling [`~generation_utils.GenerationMixin.beam_search`] if `num_beams>1` and
|
||||
`do_sample=False`.
|
||||
- *beam-search multinomial sampling* by calling [`~generation_utils.GenerationMixin.beam_sample`] if
|
||||
`num_beams>1` and `do_sample=True`.
|
||||
- *diverse beam-search decoding* by calling [`~generation_utils.GenerationMixin.group_beam_search`], if
|
||||
`num_beams>1` and `num_beam_groups>1`.
|
||||
- *constrained beam-search decoding* by calling
|
||||
[`~generation_utils.GenerationMixin.constrained_beam_search`], if `constraints!=None` or
|
||||
`force_words_ids!=None`.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Apart from `inputs`, all the arguments below will default to the value of the attribute of the same name as
|
||||
defined in the model's config (`config.json`) which in turn defaults to the
|
||||
[`~modeling_utils.PretrainedConfig`] of the model.
|
||||
|
||||
</Tip>
|
||||
|
||||
Most of these parameters are explained in more detail in [this blog
|
||||
post](https://huggingface.co/blog/how-to-generate).
|
||||
|
||||
Parameters:
|
||||
inputs (`torch.Tensor` of shape `(batch_size, sequence_length)`, `(batch_size, sequence_length,
|
||||
feature_dim)` or `(batch_size, num_channels, height, width)`, *optional*):
|
||||
inputs (`torch.Tensor` of varying shape depending on the modality, *optional*):
|
||||
The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the
|
||||
method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`
|
||||
should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of
|
||||
@ -997,66 +1030,56 @@ class GenerationMixin:
|
||||
|
||||
Examples:
|
||||
|
||||
Greedy Decoding:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
|
||||
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
|
||||
>>> model = AutoModelForCausalLM.from_pretrained("distilgpt2")
|
||||
>>> # do greedy decoding without providing a prompt
|
||||
>>> outputs = model.generate(max_length=40)
|
||||
>>> print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True))
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("t5-base")
|
||||
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
|
||||
>>> document = (
|
||||
... "at least two people were killed in a suspected bomb attack on a passenger bus "
|
||||
... "in the strife-torn southern philippines on monday , the military said."
|
||||
... )
|
||||
>>> # encode input context
|
||||
>>> input_ids = tokenizer(document, return_tensors="pt").input_ids
|
||||
>>> # generate 3 independent sequences using beam search decoding (5 beams)
|
||||
>>> # with T5 encoder-decoder model conditioned on short news article.
|
||||
>>> outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3)
|
||||
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
|
||||
>>> model = AutoModelForCausalLM.from_pretrained("distilgpt2")
|
||||
>>> input_context = "The dog"
|
||||
>>> # encode input context
|
||||
>>> input_ids = tokenizer(input_context, return_tensors="pt").input_ids
|
||||
>>> # generate 3 candidates using sampling
|
||||
>>> outputs = model.generate(input_ids=input_ids, max_length=20, num_return_sequences=3, do_sample=True)
|
||||
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("ctrl")
|
||||
>>> model = AutoModelForCausalLM.from_pretrained("ctrl")
|
||||
>>> # "Legal" is one of the control codes for ctrl
|
||||
>>> input_context = "Legal My neighbor is"
|
||||
>>> # encode input context
|
||||
>>> input_ids = tokenizer(input_context, return_tensors="pt").input_ids
|
||||
>>> outputs = model.generate(input_ids=input_ids, max_length=20, repetition_penalty=1.2)
|
||||
>>> print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True))
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2", use_fast=False)
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
|
||||
>>> input_context = "My cute dog"
|
||||
>>> # get tokens of words that should not be generated
|
||||
>>> bad_words_ids = tokenizer(
|
||||
... ["idiot", "stupid", "shut up"], add_prefix_space=True, add_special_tokens=False
|
||||
>>> ).input_ids
|
||||
>>> # get tokens of words that we want generated
|
||||
>>> force_words_ids = tokenizer(["runs", "loves"], add_prefix_space=True, add_special_tokens=False).input_ids
|
||||
>>> # encode input context
|
||||
>>> input_ids = tokenizer(input_context, return_tensors="pt").input_ids
|
||||
>>> # generate sequences without allowing bad_words to be generated
|
||||
>>> outputs = model.generate(
|
||||
... input_ids=input_ids,
|
||||
... max_length=20,
|
||||
... do_sample=True,
|
||||
... bad_words_ids=bad_words_ids,
|
||||
... force_words_ids=force_words_ids,
|
||||
... )
|
||||
>>> print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True))
|
||||
|
||||
>>> prompt = "Today I believe we can finally"
|
||||
>>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids
|
||||
|
||||
>>> # generate up to 30 tokens
|
||||
>>> outputs = model.generate(input_ids, do_sample=False, max_length=30)
|
||||
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
['Today I believe we can finally get to the point where we can make a difference in the lives of the people of the United States of America.\n']
|
||||
```
|
||||
|
||||
Multinomial Sampling:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
>>> import torch
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
|
||||
|
||||
>>> prompt = "Today I believe we can finally"
|
||||
>>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids
|
||||
|
||||
>>> # sample up to 30 tokens
|
||||
>>> torch.manual_seed(0) # doctest: +IGNORE_RESULT
|
||||
>>> outputs = model.generate(input_ids, do_sample=True, max_length=30)
|
||||
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
['Today I believe we can finally get rid of discrimination," said Rep. Mark Pocan (D-Wis.).\n\n"Just look at the']
|
||||
```
|
||||
|
||||
Beam-search decoding:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de")
|
||||
>>> model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-de")
|
||||
|
||||
>>> sentence = "Paris is one of the densest populated areas in Europe."
|
||||
>>> input_ids = tokenizer(sentence, return_tensors="pt").input_ids
|
||||
|
||||
>>> outputs = model.generate(input_ids)
|
||||
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
['Paris ist eines der dichtesten besiedelten Gebiete Europas.']
|
||||
```"""
|
||||
# 1. Set generation parameters if not already defined
|
||||
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
|
||||
@ -1457,7 +1480,8 @@ class GenerationMixin:
|
||||
**model_kwargs,
|
||||
) -> Union[GreedySearchOutput, torch.LongTensor]:
|
||||
r"""
|
||||
Generates sequences for models with a language modeling head using greedy decoding.
|
||||
Generates sequences of token ids for models with a language modeling head using **greedy decoding** and can be
|
||||
used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
|
||||
|
||||
Parameters:
|
||||
|
||||
@ -1508,6 +1532,8 @@ class GenerationMixin:
|
||||
... AutoModelForCausalLM,
|
||||
... LogitsProcessorList,
|
||||
... MinLengthLogitsProcessor,
|
||||
... StoppingCriteriaList,
|
||||
... MaxLengthCriteria,
|
||||
... )
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
@ -1516,26 +1542,30 @@ class GenerationMixin:
|
||||
>>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token
|
||||
>>> model.config.pad_token_id = model.config.eos_token_id
|
||||
|
||||
>>> input_prompt = "Today is a beautiful day, and"
|
||||
>>> input_prompt = "It might be possible to"
|
||||
>>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids
|
||||
|
||||
>>> # instantiate logits processors
|
||||
>>> logits_processor = LogitsProcessorList(
|
||||
... [
|
||||
... MinLengthLogitsProcessor(15, eos_token_id=model.config.eos_token_id),
|
||||
... MinLengthLogitsProcessor(10, eos_token_id=model.config.eos_token_id),
|
||||
... ]
|
||||
... )
|
||||
>>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
|
||||
|
||||
>>> outputs = model.greedy_search(input_ids, logits_processor=logits_processor)
|
||||
>>> outputs = model.greedy_search(
|
||||
... input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria
|
||||
... )
|
||||
|
||||
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
|
||||
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
["It might be possible to get a better understanding of the nature of the problem, but it's not"]
|
||||
```"""
|
||||
# init values
|
||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||
if max_length is not None:
|
||||
warnings.warn(
|
||||
"`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
|
||||
"`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.",
|
||||
UserWarning,
|
||||
)
|
||||
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
|
||||
@ -1683,7 +1713,8 @@ class GenerationMixin:
|
||||
**model_kwargs,
|
||||
) -> Union[SampleOutput, torch.LongTensor]:
|
||||
r"""
|
||||
Generates sequences for models with a language modeling head using multinomial sampling.
|
||||
Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and
|
||||
can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
|
||||
|
||||
Parameters:
|
||||
|
||||
@ -1739,7 +1770,10 @@ class GenerationMixin:
|
||||
... MinLengthLogitsProcessor,
|
||||
... TopKLogitsWarper,
|
||||
... TemperatureLogitsWarper,
|
||||
... StoppingCriteriaList,
|
||||
... MaxLengthCriteria,
|
||||
... )
|
||||
>>> import torch
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
|
||||
@ -1764,9 +1798,18 @@ class GenerationMixin:
|
||||
... ]
|
||||
... )
|
||||
|
||||
>>> outputs = model.sample(input_ids, logits_processor=logits_processor, logits_warper=logits_warper)
|
||||
>>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
|
||||
|
||||
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
|
||||
>>> torch.manual_seed(0) # doctest: +IGNORE_RESULT
|
||||
>>> outputs = model.sample(
|
||||
... input_ids,
|
||||
... logits_processor=logits_processor,
|
||||
... logits_warper=logits_warper,
|
||||
... stopping_criteria=stopping_criteria,
|
||||
... )
|
||||
|
||||
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
['Today is a beautiful day, and a wonderful day.\n\nI was lucky enough to meet the']
|
||||
```"""
|
||||
|
||||
# init values
|
||||
@ -1926,7 +1969,8 @@ class GenerationMixin:
|
||||
**model_kwargs,
|
||||
) -> Union[BeamSearchOutput, torch.LongTensor]:
|
||||
r"""
|
||||
Generates sequences for models with a language modeling head using beam search decoding.
|
||||
Generates sequences of token ids for models with a language modeling head using **beam search decoding** and
|
||||
can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
|
||||
|
||||
Parameters:
|
||||
|
||||
@ -2020,7 +2064,8 @@ class GenerationMixin:
|
||||
|
||||
>>> outputs = model.beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs)
|
||||
|
||||
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
|
||||
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
['Wie alt bist du?']
|
||||
```"""
|
||||
# init values
|
||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||
@ -2237,7 +2282,8 @@ class GenerationMixin:
|
||||
**model_kwargs,
|
||||
) -> Union[BeamSampleOutput, torch.LongTensor]:
|
||||
r"""
|
||||
Generates sequences for models with a language modeling head using beam search with multinomial sampling.
|
||||
Generates sequences of token ids for models with a language modeling head using **beam search multinomial
|
||||
sampling** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
|
||||
|
||||
Parameters:
|
||||
|
||||
@ -2343,7 +2389,8 @@ class GenerationMixin:
|
||||
... input_ids, beam_scorer, logits_processor=logits_processor, logits_warper=logits_warper, **model_kwargs
|
||||
... )
|
||||
|
||||
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
|
||||
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
['Wie alt bist du?']
|
||||
```"""
|
||||
# init values
|
||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||
@ -2556,7 +2603,8 @@ class GenerationMixin:
|
||||
**model_kwargs,
|
||||
):
|
||||
r"""
|
||||
Generates sequences for models with a language modeling head using beam search decoding.
|
||||
Generates sequences of token ids for models with a language modeling head using **diverse beam search
|
||||
decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
|
||||
|
||||
Parameters:
|
||||
|
||||
@ -2656,7 +2704,8 @@ class GenerationMixin:
|
||||
... input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs
|
||||
... )
|
||||
|
||||
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
|
||||
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
['Wie alt bist du?']
|
||||
```"""
|
||||
# init values
|
||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||
@ -2920,7 +2969,8 @@ class GenerationMixin:
|
||||
) -> Union[BeamSearchOutput, torch.LongTensor]:
|
||||
|
||||
r"""
|
||||
Generates sequences for models with a language modeling head using beam search decoding.
|
||||
Generates sequences of token ids for models with a language modeling head using **constrained beam search
|
||||
decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
|
||||
|
||||
Parameters:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
@ -3024,8 +3074,8 @@ class GenerationMixin:
|
||||
... input_ids, beam_scorer, constraints=constraints, logits_processor=logits_processor, **model_kwargs
|
||||
... )
|
||||
|
||||
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
|
||||
# => ['Wie alter sind Sie?']
|
||||
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
['Wie alt sind Sie?']
|
||||
```"""
|
||||
# init values
|
||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||
|
@ -28,5 +28,6 @@ src/transformers/models/pegasus/modeling_pegasus.py
|
||||
src/transformers/models/blenderbot/modeling_blenderbot.py
|
||||
src/transformers/models/blenderbot_small/modeling_blenderbot_small.py
|
||||
src/transformers/models/plbart/modeling_plbart.py
|
||||
src/transformers/generation_utils.py
|
||||
docs/source/quicktour.mdx
|
||||
docs/source/task_summary.mdx
|
||||
docs/source/task_summary.mdx
|
||||
|
Loading…
Reference in New Issue
Block a user