mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
make torch.load a bit safer (#27282)
* make torch.load a bit safer * Fixes --------- Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
parent
74cae670ce
commit
dec84b3211
@ -329,7 +329,7 @@ def convert_pt_checkpoint_to_tf(
|
||||
if compare_with_pt_model:
|
||||
tfo = tf_model(tf_model.dummy_inputs, training=False) # build the network
|
||||
|
||||
state_dict = torch.load(pytorch_checkpoint_path, map_location="cpu")
|
||||
state_dict = torch.load(pytorch_checkpoint_path, map_location="cpu", weights_only=True)
|
||||
pt_model = pt_model_class.from_pretrained(
|
||||
pretrained_model_name_or_path=None, config=config, state_dict=state_dict
|
||||
)
|
||||
|
@ -68,7 +68,7 @@ def load_pytorch_checkpoint_in_flax_state_dict(
|
||||
for k in f.keys():
|
||||
pt_state_dict[k] = f.get_tensor(k)
|
||||
else:
|
||||
pt_state_dict = torch.load(pt_path, map_location="cpu")
|
||||
pt_state_dict = torch.load(pt_path, map_location="cpu", weights_only=True)
|
||||
logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.")
|
||||
|
||||
flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model)
|
||||
@ -249,7 +249,7 @@ def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model):
|
||||
flax_state_dict = {}
|
||||
for shard_file in shard_filenames:
|
||||
# load using msgpack utils
|
||||
pt_state_dict = torch.load(shard_file)
|
||||
pt_state_dict = torch.load(shard_file, weights_only=True)
|
||||
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
|
||||
|
||||
model_prefix = flax_model.base_model_prefix
|
||||
|
@ -186,7 +186,7 @@ def load_pytorch_checkpoint_in_tf2_model(
|
||||
if pt_path.endswith(".safetensors"):
|
||||
state_dict = safe_load_file(pt_path)
|
||||
else:
|
||||
state_dict = torch.load(pt_path, map_location="cpu")
|
||||
state_dict = torch.load(pt_path, map_location="cpu", weights_only=True)
|
||||
|
||||
pt_state_dict.update(state_dict)
|
||||
|
||||
|
@ -480,7 +480,7 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
|
||||
error_message += f"\nMissing key(s): {str_unexpected_keys}."
|
||||
raise RuntimeError(error_message)
|
||||
|
||||
loader = safe_load_file if load_safe else partial(torch.load, map_location="cpu")
|
||||
loader = safe_load_file if load_safe else partial(torch.load, map_location="cpu", weights_only=True)
|
||||
|
||||
for shard_file in shard_files:
|
||||
state_dict = loader(os.path.join(folder, shard_file))
|
||||
@ -516,7 +516,7 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
|
||||
else:
|
||||
map_location = "cpu"
|
||||
|
||||
return torch.load(checkpoint_file, map_location=map_location)
|
||||
return torch.load(checkpoint_file, map_location=map_location, weights_only=True)
|
||||
except Exception as e:
|
||||
try:
|
||||
with open(checkpoint_file) as f:
|
||||
|
@ -1333,7 +1333,7 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
|
||||
state_dict = torch.load(weight_path, map_location="cpu")
|
||||
state_dict = torch.load(weight_path, map_location="cpu", weights_only=True)
|
||||
|
||||
except EnvironmentError:
|
||||
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
|
||||
|
@ -2086,7 +2086,7 @@ class Trainer:
|
||||
logger.warning(
|
||||
"Enabling FP16 and loading from smp < 1.10 checkpoint together is not suppported."
|
||||
)
|
||||
state_dict = torch.load(weights_file, map_location="cpu")
|
||||
state_dict = torch.load(weights_file, map_location="cpu", weights_only=True)
|
||||
# Required for smp to not auto-translate state_dict from hf to smp (is already smp).
|
||||
state_dict["_smp_is_partial"] = False
|
||||
load_result = model.load_state_dict(state_dict, strict=True)
|
||||
@ -2099,7 +2099,7 @@ class Trainer:
|
||||
if self.args.save_safetensors and os.path.isfile(safe_weights_file):
|
||||
state_dict = safetensors.torch.load_file(safe_weights_file, device="cpu")
|
||||
else:
|
||||
state_dict = torch.load(weights_file, map_location="cpu")
|
||||
state_dict = torch.load(weights_file, map_location="cpu", weights_only=True)
|
||||
|
||||
# workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
|
||||
# which takes *args instead of **kwargs
|
||||
@ -2167,7 +2167,7 @@ class Trainer:
|
||||
if self.args.save_safetensors and os.path.isfile(best_safe_model_path):
|
||||
state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu")
|
||||
else:
|
||||
state_dict = torch.load(best_model_path, map_location="cpu")
|
||||
state_dict = torch.load(best_model_path, map_location="cpu", weights_only=True)
|
||||
|
||||
state_dict["_smp_is_partial"] = False
|
||||
load_result = model.load_state_dict(state_dict, strict=True)
|
||||
@ -2196,7 +2196,7 @@ class Trainer:
|
||||
if self.args.save_safetensors and os.path.isfile(best_safe_model_path):
|
||||
state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu")
|
||||
else:
|
||||
state_dict = torch.load(best_model_path, map_location="cpu")
|
||||
state_dict = torch.load(best_model_path, map_location="cpu", weights_only=True)
|
||||
|
||||
# If the model is on the GPU, it still works!
|
||||
# workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
|
||||
|
Loading…
Reference in New Issue
Block a user