mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 13:20:12 +06:00
parent
773d386041
commit
6900dded49
@ -23,6 +23,7 @@ from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_
|
|||||||
|
|
||||||
|
|
||||||
if is_flax_available():
|
if is_flax_available():
|
||||||
|
import jax
|
||||||
from transformers.models.big_bird.modeling_flax_big_bird import (
|
from transformers.models.big_bird.modeling_flax_big_bird import (
|
||||||
FlaxBigBirdForMaskedLM,
|
FlaxBigBirdForMaskedLM,
|
||||||
FlaxBigBirdForMultipleChoice,
|
FlaxBigBirdForMultipleChoice,
|
||||||
@ -162,3 +163,29 @@ class FlaxBigBirdModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
|||||||
def test_attention_outputs(self):
|
def test_attention_outputs(self):
|
||||||
if self.test_attn_probs:
|
if self.test_attn_probs:
|
||||||
super().test_attention_outputs()
|
super().test_attention_outputs()
|
||||||
|
|
||||||
|
@slow
|
||||||
|
# copied from `test_modeling_flax_common` because it takes much longer than other models
|
||||||
|
def test_jit_compilation(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
with self.subTest(model_class.__name__):
|
||||||
|
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
model = model_class(config)
|
||||||
|
|
||||||
|
@jax.jit
|
||||||
|
def model_jitted(input_ids, attention_mask=None, **kwargs):
|
||||||
|
return model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
|
||||||
|
|
||||||
|
with self.subTest("JIT Enabled"):
|
||||||
|
jitted_outputs = model_jitted(**prepared_inputs_dict).to_tuple()
|
||||||
|
|
||||||
|
with self.subTest("JIT Disabled"):
|
||||||
|
with jax.disable_jit():
|
||||||
|
outputs = model_jitted(**prepared_inputs_dict).to_tuple()
|
||||||
|
|
||||||
|
self.assertEqual(len(outputs), len(jitted_outputs))
|
||||||
|
for jitted_output, output in zip(jitted_outputs, outputs):
|
||||||
|
|
||||||
|
self.assertEqual(jitted_output.shape, output.shape)
|
||||||
|
@ -378,7 +378,6 @@ class FlaxCLIPModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
|||||||
def test_hidden_states_output(self):
|
def test_hidden_states_output(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@slow
|
|
||||||
def test_jit_compilation(self):
|
def test_jit_compilation(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
@ -34,7 +34,6 @@ from transformers.testing_utils import (
|
|||||||
is_pt_flax_cross_test,
|
is_pt_flax_cross_test,
|
||||||
is_staging_test,
|
is_staging_test,
|
||||||
require_flax,
|
require_flax,
|
||||||
slow,
|
|
||||||
)
|
)
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
|
|
||||||
@ -391,7 +390,6 @@ class FlaxModelTesterMixin:
|
|||||||
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
||||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||||
|
|
||||||
@slow
|
|
||||||
def test_jit_compilation(self):
|
def test_jit_compilation(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
@ -179,7 +179,6 @@ class FlaxViTModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
|||||||
self.assertListEqual(arg_names[:1], expected_arg_names)
|
self.assertListEqual(arg_names[:1], expected_arg_names)
|
||||||
|
|
||||||
# We neeed to override this test because ViT expects pixel_values instead of input_ids
|
# We neeed to override this test because ViT expects pixel_values instead of input_ids
|
||||||
@slow
|
|
||||||
def test_jit_compilation(self):
|
def test_jit_compilation(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
@ -187,7 +187,6 @@ class FlaxWav2Vec2ModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
|||||||
expected_arg_names = ["input_values", "attention_mask"]
|
expected_arg_names = ["input_values", "attention_mask"]
|
||||||
self.assertListEqual(arg_names[:2], expected_arg_names)
|
self.assertListEqual(arg_names[:2], expected_arg_names)
|
||||||
|
|
||||||
@slow
|
|
||||||
# overwrite because of `input_values`
|
# overwrite because of `input_values`
|
||||||
def test_jit_compilation(self):
|
def test_jit_compilation(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
Loading…
Reference in New Issue
Block a user