mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Generate: add missing logits processors docs (#25653)
This commit is contained in:
parent
cb8e3ee25f
commit
85cf90a1c9
@ -75,39 +75,104 @@ values. Here, for instance, it has two keys that are `sequences` and `scores`.
|
||||
We document here all output types.
|
||||
|
||||
|
||||
### GreedySearchOutput
|
||||
|
||||
[[autodoc]] generation.GreedySearchDecoderOnlyOutput
|
||||
### PyTorch
|
||||
|
||||
[[autodoc]] generation.GreedySearchEncoderDecoderOutput
|
||||
|
||||
[[autodoc]] generation.FlaxGreedySearchOutput
|
||||
|
||||
### SampleOutput
|
||||
|
||||
[[autodoc]] generation.SampleDecoderOnlyOutput
|
||||
[[autodoc]] generation.GreedySearchDecoderOnlyOutput
|
||||
|
||||
[[autodoc]] generation.SampleEncoderDecoderOutput
|
||||
|
||||
[[autodoc]] generation.FlaxSampleOutput
|
||||
|
||||
### BeamSearchOutput
|
||||
|
||||
[[autodoc]] generation.BeamSearchDecoderOnlyOutput
|
||||
[[autodoc]] generation.SampleDecoderOnlyOutput
|
||||
|
||||
[[autodoc]] generation.BeamSearchEncoderDecoderOutput
|
||||
|
||||
### BeamSampleOutput
|
||||
[[autodoc]] generation.BeamSearchDecoderOnlyOutput
|
||||
|
||||
[[autodoc]] generation.BeamSampleEncoderDecoderOutput
|
||||
|
||||
[[autodoc]] generation.BeamSampleDecoderOnlyOutput
|
||||
|
||||
[[autodoc]] generation.BeamSampleEncoderDecoderOutput
|
||||
[[autodoc]] generation.ContrastiveSearchEncoderDecoderOutput
|
||||
|
||||
[[autodoc]] generation.ContrastiveSearchDecoderOnlyOutput
|
||||
|
||||
### TensorFlow
|
||||
|
||||
[[autodoc]] generation.TFGreedySearchEncoderDecoderOutput
|
||||
|
||||
[[autodoc]] generation.TFGreedySearchDecoderOnlyOutput
|
||||
|
||||
[[autodoc]] generation.TFSampleEncoderDecoderOutput
|
||||
|
||||
[[autodoc]] generation.TFSampleDecoderOnlyOutput
|
||||
|
||||
[[autodoc]] generation.TFBeamSearchEncoderDecoderOutput
|
||||
|
||||
[[autodoc]] generation.TFBeamSearchDecoderOnlyOutput
|
||||
|
||||
[[autodoc]] generation.TFBeamSampleEncoderDecoderOutput
|
||||
|
||||
[[autodoc]] generation.TFBeamSampleDecoderOnlyOutput
|
||||
|
||||
[[autodoc]] generation.TFContrastiveSearchEncoderDecoderOutput
|
||||
|
||||
[[autodoc]] generation.TFContrastiveSearchDecoderOnlyOutput
|
||||
|
||||
### FLAX
|
||||
|
||||
[[autodoc]] generation.FlaxSampleOutput
|
||||
|
||||
[[autodoc]] generation.FlaxGreedySearchOutput
|
||||
|
||||
[[autodoc]] generation.FlaxBeamSearchOutput
|
||||
|
||||
## LogitsProcessor
|
||||
|
||||
A [`LogitsProcessor`] can be used to modify the prediction scores of a language model head for
|
||||
generation.
|
||||
|
||||
### PyTorch
|
||||
|
||||
[[autodoc]] AlternatingCodebooksLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] ClassifierFreeGuidanceLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] EncoderNoRepeatNGramLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] EncoderRepetitionPenaltyLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] EpsilonLogitsWarper
|
||||
- __call__
|
||||
|
||||
[[autodoc]] EtaLogitsWarper
|
||||
- __call__
|
||||
|
||||
[[autodoc]] ExponentialDecayLengthPenalty
|
||||
- __call__
|
||||
|
||||
[[autodoc]] ForcedBOSTokenLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] ForcedEOSTokenLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] ForceTokensLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] HammingDiversityLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] InfNanRemoveLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] LogitNormalization
|
||||
- __call__
|
||||
|
||||
[[autodoc]] LogitsProcessor
|
||||
- __call__
|
||||
|
||||
@ -123,43 +188,54 @@ generation.
|
||||
[[autodoc]] MinNewTokensLengthLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] TemperatureLogitsWarper
|
||||
- __call__
|
||||
|
||||
[[autodoc]] RepetitionPenaltyLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] TopPLogitsWarper
|
||||
- __call__
|
||||
|
||||
[[autodoc]] TopKLogitsWarper
|
||||
- __call__
|
||||
|
||||
[[autodoc]] TypicalLogitsWarper
|
||||
[[autodoc]] NoBadWordsLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] NoRepeatNGramLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] SequenceBiasLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] NoBadWordsLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] PrefixConstrainedLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] HammingDiversityLogitsProcessor
|
||||
[[autodoc]] RepetitionPenaltyLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] ForcedBOSTokenLogitsProcessor
|
||||
[[autodoc]] SequenceBiasLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] ForcedEOSTokenLogitsProcessor
|
||||
[[autodoc]] SuppressTokensAtBeginLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] InfNanRemoveLogitsProcessor
|
||||
[[autodoc]] SuppressTokensLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] TemperatureLogitsWarper
|
||||
- __call__
|
||||
|
||||
[[autodoc]] TopKLogitsWarper
|
||||
- __call__
|
||||
|
||||
[[autodoc]] TopPLogitsWarper
|
||||
- __call__
|
||||
|
||||
[[autodoc]] TypicalLogitsWarper
|
||||
- __call__
|
||||
|
||||
[[autodoc]] UnbatchedClassifierFreeGuidanceLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] WhisperTimeStampLogitsProcessor
|
||||
- __call__
|
||||
|
||||
### TensorFlow
|
||||
|
||||
[[autodoc]] TFForcedBOSTokenLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] TFForcedEOSTokenLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] TFForceTokensLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] TFLogitsProcessor
|
||||
@ -171,15 +247,6 @@ generation.
|
||||
[[autodoc]] TFLogitsWarper
|
||||
- __call__
|
||||
|
||||
[[autodoc]] TFTemperatureLogitsWarper
|
||||
- __call__
|
||||
|
||||
[[autodoc]] TFTopPLogitsWarper
|
||||
- __call__
|
||||
|
||||
[[autodoc]] TFTopKLogitsWarper
|
||||
- __call__
|
||||
|
||||
[[autodoc]] TFMinLengthLogitsProcessor
|
||||
- __call__
|
||||
|
||||
@ -192,10 +259,30 @@ generation.
|
||||
[[autodoc]] TFRepetitionPenaltyLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] TFForcedBOSTokenLogitsProcessor
|
||||
[[autodoc]] TFSuppressTokensAtBeginLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] TFForcedEOSTokenLogitsProcessor
|
||||
[[autodoc]] TFSuppressTokensLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] TFTemperatureLogitsWarper
|
||||
- __call__
|
||||
|
||||
[[autodoc]] TFTopKLogitsWarper
|
||||
- __call__
|
||||
|
||||
[[autodoc]] TFTopPLogitsWarper
|
||||
- __call__
|
||||
|
||||
### FLAX
|
||||
|
||||
[[autodoc]] FlaxForcedBOSTokenLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] FlaxForcedEOSTokenLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] FlaxForceTokensLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] FlaxLogitsProcessor
|
||||
@ -207,27 +294,30 @@ generation.
|
||||
[[autodoc]] FlaxLogitsWarper
|
||||
- __call__
|
||||
|
||||
[[autodoc]] FlaxTemperatureLogitsWarper
|
||||
[[autodoc]] FlaxMinLengthLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] FlaxTopPLogitsWarper
|
||||
[[autodoc]] FlaxSuppressTokensAtBeginLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] FlaxSuppressTokensLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] FlaxTemperatureLogitsWarper
|
||||
- __call__
|
||||
|
||||
[[autodoc]] FlaxTopKLogitsWarper
|
||||
- __call__
|
||||
|
||||
[[autodoc]] FlaxForcedBOSTokenLogitsProcessor
|
||||
[[autodoc]] FlaxTopPLogitsWarper
|
||||
- __call__
|
||||
|
||||
[[autodoc]] FlaxForcedEOSTokenLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] FlaxMinLengthLogitsProcessor
|
||||
[[autodoc]] FlaxWhisperTimeStampLogitsProcessor
|
||||
- __call__
|
||||
|
||||
## StoppingCriteria
|
||||
|
||||
A [`StoppingCriteria`] can be used to change when to stop generation (other than EOS token).
|
||||
A [`StoppingCriteria`] can be used to change when to stop generation (other than EOS token). Please note that this is exclusivelly available to our PyTorch implementations.
|
||||
|
||||
[[autodoc]] StoppingCriteria
|
||||
- __call__
|
||||
@ -243,7 +333,7 @@ A [`StoppingCriteria`] can be used to change when to stop generation (other than
|
||||
|
||||
## Constraints
|
||||
|
||||
A [`Constraint`] can be used to force the generation to include specific tokens or sequences in the output.
|
||||
A [`Constraint`] can be used to force the generation to include specific tokens or sequences in the output. Please note that this is exclusivelly available to our PyTorch implementations.
|
||||
|
||||
[[autodoc]] Constraint
|
||||
|
||||
|
@ -1005,17 +1005,26 @@ else:
|
||||
_import_structure["deepspeed"] = []
|
||||
_import_structure["generation"].extend(
|
||||
[
|
||||
"AlternatingCodebooksLogitsProcessor",
|
||||
"BeamScorer",
|
||||
"BeamSearchScorer",
|
||||
"ClassifierFreeGuidanceLogitsProcessor",
|
||||
"ConstrainedBeamSearchScorer",
|
||||
"Constraint",
|
||||
"ConstraintListState",
|
||||
"DisjunctiveConstraint",
|
||||
"EncoderNoRepeatNGramLogitsProcessor",
|
||||
"EncoderRepetitionPenaltyLogitsProcessor",
|
||||
"EpsilonLogitsWarper",
|
||||
"EtaLogitsWarper",
|
||||
"ExponentialDecayLengthPenalty",
|
||||
"ForcedBOSTokenLogitsProcessor",
|
||||
"ForcedEOSTokenLogitsProcessor",
|
||||
"ForceTokensLogitsProcessor",
|
||||
"GenerationMixin",
|
||||
"HammingDiversityLogitsProcessor",
|
||||
"InfNanRemoveLogitsProcessor",
|
||||
"LogitNormalization",
|
||||
"LogitsProcessor",
|
||||
"LogitsProcessorList",
|
||||
"LogitsWarper",
|
||||
@ -1031,10 +1040,14 @@ else:
|
||||
"SequenceBiasLogitsProcessor",
|
||||
"StoppingCriteria",
|
||||
"StoppingCriteriaList",
|
||||
"SuppressTokensAtBeginLogitsProcessor",
|
||||
"SuppressTokensLogitsProcessor",
|
||||
"TemperatureLogitsWarper",
|
||||
"TopKLogitsWarper",
|
||||
"TopPLogitsWarper",
|
||||
"TypicalLogitsWarper",
|
||||
"UnbatchedClassifierFreeGuidanceLogitsProcessor",
|
||||
"WhisperTimeStampLogitsProcessor",
|
||||
"top_k_top_p_filtering",
|
||||
]
|
||||
)
|
||||
@ -3115,6 +3128,7 @@ else:
|
||||
[
|
||||
"TFForcedBOSTokenLogitsProcessor",
|
||||
"TFForcedEOSTokenLogitsProcessor",
|
||||
"TFForceTokensLogitsProcessor",
|
||||
"TFGenerationMixin",
|
||||
"TFLogitsProcessor",
|
||||
"TFLogitsProcessorList",
|
||||
@ -3123,6 +3137,8 @@ else:
|
||||
"TFNoBadWordsLogitsProcessor",
|
||||
"TFNoRepeatNGramLogitsProcessor",
|
||||
"TFRepetitionPenaltyLogitsProcessor",
|
||||
"TFSuppressTokensAtBeginLogitsProcessor",
|
||||
"TFSuppressTokensLogitsProcessor",
|
||||
"TFTemperatureLogitsWarper",
|
||||
"TFTopKLogitsWarper",
|
||||
"TFTopPLogitsWarper",
|
||||
@ -3836,14 +3852,18 @@ else:
|
||||
[
|
||||
"FlaxForcedBOSTokenLogitsProcessor",
|
||||
"FlaxForcedEOSTokenLogitsProcessor",
|
||||
"FlaxForceTokensLogitsProcessor",
|
||||
"FlaxGenerationMixin",
|
||||
"FlaxLogitsProcessor",
|
||||
"FlaxLogitsProcessorList",
|
||||
"FlaxLogitsWarper",
|
||||
"FlaxMinLengthLogitsProcessor",
|
||||
"FlaxTemperatureLogitsWarper",
|
||||
"FlaxSuppressTokensAtBeginLogitsProcessor",
|
||||
"FlaxSuppressTokensLogitsProcessor",
|
||||
"FlaxTopKLogitsWarper",
|
||||
"FlaxTopPLogitsWarper",
|
||||
"FlaxWhisperTimeStampLogitsProcessor",
|
||||
]
|
||||
)
|
||||
_import_structure["generation_flax_utils"] = []
|
||||
@ -4983,17 +5003,26 @@ if TYPE_CHECKING:
|
||||
TextDatasetForNextSentencePrediction,
|
||||
)
|
||||
from .generation import (
|
||||
AlternatingCodebooksLogitsProcessor,
|
||||
BeamScorer,
|
||||
BeamSearchScorer,
|
||||
ClassifierFreeGuidanceLogitsProcessor,
|
||||
ConstrainedBeamSearchScorer,
|
||||
Constraint,
|
||||
ConstraintListState,
|
||||
DisjunctiveConstraint,
|
||||
EncoderNoRepeatNGramLogitsProcessor,
|
||||
EncoderRepetitionPenaltyLogitsProcessor,
|
||||
EpsilonLogitsWarper,
|
||||
EtaLogitsWarper,
|
||||
ExponentialDecayLengthPenalty,
|
||||
ForcedBOSTokenLogitsProcessor,
|
||||
ForcedEOSTokenLogitsProcessor,
|
||||
ForceTokensLogitsProcessor,
|
||||
GenerationMixin,
|
||||
HammingDiversityLogitsProcessor,
|
||||
InfNanRemoveLogitsProcessor,
|
||||
LogitNormalization,
|
||||
LogitsProcessor,
|
||||
LogitsProcessorList,
|
||||
LogitsWarper,
|
||||
@ -5009,10 +5038,14 @@ if TYPE_CHECKING:
|
||||
SequenceBiasLogitsProcessor,
|
||||
StoppingCriteria,
|
||||
StoppingCriteriaList,
|
||||
SuppressTokensAtBeginLogitsProcessor,
|
||||
SuppressTokensLogitsProcessor,
|
||||
TemperatureLogitsWarper,
|
||||
TopKLogitsWarper,
|
||||
TopPLogitsWarper,
|
||||
TypicalLogitsWarper,
|
||||
UnbatchedClassifierFreeGuidanceLogitsProcessor,
|
||||
WhisperTimeStampLogitsProcessor,
|
||||
top_k_top_p_filtering,
|
||||
)
|
||||
from .modeling_utils import PreTrainedModel
|
||||
@ -6712,6 +6745,7 @@ if TYPE_CHECKING:
|
||||
from .generation import (
|
||||
TFForcedBOSTokenLogitsProcessor,
|
||||
TFForcedEOSTokenLogitsProcessor,
|
||||
TFForceTokensLogitsProcessor,
|
||||
TFGenerationMixin,
|
||||
TFLogitsProcessor,
|
||||
TFLogitsProcessorList,
|
||||
@ -6720,6 +6754,8 @@ if TYPE_CHECKING:
|
||||
TFNoBadWordsLogitsProcessor,
|
||||
TFNoRepeatNGramLogitsProcessor,
|
||||
TFRepetitionPenaltyLogitsProcessor,
|
||||
TFSuppressTokensAtBeginLogitsProcessor,
|
||||
TFSuppressTokensLogitsProcessor,
|
||||
TFTemperatureLogitsWarper,
|
||||
TFTopKLogitsWarper,
|
||||
TFTopPLogitsWarper,
|
||||
@ -7285,14 +7321,18 @@ if TYPE_CHECKING:
|
||||
from .generation import (
|
||||
FlaxForcedBOSTokenLogitsProcessor,
|
||||
FlaxForcedEOSTokenLogitsProcessor,
|
||||
FlaxForceTokensLogitsProcessor,
|
||||
FlaxGenerationMixin,
|
||||
FlaxLogitsProcessor,
|
||||
FlaxLogitsProcessorList,
|
||||
FlaxLogitsWarper,
|
||||
FlaxMinLengthLogitsProcessor,
|
||||
FlaxSuppressTokensAtBeginLogitsProcessor,
|
||||
FlaxSuppressTokensLogitsProcessor,
|
||||
FlaxTemperatureLogitsWarper,
|
||||
FlaxTopKLogitsWarper,
|
||||
FlaxTopPLogitsWarper,
|
||||
FlaxWhisperTimeStampLogitsProcessor,
|
||||
)
|
||||
from .modeling_flax_utils import FlaxPreTrainedModel
|
||||
|
||||
|
@ -41,12 +41,19 @@ else:
|
||||
"ConstrainedBeamSearchScorer",
|
||||
]
|
||||
_import_structure["logits_process"] = [
|
||||
"AlternatingCodebooksLogitsProcessor",
|
||||
"ClassifierFreeGuidanceLogitsProcessor",
|
||||
"EncoderNoRepeatNGramLogitsProcessor",
|
||||
"EncoderRepetitionPenaltyLogitsProcessor",
|
||||
"EpsilonLogitsWarper",
|
||||
"EtaLogitsWarper",
|
||||
"ExponentialDecayLengthPenalty",
|
||||
"ForcedBOSTokenLogitsProcessor",
|
||||
"ForcedEOSTokenLogitsProcessor",
|
||||
"ForceTokensLogitsProcessor",
|
||||
"HammingDiversityLogitsProcessor",
|
||||
"InfNanRemoveLogitsProcessor",
|
||||
"LogitNormalization",
|
||||
"LogitsProcessor",
|
||||
"LogitsProcessorList",
|
||||
"LogitsWarper",
|
||||
@ -57,15 +64,14 @@ else:
|
||||
"PrefixConstrainedLogitsProcessor",
|
||||
"RepetitionPenaltyLogitsProcessor",
|
||||
"SequenceBiasLogitsProcessor",
|
||||
"EncoderRepetitionPenaltyLogitsProcessor",
|
||||
"SuppressTokensLogitsProcessor",
|
||||
"SuppressTokensAtBeginLogitsProcessor",
|
||||
"TemperatureLogitsWarper",
|
||||
"TopKLogitsWarper",
|
||||
"TopPLogitsWarper",
|
||||
"TypicalLogitsWarper",
|
||||
"EncoderNoRepeatNGramLogitsProcessor",
|
||||
"ExponentialDecayLengthPenalty",
|
||||
"LogitNormalization",
|
||||
"UnbatchedClassifierFreeGuidanceLogitsProcessor",
|
||||
"WhisperTimeStampLogitsProcessor",
|
||||
]
|
||||
_import_structure["stopping_criteria"] = [
|
||||
"MaxNewTokensCriteria",
|
||||
@ -99,6 +105,7 @@ else:
|
||||
_import_structure["tf_logits_process"] = [
|
||||
"TFForcedBOSTokenLogitsProcessor",
|
||||
"TFForcedEOSTokenLogitsProcessor",
|
||||
"TFForceTokensLogitsProcessor",
|
||||
"TFLogitsProcessor",
|
||||
"TFLogitsProcessorList",
|
||||
"TFLogitsWarper",
|
||||
@ -106,12 +113,11 @@ else:
|
||||
"TFNoBadWordsLogitsProcessor",
|
||||
"TFNoRepeatNGramLogitsProcessor",
|
||||
"TFRepetitionPenaltyLogitsProcessor",
|
||||
"TFSuppressTokensAtBeginLogitsProcessor",
|
||||
"TFSuppressTokensLogitsProcessor",
|
||||
"TFTemperatureLogitsWarper",
|
||||
"TFTopKLogitsWarper",
|
||||
"TFTopPLogitsWarper",
|
||||
"TFForceTokensLogitsProcessor",
|
||||
"TFSuppressTokensAtBeginLogitsProcessor",
|
||||
"TFSuppressTokensLogitsProcessor",
|
||||
]
|
||||
_import_structure["tf_utils"] = [
|
||||
"TFGenerationMixin",
|
||||
@ -137,13 +143,17 @@ else:
|
||||
_import_structure["flax_logits_process"] = [
|
||||
"FlaxForcedBOSTokenLogitsProcessor",
|
||||
"FlaxForcedEOSTokenLogitsProcessor",
|
||||
"FlaxForceTokensLogitsProcessor",
|
||||
"FlaxLogitsProcessor",
|
||||
"FlaxLogitsProcessorList",
|
||||
"FlaxLogitsWarper",
|
||||
"FlaxMinLengthLogitsProcessor",
|
||||
"FlaxSuppressTokensAtBeginLogitsProcessor",
|
||||
"FlaxSuppressTokensLogitsProcessor",
|
||||
"FlaxTemperatureLogitsWarper",
|
||||
"FlaxTopKLogitsWarper",
|
||||
"FlaxTopPLogitsWarper",
|
||||
"FlaxWhisperTimeStampLogitsProcessor",
|
||||
]
|
||||
_import_structure["flax_utils"] = [
|
||||
"FlaxGenerationMixin",
|
||||
@ -165,6 +175,8 @@ if TYPE_CHECKING:
|
||||
from .beam_constraints import Constraint, ConstraintListState, DisjunctiveConstraint, PhrasalConstraint
|
||||
from .beam_search import BeamHypotheses, BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
|
||||
from .logits_process import (
|
||||
AlternatingCodebooksLogitsProcessor,
|
||||
ClassifierFreeGuidanceLogitsProcessor,
|
||||
EncoderNoRepeatNGramLogitsProcessor,
|
||||
EncoderRepetitionPenaltyLogitsProcessor,
|
||||
EpsilonLogitsWarper,
|
||||
@ -172,6 +184,7 @@ if TYPE_CHECKING:
|
||||
ExponentialDecayLengthPenalty,
|
||||
ForcedBOSTokenLogitsProcessor,
|
||||
ForcedEOSTokenLogitsProcessor,
|
||||
ForceTokensLogitsProcessor,
|
||||
HammingDiversityLogitsProcessor,
|
||||
InfNanRemoveLogitsProcessor,
|
||||
LogitNormalization,
|
||||
@ -185,11 +198,14 @@ if TYPE_CHECKING:
|
||||
PrefixConstrainedLogitsProcessor,
|
||||
RepetitionPenaltyLogitsProcessor,
|
||||
SequenceBiasLogitsProcessor,
|
||||
SuppressTokensAtBeginLogitsProcessor,
|
||||
SuppressTokensLogitsProcessor,
|
||||
TemperatureLogitsWarper,
|
||||
TopKLogitsWarper,
|
||||
TopPLogitsWarper,
|
||||
TypicalLogitsWarper,
|
||||
UnbatchedClassifierFreeGuidanceLogitsProcessor,
|
||||
WhisperTimeStampLogitsProcessor,
|
||||
)
|
||||
from .stopping_criteria import (
|
||||
MaxLengthCriteria,
|
||||
@ -261,13 +277,17 @@ if TYPE_CHECKING:
|
||||
from .flax_logits_process import (
|
||||
FlaxForcedBOSTokenLogitsProcessor,
|
||||
FlaxForcedEOSTokenLogitsProcessor,
|
||||
FlaxForceTokensLogitsProcessor,
|
||||
FlaxLogitsProcessor,
|
||||
FlaxLogitsProcessorList,
|
||||
FlaxLogitsWarper,
|
||||
FlaxMinLengthLogitsProcessor,
|
||||
FlaxSuppressTokensAtBeginLogitsProcessor,
|
||||
FlaxSuppressTokensLogitsProcessor,
|
||||
FlaxTemperatureLogitsWarper,
|
||||
FlaxTopKLogitsWarper,
|
||||
FlaxTopPLogitsWarper,
|
||||
FlaxWhisperTimeStampLogitsProcessor,
|
||||
)
|
||||
from .flax_utils import FlaxBeamSearchOutput, FlaxGenerationMixin, FlaxGreedySearchOutput, FlaxSampleOutput
|
||||
else:
|
||||
|
@ -16,6 +16,13 @@ class FlaxForcedEOSTokenLogitsProcessor(metaclass=DummyObject):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxForceTokensLogitsProcessor(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxGenerationMixin(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
@ -51,6 +58,20 @@ class FlaxMinLengthLogitsProcessor(metaclass=DummyObject):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxSuppressTokensAtBeginLogitsProcessor(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxSuppressTokensLogitsProcessor(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxTemperatureLogitsWarper(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
@ -72,6 +93,13 @@ class FlaxTopPLogitsWarper(metaclass=DummyObject):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxWhisperTimeStampLogitsProcessor(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxPreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
|
@ -79,6 +79,13 @@ class TextDatasetForNextSentencePrediction(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class AlternatingCodebooksLogitsProcessor(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class BeamScorer(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@ -93,6 +100,13 @@ class BeamSearchScorer(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class ClassifierFreeGuidanceLogitsProcessor(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class ConstrainedBeamSearchScorer(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@ -121,6 +135,41 @@ class DisjunctiveConstraint(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class EncoderNoRepeatNGramLogitsProcessor(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class EncoderRepetitionPenaltyLogitsProcessor(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class EpsilonLogitsWarper(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class EtaLogitsWarper(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class ExponentialDecayLengthPenalty(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class ForcedBOSTokenLogitsProcessor(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@ -135,6 +184,13 @@ class ForcedEOSTokenLogitsProcessor(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class ForceTokensLogitsProcessor(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class GenerationMixin(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@ -156,6 +212,13 @@ class InfNanRemoveLogitsProcessor(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class LogitNormalization(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class LogitsProcessor(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@ -261,6 +324,20 @@ class StoppingCriteriaList(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class SuppressTokensAtBeginLogitsProcessor(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class SuppressTokensLogitsProcessor(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class TemperatureLogitsWarper(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@ -289,6 +366,20 @@ class TypicalLogitsWarper(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class UnbatchedClassifierFreeGuidanceLogitsProcessor(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class WhisperTimeStampLogitsProcessor(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
def top_k_top_p_filtering(*args, **kwargs):
|
||||
requires_backends(top_k_top_p_filtering, ["torch"])
|
||||
|
||||
|
@ -30,6 +30,13 @@ class TFForcedEOSTokenLogitsProcessor(metaclass=DummyObject):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFForceTokensLogitsProcessor(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFGenerationMixin(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
@ -86,6 +93,20 @@ class TFRepetitionPenaltyLogitsProcessor(metaclass=DummyObject):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFSuppressTokensAtBeginLogitsProcessor(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFSuppressTokensLogitsProcessor(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFTemperatureLogitsWarper(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user