Generate: inner decoding methods are no longer public (#29437)

This commit is contained in:
Joao Gante 2024-03-05 10:27:36 +00:00 committed by GitHub
parent 4d892b7297
commit 87a0783dde
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 117 additions and 104 deletions

View File

@ -389,3 +389,6 @@ just like in multinomial sampling. However, in assisted decoding, reducing the t
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Alice and Bob are going to the same party. It is a small party, in a small']
```
Alternativelly, you can also set the `prompt_lookup_num_tokens` to trigger n-gram based assisted decoding, as opposed
to model based assisted decoding. You can read more about it [here](https://twitter.com/joao_gante/status/1747322413006643259).

View File

@ -16,16 +16,7 @@ rendered properly in your Markdown viewer.
# Utilities for Generation
This page lists all the utility functions used by [`~generation.GenerationMixin.generate`],
[`~generation.GenerationMixin.greedy_search`],
[`~generation.GenerationMixin.contrastive_search`],
[`~generation.GenerationMixin.sample`],
[`~generation.GenerationMixin.beam_search`],
[`~generation.GenerationMixin.beam_sample`],
[`~generation.GenerationMixin.group_beam_search`], and
[`~generation.GenerationMixin.constrained_beam_search`].
Most of those are only useful if you are studying the code of the generate methods in the library.
This page lists all the utility functions used by [`~generation.GenerationMixin.generate`].
## Generate Outputs
@ -376,4 +367,4 @@ A [`Constraint`] can be used to force the generation to include specific tokens
[[autodoc]] StaticCache
- update
- get_seq_length
- get_seq_length

View File

@ -43,13 +43,6 @@ like token streaming.
[[autodoc]] generation.GenerationMixin
- generate
- compute_transition_scores
- greedy_search
- sample
- beam_search
- beam_sample
- contrastive_search
- group_beam_search
- constrained_beam_search
## TFGenerationMixin

View File

@ -17,15 +17,6 @@ rendered properly in your Markdown viewer.
# 発電用ユーティリティ
このページには、[`~generation.GenerationMixin.generate`] で使用されるすべてのユーティリティ関数がリストされています。
[`~generation.GenerationMixin.greedy_search`],
[`~generation.GenerationMixin.contrastive_search`],
[`~generation.GenerationMixin.sample`],
[`~generation.GenerationMixin.beam_search`],
[`~generation.GenerationMixin.beam_sample`],
[`~generation.GenerationMixin.group_beam_search`]、および
[`~generation.GenerationMixin.constrained_beam_search`]。
これらのほとんどは、ライブラリ内の生成メソッドのコードを学習する場合にのみ役に立ちます。
## 出力を生成する

View File

@ -43,13 +43,6 @@ rendered properly in your Markdown viewer.
[[autodoc]] generation.GenerationMixin
- generate
- compute_transition_scores
- greedy_search
- sample
- beam_search
- beam_sample
- contrastive_search
- group_beam_search
- constrained_beam_search
## TFGenerationMixin

View File

@ -16,16 +16,7 @@ rendered properly in your Markdown viewer.
# 用于生成的工具
此页面列出了所有由 [`~generation.GenerationMixin.generate`],
[`~generation.GenerationMixin.greedy_search`],
[`~generation.GenerationMixin.contrastive_search`],
[`~generation.GenerationMixin.sample`],
[`~generation.GenerationMixin.beam_search`],
[`~generation.GenerationMixin.beam_sample`],
[`~generation.GenerationMixin.group_beam_search`], 和
[`~generation.GenerationMixin.constrained_beam_search`]使用的实用函数。
其中大多数仅在您研究库中生成方法的代码时才有用。
此页面列出了所有由 [`~generation.GenerationMixin.generate`]。
## 生成输出

View File

@ -38,13 +38,6 @@ rendered properly in your Markdown viewer.
[[autodoc]] generation.GenerationMixin
- generate
- compute_transition_scores
- greedy_search
- sample
- beam_search
- beam_sample
- contrastive_search
- group_beam_search
- constrained_beam_search
## TFGenerationMixin

View File

@ -43,22 +43,22 @@ class GenerationConfig(PushToHubMixin):
Class that holds a configuration for a generation task. A `generate` call supports the following generation methods
for text-decoder, text-to-text, speech-to-text, and vision-to-text models:
- *greedy decoding* by calling [`~generation.GenerationMixin.greedy_search`] if `num_beams=1` and
- *greedy decoding* by calling [`~generation.GenerationMixin._greedy_search`] if `num_beams=1` and
`do_sample=False`
- *contrastive search* by calling [`~generation.GenerationMixin.contrastive_search`] if `penalty_alpha>0.`
- *contrastive search* by calling [`~generation.GenerationMixin._contrastive_search`] if `penalty_alpha>0.`
and `top_k>1`
- *multinomial sampling* by calling [`~generation.GenerationMixin.sample`] if `num_beams=1` and
- *multinomial sampling* by calling [`~generation.GenerationMixin._sample`] if `num_beams=1` and
`do_sample=True`
- *beam-search decoding* by calling [`~generation.GenerationMixin.beam_search`] if `num_beams>1` and
- *beam-search decoding* by calling [`~generation.GenerationMixin._beam_search`] if `num_beams>1` and
`do_sample=False`
- *beam-search multinomial sampling* by calling [`~generation.GenerationMixin.beam_sample`] if
- *beam-search multinomial sampling* by calling [`~generation.GenerationMixin._beam_sample`] if
`num_beams>1` and `do_sample=True`
- *diverse beam-search decoding* by calling [`~generation.GenerationMixin.group_beam_search`], if
- *diverse beam-search decoding* by calling [`~generation.GenerationMixin._group_beam_search`], if
`num_beams>1` and `num_beam_groups>1`
- *constrained beam-search decoding* by calling [`~generation.GenerationMixin.constrained_beam_search`], if
- *constrained beam-search decoding* by calling [`~generation.GenerationMixin._constrained_beam_search`], if
`constraints!=None` or `force_words_ids!=None`
- *assisted decoding* by calling [`~generation.GenerationMixin.assisted_decoding`], if
`assistant_model` is passed to `.generate()`
- *assisted decoding* by calling [`~generation.GenerationMixin._assisted_decoding`], if
`assistant_model` or `prompt_lookup_num_tokens` is passed to `.generate()`
You do not need to call any of the above methods directly. Pass custom parameter values to '.generate()'. To learn
more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).

View File

@ -347,20 +347,22 @@ class GenerationMixin:
A class containing all functions for auto-regressive text generation, to be used as a mixin in [`PreTrainedModel`].
The class exposes [`~generation.GenerationMixin.generate`], which can be used for:
- *greedy decoding* by calling [`~generation.GenerationMixin.greedy_search`] if `num_beams=1` and
- *greedy decoding* by calling [`~generation.GenerationMixin._greedy_search`] if `num_beams=1` and
`do_sample=False`
- *contrastive search* by calling [`~generation.GenerationMixin.contrastive_search`] if `penalty_alpha>0` and
- *contrastive search* by calling [`~generation.GenerationMixin._contrastive_search`] if `penalty_alpha>0` and
`top_k>1`
- *multinomial sampling* by calling [`~generation.GenerationMixin.sample`] if `num_beams=1` and
- *multinomial sampling* by calling [`~generation.GenerationMixin._sample`] if `num_beams=1` and
`do_sample=True`
- *beam-search decoding* by calling [`~generation.GenerationMixin.beam_search`] if `num_beams>1` and
- *beam-search decoding* by calling [`~generation.GenerationMixin._beam_search`] if `num_beams>1` and
`do_sample=False`
- *beam-search multinomial sampling* by calling [`~generation.GenerationMixin.beam_sample`] if `num_beams>1`
- *beam-search multinomial sampling* by calling [`~generation.GenerationMixin._beam_sample`] if `num_beams>1`
and `do_sample=True`
- *diverse beam-search decoding* by calling [`~generation.GenerationMixin.group_beam_search`], if `num_beams>1`
- *diverse beam-search decoding* by calling [`~generation.GenerationMixin._group_beam_search`], if `num_beams>1`
and `num_beam_groups>1`
- *constrained beam-search decoding* by calling [`~generation.GenerationMixin.constrained_beam_search`], if
- *constrained beam-search decoding* by calling [`~generation.GenerationMixin._constrained_beam_search`], if
`constraints!=None` or `force_words_ids!=None`
- *assisted decoding* by calling [`~generation.GenerationMixin._assisted_decoding`], if
`assistant_model` or `prompt_lookup_num_tokens` is passed to `.generate()`
You do not need to call any of the above methods directly. Pass custom parameter values to 'generate' instead. To
learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
@ -1547,7 +1549,7 @@ class GenerationMixin:
)
if generation_mode == GenerationMode.GREEDY_SEARCH:
# 11. run greedy search
result = self.greedy_search(
result = self._greedy_search(
input_ids,
logits_processor=prepared_logits_processor,
stopping_criteria=prepared_stopping_criteria,
@ -1565,7 +1567,7 @@ class GenerationMixin:
if not model_kwargs["use_cache"]:
raise ValueError("Contrastive search requires `use_cache=True`")
result = self.contrastive_search(
result = self._contrastive_search(
input_ids,
top_k=generation_config.top_k,
penalty_alpha=generation_config.penalty_alpha,
@ -1595,7 +1597,7 @@ class GenerationMixin:
)
# 13. run sample
result = self.sample(
result = self._sample(
input_ids,
logits_processor=prepared_logits_processor,
logits_warper=logits_warper,
@ -1629,7 +1631,7 @@ class GenerationMixin:
**model_kwargs,
)
# 13. run beam search
result = self.beam_search(
result = self._beam_search(
input_ids,
beam_scorer,
logits_processor=prepared_logits_processor,
@ -1668,7 +1670,7 @@ class GenerationMixin:
)
# 14. run beam sample
result = self.beam_sample(
result = self._beam_sample(
input_ids,
beam_scorer,
logits_processor=prepared_logits_processor,
@ -1703,7 +1705,7 @@ class GenerationMixin:
**model_kwargs,
)
# 13. run beam search
result = self.group_beam_search(
result = self._group_beam_search(
input_ids,
beam_scorer,
logits_processor=prepared_logits_processor,
@ -1777,7 +1779,7 @@ class GenerationMixin:
**model_kwargs,
)
# 13. run beam search
result = self.constrained_beam_search(
result = self._constrained_beam_search(
input_ids,
constrained_beam_scorer=constrained_beam_scorer,
logits_processor=prepared_logits_processor,
@ -1801,8 +1803,15 @@ class GenerationMixin:
return result
def contrastive_search(self, *args, **kwargs):
logger.warning_once(
"Calling `contrastive_search` directly is deprecated and will be removed in v4.41. Use `generate` or a "
"custom generation loop instead.",
)
return self._contrastive_search(*args, **kwargs)
@torch.no_grad()
def contrastive_search(
def _contrastive_search(
self,
input_ids: torch.LongTensor,
top_k: Optional[int] = 1,
@ -1828,7 +1837,7 @@ class GenerationMixin:
<Tip warning={true}>
In most cases, you do not need to call [`~generation.GenerationMixin.contrastive_search`] directly. Use
In most cases, you do not need to call [`~generation.GenerationMixin._contrastive_search`] directly. Use
generate() instead. For an overview of generation strategies and code examples, check the [following
guide](../generation_strategies).
@ -1902,7 +1911,7 @@ class GenerationMixin:
>>> input_prompt = "DeepMind Company is"
>>> input_ids = tokenizer(input_prompt, return_tensors="pt")
>>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=64)])
>>> outputs = model.contrastive_search(
>>> outputs = model._contrastive_search(
... **input_ids, penalty_alpha=0.6, top_k=4, stopping_criteria=stopping_criteria
... )
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
@ -2243,7 +2252,14 @@ class GenerationMixin:
else:
return input_ids
def greedy_search(
def greedy_search(self, *args, **kwargs):
logger.warning_once(
"Calling `greedy_search` directly is deprecated and will be removed in v4.41. Use `generate` or a "
"custom generation loop instead.",
)
return self._greedy_search(*args, **kwargs)
def _greedy_search(
self,
input_ids: torch.LongTensor,
logits_processor: Optional[LogitsProcessorList] = None,
@ -2266,7 +2282,7 @@ class GenerationMixin:
<Tip warning={true}>
In most cases, you do not need to call [`~generation.GenerationMixin.greedy_search`] directly. Use generate()
In most cases, you do not need to call [`~generation.GenerationMixin._greedy_search`] directly. Use generate()
instead. For an overview of generation strategies and code examples, check the [following
guide](../generation_strategies).
@ -2348,7 +2364,7 @@ class GenerationMixin:
... )
>>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
>>> outputs = model.greedy_search(
>>> outputs = model._greedy_search(
... input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria
... )
@ -2514,7 +2530,14 @@ class GenerationMixin:
else:
return input_ids
def sample(
def sample(self, *args, **kwargs):
logger.warning_once(
"Calling `sample` directly is deprecated and will be removed in v4.41. Use `generate` or a "
"custom generation loop instead.",
)
return self._sample(*args, **kwargs)
def _sample(
self,
input_ids: torch.LongTensor,
logits_processor: Optional[LogitsProcessorList] = None,
@ -2538,7 +2561,7 @@ class GenerationMixin:
<Tip warning={true}>
In most cases, you do not need to call [`~generation.GenerationMixin.sample`] directly. Use generate() instead.
In most cases, you do not need to call [`~generation.GenerationMixin._sample`] directly. Use generate() instead.
For an overview of generation strategies and code examples, check the [following
guide](../generation_strategies).
@ -2635,7 +2658,7 @@ class GenerationMixin:
>>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
>>> torch.manual_seed(0) # doctest: +IGNORE_RESULT
>>> outputs = model.sample(
>>> outputs = model._sample(
... input_ids,
... logits_processor=logits_processor,
... logits_warper=logits_warper,
@ -2832,7 +2855,14 @@ class GenerationMixin:
past_key_values.reorder_cache(beam_idx)
return past_key_values
def beam_search(
def beam_search(self, *args, **kwargs):
logger.warning_once(
"Calling `beam_search` directly is deprecated and will be removed in v4.41. Use `generate` or a "
"custom generation loop instead.",
)
return self._beam_search(*args, **kwargs)
def _beam_search(
self,
input_ids: torch.LongTensor,
beam_scorer: BeamScorer,
@ -2856,7 +2886,7 @@ class GenerationMixin:
<Tip warning={true}>
In most cases, you do not need to call [`~generation.GenerationMixin.beam_search`] directly. Use generate()
In most cases, you do not need to call [`~generation.GenerationMixin._beam_search`] directly. Use generate()
instead. For an overview of generation strategies and code examples, check the [following
guide](../generation_strategies).
@ -2958,7 +2988,7 @@ class GenerationMixin:
... ]
... )
>>> outputs = model.beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs)
>>> outputs = model._beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Wie alt bist du?']
@ -3214,7 +3244,14 @@ class GenerationMixin:
else:
return sequence_outputs["sequences"]
def beam_sample(
def beam_sample(self, *args, **kwargs):
logger.warning_once(
"Calling `beam_sample` directly is deprecated and will be removed in v4.41. Use `generate` or a "
"custom generation loop instead.",
)
return self._beam_sample(*args, **kwargs)
def _beam_sample(
self,
input_ids: torch.LongTensor,
beam_scorer: BeamScorer,
@ -3238,7 +3275,7 @@ class GenerationMixin:
<Tip warning={true}>
In most cases, you do not need to call [`~generation.GenerationMixin.beam_sample`] directly. Use generate()
In most cases, you do not need to call [`~generation.GenerationMixin._beam_sample`] directly. Use generate()
instead. For an overview of generation strategies and code examples, check the [following
guide](../generation_strategies).
@ -3346,7 +3383,7 @@ class GenerationMixin:
... ]
... )
>>> outputs = model.beam_sample(
>>> outputs = model._beam_sample(
... input_ids, beam_scorer, logits_processor=logits_processor, logits_warper=logits_warper, **model_kwargs
... )
@ -3561,7 +3598,14 @@ class GenerationMixin:
else:
return sequence_outputs["sequences"]
def group_beam_search(
def group_beam_search(self, *args, **kwargs):
logger.warning_once(
"Calling `group_beam_search` directly is deprecated and will be removed in v4.41. Use `generate` or a "
"custom generation loop instead.",
)
return self._group_beam_search(*args, **kwargs)
def _group_beam_search(
self,
input_ids: torch.LongTensor,
beam_scorer: BeamScorer,
@ -3584,7 +3628,7 @@ class GenerationMixin:
<Tip warning={true}>
In most cases, you do not need to call [`~generation.GenerationMixin.group_beam_search`] directly. Use
In most cases, you do not need to call [`~generation.GenerationMixin._group_beam_search`] directly. Use
generate() instead. For an overview of generation strategies and code examples, check the [following
guide](../generation_strategies).
@ -3686,7 +3730,7 @@ class GenerationMixin:
... ]
... )
>>> outputs = model.group_beam_search(
>>> outputs = model._group_beam_search(
... input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs
... )
@ -3958,7 +4002,14 @@ class GenerationMixin:
else:
return sequence_outputs["sequences"]
def constrained_beam_search(
def constrained_beam_search(self, *args, **kwargs):
logger.warning_once(
"Calling `constrained_beam_search` directly is deprecated and will be removed in v4.41. Use `generate` or a "
"custom generation loop instead.",
)
return self._constrained_beam_search(*args, **kwargs)
def _constrained_beam_search(
self,
input_ids: torch.LongTensor,
constrained_beam_scorer: ConstrainedBeamSearchScorer,
@ -3981,7 +4032,7 @@ class GenerationMixin:
<Tip warning={true}>
In most cases, you do not need to call [`~generation.GenerationMixin.constrained_beam_search`] directly. Use
In most cases, you do not need to call [`~generation.GenerationMixin._constrained_beam_search`] directly. Use
generate() instead. For an overview of generation strategies and code examples, check the [following
guide](../generation_strategies).
@ -4088,7 +4139,7 @@ class GenerationMixin:
... ]
... )
>>> outputs = model.constrained_beam_search(
>>> outputs = model._constrained_beam_search(
... input_ids, beam_scorer, constraints=constraints, logits_processor=logits_processor, **model_kwargs
... )
@ -4311,7 +4362,14 @@ class GenerationMixin:
else:
return sequence_outputs["sequences"]
def assisted_decoding(
def assisted_decoding(self, *args, **kwargs):
logger.warning_once(
"Calling `_assisted_decoding` directly is deprecated and will be removed in v4.41. Use `generate` or a "
"custom generation loop instead.",
)
return self._assisted_decoding(*args, **kwargs)
def _assisted_decoding(
self,
input_ids: torch.LongTensor,
candidate_generator: Optional["CandidateGenerator"] = None,
@ -4338,7 +4396,7 @@ class GenerationMixin:
<Tip warning={true}>
In most cases, you do not need to call [`~generation.GenerationMixin.candidate_decoding`] directly. Use
In most cases, you do not need to call [`~generation.GenerationMixin._assisted_decoding`] directly. Use
generate() instead. For an overview of generation strategies and code examples, check the [following
guide](../generation_strategies).
@ -4429,7 +4487,7 @@ class GenerationMixin:
... logits_processor=logits_processor,
... model_kwargs={},
... )
>>> outputs = model.assisted_decoding(
>>> outputs = model._assisted_decoding(
... input_ids,
... candidate_generator=candidate_generator,
... logits_processor=logits_processor,

View File

@ -1336,7 +1336,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
)
# 11. run greedy search
outputs = self.greedy_search(
outputs = self._greedy_search(
input_ids,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
@ -1361,7 +1361,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
)
# 12. run sample
outputs = self.sample(
outputs = self._sample(
input_ids,
logits_processor=logits_processor,
logits_warper=logits_warper,
@ -2402,7 +2402,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
)
# 11. run greedy search
outputs = self.greedy_search(
outputs = self._greedy_search(
input_ids,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
@ -2428,7 +2428,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
)
# 12. run sample
outputs = self.sample(
outputs = self._sample(
input_ids,
logits_processor=logits_processor,
logits_warper=logits_warper,

View File

@ -1539,7 +1539,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing"
" greedy search."
)
return self.greedy_search(
return self._greedy_search(
input_ids,
logits_processor=pre_processor,
max_length=generation_config.max_length,
@ -1559,7 +1559,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
num_beam_hyps_to_keep=generation_config.num_return_sequences,
max_length=generation_config.max_length,
)
return self.beam_search(
return self._beam_search(
input_ids,
beam_scorer,
logits_processor=pre_processor,