diff --git a/src/transformers/data/datasets/glue.py b/src/transformers/data/datasets/glue.py index 43a1b75e518..5541e5927a3 100644 --- a/src/transformers/data/datasets/glue.py +++ b/src/transformers/data/datasets/glue.py @@ -24,7 +24,7 @@ from filelock import FileLock from torch.utils.data import Dataset from ...tokenization_utils_base import PreTrainedTokenizerBase -from ...utils import logging +from ...utils import check_torch_load_is_safe, logging from ..processors.glue import glue_convert_examples_to_features, glue_output_modes, glue_processors from ..processors.utils import InputFeatures @@ -122,6 +122,7 @@ class GlueDataset(Dataset): with FileLock(lock_path): if os.path.exists(cached_features_file) and not args.overwrite_cache: start = time.time() + check_torch_load_is_safe() self.features = torch.load(cached_features_file, weights_only=True) logger.info( f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start diff --git a/src/transformers/data/datasets/squad.py b/src/transformers/data/datasets/squad.py index 7546d7b49ed..a84f83d8f84 100644 --- a/src/transformers/data/datasets/squad.py +++ b/src/transformers/data/datasets/squad.py @@ -24,7 +24,7 @@ from torch.utils.data import Dataset from ...models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING from ...tokenization_utils import PreTrainedTokenizer -from ...utils import logging +from ...utils import check_torch_load_is_safe, logging from ..processors.squad import SquadFeatures, SquadV1Processor, SquadV2Processor, squad_convert_examples_to_features @@ -148,6 +148,7 @@ class SquadDataset(Dataset): with FileLock(lock_path): if os.path.exists(cached_features_file) and not args.overwrite_cache: start = time.time() + check_torch_load_is_safe() self.old_features = torch.load(cached_features_file, weights_only=True) # Legacy cache files have only features, while new cache files diff --git a/src/transformers/modeling_flax_pytorch_utils.py b/src/transformers/modeling_flax_pytorch_utils.py index 07285065772..1639ce9094d 100644 --- a/src/transformers/modeling_flax_pytorch_utils.py +++ b/src/transformers/modeling_flax_pytorch_utils.py @@ -27,7 +27,7 @@ from flax.traverse_util import flatten_dict, unflatten_dict import transformers from . import is_safetensors_available, is_torch_available -from .utils import logging +from .utils import check_torch_load_is_safe, logging if is_torch_available(): @@ -71,6 +71,7 @@ def load_pytorch_checkpoint_in_flax_state_dict( ) raise + check_torch_load_is_safe() pt_state_dict = torch.load(pt_path, map_location="cpu", weights_only=True) logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.") @@ -247,6 +248,7 @@ def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model): flax_state_dict = {} for shard_file in shard_filenames: # load using msgpack utils + check_torch_load_is_safe() pt_state_dict = torch.load(shard_file, weights_only=True) weight_dtypes = {k: v.dtype for k, v in pt_state_dict.items()} pt_state_dict = { diff --git a/src/transformers/modeling_tf_pytorch_utils.py b/src/transformers/modeling_tf_pytorch_utils.py index 84a6ddaebcc..51c21bb7fa4 100644 --- a/src/transformers/modeling_tf_pytorch_utils.py +++ b/src/transformers/modeling_tf_pytorch_utils.py @@ -21,6 +21,7 @@ import numpy from .utils import ( ExplicitEnum, + check_torch_load_is_safe, expand_dims, is_numpy_array, is_safetensors_available, @@ -198,6 +199,7 @@ def load_pytorch_checkpoint_in_tf2_model( if pt_path.endswith(".safetensors"): state_dict = safe_load_file(pt_path) else: + check_torch_load_is_safe() state_dict = torch.load(pt_path, map_location="cpu", weights_only=True) pt_state_dict.update(state_dict) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 5fbf51c2970..0ee4182f98c 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -94,6 +94,7 @@ from .utils import ( ModelOutput, PushToHubMixin, cached_file, + check_torch_load_is_safe, copy_func, download_url, extract_commit_hash, @@ -445,7 +446,11 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True): error_message += f"\nMissing key(s): {str_unexpected_keys}." raise RuntimeError(error_message) - loader = safe_load_file if load_safe else partial(torch.load, map_location="cpu", weights_only=True) + if load_safe: + loader = safe_load_file + else: + check_torch_load_is_safe() + loader = partial(torch.load, map_location="cpu", weights_only=True) for shard_file in shard_files: state_dict = loader(os.path.join(folder, shard_file)) @@ -490,6 +495,7 @@ def load_state_dict( """ Reads a `safetensor` or a `.bin` checkpoint file. We load the checkpoint on "cpu" by default. """ + # Use safetensors if possible if checkpoint_file.endswith(".safetensors") and is_safetensors_available(): with safe_open(checkpoint_file, framework="pt") as f: metadata = f.metadata() @@ -512,6 +518,9 @@ def load_state_dict( state_dict[k] = f.get_tensor(k) return state_dict + # Fallback to torch.load (if weights_only was explicitly False, do not check safety as this is known to be unsafe) + if weights_only: + check_torch_load_is_safe() try: if map_location is None: if ( diff --git a/src/transformers/models/deprecated/transfo_xl/tokenization_transfo_xl.py b/src/transformers/models/deprecated/transfo_xl/tokenization_transfo_xl.py index f9a5fb7b346..d9a118f54d4 100644 --- a/src/transformers/models/deprecated/transfo_xl/tokenization_transfo_xl.py +++ b/src/transformers/models/deprecated/transfo_xl/tokenization_transfo_xl.py @@ -29,6 +29,7 @@ import numpy as np from ....tokenization_utils import PreTrainedTokenizer from ....utils import ( cached_file, + check_torch_load_is_safe, is_sacremoses_available, is_torch_available, logging, @@ -222,6 +223,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer): "from a PyTorch pretrained vocabulary, " "or activate it with environment variables USE_TORCH=1 and USE_TF=0." ) + check_torch_load_is_safe() vocab_dict = torch.load(pretrained_vocab_file, weights_only=True) if vocab_dict is not None: @@ -705,6 +707,7 @@ class TransfoXLCorpus: # Instantiate tokenizer. corpus = cls(*inputs, **kwargs) + check_torch_load_is_safe() corpus_dict = torch.load(resolved_corpus_file, weights_only=True) for key, value in corpus_dict.items(): corpus.__dict__[key] = value @@ -784,6 +787,7 @@ def get_lm_corpus(datadir, dataset): fn_pickle = os.path.join(datadir, "cache.pkl") if os.path.exists(fn): logger.info("Loading cached dataset...") + check_torch_load_is_safe() corpus = torch.load(fn_pickle, weights_only=True) elif os.path.exists(fn): logger.info("Loading cached dataset from pickle...") diff --git a/src/transformers/models/megatron_gpt2/checkpoint_reshaping_and_interoperability.py b/src/transformers/models/megatron_gpt2/checkpoint_reshaping_and_interoperability.py index 548e2d1aeb3..68e2404475e 100644 --- a/src/transformers/models/megatron_gpt2/checkpoint_reshaping_and_interoperability.py +++ b/src/transformers/models/megatron_gpt2/checkpoint_reshaping_and_interoperability.py @@ -26,6 +26,7 @@ from packaging import version from transformers import AutoTokenizer, GPT2Config from transformers.modeling_utils import WEIGHTS_INDEX_NAME, WEIGHTS_NAME +from transformers.utils import check_torch_load_is_safe def add_checkpointing_args(parser): @@ -275,6 +276,7 @@ def merge_transformers_sharded_states(path, num_checkpoints): state_dict = {} for i in range(1, num_checkpoints + 1): checkpoint_path = os.path.join(path, f"pytorch_model-{i:05d}-of-{num_checkpoints:05d}.bin") + check_torch_load_is_safe() current_chunk = torch.load(checkpoint_path, map_location="cpu", weights_only=True) state_dict.update(current_chunk) return state_dict @@ -298,6 +300,7 @@ def get_megatron_sharded_states(args, tp_size, pp_size, pp_rank): checkpoint_path = os.path.join(args.load_path, sub_dir_name, checkpoint_name) if os.path.isfile(checkpoint_path): break + check_torch_load_is_safe() state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True) tp_state_dicts.append(state_dict) return tp_state_dicts @@ -338,6 +341,7 @@ def convert_checkpoint_from_megatron_to_transformers(args): rank0_checkpoint_path = os.path.join(args.load_path, sub_dir, rank0_checkpoint_name) break print(f"Loading Megatron-LM checkpoint arguments from: {rank0_checkpoint_path}") + check_torch_load_is_safe() state_dict = torch.load(rank0_checkpoint_path, map_location="cpu", weights_only=True) megatron_args = state_dict.get("args", None) if megatron_args is None: @@ -634,6 +638,7 @@ def convert_checkpoint_from_transformers_to_megatron(args): sub_dirs = [x for x in os.listdir(args.load_path) if x.startswith("pytorch_model")] if len(sub_dirs) == 1: checkpoint_name = "pytorch_model.bin" + check_torch_load_is_safe() state_dict = torch.load(os.path.join(args.load_path, checkpoint_name), map_location="cpu", weights_only=True) else: num_checkpoints = len(sub_dirs) - 1 diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index 17551dc5aeb..19c88c047a9 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -41,6 +41,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, + check_torch_load_is_safe, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, is_torch_flex_attn_available, @@ -4391,7 +4392,8 @@ class Qwen2_5OmniForConditionalGeneration(Qwen2_5OmniPreTrainedModel, Generation self.has_talker = True def load_speakers(self, path): - for key, value in torch.load(path).items(): + check_torch_load_is_safe() + for key, value in torch.load(path, weights_only=True).items(): self.speaker_map[key] = value logger.info("Speaker {} loaded".format(list(self.speaker_map.keys()))) diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index 9d3f2d9ec6a..2123be29031 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -49,6 +49,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, + check_torch_load_is_safe, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging, @@ -4078,7 +4079,8 @@ class Qwen2_5OmniForConditionalGeneration(Qwen2_5OmniPreTrainedModel, Generation self.has_talker = True def load_speakers(self, path): - for key, value in torch.load(path).items(): + check_torch_load_is_safe() + for key, value in torch.load(path, weights_only=True).items(): self.speaker_map[key] = value logger.info("Speaker {} loaded".format(list(self.speaker_map.keys()))) diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index 46c5dd790a3..f3708eaa2d5 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -45,6 +45,7 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, cached_file, + check_torch_load_is_safe, is_peft_available, is_safetensors_available, logging, @@ -1589,6 +1590,7 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel): cache_dir=cache_dir, ) + check_torch_load_is_safe() state_dict = torch.load( weight_path, map_location="cpu", @@ -1600,6 +1602,9 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel): # to the original exception. raise + except ValueError: + raise + except Exception: # For any other exception, we throw a generic error. raise EnvironmentError( diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index fc82d0b2557..90957024bc6 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -147,6 +147,7 @@ from .utils import ( PushInProgress, PushToHubMixin, can_return_loss, + check_torch_load_is_safe, find_labels, is_accelerate_available, is_apex_available, @@ -2831,6 +2832,7 @@ class Trainer: logger.warning( "Enabling FP16 and loading from smp < 1.10 checkpoint together is not supported." ) + check_torch_load_is_safe() state_dict = torch.load(weights_file, map_location="cpu", weights_only=True) # Required for smp to not auto-translate state_dict from hf to smp (is already smp). state_dict["_smp_is_partial"] = False @@ -2850,6 +2852,7 @@ class Trainer: if self.args.save_safetensors and os.path.isfile(safe_weights_file): state_dict = safetensors.torch.load_file(safe_weights_file, device="cpu") else: + check_torch_load_is_safe() state_dict = torch.load(weights_file, map_location="cpu", weights_only=True) # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963 @@ -2944,6 +2947,7 @@ class Trainer: if self.args.save_safetensors and os.path.isfile(best_safe_model_path): state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu") else: + check_torch_load_is_safe() state_dict = torch.load(best_model_path, map_location="cpu", weights_only=True) state_dict["_smp_is_partial"] = False @@ -2999,6 +3003,7 @@ class Trainer: if self.args.save_safetensors and os.path.isfile(best_safe_model_path): state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu") else: + check_torch_load_is_safe() state_dict = torch.load(best_model_path, map_location="cpu", weights_only=True) # If the model is on the GPU, it still works! @@ -3354,6 +3359,7 @@ class Trainer: # deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init if not isinstance(self.lr_scheduler, DeepSpeedSchedulerWrapper): with warnings.catch_warnings(record=True) as caught_warnings: + check_torch_load_is_safe() self.lr_scheduler.load_state_dict( torch.load(os.path.join(checkpoint, SCHEDULER_NAME), weights_only=True) ) @@ -3386,6 +3392,7 @@ class Trainer: if is_torch_xla_available(): # On TPU we have to take some extra precautions to properly load the states on the right device. if self.is_fsdp_xla_v1_enabled: + check_torch_load_is_safe() optimizer_state = torch.load( os.path.join( checkpoint, f"rank{self.args.process_index}-of-{self.args.world_size}-{OPTIMIZER_NAME}" @@ -3396,10 +3403,12 @@ class Trainer: # We only need `optimizer` when resuming from checkpoint optimizer_state = optimizer_state["optimizer"] else: + check_torch_load_is_safe() optimizer_state = torch.load( os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu", weights_only=True ) with warnings.catch_warnings(record=True) as caught_warnings: + check_torch_load_is_safe() lr_scheduler_state = torch.load( os.path.join(checkpoint, SCHEDULER_NAME), map_location="cpu", weights_only=True ) @@ -3443,12 +3452,14 @@ class Trainer: **_get_fsdp_ckpt_kwargs(), ) else: + check_torch_load_is_safe() self.optimizer.load_state_dict( torch.load( os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location, weights_only=True ) ) with warnings.catch_warnings(record=True) as caught_warnings: + check_torch_load_is_safe() self.lr_scheduler.load_state_dict( torch.load(os.path.join(checkpoint, SCHEDULER_NAME), weights_only=True) ) @@ -3486,6 +3497,7 @@ class Trainer: # Load in scaler states if is_torch_xla_available(): with warnings.catch_warnings(record=True) as caught_warnings: + check_torch_load_is_safe() scaler_state = torch.load( os.path.join(checkpoint, SCALER_NAME), map_location="cpu", weights_only=True ) @@ -3494,6 +3506,7 @@ class Trainer: self.accelerator.scaler.load_state_dict(scaler_state) else: with warnings.catch_warnings(record=True) as caught_warnings: + check_torch_load_is_safe() self.accelerator.scaler.load_state_dict( torch.load(os.path.join(checkpoint, SCALER_NAME), weights_only=True) ) diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 0b3e75cc5ab..50b1a57c3b8 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -115,6 +115,7 @@ from .import_utils import ( OptionalDependencyNotAvailable, _LazyModule, ccl_version, + check_torch_load_is_safe, direct_transformers_import, get_torch_version, is_accelerate_available, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 7291b8f98a0..905781fa81a 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -1387,6 +1387,16 @@ def is_rich_available(): return _rich_available +def check_torch_load_is_safe(): + if not is_torch_greater_or_equal("2.6"): + raise ValueError( + "Due to a serious vulnerability issue in `torch.load`, even with `weights_only=True`, we now require users " + "to upgrade torch to at least v2.6 in order to use the function. This version restriction does not apply " + "when loading files with safetensors." + "\nSee the vulnerability report here https://nvd.nist.gov/vuln/detail/CVE-2025-32434" + ) + + # docstyle-ignore AV_IMPORT_ERROR = """ {0} requires the PyAv library but it was not found in your environment. You can install it with: diff --git a/tests/models/autoformer/test_modeling_autoformer.py b/tests/models/autoformer/test_modeling_autoformer.py index 578977c00cd..3c1435f33ee 100644 --- a/tests/models/autoformer/test_modeling_autoformer.py +++ b/tests/models/autoformer/test_modeling_autoformer.py @@ -21,6 +21,7 @@ from huggingface_hub import hf_hub_download from transformers import is_torch_available from transformers.testing_utils import is_flaky, require_torch, slow, torch_device +from transformers.utils import check_torch_load_is_safe from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor @@ -414,6 +415,7 @@ class AutoformerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa def prepare_batch(filename="train-batch.pt"): file = hf_hub_download(repo_id="hf-internal-testing/tourism-monthly-batch", filename=filename, repo_type="dataset") + check_torch_load_is_safe() batch = torch.load(file, map_location=torch_device, weights_only=True) return batch diff --git a/tests/models/informer/test_modeling_informer.py b/tests/models/informer/test_modeling_informer.py index 49408197715..cb197631975 100644 --- a/tests/models/informer/test_modeling_informer.py +++ b/tests/models/informer/test_modeling_informer.py @@ -22,6 +22,7 @@ from huggingface_hub import hf_hub_download from transformers import is_torch_available from transformers.testing_utils import is_flaky, require_torch, slow, torch_device +from transformers.utils import check_torch_load_is_safe from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor @@ -475,6 +476,7 @@ class InformerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase def prepare_batch(filename="train-batch.pt"): file = hf_hub_download(repo_id="hf-internal-testing/tourism-monthly-batch", filename=filename, repo_type="dataset") + check_torch_load_is_safe() batch = torch.load(file, map_location=torch_device, weights_only=True) return batch diff --git a/tests/models/llava_next/test_modeling_llava_next.py b/tests/models/llava_next/test_modeling_llava_next.py index 3b9fc36521e..669a23e109c 100644 --- a/tests/models/llava_next/test_modeling_llava_next.py +++ b/tests/models/llava_next/test_modeling_llava_next.py @@ -33,6 +33,7 @@ from transformers.testing_utils import ( slow, torch_device, ) +from transformers.utils import check_torch_load_is_safe from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester @@ -366,6 +367,7 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase): filename="llava_1_6_input_ids.pt", repo_type="dataset", ) + check_torch_load_is_safe() original_input_ids = torch.load(filepath, map_location="cpu", weights_only=True) # replace -200 by image_token_index (since we use token ID = 32000 for the image token) # remove image token indices because HF impl expands image tokens `image_seq_length` times @@ -378,6 +380,7 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase): filename="llava_1_6_pixel_values.pt", repo_type="dataset", ) + check_torch_load_is_safe() original_pixel_values = torch.load(filepath, map_location="cpu", weights_only=True) assert torch.allclose(original_pixel_values, inputs.pixel_values.half()) diff --git a/tests/models/opt/test_modeling_opt.py b/tests/models/opt/test_modeling_opt.py index 15e8968cfb4..79791e71513 100644 --- a/tests/models/opt/test_modeling_opt.py +++ b/tests/models/opt/test_modeling_opt.py @@ -412,7 +412,6 @@ class OPTEmbeddingsTest(unittest.TestCase): # verify that prompt without BOS token is identical to Metaseq -> add_special_tokens=False inputs = tokenizer(prompts, return_tensors="pt", padding=True, add_special_tokens=False) logits = model(inputs.input_ids, attention_mask=inputs.attention_mask)[0].mean(dim=-1) - # logits_meta = torch.load(self.path_logits_meta) logits_meta = torch.Tensor( [ [1.3851, -13.8923, -10.5229, -10.7533, -0.2309, -10.2384, -0.5365, -9.0947, -5.1670], diff --git a/tests/models/patchtsmixer/test_modeling_patchtsmixer.py b/tests/models/patchtsmixer/test_modeling_patchtsmixer.py index b5b92f4bee0..5c5ff131533 100644 --- a/tests/models/patchtsmixer/test_modeling_patchtsmixer.py +++ b/tests/models/patchtsmixer/test_modeling_patchtsmixer.py @@ -27,6 +27,7 @@ from parameterized import parameterized from transformers import is_torch_available from transformers.models.auto import get_values from transformers.testing_utils import is_flaky, require_torch, slow, torch_device +from transformers.utils import check_torch_load_is_safe from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor @@ -451,6 +452,7 @@ class PatchTSMixerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Test def prepare_batch(repo_id="ibm/patchtsmixer-etth1-test-data", file="pretrain_batch.pt"): # TODO: Make repo public file = hf_hub_download(repo_id=repo_id, filename=file, repo_type="dataset") + check_torch_load_is_safe() batch = torch.load(file, map_location=torch_device, weights_only=True) return batch diff --git a/tests/models/patchtst/test_modeling_patchtst.py b/tests/models/patchtst/test_modeling_patchtst.py index 77a2ce9addc..8b64da124e5 100644 --- a/tests/models/patchtst/test_modeling_patchtst.py +++ b/tests/models/patchtst/test_modeling_patchtst.py @@ -23,6 +23,7 @@ from huggingface_hub import hf_hub_download from transformers import is_torch_available from transformers.models.auto import get_values from transformers.testing_utils import is_flaky, require_torch, slow, torch_device +from transformers.utils import check_torch_load_is_safe from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor @@ -302,6 +303,7 @@ class PatchTSTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase def prepare_batch(repo_id="hf-internal-testing/etth1-hourly-batch", file="train-batch.pt"): file = hf_hub_download(repo_id=repo_id, filename=file, repo_type="dataset") + check_torch_load_is_safe() batch = torch.load(file, map_location=torch_device, weights_only=True) return batch diff --git a/tests/models/time_series_transformer/test_modeling_time_series_transformer.py b/tests/models/time_series_transformer/test_modeling_time_series_transformer.py index dab8bb6c914..42a663e744e 100644 --- a/tests/models/time_series_transformer/test_modeling_time_series_transformer.py +++ b/tests/models/time_series_transformer/test_modeling_time_series_transformer.py @@ -22,6 +22,7 @@ from parameterized import parameterized from transformers import is_torch_available from transformers.testing_utils import is_flaky, require_torch, slow, torch_device +from transformers.utils import check_torch_load_is_safe from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor @@ -480,6 +481,7 @@ class TimeSeriesTransformerModelTest(ModelTesterMixin, PipelineTesterMixin, unit def prepare_batch(filename="train-batch.pt"): file = hf_hub_download(repo_id="hf-internal-testing/tourism-monthly-batch", filename=filename, repo_type="dataset") + check_torch_load_is_safe() batch = torch.load(file, map_location=torch_device, weights_only=True) return batch diff --git a/tests/models/videomae/test_modeling_videomae.py b/tests/models/videomae/test_modeling_videomae.py index 6389fbd6774..1b8b08d14b6 100644 --- a/tests/models/videomae/test_modeling_videomae.py +++ b/tests/models/videomae/test_modeling_videomae.py @@ -32,7 +32,7 @@ from transformers.testing_utils import ( slow, torch_device, ) -from transformers.utils import cached_property, is_torch_available, is_vision_available +from transformers.utils import cached_property, check_torch_load_is_safe, is_torch_available, is_vision_available from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor @@ -455,6 +455,7 @@ class VideoMAEModelIntegrationTest(unittest.TestCase): # add boolean mask, indicating which patches to mask local_path = hf_hub_download(repo_id="hf-internal-testing/bool-masked-pos", filename="bool_masked_pos.pt") + check_torch_load_is_safe() inputs["bool_masked_pos"] = torch.load(local_path, weights_only=True) # forward pass diff --git a/tests/peft_integration/test_peft_integration.py b/tests/peft_integration/test_peft_integration.py index 721ec7ea3f6..4a9132fcf6a 100644 --- a/tests/peft_integration/test_peft_integration.py +++ b/tests/peft_integration/test_peft_integration.py @@ -38,7 +38,7 @@ from transformers.testing_utils import ( slow, torch_device, ) -from transformers.utils import is_torch_available +from transformers.utils import check_torch_load_is_safe, is_torch_available if is_torch_available(): @@ -552,6 +552,7 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin): state_dict_path = hf_hub_download(peft_model_id, "adapter_model.bin") + check_torch_load_is_safe() dummy_state_dict = torch.load(state_dict_path, weights_only=True) model.load_adapter(adapter_state_dict=dummy_state_dict, peft_config=peft_config) @@ -577,6 +578,7 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin): peft_config = LoraConfig() state_dict_path = hf_hub_download(peft_model_id, "adapter_model.bin") + check_torch_load_is_safe() dummy_state_dict = torch.load(state_dict_path, weights_only=True) # this should always work @@ -645,6 +647,7 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin): peft_config = LoraConfig() state_dict_path = hf_hub_download(peft_model_id, "adapter_model.bin") + check_torch_load_is_safe() dummy_state_dict = torch.load(state_dict_path, weights_only=True) # add unexpected key @@ -672,6 +675,7 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin): peft_config = LoraConfig() state_dict_path = hf_hub_download(peft_model_id, "adapter_model.bin") + check_torch_load_is_safe() dummy_state_dict = torch.load(state_dict_path, weights_only=True) # remove a key so that we have missing keys diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 8e1a1c931e1..d9ee30f7cae 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -113,6 +113,7 @@ from transformers.utils import ( SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME, + check_torch_load_is_safe, is_accelerate_available, is_apex_available, is_bitsandbytes_available, @@ -646,6 +647,7 @@ class TrainerIntegrationCommon: else: best_model = RegressionModel() if not safe_weights: + check_torch_load_is_safe() state_dict = torch.load(os.path.join(checkpoint, WEIGHTS_NAME), weights_only=True) else: state_dict = safetensors.torch.load_file(os.path.join(checkpoint, SAFE_WEIGHTS_NAME)) @@ -678,6 +680,7 @@ class TrainerIntegrationCommon: loader = safetensors.torch.load_file weights_file = os.path.join(folder, SAFE_WEIGHTS_NAME) else: + check_torch_load_is_safe() loader = torch.load weights_file = os.path.join(folder, WEIGHTS_NAME) diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index b96678f114d..77d87dc3546 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -74,6 +74,7 @@ from transformers.utils import ( SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME, + check_torch_load_is_safe, ) from transformers.utils.import_utils import ( is_flash_attn_2_available, @@ -739,6 +740,7 @@ class ModelUtilsTest(TestCasePlus): # Note: pickle adds some junk so the weight of the file can end up being slightly bigger than # the size asked for (since we count parameters) if size >= max_size_int + 50000: + check_torch_load_is_safe() state_dict = torch.load(shard_file, weights_only=True) self.assertEqual(len(state_dict), 1)