diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index d40df383f96..6f1dbedd2fb 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -181,7 +181,7 @@ class FlaxModelTesterMixin: fx_outputs = fx_model(**prepared_inputs_dict).to_tuple() self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") for fx_output, pt_output in zip(fx_outputs, pt_outputs): - self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-3) + self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2) with tempfile.TemporaryDirectory() as tmpdirname: pt_model.save_pretrained(tmpdirname) @@ -192,10 +192,7 @@ class FlaxModelTesterMixin: len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch" ) for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs): - if not isinstance( - fx_output_loaded, tuple - ): # TODO(Patrick, Daniel) - let's discard use_cache for now - self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 1e-3) + self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2) @is_pt_flax_cross_test def test_equivalence_flax_to_pt(self): @@ -229,7 +226,7 @@ class FlaxModelTesterMixin: self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") for fx_output, pt_output in zip(fx_outputs, pt_outputs): - self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-3) + self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2) with tempfile.TemporaryDirectory() as tmpdirname: fx_model.save_pretrained(tmpdirname) @@ -242,8 +239,7 @@ class FlaxModelTesterMixin: len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch" ) for fx_output, pt_output in zip(fx_outputs, pt_outputs_loaded): - if not isinstance(fx_output, tuple): # TODO(Patrick, Daniel) - let's discard use_cache for now - self.assert_almost_equals(fx_output, pt_output.numpy(), 5e-3) + self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2) def test_from_pretrained_save_pretrained(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()