Support FlaxPreTrainedModel to load model checkpoint from local subfolder safetensors (#37732)

Support FlaxPreTrainedModel to load model checkpoint from subfolder in local directory as safetensors format

Signed-off-by: Yan Zhao <zhao.y4@northeastern.edu>
This commit is contained in:
Yan Zhao 2025-04-30 07:13:23 -07:00 committed by GitHub
parent 5b223bbc8c
commit 2c1155519f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -702,7 +702,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
if pretrained_model_name_or_path is not None:
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
is_local = os.path.isdir(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path):
if is_local:
if os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)):
# Load from a Flax checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)
@ -710,6 +710,11 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
# Load from a sharded Flax checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_INDEX_NAME)
is_sharded = True
elif is_safetensors_available() and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, SAFE_WEIGHTS_NAME)
):
# Load from a safetensors checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, SAFE_WEIGHTS_NAME)
elif is_safetensors_available() and os.path.isfile(
os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME)
):