Generate: add missing logits processors docs (#25653)

This commit is contained in:
Joao Gante 2023-08-25 11:56:17 +01:00 committed by GitHub
parent cb8e3ee25f
commit 85cf90a1c9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 355 additions and 65 deletions

View File

@ -75,39 +75,104 @@ values. Here, for instance, it has two keys that are `sequences` and `scores`.
We document here all output types.
### GreedySearchOutput
[[autodoc]] generation.GreedySearchDecoderOnlyOutput
### PyTorch
[[autodoc]] generation.GreedySearchEncoderDecoderOutput
[[autodoc]] generation.FlaxGreedySearchOutput
### SampleOutput
[[autodoc]] generation.SampleDecoderOnlyOutput
[[autodoc]] generation.GreedySearchDecoderOnlyOutput
[[autodoc]] generation.SampleEncoderDecoderOutput
[[autodoc]] generation.FlaxSampleOutput
### BeamSearchOutput
[[autodoc]] generation.BeamSearchDecoderOnlyOutput
[[autodoc]] generation.SampleDecoderOnlyOutput
[[autodoc]] generation.BeamSearchEncoderDecoderOutput
### BeamSampleOutput
[[autodoc]] generation.BeamSearchDecoderOnlyOutput
[[autodoc]] generation.BeamSampleEncoderDecoderOutput
[[autodoc]] generation.BeamSampleDecoderOnlyOutput
[[autodoc]] generation.BeamSampleEncoderDecoderOutput
[[autodoc]] generation.ContrastiveSearchEncoderDecoderOutput
[[autodoc]] generation.ContrastiveSearchDecoderOnlyOutput
### TensorFlow
[[autodoc]] generation.TFGreedySearchEncoderDecoderOutput
[[autodoc]] generation.TFGreedySearchDecoderOnlyOutput
[[autodoc]] generation.TFSampleEncoderDecoderOutput
[[autodoc]] generation.TFSampleDecoderOnlyOutput
[[autodoc]] generation.TFBeamSearchEncoderDecoderOutput
[[autodoc]] generation.TFBeamSearchDecoderOnlyOutput
[[autodoc]] generation.TFBeamSampleEncoderDecoderOutput
[[autodoc]] generation.TFBeamSampleDecoderOnlyOutput
[[autodoc]] generation.TFContrastiveSearchEncoderDecoderOutput
[[autodoc]] generation.TFContrastiveSearchDecoderOnlyOutput
### FLAX
[[autodoc]] generation.FlaxSampleOutput
[[autodoc]] generation.FlaxGreedySearchOutput
[[autodoc]] generation.FlaxBeamSearchOutput
## LogitsProcessor
A [`LogitsProcessor`] can be used to modify the prediction scores of a language model head for
generation.
### PyTorch
[[autodoc]] AlternatingCodebooksLogitsProcessor
- __call__
[[autodoc]] ClassifierFreeGuidanceLogitsProcessor
- __call__
[[autodoc]] EncoderNoRepeatNGramLogitsProcessor
- __call__
[[autodoc]] EncoderRepetitionPenaltyLogitsProcessor
- __call__
[[autodoc]] EpsilonLogitsWarper
- __call__
[[autodoc]] EtaLogitsWarper
- __call__
[[autodoc]] ExponentialDecayLengthPenalty
- __call__
[[autodoc]] ForcedBOSTokenLogitsProcessor
- __call__
[[autodoc]] ForcedEOSTokenLogitsProcessor
- __call__
[[autodoc]] ForceTokensLogitsProcessor
- __call__
[[autodoc]] HammingDiversityLogitsProcessor
- __call__
[[autodoc]] InfNanRemoveLogitsProcessor
- __call__
[[autodoc]] LogitNormalization
- __call__
[[autodoc]] LogitsProcessor
- __call__
@ -123,43 +188,54 @@ generation.
[[autodoc]] MinNewTokensLengthLogitsProcessor
- __call__
[[autodoc]] TemperatureLogitsWarper
- __call__
[[autodoc]] RepetitionPenaltyLogitsProcessor
- __call__
[[autodoc]] TopPLogitsWarper
- __call__
[[autodoc]] TopKLogitsWarper
- __call__
[[autodoc]] TypicalLogitsWarper
[[autodoc]] NoBadWordsLogitsProcessor
- __call__
[[autodoc]] NoRepeatNGramLogitsProcessor
- __call__
[[autodoc]] SequenceBiasLogitsProcessor
- __call__
[[autodoc]] NoBadWordsLogitsProcessor
- __call__
[[autodoc]] PrefixConstrainedLogitsProcessor
- __call__
[[autodoc]] HammingDiversityLogitsProcessor
[[autodoc]] RepetitionPenaltyLogitsProcessor
- __call__
[[autodoc]] ForcedBOSTokenLogitsProcessor
[[autodoc]] SequenceBiasLogitsProcessor
- __call__
[[autodoc]] ForcedEOSTokenLogitsProcessor
[[autodoc]] SuppressTokensAtBeginLogitsProcessor
- __call__
[[autodoc]] InfNanRemoveLogitsProcessor
[[autodoc]] SuppressTokensLogitsProcessor
- __call__
[[autodoc]] TemperatureLogitsWarper
- __call__
[[autodoc]] TopKLogitsWarper
- __call__
[[autodoc]] TopPLogitsWarper
- __call__
[[autodoc]] TypicalLogitsWarper
- __call__
[[autodoc]] UnbatchedClassifierFreeGuidanceLogitsProcessor
- __call__
[[autodoc]] WhisperTimeStampLogitsProcessor
- __call__
### TensorFlow
[[autodoc]] TFForcedBOSTokenLogitsProcessor
- __call__
[[autodoc]] TFForcedEOSTokenLogitsProcessor
- __call__
[[autodoc]] TFForceTokensLogitsProcessor
- __call__
[[autodoc]] TFLogitsProcessor
@ -171,15 +247,6 @@ generation.
[[autodoc]] TFLogitsWarper
- __call__
[[autodoc]] TFTemperatureLogitsWarper
- __call__
[[autodoc]] TFTopPLogitsWarper
- __call__
[[autodoc]] TFTopKLogitsWarper
- __call__
[[autodoc]] TFMinLengthLogitsProcessor
- __call__
@ -192,10 +259,30 @@ generation.
[[autodoc]] TFRepetitionPenaltyLogitsProcessor
- __call__
[[autodoc]] TFForcedBOSTokenLogitsProcessor
[[autodoc]] TFSuppressTokensAtBeginLogitsProcessor
- __call__
[[autodoc]] TFForcedEOSTokenLogitsProcessor
[[autodoc]] TFSuppressTokensLogitsProcessor
- __call__
[[autodoc]] TFTemperatureLogitsWarper
- __call__
[[autodoc]] TFTopKLogitsWarper
- __call__
[[autodoc]] TFTopPLogitsWarper
- __call__
### FLAX
[[autodoc]] FlaxForcedBOSTokenLogitsProcessor
- __call__
[[autodoc]] FlaxForcedEOSTokenLogitsProcessor
- __call__
[[autodoc]] FlaxForceTokensLogitsProcessor
- __call__
[[autodoc]] FlaxLogitsProcessor
@ -207,27 +294,30 @@ generation.
[[autodoc]] FlaxLogitsWarper
- __call__
[[autodoc]] FlaxTemperatureLogitsWarper
[[autodoc]] FlaxMinLengthLogitsProcessor
- __call__
[[autodoc]] FlaxTopPLogitsWarper
[[autodoc]] FlaxSuppressTokensAtBeginLogitsProcessor
- __call__
[[autodoc]] FlaxSuppressTokensLogitsProcessor
- __call__
[[autodoc]] FlaxTemperatureLogitsWarper
- __call__
[[autodoc]] FlaxTopKLogitsWarper
- __call__
[[autodoc]] FlaxForcedBOSTokenLogitsProcessor
[[autodoc]] FlaxTopPLogitsWarper
- __call__
[[autodoc]] FlaxForcedEOSTokenLogitsProcessor
- __call__
[[autodoc]] FlaxMinLengthLogitsProcessor
[[autodoc]] FlaxWhisperTimeStampLogitsProcessor
- __call__
## StoppingCriteria
A [`StoppingCriteria`] can be used to change when to stop generation (other than EOS token).
A [`StoppingCriteria`] can be used to change when to stop generation (other than EOS token). Please note that this is exclusivelly available to our PyTorch implementations.
[[autodoc]] StoppingCriteria
- __call__
@ -243,7 +333,7 @@ A [`StoppingCriteria`] can be used to change when to stop generation (other than
## Constraints
A [`Constraint`] can be used to force the generation to include specific tokens or sequences in the output.
A [`Constraint`] can be used to force the generation to include specific tokens or sequences in the output. Please note that this is exclusivelly available to our PyTorch implementations.
[[autodoc]] Constraint

View File

@ -1005,17 +1005,26 @@ else:
_import_structure["deepspeed"] = []
_import_structure["generation"].extend(
[
"AlternatingCodebooksLogitsProcessor",
"BeamScorer",
"BeamSearchScorer",
"ClassifierFreeGuidanceLogitsProcessor",
"ConstrainedBeamSearchScorer",
"Constraint",
"ConstraintListState",
"DisjunctiveConstraint",
"EncoderNoRepeatNGramLogitsProcessor",
"EncoderRepetitionPenaltyLogitsProcessor",
"EpsilonLogitsWarper",
"EtaLogitsWarper",
"ExponentialDecayLengthPenalty",
"ForcedBOSTokenLogitsProcessor",
"ForcedEOSTokenLogitsProcessor",
"ForceTokensLogitsProcessor",
"GenerationMixin",
"HammingDiversityLogitsProcessor",
"InfNanRemoveLogitsProcessor",
"LogitNormalization",
"LogitsProcessor",
"LogitsProcessorList",
"LogitsWarper",
@ -1031,10 +1040,14 @@ else:
"SequenceBiasLogitsProcessor",
"StoppingCriteria",
"StoppingCriteriaList",
"SuppressTokensAtBeginLogitsProcessor",
"SuppressTokensLogitsProcessor",
"TemperatureLogitsWarper",
"TopKLogitsWarper",
"TopPLogitsWarper",
"TypicalLogitsWarper",
"UnbatchedClassifierFreeGuidanceLogitsProcessor",
"WhisperTimeStampLogitsProcessor",
"top_k_top_p_filtering",
]
)
@ -3115,6 +3128,7 @@ else:
[
"TFForcedBOSTokenLogitsProcessor",
"TFForcedEOSTokenLogitsProcessor",
"TFForceTokensLogitsProcessor",
"TFGenerationMixin",
"TFLogitsProcessor",
"TFLogitsProcessorList",
@ -3123,6 +3137,8 @@ else:
"TFNoBadWordsLogitsProcessor",
"TFNoRepeatNGramLogitsProcessor",
"TFRepetitionPenaltyLogitsProcessor",
"TFSuppressTokensAtBeginLogitsProcessor",
"TFSuppressTokensLogitsProcessor",
"TFTemperatureLogitsWarper",
"TFTopKLogitsWarper",
"TFTopPLogitsWarper",
@ -3836,14 +3852,18 @@ else:
[
"FlaxForcedBOSTokenLogitsProcessor",
"FlaxForcedEOSTokenLogitsProcessor",
"FlaxForceTokensLogitsProcessor",
"FlaxGenerationMixin",
"FlaxLogitsProcessor",
"FlaxLogitsProcessorList",
"FlaxLogitsWarper",
"FlaxMinLengthLogitsProcessor",
"FlaxTemperatureLogitsWarper",
"FlaxSuppressTokensAtBeginLogitsProcessor",
"FlaxSuppressTokensLogitsProcessor",
"FlaxTopKLogitsWarper",
"FlaxTopPLogitsWarper",
"FlaxWhisperTimeStampLogitsProcessor",
]
)
_import_structure["generation_flax_utils"] = []
@ -4983,17 +5003,26 @@ if TYPE_CHECKING:
TextDatasetForNextSentencePrediction,
)
from .generation import (
AlternatingCodebooksLogitsProcessor,
BeamScorer,
BeamSearchScorer,
ClassifierFreeGuidanceLogitsProcessor,
ConstrainedBeamSearchScorer,
Constraint,
ConstraintListState,
DisjunctiveConstraint,
EncoderNoRepeatNGramLogitsProcessor,
EncoderRepetitionPenaltyLogitsProcessor,
EpsilonLogitsWarper,
EtaLogitsWarper,
ExponentialDecayLengthPenalty,
ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor,
ForceTokensLogitsProcessor,
GenerationMixin,
HammingDiversityLogitsProcessor,
InfNanRemoveLogitsProcessor,
LogitNormalization,
LogitsProcessor,
LogitsProcessorList,
LogitsWarper,
@ -5009,10 +5038,14 @@ if TYPE_CHECKING:
SequenceBiasLogitsProcessor,
StoppingCriteria,
StoppingCriteriaList,
SuppressTokensAtBeginLogitsProcessor,
SuppressTokensLogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
TypicalLogitsWarper,
UnbatchedClassifierFreeGuidanceLogitsProcessor,
WhisperTimeStampLogitsProcessor,
top_k_top_p_filtering,
)
from .modeling_utils import PreTrainedModel
@ -6712,6 +6745,7 @@ if TYPE_CHECKING:
from .generation import (
TFForcedBOSTokenLogitsProcessor,
TFForcedEOSTokenLogitsProcessor,
TFForceTokensLogitsProcessor,
TFGenerationMixin,
TFLogitsProcessor,
TFLogitsProcessorList,
@ -6720,6 +6754,8 @@ if TYPE_CHECKING:
TFNoBadWordsLogitsProcessor,
TFNoRepeatNGramLogitsProcessor,
TFRepetitionPenaltyLogitsProcessor,
TFSuppressTokensAtBeginLogitsProcessor,
TFSuppressTokensLogitsProcessor,
TFTemperatureLogitsWarper,
TFTopKLogitsWarper,
TFTopPLogitsWarper,
@ -7285,14 +7321,18 @@ if TYPE_CHECKING:
from .generation import (
FlaxForcedBOSTokenLogitsProcessor,
FlaxForcedEOSTokenLogitsProcessor,
FlaxForceTokensLogitsProcessor,
FlaxGenerationMixin,
FlaxLogitsProcessor,
FlaxLogitsProcessorList,
FlaxLogitsWarper,
FlaxMinLengthLogitsProcessor,
FlaxSuppressTokensAtBeginLogitsProcessor,
FlaxSuppressTokensLogitsProcessor,
FlaxTemperatureLogitsWarper,
FlaxTopKLogitsWarper,
FlaxTopPLogitsWarper,
FlaxWhisperTimeStampLogitsProcessor,
)
from .modeling_flax_utils import FlaxPreTrainedModel

View File

@ -41,12 +41,19 @@ else:
"ConstrainedBeamSearchScorer",
]
_import_structure["logits_process"] = [
"AlternatingCodebooksLogitsProcessor",
"ClassifierFreeGuidanceLogitsProcessor",
"EncoderNoRepeatNGramLogitsProcessor",
"EncoderRepetitionPenaltyLogitsProcessor",
"EpsilonLogitsWarper",
"EtaLogitsWarper",
"ExponentialDecayLengthPenalty",
"ForcedBOSTokenLogitsProcessor",
"ForcedEOSTokenLogitsProcessor",
"ForceTokensLogitsProcessor",
"HammingDiversityLogitsProcessor",
"InfNanRemoveLogitsProcessor",
"LogitNormalization",
"LogitsProcessor",
"LogitsProcessorList",
"LogitsWarper",
@ -57,15 +64,14 @@ else:
"PrefixConstrainedLogitsProcessor",
"RepetitionPenaltyLogitsProcessor",
"SequenceBiasLogitsProcessor",
"EncoderRepetitionPenaltyLogitsProcessor",
"SuppressTokensLogitsProcessor",
"SuppressTokensAtBeginLogitsProcessor",
"TemperatureLogitsWarper",
"TopKLogitsWarper",
"TopPLogitsWarper",
"TypicalLogitsWarper",
"EncoderNoRepeatNGramLogitsProcessor",
"ExponentialDecayLengthPenalty",
"LogitNormalization",
"UnbatchedClassifierFreeGuidanceLogitsProcessor",
"WhisperTimeStampLogitsProcessor",
]
_import_structure["stopping_criteria"] = [
"MaxNewTokensCriteria",
@ -99,6 +105,7 @@ else:
_import_structure["tf_logits_process"] = [
"TFForcedBOSTokenLogitsProcessor",
"TFForcedEOSTokenLogitsProcessor",
"TFForceTokensLogitsProcessor",
"TFLogitsProcessor",
"TFLogitsProcessorList",
"TFLogitsWarper",
@ -106,12 +113,11 @@ else:
"TFNoBadWordsLogitsProcessor",
"TFNoRepeatNGramLogitsProcessor",
"TFRepetitionPenaltyLogitsProcessor",
"TFSuppressTokensAtBeginLogitsProcessor",
"TFSuppressTokensLogitsProcessor",
"TFTemperatureLogitsWarper",
"TFTopKLogitsWarper",
"TFTopPLogitsWarper",
"TFForceTokensLogitsProcessor",
"TFSuppressTokensAtBeginLogitsProcessor",
"TFSuppressTokensLogitsProcessor",
]
_import_structure["tf_utils"] = [
"TFGenerationMixin",
@ -137,13 +143,17 @@ else:
_import_structure["flax_logits_process"] = [
"FlaxForcedBOSTokenLogitsProcessor",
"FlaxForcedEOSTokenLogitsProcessor",
"FlaxForceTokensLogitsProcessor",
"FlaxLogitsProcessor",
"FlaxLogitsProcessorList",
"FlaxLogitsWarper",
"FlaxMinLengthLogitsProcessor",
"FlaxSuppressTokensAtBeginLogitsProcessor",
"FlaxSuppressTokensLogitsProcessor",
"FlaxTemperatureLogitsWarper",
"FlaxTopKLogitsWarper",
"FlaxTopPLogitsWarper",
"FlaxWhisperTimeStampLogitsProcessor",
]
_import_structure["flax_utils"] = [
"FlaxGenerationMixin",
@ -165,6 +175,8 @@ if TYPE_CHECKING:
from .beam_constraints import Constraint, ConstraintListState, DisjunctiveConstraint, PhrasalConstraint
from .beam_search import BeamHypotheses, BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
from .logits_process import (
AlternatingCodebooksLogitsProcessor,
ClassifierFreeGuidanceLogitsProcessor,
EncoderNoRepeatNGramLogitsProcessor,
EncoderRepetitionPenaltyLogitsProcessor,
EpsilonLogitsWarper,
@ -172,6 +184,7 @@ if TYPE_CHECKING:
ExponentialDecayLengthPenalty,
ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor,
ForceTokensLogitsProcessor,
HammingDiversityLogitsProcessor,
InfNanRemoveLogitsProcessor,
LogitNormalization,
@ -185,11 +198,14 @@ if TYPE_CHECKING:
PrefixConstrainedLogitsProcessor,
RepetitionPenaltyLogitsProcessor,
SequenceBiasLogitsProcessor,
SuppressTokensAtBeginLogitsProcessor,
SuppressTokensLogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
TypicalLogitsWarper,
UnbatchedClassifierFreeGuidanceLogitsProcessor,
WhisperTimeStampLogitsProcessor,
)
from .stopping_criteria import (
MaxLengthCriteria,
@ -261,13 +277,17 @@ if TYPE_CHECKING:
from .flax_logits_process import (
FlaxForcedBOSTokenLogitsProcessor,
FlaxForcedEOSTokenLogitsProcessor,
FlaxForceTokensLogitsProcessor,
FlaxLogitsProcessor,
FlaxLogitsProcessorList,
FlaxLogitsWarper,
FlaxMinLengthLogitsProcessor,
FlaxSuppressTokensAtBeginLogitsProcessor,
FlaxSuppressTokensLogitsProcessor,
FlaxTemperatureLogitsWarper,
FlaxTopKLogitsWarper,
FlaxTopPLogitsWarper,
FlaxWhisperTimeStampLogitsProcessor,
)
from .flax_utils import FlaxBeamSearchOutput, FlaxGenerationMixin, FlaxGreedySearchOutput, FlaxSampleOutput
else:

View File

@ -16,6 +16,13 @@ class FlaxForcedEOSTokenLogitsProcessor(metaclass=DummyObject):
requires_backends(self, ["flax"])
class FlaxForceTokensLogitsProcessor(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxGenerationMixin(metaclass=DummyObject):
_backends = ["flax"]
@ -51,6 +58,20 @@ class FlaxMinLengthLogitsProcessor(metaclass=DummyObject):
requires_backends(self, ["flax"])
class FlaxSuppressTokensAtBeginLogitsProcessor(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxSuppressTokensLogitsProcessor(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxTemperatureLogitsWarper(metaclass=DummyObject):
_backends = ["flax"]
@ -72,6 +93,13 @@ class FlaxTopPLogitsWarper(metaclass=DummyObject):
requires_backends(self, ["flax"])
class FlaxWhisperTimeStampLogitsProcessor(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxPreTrainedModel(metaclass=DummyObject):
_backends = ["flax"]

View File

@ -79,6 +79,13 @@ class TextDatasetForNextSentencePrediction(metaclass=DummyObject):
requires_backends(self, ["torch"])
class AlternatingCodebooksLogitsProcessor(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class BeamScorer(metaclass=DummyObject):
_backends = ["torch"]
@ -93,6 +100,13 @@ class BeamSearchScorer(metaclass=DummyObject):
requires_backends(self, ["torch"])
class ClassifierFreeGuidanceLogitsProcessor(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class ConstrainedBeamSearchScorer(metaclass=DummyObject):
_backends = ["torch"]
@ -121,6 +135,41 @@ class DisjunctiveConstraint(metaclass=DummyObject):
requires_backends(self, ["torch"])
class EncoderNoRepeatNGramLogitsProcessor(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class EncoderRepetitionPenaltyLogitsProcessor(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class EpsilonLogitsWarper(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class EtaLogitsWarper(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class ExponentialDecayLengthPenalty(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class ForcedBOSTokenLogitsProcessor(metaclass=DummyObject):
_backends = ["torch"]
@ -135,6 +184,13 @@ class ForcedEOSTokenLogitsProcessor(metaclass=DummyObject):
requires_backends(self, ["torch"])
class ForceTokensLogitsProcessor(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class GenerationMixin(metaclass=DummyObject):
_backends = ["torch"]
@ -156,6 +212,13 @@ class InfNanRemoveLogitsProcessor(metaclass=DummyObject):
requires_backends(self, ["torch"])
class LogitNormalization(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class LogitsProcessor(metaclass=DummyObject):
_backends = ["torch"]
@ -261,6 +324,20 @@ class StoppingCriteriaList(metaclass=DummyObject):
requires_backends(self, ["torch"])
class SuppressTokensAtBeginLogitsProcessor(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class SuppressTokensLogitsProcessor(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class TemperatureLogitsWarper(metaclass=DummyObject):
_backends = ["torch"]
@ -289,6 +366,20 @@ class TypicalLogitsWarper(metaclass=DummyObject):
requires_backends(self, ["torch"])
class UnbatchedClassifierFreeGuidanceLogitsProcessor(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class WhisperTimeStampLogitsProcessor(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
def top_k_top_p_filtering(*args, **kwargs):
requires_backends(top_k_top_p_filtering, ["torch"])

View File

@ -30,6 +30,13 @@ class TFForcedEOSTokenLogitsProcessor(metaclass=DummyObject):
requires_backends(self, ["tf"])
class TFForceTokensLogitsProcessor(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFGenerationMixin(metaclass=DummyObject):
_backends = ["tf"]
@ -86,6 +93,20 @@ class TFRepetitionPenaltyLogitsProcessor(metaclass=DummyObject):
requires_backends(self, ["tf"])
class TFSuppressTokensAtBeginLogitsProcessor(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFSuppressTokensLogitsProcessor(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFTemperatureLogitsWarper(metaclass=DummyObject):
_backends = ["tf"]