mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
Make dummy inputs a property of TFPreTrainedModel.
This commit is contained in:
parent
8df7dfd2a7
commit
124409d075
@ -52,6 +52,15 @@ class TFPreTrainedModel(tf.keras.Model):
|
||||
pretrained_model_archive_map = {}
|
||||
base_model_prefix = ""
|
||||
|
||||
@property
|
||||
def dummy_inputs(self):
|
||||
""" Dummy inputs to build the network.
|
||||
|
||||
Returns:
|
||||
tf.Tensor with dummy inputs
|
||||
"""
|
||||
return tf.constant(DUMMY_INPUTS)
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super(TFPreTrainedModel, self).__init__(*inputs, **kwargs)
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
@ -265,15 +274,14 @@ class TFPreTrainedModel(tf.keras.Model):
|
||||
# Load from a PyTorch checkpoint
|
||||
return load_pytorch_checkpoint_in_tf2_model(model, resolved_archive_file)
|
||||
|
||||
dummy_inputs = tf.constant(DUMMY_INPUTS) # dummy inputs to build the network
|
||||
ret = model(dummy_inputs, training=False) # build the network with dummy inputs
|
||||
ret = model(model.dummy_inputs, training=False) # build the network with dummy inputs
|
||||
|
||||
assert os.path.isfile(resolved_archive_file), "Error retrieving file {}".format(resolved_archive_file)
|
||||
# 'by_name' allow us to do transfer learning by skipping/adding layers
|
||||
# see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
|
||||
model.load_weights(resolved_archive_file, by_name=True)
|
||||
|
||||
ret = model(dummy_inputs, training=False) # Make sure restore ops are run
|
||||
ret = model(model.dummy_inputs, training=False) # Make sure restore ops are run
|
||||
|
||||
return model
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user