Fix FlaxPretTrainedModel pt weights check (#19133)

* Fix FlaxPretTrainedModel pt weights check

* Update src/transformers/modeling_flax_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* fix raise comment

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Mishig Davaadorj 2022-09-21 14:17:04 +02:00 committed by GitHub
parent e7fdfc720a
commit 486134e5a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -665,7 +665,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
archive_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_INDEX_NAME)
is_sharded = True
# At this stage we don't have a weight file so we will raise an error.
elif os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME):
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
raise EnvironmentError(
f"Error no file named {FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
"but there is a file for PyTorch weights. Use `from_pt=True` to load this model from those "