mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Generate: inner decoding methods are no longer public (#29437)
This commit is contained in:
parent
4d892b7297
commit
87a0783dde
@ -389,3 +389,6 @@ just like in multinomial sampling. However, in assisted decoding, reducing the t
|
||||
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
['Alice and Bob are going to the same party. It is a small party, in a small']
|
||||
```
|
||||
|
||||
Alternativelly, you can also set the `prompt_lookup_num_tokens` to trigger n-gram based assisted decoding, as opposed
|
||||
to model based assisted decoding. You can read more about it [here](https://twitter.com/joao_gante/status/1747322413006643259).
|
||||
|
@ -16,16 +16,7 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
# Utilities for Generation
|
||||
|
||||
This page lists all the utility functions used by [`~generation.GenerationMixin.generate`],
|
||||
[`~generation.GenerationMixin.greedy_search`],
|
||||
[`~generation.GenerationMixin.contrastive_search`],
|
||||
[`~generation.GenerationMixin.sample`],
|
||||
[`~generation.GenerationMixin.beam_search`],
|
||||
[`~generation.GenerationMixin.beam_sample`],
|
||||
[`~generation.GenerationMixin.group_beam_search`], and
|
||||
[`~generation.GenerationMixin.constrained_beam_search`].
|
||||
|
||||
Most of those are only useful if you are studying the code of the generate methods in the library.
|
||||
This page lists all the utility functions used by [`~generation.GenerationMixin.generate`].
|
||||
|
||||
## Generate Outputs
|
||||
|
||||
@ -376,4 +367,4 @@ A [`Constraint`] can be used to force the generation to include specific tokens
|
||||
|
||||
[[autodoc]] StaticCache
|
||||
- update
|
||||
- get_seq_length
|
||||
- get_seq_length
|
||||
|
@ -43,13 +43,6 @@ like token streaming.
|
||||
[[autodoc]] generation.GenerationMixin
|
||||
- generate
|
||||
- compute_transition_scores
|
||||
- greedy_search
|
||||
- sample
|
||||
- beam_search
|
||||
- beam_sample
|
||||
- contrastive_search
|
||||
- group_beam_search
|
||||
- constrained_beam_search
|
||||
|
||||
## TFGenerationMixin
|
||||
|
||||
|
@ -17,15 +17,6 @@ rendered properly in your Markdown viewer.
|
||||
# 発電用ユーティリティ
|
||||
|
||||
このページには、[`~generation.GenerationMixin.generate`] で使用されるすべてのユーティリティ関数がリストされています。
|
||||
[`~generation.GenerationMixin.greedy_search`],
|
||||
[`~generation.GenerationMixin.contrastive_search`],
|
||||
[`~generation.GenerationMixin.sample`],
|
||||
[`~generation.GenerationMixin.beam_search`],
|
||||
[`~generation.GenerationMixin.beam_sample`],
|
||||
[`~generation.GenerationMixin.group_beam_search`]、および
|
||||
[`~generation.GenerationMixin.constrained_beam_search`]。
|
||||
|
||||
これらのほとんどは、ライブラリ内の生成メソッドのコードを学習する場合にのみ役に立ちます。
|
||||
|
||||
## 出力を生成する
|
||||
|
||||
|
@ -43,13 +43,6 @@ rendered properly in your Markdown viewer.
|
||||
[[autodoc]] generation.GenerationMixin
|
||||
- generate
|
||||
- compute_transition_scores
|
||||
- greedy_search
|
||||
- sample
|
||||
- beam_search
|
||||
- beam_sample
|
||||
- contrastive_search
|
||||
- group_beam_search
|
||||
- constrained_beam_search
|
||||
|
||||
## TFGenerationMixin
|
||||
|
||||
|
@ -16,16 +16,7 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
# 用于生成的工具
|
||||
|
||||
此页面列出了所有由 [`~generation.GenerationMixin.generate`],
|
||||
[`~generation.GenerationMixin.greedy_search`],
|
||||
[`~generation.GenerationMixin.contrastive_search`],
|
||||
[`~generation.GenerationMixin.sample`],
|
||||
[`~generation.GenerationMixin.beam_search`],
|
||||
[`~generation.GenerationMixin.beam_sample`],
|
||||
[`~generation.GenerationMixin.group_beam_search`], 和
|
||||
[`~generation.GenerationMixin.constrained_beam_search`]使用的实用函数。
|
||||
|
||||
其中大多数仅在您研究库中生成方法的代码时才有用。
|
||||
此页面列出了所有由 [`~generation.GenerationMixin.generate`]。
|
||||
|
||||
## 生成输出
|
||||
|
||||
|
@ -38,13 +38,6 @@ rendered properly in your Markdown viewer.
|
||||
[[autodoc]] generation.GenerationMixin
|
||||
- generate
|
||||
- compute_transition_scores
|
||||
- greedy_search
|
||||
- sample
|
||||
- beam_search
|
||||
- beam_sample
|
||||
- contrastive_search
|
||||
- group_beam_search
|
||||
- constrained_beam_search
|
||||
|
||||
## TFGenerationMixin
|
||||
|
||||
|
@ -43,22 +43,22 @@ class GenerationConfig(PushToHubMixin):
|
||||
Class that holds a configuration for a generation task. A `generate` call supports the following generation methods
|
||||
for text-decoder, text-to-text, speech-to-text, and vision-to-text models:
|
||||
|
||||
- *greedy decoding* by calling [`~generation.GenerationMixin.greedy_search`] if `num_beams=1` and
|
||||
- *greedy decoding* by calling [`~generation.GenerationMixin._greedy_search`] if `num_beams=1` and
|
||||
`do_sample=False`
|
||||
- *contrastive search* by calling [`~generation.GenerationMixin.contrastive_search`] if `penalty_alpha>0.`
|
||||
- *contrastive search* by calling [`~generation.GenerationMixin._contrastive_search`] if `penalty_alpha>0.`
|
||||
and `top_k>1`
|
||||
- *multinomial sampling* by calling [`~generation.GenerationMixin.sample`] if `num_beams=1` and
|
||||
- *multinomial sampling* by calling [`~generation.GenerationMixin._sample`] if `num_beams=1` and
|
||||
`do_sample=True`
|
||||
- *beam-search decoding* by calling [`~generation.GenerationMixin.beam_search`] if `num_beams>1` and
|
||||
- *beam-search decoding* by calling [`~generation.GenerationMixin._beam_search`] if `num_beams>1` and
|
||||
`do_sample=False`
|
||||
- *beam-search multinomial sampling* by calling [`~generation.GenerationMixin.beam_sample`] if
|
||||
- *beam-search multinomial sampling* by calling [`~generation.GenerationMixin._beam_sample`] if
|
||||
`num_beams>1` and `do_sample=True`
|
||||
- *diverse beam-search decoding* by calling [`~generation.GenerationMixin.group_beam_search`], if
|
||||
- *diverse beam-search decoding* by calling [`~generation.GenerationMixin._group_beam_search`], if
|
||||
`num_beams>1` and `num_beam_groups>1`
|
||||
- *constrained beam-search decoding* by calling [`~generation.GenerationMixin.constrained_beam_search`], if
|
||||
- *constrained beam-search decoding* by calling [`~generation.GenerationMixin._constrained_beam_search`], if
|
||||
`constraints!=None` or `force_words_ids!=None`
|
||||
- *assisted decoding* by calling [`~generation.GenerationMixin.assisted_decoding`], if
|
||||
`assistant_model` is passed to `.generate()`
|
||||
- *assisted decoding* by calling [`~generation.GenerationMixin._assisted_decoding`], if
|
||||
`assistant_model` or `prompt_lookup_num_tokens` is passed to `.generate()`
|
||||
|
||||
You do not need to call any of the above methods directly. Pass custom parameter values to '.generate()'. To learn
|
||||
more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
|
||||
|
@ -347,20 +347,22 @@ class GenerationMixin:
|
||||
A class containing all functions for auto-regressive text generation, to be used as a mixin in [`PreTrainedModel`].
|
||||
|
||||
The class exposes [`~generation.GenerationMixin.generate`], which can be used for:
|
||||
- *greedy decoding* by calling [`~generation.GenerationMixin.greedy_search`] if `num_beams=1` and
|
||||
- *greedy decoding* by calling [`~generation.GenerationMixin._greedy_search`] if `num_beams=1` and
|
||||
`do_sample=False`
|
||||
- *contrastive search* by calling [`~generation.GenerationMixin.contrastive_search`] if `penalty_alpha>0` and
|
||||
- *contrastive search* by calling [`~generation.GenerationMixin._contrastive_search`] if `penalty_alpha>0` and
|
||||
`top_k>1`
|
||||
- *multinomial sampling* by calling [`~generation.GenerationMixin.sample`] if `num_beams=1` and
|
||||
- *multinomial sampling* by calling [`~generation.GenerationMixin._sample`] if `num_beams=1` and
|
||||
`do_sample=True`
|
||||
- *beam-search decoding* by calling [`~generation.GenerationMixin.beam_search`] if `num_beams>1` and
|
||||
- *beam-search decoding* by calling [`~generation.GenerationMixin._beam_search`] if `num_beams>1` and
|
||||
`do_sample=False`
|
||||
- *beam-search multinomial sampling* by calling [`~generation.GenerationMixin.beam_sample`] if `num_beams>1`
|
||||
- *beam-search multinomial sampling* by calling [`~generation.GenerationMixin._beam_sample`] if `num_beams>1`
|
||||
and `do_sample=True`
|
||||
- *diverse beam-search decoding* by calling [`~generation.GenerationMixin.group_beam_search`], if `num_beams>1`
|
||||
- *diverse beam-search decoding* by calling [`~generation.GenerationMixin._group_beam_search`], if `num_beams>1`
|
||||
and `num_beam_groups>1`
|
||||
- *constrained beam-search decoding* by calling [`~generation.GenerationMixin.constrained_beam_search`], if
|
||||
- *constrained beam-search decoding* by calling [`~generation.GenerationMixin._constrained_beam_search`], if
|
||||
`constraints!=None` or `force_words_ids!=None`
|
||||
- *assisted decoding* by calling [`~generation.GenerationMixin._assisted_decoding`], if
|
||||
`assistant_model` or `prompt_lookup_num_tokens` is passed to `.generate()`
|
||||
|
||||
You do not need to call any of the above methods directly. Pass custom parameter values to 'generate' instead. To
|
||||
learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
|
||||
@ -1547,7 +1549,7 @@ class GenerationMixin:
|
||||
)
|
||||
if generation_mode == GenerationMode.GREEDY_SEARCH:
|
||||
# 11. run greedy search
|
||||
result = self.greedy_search(
|
||||
result = self._greedy_search(
|
||||
input_ids,
|
||||
logits_processor=prepared_logits_processor,
|
||||
stopping_criteria=prepared_stopping_criteria,
|
||||
@ -1565,7 +1567,7 @@ class GenerationMixin:
|
||||
if not model_kwargs["use_cache"]:
|
||||
raise ValueError("Contrastive search requires `use_cache=True`")
|
||||
|
||||
result = self.contrastive_search(
|
||||
result = self._contrastive_search(
|
||||
input_ids,
|
||||
top_k=generation_config.top_k,
|
||||
penalty_alpha=generation_config.penalty_alpha,
|
||||
@ -1595,7 +1597,7 @@ class GenerationMixin:
|
||||
)
|
||||
|
||||
# 13. run sample
|
||||
result = self.sample(
|
||||
result = self._sample(
|
||||
input_ids,
|
||||
logits_processor=prepared_logits_processor,
|
||||
logits_warper=logits_warper,
|
||||
@ -1629,7 +1631,7 @@ class GenerationMixin:
|
||||
**model_kwargs,
|
||||
)
|
||||
# 13. run beam search
|
||||
result = self.beam_search(
|
||||
result = self._beam_search(
|
||||
input_ids,
|
||||
beam_scorer,
|
||||
logits_processor=prepared_logits_processor,
|
||||
@ -1668,7 +1670,7 @@ class GenerationMixin:
|
||||
)
|
||||
|
||||
# 14. run beam sample
|
||||
result = self.beam_sample(
|
||||
result = self._beam_sample(
|
||||
input_ids,
|
||||
beam_scorer,
|
||||
logits_processor=prepared_logits_processor,
|
||||
@ -1703,7 +1705,7 @@ class GenerationMixin:
|
||||
**model_kwargs,
|
||||
)
|
||||
# 13. run beam search
|
||||
result = self.group_beam_search(
|
||||
result = self._group_beam_search(
|
||||
input_ids,
|
||||
beam_scorer,
|
||||
logits_processor=prepared_logits_processor,
|
||||
@ -1777,7 +1779,7 @@ class GenerationMixin:
|
||||
**model_kwargs,
|
||||
)
|
||||
# 13. run beam search
|
||||
result = self.constrained_beam_search(
|
||||
result = self._constrained_beam_search(
|
||||
input_ids,
|
||||
constrained_beam_scorer=constrained_beam_scorer,
|
||||
logits_processor=prepared_logits_processor,
|
||||
@ -1801,8 +1803,15 @@ class GenerationMixin:
|
||||
|
||||
return result
|
||||
|
||||
def contrastive_search(self, *args, **kwargs):
|
||||
logger.warning_once(
|
||||
"Calling `contrastive_search` directly is deprecated and will be removed in v4.41. Use `generate` or a "
|
||||
"custom generation loop instead.",
|
||||
)
|
||||
return self._contrastive_search(*args, **kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
def contrastive_search(
|
||||
def _contrastive_search(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
top_k: Optional[int] = 1,
|
||||
@ -1828,7 +1837,7 @@ class GenerationMixin:
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
In most cases, you do not need to call [`~generation.GenerationMixin.contrastive_search`] directly. Use
|
||||
In most cases, you do not need to call [`~generation.GenerationMixin._contrastive_search`] directly. Use
|
||||
generate() instead. For an overview of generation strategies and code examples, check the [following
|
||||
guide](../generation_strategies).
|
||||
|
||||
@ -1902,7 +1911,7 @@ class GenerationMixin:
|
||||
>>> input_prompt = "DeepMind Company is"
|
||||
>>> input_ids = tokenizer(input_prompt, return_tensors="pt")
|
||||
>>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=64)])
|
||||
>>> outputs = model.contrastive_search(
|
||||
>>> outputs = model._contrastive_search(
|
||||
... **input_ids, penalty_alpha=0.6, top_k=4, stopping_criteria=stopping_criteria
|
||||
... )
|
||||
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
@ -2243,7 +2252,14 @@ class GenerationMixin:
|
||||
else:
|
||||
return input_ids
|
||||
|
||||
def greedy_search(
|
||||
def greedy_search(self, *args, **kwargs):
|
||||
logger.warning_once(
|
||||
"Calling `greedy_search` directly is deprecated and will be removed in v4.41. Use `generate` or a "
|
||||
"custom generation loop instead.",
|
||||
)
|
||||
return self._greedy_search(*args, **kwargs)
|
||||
|
||||
def _greedy_search(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
logits_processor: Optional[LogitsProcessorList] = None,
|
||||
@ -2266,7 +2282,7 @@ class GenerationMixin:
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
In most cases, you do not need to call [`~generation.GenerationMixin.greedy_search`] directly. Use generate()
|
||||
In most cases, you do not need to call [`~generation.GenerationMixin._greedy_search`] directly. Use generate()
|
||||
instead. For an overview of generation strategies and code examples, check the [following
|
||||
guide](../generation_strategies).
|
||||
|
||||
@ -2348,7 +2364,7 @@ class GenerationMixin:
|
||||
... )
|
||||
>>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
|
||||
|
||||
>>> outputs = model.greedy_search(
|
||||
>>> outputs = model._greedy_search(
|
||||
... input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria
|
||||
... )
|
||||
|
||||
@ -2514,7 +2530,14 @@ class GenerationMixin:
|
||||
else:
|
||||
return input_ids
|
||||
|
||||
def sample(
|
||||
def sample(self, *args, **kwargs):
|
||||
logger.warning_once(
|
||||
"Calling `sample` directly is deprecated and will be removed in v4.41. Use `generate` or a "
|
||||
"custom generation loop instead.",
|
||||
)
|
||||
return self._sample(*args, **kwargs)
|
||||
|
||||
def _sample(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
logits_processor: Optional[LogitsProcessorList] = None,
|
||||
@ -2538,7 +2561,7 @@ class GenerationMixin:
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
In most cases, you do not need to call [`~generation.GenerationMixin.sample`] directly. Use generate() instead.
|
||||
In most cases, you do not need to call [`~generation.GenerationMixin._sample`] directly. Use generate() instead.
|
||||
For an overview of generation strategies and code examples, check the [following
|
||||
guide](../generation_strategies).
|
||||
|
||||
@ -2635,7 +2658,7 @@ class GenerationMixin:
|
||||
>>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
|
||||
|
||||
>>> torch.manual_seed(0) # doctest: +IGNORE_RESULT
|
||||
>>> outputs = model.sample(
|
||||
>>> outputs = model._sample(
|
||||
... input_ids,
|
||||
... logits_processor=logits_processor,
|
||||
... logits_warper=logits_warper,
|
||||
@ -2832,7 +2855,14 @@ class GenerationMixin:
|
||||
past_key_values.reorder_cache(beam_idx)
|
||||
return past_key_values
|
||||
|
||||
def beam_search(
|
||||
def beam_search(self, *args, **kwargs):
|
||||
logger.warning_once(
|
||||
"Calling `beam_search` directly is deprecated and will be removed in v4.41. Use `generate` or a "
|
||||
"custom generation loop instead.",
|
||||
)
|
||||
return self._beam_search(*args, **kwargs)
|
||||
|
||||
def _beam_search(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
beam_scorer: BeamScorer,
|
||||
@ -2856,7 +2886,7 @@ class GenerationMixin:
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
In most cases, you do not need to call [`~generation.GenerationMixin.beam_search`] directly. Use generate()
|
||||
In most cases, you do not need to call [`~generation.GenerationMixin._beam_search`] directly. Use generate()
|
||||
instead. For an overview of generation strategies and code examples, check the [following
|
||||
guide](../generation_strategies).
|
||||
|
||||
@ -2958,7 +2988,7 @@ class GenerationMixin:
|
||||
... ]
|
||||
... )
|
||||
|
||||
>>> outputs = model.beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs)
|
||||
>>> outputs = model._beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs)
|
||||
|
||||
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
['Wie alt bist du?']
|
||||
@ -3214,7 +3244,14 @@ class GenerationMixin:
|
||||
else:
|
||||
return sequence_outputs["sequences"]
|
||||
|
||||
def beam_sample(
|
||||
def beam_sample(self, *args, **kwargs):
|
||||
logger.warning_once(
|
||||
"Calling `beam_sample` directly is deprecated and will be removed in v4.41. Use `generate` or a "
|
||||
"custom generation loop instead.",
|
||||
)
|
||||
return self._beam_sample(*args, **kwargs)
|
||||
|
||||
def _beam_sample(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
beam_scorer: BeamScorer,
|
||||
@ -3238,7 +3275,7 @@ class GenerationMixin:
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
In most cases, you do not need to call [`~generation.GenerationMixin.beam_sample`] directly. Use generate()
|
||||
In most cases, you do not need to call [`~generation.GenerationMixin._beam_sample`] directly. Use generate()
|
||||
instead. For an overview of generation strategies and code examples, check the [following
|
||||
guide](../generation_strategies).
|
||||
|
||||
@ -3346,7 +3383,7 @@ class GenerationMixin:
|
||||
... ]
|
||||
... )
|
||||
|
||||
>>> outputs = model.beam_sample(
|
||||
>>> outputs = model._beam_sample(
|
||||
... input_ids, beam_scorer, logits_processor=logits_processor, logits_warper=logits_warper, **model_kwargs
|
||||
... )
|
||||
|
||||
@ -3561,7 +3598,14 @@ class GenerationMixin:
|
||||
else:
|
||||
return sequence_outputs["sequences"]
|
||||
|
||||
def group_beam_search(
|
||||
def group_beam_search(self, *args, **kwargs):
|
||||
logger.warning_once(
|
||||
"Calling `group_beam_search` directly is deprecated and will be removed in v4.41. Use `generate` or a "
|
||||
"custom generation loop instead.",
|
||||
)
|
||||
return self._group_beam_search(*args, **kwargs)
|
||||
|
||||
def _group_beam_search(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
beam_scorer: BeamScorer,
|
||||
@ -3584,7 +3628,7 @@ class GenerationMixin:
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
In most cases, you do not need to call [`~generation.GenerationMixin.group_beam_search`] directly. Use
|
||||
In most cases, you do not need to call [`~generation.GenerationMixin._group_beam_search`] directly. Use
|
||||
generate() instead. For an overview of generation strategies and code examples, check the [following
|
||||
guide](../generation_strategies).
|
||||
|
||||
@ -3686,7 +3730,7 @@ class GenerationMixin:
|
||||
... ]
|
||||
... )
|
||||
|
||||
>>> outputs = model.group_beam_search(
|
||||
>>> outputs = model._group_beam_search(
|
||||
... input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs
|
||||
... )
|
||||
|
||||
@ -3958,7 +4002,14 @@ class GenerationMixin:
|
||||
else:
|
||||
return sequence_outputs["sequences"]
|
||||
|
||||
def constrained_beam_search(
|
||||
def constrained_beam_search(self, *args, **kwargs):
|
||||
logger.warning_once(
|
||||
"Calling `constrained_beam_search` directly is deprecated and will be removed in v4.41. Use `generate` or a "
|
||||
"custom generation loop instead.",
|
||||
)
|
||||
return self._constrained_beam_search(*args, **kwargs)
|
||||
|
||||
def _constrained_beam_search(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
constrained_beam_scorer: ConstrainedBeamSearchScorer,
|
||||
@ -3981,7 +4032,7 @@ class GenerationMixin:
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
In most cases, you do not need to call [`~generation.GenerationMixin.constrained_beam_search`] directly. Use
|
||||
In most cases, you do not need to call [`~generation.GenerationMixin._constrained_beam_search`] directly. Use
|
||||
generate() instead. For an overview of generation strategies and code examples, check the [following
|
||||
guide](../generation_strategies).
|
||||
|
||||
@ -4088,7 +4139,7 @@ class GenerationMixin:
|
||||
... ]
|
||||
... )
|
||||
|
||||
>>> outputs = model.constrained_beam_search(
|
||||
>>> outputs = model._constrained_beam_search(
|
||||
... input_ids, beam_scorer, constraints=constraints, logits_processor=logits_processor, **model_kwargs
|
||||
... )
|
||||
|
||||
@ -4311,7 +4362,14 @@ class GenerationMixin:
|
||||
else:
|
||||
return sequence_outputs["sequences"]
|
||||
|
||||
def assisted_decoding(
|
||||
def assisted_decoding(self, *args, **kwargs):
|
||||
logger.warning_once(
|
||||
"Calling `_assisted_decoding` directly is deprecated and will be removed in v4.41. Use `generate` or a "
|
||||
"custom generation loop instead.",
|
||||
)
|
||||
return self._assisted_decoding(*args, **kwargs)
|
||||
|
||||
def _assisted_decoding(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
candidate_generator: Optional["CandidateGenerator"] = None,
|
||||
@ -4338,7 +4396,7 @@ class GenerationMixin:
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
In most cases, you do not need to call [`~generation.GenerationMixin.candidate_decoding`] directly. Use
|
||||
In most cases, you do not need to call [`~generation.GenerationMixin._assisted_decoding`] directly. Use
|
||||
generate() instead. For an overview of generation strategies and code examples, check the [following
|
||||
guide](../generation_strategies).
|
||||
|
||||
@ -4429,7 +4487,7 @@ class GenerationMixin:
|
||||
... logits_processor=logits_processor,
|
||||
... model_kwargs={},
|
||||
... )
|
||||
>>> outputs = model.assisted_decoding(
|
||||
>>> outputs = model._assisted_decoding(
|
||||
... input_ids,
|
||||
... candidate_generator=candidate_generator,
|
||||
... logits_processor=logits_processor,
|
||||
|
@ -1336,7 +1336,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
|
||||
)
|
||||
|
||||
# 11. run greedy search
|
||||
outputs = self.greedy_search(
|
||||
outputs = self._greedy_search(
|
||||
input_ids,
|
||||
logits_processor=logits_processor,
|
||||
stopping_criteria=stopping_criteria,
|
||||
@ -1361,7 +1361,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
|
||||
)
|
||||
|
||||
# 12. run sample
|
||||
outputs = self.sample(
|
||||
outputs = self._sample(
|
||||
input_ids,
|
||||
logits_processor=logits_processor,
|
||||
logits_warper=logits_warper,
|
||||
@ -2402,7 +2402,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
|
||||
)
|
||||
|
||||
# 11. run greedy search
|
||||
outputs = self.greedy_search(
|
||||
outputs = self._greedy_search(
|
||||
input_ids,
|
||||
logits_processor=logits_processor,
|
||||
stopping_criteria=stopping_criteria,
|
||||
@ -2428,7 +2428,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
|
||||
)
|
||||
|
||||
# 12. run sample
|
||||
outputs = self.sample(
|
||||
outputs = self._sample(
|
||||
input_ids,
|
||||
logits_processor=logits_processor,
|
||||
logits_warper=logits_warper,
|
||||
|
@ -1539,7 +1539,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing"
|
||||
" greedy search."
|
||||
)
|
||||
return self.greedy_search(
|
||||
return self._greedy_search(
|
||||
input_ids,
|
||||
logits_processor=pre_processor,
|
||||
max_length=generation_config.max_length,
|
||||
@ -1559,7 +1559,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
num_beam_hyps_to_keep=generation_config.num_return_sequences,
|
||||
max_length=generation_config.max_length,
|
||||
)
|
||||
return self.beam_search(
|
||||
return self._beam_search(
|
||||
input_ids,
|
||||
beam_scorer,
|
||||
logits_processor=pre_processor,
|
||||
|
Loading…
Reference in New Issue
Block a user