mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix load from PT-formatted checkpoint in composite TF models (#20661)
* Fix load from PT-formatted checkpoint in composite TF models * Leave the from_pt part as it was
This commit is contained in:
parent
521da6518f
commit
a03f7514db
@ -2727,14 +2727,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
return load_pytorch_checkpoint_in_tf2_model(
|
||||
model, resolved_archive_file, allow_missing_keys=True, output_loading_info=output_loading_info
|
||||
)
|
||||
elif safetensors_from_pt:
|
||||
from .modeling_tf_pytorch_utils import load_pytorch_state_dict_in_tf2_model
|
||||
|
||||
state_dict = safe_load_file(resolved_archive_file)
|
||||
# Load from a PyTorch checkpoint
|
||||
return load_pytorch_state_dict_in_tf2_model(
|
||||
model, state_dict, allow_missing_keys=True, output_loading_info=output_loading_info
|
||||
)
|
||||
|
||||
# we might need to extend the variable scope for composite models
|
||||
if load_weight_prefix is not None:
|
||||
@ -2743,6 +2735,15 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
else:
|
||||
model(model.dummy_inputs) # build the network with dummy inputs
|
||||
|
||||
if safetensors_from_pt:
|
||||
from .modeling_tf_pytorch_utils import load_pytorch_state_dict_in_tf2_model
|
||||
|
||||
state_dict = safe_load_file(resolved_archive_file)
|
||||
# Load from a PyTorch checkpoint
|
||||
return load_pytorch_state_dict_in_tf2_model(
|
||||
model, state_dict, allow_missing_keys=True, output_loading_info=output_loading_info
|
||||
)
|
||||
|
||||
# 'by_name' allow us to do transfer learning by skipping/adding layers
|
||||
# see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
|
||||
try:
|
||||
|
Loading…
Reference in New Issue
Block a user