diff --git a/tests/test_modeling_flax_big_bird.py b/tests/test_modeling_flax_big_bird.py index d95a2df278d..c8afce26858 100644 --- a/tests/test_modeling_flax_big_bird.py +++ b/tests/test_modeling_flax_big_bird.py @@ -23,6 +23,7 @@ from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_ if is_flax_available(): + import jax from transformers.models.big_bird.modeling_flax_big_bird import ( FlaxBigBirdForMaskedLM, FlaxBigBirdForMultipleChoice, @@ -162,3 +163,29 @@ class FlaxBigBirdModelTest(FlaxModelTesterMixin, unittest.TestCase): def test_attention_outputs(self): if self.test_attn_probs: super().test_attention_outputs() + + @slow + # copied from `test_modeling_flax_common` because it takes much longer than other models + def test_jit_compilation(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + with self.subTest(model_class.__name__): + prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) + model = model_class(config) + + @jax.jit + def model_jitted(input_ids, attention_mask=None, **kwargs): + return model(input_ids=input_ids, attention_mask=attention_mask, **kwargs) + + with self.subTest("JIT Enabled"): + jitted_outputs = model_jitted(**prepared_inputs_dict).to_tuple() + + with self.subTest("JIT Disabled"): + with jax.disable_jit(): + outputs = model_jitted(**prepared_inputs_dict).to_tuple() + + self.assertEqual(len(outputs), len(jitted_outputs)) + for jitted_output, output in zip(jitted_outputs, outputs): + + self.assertEqual(jitted_output.shape, output.shape) diff --git a/tests/test_modeling_flax_clip.py b/tests/test_modeling_flax_clip.py index 80363505308..cab4b7b53d1 100644 --- a/tests/test_modeling_flax_clip.py +++ b/tests/test_modeling_flax_clip.py @@ -378,7 +378,6 @@ class FlaxCLIPModelTest(FlaxModelTesterMixin, unittest.TestCase): def test_hidden_states_output(self): pass - @slow def test_jit_compilation(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index 03139ac73a7..612cb5d2aa9 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -34,7 +34,6 @@ from transformers.testing_utils import ( is_pt_flax_cross_test, is_staging_test, require_flax, - slow, ) from transformers.utils import logging @@ -391,7 +390,6 @@ class FlaxModelTesterMixin: max_diff = (base_params[key] - base_params_from_head[key]).sum().item() self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") - @slow def test_jit_compilation(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/test_modeling_flax_vit.py b/tests/test_modeling_flax_vit.py index f745d7c7ffa..276777e0009 100644 --- a/tests/test_modeling_flax_vit.py +++ b/tests/test_modeling_flax_vit.py @@ -179,7 +179,6 @@ class FlaxViTModelTest(FlaxModelTesterMixin, unittest.TestCase): self.assertListEqual(arg_names[:1], expected_arg_names) # We neeed to override this test because ViT expects pixel_values instead of input_ids - @slow def test_jit_compilation(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/test_modeling_flax_wav2vec2.py b/tests/test_modeling_flax_wav2vec2.py index 66dd1e0611d..ce83d77f9f3 100644 --- a/tests/test_modeling_flax_wav2vec2.py +++ b/tests/test_modeling_flax_wav2vec2.py @@ -187,7 +187,6 @@ class FlaxWav2Vec2ModelTest(FlaxModelTesterMixin, unittest.TestCase): expected_arg_names = ["input_values", "attention_mask"] self.assertListEqual(arg_names[:2], expected_arg_names) - @slow # overwrite because of `input_values` def test_jit_compilation(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()