mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00
Use 5e-5 For BigBird PT/Flax equivalence tests (#17780)
* rename to check_pt_flax_outputs * update check_pt_flax_outputs * use 5e-5 for BigBird PT/Flax test Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
6a5272b205
commit
f47afefb21
@ -597,13 +597,14 @@ class BigBirdModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
self.model_tester.create_and_check_for_change_to_full_attn(*config_and_inputs)
|
||||
|
||||
# overwrite from common in order to skip the check on `attentions`
|
||||
def check_outputs(self, fx_outputs, pt_outputs, model_class, names):
|
||||
# also use `5e-5` to avoid flaky test failure
|
||||
def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=5e-5, name="outputs", attributes=None):
|
||||
# `bigbird_block_sparse_attention` in `FlaxBigBird` returns `attention_probs = None`, while in PyTorch version,
|
||||
# an effort was done to return `attention_probs` (yet to be verified).
|
||||
if type(names) == str and names.startswith("attentions"):
|
||||
if name.startswith("outputs.attentions"):
|
||||
return
|
||||
else:
|
||||
super().check_outputs(fx_outputs, pt_outputs, model_class, names)
|
||||
super().check_pt_flax_outputs(fx_outputs, pt_outputs, model_class, tol, name, attributes)
|
||||
|
||||
|
||||
@require_torch
|
||||
|
@ -214,10 +214,11 @@ class FlaxBigBirdModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
self.assertEqual(jitted_output.shape, output.shape)
|
||||
|
||||
# overwrite from common in order to skip the check on `attentions`
|
||||
def check_outputs(self, fx_outputs, pt_outputs, model_class, names):
|
||||
# also use `5e-5` to avoid flaky test failure
|
||||
def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=5e-5, name="outputs", attributes=None):
|
||||
# `bigbird_block_sparse_attention` in `FlaxBigBird` returns `attention_probs = None`, while in PyTorch version,
|
||||
# an effort was done to return `attention_probs` (yet to be verified).
|
||||
if type(names) == str and names.startswith("attentions"):
|
||||
if name.startswith("outputs.attentions"):
|
||||
return
|
||||
else:
|
||||
super().check_outputs(fx_outputs, pt_outputs, model_class, names)
|
||||
super().check_pt_flax_outputs(fx_outputs, pt_outputs, model_class, tol, name, attributes)
|
||||
|
@ -1648,7 +1648,7 @@ class ModelTesterMixin:
|
||||
self.assertEqual(tf_keys, pt_keys, f"{name}: Output keys differ between TF and PyTorch")
|
||||
|
||||
# convert to the case of `tuple`
|
||||
# appending each key to the current (string) `names`
|
||||
# appending each key to the current (string) `name`
|
||||
attributes = tuple([f"{name}.{k}" for k in tf_keys])
|
||||
self.check_pt_tf_outputs(
|
||||
tf_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, tol=tol, name=name, attributes=attributes
|
||||
@ -1664,10 +1664,10 @@ class ModelTesterMixin:
|
||||
self.assertEqual(
|
||||
len(attributes),
|
||||
len(tf_outputs),
|
||||
f"{name}: The tuple `names` should have the same length as `tf_outputs`",
|
||||
f"{name}: The tuple `attributes` should have the same length as `tf_outputs`",
|
||||
)
|
||||
else:
|
||||
# case 2: each output has no assigned name (e.g. hidden states of each layer) -> add an index to `names`
|
||||
# case 2: each output has no assigned name (e.g. hidden states of each layer) -> add an index to `name`
|
||||
attributes = tuple([f"{name}_{idx}" for idx in range(len(tf_outputs))])
|
||||
|
||||
for tf_output, pt_output, attr in zip(tf_outputs, pt_outputs, attributes):
|
||||
@ -1699,10 +1699,10 @@ class ModelTesterMixin:
|
||||
tf_outputs[pt_nans] = 0
|
||||
|
||||
max_diff = np.amax(np.abs(tf_outputs - pt_outputs))
|
||||
self.assertLessEqual(max_diff, tol, f"{name}: Difference between torch and tf is {max_diff} (>= {tol}).")
|
||||
self.assertLessEqual(max_diff, tol, f"{name}: Difference between PyTorch and TF is {max_diff} (>= {tol}).")
|
||||
else:
|
||||
raise ValueError(
|
||||
"`tf_outputs` should be an instance of `tf.Tensor`, a `tuple`, or an instance of `tf.Tensor`. Got"
|
||||
"`tf_outputs` should be an instance of `ModelOutput`, a `tuple`, or an instance of `tf.Tensor`. Got"
|
||||
f" {type(tf_outputs)} instead."
|
||||
)
|
||||
|
||||
@ -1838,7 +1838,7 @@ class ModelTesterMixin:
|
||||
diff = np.abs((a - b)).max()
|
||||
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
|
||||
|
||||
def check_outputs(self, fx_outputs, pt_outputs, model_class, names):
|
||||
def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=1e-5, name="outputs", attributes=None):
|
||||
"""
|
||||
Args:
|
||||
model_class: The class of the model that is currently testing. For example, ..., etc.
|
||||
@ -1848,24 +1848,71 @@ class ModelTesterMixin:
|
||||
Currently unused, but in the future, we could use this information to make the error message clearer
|
||||
by giving the name(s) of the output tensor(s) with large difference(s) between PT and Flax.
|
||||
"""
|
||||
if type(fx_outputs) in [tuple, list]:
|
||||
self.assertEqual(type(fx_outputs), type(pt_outputs))
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs))
|
||||
if type(names) == tuple:
|
||||
for fo, po, name in zip(fx_outputs, pt_outputs, names):
|
||||
self.check_outputs(fo, po, model_class, names=name)
|
||||
elif type(names) == str:
|
||||
for idx, (fo, po) in enumerate(zip(fx_outputs, pt_outputs)):
|
||||
self.check_outputs(fo, po, model_class, names=f"{names}_{idx}")
|
||||
|
||||
self.assertEqual(type(name), str)
|
||||
if attributes is not None:
|
||||
self.assertEqual(type(attributes), tuple, f"{name}: The argument `attributes` should be a `tuple`")
|
||||
|
||||
# Allow `ModelOutput` (e.g. `CLIPOutput` has `text_model_output` and `vision_model_output`).
|
||||
if isinstance(fx_outputs, ModelOutput):
|
||||
self.assertTrue(
|
||||
isinstance(pt_outputs, ModelOutput),
|
||||
f"{name}: `pt_outputs` should an instance of `ModelOutput` when `fx_outputs` is",
|
||||
)
|
||||
|
||||
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
|
||||
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
|
||||
|
||||
self.assertEqual(fx_keys, pt_keys, f"{name}: Output keys differ between Flax and PyTorch")
|
||||
|
||||
# convert to the case of `tuple`
|
||||
# appending each key to the current (string) `name`
|
||||
attributes = tuple([f"{name}.{k}" for k in fx_keys])
|
||||
self.check_pt_flax_outputs(
|
||||
fx_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, tol=tol, name=name, attributes=attributes
|
||||
)
|
||||
|
||||
# Allow `list` (e.g. `TransfoXLModelOutput.mems` is a list of tensors.)
|
||||
elif type(fx_outputs) in [tuple, list]:
|
||||
self.assertEqual(
|
||||
type(fx_outputs), type(pt_outputs), f"{name}: Output types differ between Flax and PyTorch"
|
||||
)
|
||||
self.assertEqual(
|
||||
len(fx_outputs), len(pt_outputs), f"{name}: Output lengths differ between Flax and PyTorch"
|
||||
)
|
||||
|
||||
if attributes is not None:
|
||||
# case 1: each output has assigned name (e.g. a tuple form of a `ModelOutput`)
|
||||
self.assertEqual(
|
||||
len(attributes),
|
||||
len(fx_outputs),
|
||||
f"{name}: The tuple `attributes` should have the same length as `fx_outputs`",
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"`names` should be a `tuple` or a string. Got {type(names)} instead.")
|
||||
# case 2: each output has no assigned name (e.g. hidden states of each layer) -> add an index to `name`
|
||||
attributes = tuple([f"{name}_{idx}" for idx in range(len(fx_outputs))])
|
||||
|
||||
for fx_output, pt_output, attr in zip(fx_outputs, pt_outputs, attributes):
|
||||
self.check_pt_flax_outputs(fx_output, pt_output, model_class, tol=tol, name=attr)
|
||||
|
||||
elif isinstance(fx_outputs, jnp.ndarray):
|
||||
self.assertTrue(isinstance(pt_outputs, torch.Tensor))
|
||||
self.assertTrue(
|
||||
isinstance(pt_outputs, torch.Tensor), f"{name}: `pt_outputs` should a tensor when `fx_outputs` is"
|
||||
)
|
||||
|
||||
# Using `np.asarray` gives `ValueError: assignment destination is read-only` at the line `fx_outputs[fx_nans] = 0`.
|
||||
fx_outputs = np.array(fx_outputs)
|
||||
pt_outputs = pt_outputs.detach().to("cpu").numpy()
|
||||
|
||||
self.assertEqual(
|
||||
fx_outputs.shape, pt_outputs.shape, f"{name}: Output shapes differ between Flax and PyTorch"
|
||||
)
|
||||
|
||||
# deal with NumPy's scalars to make replacing nan values by 0 work.
|
||||
if np.isscalar(fx_outputs):
|
||||
fx_outputs = np.array([fx_outputs])
|
||||
pt_outputs = np.array([pt_outputs])
|
||||
|
||||
fx_nans = np.isnan(fx_outputs)
|
||||
pt_nans = np.isnan(pt_outputs)
|
||||
|
||||
@ -1874,10 +1921,14 @@ class ModelTesterMixin:
|
||||
pt_outputs[pt_nans] = 0
|
||||
fx_outputs[pt_nans] = 0
|
||||
|
||||
self.assert_almost_equals(fx_outputs, pt_outputs, 1e-5)
|
||||
max_diff = np.amax(np.abs(fx_outputs - pt_outputs))
|
||||
self.assertLessEqual(
|
||||
max_diff, tol, f"{name}: Difference between PyTorch and Flax is {max_diff} (>= {tol})."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`fx_outputs` should be a `tuple` or an instance of `jnp.ndarray`. Got {type(fx_outputs)} instead."
|
||||
"`fx_outputs` should be an instance of `ModelOutput`, a `tuple`, or an instance of `jnp.ndarray`. Got"
|
||||
f" {type(fx_outputs)} instead."
|
||||
)
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
@ -1938,7 +1989,7 @@ class ModelTesterMixin:
|
||||
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
|
||||
|
||||
self.assertEqual(fx_keys, pt_keys)
|
||||
self.check_outputs(fx_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, names=fx_keys)
|
||||
self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
@ -1950,7 +2001,7 @@ class ModelTesterMixin:
|
||||
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
|
||||
|
||||
self.assertEqual(fx_keys, pt_keys)
|
||||
self.check_outputs(fx_outputs_loaded.to_tuple(), pt_outputs.to_tuple(), model_class, names=fx_keys)
|
||||
self.check_pt_flax_outputs(fx_outputs_loaded, pt_outputs, model_class)
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_equivalence_flax_to_pt(self):
|
||||
@ -2012,7 +2063,7 @@ class ModelTesterMixin:
|
||||
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
|
||||
|
||||
self.assertEqual(fx_keys, pt_keys)
|
||||
self.check_outputs(fx_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, names=fx_keys)
|
||||
self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
fx_model.save_pretrained(tmpdirname)
|
||||
@ -2029,7 +2080,7 @@ class ModelTesterMixin:
|
||||
pt_keys = tuple([k for k, v in pt_outputs_loaded.items() if v is not None])
|
||||
|
||||
self.assertEqual(fx_keys, pt_keys)
|
||||
self.check_outputs(fx_outputs.to_tuple(), pt_outputs_loaded.to_tuple(), model_class, names=fx_keys)
|
||||
self.check_pt_flax_outputs(fx_outputs, pt_outputs_loaded, model_class)
|
||||
|
||||
def test_inputs_embeds(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
@ -36,6 +36,7 @@ from transformers.testing_utils import (
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import logging
|
||||
from transformers.utils.generic import ModelOutput
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
@ -169,8 +170,8 @@ class FlaxModelTesterMixin:
|
||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
|
||||
|
||||
# (Copied from tests.test_modeling_common.ModelTesterMixin.check_outputs)
|
||||
def check_outputs(self, fx_outputs, pt_outputs, model_class, names):
|
||||
# (Copied from tests.test_modeling_common.ModelTesterMixin.check_pt_flax_outputs)
|
||||
def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=1e-5, name="outputs", attributes=None):
|
||||
"""
|
||||
Args:
|
||||
model_class: The class of the model that is currently testing. For example, ..., etc.
|
||||
@ -180,24 +181,71 @@ class FlaxModelTesterMixin:
|
||||
Currently unused, but in the future, we could use this information to make the error message clearer
|
||||
by giving the name(s) of the output tensor(s) with large difference(s) between PT and Flax.
|
||||
"""
|
||||
if type(fx_outputs) in [tuple, list]:
|
||||
self.assertEqual(type(fx_outputs), type(pt_outputs))
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs))
|
||||
if type(names) == tuple:
|
||||
for fo, po, name in zip(fx_outputs, pt_outputs, names):
|
||||
self.check_outputs(fo, po, model_class, names=name)
|
||||
elif type(names) == str:
|
||||
for idx, (fo, po) in enumerate(zip(fx_outputs, pt_outputs)):
|
||||
self.check_outputs(fo, po, model_class, names=f"{names}_{idx}")
|
||||
|
||||
self.assertEqual(type(name), str)
|
||||
if attributes is not None:
|
||||
self.assertEqual(type(attributes), tuple, f"{name}: The argument `attributes` should be a `tuple`")
|
||||
|
||||
# Allow `ModelOutput` (e.g. `CLIPOutput` has `text_model_output` and `vision_model_output`).
|
||||
if isinstance(fx_outputs, ModelOutput):
|
||||
self.assertTrue(
|
||||
isinstance(pt_outputs, ModelOutput),
|
||||
f"{name}: `pt_outputs` should an instance of `ModelOutput` when `fx_outputs` is",
|
||||
)
|
||||
|
||||
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
|
||||
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
|
||||
|
||||
self.assertEqual(fx_keys, pt_keys, f"{name}: Output keys differ between Flax and PyTorch")
|
||||
|
||||
# convert to the case of `tuple`
|
||||
# appending each key to the current (string) `name`
|
||||
attributes = tuple([f"{name}.{k}" for k in fx_keys])
|
||||
self.check_pt_flax_outputs(
|
||||
fx_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, tol=tol, name=name, attributes=attributes
|
||||
)
|
||||
|
||||
# Allow `list` (e.g. `TransfoXLModelOutput.mems` is a list of tensors.)
|
||||
elif type(fx_outputs) in [tuple, list]:
|
||||
self.assertEqual(
|
||||
type(fx_outputs), type(pt_outputs), f"{name}: Output types differ between Flax and PyTorch"
|
||||
)
|
||||
self.assertEqual(
|
||||
len(fx_outputs), len(pt_outputs), f"{name}: Output lengths differ between Flax and PyTorch"
|
||||
)
|
||||
|
||||
if attributes is not None:
|
||||
# case 1: each output has assigned name (e.g. a tuple form of a `ModelOutput`)
|
||||
self.assertEqual(
|
||||
len(attributes),
|
||||
len(fx_outputs),
|
||||
f"{name}: The tuple `attributes` should have the same length as `fx_outputs`",
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"`names` should be a `tuple` or a string. Got {type(names)} instead.")
|
||||
# case 2: each output has no assigned name (e.g. hidden states of each layer) -> add an index to `name`
|
||||
attributes = tuple([f"{name}_{idx}" for idx in range(len(fx_outputs))])
|
||||
|
||||
for fx_output, pt_output, attr in zip(fx_outputs, pt_outputs, attributes):
|
||||
self.check_pt_flax_outputs(fx_output, pt_output, model_class, tol=tol, name=attr)
|
||||
|
||||
elif isinstance(fx_outputs, jnp.ndarray):
|
||||
self.assertTrue(isinstance(pt_outputs, torch.Tensor))
|
||||
self.assertTrue(
|
||||
isinstance(pt_outputs, torch.Tensor), f"{name}: `pt_outputs` should a tensor when `fx_outputs` is"
|
||||
)
|
||||
|
||||
# Using `np.asarray` gives `ValueError: assignment destination is read-only` at the line `fx_outputs[fx_nans] = 0`.
|
||||
fx_outputs = np.array(fx_outputs)
|
||||
pt_outputs = pt_outputs.detach().to("cpu").numpy()
|
||||
|
||||
self.assertEqual(
|
||||
fx_outputs.shape, pt_outputs.shape, f"{name}: Output shapes differ between Flax and PyTorch"
|
||||
)
|
||||
|
||||
# deal with NumPy's scalars to make replacing nan values by 0 work.
|
||||
if np.isscalar(fx_outputs):
|
||||
fx_outputs = np.array([fx_outputs])
|
||||
pt_outputs = np.array([pt_outputs])
|
||||
|
||||
fx_nans = np.isnan(fx_outputs)
|
||||
pt_nans = np.isnan(pt_outputs)
|
||||
|
||||
@ -206,10 +254,14 @@ class FlaxModelTesterMixin:
|
||||
pt_outputs[pt_nans] = 0
|
||||
fx_outputs[pt_nans] = 0
|
||||
|
||||
self.assert_almost_equals(fx_outputs, pt_outputs, 1e-5)
|
||||
max_diff = np.amax(np.abs(fx_outputs - pt_outputs))
|
||||
self.assertLessEqual(
|
||||
max_diff, tol, f"{name}: Difference between PyTorch and Flax is {max_diff} (>= {tol})."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`fx_outputs` should be a `tuple` or an instance of `jnp.ndarray`. Got {type(fx_outputs)} instead."
|
||||
"`fx_outputs` should be an instance of `ModelOutput`, a `tuple`, or an instance of `jnp.ndarray`. Got"
|
||||
f" {type(fx_outputs)} instead."
|
||||
)
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
@ -253,7 +305,7 @@ class FlaxModelTesterMixin:
|
||||
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
|
||||
|
||||
self.assertEqual(fx_keys, pt_keys)
|
||||
self.check_outputs(fx_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, names=fx_keys)
|
||||
self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
@ -265,7 +317,7 @@ class FlaxModelTesterMixin:
|
||||
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
|
||||
|
||||
self.assertEqual(fx_keys, pt_keys)
|
||||
self.check_outputs(fx_outputs_loaded.to_tuple(), pt_outputs.to_tuple(), model_class, names=fx_keys)
|
||||
self.check_pt_flax_outputs(fx_outputs_loaded, pt_outputs, model_class)
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_equivalence_flax_to_pt(self):
|
||||
@ -308,7 +360,7 @@ class FlaxModelTesterMixin:
|
||||
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
|
||||
|
||||
self.assertEqual(fx_keys, pt_keys)
|
||||
self.check_outputs(fx_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, names=fx_keys)
|
||||
self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
fx_model.save_pretrained(tmpdirname)
|
||||
@ -325,7 +377,7 @@ class FlaxModelTesterMixin:
|
||||
pt_keys = tuple([k for k, v in pt_outputs_loaded.items() if v is not None])
|
||||
|
||||
self.assertEqual(fx_keys, pt_keys)
|
||||
self.check_outputs(fx_outputs.to_tuple(), pt_outputs_loaded.to_tuple(), model_class, names=fx_keys)
|
||||
self.check_pt_flax_outputs(fx_outputs, pt_outputs_loaded, model_class)
|
||||
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
Loading…
Reference in New Issue
Block a user