[Flax/JAX] Run jitted tests at every commit (#13090)

* up

* up

* up
This commit is contained in:
Patrick von Platen 2021-08-12 14:49:46 +02:00 committed by GitHub
parent 773d386041
commit 6900dded49
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 27 additions and 5 deletions

View File

@ -23,6 +23,7 @@ from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_
if is_flax_available(): if is_flax_available():
import jax
from transformers.models.big_bird.modeling_flax_big_bird import ( from transformers.models.big_bird.modeling_flax_big_bird import (
FlaxBigBirdForMaskedLM, FlaxBigBirdForMaskedLM,
FlaxBigBirdForMultipleChoice, FlaxBigBirdForMultipleChoice,
@ -162,3 +163,29 @@ class FlaxBigBirdModelTest(FlaxModelTesterMixin, unittest.TestCase):
def test_attention_outputs(self): def test_attention_outputs(self):
if self.test_attn_probs: if self.test_attn_probs:
super().test_attention_outputs() super().test_attention_outputs()
@slow
# copied from `test_modeling_flax_common` because it takes much longer than other models
def test_jit_compilation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config)
@jax.jit
def model_jitted(input_ids, attention_mask=None, **kwargs):
return model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
with self.subTest("JIT Enabled"):
jitted_outputs = model_jitted(**prepared_inputs_dict).to_tuple()
with self.subTest("JIT Disabled"):
with jax.disable_jit():
outputs = model_jitted(**prepared_inputs_dict).to_tuple()
self.assertEqual(len(outputs), len(jitted_outputs))
for jitted_output, output in zip(jitted_outputs, outputs):
self.assertEqual(jitted_output.shape, output.shape)

View File

@ -378,7 +378,6 @@ class FlaxCLIPModelTest(FlaxModelTesterMixin, unittest.TestCase):
def test_hidden_states_output(self): def test_hidden_states_output(self):
pass pass
@slow
def test_jit_compilation(self): def test_jit_compilation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

View File

@ -34,7 +34,6 @@ from transformers.testing_utils import (
is_pt_flax_cross_test, is_pt_flax_cross_test,
is_staging_test, is_staging_test,
require_flax, require_flax,
slow,
) )
from transformers.utils import logging from transformers.utils import logging
@ -391,7 +390,6 @@ class FlaxModelTesterMixin:
max_diff = (base_params[key] - base_params_from_head[key]).sum().item() max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
@slow
def test_jit_compilation(self): def test_jit_compilation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

View File

@ -179,7 +179,6 @@ class FlaxViTModelTest(FlaxModelTesterMixin, unittest.TestCase):
self.assertListEqual(arg_names[:1], expected_arg_names) self.assertListEqual(arg_names[:1], expected_arg_names)
# We neeed to override this test because ViT expects pixel_values instead of input_ids # We neeed to override this test because ViT expects pixel_values instead of input_ids
@slow
def test_jit_compilation(self): def test_jit_compilation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

View File

@ -187,7 +187,6 @@ class FlaxWav2Vec2ModelTest(FlaxModelTesterMixin, unittest.TestCase):
expected_arg_names = ["input_values", "attention_mask"] expected_arg_names = ["input_values", "attention_mask"]
self.assertListEqual(arg_names[:2], expected_arg_names) self.assertListEqual(arg_names[:2], expected_arg_names)
@slow
# overwrite because of `input_values` # overwrite because of `input_values`
def test_jit_compilation(self): def test_jit_compilation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()