Make dummy inputs a property of TFPreTrainedModel.

This commit is contained in:
Filip Povolny 2019-11-05 11:48:45 +01:00
parent 8df7dfd2a7
commit 124409d075

View File

@ -52,6 +52,15 @@ class TFPreTrainedModel(tf.keras.Model):
pretrained_model_archive_map = {}
base_model_prefix = ""
@property
def dummy_inputs(self):
""" Dummy inputs to build the network.
Returns:
tf.Tensor with dummy inputs
"""
return tf.constant(DUMMY_INPUTS)
def __init__(self, config, *inputs, **kwargs):
super(TFPreTrainedModel, self).__init__(*inputs, **kwargs)
if not isinstance(config, PretrainedConfig):
@ -265,15 +274,14 @@ class TFPreTrainedModel(tf.keras.Model):
# Load from a PyTorch checkpoint
return load_pytorch_checkpoint_in_tf2_model(model, resolved_archive_file)
dummy_inputs = tf.constant(DUMMY_INPUTS) # dummy inputs to build the network
ret = model(dummy_inputs, training=False) # build the network with dummy inputs
ret = model(model.dummy_inputs, training=False) # build the network with dummy inputs
assert os.path.isfile(resolved_archive_file), "Error retrieving file {}".format(resolved_archive_file)
# 'by_name' allow us to do transfer learning by skipping/adding layers
# see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
model.load_weights(resolved_archive_file, by_name=True)
ret = model(dummy_inputs, training=False) # Make sure restore ops are run
ret = model(model.dummy_inputs, training=False) # Make sure restore ops are run
return model