mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Support FlaxPreTrainedModel to load model checkpoint from local subfolder safetensors (#37732)
Support FlaxPreTrainedModel to load model checkpoint from subfolder in local directory as safetensors format Signed-off-by: Yan Zhao <zhao.y4@northeastern.edu>
This commit is contained in:
parent
5b223bbc8c
commit
2c1155519f
@ -702,7 +702,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
||||
if pretrained_model_name_or_path is not None:
|
||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||
is_local = os.path.isdir(pretrained_model_name_or_path)
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
if is_local:
|
||||
if os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)):
|
||||
# Load from a Flax checkpoint
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)
|
||||
@ -710,6 +710,11 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
||||
# Load from a sharded Flax checkpoint
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_INDEX_NAME)
|
||||
is_sharded = True
|
||||
elif is_safetensors_available() and os.path.isfile(
|
||||
os.path.join(pretrained_model_name_or_path, subfolder, SAFE_WEIGHTS_NAME)
|
||||
):
|
||||
# Load from a safetensors checkpoint
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, SAFE_WEIGHTS_NAME)
|
||||
elif is_safetensors_available() and os.path.isfile(
|
||||
os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME)
|
||||
):
|
||||
|
Loading…
Reference in New Issue
Block a user