mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
Make dummy inputs a property of TFPreTrainedModel.
This commit is contained in:
parent
8df7dfd2a7
commit
124409d075
@ -52,6 +52,15 @@ class TFPreTrainedModel(tf.keras.Model):
|
|||||||
pretrained_model_archive_map = {}
|
pretrained_model_archive_map = {}
|
||||||
base_model_prefix = ""
|
base_model_prefix = ""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dummy_inputs(self):
|
||||||
|
""" Dummy inputs to build the network.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tf.Tensor with dummy inputs
|
||||||
|
"""
|
||||||
|
return tf.constant(DUMMY_INPUTS)
|
||||||
|
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super(TFPreTrainedModel, self).__init__(*inputs, **kwargs)
|
super(TFPreTrainedModel, self).__init__(*inputs, **kwargs)
|
||||||
if not isinstance(config, PretrainedConfig):
|
if not isinstance(config, PretrainedConfig):
|
||||||
@ -265,15 +274,14 @@ class TFPreTrainedModel(tf.keras.Model):
|
|||||||
# Load from a PyTorch checkpoint
|
# Load from a PyTorch checkpoint
|
||||||
return load_pytorch_checkpoint_in_tf2_model(model, resolved_archive_file)
|
return load_pytorch_checkpoint_in_tf2_model(model, resolved_archive_file)
|
||||||
|
|
||||||
dummy_inputs = tf.constant(DUMMY_INPUTS) # dummy inputs to build the network
|
ret = model(model.dummy_inputs, training=False) # build the network with dummy inputs
|
||||||
ret = model(dummy_inputs, training=False) # build the network with dummy inputs
|
|
||||||
|
|
||||||
assert os.path.isfile(resolved_archive_file), "Error retrieving file {}".format(resolved_archive_file)
|
assert os.path.isfile(resolved_archive_file), "Error retrieving file {}".format(resolved_archive_file)
|
||||||
# 'by_name' allow us to do transfer learning by skipping/adding layers
|
# 'by_name' allow us to do transfer learning by skipping/adding layers
|
||||||
# see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
|
# see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
|
||||||
model.load_weights(resolved_archive_file, by_name=True)
|
model.load_weights(resolved_archive_file, by_name=True)
|
||||||
|
|
||||||
ret = model(dummy_inputs, training=False) # Make sure restore ops are run
|
ret = model(model.dummy_inputs, training=False) # Make sure restore ops are run
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user