Flax Big Bird (#11967)

* add flax bert

* bert -> bigbird

* original_full ported

* add debugger

* init block sparse

* fix copies ; gelu_fast -> gelu_new

* block sparse port

* fix block sparse

* block sparse working

* all ckpts working

* fix-copies

* make quality

* init tests

* temporary fix for FlaxBigBirdForMultipleChoice

* skip test_attention_outputs

* fix

* gelu_fast -> gelu_new ; fix multiple choice model

* remove nsp

* fix sequence classifier

* fix

* make quality

* make fix-copies

* finish

* Delete debugger.ipynb

* Update src/transformers/models/big_bird/modeling_flax_big_bird.py

* make style

* finish

* bye bye jit flax tests

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Vasudev Gupta 2021-06-15 00:31:03 +05:30 committed by GitHub
parent a156da9a23
commit d9c0d08f9a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 2434 additions and 17 deletions

View File

@ -305,7 +305,7 @@ Flax), PyTorch, and/or TensorFlow.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Bert Generation | ✅ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| BigBird | ✅ | ✅ | ✅ | ❌ | |
| BigBird | ✅ | ✅ | ✅ | ❌ | |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| BigBirdPegasus | ❌ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+

View File

@ -134,3 +134,52 @@ BigBirdForQuestionAnswering
.. autoclass:: transformers.BigBirdForQuestionAnswering
:members: forward
FlaxBigBirdModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxBigBirdModel
:members: __call__
FlaxBigBirdForPreTraining
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxBigBirdForPreTraining
:members: __call__
FlaxBigBirdForMaskedLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxBigBirdForMaskedLM
:members: __call__
FlaxBigBirdForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxBigBirdForSequenceClassification
:members: __call__
FlaxBigBirdForMultipleChoice
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxBigBirdForMultipleChoice
:members: __call__
FlaxBigBirdForTokenClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxBigBirdForTokenClassification
:members: __call__
FlaxBigBirdForQuestionAnswering
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxBigBirdForQuestionAnswering
:members: __call__

View File

@ -1537,6 +1537,18 @@ if is_flax_available():
"FlaxBertPreTrainedModel",
]
)
_import_structure["models.big_bird"].extend(
[
"FlaxBigBirdForMaskedLM",
"FlaxBigBirdForMultipleChoice",
"FlaxBigBirdForPreTraining",
"FlaxBigBirdForQuestionAnswering",
"FlaxBigBirdForSequenceClassification",
"FlaxBigBirdForTokenClassification",
"FlaxBigBirdModel",
"FlaxBigBirdPreTrainedModel",
]
)
_import_structure["models.clip"].extend(
[
"FlaxCLIPModel",
@ -2847,6 +2859,16 @@ if TYPE_CHECKING:
FlaxBertModel,
FlaxBertPreTrainedModel,
)
from .models.big_bird import (
FlaxBigBirdForMaskedLM,
FlaxBigBirdForMultipleChoice,
FlaxBigBirdForPreTraining,
FlaxBigBirdForQuestionAnswering,
FlaxBigBirdForSequenceClassification,
FlaxBigBirdForTokenClassification,
FlaxBigBirdModel,
FlaxBigBirdPreTrainedModel,
)
from .models.clip import FlaxCLIPModel, FlaxCLIPPreTrainedModel, FlaxCLIPTextModel, FlaxCLIPVisionModel
from .models.electra import (
FlaxElectraForMaskedLM,

View File

@ -34,6 +34,15 @@ from ..bert.modeling_flax_bert import (
FlaxBertForTokenClassification,
FlaxBertModel,
)
from ..big_bird.modeling_flax_big_bird import (
FlaxBigBirdForMaskedLM,
FlaxBigBirdForMultipleChoice,
FlaxBigBirdForPreTraining,
FlaxBigBirdForQuestionAnswering,
FlaxBigBirdForSequenceClassification,
FlaxBigBirdForTokenClassification,
FlaxBigBirdModel,
)
from ..clip.modeling_flax_clip import FlaxCLIPModel
from ..electra.modeling_flax_electra import (
FlaxElectraForMaskedLM,
@ -55,7 +64,16 @@ from ..roberta.modeling_flax_roberta import (
)
from ..vit.modeling_flax_vit import FlaxViTForImageClassification, FlaxViTModel
from .auto_factory import auto_class_factory
from .configuration_auto import BartConfig, BertConfig, CLIPConfig, ElectraConfig, GPT2Config, RobertaConfig, ViTConfig
from .configuration_auto import (
BartConfig,
BertConfig,
BigBirdConfig,
CLIPConfig,
ElectraConfig,
GPT2Config,
RobertaConfig,
ViTConfig,
)
logger = logging.get_logger(__name__)
@ -66,6 +84,7 @@ FLAX_MODEL_MAPPING = OrderedDict(
# Base model mapping
(RobertaConfig, FlaxRobertaModel),
(BertConfig, FlaxBertModel),
(BigBirdConfig, FlaxBigBirdModel),
(BartConfig, FlaxBartModel),
(GPT2Config, FlaxGPT2Model),
(ElectraConfig, FlaxElectraModel),
@ -79,6 +98,7 @@ FLAX_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
# Model for pre-training mapping
(RobertaConfig, FlaxRobertaForMaskedLM),
(BertConfig, FlaxBertForPreTraining),
(BigBirdConfig, FlaxBigBirdForPreTraining),
(BartConfig, FlaxBartForConditionalGeneration),
(ElectraConfig, FlaxElectraForPreTraining),
]
@ -89,6 +109,7 @@ FLAX_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
# Model for Masked LM mapping
(RobertaConfig, FlaxRobertaForMaskedLM),
(BertConfig, FlaxBertForMaskedLM),
(BigBirdConfig, FlaxBigBirdForMaskedLM),
(BartConfig, FlaxBartForConditionalGeneration),
(ElectraConfig, FlaxElectraForMaskedLM),
]
@ -113,6 +134,7 @@ FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
# Model for Sequence Classification mapping
(RobertaConfig, FlaxRobertaForSequenceClassification),
(BertConfig, FlaxBertForSequenceClassification),
(BigBirdConfig, FlaxBigBirdForSequenceClassification),
(BartConfig, FlaxBartForSequenceClassification),
(ElectraConfig, FlaxElectraForSequenceClassification),
]
@ -123,6 +145,7 @@ FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
# Model for Question Answering mapping
(RobertaConfig, FlaxRobertaForQuestionAnswering),
(BertConfig, FlaxBertForQuestionAnswering),
(BigBirdConfig, FlaxBigBirdForQuestionAnswering),
(BartConfig, FlaxBartForQuestionAnswering),
(ElectraConfig, FlaxElectraForQuestionAnswering),
]
@ -133,6 +156,7 @@ FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict(
# Model for Token Classification mapping
(RobertaConfig, FlaxRobertaForTokenClassification),
(BertConfig, FlaxBertForTokenClassification),
(BigBirdConfig, FlaxBigBirdForTokenClassification),
(ElectraConfig, FlaxElectraForTokenClassification),
]
)
@ -142,6 +166,7 @@ FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
# Model for Multiple Choice mapping
(RobertaConfig, FlaxRobertaForMultipleChoice),
(BertConfig, FlaxBertForMultipleChoice),
(BigBirdConfig, FlaxBigBirdForMultipleChoice),
(ElectraConfig, FlaxElectraForMultipleChoice),
]
)

View File

@ -193,7 +193,8 @@ class FlaxBertSelfAttention(nn.Module):
def setup(self):
if self.config.hidden_size % self.config.num_attention_heads != 0:
raise ValueError(
"`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`: {self.config.num_attention_heads}"
"`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`\
: {self.config.num_attention_heads}"
)
self.query = nn.Dense(

View File

@ -19,6 +19,7 @@ from typing import TYPE_CHECKING
from ...file_utils import (
_BaseLazyModule,
is_flax_available,
is_sentencepiece_available,
is_tf_available,
is_tokenizers_available,
@ -52,6 +53,17 @@ if is_torch_available():
"load_tf_weights_in_big_bird",
]
if is_flax_available():
_import_structure["modeling_flax_big_bird"] = [
"FlaxBigBirdForMaskedLM",
"FlaxBigBirdForMultipleChoice",
"FlaxBigBirdForPreTraining",
"FlaxBigBirdForQuestionAnswering",
"FlaxBigBirdForSequenceClassification",
"FlaxBigBirdForTokenClassification",
"FlaxBigBirdModel",
"FlaxBigBirdPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_big_bird import BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP, BigBirdConfig
@ -78,6 +90,17 @@ if TYPE_CHECKING:
load_tf_weights_in_big_bird,
)
if is_flax_available():
from .modeling_flax_big_bird import (
FlaxBigBirdForMaskedLM,
FlaxBigBirdForMultipleChoice,
FlaxBigBirdForPreTraining,
FlaxBigBirdForQuestionAnswering,
FlaxBigBirdForSequenceClassification,
FlaxBigBirdForTokenClassification,
FlaxBigBirdModel,
FlaxBigBirdPreTrainedModel,
)
else:
import importlib

View File

@ -51,9 +51,9 @@ class BigBirdConfig(PretrainedConfig):
Number of attention heads for each attention layer in the Transformer encoder.
intermediate_size (:obj:`int`, `optional`, defaults to 3072):
Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
hidden_act (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu_fast"`):
hidden_act (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu_new"`):
The non-linear activation function (function or string) in the encoder and pooler. If string,
:obj:`"gelu"`, :obj:`"gelu_fast"`, :obj:`"relu"`, :obj:`"selu"` and :obj:`"gelu_new"` are supported.
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"selu"` and :obj:`"gelu_new"` are supported.
hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
@ -107,7 +107,7 @@ class BigBirdConfig(PretrainedConfig):
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu_fast",
hidden_act="gelu_new",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=4096,

View File

@ -43,7 +43,7 @@ from ...modeling_outputs import (
SequenceClassifierOutput,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel, SequenceSummary, apply_chunking_to_forward
from ...modeling_utils import PreTrainedModel, apply_chunking_to_forward
from ...utils import logging
from .configuration_big_bird import BigBirdConfig
@ -2309,7 +2309,6 @@ class BigBirdForPreTraining(BigBirdPreTrainedModel):
)
sequence_output, pooled_output = outputs[:2]
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
total_loss = None
@ -2709,7 +2708,7 @@ class BigBirdForMultipleChoice(BigBirdPreTrainedModel):
super().__init__(config)
self.bert = BigBirdModel(config)
self.sequence_summary = SequenceSummary(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, 1)
self.init_weights()
@ -2767,9 +2766,9 @@ class BigBirdForMultipleChoice(BigBirdPreTrainedModel):
return_dict=return_dict,
)
sequence_output = outputs[0]
pooled_output = outputs[1]
pooled_output = self.sequence_summary(sequence_output)
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
reshaped_logits = logits.view(-1, num_choices)

File diff suppressed because it is too large Load Diff

View File

@ -57,9 +57,9 @@ class BigBirdPegasusConfig(PretrainedConfig):
Dimension of the "intermediate" (often named feed-forward) layer in decoder.
encoder_ffn_dim (:obj:`int`, `optional`, defaults to 4096):
Dimension of the "intermediate" (often named feed-forward) layer in decoder.
activation_function (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu_fast"`):
activation_function (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu_new"`):
The non-linear activation function (function or string) in the encoder and pooler. If string,
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"`, :obj:`"gelu_fast"` and :obj:`"gelu_new"` are supported.
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported.
dropout (:obj:`float`, `optional`, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_dropout (:obj:`float`, `optional`, defaults to 0.0):
@ -127,7 +127,7 @@ class BigBirdPegasusConfig(PretrainedConfig):
decoder_layerdrop=0.0,
use_cache=True,
is_encoder_decoder=True,
activation_function="gelu_fast",
activation_function="gelu_new",
d_model=1024,
dropout=0.1,
attention_dropout=0.0,

View File

@ -190,7 +190,8 @@ class FlaxElectraSelfAttention(nn.Module):
def setup(self):
if self.config.hidden_size % self.config.num_attention_heads != 0:
raise ValueError(
"`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`: {self.config.num_attention_heads}"
"`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`\
: {self.config.num_attention_heads}"
)
self.query = nn.Dense(

View File

@ -179,7 +179,8 @@ class FlaxRobertaSelfAttention(nn.Module):
def setup(self):
if self.config.hidden_size % self.config.num_attention_heads != 0:
raise ValueError(
"`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`: {self.config.num_attention_heads}"
"`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`\
: {self.config.num_attention_heads}"
)
self.query = nn.Dense(

View File

@ -258,6 +258,74 @@ class FlaxBertPreTrainedModel:
requires_backends(cls, ["flax"])
class FlaxBigBirdForMaskedLM:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxBigBirdForMultipleChoice:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxBigBirdForPreTraining:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxBigBirdForQuestionAnswering:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxBigBirdForSequenceClassification:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxBigBirdForTokenClassification:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxBigBirdModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxBigBirdPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxCLIPModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])

View File

@ -0,0 +1,164 @@
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from transformers import BigBirdConfig, is_flax_available
from transformers.testing_utils import require_flax, slow
from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
if is_flax_available():
from transformers.models.big_bird.modeling_flax_big_bird import (
FlaxBigBirdForMaskedLM,
FlaxBigBirdForMultipleChoice,
FlaxBigBirdForPreTraining,
FlaxBigBirdForQuestionAnswering,
FlaxBigBirdForSequenceClassification,
FlaxBigBirdForTokenClassification,
FlaxBigBirdModel,
)
class FlaxBigBirdModelTester(unittest.TestCase):
def __init__(
self,
parent,
batch_size=13,
seq_length=56,
is_training=True,
use_attention_mask=True,
use_token_type_ids=True,
use_labels=True,
vocab_size=99,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
intermediate_size=37,
hidden_act="gelu_new",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
type_sequence_label_size=2,
initializer_range=0.02,
num_choices=4,
attention_type="block_sparse",
use_bias=True,
rescale_embeddings=False,
block_size=4,
num_random_blocks=3,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_attention_mask = use_attention_mask
self.use_token_type_ids = use_token_type_ids
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.num_choices = num_choices
self.rescale_embeddings = rescale_embeddings
self.attention_type = attention_type
self.use_bias = use_bias
self.block_size = block_size
self.num_random_blocks = num_random_blocks
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
attention_mask = None
if self.use_attention_mask:
attention_mask = random_attention_mask([self.batch_size, self.seq_length])
token_type_ids = None
if self.use_token_type_ids:
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
config = BigBirdConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
hidden_act=self.hidden_act,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size,
is_decoder=False,
initializer_range=self.initializer_range,
attention_type=self.attention_type,
block_size=self.block_size,
num_random_blocks=self.num_random_blocks,
use_bias=self.use_bias,
rescale_embeddings=self.rescale_embeddings,
)
return config, input_ids, token_type_ids, attention_mask
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, input_ids, token_type_ids, attention_mask = config_and_inputs
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": attention_mask}
return config, inputs_dict
@require_flax
class FlaxBigBirdModelTest(FlaxModelTesterMixin, unittest.TestCase):
all_model_classes = (
(
FlaxBigBirdModel,
FlaxBigBirdForPreTraining,
FlaxBigBirdForMaskedLM,
FlaxBigBirdForMultipleChoice,
FlaxBigBirdForQuestionAnswering,
FlaxBigBirdForSequenceClassification,
FlaxBigBirdForTokenClassification,
)
if is_flax_available()
else ()
)
test_attn_probs = False
def setUp(self):
self.model_tester = FlaxBigBirdModelTester(self)
@slow
def test_model_from_pretrained(self):
for model_class_name in self.all_model_classes:
model = model_class_name.from_pretrained("google/bigbird-roberta-base", from_pt=True)
outputs = model(np.ones((1, 1)))
self.assertIsNotNone(outputs)
def test_attention_outputs(self):
if self.test_attn_probs:
super().test_attention_outputs()

View File

@ -342,6 +342,7 @@ class FlaxCLIPModelTest(FlaxModelTesterMixin, unittest.TestCase):
def test_hidden_states_output(self):
pass
@slow
def test_jit_compilation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

View File

@ -23,7 +23,7 @@ import numpy as np
import transformers
from transformers import is_flax_available, is_torch_available
from transformers.models.auto import get_values
from transformers.testing_utils import is_pt_flax_cross_test, require_flax
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow
if is_flax_available():
@ -273,6 +273,7 @@ class FlaxModelTesterMixin:
for output_loaded, output in zip(outputs_loaded, outputs):
self.assert_almost_equals(output_loaded, output, 1e-3)
@slow
def test_jit_compilation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

View File

@ -179,6 +179,7 @@ class FlaxViTModelTest(FlaxModelTesterMixin, unittest.TestCase):
self.assertListEqual(arg_names[:1], expected_arg_names)
# We neeed to override this test because ViT expects pixel_values instead of input_ids
@slow
def test_jit_compilation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()