make torch.load a bit safer (#27282)

* make torch.load a bit safer

* Fixes

---------

Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
Julien Chaumond 2023-12-15 16:01:18 +01:00 committed by GitHub
parent 74cae670ce
commit dec84b3211
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 11 additions and 11 deletions

View File

@ -329,7 +329,7 @@ def convert_pt_checkpoint_to_tf(
if compare_with_pt_model:
tfo = tf_model(tf_model.dummy_inputs, training=False) # build the network
state_dict = torch.load(pytorch_checkpoint_path, map_location="cpu")
state_dict = torch.load(pytorch_checkpoint_path, map_location="cpu", weights_only=True)
pt_model = pt_model_class.from_pretrained(
pretrained_model_name_or_path=None, config=config, state_dict=state_dict
)

View File

@ -68,7 +68,7 @@ def load_pytorch_checkpoint_in_flax_state_dict(
for k in f.keys():
pt_state_dict[k] = f.get_tensor(k)
else:
pt_state_dict = torch.load(pt_path, map_location="cpu")
pt_state_dict = torch.load(pt_path, map_location="cpu", weights_only=True)
logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.")
flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model)
@ -249,7 +249,7 @@ def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model):
flax_state_dict = {}
for shard_file in shard_filenames:
# load using msgpack utils
pt_state_dict = torch.load(shard_file)
pt_state_dict = torch.load(shard_file, weights_only=True)
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
model_prefix = flax_model.base_model_prefix

View File

@ -186,7 +186,7 @@ def load_pytorch_checkpoint_in_tf2_model(
if pt_path.endswith(".safetensors"):
state_dict = safe_load_file(pt_path)
else:
state_dict = torch.load(pt_path, map_location="cpu")
state_dict = torch.load(pt_path, map_location="cpu", weights_only=True)
pt_state_dict.update(state_dict)

View File

@ -480,7 +480,7 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
error_message += f"\nMissing key(s): {str_unexpected_keys}."
raise RuntimeError(error_message)
loader = safe_load_file if load_safe else partial(torch.load, map_location="cpu")
loader = safe_load_file if load_safe else partial(torch.load, map_location="cpu", weights_only=True)
for shard_file in shard_files:
state_dict = loader(os.path.join(folder, shard_file))
@ -516,7 +516,7 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
else:
map_location = "cpu"
return torch.load(checkpoint_file, map_location=map_location)
return torch.load(checkpoint_file, map_location=map_location, weights_only=True)
except Exception as e:
try:
with open(checkpoint_file) as f:

View File

@ -1333,7 +1333,7 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
cache_dir=cache_dir,
)
state_dict = torch.load(weight_path, map_location="cpu")
state_dict = torch.load(weight_path, map_location="cpu", weights_only=True)
except EnvironmentError:
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted

View File

@ -2086,7 +2086,7 @@ class Trainer:
logger.warning(
"Enabling FP16 and loading from smp < 1.10 checkpoint together is not suppported."
)
state_dict = torch.load(weights_file, map_location="cpu")
state_dict = torch.load(weights_file, map_location="cpu", weights_only=True)
# Required for smp to not auto-translate state_dict from hf to smp (is already smp).
state_dict["_smp_is_partial"] = False
load_result = model.load_state_dict(state_dict, strict=True)
@ -2099,7 +2099,7 @@ class Trainer:
if self.args.save_safetensors and os.path.isfile(safe_weights_file):
state_dict = safetensors.torch.load_file(safe_weights_file, device="cpu")
else:
state_dict = torch.load(weights_file, map_location="cpu")
state_dict = torch.load(weights_file, map_location="cpu", weights_only=True)
# workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
# which takes *args instead of **kwargs
@ -2167,7 +2167,7 @@ class Trainer:
if self.args.save_safetensors and os.path.isfile(best_safe_model_path):
state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu")
else:
state_dict = torch.load(best_model_path, map_location="cpu")
state_dict = torch.load(best_model_path, map_location="cpu", weights_only=True)
state_dict["_smp_is_partial"] = False
load_result = model.load_state_dict(state_dict, strict=True)
@ -2196,7 +2196,7 @@ class Trainer:
if self.args.save_safetensors and os.path.isfile(best_safe_model_path):
state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu")
else:
state_dict = torch.load(best_model_path, map_location="cpu")
state_dict = torch.load(best_model_path, map_location="cpu", weights_only=True)
# If the model is on the GPU, it still works!
# workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963