mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
parent
773d386041
commit
6900dded49
@ -23,6 +23,7 @@ from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
import jax
|
||||
from transformers.models.big_bird.modeling_flax_big_bird import (
|
||||
FlaxBigBirdForMaskedLM,
|
||||
FlaxBigBirdForMultipleChoice,
|
||||
@ -162,3 +163,29 @@ class FlaxBigBirdModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
def test_attention_outputs(self):
|
||||
if self.test_attn_probs:
|
||||
super().test_attention_outputs()
|
||||
|
||||
@slow
|
||||
# copied from `test_modeling_flax_common` because it takes much longer than other models
|
||||
def test_jit_compilation(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
model = model_class(config)
|
||||
|
||||
@jax.jit
|
||||
def model_jitted(input_ids, attention_mask=None, **kwargs):
|
||||
return model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
|
||||
|
||||
with self.subTest("JIT Enabled"):
|
||||
jitted_outputs = model_jitted(**prepared_inputs_dict).to_tuple()
|
||||
|
||||
with self.subTest("JIT Disabled"):
|
||||
with jax.disable_jit():
|
||||
outputs = model_jitted(**prepared_inputs_dict).to_tuple()
|
||||
|
||||
self.assertEqual(len(outputs), len(jitted_outputs))
|
||||
for jitted_output, output in zip(jitted_outputs, outputs):
|
||||
|
||||
self.assertEqual(jitted_output.shape, output.shape)
|
||||
|
@ -378,7 +378,6 @@ class FlaxCLIPModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
def test_hidden_states_output(self):
|
||||
pass
|
||||
|
||||
@slow
|
||||
def test_jit_compilation(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
@ -34,7 +34,6 @@ from transformers.testing_utils import (
|
||||
is_pt_flax_cross_test,
|
||||
is_staging_test,
|
||||
require_flax,
|
||||
slow,
|
||||
)
|
||||
from transformers.utils import logging
|
||||
|
||||
@ -391,7 +390,6 @@ class FlaxModelTesterMixin:
|
||||
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
@slow
|
||||
def test_jit_compilation(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
@ -179,7 +179,6 @@ class FlaxViTModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
self.assertListEqual(arg_names[:1], expected_arg_names)
|
||||
|
||||
# We neeed to override this test because ViT expects pixel_values instead of input_ids
|
||||
@slow
|
||||
def test_jit_compilation(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
@ -187,7 +187,6 @@ class FlaxWav2Vec2ModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
expected_arg_names = ["input_values", "attention_mask"]
|
||||
self.assertListEqual(arg_names[:2], expected_arg_names)
|
||||
|
||||
@slow
|
||||
# overwrite because of `input_values`
|
||||
def test_jit_compilation(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
Loading…
Reference in New Issue
Block a user