mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Aggressive PT/TF equivalence test on PT side (#16250)
* Aggressive PT/TF equivalence test on PT side * Ugly fix for `TFTapasForQuestionAnswering` * apply review suggestions Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
d481b6414d
commit
75c666b4a8
@ -1463,6 +1463,193 @@ class ModelTesterMixin:
|
||||
|
||||
import transformers
|
||||
|
||||
def prepare_tf_inputs_from_pt_inputs(pt_inputs_dict):
|
||||
|
||||
tf_inputs_dict = {}
|
||||
for key, tensor in pt_inputs_dict.items():
|
||||
# skip key that does not exist in tf
|
||||
if type(tensor) == bool:
|
||||
tf_inputs_dict[key] = tensor
|
||||
elif key == "input_values":
|
||||
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.float32)
|
||||
elif key == "pixel_values":
|
||||
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.float32)
|
||||
elif key == "input_features":
|
||||
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.float32)
|
||||
# To deal with the edge cases from `TFTapasForQuestionAnswering`.
|
||||
# PyTorch can deal with type casting automatically, but TensorFlow is more strict!
|
||||
# TODO: find a clean/better way to deal with these extra keys that are not common.
|
||||
elif key in ["float_answer", "numeric_values", "numeric_values_scale"]:
|
||||
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.float32)
|
||||
else:
|
||||
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.int32)
|
||||
|
||||
return tf_inputs_dict
|
||||
|
||||
def check_outputs(tf_outputs, pt_outputs, model_class, names):
|
||||
"""
|
||||
Args:
|
||||
model_class: The class of the model that is currently testing. For example, `TFBertModel`,
|
||||
TFBertForMaskedLM`, `TFBertForSequenceClassification`, etc. Currently unused, but it could make
|
||||
debugging easier and faster.
|
||||
|
||||
names: A string, or a tuple of strings. These specify what tf_outputs/pt_outputs represent in the model outputs.
|
||||
Currently unused, but in the future, we could use this information to make the error message clearer
|
||||
by giving the name(s) of the output tensor(s) with large difference(s) between PT and TF.
|
||||
"""
|
||||
|
||||
# Some issue (`about past_key_values`) to solve (e.g. `TFPegasusForConditionalGeneration`) in a separate PR.
|
||||
if names == "past_key_values":
|
||||
return
|
||||
|
||||
# Allow `list` because `(TF)TransfoXLModelOutput.mems` is a list of tensors.
|
||||
if type(tf_outputs) in [tuple, list]:
|
||||
self.assertEqual(type(tf_outputs), type(pt_outputs))
|
||||
self.assertEqual(len(tf_outputs), len(pt_outputs))
|
||||
if type(names) == tuple:
|
||||
for tf_output, pt_output, name in zip(tf_outputs, pt_outputs, names):
|
||||
check_outputs(tf_output, pt_output, model_class, names=name)
|
||||
elif type(names) == str:
|
||||
for idx, (tf_output, pt_output) in enumerate(zip(tf_outputs, pt_outputs)):
|
||||
check_outputs(tf_output, pt_output, model_class, names=f"{names}_{idx}")
|
||||
else:
|
||||
raise ValueError(f"`names` should be a `tuple` or a string. Got {type(names)} instead.")
|
||||
elif isinstance(tf_outputs, tf.Tensor):
|
||||
self.assertTrue(isinstance(pt_outputs, torch.Tensor))
|
||||
|
||||
tf_outputs = tf_outputs.numpy()
|
||||
pt_outputs = pt_outputs.detach().to("cpu").numpy()
|
||||
|
||||
tf_nans = np.isnan(tf_outputs)
|
||||
pt_nans = np.isnan(pt_outputs)
|
||||
|
||||
pt_outputs[tf_nans] = 0
|
||||
tf_outputs[tf_nans] = 0
|
||||
pt_outputs[pt_nans] = 0
|
||||
tf_outputs[pt_nans] = 0
|
||||
|
||||
max_diff = np.amax(np.abs(tf_outputs - pt_outputs))
|
||||
self.assertLessEqual(max_diff, 1e-5)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`tf_outputs` should be a `tuple` or an instance of `tf.Tensor`. Got {type(tf_outputs)} instead."
|
||||
)
|
||||
|
||||
def check_pt_tf_models(tf_model, pt_model, pt_inputs_dict, pt_inputs_dict_maybe_with_labels):
|
||||
|
||||
# send pytorch model to the correct device
|
||||
pt_model.to(torch_device)
|
||||
|
||||
# Check predictions on first output (logits/hidden-states) are close enough given low-level computational differences
|
||||
pt_model.eval()
|
||||
|
||||
tf_inputs_dict = prepare_tf_inputs_from_pt_inputs(pt_inputs_dict)
|
||||
tf_inputs_dict_maybe_with_labels = prepare_tf_inputs_from_pt_inputs(pt_inputs_dict_maybe_with_labels)
|
||||
|
||||
# send pytorch inputs to the correct device
|
||||
pt_inputs_dict = {
|
||||
k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs_dict.items()
|
||||
}
|
||||
pt_inputs_dict_maybe_with_labels = {
|
||||
k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in pt_inputs_dict_maybe_with_labels.items()
|
||||
}
|
||||
|
||||
# Original test: check without `labels`
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs_dict)
|
||||
tf_outputs = tf_model(tf_inputs_dict)
|
||||
|
||||
tf_keys = tuple([k for k, v in tf_outputs.items() if v is not None])
|
||||
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
|
||||
|
||||
self.assertEqual(tf_keys, pt_keys)
|
||||
check_outputs(tf_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, names=tf_keys)
|
||||
|
||||
# check the case where `labels` is passed
|
||||
has_labels = any(
|
||||
x in tf_inputs_dict_maybe_with_labels for x in ["labels", "next_sentence_label", "start_positions"]
|
||||
)
|
||||
if has_labels:
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs_dict_maybe_with_labels)
|
||||
tf_outputs = tf_model(tf_inputs_dict_maybe_with_labels)
|
||||
|
||||
# Some models' output class don't have `loss` attribute despite `labels` is used.
|
||||
# TODO: identify which models
|
||||
tf_loss = getattr(tf_outputs, "loss", None)
|
||||
pt_loss = getattr(pt_outputs, "loss", None)
|
||||
|
||||
# Some PT models return loss while the corresponding TF models don't (i.e. `None` for `loss`).
|
||||
# - FlaubertWithLMHeadModel
|
||||
# - FunnelForPreTraining
|
||||
# - ElectraForPreTraining
|
||||
# - XLMWithLMHeadModel
|
||||
# TODO: Fix PT/TF diff -> remove this condition to fail the test if a diff occurs
|
||||
if not ((tf_loss is None and pt_loss is None) or (tf_loss is not None and pt_loss is not None)):
|
||||
if model_class.__name__ not in [
|
||||
"FlaubertWithLMHeadModel",
|
||||
"FunnelForPreTraining",
|
||||
"ElectraForPreTraining",
|
||||
"XLMWithLMHeadModel",
|
||||
"TransfoXLLMHeadModel",
|
||||
]:
|
||||
self.assertEqual(tf_loss is None, pt_loss is None)
|
||||
|
||||
tf_keys = tuple([k for k, v in tf_outputs.items() if v is not None])
|
||||
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
|
||||
|
||||
# TODO: remove these 2 conditions once the above TODOs (above loss) are implemented
|
||||
# (Also, `TFTransfoXLLMHeadModel` has no `loss` while `TransfoXLLMHeadModel` return `losses`)
|
||||
if tf_keys != pt_keys:
|
||||
if model_class.__name__ not in [
|
||||
"FlaubertWithLMHeadModel",
|
||||
"FunnelForPreTraining",
|
||||
"ElectraForPreTraining",
|
||||
"XLMWithLMHeadModel",
|
||||
"TransfoXLLMHeadModel",
|
||||
]:
|
||||
self.assertEqual(tf_keys, pt_keys)
|
||||
|
||||
# Since we deliberately make some tests pass above (regarding the `loss`), let's still try to test
|
||||
# some remaining attributes in the outputs.
|
||||
# TODO: remove this block of `index` computing once the above TODOs (above loss) are implemented
|
||||
# compute the 1st `index` where `tf_keys` and `pt_keys` is different
|
||||
index = 0
|
||||
for _ in range(min(len(tf_keys), len(pt_keys))):
|
||||
if tf_keys[index] == pt_keys[index]:
|
||||
index += 1
|
||||
else:
|
||||
break
|
||||
if tf_keys[:index] != pt_keys[:index]:
|
||||
self.assertEqual(tf_keys, pt_keys)
|
||||
|
||||
# Some models require extra condition to return loss. For example, `(TF)BertForPreTraining` requires
|
||||
# both`labels` and `next_sentence_label`.
|
||||
if tf_loss is not None and pt_loss is not None:
|
||||
|
||||
# check anything else than `loss`
|
||||
keys = tuple([k for k in tf_keys])
|
||||
check_outputs(tf_outputs[1:index], pt_outputs[1:index], model_class, names=keys[1:index])
|
||||
|
||||
# check `loss`
|
||||
|
||||
# tf models returned loss is usually a tensor rather than a scalar.
|
||||
# (see `hf_compute_loss`: it uses `tf.keras.losses.Reduction.NONE`)
|
||||
# Change it here to a scalar to match PyTorch models' loss
|
||||
tf_loss = tf.math.reduce_mean(tf_loss).numpy()
|
||||
pt_loss = pt_loss.detach().to("cpu").numpy()
|
||||
|
||||
tf_nans = np.isnan(tf_loss)
|
||||
pt_nans = np.isnan(pt_loss)
|
||||
# the 2 losses need to be both nan or both not nan
|
||||
self.assertEqual(tf_nans, pt_nans)
|
||||
|
||||
if not tf_nans:
|
||||
max_diff = np.amax(np.abs(tf_loss - pt_loss))
|
||||
self.assertLessEqual(max_diff, 1e-5)
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
@ -1472,9 +1659,30 @@ class ModelTesterMixin:
|
||||
# transformers does not have TF version yet
|
||||
return
|
||||
|
||||
tf_model_class = getattr(transformers, tf_model_class_name)
|
||||
if self.has_attentions:
|
||||
config.output_attentions = True
|
||||
|
||||
config.output_hidden_states = True
|
||||
for k in ["attention_mask", "encoder_attention_mask", "decoder_attention_mask"]:
|
||||
if k in inputs_dict:
|
||||
attention_mask = inputs_dict[k]
|
||||
# make sure no all 0s attention masks - to avoid failure at this moment.
|
||||
# TODO: remove this line once the TODO below is implemented.
|
||||
attention_mask = torch.ones_like(attention_mask, dtype=torch.int32)
|
||||
# Here we make the first sequence with all 0s as attention mask.
|
||||
# Currently, this will fail for `TFWav2Vec2Model`. This is caused by the different large negative
|
||||
# values, like `1e-4`, `1e-9`, `1e-30` and `-inf` for attention mask across models/frameworks.
|
||||
# TODO: enable this block once the large negative values thing is cleaned up.
|
||||
# (see https://github.com/huggingface/transformers/issues/14859)
|
||||
# attention_mask = torch.cat(
|
||||
# [
|
||||
# torch.zeros_like(attention_mask[:1], dtype=torch.int32),
|
||||
# attention_mask[1:].type(dtype=torch.int32)
|
||||
# ],
|
||||
# dim=0
|
||||
# )
|
||||
inputs_dict[k] = attention_mask
|
||||
|
||||
tf_model_class = getattr(transformers, tf_model_class_name)
|
||||
|
||||
tf_model = tf_model_class(config)
|
||||
pt_model = model_class(config)
|
||||
@ -1487,49 +1695,20 @@ class ModelTesterMixin:
|
||||
tf_input_keys.discard("cross_attn_head_mask")
|
||||
tf_input_keys.discard("decoder_head_mask")
|
||||
|
||||
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
pt_inputs = {k: v for k, v in pt_inputs.items() if k in tf_input_keys}
|
||||
pt_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
pt_inputs_dict_maybe_with_labels = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
|
||||
# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
|
||||
pt_model.eval()
|
||||
tf_inputs_dict = {}
|
||||
for key, tensor in pt_inputs.items():
|
||||
# skip key that does not exist in tf
|
||||
if type(tensor) == bool:
|
||||
tf_inputs_dict[key] = tensor
|
||||
elif key == "input_values":
|
||||
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.float32)
|
||||
elif key == "pixel_values":
|
||||
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.float32)
|
||||
elif key == "input_features":
|
||||
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.float32)
|
||||
else:
|
||||
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.int32)
|
||||
pt_inputs_dict = {k: v for k, v in pt_inputs_dict.items() if k in tf_input_keys}
|
||||
pt_inputs_dict_maybe_with_labels = {
|
||||
k: v for k, v in pt_inputs_dict_maybe_with_labels.items() if k in tf_input_keys
|
||||
}
|
||||
|
||||
# Check we can load pt model in tf and vice-versa with model => model functions
|
||||
tf_inputs_dict = prepare_tf_inputs_from_pt_inputs(pt_inputs_dict)
|
||||
tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=tf_inputs_dict)
|
||||
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model).to(torch_device)
|
||||
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model)
|
||||
|
||||
# Make sure PyTorch tensors are on same device as model
|
||||
pt_inputs = {k: v.to(torch_device) if torch.is_tensor(v) else v for k, v in pt_inputs.items()}
|
||||
|
||||
with torch.no_grad():
|
||||
pto = pt_model(**pt_inputs)
|
||||
tfo = tf_model(tf_inputs_dict, training=False)
|
||||
|
||||
tf_hidden_states = tfo[0].numpy()
|
||||
pt_hidden_states = pto[0].cpu().numpy()
|
||||
|
||||
tf_nans = np.copy(np.isnan(tf_hidden_states))
|
||||
pt_nans = np.copy(np.isnan(pt_hidden_states))
|
||||
|
||||
pt_hidden_states[tf_nans] = 0
|
||||
tf_hidden_states[tf_nans] = 0
|
||||
pt_hidden_states[pt_nans] = 0
|
||||
tf_hidden_states[pt_nans] = 0
|
||||
|
||||
max_diff = np.amax(np.abs(tf_hidden_states - pt_hidden_states))
|
||||
self.assertLessEqual(max_diff, 4e-2)
|
||||
check_pt_tf_models(tf_model, pt_model, pt_inputs_dict, pt_inputs_dict_maybe_with_labels)
|
||||
|
||||
# Check we can load pt model in tf and vice-versa with checkpoint => model functions
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
@ -1542,43 +1721,7 @@ class ModelTesterMixin:
|
||||
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path)
|
||||
pt_model = pt_model.to(torch_device)
|
||||
|
||||
# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
|
||||
pt_model.eval()
|
||||
tf_inputs_dict = {}
|
||||
for key, tensor in pt_inputs.items():
|
||||
# skip key that does not exist in tf
|
||||
if type(tensor) == bool:
|
||||
tensor = np.array(tensor, dtype=bool)
|
||||
tf_inputs_dict[key] = tf.convert_to_tensor(tensor, dtype=tf.int32)
|
||||
elif key == "input_values":
|
||||
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.float32)
|
||||
elif key == "pixel_values":
|
||||
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.float32)
|
||||
elif key == "input_features":
|
||||
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.float32)
|
||||
else:
|
||||
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.int32)
|
||||
|
||||
# need to rename encoder-decoder "inputs" for PyTorch
|
||||
# if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
|
||||
# pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs")
|
||||
|
||||
with torch.no_grad():
|
||||
pto = pt_model(**pt_inputs)
|
||||
|
||||
tfo = tf_model(tf_inputs_dict)
|
||||
tfo = tfo[0].numpy()
|
||||
pto = pto[0].cpu().numpy()
|
||||
tf_nans = np.copy(np.isnan(tfo))
|
||||
pt_nans = np.copy(np.isnan(pto))
|
||||
|
||||
pto[tf_nans] = 0
|
||||
tfo[tf_nans] = 0
|
||||
pto[pt_nans] = 0
|
||||
tfo[pt_nans] = 0
|
||||
|
||||
max_diff = np.amax(np.abs(tfo - pto))
|
||||
self.assertLessEqual(max_diff, 4e-2)
|
||||
check_pt_tf_models(tf_model, pt_model, pt_inputs_dict, pt_inputs_dict_maybe_with_labels)
|
||||
|
||||
def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float):
|
||||
diff = np.abs((a - b)).max()
|
||||
|
Loading…
Reference in New Issue
Block a user