mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Absolute definitive HeisenDistilBug solve
cc @julien-c @thomwolf
This commit is contained in:
parent
5c3d441ee1
commit
ea2600bd5f
@ -113,10 +113,13 @@ class TFModelTesterMixin:
|
||||
tf_hidden_states = tfo[0].numpy()
|
||||
pt_hidden_states = pto[0].numpy()
|
||||
|
||||
pt_hidden_states[np.isnan(tf_hidden_states)] = 0
|
||||
tf_hidden_states[np.isnan(tf_hidden_states)] = 0
|
||||
pt_hidden_states[np.isnan(pt_hidden_states)] = 0
|
||||
tf_hidden_states[np.isnan(pt_hidden_states)] = 0
|
||||
tf_nans = np.copy(np.isnan(tf_hidden_states))
|
||||
pt_nans = np.copy(np.isnan(pt_hidden_states))
|
||||
|
||||
pt_hidden_states[tf_nans] = 0
|
||||
tf_hidden_states[tf_nans] = 0
|
||||
pt_hidden_states[pt_nans] = 0
|
||||
tf_hidden_states[pt_nans] = 0
|
||||
|
||||
max_diff = np.amax(np.abs(tf_hidden_states - pt_hidden_states))
|
||||
# Debug info (remove when fixed)
|
||||
@ -148,8 +151,14 @@ class TFModelTesterMixin:
|
||||
tfo = tf_model(inputs_dict)
|
||||
tfo = tfo[0].numpy()
|
||||
pto = pto[0].numpy()
|
||||
tfo[np.isnan(tfo)] = 0
|
||||
pto[np.isnan(pto)] = 0
|
||||
tf_nans = np.copy(np.isnan(tfo))
|
||||
pt_nans = np.copy(np.isnan(pto))
|
||||
|
||||
pto[tf_nans] = 0
|
||||
tfo[tf_nans] = 0
|
||||
pto[pt_nans] = 0
|
||||
tfo[pt_nans] = 0
|
||||
|
||||
max_diff = np.amax(np.abs(tfo - pto))
|
||||
self.assertLessEqual(max_diff, 2e-2)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user