Force torch>=2.6 with torch.load to avoid vulnerability issue (#37785)

* fix all main files

* fix test files

* oups forgot modular

* add link

* update message
This commit is contained in:
Cyril Vallez 2025-04-25 16:57:09 +02:00 committed by GitHub
parent eefc86aa31
commit 0cfbf9c95b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 88 additions and 9 deletions

View File

@ -24,7 +24,7 @@ from filelock import FileLock
from torch.utils.data import Dataset
from ...tokenization_utils_base import PreTrainedTokenizerBase
from ...utils import logging
from ...utils import check_torch_load_is_safe, logging
from ..processors.glue import glue_convert_examples_to_features, glue_output_modes, glue_processors
from ..processors.utils import InputFeatures
@ -122,6 +122,7 @@ class GlueDataset(Dataset):
with FileLock(lock_path):
if os.path.exists(cached_features_file) and not args.overwrite_cache:
start = time.time()
check_torch_load_is_safe()
self.features = torch.load(cached_features_file, weights_only=True)
logger.info(
f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start

View File

@ -24,7 +24,7 @@ from torch.utils.data import Dataset
from ...models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING
from ...tokenization_utils import PreTrainedTokenizer
from ...utils import logging
from ...utils import check_torch_load_is_safe, logging
from ..processors.squad import SquadFeatures, SquadV1Processor, SquadV2Processor, squad_convert_examples_to_features
@ -148,6 +148,7 @@ class SquadDataset(Dataset):
with FileLock(lock_path):
if os.path.exists(cached_features_file) and not args.overwrite_cache:
start = time.time()
check_torch_load_is_safe()
self.old_features = torch.load(cached_features_file, weights_only=True)
# Legacy cache files have only features, while new cache files

View File

@ -27,7 +27,7 @@ from flax.traverse_util import flatten_dict, unflatten_dict
import transformers
from . import is_safetensors_available, is_torch_available
from .utils import logging
from .utils import check_torch_load_is_safe, logging
if is_torch_available():
@ -71,6 +71,7 @@ def load_pytorch_checkpoint_in_flax_state_dict(
)
raise
check_torch_load_is_safe()
pt_state_dict = torch.load(pt_path, map_location="cpu", weights_only=True)
logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.")
@ -247,6 +248,7 @@ def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model):
flax_state_dict = {}
for shard_file in shard_filenames:
# load using msgpack utils
check_torch_load_is_safe()
pt_state_dict = torch.load(shard_file, weights_only=True)
weight_dtypes = {k: v.dtype for k, v in pt_state_dict.items()}
pt_state_dict = {

View File

@ -21,6 +21,7 @@ import numpy
from .utils import (
ExplicitEnum,
check_torch_load_is_safe,
expand_dims,
is_numpy_array,
is_safetensors_available,
@ -198,6 +199,7 @@ def load_pytorch_checkpoint_in_tf2_model(
if pt_path.endswith(".safetensors"):
state_dict = safe_load_file(pt_path)
else:
check_torch_load_is_safe()
state_dict = torch.load(pt_path, map_location="cpu", weights_only=True)
pt_state_dict.update(state_dict)

View File

@ -94,6 +94,7 @@ from .utils import (
ModelOutput,
PushToHubMixin,
cached_file,
check_torch_load_is_safe,
copy_func,
download_url,
extract_commit_hash,
@ -445,7 +446,11 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
error_message += f"\nMissing key(s): {str_unexpected_keys}."
raise RuntimeError(error_message)
loader = safe_load_file if load_safe else partial(torch.load, map_location="cpu", weights_only=True)
if load_safe:
loader = safe_load_file
else:
check_torch_load_is_safe()
loader = partial(torch.load, map_location="cpu", weights_only=True)
for shard_file in shard_files:
state_dict = loader(os.path.join(folder, shard_file))
@ -490,6 +495,7 @@ def load_state_dict(
"""
Reads a `safetensor` or a `.bin` checkpoint file. We load the checkpoint on "cpu" by default.
"""
# Use safetensors if possible
if checkpoint_file.endswith(".safetensors") and is_safetensors_available():
with safe_open(checkpoint_file, framework="pt") as f:
metadata = f.metadata()
@ -512,6 +518,9 @@ def load_state_dict(
state_dict[k] = f.get_tensor(k)
return state_dict
# Fallback to torch.load (if weights_only was explicitly False, do not check safety as this is known to be unsafe)
if weights_only:
check_torch_load_is_safe()
try:
if map_location is None:
if (

View File

@ -29,6 +29,7 @@ import numpy as np
from ....tokenization_utils import PreTrainedTokenizer
from ....utils import (
cached_file,
check_torch_load_is_safe,
is_sacremoses_available,
is_torch_available,
logging,
@ -222,6 +223,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
"from a PyTorch pretrained vocabulary, "
"or activate it with environment variables USE_TORCH=1 and USE_TF=0."
)
check_torch_load_is_safe()
vocab_dict = torch.load(pretrained_vocab_file, weights_only=True)
if vocab_dict is not None:
@ -705,6 +707,7 @@ class TransfoXLCorpus:
# Instantiate tokenizer.
corpus = cls(*inputs, **kwargs)
check_torch_load_is_safe()
corpus_dict = torch.load(resolved_corpus_file, weights_only=True)
for key, value in corpus_dict.items():
corpus.__dict__[key] = value
@ -784,6 +787,7 @@ def get_lm_corpus(datadir, dataset):
fn_pickle = os.path.join(datadir, "cache.pkl")
if os.path.exists(fn):
logger.info("Loading cached dataset...")
check_torch_load_is_safe()
corpus = torch.load(fn_pickle, weights_only=True)
elif os.path.exists(fn):
logger.info("Loading cached dataset from pickle...")

View File

@ -26,6 +26,7 @@ from packaging import version
from transformers import AutoTokenizer, GPT2Config
from transformers.modeling_utils import WEIGHTS_INDEX_NAME, WEIGHTS_NAME
from transformers.utils import check_torch_load_is_safe
def add_checkpointing_args(parser):
@ -275,6 +276,7 @@ def merge_transformers_sharded_states(path, num_checkpoints):
state_dict = {}
for i in range(1, num_checkpoints + 1):
checkpoint_path = os.path.join(path, f"pytorch_model-{i:05d}-of-{num_checkpoints:05d}.bin")
check_torch_load_is_safe()
current_chunk = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
state_dict.update(current_chunk)
return state_dict
@ -298,6 +300,7 @@ def get_megatron_sharded_states(args, tp_size, pp_size, pp_rank):
checkpoint_path = os.path.join(args.load_path, sub_dir_name, checkpoint_name)
if os.path.isfile(checkpoint_path):
break
check_torch_load_is_safe()
state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
tp_state_dicts.append(state_dict)
return tp_state_dicts
@ -338,6 +341,7 @@ def convert_checkpoint_from_megatron_to_transformers(args):
rank0_checkpoint_path = os.path.join(args.load_path, sub_dir, rank0_checkpoint_name)
break
print(f"Loading Megatron-LM checkpoint arguments from: {rank0_checkpoint_path}")
check_torch_load_is_safe()
state_dict = torch.load(rank0_checkpoint_path, map_location="cpu", weights_only=True)
megatron_args = state_dict.get("args", None)
if megatron_args is None:
@ -634,6 +638,7 @@ def convert_checkpoint_from_transformers_to_megatron(args):
sub_dirs = [x for x in os.listdir(args.load_path) if x.startswith("pytorch_model")]
if len(sub_dirs) == 1:
checkpoint_name = "pytorch_model.bin"
check_torch_load_is_safe()
state_dict = torch.load(os.path.join(args.load_path, checkpoint_name), map_location="cpu", weights_only=True)
else:
num_checkpoints = len(sub_dirs) - 1

View File

@ -41,6 +41,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
check_torch_load_is_safe,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
is_torch_flex_attn_available,
@ -4391,7 +4392,8 @@ class Qwen2_5OmniForConditionalGeneration(Qwen2_5OmniPreTrainedModel, Generation
self.has_talker = True
def load_speakers(self, path):
for key, value in torch.load(path).items():
check_torch_load_is_safe()
for key, value in torch.load(path, weights_only=True).items():
self.speaker_map[key] = value
logger.info("Speaker {} loaded".format(list(self.speaker_map.keys())))

View File

@ -49,6 +49,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
check_torch_load_is_safe,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging,
@ -4078,7 +4079,8 @@ class Qwen2_5OmniForConditionalGeneration(Qwen2_5OmniPreTrainedModel, Generation
self.has_talker = True
def load_speakers(self, path):
for key, value in torch.load(path).items():
check_torch_load_is_safe()
for key, value in torch.load(path, weights_only=True).items():
self.speaker_map[key] = value
logger.info("Speaker {} loaded".format(list(self.speaker_map.keys())))

View File

@ -45,6 +45,7 @@ from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
cached_file,
check_torch_load_is_safe,
is_peft_available,
is_safetensors_available,
logging,
@ -1589,6 +1590,7 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
cache_dir=cache_dir,
)
check_torch_load_is_safe()
state_dict = torch.load(
weight_path,
map_location="cpu",
@ -1600,6 +1602,9 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
# to the original exception.
raise
except ValueError:
raise
except Exception:
# For any other exception, we throw a generic error.
raise EnvironmentError(

View File

@ -147,6 +147,7 @@ from .utils import (
PushInProgress,
PushToHubMixin,
can_return_loss,
check_torch_load_is_safe,
find_labels,
is_accelerate_available,
is_apex_available,
@ -2831,6 +2832,7 @@ class Trainer:
logger.warning(
"Enabling FP16 and loading from smp < 1.10 checkpoint together is not supported."
)
check_torch_load_is_safe()
state_dict = torch.load(weights_file, map_location="cpu", weights_only=True)
# Required for smp to not auto-translate state_dict from hf to smp (is already smp).
state_dict["_smp_is_partial"] = False
@ -2850,6 +2852,7 @@ class Trainer:
if self.args.save_safetensors and os.path.isfile(safe_weights_file):
state_dict = safetensors.torch.load_file(safe_weights_file, device="cpu")
else:
check_torch_load_is_safe()
state_dict = torch.load(weights_file, map_location="cpu", weights_only=True)
# workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
@ -2944,6 +2947,7 @@ class Trainer:
if self.args.save_safetensors and os.path.isfile(best_safe_model_path):
state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu")
else:
check_torch_load_is_safe()
state_dict = torch.load(best_model_path, map_location="cpu", weights_only=True)
state_dict["_smp_is_partial"] = False
@ -2999,6 +3003,7 @@ class Trainer:
if self.args.save_safetensors and os.path.isfile(best_safe_model_path):
state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu")
else:
check_torch_load_is_safe()
state_dict = torch.load(best_model_path, map_location="cpu", weights_only=True)
# If the model is on the GPU, it still works!
@ -3354,6 +3359,7 @@ class Trainer:
# deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init
if not isinstance(self.lr_scheduler, DeepSpeedSchedulerWrapper):
with warnings.catch_warnings(record=True) as caught_warnings:
check_torch_load_is_safe()
self.lr_scheduler.load_state_dict(
torch.load(os.path.join(checkpoint, SCHEDULER_NAME), weights_only=True)
)
@ -3386,6 +3392,7 @@ class Trainer:
if is_torch_xla_available():
# On TPU we have to take some extra precautions to properly load the states on the right device.
if self.is_fsdp_xla_v1_enabled:
check_torch_load_is_safe()
optimizer_state = torch.load(
os.path.join(
checkpoint, f"rank{self.args.process_index}-of-{self.args.world_size}-{OPTIMIZER_NAME}"
@ -3396,10 +3403,12 @@ class Trainer:
# We only need `optimizer` when resuming from checkpoint
optimizer_state = optimizer_state["optimizer"]
else:
check_torch_load_is_safe()
optimizer_state = torch.load(
os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu", weights_only=True
)
with warnings.catch_warnings(record=True) as caught_warnings:
check_torch_load_is_safe()
lr_scheduler_state = torch.load(
os.path.join(checkpoint, SCHEDULER_NAME), map_location="cpu", weights_only=True
)
@ -3443,12 +3452,14 @@ class Trainer:
**_get_fsdp_ckpt_kwargs(),
)
else:
check_torch_load_is_safe()
self.optimizer.load_state_dict(
torch.load(
os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location, weights_only=True
)
)
with warnings.catch_warnings(record=True) as caught_warnings:
check_torch_load_is_safe()
self.lr_scheduler.load_state_dict(
torch.load(os.path.join(checkpoint, SCHEDULER_NAME), weights_only=True)
)
@ -3486,6 +3497,7 @@ class Trainer:
# Load in scaler states
if is_torch_xla_available():
with warnings.catch_warnings(record=True) as caught_warnings:
check_torch_load_is_safe()
scaler_state = torch.load(
os.path.join(checkpoint, SCALER_NAME), map_location="cpu", weights_only=True
)
@ -3494,6 +3506,7 @@ class Trainer:
self.accelerator.scaler.load_state_dict(scaler_state)
else:
with warnings.catch_warnings(record=True) as caught_warnings:
check_torch_load_is_safe()
self.accelerator.scaler.load_state_dict(
torch.load(os.path.join(checkpoint, SCALER_NAME), weights_only=True)
)

View File

@ -115,6 +115,7 @@ from .import_utils import (
OptionalDependencyNotAvailable,
_LazyModule,
ccl_version,
check_torch_load_is_safe,
direct_transformers_import,
get_torch_version,
is_accelerate_available,

View File

@ -1387,6 +1387,16 @@ def is_rich_available():
return _rich_available
def check_torch_load_is_safe():
if not is_torch_greater_or_equal("2.6"):
raise ValueError(
"Due to a serious vulnerability issue in `torch.load`, even with `weights_only=True`, we now require users "
"to upgrade torch to at least v2.6 in order to use the function. This version restriction does not apply "
"when loading files with safetensors."
"\nSee the vulnerability report here https://nvd.nist.gov/vuln/detail/CVE-2025-32434"
)
# docstyle-ignore
AV_IMPORT_ERROR = """
{0} requires the PyAv library but it was not found in your environment. You can install it with:

View File

@ -21,6 +21,7 @@ from huggingface_hub import hf_hub_download
from transformers import is_torch_available
from transformers.testing_utils import is_flaky, require_torch, slow, torch_device
from transformers.utils import check_torch_load_is_safe
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
@ -414,6 +415,7 @@ class AutoformerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
def prepare_batch(filename="train-batch.pt"):
file = hf_hub_download(repo_id="hf-internal-testing/tourism-monthly-batch", filename=filename, repo_type="dataset")
check_torch_load_is_safe()
batch = torch.load(file, map_location=torch_device, weights_only=True)
return batch

View File

@ -22,6 +22,7 @@ from huggingface_hub import hf_hub_download
from transformers import is_torch_available
from transformers.testing_utils import is_flaky, require_torch, slow, torch_device
from transformers.utils import check_torch_load_is_safe
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
@ -475,6 +476,7 @@ class InformerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
def prepare_batch(filename="train-batch.pt"):
file = hf_hub_download(repo_id="hf-internal-testing/tourism-monthly-batch", filename=filename, repo_type="dataset")
check_torch_load_is_safe()
batch = torch.load(file, map_location=torch_device, weights_only=True)
return batch

View File

@ -33,6 +33,7 @@ from transformers.testing_utils import (
slow,
torch_device,
)
from transformers.utils import check_torch_load_is_safe
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
@ -366,6 +367,7 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
filename="llava_1_6_input_ids.pt",
repo_type="dataset",
)
check_torch_load_is_safe()
original_input_ids = torch.load(filepath, map_location="cpu", weights_only=True)
# replace -200 by image_token_index (since we use token ID = 32000 for the image token)
# remove image token indices because HF impl expands image tokens `image_seq_length` times
@ -378,6 +380,7 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
filename="llava_1_6_pixel_values.pt",
repo_type="dataset",
)
check_torch_load_is_safe()
original_pixel_values = torch.load(filepath, map_location="cpu", weights_only=True)
assert torch.allclose(original_pixel_values, inputs.pixel_values.half())

View File

@ -412,7 +412,6 @@ class OPTEmbeddingsTest(unittest.TestCase):
# verify that prompt without BOS token is identical to Metaseq -> add_special_tokens=False
inputs = tokenizer(prompts, return_tensors="pt", padding=True, add_special_tokens=False)
logits = model(inputs.input_ids, attention_mask=inputs.attention_mask)[0].mean(dim=-1)
# logits_meta = torch.load(self.path_logits_meta)
logits_meta = torch.Tensor(
[
[1.3851, -13.8923, -10.5229, -10.7533, -0.2309, -10.2384, -0.5365, -9.0947, -5.1670],

View File

@ -27,6 +27,7 @@ from parameterized import parameterized
from transformers import is_torch_available
from transformers.models.auto import get_values
from transformers.testing_utils import is_flaky, require_torch, slow, torch_device
from transformers.utils import check_torch_load_is_safe
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
@ -451,6 +452,7 @@ class PatchTSMixerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Test
def prepare_batch(repo_id="ibm/patchtsmixer-etth1-test-data", file="pretrain_batch.pt"):
# TODO: Make repo public
file = hf_hub_download(repo_id=repo_id, filename=file, repo_type="dataset")
check_torch_load_is_safe()
batch = torch.load(file, map_location=torch_device, weights_only=True)
return batch

View File

@ -23,6 +23,7 @@ from huggingface_hub import hf_hub_download
from transformers import is_torch_available
from transformers.models.auto import get_values
from transformers.testing_utils import is_flaky, require_torch, slow, torch_device
from transformers.utils import check_torch_load_is_safe
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
@ -302,6 +303,7 @@ class PatchTSTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
def prepare_batch(repo_id="hf-internal-testing/etth1-hourly-batch", file="train-batch.pt"):
file = hf_hub_download(repo_id=repo_id, filename=file, repo_type="dataset")
check_torch_load_is_safe()
batch = torch.load(file, map_location=torch_device, weights_only=True)
return batch

View File

@ -22,6 +22,7 @@ from parameterized import parameterized
from transformers import is_torch_available
from transformers.testing_utils import is_flaky, require_torch, slow, torch_device
from transformers.utils import check_torch_load_is_safe
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
@ -480,6 +481,7 @@ class TimeSeriesTransformerModelTest(ModelTesterMixin, PipelineTesterMixin, unit
def prepare_batch(filename="train-batch.pt"):
file = hf_hub_download(repo_id="hf-internal-testing/tourism-monthly-batch", filename=filename, repo_type="dataset")
check_torch_load_is_safe()
batch = torch.load(file, map_location=torch_device, weights_only=True)
return batch

View File

@ -32,7 +32,7 @@ from transformers.testing_utils import (
slow,
torch_device,
)
from transformers.utils import cached_property, is_torch_available, is_vision_available
from transformers.utils import cached_property, check_torch_load_is_safe, is_torch_available, is_vision_available
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
@ -455,6 +455,7 @@ class VideoMAEModelIntegrationTest(unittest.TestCase):
# add boolean mask, indicating which patches to mask
local_path = hf_hub_download(repo_id="hf-internal-testing/bool-masked-pos", filename="bool_masked_pos.pt")
check_torch_load_is_safe()
inputs["bool_masked_pos"] = torch.load(local_path, weights_only=True)
# forward pass

View File

@ -38,7 +38,7 @@ from transformers.testing_utils import (
slow,
torch_device,
)
from transformers.utils import is_torch_available
from transformers.utils import check_torch_load_is_safe, is_torch_available
if is_torch_available():
@ -552,6 +552,7 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
state_dict_path = hf_hub_download(peft_model_id, "adapter_model.bin")
check_torch_load_is_safe()
dummy_state_dict = torch.load(state_dict_path, weights_only=True)
model.load_adapter(adapter_state_dict=dummy_state_dict, peft_config=peft_config)
@ -577,6 +578,7 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
peft_config = LoraConfig()
state_dict_path = hf_hub_download(peft_model_id, "adapter_model.bin")
check_torch_load_is_safe()
dummy_state_dict = torch.load(state_dict_path, weights_only=True)
# this should always work
@ -645,6 +647,7 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
peft_config = LoraConfig()
state_dict_path = hf_hub_download(peft_model_id, "adapter_model.bin")
check_torch_load_is_safe()
dummy_state_dict = torch.load(state_dict_path, weights_only=True)
# add unexpected key
@ -672,6 +675,7 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
peft_config = LoraConfig()
state_dict_path = hf_hub_download(peft_model_id, "adapter_model.bin")
check_torch_load_is_safe()
dummy_state_dict = torch.load(state_dict_path, weights_only=True)
# remove a key so that we have missing keys

View File

@ -113,6 +113,7 @@ from transformers.utils import (
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
check_torch_load_is_safe,
is_accelerate_available,
is_apex_available,
is_bitsandbytes_available,
@ -646,6 +647,7 @@ class TrainerIntegrationCommon:
else:
best_model = RegressionModel()
if not safe_weights:
check_torch_load_is_safe()
state_dict = torch.load(os.path.join(checkpoint, WEIGHTS_NAME), weights_only=True)
else:
state_dict = safetensors.torch.load_file(os.path.join(checkpoint, SAFE_WEIGHTS_NAME))
@ -678,6 +680,7 @@ class TrainerIntegrationCommon:
loader = safetensors.torch.load_file
weights_file = os.path.join(folder, SAFE_WEIGHTS_NAME)
else:
check_torch_load_is_safe()
loader = torch.load
weights_file = os.path.join(folder, WEIGHTS_NAME)

View File

@ -74,6 +74,7 @@ from transformers.utils import (
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
check_torch_load_is_safe,
)
from transformers.utils.import_utils import (
is_flash_attn_2_available,
@ -739,6 +740,7 @@ class ModelUtilsTest(TestCasePlus):
# Note: pickle adds some junk so the weight of the file can end up being slightly bigger than
# the size asked for (since we count parameters)
if size >= max_size_int + 50000:
check_torch_load_is_safe()
state_dict = torch.load(shard_file, weights_only=True)
self.assertEqual(len(state_dict), 1)