Absolute definitive HeisenDistilBug solve

cc @julien-c @thomwolf
This commit is contained in:
Lysandre 2020-01-27 21:57:23 -05:00
parent 5c3d441ee1
commit ea2600bd5f

View File

@ -113,10 +113,13 @@ class TFModelTesterMixin:
tf_hidden_states = tfo[0].numpy()
pt_hidden_states = pto[0].numpy()
pt_hidden_states[np.isnan(tf_hidden_states)] = 0
tf_hidden_states[np.isnan(tf_hidden_states)] = 0
pt_hidden_states[np.isnan(pt_hidden_states)] = 0
tf_hidden_states[np.isnan(pt_hidden_states)] = 0
tf_nans = np.copy(np.isnan(tf_hidden_states))
pt_nans = np.copy(np.isnan(pt_hidden_states))
pt_hidden_states[tf_nans] = 0
tf_hidden_states[tf_nans] = 0
pt_hidden_states[pt_nans] = 0
tf_hidden_states[pt_nans] = 0
max_diff = np.amax(np.abs(tf_hidden_states - pt_hidden_states))
# Debug info (remove when fixed)
@ -148,8 +151,14 @@ class TFModelTesterMixin:
tfo = tf_model(inputs_dict)
tfo = tfo[0].numpy()
pto = pto[0].numpy()
tfo[np.isnan(tfo)] = 0
pto[np.isnan(pto)] = 0
tf_nans = np.copy(np.isnan(tfo))
pt_nans = np.copy(np.isnan(pto))
pto[tf_nans] = 0
tfo[tf_nans] = 0
pto[pt_nans] = 0
tfo[pt_nans] = 0
max_diff = np.amax(np.abs(tfo - pto))
self.assertLessEqual(max_diff, 2e-2)