Make TF pt-tf equivalence test more aggressive (#15839)

* Make TF pt-tf equivalence test more aggressive

* Fix for TFConvNextModelTest and TFTransfoXLModelTest

* fix kwargs for outputs

* clean-up

* Add docstring for check_outputs()

* remove: need to rename encoder-decoder

* clean-up

* send PyTorch things to the correct device

* Add back the accidentally removed test case in test_pt_tf_model_equivalence()

* Fix: change to tuple before calling check_outputs()

* Fix: tfo could be a list

* use to_tuple()

* allow tfo only to be tuple or tensor

* allow tfo to be list or tuple for now + style change

* minor fix

* remove np.copy and update comments

* tfo -> tf_output, same for pt

* Add more detailed comment

* remove the incorrect comment

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2022-03-14 13:31:32 +01:00 committed by GitHub
parent 9e9f6b8a45
commit 923c35b5c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -39,10 +39,14 @@ from transformers.testing_utils import (
require_tf,
require_tf2onnx,
slow,
torch_device,
)
from transformers.utils import logging
logger = logging.get_logger(__name__)
if is_tf_available():
import numpy as np
import tensorflow as tf
@ -348,27 +352,10 @@ class TFModelTesterMixin:
import transformers
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
def prepare_pt_inputs_from_tf_inputs(tf_inputs_dict):
for model_class in self.all_model_classes:
pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beginning
pt_model_class = getattr(transformers, pt_model_class_name)
config.output_hidden_states = True
tf_model = model_class(config)
pt_model = pt_model_class(config)
# Check we can load pt model in tf and vice-versa with model => model functions
tf_model = transformers.load_pytorch_model_in_tf2_model(
tf_model, pt_model, tf_inputs=self._prepare_for_class(inputs_dict, model_class)
)
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model)
# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
pt_model.eval()
pt_inputs_dict = {}
for name, key in self._prepare_for_class(inputs_dict, model_class).items():
for name, key in tf_inputs_dict.items():
if type(key) == bool:
pt_inputs_dict[name] = key
elif name == "input_values":
@ -380,23 +367,217 @@ class TFModelTesterMixin:
else:
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long)
return pt_inputs_dict
def check_outputs(tf_outputs, pt_outputs, model_class, names):
"""
Args:
model_class: The class of the model that is currently testing. For example, `TFBertModel`,
TFBertForMaskedLM`, `TFBertForSequenceClassification`, etc. Currently unused, but it could make
debugging easier and faster.
names: A string, or a tuple of strings. These specify what tf_outputs/pt_outputs represent in the model outputs.
Currently unused, but in the future, we could use this information to make the error message clearer
by giving the name(s) of the output tensor(s) with large difference(s) between PT and TF.
"""
# Some issue (`about past_key_values`) to solve (e.g. `TFPegasusForConditionalGeneration`) in a separate PR.
if names == "past_key_values":
return
# Allow `list` because `(TF)TransfoXLModelOutput.mems` is a list of tensors.
if type(tf_outputs) in [tuple, list]:
self.assertEqual(type(tf_outputs), type(pt_outputs))
self.assertEqual(len(tf_outputs), len(pt_outputs))
if type(names) == tuple:
for tf_output, pt_output, name in zip(tf_outputs, pt_outputs, names):
check_outputs(tf_output, pt_output, model_class, names=name)
elif type(names) == str:
for idx, (tf_output, pt_output) in enumerate(zip(tf_outputs, pt_outputs)):
check_outputs(tf_output, pt_output, model_class, names=f"{names}_{idx}")
else:
raise ValueError(f"`names` should be a `tuple` or a string. Got {type(names)} instead.")
elif isinstance(tf_outputs, tf.Tensor):
self.assertTrue(isinstance(pt_outputs, torch.Tensor))
tf_outputs = tf_outputs.numpy()
pt_outputs = pt_outputs.detach().to("cpu").numpy()
tf_nans = np.isnan(tf_outputs)
pt_nans = np.isnan(pt_outputs)
pt_outputs[tf_nans] = 0
tf_outputs[tf_nans] = 0
pt_outputs[pt_nans] = 0
tf_outputs[pt_nans] = 0
max_diff = np.amax(np.abs(tf_outputs - pt_outputs))
self.assertLessEqual(max_diff, 1e-5)
else:
raise ValueError(
f"`tf_outputs` should be a `tuple` or an instance of `tf.Tensor`. Got {type(tf_outputs)} instead."
)
def check_pt_tf_models(tf_model, pt_model):
# send pytorch model to the correct device
pt_model.to(torch_device)
# Check predictions on first output (logits/hidden-states) are close enough given low-level computational differences
pt_model.eval()
pt_inputs_dict = prepare_pt_inputs_from_tf_inputs(tf_inputs_dict)
pt_inputs_dict_maybe_with_labels = prepare_pt_inputs_from_tf_inputs(tf_inputs_dict_maybe_with_labels)
# send pytorch inputs to the correct device
pt_inputs_dict = {
k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs_dict.items()
}
pt_inputs_dict_maybe_with_labels = {
k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v
for k, v in pt_inputs_dict_maybe_with_labels.items()
}
# Original test: check without `labels`
with torch.no_grad():
pto = pt_model(**pt_inputs_dict)
tfo = tf_model(self._prepare_for_class(inputs_dict, model_class), training=False)
pt_outputs = pt_model(**pt_inputs_dict)
tf_outputs = tf_model(tf_inputs_dict)
tf_hidden_states = tfo[0].numpy()
pt_hidden_states = pto[0].numpy()
tf_keys = tuple([k for k, v in tf_outputs.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
tf_nans = np.copy(np.isnan(tf_hidden_states))
pt_nans = np.copy(np.isnan(pt_hidden_states))
self.assertEqual(tf_keys, pt_keys)
check_outputs(tf_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, names=tf_keys)
pt_hidden_states[tf_nans] = 0
tf_hidden_states[tf_nans] = 0
pt_hidden_states[pt_nans] = 0
tf_hidden_states[pt_nans] = 0
# check the case where `labels` is passed
has_labels = any(
x in tf_inputs_dict_maybe_with_labels for x in ["labels", "next_sentence_label", "start_positions"]
)
if has_labels:
max_diff = np.amax(np.abs(tf_hidden_states - pt_hidden_states))
self.assertLessEqual(max_diff, 4e-2)
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs_dict_maybe_with_labels)
tf_outputs = tf_model(tf_inputs_dict_maybe_with_labels)
# Some models' output class don't have `loss` attribute despite `labels` is used.
# TODO: identify which models
tf_loss = getattr(tf_outputs, "loss", None)
pt_loss = getattr(pt_outputs, "loss", None)
# Some PT models return loss while the corresponding TF models don't (i.e. `None` for `loss`).
# - TFFlaubertWithLMHeadModel
# - TFFunnelForPreTraining
# - TFElectraForPreTraining
# - TFXLMWithLMHeadModel
# TODO: Fix PT/TF diff -> remove this condition to fail the test if a diff occurs
if not ((tf_loss is None and pt_loss is None) or (tf_loss is not None and pt_loss is not None)):
if model_class.__name__ not in [
"TFFlaubertWithLMHeadModel",
"TFFunnelForPreTraining",
"TFElectraForPreTraining",
"TFXLMWithLMHeadModel",
]:
self.assertEqual(tf_loss is None, pt_loss is None)
tf_keys = tuple([k for k, v in tf_outputs.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
# TODO: remove these 2 conditions once the above TODOs (above loss) are implemented
# (Also, `TFTransfoXLLMHeadModel` has no `loss` while `TransfoXLLMHeadModel` return `losses`)
if tf_keys != pt_keys:
if model_class.__name__ not in [
"TFFlaubertWithLMHeadModel",
"TFFunnelForPreTraining",
"TFElectraForPreTraining",
"TFXLMWithLMHeadModel",
] + ["TFTransfoXLLMHeadModel"]:
self.assertEqual(tf_keys, pt_keys)
# Since we deliberately make some tests pass above (regarding the `loss`), let's still try to test
# some remaining attributes in the outputs.
# TODO: remove this block of `index` computing once the above TODOs (above loss) are implemented
# compute the 1st `index` where `tf_keys` and `pt_keys` is different
index = 0
for _ in range(min(len(tf_keys), len(pt_keys))):
if tf_keys[index] == pt_keys[index]:
index += 1
else:
break
if tf_keys[:index] != pt_keys[:index]:
self.assertEqual(tf_keys, pt_keys)
# Some models require extra condition to return loss. For example, `(TF)BertForPreTraining` requires
# both`labels` and `next_sentence_label`.
if tf_loss is not None and pt_loss is not None:
# check anything else than `loss`
keys = tuple([k for k in tf_keys])
check_outputs(tf_outputs[1:index], pt_outputs[1:index], model_class, names=keys[1:index])
# check `loss`
# tf models returned loss is usually a tensor rather than a scalar.
# (see `hf_compute_loss`: it uses `tf.keras.losses.Reduction.NONE`)
# Change it here to a scalar to match PyTorch models' loss
tf_loss = tf.math.reduce_mean(tf_loss).numpy()
pt_loss = pt_loss.detach().to("cpu").numpy()
tf_nans = np.isnan(tf_loss)
pt_nans = np.isnan(pt_loss)
# the 2 losses need to be both nan or both not nan
self.assertEqual(tf_nans, pt_nans)
if not tf_nans:
max_diff = np.amax(np.abs(tf_loss - pt_loss))
self.assertLessEqual(max_diff, 1e-5)
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# Output all for aggressive testing
config.output_hidden_states = True
# Pure convolutional models have no attention
# TODO: use a better and general criteria
if "TFConvNext" not in model_class.__name__:
config.output_attentions = True
for k in ["attention_mask", "encoder_attention_mask", "decoder_attention_mask"]:
if k in inputs_dict:
attention_mask = inputs_dict[k]
# make sure no all 0s attention masks - to avoid failure at this moment.
# TODO: remove this line once the TODO below is implemented.
attention_mask = tf.ones_like(attention_mask, dtype=tf.int32)
# Here we make the first sequence with all 0s as attention mask.
# Currently, this will fail for `TFWav2Vec2Model`. This is caused by the different large negative
# values, like `1e-4`, `1e-9`, `1e-30` and `-inf` for attention mask across models/frameworks.
# TODO: enable this block once the large negative values thing is cleaned up.
# (see https://github.com/huggingface/transformers/issues/14859)
# attention_mask = tf.concat(
# [
# tf.zeros_like(attention_mask[:1], dtype=tf.int32),
# tf.cast(attention_mask[1:], dtype=tf.int32)
# ],
# axis=0
# )
inputs_dict[k] = attention_mask
pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beginning
pt_model_class = getattr(transformers, pt_model_class_name)
config.output_hidden_states = True
tf_model = model_class(config)
pt_model = pt_model_class(config)
tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
tf_inputs_dict_maybe_with_labels = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
# Check we can load pt model in tf and vice-versa with model => model functions
tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=tf_inputs_dict)
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model)
check_pt_tf_models(tf_model, pt_model)
# Check we can load pt model in tf and vice-versa with checkpoint => model functions
with tempfile.TemporaryDirectory() as tmpdirname:
@ -408,37 +589,7 @@ class TFModelTesterMixin:
tf_model.save_weights(tf_checkpoint_path)
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path)
# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
pt_model.eval()
pt_inputs_dict = {}
for name, key in self._prepare_for_class(inputs_dict, model_class).items():
if type(key) == bool:
key = np.array(key, dtype=bool)
pt_inputs_dict[name] = torch.from_numpy(key).to(torch.long)
elif name == "input_values":
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
elif name == "pixel_values":
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
elif name == "input_features":
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
else:
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long)
with torch.no_grad():
pto = pt_model(**pt_inputs_dict)
tfo = tf_model(self._prepare_for_class(inputs_dict, model_class))
tfo = tfo[0].numpy()
pto = pto[0].numpy()
tf_nans = np.copy(np.isnan(tfo))
pt_nans = np.copy(np.isnan(pto))
pto[tf_nans] = 0
tfo[tf_nans] = 0
pto[pt_nans] = 0
tfo[pt_nans] = 0
max_diff = np.amax(np.abs(tfo - pto))
self.assertLessEqual(max_diff, 4e-2)
check_pt_tf_models(tf_model, pt_model)
def test_compile_tf_model(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()