Use 5e-5 For BigBird PT/Flax equivalence tests (#17780)

* rename to check_pt_flax_outputs

* update check_pt_flax_outputs

* use 5e-5 for BigBird PT/Flax test

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2022-06-21 17:55:26 +02:00 committed by GitHub
parent 6a5272b205
commit f47afefb21
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 153 additions and 48 deletions

View File

@ -597,13 +597,14 @@ class BigBirdModelTest(ModelTesterMixin, unittest.TestCase):
self.model_tester.create_and_check_for_change_to_full_attn(*config_and_inputs)
# overwrite from common in order to skip the check on `attentions`
def check_outputs(self, fx_outputs, pt_outputs, model_class, names):
# also use `5e-5` to avoid flaky test failure
def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=5e-5, name="outputs", attributes=None):
# `bigbird_block_sparse_attention` in `FlaxBigBird` returns `attention_probs = None`, while in PyTorch version,
# an effort was done to return `attention_probs` (yet to be verified).
if type(names) == str and names.startswith("attentions"):
if name.startswith("outputs.attentions"):
return
else:
super().check_outputs(fx_outputs, pt_outputs, model_class, names)
super().check_pt_flax_outputs(fx_outputs, pt_outputs, model_class, tol, name, attributes)
@require_torch

View File

@ -214,10 +214,11 @@ class FlaxBigBirdModelTest(FlaxModelTesterMixin, unittest.TestCase):
self.assertEqual(jitted_output.shape, output.shape)
# overwrite from common in order to skip the check on `attentions`
def check_outputs(self, fx_outputs, pt_outputs, model_class, names):
# also use `5e-5` to avoid flaky test failure
def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=5e-5, name="outputs", attributes=None):
# `bigbird_block_sparse_attention` in `FlaxBigBird` returns `attention_probs = None`, while in PyTorch version,
# an effort was done to return `attention_probs` (yet to be verified).
if type(names) == str and names.startswith("attentions"):
if name.startswith("outputs.attentions"):
return
else:
super().check_outputs(fx_outputs, pt_outputs, model_class, names)
super().check_pt_flax_outputs(fx_outputs, pt_outputs, model_class, tol, name, attributes)

View File

@ -1648,7 +1648,7 @@ class ModelTesterMixin:
self.assertEqual(tf_keys, pt_keys, f"{name}: Output keys differ between TF and PyTorch")
# convert to the case of `tuple`
# appending each key to the current (string) `names`
# appending each key to the current (string) `name`
attributes = tuple([f"{name}.{k}" for k in tf_keys])
self.check_pt_tf_outputs(
tf_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, tol=tol, name=name, attributes=attributes
@ -1664,10 +1664,10 @@ class ModelTesterMixin:
self.assertEqual(
len(attributes),
len(tf_outputs),
f"{name}: The tuple `names` should have the same length as `tf_outputs`",
f"{name}: The tuple `attributes` should have the same length as `tf_outputs`",
)
else:
# case 2: each output has no assigned name (e.g. hidden states of each layer) -> add an index to `names`
# case 2: each output has no assigned name (e.g. hidden states of each layer) -> add an index to `name`
attributes = tuple([f"{name}_{idx}" for idx in range(len(tf_outputs))])
for tf_output, pt_output, attr in zip(tf_outputs, pt_outputs, attributes):
@ -1699,10 +1699,10 @@ class ModelTesterMixin:
tf_outputs[pt_nans] = 0
max_diff = np.amax(np.abs(tf_outputs - pt_outputs))
self.assertLessEqual(max_diff, tol, f"{name}: Difference between torch and tf is {max_diff} (>= {tol}).")
self.assertLessEqual(max_diff, tol, f"{name}: Difference between PyTorch and TF is {max_diff} (>= {tol}).")
else:
raise ValueError(
"`tf_outputs` should be an instance of `tf.Tensor`, a `tuple`, or an instance of `tf.Tensor`. Got"
"`tf_outputs` should be an instance of `ModelOutput`, a `tuple`, or an instance of `tf.Tensor`. Got"
f" {type(tf_outputs)} instead."
)
@ -1838,7 +1838,7 @@ class ModelTesterMixin:
diff = np.abs((a - b)).max()
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
def check_outputs(self, fx_outputs, pt_outputs, model_class, names):
def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=1e-5, name="outputs", attributes=None):
"""
Args:
model_class: The class of the model that is currently testing. For example, ..., etc.
@ -1848,24 +1848,71 @@ class ModelTesterMixin:
Currently unused, but in the future, we could use this information to make the error message clearer
by giving the name(s) of the output tensor(s) with large difference(s) between PT and Flax.
"""
if type(fx_outputs) in [tuple, list]:
self.assertEqual(type(fx_outputs), type(pt_outputs))
self.assertEqual(len(fx_outputs), len(pt_outputs))
if type(names) == tuple:
for fo, po, name in zip(fx_outputs, pt_outputs, names):
self.check_outputs(fo, po, model_class, names=name)
elif type(names) == str:
for idx, (fo, po) in enumerate(zip(fx_outputs, pt_outputs)):
self.check_outputs(fo, po, model_class, names=f"{names}_{idx}")
self.assertEqual(type(name), str)
if attributes is not None:
self.assertEqual(type(attributes), tuple, f"{name}: The argument `attributes` should be a `tuple`")
# Allow `ModelOutput` (e.g. `CLIPOutput` has `text_model_output` and `vision_model_output`).
if isinstance(fx_outputs, ModelOutput):
self.assertTrue(
isinstance(pt_outputs, ModelOutput),
f"{name}: `pt_outputs` should an instance of `ModelOutput` when `fx_outputs` is",
)
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
self.assertEqual(fx_keys, pt_keys, f"{name}: Output keys differ between Flax and PyTorch")
# convert to the case of `tuple`
# appending each key to the current (string) `name`
attributes = tuple([f"{name}.{k}" for k in fx_keys])
self.check_pt_flax_outputs(
fx_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, tol=tol, name=name, attributes=attributes
)
# Allow `list` (e.g. `TransfoXLModelOutput.mems` is a list of tensors.)
elif type(fx_outputs) in [tuple, list]:
self.assertEqual(
type(fx_outputs), type(pt_outputs), f"{name}: Output types differ between Flax and PyTorch"
)
self.assertEqual(
len(fx_outputs), len(pt_outputs), f"{name}: Output lengths differ between Flax and PyTorch"
)
if attributes is not None:
# case 1: each output has assigned name (e.g. a tuple form of a `ModelOutput`)
self.assertEqual(
len(attributes),
len(fx_outputs),
f"{name}: The tuple `attributes` should have the same length as `fx_outputs`",
)
else:
raise ValueError(f"`names` should be a `tuple` or a string. Got {type(names)} instead.")
# case 2: each output has no assigned name (e.g. hidden states of each layer) -> add an index to `name`
attributes = tuple([f"{name}_{idx}" for idx in range(len(fx_outputs))])
for fx_output, pt_output, attr in zip(fx_outputs, pt_outputs, attributes):
self.check_pt_flax_outputs(fx_output, pt_output, model_class, tol=tol, name=attr)
elif isinstance(fx_outputs, jnp.ndarray):
self.assertTrue(isinstance(pt_outputs, torch.Tensor))
self.assertTrue(
isinstance(pt_outputs, torch.Tensor), f"{name}: `pt_outputs` should a tensor when `fx_outputs` is"
)
# Using `np.asarray` gives `ValueError: assignment destination is read-only` at the line `fx_outputs[fx_nans] = 0`.
fx_outputs = np.array(fx_outputs)
pt_outputs = pt_outputs.detach().to("cpu").numpy()
self.assertEqual(
fx_outputs.shape, pt_outputs.shape, f"{name}: Output shapes differ between Flax and PyTorch"
)
# deal with NumPy's scalars to make replacing nan values by 0 work.
if np.isscalar(fx_outputs):
fx_outputs = np.array([fx_outputs])
pt_outputs = np.array([pt_outputs])
fx_nans = np.isnan(fx_outputs)
pt_nans = np.isnan(pt_outputs)
@ -1874,10 +1921,14 @@ class ModelTesterMixin:
pt_outputs[pt_nans] = 0
fx_outputs[pt_nans] = 0
self.assert_almost_equals(fx_outputs, pt_outputs, 1e-5)
max_diff = np.amax(np.abs(fx_outputs - pt_outputs))
self.assertLessEqual(
max_diff, tol, f"{name}: Difference between PyTorch and Flax is {max_diff} (>= {tol})."
)
else:
raise ValueError(
f"`fx_outputs` should be a `tuple` or an instance of `jnp.ndarray`. Got {type(fx_outputs)} instead."
"`fx_outputs` should be an instance of `ModelOutput`, a `tuple`, or an instance of `jnp.ndarray`. Got"
f" {type(fx_outputs)} instead."
)
@is_pt_flax_cross_test
@ -1938,7 +1989,7 @@ class ModelTesterMixin:
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
self.assertEqual(fx_keys, pt_keys)
self.check_outputs(fx_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, names=fx_keys)
self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class)
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
@ -1950,7 +2001,7 @@ class ModelTesterMixin:
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
self.assertEqual(fx_keys, pt_keys)
self.check_outputs(fx_outputs_loaded.to_tuple(), pt_outputs.to_tuple(), model_class, names=fx_keys)
self.check_pt_flax_outputs(fx_outputs_loaded, pt_outputs, model_class)
@is_pt_flax_cross_test
def test_equivalence_flax_to_pt(self):
@ -2012,7 +2063,7 @@ class ModelTesterMixin:
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
self.assertEqual(fx_keys, pt_keys)
self.check_outputs(fx_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, names=fx_keys)
self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class)
with tempfile.TemporaryDirectory() as tmpdirname:
fx_model.save_pretrained(tmpdirname)
@ -2029,7 +2080,7 @@ class ModelTesterMixin:
pt_keys = tuple([k for k, v in pt_outputs_loaded.items() if v is not None])
self.assertEqual(fx_keys, pt_keys)
self.check_outputs(fx_outputs.to_tuple(), pt_outputs_loaded.to_tuple(), model_class, names=fx_keys)
self.check_pt_flax_outputs(fx_outputs, pt_outputs_loaded, model_class)
def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

View File

@ -36,6 +36,7 @@ from transformers.testing_utils import (
torch_device,
)
from transformers.utils import logging
from transformers.utils.generic import ModelOutput
if is_flax_available():
@ -169,8 +170,8 @@ class FlaxModelTesterMixin:
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
# (Copied from tests.test_modeling_common.ModelTesterMixin.check_outputs)
def check_outputs(self, fx_outputs, pt_outputs, model_class, names):
# (Copied from tests.test_modeling_common.ModelTesterMixin.check_pt_flax_outputs)
def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=1e-5, name="outputs", attributes=None):
"""
Args:
model_class: The class of the model that is currently testing. For example, ..., etc.
@ -180,24 +181,71 @@ class FlaxModelTesterMixin:
Currently unused, but in the future, we could use this information to make the error message clearer
by giving the name(s) of the output tensor(s) with large difference(s) between PT and Flax.
"""
if type(fx_outputs) in [tuple, list]:
self.assertEqual(type(fx_outputs), type(pt_outputs))
self.assertEqual(len(fx_outputs), len(pt_outputs))
if type(names) == tuple:
for fo, po, name in zip(fx_outputs, pt_outputs, names):
self.check_outputs(fo, po, model_class, names=name)
elif type(names) == str:
for idx, (fo, po) in enumerate(zip(fx_outputs, pt_outputs)):
self.check_outputs(fo, po, model_class, names=f"{names}_{idx}")
self.assertEqual(type(name), str)
if attributes is not None:
self.assertEqual(type(attributes), tuple, f"{name}: The argument `attributes` should be a `tuple`")
# Allow `ModelOutput` (e.g. `CLIPOutput` has `text_model_output` and `vision_model_output`).
if isinstance(fx_outputs, ModelOutput):
self.assertTrue(
isinstance(pt_outputs, ModelOutput),
f"{name}: `pt_outputs` should an instance of `ModelOutput` when `fx_outputs` is",
)
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
self.assertEqual(fx_keys, pt_keys, f"{name}: Output keys differ between Flax and PyTorch")
# convert to the case of `tuple`
# appending each key to the current (string) `name`
attributes = tuple([f"{name}.{k}" for k in fx_keys])
self.check_pt_flax_outputs(
fx_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, tol=tol, name=name, attributes=attributes
)
# Allow `list` (e.g. `TransfoXLModelOutput.mems` is a list of tensors.)
elif type(fx_outputs) in [tuple, list]:
self.assertEqual(
type(fx_outputs), type(pt_outputs), f"{name}: Output types differ between Flax and PyTorch"
)
self.assertEqual(
len(fx_outputs), len(pt_outputs), f"{name}: Output lengths differ between Flax and PyTorch"
)
if attributes is not None:
# case 1: each output has assigned name (e.g. a tuple form of a `ModelOutput`)
self.assertEqual(
len(attributes),
len(fx_outputs),
f"{name}: The tuple `attributes` should have the same length as `fx_outputs`",
)
else:
raise ValueError(f"`names` should be a `tuple` or a string. Got {type(names)} instead.")
# case 2: each output has no assigned name (e.g. hidden states of each layer) -> add an index to `name`
attributes = tuple([f"{name}_{idx}" for idx in range(len(fx_outputs))])
for fx_output, pt_output, attr in zip(fx_outputs, pt_outputs, attributes):
self.check_pt_flax_outputs(fx_output, pt_output, model_class, tol=tol, name=attr)
elif isinstance(fx_outputs, jnp.ndarray):
self.assertTrue(isinstance(pt_outputs, torch.Tensor))
self.assertTrue(
isinstance(pt_outputs, torch.Tensor), f"{name}: `pt_outputs` should a tensor when `fx_outputs` is"
)
# Using `np.asarray` gives `ValueError: assignment destination is read-only` at the line `fx_outputs[fx_nans] = 0`.
fx_outputs = np.array(fx_outputs)
pt_outputs = pt_outputs.detach().to("cpu").numpy()
self.assertEqual(
fx_outputs.shape, pt_outputs.shape, f"{name}: Output shapes differ between Flax and PyTorch"
)
# deal with NumPy's scalars to make replacing nan values by 0 work.
if np.isscalar(fx_outputs):
fx_outputs = np.array([fx_outputs])
pt_outputs = np.array([pt_outputs])
fx_nans = np.isnan(fx_outputs)
pt_nans = np.isnan(pt_outputs)
@ -206,10 +254,14 @@ class FlaxModelTesterMixin:
pt_outputs[pt_nans] = 0
fx_outputs[pt_nans] = 0
self.assert_almost_equals(fx_outputs, pt_outputs, 1e-5)
max_diff = np.amax(np.abs(fx_outputs - pt_outputs))
self.assertLessEqual(
max_diff, tol, f"{name}: Difference between PyTorch and Flax is {max_diff} (>= {tol})."
)
else:
raise ValueError(
f"`fx_outputs` should be a `tuple` or an instance of `jnp.ndarray`. Got {type(fx_outputs)} instead."
"`fx_outputs` should be an instance of `ModelOutput`, a `tuple`, or an instance of `jnp.ndarray`. Got"
f" {type(fx_outputs)} instead."
)
@is_pt_flax_cross_test
@ -253,7 +305,7 @@ class FlaxModelTesterMixin:
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
self.assertEqual(fx_keys, pt_keys)
self.check_outputs(fx_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, names=fx_keys)
self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class)
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
@ -265,7 +317,7 @@ class FlaxModelTesterMixin:
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
self.assertEqual(fx_keys, pt_keys)
self.check_outputs(fx_outputs_loaded.to_tuple(), pt_outputs.to_tuple(), model_class, names=fx_keys)
self.check_pt_flax_outputs(fx_outputs_loaded, pt_outputs, model_class)
@is_pt_flax_cross_test
def test_equivalence_flax_to_pt(self):
@ -308,7 +360,7 @@ class FlaxModelTesterMixin:
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
self.assertEqual(fx_keys, pt_keys)
self.check_outputs(fx_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, names=fx_keys)
self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class)
with tempfile.TemporaryDirectory() as tmpdirname:
fx_model.save_pretrained(tmpdirname)
@ -325,7 +377,7 @@ class FlaxModelTesterMixin:
pt_keys = tuple([k for k, v in pt_outputs_loaded.items() if v is not None])
self.assertEqual(fx_keys, pt_keys)
self.check_outputs(fx_outputs.to_tuple(), pt_outputs_loaded.to_tuple(), model_class, names=fx_keys)
self.check_pt_flax_outputs(fx_outputs, pt_outputs_loaded, model_class)
def test_from_pretrained_save_pretrained(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()