Fix load from PT-formatted checkpoint in composite TF models (#20661)

* Fix load from PT-formatted checkpoint in composite TF models

* Leave the from_pt part as it was
This commit is contained in:
Sylvain Gugger 2022-12-08 09:33:07 -05:00 committed by GitHub
parent 521da6518f
commit a03f7514db
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2727,14 +2727,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
return load_pytorch_checkpoint_in_tf2_model(
model, resolved_archive_file, allow_missing_keys=True, output_loading_info=output_loading_info
)
elif safetensors_from_pt:
from .modeling_tf_pytorch_utils import load_pytorch_state_dict_in_tf2_model
state_dict = safe_load_file(resolved_archive_file)
# Load from a PyTorch checkpoint
return load_pytorch_state_dict_in_tf2_model(
model, state_dict, allow_missing_keys=True, output_loading_info=output_loading_info
)
# we might need to extend the variable scope for composite models
if load_weight_prefix is not None:
@ -2743,6 +2735,15 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
else:
model(model.dummy_inputs) # build the network with dummy inputs
if safetensors_from_pt:
from .modeling_tf_pytorch_utils import load_pytorch_state_dict_in_tf2_model
state_dict = safe_load_file(resolved_archive_file)
# Load from a PyTorch checkpoint
return load_pytorch_state_dict_in_tf2_model(
model, state_dict, allow_missing_keys=True, output_loading_info=output_loading_info
)
# 'by_name' allow us to do transfer learning by skipping/adding layers
# see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
try: