mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[Flax] FlaxAutoModelForSeq2SeqLM (#12228)
* add FlaxAutoModelForSeq2SeqLM
This commit is contained in:
parent
e43e11260f
commit
f74655cd9b
@ -226,6 +226,13 @@ FlaxAutoModelForMaskedLM
|
||||
:members:
|
||||
|
||||
|
||||
FlaxAutoModelForSeq2SeqLM
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.FlaxAutoModelForSeq2SeqLM
|
||||
:members:
|
||||
|
||||
|
||||
FlaxAutoModelForSequenceClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
@ -1514,6 +1514,7 @@ if is_flax_available():
|
||||
"FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
|
||||
"FLAX_MODEL_FOR_PRETRAINING_MAPPING",
|
||||
"FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING",
|
||||
"FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
|
||||
"FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
|
||||
"FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
|
||||
"FLAX_MODEL_MAPPING",
|
||||
@ -1524,6 +1525,7 @@ if is_flax_available():
|
||||
"FlaxAutoModelForNextSentencePrediction",
|
||||
"FlaxAutoModelForPreTraining",
|
||||
"FlaxAutoModelForQuestionAnswering",
|
||||
"FlaxAutoModelForSeq2SeqLM",
|
||||
"FlaxAutoModelForSequenceClassification",
|
||||
"FlaxAutoModelForTokenClassification",
|
||||
]
|
||||
@ -2851,6 +2853,7 @@ if TYPE_CHECKING:
|
||||
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
|
||||
FLAX_MODEL_FOR_PRETRAINING_MAPPING,
|
||||
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
FLAX_MODEL_MAPPING,
|
||||
@ -2861,6 +2864,7 @@ if TYPE_CHECKING:
|
||||
FlaxAutoModelForNextSentencePrediction,
|
||||
FlaxAutoModelForPreTraining,
|
||||
FlaxAutoModelForQuestionAnswering,
|
||||
FlaxAutoModelForSeq2SeqLM,
|
||||
FlaxAutoModelForSequenceClassification,
|
||||
FlaxAutoModelForTokenClassification,
|
||||
)
|
||||
|
@ -92,6 +92,7 @@ if is_flax_available():
|
||||
"FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
|
||||
"FLAX_MODEL_FOR_PRETRAINING_MAPPING",
|
||||
"FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING",
|
||||
"FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
|
||||
"FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
|
||||
"FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
|
||||
"FLAX_MODEL_MAPPING",
|
||||
@ -103,6 +104,7 @@ if is_flax_available():
|
||||
"FlaxAutoModelForNextSentencePrediction",
|
||||
"FlaxAutoModelForPreTraining",
|
||||
"FlaxAutoModelForQuestionAnswering",
|
||||
"FlaxAutoModelForSeq2SeqLM",
|
||||
"FlaxAutoModelForSequenceClassification",
|
||||
"FlaxAutoModelForTokenClassification",
|
||||
]
|
||||
@ -178,6 +180,7 @@ if TYPE_CHECKING:
|
||||
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
|
||||
FLAX_MODEL_FOR_PRETRAINING_MAPPING,
|
||||
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
FLAX_MODEL_MAPPING,
|
||||
@ -189,6 +192,7 @@ if TYPE_CHECKING:
|
||||
FlaxAutoModelForNextSentencePrediction,
|
||||
FlaxAutoModelForPreTraining,
|
||||
FlaxAutoModelForQuestionAnswering,
|
||||
FlaxAutoModelForSeq2SeqLM,
|
||||
FlaxAutoModelForSequenceClassification,
|
||||
FlaxAutoModelForTokenClassification,
|
||||
)
|
||||
|
@ -129,6 +129,13 @@ FLAX_MODEL_FOR_CAUSAL_LM_MAPPING = OrderedDict(
|
||||
]
|
||||
)
|
||||
|
||||
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
|
||||
[
|
||||
# Model for Seq2Seq Causal LM mapping
|
||||
(BartConfig, FlaxBartForConditionalGeneration)
|
||||
]
|
||||
)
|
||||
|
||||
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
|
||||
[
|
||||
# Model for Sequence Classification mapping
|
||||
@ -197,6 +204,13 @@ FlaxAutoModelForMaskedLM = auto_class_factory(
|
||||
"FlaxAutoModelForMaskedLM", FLAX_MODEL_FOR_MASKED_LM_MAPPING, head_doc="masked language modeling"
|
||||
)
|
||||
|
||||
|
||||
FlaxAutoModelForSeq2SeqLM = auto_class_factory(
|
||||
"FlaxAutoModelForSeq2SeqLM",
|
||||
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||
head_doc="sequence-to-sequence language modeling",
|
||||
)
|
||||
|
||||
FlaxAutoModelForSequenceClassification = auto_class_factory(
|
||||
"FlaxAutoModelForSequenceClassification",
|
||||
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
|
@ -94,6 +94,9 @@ FLAX_MODEL_FOR_PRETRAINING_MAPPING = None
|
||||
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING = None
|
||||
|
||||
|
||||
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = None
|
||||
|
||||
|
||||
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = None
|
||||
|
||||
|
||||
@ -166,6 +169,15 @@ class FlaxAutoModelForQuestionAnswering:
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
|
||||
class FlaxAutoModelForSeq2SeqLM:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
|
||||
class FlaxAutoModelForSequenceClassification:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
Loading…
Reference in New Issue
Block a user