mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
TF safetensors reduced mem usage (#24404)
* Slight comment cleanup * Reduce peak mem usage when loading TF-format safetensor weights * Tweak the PyTorch loading code to support lazy loading from safetensors * Pass safe_open objects to the PyTorch loading function * Do GPU transposes for speed * One more tweak to reduce peak usage further * One-line hasattr * Fix bug when there's a shape mismatch * Rename state_dict in the loading code to be clearer * Use TF format everywhere for consistency
This commit is contained in:
parent
7e03e46934
commit
22fe73c378
@ -248,7 +248,8 @@ def load_pytorch_state_dict_in_tf2_model(
|
||||
tf_to_pt_weight_rename=None,
|
||||
ignore_mismatched_sizes=False,
|
||||
):
|
||||
"""Load a pytorch state_dict in a TF 2.0 model."""
|
||||
"""Load a pytorch state_dict in a TF 2.0 model. pt_state_dict can be either an actual dict or a lazy-loading
|
||||
safetensors archive created with the safe_open() function."""
|
||||
import tensorflow as tf
|
||||
from packaging.version import parse
|
||||
|
||||
@ -262,13 +263,11 @@ def load_pytorch_state_dict_in_tf2_model(
|
||||
|
||||
if _prefix is None:
|
||||
_prefix = ""
|
||||
if tf_inputs is not None:
|
||||
if tf_inputs:
|
||||
with tf.name_scope(_prefix):
|
||||
tf_model(tf_inputs, training=False) # Make sure model is built
|
||||
# Adapt state dict - TODO remove this and update the AWS weights files instead
|
||||
# Convert old format to new format if needed from a PyTorch state_dict
|
||||
old_keys = []
|
||||
new_keys = []
|
||||
tf_keys_to_pt_keys = {}
|
||||
for key in pt_state_dict.keys():
|
||||
new_key = None
|
||||
if "gamma" in key:
|
||||
@ -279,26 +278,24 @@ def load_pytorch_state_dict_in_tf2_model(
|
||||
new_key = key.replace("running_var", "moving_variance")
|
||||
if "running_mean" in key:
|
||||
new_key = key.replace("running_mean", "moving_mean")
|
||||
if new_key:
|
||||
old_keys.append(key)
|
||||
new_keys.append(new_key)
|
||||
for old_key, new_key in zip(old_keys, new_keys):
|
||||
pt_state_dict[new_key] = pt_state_dict.pop(old_key)
|
||||
if new_key is None:
|
||||
new_key = key
|
||||
tf_keys_to_pt_keys[new_key] = key
|
||||
|
||||
# Matt: All TF models store the actual model stem in a MainLayer class, including the base model.
|
||||
# In PT, the derived models (with heads) use the base model class as the stem instead, and the base model
|
||||
# just contains the stem itself, and there is no MainLayer class. This means that TF base classes have one
|
||||
# In PT, the derived models (with heads) use the base model class as the stem instead,
|
||||
# and there is no MainLayer class. This means that TF base classes have one
|
||||
# extra layer in their weight names, corresponding to the MainLayer class. This code block compensates for that.
|
||||
start_prefix_to_remove = ""
|
||||
if not any(s.startswith(tf_model.base_model_prefix) for s in pt_state_dict.keys()):
|
||||
if not any(s.startswith(tf_model.base_model_prefix) for s in tf_keys_to_pt_keys.keys()):
|
||||
start_prefix_to_remove = tf_model.base_model_prefix + "."
|
||||
|
||||
symbolic_weights = tf_model.trainable_weights + tf_model.non_trainable_weights
|
||||
tf_loaded_numel = 0
|
||||
weight_value_tuples = []
|
||||
all_pytorch_weights = set(pt_state_dict.keys())
|
||||
all_pytorch_weights = set(tf_keys_to_pt_keys.keys())
|
||||
missing_keys = []
|
||||
mismatched_keys = []
|
||||
is_safetensor_archive = hasattr(pt_state_dict, "get_tensor")
|
||||
for symbolic_weight in symbolic_weights:
|
||||
sw_name = symbolic_weight.name
|
||||
name, transpose = convert_tf_weight_name_to_pt_weight_name(
|
||||
@ -311,7 +308,7 @@ def load_pytorch_state_dict_in_tf2_model(
|
||||
name = tf_to_pt_weight_rename(name)
|
||||
|
||||
# Find associated numpy array in pytorch model state dict
|
||||
if name not in pt_state_dict:
|
||||
if name not in tf_keys_to_pt_keys:
|
||||
if allow_missing_keys:
|
||||
missing_keys.append(name)
|
||||
continue
|
||||
@ -320,9 +317,13 @@ def load_pytorch_state_dict_in_tf2_model(
|
||||
if any(re.search(pat, name) is not None for pat in tf_model._keys_to_ignore_on_load_missing):
|
||||
continue
|
||||
raise AttributeError(f"{name} not found in PyTorch model")
|
||||
|
||||
state_dict_name = tf_keys_to_pt_keys[name]
|
||||
if is_safetensor_archive:
|
||||
array = pt_state_dict.get_tensor(state_dict_name)
|
||||
else:
|
||||
array = pt_state_dict[state_dict_name]
|
||||
try:
|
||||
array = apply_transpose(transpose, pt_state_dict[name], symbolic_weight.shape)
|
||||
array = apply_transpose(transpose, array, symbolic_weight.shape)
|
||||
except tf.errors.InvalidArgumentError as e:
|
||||
if not ignore_mismatched_sizes:
|
||||
error_msg = str(e)
|
||||
@ -331,16 +332,15 @@ def load_pytorch_state_dict_in_tf2_model(
|
||||
)
|
||||
raise tf.errors.InvalidArgumentError(error_msg)
|
||||
else:
|
||||
mismatched_keys.append((name, pt_state_dict[name].shape, symbolic_weight.shape))
|
||||
mismatched_keys.append((name, array.shape, symbolic_weight.shape))
|
||||
continue
|
||||
|
||||
tf_loaded_numel += tensor_size(array)
|
||||
|
||||
weight_value_tuples.append((symbolic_weight, array))
|
||||
K.set_value(symbolic_weight, array)
|
||||
del array # Immediately free memory to keep peak usage as low as possible
|
||||
all_pytorch_weights.discard(name)
|
||||
|
||||
K.batch_set_value(weight_value_tuples)
|
||||
|
||||
logger.info(f"Loaded {tf_loaded_numel:,} parameters in the TF 2.0 model.")
|
||||
|
||||
unexpected_keys = list(all_pytorch_weights)
|
||||
|
@ -87,7 +87,6 @@ else:
|
||||
|
||||
if is_safetensors_available():
|
||||
from safetensors import safe_open
|
||||
from safetensors.tensorflow import load_file as safe_load_file
|
||||
from safetensors.tensorflow import save_file as safe_save_file
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -1000,42 +999,33 @@ def load_tf_weights_from_h5(model, resolved_archive_file, ignore_mismatched_size
|
||||
|
||||
def load_tf_weights_from_safetensors(model, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None):
|
||||
# Read the safetensors file
|
||||
state_dict = safe_load_file(resolved_archive_file)
|
||||
with safe_open(resolved_archive_file, framework="tf") as safetensors_archive:
|
||||
mismatched_layers = []
|
||||
weight_names = [format_weight_name(w.name, _prefix=_prefix) for w in model.weights]
|
||||
loaded_weight_names = list(safetensors_archive.keys())
|
||||
# Find the missing layers from the high level list of layers
|
||||
missing_layers = list(set(weight_names) - set(loaded_weight_names))
|
||||
# Find the unexpected layers from the high level list of layers
|
||||
unexpected_layers = list(set(loaded_weight_names) - set(weight_names))
|
||||
|
||||
weight_value_tuples = []
|
||||
mismatched_layers = []
|
||||
|
||||
weight_names = [format_weight_name(w.name, _prefix=_prefix) for w in model.weights]
|
||||
loaded_weight_names = list(state_dict.keys())
|
||||
|
||||
# Find the missing layers from the high level list of layers
|
||||
missing_layers = list(set(weight_names) - set(loaded_weight_names))
|
||||
# Find the unexpected layers from the high level list of layers
|
||||
unexpected_layers = list(set(loaded_weight_names) - set(weight_names))
|
||||
|
||||
weight_value_tuples = []
|
||||
for weight in model.weights:
|
||||
weight_name = format_weight_name(weight.name, _prefix=_prefix)
|
||||
if weight_name in state_dict:
|
||||
weight_value = state_dict[weight_name]
|
||||
# Check if the shape of the current weight and the one from the H5 file are different
|
||||
if K.int_shape(weight) != weight_value.shape:
|
||||
# If yes we reshape the weight from the H5 file accordingly to the current weight
|
||||
# If the two shapes are not compatible we raise an issue
|
||||
try:
|
||||
weight_value = tf.reshape(weight_value, K.int_shape(weight))
|
||||
except ValueError as e:
|
||||
if ignore_mismatched_sizes:
|
||||
mismatched_layers.append((weight_name, weight_value.shape, K.int_shape(weight)))
|
||||
continue
|
||||
else:
|
||||
raise e
|
||||
|
||||
weight_value_tuples.append((weight, weight_value))
|
||||
|
||||
# Load all the weights
|
||||
K.batch_set_value(weight_value_tuples)
|
||||
for weight in model.weights:
|
||||
weight_name = format_weight_name(weight.name, _prefix=_prefix)
|
||||
if weight_name in loaded_weight_names:
|
||||
weight_value = safetensors_archive.get_tensor(weight_name)
|
||||
# Check if the shape of the current weight and the one from the H5 file are different
|
||||
if K.int_shape(weight) != weight_value.shape:
|
||||
# If yes we reshape the weight from the H5 file accordingly to the current weight
|
||||
# If the two shapes are not compatible we raise an issue
|
||||
try:
|
||||
weight_value = tf.reshape(weight_value, K.int_shape(weight))
|
||||
except ValueError as e:
|
||||
if ignore_mismatched_sizes:
|
||||
mismatched_layers.append((weight_name, weight_value.shape, K.int_shape(weight)))
|
||||
continue
|
||||
else:
|
||||
raise e
|
||||
|
||||
K.set_value(weight, weight_value) # weight.assign() might break if weight is a DTensor
|
||||
return missing_layers, unexpected_layers, mismatched_layers
|
||||
|
||||
|
||||
@ -2921,16 +2911,19 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
if safetensors_from_pt:
|
||||
from .modeling_tf_pytorch_utils import load_pytorch_state_dict_in_tf2_model
|
||||
|
||||
state_dict = safe_load_file(resolved_archive_file)
|
||||
# Load from a PyTorch checkpoint
|
||||
return load_pytorch_state_dict_in_tf2_model(
|
||||
model,
|
||||
state_dict,
|
||||
allow_missing_keys=True,
|
||||
output_loading_info=output_loading_info,
|
||||
_prefix=load_weight_prefix,
|
||||
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||
)
|
||||
with safe_open(resolved_archive_file, framework="tf") as safetensors_archive:
|
||||
# Load from a PyTorch checkpoint
|
||||
# We load in TF format here because PT weights often need to be transposed, and this is much
|
||||
# faster on GPU. Loading as numpy and transposing on CPU adds several seconds to load times.
|
||||
return load_pytorch_state_dict_in_tf2_model(
|
||||
model,
|
||||
safetensors_archive,
|
||||
tf_inputs=False, # No need to build the model again
|
||||
allow_missing_keys=True,
|
||||
output_loading_info=output_loading_info,
|
||||
_prefix=load_weight_prefix,
|
||||
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||
)
|
||||
|
||||
# 'by_name' allow us to do transfer learning by skipping/adding layers
|
||||
# see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
|
||||
|
Loading…
Reference in New Issue
Block a user