diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 5e3c492909f..132513e0786 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -2727,14 +2727,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu return load_pytorch_checkpoint_in_tf2_model( model, resolved_archive_file, allow_missing_keys=True, output_loading_info=output_loading_info ) - elif safetensors_from_pt: - from .modeling_tf_pytorch_utils import load_pytorch_state_dict_in_tf2_model - - state_dict = safe_load_file(resolved_archive_file) - # Load from a PyTorch checkpoint - return load_pytorch_state_dict_in_tf2_model( - model, state_dict, allow_missing_keys=True, output_loading_info=output_loading_info - ) # we might need to extend the variable scope for composite models if load_weight_prefix is not None: @@ -2743,6 +2735,15 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu else: model(model.dummy_inputs) # build the network with dummy inputs + if safetensors_from_pt: + from .modeling_tf_pytorch_utils import load_pytorch_state_dict_in_tf2_model + + state_dict = safe_load_file(resolved_archive_file) + # Load from a PyTorch checkpoint + return load_pytorch_state_dict_in_tf2_model( + model, state_dict, allow_missing_keys=True, output_loading_info=output_loading_info + ) + # 'by_name' allow us to do transfer learning by skipping/adding layers # see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357 try: