Merge pull request #1509 from julian-pani/patch-3

remove leftover usage of DUMMY_INPUTS
This commit is contained in:
Thomas Wolf 2019-10-15 10:24:13 +02:00 committed by GitHub
commit e703e4dfe1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 23 additions and 2 deletions

View File

@ -198,7 +198,7 @@ def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs
tf_model = tf_model_class(pt_model.config)
if tf_inputs is None:
tf_inputs = tf.constant(DUMMY_INPUTS)
tf_inputs = tf_model.dummy_inputs
if tf_inputs is not None:
tfo = tf_model(tf_inputs, training=False) # Make sure model is built

View File

@ -14,6 +14,7 @@
# limitations under the License.
from __future__ import absolute_import, division, print_function
import os
import copy
import json
import logging
@ -118,7 +119,7 @@ class TFCommonTestCases:
tf_model = model_class(config)
pt_model = pt_model_class(config)
# Check we can load pt model in tf and vice-versa (architecture similar)
# Check we can load pt model in tf and vice-versa with model => model functions
tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=inputs_dict)
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model)
@ -132,6 +133,26 @@ class TFCommonTestCases:
max_diff = np.amax(np.abs(tfo[0].numpy() - pto[0].numpy()))
self.assertLessEqual(max_diff, 2e-2)
# Check we can load pt model in tf and vice-versa with checkpoint => model functions
with TemporaryDirectory() as tmpdirname:
pt_checkpoint_path = os.path.join(tmpdirname, 'pt_model.bin')
torch.save(pt_model.state_dict(), pt_checkpoint_path)
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(tf_model, pt_checkpoint_path)
tf_checkpoint_path = os.path.join(tmpdirname, 'tf_model.h5')
tf_model.save_weights(tf_checkpoint_path)
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path)
# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
pt_model.eval()
pt_inputs_dict = dict((name, torch.from_numpy(key.numpy()).to(torch.long))
for name, key in inputs_dict.items())
with torch.no_grad():
pto = pt_model(**pt_inputs_dict)
tfo = tf_model(inputs_dict)
max_diff = np.amax(np.abs(tfo[0].numpy() - pto[0].numpy()))
self.assertLessEqual(max_diff, 2e-2)
def test_compile_tf_model(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()