mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
Better error raised when cloned without lfs (#13401)
* Better error raised when cloned without lfs * add from e
This commit is contained in:
parent
18447c206d
commit
99029ab6b0
@ -21,6 +21,7 @@ from typing import Dict, Set, Tuple, Union
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import msgpack.exceptions
|
||||
from flax.core.frozen_dict import FrozenDict, unfreeze
|
||||
from flax.serialization import from_bytes, to_bytes
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
@ -348,8 +349,19 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
||||
with open(resolved_archive_file, "rb") as state_f:
|
||||
try:
|
||||
state = from_bytes(cls, state_f.read())
|
||||
except UnpicklingError:
|
||||
raise EnvironmentError(f"Unable to convert {archive_file} to Flax deserializable object. ")
|
||||
except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
|
||||
try:
|
||||
with open(resolved_archive_file) as f:
|
||||
if f.read().startswith("version"):
|
||||
raise OSError(
|
||||
"You seem to have cloned a repository without having git-lfs installed. Please install "
|
||||
"git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
|
||||
"you cloned."
|
||||
)
|
||||
else:
|
||||
raise ValueError from e
|
||||
except (UnicodeDecodeError, ValueError):
|
||||
raise EnvironmentError(f"Unable to convert {archive_file} to Flax deserializable object. ")
|
||||
# make sure all arrays are stored as jnp.arrays
|
||||
# NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
|
||||
# https://github.com/google/flax/issues/1261
|
||||
|
@ -1334,11 +1334,22 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||
_prefix=load_weight_prefix,
|
||||
)
|
||||
except OSError:
|
||||
raise OSError(
|
||||
"Unable to load weights from h5 file. "
|
||||
"If you tried to load a TF 2.0 model from a PyTorch checkpoint, please set from_pt=True. "
|
||||
)
|
||||
except OSError as e:
|
||||
try:
|
||||
with open(resolved_archive_file) as f:
|
||||
if f.read().startswith("version"):
|
||||
raise OSError(
|
||||
"You seem to have cloned a repository without having git-lfs installed. Please install "
|
||||
"git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
|
||||
"you cloned."
|
||||
)
|
||||
else:
|
||||
raise ValueError from e
|
||||
except (UnicodeDecodeError, ValueError):
|
||||
raise OSError(
|
||||
"Unable to load weights from h5 file. "
|
||||
"If you tried to load a TF 2.0 model from a PyTorch checkpoint, please set from_pt=True. "
|
||||
)
|
||||
|
||||
model(model.dummy_inputs) # Make sure restore ops are run
|
||||
|
||||
|
@ -1285,12 +1285,23 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
if state_dict is None:
|
||||
try:
|
||||
state_dict = torch.load(resolved_archive_file, map_location="cpu")
|
||||
except Exception:
|
||||
raise OSError(
|
||||
f"Unable to load weights from pytorch checkpoint file for '{pretrained_model_name_or_path}' "
|
||||
f"at '{resolved_archive_file}'"
|
||||
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. "
|
||||
)
|
||||
except Exception as e:
|
||||
try:
|
||||
with open(resolved_archive_file) as f:
|
||||
if f.read().startswith("version"):
|
||||
raise OSError(
|
||||
"You seem to have cloned a repository without having git-lfs installed. Please install "
|
||||
"git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
|
||||
"you cloned."
|
||||
)
|
||||
else:
|
||||
raise ValueError from e
|
||||
except (UnicodeDecodeError, ValueError):
|
||||
raise OSError(
|
||||
f"Unable to load weights from pytorch checkpoint file for '{pretrained_model_name_or_path}' "
|
||||
f"at '{resolved_archive_file}'"
|
||||
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. "
|
||||
)
|
||||
|
||||
# set dtype to instantiate the model under:
|
||||
# 1. If torch_dtype is not None, we use that dtype
|
||||
|
Loading…
Reference in New Issue
Block a user