TF safetensors reduced mem usage (#24404)

* Slight comment cleanup

* Reduce peak mem usage when loading TF-format safetensor weights

* Tweak the PyTorch loading code to support lazy loading from safetensors

* Pass safe_open objects to the PyTorch loading function

* Do GPU transposes for speed

* One more tweak to reduce peak usage further

* One-line hasattr

* Fix bug when there's a shape mismatch

* Rename state_dict in the loading code to be clearer

* Use TF format everywhere for consistency
This commit is contained in:
Matt 2023-06-22 14:06:16 +01:00 committed by GitHub
parent 7e03e46934
commit 22fe73c378
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 60 additions and 67 deletions

View File

@ -248,7 +248,8 @@ def load_pytorch_state_dict_in_tf2_model(
tf_to_pt_weight_rename=None,
ignore_mismatched_sizes=False,
):
"""Load a pytorch state_dict in a TF 2.0 model."""
"""Load a pytorch state_dict in a TF 2.0 model. pt_state_dict can be either an actual dict or a lazy-loading
safetensors archive created with the safe_open() function."""
import tensorflow as tf
from packaging.version import parse
@ -262,13 +263,11 @@ def load_pytorch_state_dict_in_tf2_model(
if _prefix is None:
_prefix = ""
if tf_inputs is not None:
if tf_inputs:
with tf.name_scope(_prefix):
tf_model(tf_inputs, training=False) # Make sure model is built
# Adapt state dict - TODO remove this and update the AWS weights files instead
# Convert old format to new format if needed from a PyTorch state_dict
old_keys = []
new_keys = []
tf_keys_to_pt_keys = {}
for key in pt_state_dict.keys():
new_key = None
if "gamma" in key:
@ -279,26 +278,24 @@ def load_pytorch_state_dict_in_tf2_model(
new_key = key.replace("running_var", "moving_variance")
if "running_mean" in key:
new_key = key.replace("running_mean", "moving_mean")
if new_key:
old_keys.append(key)
new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys):
pt_state_dict[new_key] = pt_state_dict.pop(old_key)
if new_key is None:
new_key = key
tf_keys_to_pt_keys[new_key] = key
# Matt: All TF models store the actual model stem in a MainLayer class, including the base model.
# In PT, the derived models (with heads) use the base model class as the stem instead, and the base model
# just contains the stem itself, and there is no MainLayer class. This means that TF base classes have one
# In PT, the derived models (with heads) use the base model class as the stem instead,
# and there is no MainLayer class. This means that TF base classes have one
# extra layer in their weight names, corresponding to the MainLayer class. This code block compensates for that.
start_prefix_to_remove = ""
if not any(s.startswith(tf_model.base_model_prefix) for s in pt_state_dict.keys()):
if not any(s.startswith(tf_model.base_model_prefix) for s in tf_keys_to_pt_keys.keys()):
start_prefix_to_remove = tf_model.base_model_prefix + "."
symbolic_weights = tf_model.trainable_weights + tf_model.non_trainable_weights
tf_loaded_numel = 0
weight_value_tuples = []
all_pytorch_weights = set(pt_state_dict.keys())
all_pytorch_weights = set(tf_keys_to_pt_keys.keys())
missing_keys = []
mismatched_keys = []
is_safetensor_archive = hasattr(pt_state_dict, "get_tensor")
for symbolic_weight in symbolic_weights:
sw_name = symbolic_weight.name
name, transpose = convert_tf_weight_name_to_pt_weight_name(
@ -311,7 +308,7 @@ def load_pytorch_state_dict_in_tf2_model(
name = tf_to_pt_weight_rename(name)
# Find associated numpy array in pytorch model state dict
if name not in pt_state_dict:
if name not in tf_keys_to_pt_keys:
if allow_missing_keys:
missing_keys.append(name)
continue
@ -320,9 +317,13 @@ def load_pytorch_state_dict_in_tf2_model(
if any(re.search(pat, name) is not None for pat in tf_model._keys_to_ignore_on_load_missing):
continue
raise AttributeError(f"{name} not found in PyTorch model")
state_dict_name = tf_keys_to_pt_keys[name]
if is_safetensor_archive:
array = pt_state_dict.get_tensor(state_dict_name)
else:
array = pt_state_dict[state_dict_name]
try:
array = apply_transpose(transpose, pt_state_dict[name], symbolic_weight.shape)
array = apply_transpose(transpose, array, symbolic_weight.shape)
except tf.errors.InvalidArgumentError as e:
if not ignore_mismatched_sizes:
error_msg = str(e)
@ -331,16 +332,15 @@ def load_pytorch_state_dict_in_tf2_model(
)
raise tf.errors.InvalidArgumentError(error_msg)
else:
mismatched_keys.append((name, pt_state_dict[name].shape, symbolic_weight.shape))
mismatched_keys.append((name, array.shape, symbolic_weight.shape))
continue
tf_loaded_numel += tensor_size(array)
weight_value_tuples.append((symbolic_weight, array))
K.set_value(symbolic_weight, array)
del array # Immediately free memory to keep peak usage as low as possible
all_pytorch_weights.discard(name)
K.batch_set_value(weight_value_tuples)
logger.info(f"Loaded {tf_loaded_numel:,} parameters in the TF 2.0 model.")
unexpected_keys = list(all_pytorch_weights)

View File

@ -87,7 +87,6 @@ else:
if is_safetensors_available():
from safetensors import safe_open
from safetensors.tensorflow import load_file as safe_load_file
from safetensors.tensorflow import save_file as safe_save_file
if TYPE_CHECKING:
@ -1000,42 +999,33 @@ def load_tf_weights_from_h5(model, resolved_archive_file, ignore_mismatched_size
def load_tf_weights_from_safetensors(model, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None):
# Read the safetensors file
state_dict = safe_load_file(resolved_archive_file)
with safe_open(resolved_archive_file, framework="tf") as safetensors_archive:
mismatched_layers = []
weight_names = [format_weight_name(w.name, _prefix=_prefix) for w in model.weights]
loaded_weight_names = list(safetensors_archive.keys())
# Find the missing layers from the high level list of layers
missing_layers = list(set(weight_names) - set(loaded_weight_names))
# Find the unexpected layers from the high level list of layers
unexpected_layers = list(set(loaded_weight_names) - set(weight_names))
weight_value_tuples = []
mismatched_layers = []
weight_names = [format_weight_name(w.name, _prefix=_prefix) for w in model.weights]
loaded_weight_names = list(state_dict.keys())
# Find the missing layers from the high level list of layers
missing_layers = list(set(weight_names) - set(loaded_weight_names))
# Find the unexpected layers from the high level list of layers
unexpected_layers = list(set(loaded_weight_names) - set(weight_names))
weight_value_tuples = []
for weight in model.weights:
weight_name = format_weight_name(weight.name, _prefix=_prefix)
if weight_name in state_dict:
weight_value = state_dict[weight_name]
# Check if the shape of the current weight and the one from the H5 file are different
if K.int_shape(weight) != weight_value.shape:
# If yes we reshape the weight from the H5 file accordingly to the current weight
# If the two shapes are not compatible we raise an issue
try:
weight_value = tf.reshape(weight_value, K.int_shape(weight))
except ValueError as e:
if ignore_mismatched_sizes:
mismatched_layers.append((weight_name, weight_value.shape, K.int_shape(weight)))
continue
else:
raise e
weight_value_tuples.append((weight, weight_value))
# Load all the weights
K.batch_set_value(weight_value_tuples)
for weight in model.weights:
weight_name = format_weight_name(weight.name, _prefix=_prefix)
if weight_name in loaded_weight_names:
weight_value = safetensors_archive.get_tensor(weight_name)
# Check if the shape of the current weight and the one from the H5 file are different
if K.int_shape(weight) != weight_value.shape:
# If yes we reshape the weight from the H5 file accordingly to the current weight
# If the two shapes are not compatible we raise an issue
try:
weight_value = tf.reshape(weight_value, K.int_shape(weight))
except ValueError as e:
if ignore_mismatched_sizes:
mismatched_layers.append((weight_name, weight_value.shape, K.int_shape(weight)))
continue
else:
raise e
K.set_value(weight, weight_value) # weight.assign() might break if weight is a DTensor
return missing_layers, unexpected_layers, mismatched_layers
@ -2921,16 +2911,19 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
if safetensors_from_pt:
from .modeling_tf_pytorch_utils import load_pytorch_state_dict_in_tf2_model
state_dict = safe_load_file(resolved_archive_file)
# Load from a PyTorch checkpoint
return load_pytorch_state_dict_in_tf2_model(
model,
state_dict,
allow_missing_keys=True,
output_loading_info=output_loading_info,
_prefix=load_weight_prefix,
ignore_mismatched_sizes=ignore_mismatched_sizes,
)
with safe_open(resolved_archive_file, framework="tf") as safetensors_archive:
# Load from a PyTorch checkpoint
# We load in TF format here because PT weights often need to be transposed, and this is much
# faster on GPU. Loading as numpy and transposing on CPU adds several seconds to load times.
return load_pytorch_state_dict_in_tf2_model(
model,
safetensors_archive,
tf_inputs=False, # No need to build the model again
allow_missing_keys=True,
output_loading_info=output_loading_info,
_prefix=load_weight_prefix,
ignore_mismatched_sizes=ignore_mismatched_sizes,
)
# 'by_name' allow us to do transfer learning by skipping/adding layers
# see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357