Add Flax RoFormer (#15005)

* Add FlaxRoFormer

* Clean code + make quality

* Fix output pooling for FlaxRoFormerForMultipleChoiceModule

* Apply suggestions from code review

* add flax model to repos

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Daniel Stancl 2022-01-04 13:23:10 +01:00 committed by GitHub
parent 9e1775dd23
commit 21aecc0971
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 1417 additions and 2 deletions

View File

@ -246,7 +246,7 @@ Flax), PyTorch, and/or TensorFlow.
| RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
| RoBERTa | ✅ | ✅ | ✅ | ✅ | ✅ |
| RoFormer | ✅ | ✅ | ✅ | ✅ | |
| RoFormer | ✅ | ✅ | ✅ | ✅ | |
| SegFormer | ❌ | ❌ | ✅ | ❌ | ❌ |
| SEW | ❌ | ❌ | ✅ | ❌ | ❌ |
| SEW-D | ❌ | ❌ | ✅ | ❌ | ❌ |

View File

@ -123,3 +123,33 @@ This model was contributed by [junnyu](https://huggingface.co/junnyu). The origi
[[autodoc]] TFRoFormerForQuestionAnswering
- call
## FlaxRoFormerModel
[[autodoc]] FlaxRoFormerModel
- __call__
## FlaxRoFormerForMaskedLM
[[autodoc]] FlaxRoFormerForMaskedLM
- __call__
## FlaxRoFormerForSequenceClassification
[[autodoc]] FlaxRoFormerForSequenceClassification
- __call__
## FlaxRoFormerForMultipleChoice
[[autodoc]] FlaxRoFormerForMultipleChoice
- __call__
## FlaxRoFormerForTokenClassification
[[autodoc]] FlaxRoFormerForTokenClassification
- __call__
## FlaxRoFormerForQuestionAnswering
[[autodoc]] FlaxRoFormerForQuestionAnswering
- __call__

View File

@ -2084,6 +2084,17 @@ if is_flax_available():
"FlaxRobertaPreTrainedModel",
]
)
_import_structure["models.roformer"].extend(
[
"FlaxRoFormerForMaskedLM",
"FlaxRoFormerForMultipleChoice",
"FlaxRoFormerForQuestionAnswering",
"FlaxRoFormerForSequenceClassification",
"FlaxRoFormerForTokenClassification",
"FlaxRoFormerModel",
"FlaxRoFormerPreTrainedModel",
]
)
_import_structure["models.t5"].extend(["FlaxT5ForConditionalGeneration", "FlaxT5Model", "FlaxT5PreTrainedModel"])
_import_structure["models.vision_encoder_decoder"].append("FlaxVisionEncoderDecoderModel")
_import_structure["models.vision_text_dual_encoder"].extend(["FlaxVisionTextDualEncoderModel"])
@ -3819,6 +3830,15 @@ if TYPE_CHECKING:
FlaxRobertaModel,
FlaxRobertaPreTrainedModel,
)
from .models.roformer import (
FlaxRoFormerForMaskedLM,
FlaxRoFormerForMultipleChoice,
FlaxRoFormerForQuestionAnswering,
FlaxRoFormerForSequenceClassification,
FlaxRoFormerForTokenClassification,
FlaxRoFormerModel,
FlaxRoFormerPreTrainedModel,
)
from .models.t5 import FlaxT5ForConditionalGeneration, FlaxT5Model, FlaxT5PreTrainedModel
from .models.vision_encoder_decoder import FlaxVisionEncoderDecoderModel
from .models.vision_text_dual_encoder import FlaxVisionTextDualEncoderModel

View File

@ -50,6 +50,7 @@ FLAX_MODEL_MAPPING_NAMES = OrderedDict(
("wav2vec2", "FlaxWav2Vec2Model"),
("marian", "FlaxMarianModel"),
("blenderbot", "FlaxBlenderbotModel"),
("roformer", "FlaxRoFormerModel"),
]
)
@ -66,6 +67,7 @@ FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
("t5", "FlaxT5ForConditionalGeneration"),
("mt5", "FlaxMT5ForConditionalGeneration"),
("wav2vec2", "FlaxWav2Vec2ForPreTraining"),
("roformer", "FlaxRoFormerForMaskedLM"),
]
)
@ -80,6 +82,7 @@ FLAX_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
("bart", "FlaxBartForConditionalGeneration"),
("electra", "FlaxElectraForMaskedLM"),
("mbart", "FlaxMBartForConditionalGeneration"),
("roformer", "FlaxRoFormerForMaskedLM"),
]
)
@ -132,6 +135,7 @@ FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("bart", "FlaxBartForSequenceClassification"),
("electra", "FlaxElectraForSequenceClassification"),
("mbart", "FlaxMBartForSequenceClassification"),
("roformer", "FlaxRoFormerForSequenceClassification"),
]
)
@ -146,6 +150,7 @@ FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
("bart", "FlaxBartForQuestionAnswering"),
("electra", "FlaxElectraForQuestionAnswering"),
("mbart", "FlaxMBartForQuestionAnswering"),
("roformer", "FlaxRoFormerForQuestionAnswering"),
]
)
@ -158,6 +163,7 @@ FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("bert", "FlaxBertForTokenClassification"),
("big_bird", "FlaxBigBirdForTokenClassification"),
("electra", "FlaxElectraForTokenClassification"),
("roformer", "FlaxRoFormerForTokenClassification"),
]
)
@ -170,6 +176,7 @@ FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
("bert", "FlaxBertForMultipleChoice"),
("big_bird", "FlaxBigBirdForMultipleChoice"),
("electra", "FlaxElectraForMultipleChoice"),
("roformer", "FlaxRoFormerForMultipleChoice"),
]
)

View File

@ -17,7 +17,7 @@
# limitations under the License.
from typing import TYPE_CHECKING
from ...file_utils import _LazyModule, is_tf_available, is_tokenizers_available, is_torch_available
from ...file_utils import _LazyModule, is_flax_available, is_tf_available, is_tokenizers_available, is_torch_available
_import_structure = {
@ -59,6 +59,19 @@ if is_tf_available():
]
if is_flax_available():
_import_structure["modeling_flax_roformer"] = [
"FLAX_ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"FlaxRoFormerForMaskedLM",
"FlaxRoFormerForMultipleChoice",
"FlaxRoFormerForQuestionAnswering",
"FlaxRoFormerForSequenceClassification",
"FlaxRoFormerForTokenClassification",
"FlaxRoFormerModel",
"FlaxRoFormerPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_roformer import ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, RoFormerConfig
from .tokenization_roformer import RoFormerTokenizer
@ -95,6 +108,18 @@ if TYPE_CHECKING:
TFRoFormerPreTrainedModel,
)
if is_flax_available():
from .modeling_flax_roformer import (
FLAX_ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
FlaxRoFormerForMaskedLM,
FlaxRoFormerForMultipleChoice,
FlaxRoFormerForQuestionAnswering,
FlaxRoFormerForSequenceClassification,
FlaxRoFormerForTokenClassification,
FlaxRoFormerModel,
FlaxRoFormerPreTrainedModel,
)
else:
import sys

File diff suppressed because it is too large Load Diff

View File

@ -1316,6 +1316,90 @@ class FlaxRobertaPreTrainedModel:
requires_backends(self, ["flax"])
class FlaxRoFormerForMaskedLM:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
def __call__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxRoFormerForMultipleChoice:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
def __call__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxRoFormerForQuestionAnswering:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
def __call__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxRoFormerForSequenceClassification:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
def __call__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxRoFormerForTokenClassification:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
def __call__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxRoFormerModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
def __call__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxRoFormerPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
def __call__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxT5ForConditionalGeneration:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])

View File

@ -0,0 +1,163 @@
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from transformers import RoFormerConfig, is_flax_available
from transformers.testing_utils import require_flax, slow
from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
if is_flax_available():
import jax.numpy as jnp
from transformers.models.roformer.modeling_flax_roformer import (
FlaxRoFormerForMaskedLM,
FlaxRoFormerForMultipleChoice,
FlaxRoFormerForQuestionAnswering,
FlaxRoFormerForSequenceClassification,
FlaxRoFormerForTokenClassification,
FlaxRoFormerModel,
)
class FlaxRoFormerModelTester(unittest.TestCase):
def __init__(
self,
parent,
batch_size=13,
seq_length=7,
is_training=True,
use_attention_mask=True,
use_token_type_ids=True,
use_labels=True,
vocab_size=99,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
intermediate_size=37,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
type_sequence_label_size=2,
initializer_range=0.02,
num_choices=4,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_attention_mask = use_attention_mask
self.use_token_type_ids = use_token_type_ids
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.num_choices = num_choices
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
attention_mask = None
if self.use_attention_mask:
attention_mask = random_attention_mask([self.batch_size, self.seq_length])
token_type_ids = None
if self.use_token_type_ids:
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
config = RoFormerConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
hidden_act=self.hidden_act,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size,
is_decoder=False,
initializer_range=self.initializer_range,
)
return config, input_ids, token_type_ids, attention_mask
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, input_ids, token_type_ids, attention_mask = config_and_inputs
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": attention_mask}
return config, inputs_dict
@require_flax
class FlaxRoFormerModelTest(FlaxModelTesterMixin, unittest.TestCase):
test_head_masking = True
all_model_classes = (
(
FlaxRoFormerModel,
FlaxRoFormerForMaskedLM,
FlaxRoFormerForSequenceClassification,
FlaxRoFormerForTokenClassification,
FlaxRoFormerForMultipleChoice,
FlaxRoFormerForQuestionAnswering,
)
if is_flax_available()
else ()
)
def setUp(self):
self.model_tester = FlaxRoFormerModelTester(self)
@slow
def test_model_from_pretrained(self):
for model_class_name in self.all_model_classes:
model = model_class_name.from_pretrained("junnyu/roformer_chinese_small", from_pt=True)
outputs = model(np.ones((1, 1)))
self.assertIsNotNone(outputs)
@require_flax
class FlaxRoFormerModelIntegrationTest(unittest.TestCase):
@slow
def test_inference_masked_lm(self):
model = FlaxRoFormerForMaskedLM.from_pretrained("junnyu/roformer_chinese_base")
input_ids = jnp.array([[0, 1, 2, 3, 4, 5]])
output = model(input_ids)[0]
vocab_size = 50000
expected_shape = (1, 6, vocab_size)
self.assertEqual(output.shape, expected_shape)
expected_slice = jnp.array(
[[[-0.1205, -1.0265, 0.2922], [-1.5134, 0.1974, 0.1519], [-5.0135, -3.9003, -0.8404]]]
)
self.assertTrue(jnp.allclose(output[:, :3, :3], expected_slice, atol=1e-4))