Add a use_safetensors arg to TFPreTrainedModel.from_pretrained() (#28511)

* Add a use_safetensors arg to TFPreTrainedModel.from_pretrained()

* One more catch!

* One more one more catch
This commit is contained in:
Matt 2024-01-15 17:00:54 +00:00 committed by GitHub
parent 78d767e3c8
commit 72db39c065
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2508,6 +2508,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
local_files_only: bool = False,
token: Optional[Union[str, bool]] = None,
revision: str = "main",
use_safetensors: bool = None,
**kwargs,
):
r"""
@ -2601,6 +2602,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
A function that is called to transform the names of weights during the PyTorch to TensorFlow
crossloading process. This is not necessary for most models, but is useful to allow composite models to
be crossloaded correctly.
use_safetensors (`bool`, *optional*, defaults to `None`):
Whether or not to use `safetensors` checkpoints. Defaults to `None`. If not specified and `safetensors`
is not installed, it will be set to `False`.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
`output_attentions=True`). Behaves differently depending on whether a `config` is provided or
@ -2673,6 +2677,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
logger.info("Offline mode: forcing local_files_only=True")
local_files_only = True
if use_safetensors is None and not is_safetensors_available():
use_safetensors = False
# Load config if we don't provide a configuration
if not isinstance(config, PretrainedConfig):
config_path = config if config is not None else pretrained_model_name_or_path
@ -2712,7 +2719,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
# Load from a sharded PyTorch checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)
is_sharded = True
elif is_safetensors_available() and os.path.isfile(
elif use_safetensors is not False and os.path.isfile(
os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME)
):
# Load from a safetensors checkpoint
@ -2724,7 +2731,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
# Load from a sharded TF 2.0 checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME)
is_sharded = True
elif is_safetensors_available() and os.path.isfile(
elif use_safetensors is not False and os.path.isfile(
os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
):
# Load from a sharded safetensors checkpoint
@ -2732,6 +2739,12 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
is_sharded = True
raise NotImplementedError("Support for sharded checkpoints using safetensors is coming soon!")
# At this stage we don't have a weight file so we will raise an error.
elif use_safetensors:
raise EnvironmentError(
f"Error no file named {SAFE_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}. "
f"Please make sure that the model has been saved with `safe_serialization=True` or do not "
f"set `use_safetensors=True`."
)
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)) or os.path.isfile(
os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)
):
@ -2758,7 +2771,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
# set correct filename
if from_pt:
filename = WEIGHTS_NAME
elif is_safetensors_available():
elif use_safetensors is not False:
filename = SAFE_WEIGHTS_NAME
else:
filename = TF2_WEIGHTS_NAME