diff --git a/docs/source/en/internal/generation_utils.md b/docs/source/en/internal/generation_utils.md index f5c882bf128..906ee4ea620 100644 --- a/docs/source/en/internal/generation_utils.md +++ b/docs/source/en/internal/generation_utils.md @@ -75,39 +75,104 @@ values. Here, for instance, it has two keys that are `sequences` and `scores`. We document here all output types. -### GreedySearchOutput - -[[autodoc]] generation.GreedySearchDecoderOnlyOutput +### PyTorch [[autodoc]] generation.GreedySearchEncoderDecoderOutput -[[autodoc]] generation.FlaxGreedySearchOutput - -### SampleOutput - -[[autodoc]] generation.SampleDecoderOnlyOutput +[[autodoc]] generation.GreedySearchDecoderOnlyOutput [[autodoc]] generation.SampleEncoderDecoderOutput -[[autodoc]] generation.FlaxSampleOutput - -### BeamSearchOutput - -[[autodoc]] generation.BeamSearchDecoderOnlyOutput +[[autodoc]] generation.SampleDecoderOnlyOutput [[autodoc]] generation.BeamSearchEncoderDecoderOutput -### BeamSampleOutput +[[autodoc]] generation.BeamSearchDecoderOnlyOutput + +[[autodoc]] generation.BeamSampleEncoderDecoderOutput [[autodoc]] generation.BeamSampleDecoderOnlyOutput -[[autodoc]] generation.BeamSampleEncoderDecoderOutput +[[autodoc]] generation.ContrastiveSearchEncoderDecoderOutput + +[[autodoc]] generation.ContrastiveSearchDecoderOnlyOutput + +### TensorFlow + +[[autodoc]] generation.TFGreedySearchEncoderDecoderOutput + +[[autodoc]] generation.TFGreedySearchDecoderOnlyOutput + +[[autodoc]] generation.TFSampleEncoderDecoderOutput + +[[autodoc]] generation.TFSampleDecoderOnlyOutput + +[[autodoc]] generation.TFBeamSearchEncoderDecoderOutput + +[[autodoc]] generation.TFBeamSearchDecoderOnlyOutput + +[[autodoc]] generation.TFBeamSampleEncoderDecoderOutput + +[[autodoc]] generation.TFBeamSampleDecoderOnlyOutput + +[[autodoc]] generation.TFContrastiveSearchEncoderDecoderOutput + +[[autodoc]] generation.TFContrastiveSearchDecoderOnlyOutput + +### FLAX + +[[autodoc]] generation.FlaxSampleOutput + +[[autodoc]] generation.FlaxGreedySearchOutput + +[[autodoc]] generation.FlaxBeamSearchOutput ## LogitsProcessor A [`LogitsProcessor`] can be used to modify the prediction scores of a language model head for generation. +### PyTorch + +[[autodoc]] AlternatingCodebooksLogitsProcessor + - __call__ + +[[autodoc]] ClassifierFreeGuidanceLogitsProcessor + - __call__ + +[[autodoc]] EncoderNoRepeatNGramLogitsProcessor + - __call__ + +[[autodoc]] EncoderRepetitionPenaltyLogitsProcessor + - __call__ + +[[autodoc]] EpsilonLogitsWarper + - __call__ + +[[autodoc]] EtaLogitsWarper + - __call__ + +[[autodoc]] ExponentialDecayLengthPenalty + - __call__ + +[[autodoc]] ForcedBOSTokenLogitsProcessor + - __call__ + +[[autodoc]] ForcedEOSTokenLogitsProcessor + - __call__ + +[[autodoc]] ForceTokensLogitsProcessor + - __call__ + +[[autodoc]] HammingDiversityLogitsProcessor + - __call__ + +[[autodoc]] InfNanRemoveLogitsProcessor + - __call__ + +[[autodoc]] LogitNormalization + - __call__ + [[autodoc]] LogitsProcessor - __call__ @@ -123,43 +188,54 @@ generation. [[autodoc]] MinNewTokensLengthLogitsProcessor - __call__ -[[autodoc]] TemperatureLogitsWarper - - __call__ - -[[autodoc]] RepetitionPenaltyLogitsProcessor - - __call__ - -[[autodoc]] TopPLogitsWarper - - __call__ - -[[autodoc]] TopKLogitsWarper - - __call__ - -[[autodoc]] TypicalLogitsWarper +[[autodoc]] NoBadWordsLogitsProcessor - __call__ [[autodoc]] NoRepeatNGramLogitsProcessor - __call__ -[[autodoc]] SequenceBiasLogitsProcessor - - __call__ - -[[autodoc]] NoBadWordsLogitsProcessor - - __call__ - [[autodoc]] PrefixConstrainedLogitsProcessor - __call__ -[[autodoc]] HammingDiversityLogitsProcessor +[[autodoc]] RepetitionPenaltyLogitsProcessor - __call__ -[[autodoc]] ForcedBOSTokenLogitsProcessor +[[autodoc]] SequenceBiasLogitsProcessor - __call__ -[[autodoc]] ForcedEOSTokenLogitsProcessor +[[autodoc]] SuppressTokensAtBeginLogitsProcessor - __call__ -[[autodoc]] InfNanRemoveLogitsProcessor +[[autodoc]] SuppressTokensLogitsProcessor + - __call__ + +[[autodoc]] TemperatureLogitsWarper + - __call__ + +[[autodoc]] TopKLogitsWarper + - __call__ + +[[autodoc]] TopPLogitsWarper + - __call__ + +[[autodoc]] TypicalLogitsWarper + - __call__ + +[[autodoc]] UnbatchedClassifierFreeGuidanceLogitsProcessor + - __call__ + +[[autodoc]] WhisperTimeStampLogitsProcessor + - __call__ + +### TensorFlow + +[[autodoc]] TFForcedBOSTokenLogitsProcessor + - __call__ + +[[autodoc]] TFForcedEOSTokenLogitsProcessor + - __call__ + +[[autodoc]] TFForceTokensLogitsProcessor - __call__ [[autodoc]] TFLogitsProcessor @@ -171,15 +247,6 @@ generation. [[autodoc]] TFLogitsWarper - __call__ -[[autodoc]] TFTemperatureLogitsWarper - - __call__ - -[[autodoc]] TFTopPLogitsWarper - - __call__ - -[[autodoc]] TFTopKLogitsWarper - - __call__ - [[autodoc]] TFMinLengthLogitsProcessor - __call__ @@ -192,10 +259,30 @@ generation. [[autodoc]] TFRepetitionPenaltyLogitsProcessor - __call__ -[[autodoc]] TFForcedBOSTokenLogitsProcessor +[[autodoc]] TFSuppressTokensAtBeginLogitsProcessor - __call__ -[[autodoc]] TFForcedEOSTokenLogitsProcessor +[[autodoc]] TFSuppressTokensLogitsProcessor + - __call__ + +[[autodoc]] TFTemperatureLogitsWarper + - __call__ + +[[autodoc]] TFTopKLogitsWarper + - __call__ + +[[autodoc]] TFTopPLogitsWarper + - __call__ + +### FLAX + +[[autodoc]] FlaxForcedBOSTokenLogitsProcessor + - __call__ + +[[autodoc]] FlaxForcedEOSTokenLogitsProcessor + - __call__ + +[[autodoc]] FlaxForceTokensLogitsProcessor - __call__ [[autodoc]] FlaxLogitsProcessor @@ -207,27 +294,30 @@ generation. [[autodoc]] FlaxLogitsWarper - __call__ -[[autodoc]] FlaxTemperatureLogitsWarper +[[autodoc]] FlaxMinLengthLogitsProcessor - __call__ -[[autodoc]] FlaxTopPLogitsWarper +[[autodoc]] FlaxSuppressTokensAtBeginLogitsProcessor + - __call__ + +[[autodoc]] FlaxSuppressTokensLogitsProcessor + - __call__ + +[[autodoc]] FlaxTemperatureLogitsWarper - __call__ [[autodoc]] FlaxTopKLogitsWarper - __call__ -[[autodoc]] FlaxForcedBOSTokenLogitsProcessor +[[autodoc]] FlaxTopPLogitsWarper - __call__ -[[autodoc]] FlaxForcedEOSTokenLogitsProcessor - - __call__ - -[[autodoc]] FlaxMinLengthLogitsProcessor +[[autodoc]] FlaxWhisperTimeStampLogitsProcessor - __call__ ## StoppingCriteria -A [`StoppingCriteria`] can be used to change when to stop generation (other than EOS token). +A [`StoppingCriteria`] can be used to change when to stop generation (other than EOS token). Please note that this is exclusivelly available to our PyTorch implementations. [[autodoc]] StoppingCriteria - __call__ @@ -243,7 +333,7 @@ A [`StoppingCriteria`] can be used to change when to stop generation (other than ## Constraints -A [`Constraint`] can be used to force the generation to include specific tokens or sequences in the output. +A [`Constraint`] can be used to force the generation to include specific tokens or sequences in the output. Please note that this is exclusivelly available to our PyTorch implementations. [[autodoc]] Constraint diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 9024481269a..1b41335c11f 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1005,17 +1005,26 @@ else: _import_structure["deepspeed"] = [] _import_structure["generation"].extend( [ + "AlternatingCodebooksLogitsProcessor", "BeamScorer", "BeamSearchScorer", + "ClassifierFreeGuidanceLogitsProcessor", "ConstrainedBeamSearchScorer", "Constraint", "ConstraintListState", "DisjunctiveConstraint", + "EncoderNoRepeatNGramLogitsProcessor", + "EncoderRepetitionPenaltyLogitsProcessor", + "EpsilonLogitsWarper", + "EtaLogitsWarper", + "ExponentialDecayLengthPenalty", "ForcedBOSTokenLogitsProcessor", "ForcedEOSTokenLogitsProcessor", + "ForceTokensLogitsProcessor", "GenerationMixin", "HammingDiversityLogitsProcessor", "InfNanRemoveLogitsProcessor", + "LogitNormalization", "LogitsProcessor", "LogitsProcessorList", "LogitsWarper", @@ -1031,10 +1040,14 @@ else: "SequenceBiasLogitsProcessor", "StoppingCriteria", "StoppingCriteriaList", + "SuppressTokensAtBeginLogitsProcessor", + "SuppressTokensLogitsProcessor", "TemperatureLogitsWarper", "TopKLogitsWarper", "TopPLogitsWarper", "TypicalLogitsWarper", + "UnbatchedClassifierFreeGuidanceLogitsProcessor", + "WhisperTimeStampLogitsProcessor", "top_k_top_p_filtering", ] ) @@ -3115,6 +3128,7 @@ else: [ "TFForcedBOSTokenLogitsProcessor", "TFForcedEOSTokenLogitsProcessor", + "TFForceTokensLogitsProcessor", "TFGenerationMixin", "TFLogitsProcessor", "TFLogitsProcessorList", @@ -3123,6 +3137,8 @@ else: "TFNoBadWordsLogitsProcessor", "TFNoRepeatNGramLogitsProcessor", "TFRepetitionPenaltyLogitsProcessor", + "TFSuppressTokensAtBeginLogitsProcessor", + "TFSuppressTokensLogitsProcessor", "TFTemperatureLogitsWarper", "TFTopKLogitsWarper", "TFTopPLogitsWarper", @@ -3836,14 +3852,18 @@ else: [ "FlaxForcedBOSTokenLogitsProcessor", "FlaxForcedEOSTokenLogitsProcessor", + "FlaxForceTokensLogitsProcessor", "FlaxGenerationMixin", "FlaxLogitsProcessor", "FlaxLogitsProcessorList", "FlaxLogitsWarper", "FlaxMinLengthLogitsProcessor", "FlaxTemperatureLogitsWarper", + "FlaxSuppressTokensAtBeginLogitsProcessor", + "FlaxSuppressTokensLogitsProcessor", "FlaxTopKLogitsWarper", "FlaxTopPLogitsWarper", + "FlaxWhisperTimeStampLogitsProcessor", ] ) _import_structure["generation_flax_utils"] = [] @@ -4983,17 +5003,26 @@ if TYPE_CHECKING: TextDatasetForNextSentencePrediction, ) from .generation import ( + AlternatingCodebooksLogitsProcessor, BeamScorer, BeamSearchScorer, + ClassifierFreeGuidanceLogitsProcessor, ConstrainedBeamSearchScorer, Constraint, ConstraintListState, DisjunctiveConstraint, + EncoderNoRepeatNGramLogitsProcessor, + EncoderRepetitionPenaltyLogitsProcessor, + EpsilonLogitsWarper, + EtaLogitsWarper, + ExponentialDecayLengthPenalty, ForcedBOSTokenLogitsProcessor, ForcedEOSTokenLogitsProcessor, + ForceTokensLogitsProcessor, GenerationMixin, HammingDiversityLogitsProcessor, InfNanRemoveLogitsProcessor, + LogitNormalization, LogitsProcessor, LogitsProcessorList, LogitsWarper, @@ -5009,10 +5038,14 @@ if TYPE_CHECKING: SequenceBiasLogitsProcessor, StoppingCriteria, StoppingCriteriaList, + SuppressTokensAtBeginLogitsProcessor, + SuppressTokensLogitsProcessor, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper, TypicalLogitsWarper, + UnbatchedClassifierFreeGuidanceLogitsProcessor, + WhisperTimeStampLogitsProcessor, top_k_top_p_filtering, ) from .modeling_utils import PreTrainedModel @@ -6712,6 +6745,7 @@ if TYPE_CHECKING: from .generation import ( TFForcedBOSTokenLogitsProcessor, TFForcedEOSTokenLogitsProcessor, + TFForceTokensLogitsProcessor, TFGenerationMixin, TFLogitsProcessor, TFLogitsProcessorList, @@ -6720,6 +6754,8 @@ if TYPE_CHECKING: TFNoBadWordsLogitsProcessor, TFNoRepeatNGramLogitsProcessor, TFRepetitionPenaltyLogitsProcessor, + TFSuppressTokensAtBeginLogitsProcessor, + TFSuppressTokensLogitsProcessor, TFTemperatureLogitsWarper, TFTopKLogitsWarper, TFTopPLogitsWarper, @@ -7285,14 +7321,18 @@ if TYPE_CHECKING: from .generation import ( FlaxForcedBOSTokenLogitsProcessor, FlaxForcedEOSTokenLogitsProcessor, + FlaxForceTokensLogitsProcessor, FlaxGenerationMixin, FlaxLogitsProcessor, FlaxLogitsProcessorList, FlaxLogitsWarper, FlaxMinLengthLogitsProcessor, + FlaxSuppressTokensAtBeginLogitsProcessor, + FlaxSuppressTokensLogitsProcessor, FlaxTemperatureLogitsWarper, FlaxTopKLogitsWarper, FlaxTopPLogitsWarper, + FlaxWhisperTimeStampLogitsProcessor, ) from .modeling_flax_utils import FlaxPreTrainedModel diff --git a/src/transformers/generation/__init__.py b/src/transformers/generation/__init__.py index f0da9f514e7..a46cb4fa910 100644 --- a/src/transformers/generation/__init__.py +++ b/src/transformers/generation/__init__.py @@ -41,12 +41,19 @@ else: "ConstrainedBeamSearchScorer", ] _import_structure["logits_process"] = [ + "AlternatingCodebooksLogitsProcessor", + "ClassifierFreeGuidanceLogitsProcessor", + "EncoderNoRepeatNGramLogitsProcessor", + "EncoderRepetitionPenaltyLogitsProcessor", "EpsilonLogitsWarper", "EtaLogitsWarper", + "ExponentialDecayLengthPenalty", "ForcedBOSTokenLogitsProcessor", "ForcedEOSTokenLogitsProcessor", + "ForceTokensLogitsProcessor", "HammingDiversityLogitsProcessor", "InfNanRemoveLogitsProcessor", + "LogitNormalization", "LogitsProcessor", "LogitsProcessorList", "LogitsWarper", @@ -57,15 +64,14 @@ else: "PrefixConstrainedLogitsProcessor", "RepetitionPenaltyLogitsProcessor", "SequenceBiasLogitsProcessor", - "EncoderRepetitionPenaltyLogitsProcessor", + "SuppressTokensLogitsProcessor", + "SuppressTokensAtBeginLogitsProcessor", "TemperatureLogitsWarper", "TopKLogitsWarper", "TopPLogitsWarper", "TypicalLogitsWarper", - "EncoderNoRepeatNGramLogitsProcessor", - "ExponentialDecayLengthPenalty", - "LogitNormalization", "UnbatchedClassifierFreeGuidanceLogitsProcessor", + "WhisperTimeStampLogitsProcessor", ] _import_structure["stopping_criteria"] = [ "MaxNewTokensCriteria", @@ -99,6 +105,7 @@ else: _import_structure["tf_logits_process"] = [ "TFForcedBOSTokenLogitsProcessor", "TFForcedEOSTokenLogitsProcessor", + "TFForceTokensLogitsProcessor", "TFLogitsProcessor", "TFLogitsProcessorList", "TFLogitsWarper", @@ -106,12 +113,11 @@ else: "TFNoBadWordsLogitsProcessor", "TFNoRepeatNGramLogitsProcessor", "TFRepetitionPenaltyLogitsProcessor", + "TFSuppressTokensAtBeginLogitsProcessor", + "TFSuppressTokensLogitsProcessor", "TFTemperatureLogitsWarper", "TFTopKLogitsWarper", "TFTopPLogitsWarper", - "TFForceTokensLogitsProcessor", - "TFSuppressTokensAtBeginLogitsProcessor", - "TFSuppressTokensLogitsProcessor", ] _import_structure["tf_utils"] = [ "TFGenerationMixin", @@ -137,13 +143,17 @@ else: _import_structure["flax_logits_process"] = [ "FlaxForcedBOSTokenLogitsProcessor", "FlaxForcedEOSTokenLogitsProcessor", + "FlaxForceTokensLogitsProcessor", "FlaxLogitsProcessor", "FlaxLogitsProcessorList", "FlaxLogitsWarper", "FlaxMinLengthLogitsProcessor", + "FlaxSuppressTokensAtBeginLogitsProcessor", + "FlaxSuppressTokensLogitsProcessor", "FlaxTemperatureLogitsWarper", "FlaxTopKLogitsWarper", "FlaxTopPLogitsWarper", + "FlaxWhisperTimeStampLogitsProcessor", ] _import_structure["flax_utils"] = [ "FlaxGenerationMixin", @@ -165,6 +175,8 @@ if TYPE_CHECKING: from .beam_constraints import Constraint, ConstraintListState, DisjunctiveConstraint, PhrasalConstraint from .beam_search import BeamHypotheses, BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer from .logits_process import ( + AlternatingCodebooksLogitsProcessor, + ClassifierFreeGuidanceLogitsProcessor, EncoderNoRepeatNGramLogitsProcessor, EncoderRepetitionPenaltyLogitsProcessor, EpsilonLogitsWarper, @@ -172,6 +184,7 @@ if TYPE_CHECKING: ExponentialDecayLengthPenalty, ForcedBOSTokenLogitsProcessor, ForcedEOSTokenLogitsProcessor, + ForceTokensLogitsProcessor, HammingDiversityLogitsProcessor, InfNanRemoveLogitsProcessor, LogitNormalization, @@ -185,11 +198,14 @@ if TYPE_CHECKING: PrefixConstrainedLogitsProcessor, RepetitionPenaltyLogitsProcessor, SequenceBiasLogitsProcessor, + SuppressTokensAtBeginLogitsProcessor, + SuppressTokensLogitsProcessor, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper, TypicalLogitsWarper, UnbatchedClassifierFreeGuidanceLogitsProcessor, + WhisperTimeStampLogitsProcessor, ) from .stopping_criteria import ( MaxLengthCriteria, @@ -261,13 +277,17 @@ if TYPE_CHECKING: from .flax_logits_process import ( FlaxForcedBOSTokenLogitsProcessor, FlaxForcedEOSTokenLogitsProcessor, + FlaxForceTokensLogitsProcessor, FlaxLogitsProcessor, FlaxLogitsProcessorList, FlaxLogitsWarper, FlaxMinLengthLogitsProcessor, + FlaxSuppressTokensAtBeginLogitsProcessor, + FlaxSuppressTokensLogitsProcessor, FlaxTemperatureLogitsWarper, FlaxTopKLogitsWarper, FlaxTopPLogitsWarper, + FlaxWhisperTimeStampLogitsProcessor, ) from .flax_utils import FlaxBeamSearchOutput, FlaxGenerationMixin, FlaxGreedySearchOutput, FlaxSampleOutput else: diff --git a/src/transformers/utils/dummy_flax_objects.py b/src/transformers/utils/dummy_flax_objects.py index 7e5b78d3e6f..4090e4ff513 100644 --- a/src/transformers/utils/dummy_flax_objects.py +++ b/src/transformers/utils/dummy_flax_objects.py @@ -16,6 +16,13 @@ class FlaxForcedEOSTokenLogitsProcessor(metaclass=DummyObject): requires_backends(self, ["flax"]) +class FlaxForceTokensLogitsProcessor(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + class FlaxGenerationMixin(metaclass=DummyObject): _backends = ["flax"] @@ -51,6 +58,20 @@ class FlaxMinLengthLogitsProcessor(metaclass=DummyObject): requires_backends(self, ["flax"]) +class FlaxSuppressTokensAtBeginLogitsProcessor(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxSuppressTokensLogitsProcessor(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + class FlaxTemperatureLogitsWarper(metaclass=DummyObject): _backends = ["flax"] @@ -72,6 +93,13 @@ class FlaxTopPLogitsWarper(metaclass=DummyObject): requires_backends(self, ["flax"]) +class FlaxWhisperTimeStampLogitsProcessor(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + class FlaxPreTrainedModel(metaclass=DummyObject): _backends = ["flax"] diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 5724e689f2f..c1cdc3955e9 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -79,6 +79,13 @@ class TextDatasetForNextSentencePrediction(metaclass=DummyObject): requires_backends(self, ["torch"]) +class AlternatingCodebooksLogitsProcessor(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class BeamScorer(metaclass=DummyObject): _backends = ["torch"] @@ -93,6 +100,13 @@ class BeamSearchScorer(metaclass=DummyObject): requires_backends(self, ["torch"]) +class ClassifierFreeGuidanceLogitsProcessor(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class ConstrainedBeamSearchScorer(metaclass=DummyObject): _backends = ["torch"] @@ -121,6 +135,41 @@ class DisjunctiveConstraint(metaclass=DummyObject): requires_backends(self, ["torch"]) +class EncoderNoRepeatNGramLogitsProcessor(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class EncoderRepetitionPenaltyLogitsProcessor(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class EpsilonLogitsWarper(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class EtaLogitsWarper(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class ExponentialDecayLengthPenalty(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class ForcedBOSTokenLogitsProcessor(metaclass=DummyObject): _backends = ["torch"] @@ -135,6 +184,13 @@ class ForcedEOSTokenLogitsProcessor(metaclass=DummyObject): requires_backends(self, ["torch"]) +class ForceTokensLogitsProcessor(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class GenerationMixin(metaclass=DummyObject): _backends = ["torch"] @@ -156,6 +212,13 @@ class InfNanRemoveLogitsProcessor(metaclass=DummyObject): requires_backends(self, ["torch"]) +class LogitNormalization(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class LogitsProcessor(metaclass=DummyObject): _backends = ["torch"] @@ -261,6 +324,20 @@ class StoppingCriteriaList(metaclass=DummyObject): requires_backends(self, ["torch"]) +class SuppressTokensAtBeginLogitsProcessor(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class SuppressTokensLogitsProcessor(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class TemperatureLogitsWarper(metaclass=DummyObject): _backends = ["torch"] @@ -289,6 +366,20 @@ class TypicalLogitsWarper(metaclass=DummyObject): requires_backends(self, ["torch"]) +class UnbatchedClassifierFreeGuidanceLogitsProcessor(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class WhisperTimeStampLogitsProcessor(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + def top_k_top_p_filtering(*args, **kwargs): requires_backends(top_k_top_p_filtering, ["torch"]) diff --git a/src/transformers/utils/dummy_tf_objects.py b/src/transformers/utils/dummy_tf_objects.py index 46cde8ffbef..9b1aae44932 100644 --- a/src/transformers/utils/dummy_tf_objects.py +++ b/src/transformers/utils/dummy_tf_objects.py @@ -30,6 +30,13 @@ class TFForcedEOSTokenLogitsProcessor(metaclass=DummyObject): requires_backends(self, ["tf"]) +class TFForceTokensLogitsProcessor(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + class TFGenerationMixin(metaclass=DummyObject): _backends = ["tf"] @@ -86,6 +93,20 @@ class TFRepetitionPenaltyLogitsProcessor(metaclass=DummyObject): requires_backends(self, ["tf"]) +class TFSuppressTokensAtBeginLogitsProcessor(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + +class TFSuppressTokensLogitsProcessor(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + class TFTemperatureLogitsWarper(metaclass=DummyObject): _backends = ["tf"]