Better error raised when cloned without lfs (#13401)

* Better error raised when cloned without lfs

* add from e
This commit is contained in:
Lysandre Debut 2021-09-08 08:28:22 -04:00 committed by GitHub
parent 18447c206d
commit 99029ab6b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 47 additions and 13 deletions

View File

@ -21,6 +21,7 @@ from typing import Dict, Set, Tuple, Union
import flax.linen as nn
import jax
import jax.numpy as jnp
import msgpack.exceptions
from flax.core.frozen_dict import FrozenDict, unfreeze
from flax.serialization import from_bytes, to_bytes
from flax.traverse_util import flatten_dict, unflatten_dict
@ -348,8 +349,19 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
with open(resolved_archive_file, "rb") as state_f:
try:
state = from_bytes(cls, state_f.read())
except UnpicklingError:
raise EnvironmentError(f"Unable to convert {archive_file} to Flax deserializable object. ")
except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
try:
with open(resolved_archive_file) as f:
if f.read().startswith("version"):
raise OSError(
"You seem to have cloned a repository without having git-lfs installed. Please install "
"git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
"you cloned."
)
else:
raise ValueError from e
except (UnicodeDecodeError, ValueError):
raise EnvironmentError(f"Unable to convert {archive_file} to Flax deserializable object. ")
# make sure all arrays are stored as jnp.arrays
# NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
# https://github.com/google/flax/issues/1261

View File

@ -1334,11 +1334,22 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
ignore_mismatched_sizes=ignore_mismatched_sizes,
_prefix=load_weight_prefix,
)
except OSError:
raise OSError(
"Unable to load weights from h5 file. "
"If you tried to load a TF 2.0 model from a PyTorch checkpoint, please set from_pt=True. "
)
except OSError as e:
try:
with open(resolved_archive_file) as f:
if f.read().startswith("version"):
raise OSError(
"You seem to have cloned a repository without having git-lfs installed. Please install "
"git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
"you cloned."
)
else:
raise ValueError from e
except (UnicodeDecodeError, ValueError):
raise OSError(
"Unable to load weights from h5 file. "
"If you tried to load a TF 2.0 model from a PyTorch checkpoint, please set from_pt=True. "
)
model(model.dummy_inputs) # Make sure restore ops are run

View File

@ -1285,12 +1285,23 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if state_dict is None:
try:
state_dict = torch.load(resolved_archive_file, map_location="cpu")
except Exception:
raise OSError(
f"Unable to load weights from pytorch checkpoint file for '{pretrained_model_name_or_path}' "
f"at '{resolved_archive_file}'"
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. "
)
except Exception as e:
try:
with open(resolved_archive_file) as f:
if f.read().startswith("version"):
raise OSError(
"You seem to have cloned a repository without having git-lfs installed. Please install "
"git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
"you cloned."
)
else:
raise ValueError from e
except (UnicodeDecodeError, ValueError):
raise OSError(
f"Unable to load weights from pytorch checkpoint file for '{pretrained_model_name_or_path}' "
f"at '{resolved_archive_file}'"
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. "
)
# set dtype to instantiate the model under:
# 1. If torch_dtype is not None, we use that dtype