mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 11:41:51 +06:00
Temporary tolerance fix for flaky whipser PT-TF equiv. test (#23257)
* Temp tol fix for flaky whipser test * Add equivalent update to TF tests
This commit is contained in:
parent
ca26699f37
commit
f82ee109e6
@ -400,6 +400,10 @@ class TFWhisperModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestC
|
||||
|
||||
check_hidden_states_output(inputs_dict, config, model_class)
|
||||
|
||||
def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=5e-5, name="outputs", attributes=None):
|
||||
# We override with a slightly higher tol value, as test recently became flaky
|
||||
super().check_pt_tf_outputs(tf_outputs, pt_outputs, model_class, tol, name, attributes)
|
||||
|
||||
def test_attention_outputs(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.return_dict = True
|
||||
|
@ -824,6 +824,10 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=5e-5, name="outputs", attributes=None):
|
||||
# We override with a slightly higher tol value, as test recently became flaky
|
||||
super().check_pt_tf_outputs(tf_outputs, pt_outputs, model_class, tol, name, attributes)
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_equivalence_pt_to_flax(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
Loading…
Reference in New Issue
Block a user