Use higher value for hidden_size in Flax BigBird test (#17822)

* Use higher value for hidden_size in Flax BigBird test

* remove 5e-5

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2022-06-24 19:31:30 +02:00 committed by GitHub
parent 2ef94ee039
commit 0e0f1f4692
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 3 additions and 5 deletions

View File

@ -597,8 +597,7 @@ class BigBirdModelTest(ModelTesterMixin, unittest.TestCase):
self.model_tester.create_and_check_for_change_to_full_attn(*config_and_inputs)
# overwrite from common in order to skip the check on `attentions`
# also use `5e-5` to avoid flaky test failure
def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=5e-5, name="outputs", attributes=None):
def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=1e-5, name="outputs", attributes=None):
# `bigbird_block_sparse_attention` in `FlaxBigBird` returns `attention_probs = None`, while in PyTorch version,
# an effort was done to return `attention_probs` (yet to be verified).
if name.startswith("outputs.attentions"):

View File

@ -47,7 +47,7 @@ class FlaxBigBirdModelTester(unittest.TestCase):
use_token_type_ids=True,
use_labels=True,
vocab_size=99,
hidden_size=4,
hidden_size=32,
num_hidden_layers=2,
num_attention_heads=2,
intermediate_size=7,
@ -214,8 +214,7 @@ class FlaxBigBirdModelTest(FlaxModelTesterMixin, unittest.TestCase):
self.assertEqual(jitted_output.shape, output.shape)
# overwrite from common in order to skip the check on `attentions`
# also use `5e-5` to avoid flaky test failure
def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=5e-5, name="outputs", attributes=None):
def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=1e-5, name="outputs", attributes=None):
# `bigbird_block_sparse_attention` in `FlaxBigBird` returns `attention_probs = None`, while in PyTorch version,
# an effort was done to return `attention_probs` (yet to be verified).
if name.startswith("outputs.attentions"):