mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00
Flax Big Bird (#11967)
* add flax bert * bert -> bigbird * original_full ported * add debugger * init block sparse * fix copies ; gelu_fast -> gelu_new * block sparse port * fix block sparse * block sparse working * all ckpts working * fix-copies * make quality * init tests * temporary fix for FlaxBigBirdForMultipleChoice * skip test_attention_outputs * fix * gelu_fast -> gelu_new ; fix multiple choice model * remove nsp * fix sequence classifier * fix * make quality * make fix-copies * finish * Delete debugger.ipynb * Update src/transformers/models/big_bird/modeling_flax_big_bird.py * make style * finish * bye bye jit flax tests Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
parent
a156da9a23
commit
d9c0d08f9a
@ -305,7 +305,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| Bert Generation | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| BigBird | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| BigBird | ✅ | ✅ | ✅ | ❌ | ✅ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| BigBirdPegasus | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
|
@ -134,3 +134,52 @@ BigBirdForQuestionAnswering
|
||||
|
||||
.. autoclass:: transformers.BigBirdForQuestionAnswering
|
||||
:members: forward
|
||||
|
||||
|
||||
FlaxBigBirdModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.FlaxBigBirdModel
|
||||
:members: __call__
|
||||
|
||||
|
||||
FlaxBigBirdForPreTraining
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.FlaxBigBirdForPreTraining
|
||||
:members: __call__
|
||||
|
||||
|
||||
FlaxBigBirdForMaskedLM
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.FlaxBigBirdForMaskedLM
|
||||
:members: __call__
|
||||
|
||||
|
||||
FlaxBigBirdForSequenceClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.FlaxBigBirdForSequenceClassification
|
||||
:members: __call__
|
||||
|
||||
|
||||
FlaxBigBirdForMultipleChoice
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.FlaxBigBirdForMultipleChoice
|
||||
:members: __call__
|
||||
|
||||
|
||||
FlaxBigBirdForTokenClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.FlaxBigBirdForTokenClassification
|
||||
:members: __call__
|
||||
|
||||
|
||||
FlaxBigBirdForQuestionAnswering
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.FlaxBigBirdForQuestionAnswering
|
||||
:members: __call__
|
||||
|
@ -1537,6 +1537,18 @@ if is_flax_available():
|
||||
"FlaxBertPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.big_bird"].extend(
|
||||
[
|
||||
"FlaxBigBirdForMaskedLM",
|
||||
"FlaxBigBirdForMultipleChoice",
|
||||
"FlaxBigBirdForPreTraining",
|
||||
"FlaxBigBirdForQuestionAnswering",
|
||||
"FlaxBigBirdForSequenceClassification",
|
||||
"FlaxBigBirdForTokenClassification",
|
||||
"FlaxBigBirdModel",
|
||||
"FlaxBigBirdPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.clip"].extend(
|
||||
[
|
||||
"FlaxCLIPModel",
|
||||
@ -2847,6 +2859,16 @@ if TYPE_CHECKING:
|
||||
FlaxBertModel,
|
||||
FlaxBertPreTrainedModel,
|
||||
)
|
||||
from .models.big_bird import (
|
||||
FlaxBigBirdForMaskedLM,
|
||||
FlaxBigBirdForMultipleChoice,
|
||||
FlaxBigBirdForPreTraining,
|
||||
FlaxBigBirdForQuestionAnswering,
|
||||
FlaxBigBirdForSequenceClassification,
|
||||
FlaxBigBirdForTokenClassification,
|
||||
FlaxBigBirdModel,
|
||||
FlaxBigBirdPreTrainedModel,
|
||||
)
|
||||
from .models.clip import FlaxCLIPModel, FlaxCLIPPreTrainedModel, FlaxCLIPTextModel, FlaxCLIPVisionModel
|
||||
from .models.electra import (
|
||||
FlaxElectraForMaskedLM,
|
||||
|
@ -34,6 +34,15 @@ from ..bert.modeling_flax_bert import (
|
||||
FlaxBertForTokenClassification,
|
||||
FlaxBertModel,
|
||||
)
|
||||
from ..big_bird.modeling_flax_big_bird import (
|
||||
FlaxBigBirdForMaskedLM,
|
||||
FlaxBigBirdForMultipleChoice,
|
||||
FlaxBigBirdForPreTraining,
|
||||
FlaxBigBirdForQuestionAnswering,
|
||||
FlaxBigBirdForSequenceClassification,
|
||||
FlaxBigBirdForTokenClassification,
|
||||
FlaxBigBirdModel,
|
||||
)
|
||||
from ..clip.modeling_flax_clip import FlaxCLIPModel
|
||||
from ..electra.modeling_flax_electra import (
|
||||
FlaxElectraForMaskedLM,
|
||||
@ -55,7 +64,16 @@ from ..roberta.modeling_flax_roberta import (
|
||||
)
|
||||
from ..vit.modeling_flax_vit import FlaxViTForImageClassification, FlaxViTModel
|
||||
from .auto_factory import auto_class_factory
|
||||
from .configuration_auto import BartConfig, BertConfig, CLIPConfig, ElectraConfig, GPT2Config, RobertaConfig, ViTConfig
|
||||
from .configuration_auto import (
|
||||
BartConfig,
|
||||
BertConfig,
|
||||
BigBirdConfig,
|
||||
CLIPConfig,
|
||||
ElectraConfig,
|
||||
GPT2Config,
|
||||
RobertaConfig,
|
||||
ViTConfig,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -66,6 +84,7 @@ FLAX_MODEL_MAPPING = OrderedDict(
|
||||
# Base model mapping
|
||||
(RobertaConfig, FlaxRobertaModel),
|
||||
(BertConfig, FlaxBertModel),
|
||||
(BigBirdConfig, FlaxBigBirdModel),
|
||||
(BartConfig, FlaxBartModel),
|
||||
(GPT2Config, FlaxGPT2Model),
|
||||
(ElectraConfig, FlaxElectraModel),
|
||||
@ -79,6 +98,7 @@ FLAX_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
|
||||
# Model for pre-training mapping
|
||||
(RobertaConfig, FlaxRobertaForMaskedLM),
|
||||
(BertConfig, FlaxBertForPreTraining),
|
||||
(BigBirdConfig, FlaxBigBirdForPreTraining),
|
||||
(BartConfig, FlaxBartForConditionalGeneration),
|
||||
(ElectraConfig, FlaxElectraForPreTraining),
|
||||
]
|
||||
@ -89,6 +109,7 @@ FLAX_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
|
||||
# Model for Masked LM mapping
|
||||
(RobertaConfig, FlaxRobertaForMaskedLM),
|
||||
(BertConfig, FlaxBertForMaskedLM),
|
||||
(BigBirdConfig, FlaxBigBirdForMaskedLM),
|
||||
(BartConfig, FlaxBartForConditionalGeneration),
|
||||
(ElectraConfig, FlaxElectraForMaskedLM),
|
||||
]
|
||||
@ -113,6 +134,7 @@ FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
|
||||
# Model for Sequence Classification mapping
|
||||
(RobertaConfig, FlaxRobertaForSequenceClassification),
|
||||
(BertConfig, FlaxBertForSequenceClassification),
|
||||
(BigBirdConfig, FlaxBigBirdForSequenceClassification),
|
||||
(BartConfig, FlaxBartForSequenceClassification),
|
||||
(ElectraConfig, FlaxElectraForSequenceClassification),
|
||||
]
|
||||
@ -123,6 +145,7 @@ FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
|
||||
# Model for Question Answering mapping
|
||||
(RobertaConfig, FlaxRobertaForQuestionAnswering),
|
||||
(BertConfig, FlaxBertForQuestionAnswering),
|
||||
(BigBirdConfig, FlaxBigBirdForQuestionAnswering),
|
||||
(BartConfig, FlaxBartForQuestionAnswering),
|
||||
(ElectraConfig, FlaxElectraForQuestionAnswering),
|
||||
]
|
||||
@ -133,6 +156,7 @@ FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict(
|
||||
# Model for Token Classification mapping
|
||||
(RobertaConfig, FlaxRobertaForTokenClassification),
|
||||
(BertConfig, FlaxBertForTokenClassification),
|
||||
(BigBirdConfig, FlaxBigBirdForTokenClassification),
|
||||
(ElectraConfig, FlaxElectraForTokenClassification),
|
||||
]
|
||||
)
|
||||
@ -142,6 +166,7 @@ FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
|
||||
# Model for Multiple Choice mapping
|
||||
(RobertaConfig, FlaxRobertaForMultipleChoice),
|
||||
(BertConfig, FlaxBertForMultipleChoice),
|
||||
(BigBirdConfig, FlaxBigBirdForMultipleChoice),
|
||||
(ElectraConfig, FlaxElectraForMultipleChoice),
|
||||
]
|
||||
)
|
||||
|
@ -193,7 +193,8 @@ class FlaxBertSelfAttention(nn.Module):
|
||||
def setup(self):
|
||||
if self.config.hidden_size % self.config.num_attention_heads != 0:
|
||||
raise ValueError(
|
||||
"`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`: {self.config.num_attention_heads}"
|
||||
"`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`\
|
||||
: {self.config.num_attention_heads}"
|
||||
)
|
||||
|
||||
self.query = nn.Dense(
|
||||
|
@ -19,6 +19,7 @@ from typing import TYPE_CHECKING
|
||||
|
||||
from ...file_utils import (
|
||||
_BaseLazyModule,
|
||||
is_flax_available,
|
||||
is_sentencepiece_available,
|
||||
is_tf_available,
|
||||
is_tokenizers_available,
|
||||
@ -52,6 +53,17 @@ if is_torch_available():
|
||||
"load_tf_weights_in_big_bird",
|
||||
]
|
||||
|
||||
if is_flax_available():
|
||||
_import_structure["modeling_flax_big_bird"] = [
|
||||
"FlaxBigBirdForMaskedLM",
|
||||
"FlaxBigBirdForMultipleChoice",
|
||||
"FlaxBigBirdForPreTraining",
|
||||
"FlaxBigBirdForQuestionAnswering",
|
||||
"FlaxBigBirdForSequenceClassification",
|
||||
"FlaxBigBirdForTokenClassification",
|
||||
"FlaxBigBirdModel",
|
||||
"FlaxBigBirdPreTrainedModel",
|
||||
]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_big_bird import BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP, BigBirdConfig
|
||||
@ -78,6 +90,17 @@ if TYPE_CHECKING:
|
||||
load_tf_weights_in_big_bird,
|
||||
)
|
||||
|
||||
if is_flax_available():
|
||||
from .modeling_flax_big_bird import (
|
||||
FlaxBigBirdForMaskedLM,
|
||||
FlaxBigBirdForMultipleChoice,
|
||||
FlaxBigBirdForPreTraining,
|
||||
FlaxBigBirdForQuestionAnswering,
|
||||
FlaxBigBirdForSequenceClassification,
|
||||
FlaxBigBirdForTokenClassification,
|
||||
FlaxBigBirdModel,
|
||||
FlaxBigBirdPreTrainedModel,
|
||||
)
|
||||
|
||||
else:
|
||||
import importlib
|
||||
|
@ -51,9 +51,9 @@ class BigBirdConfig(PretrainedConfig):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
intermediate_size (:obj:`int`, `optional`, defaults to 3072):
|
||||
Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
||||
hidden_act (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu_fast"`):
|
||||
hidden_act (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu_new"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler. If string,
|
||||
:obj:`"gelu"`, :obj:`"gelu_fast"`, :obj:`"relu"`, :obj:`"selu"` and :obj:`"gelu_new"` are supported.
|
||||
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"selu"` and :obj:`"gelu_new"` are supported.
|
||||
hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
|
||||
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
|
||||
attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
|
||||
@ -107,7 +107,7 @@ class BigBirdConfig(PretrainedConfig):
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu_fast",
|
||||
hidden_act="gelu_new",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=4096,
|
||||
|
@ -43,7 +43,7 @@ from ...modeling_outputs import (
|
||||
SequenceClassifierOutput,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel, SequenceSummary, apply_chunking_to_forward
|
||||
from ...modeling_utils import PreTrainedModel, apply_chunking_to_forward
|
||||
from ...utils import logging
|
||||
from .configuration_big_bird import BigBirdConfig
|
||||
|
||||
@ -2309,7 +2309,6 @@ class BigBirdForPreTraining(BigBirdPreTrainedModel):
|
||||
)
|
||||
|
||||
sequence_output, pooled_output = outputs[:2]
|
||||
|
||||
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
|
||||
|
||||
total_loss = None
|
||||
@ -2709,7 +2708,7 @@ class BigBirdForMultipleChoice(BigBirdPreTrainedModel):
|
||||
super().__init__(config)
|
||||
|
||||
self.bert = BigBirdModel(config)
|
||||
self.sequence_summary = SequenceSummary(config)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = nn.Linear(config.hidden_size, 1)
|
||||
|
||||
self.init_weights()
|
||||
@ -2767,9 +2766,9 @@ class BigBirdForMultipleChoice(BigBirdPreTrainedModel):
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
pooled_output = outputs[1]
|
||||
|
||||
pooled_output = self.sequence_summary(sequence_output)
|
||||
pooled_output = self.dropout(pooled_output)
|
||||
logits = self.classifier(pooled_output)
|
||||
reshaped_logits = logits.view(-1, num_choices)
|
||||
|
||||
|
2061
src/transformers/models/big_bird/modeling_flax_big_bird.py
Normal file
2061
src/transformers/models/big_bird/modeling_flax_big_bird.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -57,9 +57,9 @@ class BigBirdPegasusConfig(PretrainedConfig):
|
||||
Dimension of the "intermediate" (often named feed-forward) layer in decoder.
|
||||
encoder_ffn_dim (:obj:`int`, `optional`, defaults to 4096):
|
||||
Dimension of the "intermediate" (often named feed-forward) layer in decoder.
|
||||
activation_function (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu_fast"`):
|
||||
activation_function (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu_new"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler. If string,
|
||||
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"`, :obj:`"gelu_fast"` and :obj:`"gelu_new"` are supported.
|
||||
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported.
|
||||
dropout (:obj:`float`, `optional`, defaults to 0.1):
|
||||
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
||||
attention_dropout (:obj:`float`, `optional`, defaults to 0.0):
|
||||
@ -127,7 +127,7 @@ class BigBirdPegasusConfig(PretrainedConfig):
|
||||
decoder_layerdrop=0.0,
|
||||
use_cache=True,
|
||||
is_encoder_decoder=True,
|
||||
activation_function="gelu_fast",
|
||||
activation_function="gelu_new",
|
||||
d_model=1024,
|
||||
dropout=0.1,
|
||||
attention_dropout=0.0,
|
||||
|
@ -190,7 +190,8 @@ class FlaxElectraSelfAttention(nn.Module):
|
||||
def setup(self):
|
||||
if self.config.hidden_size % self.config.num_attention_heads != 0:
|
||||
raise ValueError(
|
||||
"`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`: {self.config.num_attention_heads}"
|
||||
"`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`\
|
||||
: {self.config.num_attention_heads}"
|
||||
)
|
||||
|
||||
self.query = nn.Dense(
|
||||
|
@ -179,7 +179,8 @@ class FlaxRobertaSelfAttention(nn.Module):
|
||||
def setup(self):
|
||||
if self.config.hidden_size % self.config.num_attention_heads != 0:
|
||||
raise ValueError(
|
||||
"`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`: {self.config.num_attention_heads}"
|
||||
"`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`\
|
||||
: {self.config.num_attention_heads}"
|
||||
)
|
||||
|
||||
self.query = nn.Dense(
|
||||
|
@ -258,6 +258,74 @@ class FlaxBertPreTrainedModel:
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
|
||||
class FlaxBigBirdForMaskedLM:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
|
||||
class FlaxBigBirdForMultipleChoice:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
|
||||
class FlaxBigBirdForPreTraining:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxBigBirdForQuestionAnswering:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
|
||||
class FlaxBigBirdForSequenceClassification:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
|
||||
class FlaxBigBirdForTokenClassification:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
|
||||
class FlaxBigBirdModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
|
||||
class FlaxBigBirdPreTrainedModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
|
||||
class FlaxCLIPModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
164
tests/test_modeling_flax_big_bird.py
Normal file
164
tests/test_modeling_flax_big_bird.py
Normal file
@ -0,0 +1,164 @@
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import BigBirdConfig, is_flax_available
|
||||
from transformers.testing_utils import require_flax, slow
|
||||
|
||||
from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
from transformers.models.big_bird.modeling_flax_big_bird import (
|
||||
FlaxBigBirdForMaskedLM,
|
||||
FlaxBigBirdForMultipleChoice,
|
||||
FlaxBigBirdForPreTraining,
|
||||
FlaxBigBirdForQuestionAnswering,
|
||||
FlaxBigBirdForSequenceClassification,
|
||||
FlaxBigBirdForTokenClassification,
|
||||
FlaxBigBirdModel,
|
||||
)
|
||||
|
||||
|
||||
class FlaxBigBirdModelTester(unittest.TestCase):
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=56,
|
||||
is_training=True,
|
||||
use_attention_mask=True,
|
||||
use_token_type_ids=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu_new",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=16,
|
||||
type_sequence_label_size=2,
|
||||
initializer_range=0.02,
|
||||
num_choices=4,
|
||||
attention_type="block_sparse",
|
||||
use_bias=True,
|
||||
rescale_embeddings=False,
|
||||
block_size=4,
|
||||
num_random_blocks=3,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_attention_mask = use_attention_mask
|
||||
self.use_token_type_ids = use_token_type_ids
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.initializer_range = initializer_range
|
||||
self.num_choices = num_choices
|
||||
|
||||
self.rescale_embeddings = rescale_embeddings
|
||||
self.attention_type = attention_type
|
||||
self.use_bias = use_bias
|
||||
self.block_size = block_size
|
||||
self.num_random_blocks = num_random_blocks
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
attention_mask = None
|
||||
if self.use_attention_mask:
|
||||
attention_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
||||
|
||||
config = BigBirdConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
intermediate_size=self.intermediate_size,
|
||||
hidden_act=self.hidden_act,
|
||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
type_vocab_size=self.type_vocab_size,
|
||||
is_decoder=False,
|
||||
initializer_range=self.initializer_range,
|
||||
attention_type=self.attention_type,
|
||||
block_size=self.block_size,
|
||||
num_random_blocks=self.num_random_blocks,
|
||||
use_bias=self.use_bias,
|
||||
rescale_embeddings=self.rescale_embeddings,
|
||||
)
|
||||
|
||||
return config, input_ids, token_type_ids, attention_mask
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, input_ids, token_type_ids, attention_mask = config_and_inputs
|
||||
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": attention_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxBigBirdModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (
|
||||
(
|
||||
FlaxBigBirdModel,
|
||||
FlaxBigBirdForPreTraining,
|
||||
FlaxBigBirdForMaskedLM,
|
||||
FlaxBigBirdForMultipleChoice,
|
||||
FlaxBigBirdForQuestionAnswering,
|
||||
FlaxBigBirdForSequenceClassification,
|
||||
FlaxBigBirdForTokenClassification,
|
||||
)
|
||||
if is_flax_available()
|
||||
else ()
|
||||
)
|
||||
|
||||
test_attn_probs = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = FlaxBigBirdModelTester(self)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_class_name in self.all_model_classes:
|
||||
model = model_class_name.from_pretrained("google/bigbird-roberta-base", from_pt=True)
|
||||
outputs = model(np.ones((1, 1)))
|
||||
self.assertIsNotNone(outputs)
|
||||
|
||||
def test_attention_outputs(self):
|
||||
if self.test_attn_probs:
|
||||
super().test_attention_outputs()
|
@ -342,6 +342,7 @@ class FlaxCLIPModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
def test_hidden_states_output(self):
|
||||
pass
|
||||
|
||||
@slow
|
||||
def test_jit_compilation(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
@ -23,7 +23,7 @@ import numpy as np
|
||||
import transformers
|
||||
from transformers import is_flax_available, is_torch_available
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.testing_utils import is_pt_flax_cross_test, require_flax
|
||||
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
@ -273,6 +273,7 @@ class FlaxModelTesterMixin:
|
||||
for output_loaded, output in zip(outputs_loaded, outputs):
|
||||
self.assert_almost_equals(output_loaded, output, 1e-3)
|
||||
|
||||
@slow
|
||||
def test_jit_compilation(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
@ -179,6 +179,7 @@ class FlaxViTModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
self.assertListEqual(arg_names[:1], expected_arg_names)
|
||||
|
||||
# We neeed to override this test because ViT expects pixel_values instead of input_ids
|
||||
@slow
|
||||
def test_jit_compilation(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user