[Flax] FlaxAutoModelForSeq2SeqLM (#12228)

* add FlaxAutoModelForSeq2SeqLM
This commit is contained in:
Suraj Patil 2021-06-18 13:20:09 +05:30 committed by GitHub
parent e43e11260f
commit f74655cd9b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 41 additions and 0 deletions

View File

@ -226,6 +226,13 @@ FlaxAutoModelForMaskedLM
:members:
FlaxAutoModelForSeq2SeqLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxAutoModelForSeq2SeqLM
:members:
FlaxAutoModelForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -1514,6 +1514,7 @@ if is_flax_available():
"FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
"FLAX_MODEL_FOR_PRETRAINING_MAPPING",
"FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING",
"FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
"FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
"FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
"FLAX_MODEL_MAPPING",
@ -1524,6 +1525,7 @@ if is_flax_available():
"FlaxAutoModelForNextSentencePrediction",
"FlaxAutoModelForPreTraining",
"FlaxAutoModelForQuestionAnswering",
"FlaxAutoModelForSeq2SeqLM",
"FlaxAutoModelForSequenceClassification",
"FlaxAutoModelForTokenClassification",
]
@ -2851,6 +2853,7 @@ if TYPE_CHECKING:
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
FLAX_MODEL_FOR_PRETRAINING_MAPPING,
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
FLAX_MODEL_MAPPING,
@ -2861,6 +2864,7 @@ if TYPE_CHECKING:
FlaxAutoModelForNextSentencePrediction,
FlaxAutoModelForPreTraining,
FlaxAutoModelForQuestionAnswering,
FlaxAutoModelForSeq2SeqLM,
FlaxAutoModelForSequenceClassification,
FlaxAutoModelForTokenClassification,
)

View File

@ -92,6 +92,7 @@ if is_flax_available():
"FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
"FLAX_MODEL_FOR_PRETRAINING_MAPPING",
"FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING",
"FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
"FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
"FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
"FLAX_MODEL_MAPPING",
@ -103,6 +104,7 @@ if is_flax_available():
"FlaxAutoModelForNextSentencePrediction",
"FlaxAutoModelForPreTraining",
"FlaxAutoModelForQuestionAnswering",
"FlaxAutoModelForSeq2SeqLM",
"FlaxAutoModelForSequenceClassification",
"FlaxAutoModelForTokenClassification",
]
@ -178,6 +180,7 @@ if TYPE_CHECKING:
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
FLAX_MODEL_FOR_PRETRAINING_MAPPING,
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
FLAX_MODEL_MAPPING,
@ -189,6 +192,7 @@ if TYPE_CHECKING:
FlaxAutoModelForNextSentencePrediction,
FlaxAutoModelForPreTraining,
FlaxAutoModelForQuestionAnswering,
FlaxAutoModelForSeq2SeqLM,
FlaxAutoModelForSequenceClassification,
FlaxAutoModelForTokenClassification,
)

View File

@ -129,6 +129,13 @@ FLAX_MODEL_FOR_CAUSAL_LM_MAPPING = OrderedDict(
]
)
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
[
# Model for Seq2Seq Causal LM mapping
(BartConfig, FlaxBartForConditionalGeneration)
]
)
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
[
# Model for Sequence Classification mapping
@ -197,6 +204,13 @@ FlaxAutoModelForMaskedLM = auto_class_factory(
"FlaxAutoModelForMaskedLM", FLAX_MODEL_FOR_MASKED_LM_MAPPING, head_doc="masked language modeling"
)
FlaxAutoModelForSeq2SeqLM = auto_class_factory(
"FlaxAutoModelForSeq2SeqLM",
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
head_doc="sequence-to-sequence language modeling",
)
FlaxAutoModelForSequenceClassification = auto_class_factory(
"FlaxAutoModelForSequenceClassification",
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,

View File

@ -94,6 +94,9 @@ FLAX_MODEL_FOR_PRETRAINING_MAPPING = None
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING = None
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = None
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = None
@ -166,6 +169,15 @@ class FlaxAutoModelForQuestionAnswering:
requires_backends(cls, ["flax"])
class FlaxAutoModelForSeq2SeqLM:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxAutoModelForSequenceClassification:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])