mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add a use_safetensors arg to TFPreTrainedModel.from_pretrained() (#28511)
* Add a use_safetensors arg to TFPreTrainedModel.from_pretrained() * One more catch! * One more one more catch
This commit is contained in:
parent
78d767e3c8
commit
72db39c065
@ -2508,6 +2508,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
local_files_only: bool = False,
|
||||
token: Optional[Union[str, bool]] = None,
|
||||
revision: str = "main",
|
||||
use_safetensors: bool = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
@ -2601,6 +2602,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
A function that is called to transform the names of weights during the PyTorch to TensorFlow
|
||||
crossloading process. This is not necessary for most models, but is useful to allow composite models to
|
||||
be crossloaded correctly.
|
||||
use_safetensors (`bool`, *optional*, defaults to `None`):
|
||||
Whether or not to use `safetensors` checkpoints. Defaults to `None`. If not specified and `safetensors`
|
||||
is not installed, it will be set to `False`.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
||||
`output_attentions=True`). Behaves differently depending on whether a `config` is provided or
|
||||
@ -2673,6 +2677,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
logger.info("Offline mode: forcing local_files_only=True")
|
||||
local_files_only = True
|
||||
|
||||
if use_safetensors is None and not is_safetensors_available():
|
||||
use_safetensors = False
|
||||
|
||||
# Load config if we don't provide a configuration
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config_path = config if config is not None else pretrained_model_name_or_path
|
||||
@ -2712,7 +2719,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
# Load from a sharded PyTorch checkpoint
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)
|
||||
is_sharded = True
|
||||
elif is_safetensors_available() and os.path.isfile(
|
||||
elif use_safetensors is not False and os.path.isfile(
|
||||
os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME)
|
||||
):
|
||||
# Load from a safetensors checkpoint
|
||||
@ -2724,7 +2731,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
# Load from a sharded TF 2.0 checkpoint
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME)
|
||||
is_sharded = True
|
||||
elif is_safetensors_available() and os.path.isfile(
|
||||
elif use_safetensors is not False and os.path.isfile(
|
||||
os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
|
||||
):
|
||||
# Load from a sharded safetensors checkpoint
|
||||
@ -2732,6 +2739,12 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
is_sharded = True
|
||||
raise NotImplementedError("Support for sharded checkpoints using safetensors is coming soon!")
|
||||
# At this stage we don't have a weight file so we will raise an error.
|
||||
elif use_safetensors:
|
||||
raise EnvironmentError(
|
||||
f"Error no file named {SAFE_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}. "
|
||||
f"Please make sure that the model has been saved with `safe_serialization=True` or do not "
|
||||
f"set `use_safetensors=True`."
|
||||
)
|
||||
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)) or os.path.isfile(
|
||||
os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)
|
||||
):
|
||||
@ -2758,7 +2771,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
# set correct filename
|
||||
if from_pt:
|
||||
filename = WEIGHTS_NAME
|
||||
elif is_safetensors_available():
|
||||
elif use_safetensors is not False:
|
||||
filename = SAFE_WEIGHTS_NAME
|
||||
else:
|
||||
filename = TF2_WEIGHTS_NAME
|
||||
|
Loading…
Reference in New Issue
Block a user