mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Force torch>=2.6 with torch.load to avoid vulnerability issue (#37785)
* fix all main files * fix test files * oups forgot modular * add link * update message
This commit is contained in:
parent
eefc86aa31
commit
0cfbf9c95b
@ -24,7 +24,7 @@ from filelock import FileLock
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from ...tokenization_utils_base import PreTrainedTokenizerBase
|
||||
from ...utils import logging
|
||||
from ...utils import check_torch_load_is_safe, logging
|
||||
from ..processors.glue import glue_convert_examples_to_features, glue_output_modes, glue_processors
|
||||
from ..processors.utils import InputFeatures
|
||||
|
||||
@ -122,6 +122,7 @@ class GlueDataset(Dataset):
|
||||
with FileLock(lock_path):
|
||||
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
||||
start = time.time()
|
||||
check_torch_load_is_safe()
|
||||
self.features = torch.load(cached_features_file, weights_only=True)
|
||||
logger.info(
|
||||
f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start
|
||||
|
@ -24,7 +24,7 @@ from torch.utils.data import Dataset
|
||||
|
||||
from ...models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING
|
||||
from ...tokenization_utils import PreTrainedTokenizer
|
||||
from ...utils import logging
|
||||
from ...utils import check_torch_load_is_safe, logging
|
||||
from ..processors.squad import SquadFeatures, SquadV1Processor, SquadV2Processor, squad_convert_examples_to_features
|
||||
|
||||
|
||||
@ -148,6 +148,7 @@ class SquadDataset(Dataset):
|
||||
with FileLock(lock_path):
|
||||
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
||||
start = time.time()
|
||||
check_torch_load_is_safe()
|
||||
self.old_features = torch.load(cached_features_file, weights_only=True)
|
||||
|
||||
# Legacy cache files have only features, while new cache files
|
||||
|
@ -27,7 +27,7 @@ from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
import transformers
|
||||
|
||||
from . import is_safetensors_available, is_torch_available
|
||||
from .utils import logging
|
||||
from .utils import check_torch_load_is_safe, logging
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@ -71,6 +71,7 @@ def load_pytorch_checkpoint_in_flax_state_dict(
|
||||
)
|
||||
raise
|
||||
|
||||
check_torch_load_is_safe()
|
||||
pt_state_dict = torch.load(pt_path, map_location="cpu", weights_only=True)
|
||||
logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.")
|
||||
|
||||
@ -247,6 +248,7 @@ def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model):
|
||||
flax_state_dict = {}
|
||||
for shard_file in shard_filenames:
|
||||
# load using msgpack utils
|
||||
check_torch_load_is_safe()
|
||||
pt_state_dict = torch.load(shard_file, weights_only=True)
|
||||
weight_dtypes = {k: v.dtype for k, v in pt_state_dict.items()}
|
||||
pt_state_dict = {
|
||||
|
@ -21,6 +21,7 @@ import numpy
|
||||
|
||||
from .utils import (
|
||||
ExplicitEnum,
|
||||
check_torch_load_is_safe,
|
||||
expand_dims,
|
||||
is_numpy_array,
|
||||
is_safetensors_available,
|
||||
@ -198,6 +199,7 @@ def load_pytorch_checkpoint_in_tf2_model(
|
||||
if pt_path.endswith(".safetensors"):
|
||||
state_dict = safe_load_file(pt_path)
|
||||
else:
|
||||
check_torch_load_is_safe()
|
||||
state_dict = torch.load(pt_path, map_location="cpu", weights_only=True)
|
||||
|
||||
pt_state_dict.update(state_dict)
|
||||
|
@ -94,6 +94,7 @@ from .utils import (
|
||||
ModelOutput,
|
||||
PushToHubMixin,
|
||||
cached_file,
|
||||
check_torch_load_is_safe,
|
||||
copy_func,
|
||||
download_url,
|
||||
extract_commit_hash,
|
||||
@ -445,7 +446,11 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
|
||||
error_message += f"\nMissing key(s): {str_unexpected_keys}."
|
||||
raise RuntimeError(error_message)
|
||||
|
||||
loader = safe_load_file if load_safe else partial(torch.load, map_location="cpu", weights_only=True)
|
||||
if load_safe:
|
||||
loader = safe_load_file
|
||||
else:
|
||||
check_torch_load_is_safe()
|
||||
loader = partial(torch.load, map_location="cpu", weights_only=True)
|
||||
|
||||
for shard_file in shard_files:
|
||||
state_dict = loader(os.path.join(folder, shard_file))
|
||||
@ -490,6 +495,7 @@ def load_state_dict(
|
||||
"""
|
||||
Reads a `safetensor` or a `.bin` checkpoint file. We load the checkpoint on "cpu" by default.
|
||||
"""
|
||||
# Use safetensors if possible
|
||||
if checkpoint_file.endswith(".safetensors") and is_safetensors_available():
|
||||
with safe_open(checkpoint_file, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
@ -512,6 +518,9 @@ def load_state_dict(
|
||||
state_dict[k] = f.get_tensor(k)
|
||||
return state_dict
|
||||
|
||||
# Fallback to torch.load (if weights_only was explicitly False, do not check safety as this is known to be unsafe)
|
||||
if weights_only:
|
||||
check_torch_load_is_safe()
|
||||
try:
|
||||
if map_location is None:
|
||||
if (
|
||||
|
@ -29,6 +29,7 @@ import numpy as np
|
||||
from ....tokenization_utils import PreTrainedTokenizer
|
||||
from ....utils import (
|
||||
cached_file,
|
||||
check_torch_load_is_safe,
|
||||
is_sacremoses_available,
|
||||
is_torch_available,
|
||||
logging,
|
||||
@ -222,6 +223,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
|
||||
"from a PyTorch pretrained vocabulary, "
|
||||
"or activate it with environment variables USE_TORCH=1 and USE_TF=0."
|
||||
)
|
||||
check_torch_load_is_safe()
|
||||
vocab_dict = torch.load(pretrained_vocab_file, weights_only=True)
|
||||
|
||||
if vocab_dict is not None:
|
||||
@ -705,6 +707,7 @@ class TransfoXLCorpus:
|
||||
|
||||
# Instantiate tokenizer.
|
||||
corpus = cls(*inputs, **kwargs)
|
||||
check_torch_load_is_safe()
|
||||
corpus_dict = torch.load(resolved_corpus_file, weights_only=True)
|
||||
for key, value in corpus_dict.items():
|
||||
corpus.__dict__[key] = value
|
||||
@ -784,6 +787,7 @@ def get_lm_corpus(datadir, dataset):
|
||||
fn_pickle = os.path.join(datadir, "cache.pkl")
|
||||
if os.path.exists(fn):
|
||||
logger.info("Loading cached dataset...")
|
||||
check_torch_load_is_safe()
|
||||
corpus = torch.load(fn_pickle, weights_only=True)
|
||||
elif os.path.exists(fn):
|
||||
logger.info("Loading cached dataset from pickle...")
|
||||
|
@ -26,6 +26,7 @@ from packaging import version
|
||||
|
||||
from transformers import AutoTokenizer, GPT2Config
|
||||
from transformers.modeling_utils import WEIGHTS_INDEX_NAME, WEIGHTS_NAME
|
||||
from transformers.utils import check_torch_load_is_safe
|
||||
|
||||
|
||||
def add_checkpointing_args(parser):
|
||||
@ -275,6 +276,7 @@ def merge_transformers_sharded_states(path, num_checkpoints):
|
||||
state_dict = {}
|
||||
for i in range(1, num_checkpoints + 1):
|
||||
checkpoint_path = os.path.join(path, f"pytorch_model-{i:05d}-of-{num_checkpoints:05d}.bin")
|
||||
check_torch_load_is_safe()
|
||||
current_chunk = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
|
||||
state_dict.update(current_chunk)
|
||||
return state_dict
|
||||
@ -298,6 +300,7 @@ def get_megatron_sharded_states(args, tp_size, pp_size, pp_rank):
|
||||
checkpoint_path = os.path.join(args.load_path, sub_dir_name, checkpoint_name)
|
||||
if os.path.isfile(checkpoint_path):
|
||||
break
|
||||
check_torch_load_is_safe()
|
||||
state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
|
||||
tp_state_dicts.append(state_dict)
|
||||
return tp_state_dicts
|
||||
@ -338,6 +341,7 @@ def convert_checkpoint_from_megatron_to_transformers(args):
|
||||
rank0_checkpoint_path = os.path.join(args.load_path, sub_dir, rank0_checkpoint_name)
|
||||
break
|
||||
print(f"Loading Megatron-LM checkpoint arguments from: {rank0_checkpoint_path}")
|
||||
check_torch_load_is_safe()
|
||||
state_dict = torch.load(rank0_checkpoint_path, map_location="cpu", weights_only=True)
|
||||
megatron_args = state_dict.get("args", None)
|
||||
if megatron_args is None:
|
||||
@ -634,6 +638,7 @@ def convert_checkpoint_from_transformers_to_megatron(args):
|
||||
sub_dirs = [x for x in os.listdir(args.load_path) if x.startswith("pytorch_model")]
|
||||
if len(sub_dirs) == 1:
|
||||
checkpoint_name = "pytorch_model.bin"
|
||||
check_torch_load_is_safe()
|
||||
state_dict = torch.load(os.path.join(args.load_path, checkpoint_name), map_location="cpu", weights_only=True)
|
||||
else:
|
||||
num_checkpoints = len(sub_dirs) - 1
|
||||
|
@ -41,6 +41,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
check_torch_load_is_safe,
|
||||
is_flash_attn_2_available,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
is_torch_flex_attn_available,
|
||||
@ -4391,7 +4392,8 @@ class Qwen2_5OmniForConditionalGeneration(Qwen2_5OmniPreTrainedModel, Generation
|
||||
self.has_talker = True
|
||||
|
||||
def load_speakers(self, path):
|
||||
for key, value in torch.load(path).items():
|
||||
check_torch_load_is_safe()
|
||||
for key, value in torch.load(path, weights_only=True).items():
|
||||
self.speaker_map[key] = value
|
||||
logger.info("Speaker {} loaded".format(list(self.speaker_map.keys())))
|
||||
|
||||
|
@ -49,6 +49,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
check_torch_load_is_safe,
|
||||
is_flash_attn_2_available,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
logging,
|
||||
@ -4078,7 +4079,8 @@ class Qwen2_5OmniForConditionalGeneration(Qwen2_5OmniPreTrainedModel, Generation
|
||||
self.has_talker = True
|
||||
|
||||
def load_speakers(self, path):
|
||||
for key, value in torch.load(path).items():
|
||||
check_torch_load_is_safe()
|
||||
for key, value in torch.load(path, weights_only=True).items():
|
||||
self.speaker_map[key] = value
|
||||
logger.info("Speaker {} loaded".format(list(self.speaker_map.keys())))
|
||||
|
||||
|
@ -45,6 +45,7 @@ from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
cached_file,
|
||||
check_torch_load_is_safe,
|
||||
is_peft_available,
|
||||
is_safetensors_available,
|
||||
logging,
|
||||
@ -1589,6 +1590,7 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
|
||||
check_torch_load_is_safe()
|
||||
state_dict = torch.load(
|
||||
weight_path,
|
||||
map_location="cpu",
|
||||
@ -1600,6 +1602,9 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
|
||||
# to the original exception.
|
||||
raise
|
||||
|
||||
except ValueError:
|
||||
raise
|
||||
|
||||
except Exception:
|
||||
# For any other exception, we throw a generic error.
|
||||
raise EnvironmentError(
|
||||
|
@ -147,6 +147,7 @@ from .utils import (
|
||||
PushInProgress,
|
||||
PushToHubMixin,
|
||||
can_return_loss,
|
||||
check_torch_load_is_safe,
|
||||
find_labels,
|
||||
is_accelerate_available,
|
||||
is_apex_available,
|
||||
@ -2831,6 +2832,7 @@ class Trainer:
|
||||
logger.warning(
|
||||
"Enabling FP16 and loading from smp < 1.10 checkpoint together is not supported."
|
||||
)
|
||||
check_torch_load_is_safe()
|
||||
state_dict = torch.load(weights_file, map_location="cpu", weights_only=True)
|
||||
# Required for smp to not auto-translate state_dict from hf to smp (is already smp).
|
||||
state_dict["_smp_is_partial"] = False
|
||||
@ -2850,6 +2852,7 @@ class Trainer:
|
||||
if self.args.save_safetensors and os.path.isfile(safe_weights_file):
|
||||
state_dict = safetensors.torch.load_file(safe_weights_file, device="cpu")
|
||||
else:
|
||||
check_torch_load_is_safe()
|
||||
state_dict = torch.load(weights_file, map_location="cpu", weights_only=True)
|
||||
|
||||
# workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
|
||||
@ -2944,6 +2947,7 @@ class Trainer:
|
||||
if self.args.save_safetensors and os.path.isfile(best_safe_model_path):
|
||||
state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu")
|
||||
else:
|
||||
check_torch_load_is_safe()
|
||||
state_dict = torch.load(best_model_path, map_location="cpu", weights_only=True)
|
||||
|
||||
state_dict["_smp_is_partial"] = False
|
||||
@ -2999,6 +3003,7 @@ class Trainer:
|
||||
if self.args.save_safetensors and os.path.isfile(best_safe_model_path):
|
||||
state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu")
|
||||
else:
|
||||
check_torch_load_is_safe()
|
||||
state_dict = torch.load(best_model_path, map_location="cpu", weights_only=True)
|
||||
|
||||
# If the model is on the GPU, it still works!
|
||||
@ -3354,6 +3359,7 @@ class Trainer:
|
||||
# deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init
|
||||
if not isinstance(self.lr_scheduler, DeepSpeedSchedulerWrapper):
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
check_torch_load_is_safe()
|
||||
self.lr_scheduler.load_state_dict(
|
||||
torch.load(os.path.join(checkpoint, SCHEDULER_NAME), weights_only=True)
|
||||
)
|
||||
@ -3386,6 +3392,7 @@ class Trainer:
|
||||
if is_torch_xla_available():
|
||||
# On TPU we have to take some extra precautions to properly load the states on the right device.
|
||||
if self.is_fsdp_xla_v1_enabled:
|
||||
check_torch_load_is_safe()
|
||||
optimizer_state = torch.load(
|
||||
os.path.join(
|
||||
checkpoint, f"rank{self.args.process_index}-of-{self.args.world_size}-{OPTIMIZER_NAME}"
|
||||
@ -3396,10 +3403,12 @@ class Trainer:
|
||||
# We only need `optimizer` when resuming from checkpoint
|
||||
optimizer_state = optimizer_state["optimizer"]
|
||||
else:
|
||||
check_torch_load_is_safe()
|
||||
optimizer_state = torch.load(
|
||||
os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu", weights_only=True
|
||||
)
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
check_torch_load_is_safe()
|
||||
lr_scheduler_state = torch.load(
|
||||
os.path.join(checkpoint, SCHEDULER_NAME), map_location="cpu", weights_only=True
|
||||
)
|
||||
@ -3443,12 +3452,14 @@ class Trainer:
|
||||
**_get_fsdp_ckpt_kwargs(),
|
||||
)
|
||||
else:
|
||||
check_torch_load_is_safe()
|
||||
self.optimizer.load_state_dict(
|
||||
torch.load(
|
||||
os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location, weights_only=True
|
||||
)
|
||||
)
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
check_torch_load_is_safe()
|
||||
self.lr_scheduler.load_state_dict(
|
||||
torch.load(os.path.join(checkpoint, SCHEDULER_NAME), weights_only=True)
|
||||
)
|
||||
@ -3486,6 +3497,7 @@ class Trainer:
|
||||
# Load in scaler states
|
||||
if is_torch_xla_available():
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
check_torch_load_is_safe()
|
||||
scaler_state = torch.load(
|
||||
os.path.join(checkpoint, SCALER_NAME), map_location="cpu", weights_only=True
|
||||
)
|
||||
@ -3494,6 +3506,7 @@ class Trainer:
|
||||
self.accelerator.scaler.load_state_dict(scaler_state)
|
||||
else:
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
check_torch_load_is_safe()
|
||||
self.accelerator.scaler.load_state_dict(
|
||||
torch.load(os.path.join(checkpoint, SCALER_NAME), weights_only=True)
|
||||
)
|
||||
|
@ -115,6 +115,7 @@ from .import_utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
ccl_version,
|
||||
check_torch_load_is_safe,
|
||||
direct_transformers_import,
|
||||
get_torch_version,
|
||||
is_accelerate_available,
|
||||
|
@ -1387,6 +1387,16 @@ def is_rich_available():
|
||||
return _rich_available
|
||||
|
||||
|
||||
def check_torch_load_is_safe():
|
||||
if not is_torch_greater_or_equal("2.6"):
|
||||
raise ValueError(
|
||||
"Due to a serious vulnerability issue in `torch.load`, even with `weights_only=True`, we now require users "
|
||||
"to upgrade torch to at least v2.6 in order to use the function. This version restriction does not apply "
|
||||
"when loading files with safetensors."
|
||||
"\nSee the vulnerability report here https://nvd.nist.gov/vuln/detail/CVE-2025-32434"
|
||||
)
|
||||
|
||||
|
||||
# docstyle-ignore
|
||||
AV_IMPORT_ERROR = """
|
||||
{0} requires the PyAv library but it was not found in your environment. You can install it with:
|
||||
|
@ -21,6 +21,7 @@ from huggingface_hub import hf_hub_download
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.testing_utils import is_flaky, require_torch, slow, torch_device
|
||||
from transformers.utils import check_torch_load_is_safe
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||
@ -414,6 +415,7 @@ class AutoformerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
|
||||
|
||||
def prepare_batch(filename="train-batch.pt"):
|
||||
file = hf_hub_download(repo_id="hf-internal-testing/tourism-monthly-batch", filename=filename, repo_type="dataset")
|
||||
check_torch_load_is_safe()
|
||||
batch = torch.load(file, map_location=torch_device, weights_only=True)
|
||||
return batch
|
||||
|
||||
|
@ -22,6 +22,7 @@ from huggingface_hub import hf_hub_download
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.testing_utils import is_flaky, require_torch, slow, torch_device
|
||||
from transformers.utils import check_torch_load_is_safe
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||
@ -475,6 +476,7 @@ class InformerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
|
||||
def prepare_batch(filename="train-batch.pt"):
|
||||
file = hf_hub_download(repo_id="hf-internal-testing/tourism-monthly-batch", filename=filename, repo_type="dataset")
|
||||
check_torch_load_is_safe()
|
||||
batch = torch.load(file, map_location=torch_device, weights_only=True)
|
||||
return batch
|
||||
|
||||
|
@ -33,6 +33,7 @@ from transformers.testing_utils import (
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import check_torch_load_is_safe
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
@ -366,6 +367,7 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
filename="llava_1_6_input_ids.pt",
|
||||
repo_type="dataset",
|
||||
)
|
||||
check_torch_load_is_safe()
|
||||
original_input_ids = torch.load(filepath, map_location="cpu", weights_only=True)
|
||||
# replace -200 by image_token_index (since we use token ID = 32000 for the image token)
|
||||
# remove image token indices because HF impl expands image tokens `image_seq_length` times
|
||||
@ -378,6 +380,7 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
filename="llava_1_6_pixel_values.pt",
|
||||
repo_type="dataset",
|
||||
)
|
||||
check_torch_load_is_safe()
|
||||
original_pixel_values = torch.load(filepath, map_location="cpu", weights_only=True)
|
||||
assert torch.allclose(original_pixel_values, inputs.pixel_values.half())
|
||||
|
||||
|
@ -412,7 +412,6 @@ class OPTEmbeddingsTest(unittest.TestCase):
|
||||
# verify that prompt without BOS token is identical to Metaseq -> add_special_tokens=False
|
||||
inputs = tokenizer(prompts, return_tensors="pt", padding=True, add_special_tokens=False)
|
||||
logits = model(inputs.input_ids, attention_mask=inputs.attention_mask)[0].mean(dim=-1)
|
||||
# logits_meta = torch.load(self.path_logits_meta)
|
||||
logits_meta = torch.Tensor(
|
||||
[
|
||||
[1.3851, -13.8923, -10.5229, -10.7533, -0.2309, -10.2384, -0.5365, -9.0947, -5.1670],
|
||||
|
@ -27,6 +27,7 @@ from parameterized import parameterized
|
||||
from transformers import is_torch_available
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.testing_utils import is_flaky, require_torch, slow, torch_device
|
||||
from transformers.utils import check_torch_load_is_safe
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||
@ -451,6 +452,7 @@ class PatchTSMixerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Test
|
||||
def prepare_batch(repo_id="ibm/patchtsmixer-etth1-test-data", file="pretrain_batch.pt"):
|
||||
# TODO: Make repo public
|
||||
file = hf_hub_download(repo_id=repo_id, filename=file, repo_type="dataset")
|
||||
check_torch_load_is_safe()
|
||||
batch = torch.load(file, map_location=torch_device, weights_only=True)
|
||||
return batch
|
||||
|
||||
|
@ -23,6 +23,7 @@ from huggingface_hub import hf_hub_download
|
||||
from transformers import is_torch_available
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.testing_utils import is_flaky, require_torch, slow, torch_device
|
||||
from transformers.utils import check_torch_load_is_safe
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||
@ -302,6 +303,7 @@ class PatchTSTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
|
||||
def prepare_batch(repo_id="hf-internal-testing/etth1-hourly-batch", file="train-batch.pt"):
|
||||
file = hf_hub_download(repo_id=repo_id, filename=file, repo_type="dataset")
|
||||
check_torch_load_is_safe()
|
||||
batch = torch.load(file, map_location=torch_device, weights_only=True)
|
||||
return batch
|
||||
|
||||
|
@ -22,6 +22,7 @@ from parameterized import parameterized
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.testing_utils import is_flaky, require_torch, slow, torch_device
|
||||
from transformers.utils import check_torch_load_is_safe
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||
@ -480,6 +481,7 @@ class TimeSeriesTransformerModelTest(ModelTesterMixin, PipelineTesterMixin, unit
|
||||
|
||||
def prepare_batch(filename="train-batch.pt"):
|
||||
file = hf_hub_download(repo_id="hf-internal-testing/tourism-monthly-batch", filename=filename, repo_type="dataset")
|
||||
check_torch_load_is_safe()
|
||||
batch = torch.load(file, map_location=torch_device, weights_only=True)
|
||||
return batch
|
||||
|
||||
|
@ -32,7 +32,7 @@ from transformers.testing_utils import (
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import cached_property, is_torch_available, is_vision_available
|
||||
from transformers.utils import cached_property, check_torch_load_is_safe, is_torch_available, is_vision_available
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||
@ -455,6 +455,7 @@ class VideoMAEModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
# add boolean mask, indicating which patches to mask
|
||||
local_path = hf_hub_download(repo_id="hf-internal-testing/bool-masked-pos", filename="bool_masked_pos.pt")
|
||||
check_torch_load_is_safe()
|
||||
inputs["bool_masked_pos"] = torch.load(local_path, weights_only=True)
|
||||
|
||||
# forward pass
|
||||
|
@ -38,7 +38,7 @@ from transformers.testing_utils import (
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import is_torch_available
|
||||
from transformers.utils import check_torch_load_is_safe, is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@ -552,6 +552,7 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
|
||||
|
||||
state_dict_path = hf_hub_download(peft_model_id, "adapter_model.bin")
|
||||
|
||||
check_torch_load_is_safe()
|
||||
dummy_state_dict = torch.load(state_dict_path, weights_only=True)
|
||||
|
||||
model.load_adapter(adapter_state_dict=dummy_state_dict, peft_config=peft_config)
|
||||
@ -577,6 +578,7 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
|
||||
|
||||
peft_config = LoraConfig()
|
||||
state_dict_path = hf_hub_download(peft_model_id, "adapter_model.bin")
|
||||
check_torch_load_is_safe()
|
||||
dummy_state_dict = torch.load(state_dict_path, weights_only=True)
|
||||
|
||||
# this should always work
|
||||
@ -645,6 +647,7 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
|
||||
|
||||
peft_config = LoraConfig()
|
||||
state_dict_path = hf_hub_download(peft_model_id, "adapter_model.bin")
|
||||
check_torch_load_is_safe()
|
||||
dummy_state_dict = torch.load(state_dict_path, weights_only=True)
|
||||
|
||||
# add unexpected key
|
||||
@ -672,6 +675,7 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
|
||||
|
||||
peft_config = LoraConfig()
|
||||
state_dict_path = hf_hub_download(peft_model_id, "adapter_model.bin")
|
||||
check_torch_load_is_safe()
|
||||
dummy_state_dict = torch.load(state_dict_path, weights_only=True)
|
||||
|
||||
# remove a key so that we have missing keys
|
||||
|
@ -113,6 +113,7 @@ from transformers.utils import (
|
||||
SAFE_WEIGHTS_NAME,
|
||||
WEIGHTS_INDEX_NAME,
|
||||
WEIGHTS_NAME,
|
||||
check_torch_load_is_safe,
|
||||
is_accelerate_available,
|
||||
is_apex_available,
|
||||
is_bitsandbytes_available,
|
||||
@ -646,6 +647,7 @@ class TrainerIntegrationCommon:
|
||||
else:
|
||||
best_model = RegressionModel()
|
||||
if not safe_weights:
|
||||
check_torch_load_is_safe()
|
||||
state_dict = torch.load(os.path.join(checkpoint, WEIGHTS_NAME), weights_only=True)
|
||||
else:
|
||||
state_dict = safetensors.torch.load_file(os.path.join(checkpoint, SAFE_WEIGHTS_NAME))
|
||||
@ -678,6 +680,7 @@ class TrainerIntegrationCommon:
|
||||
loader = safetensors.torch.load_file
|
||||
weights_file = os.path.join(folder, SAFE_WEIGHTS_NAME)
|
||||
else:
|
||||
check_torch_load_is_safe()
|
||||
loader = torch.load
|
||||
weights_file = os.path.join(folder, WEIGHTS_NAME)
|
||||
|
||||
|
@ -74,6 +74,7 @@ from transformers.utils import (
|
||||
SAFE_WEIGHTS_NAME,
|
||||
WEIGHTS_INDEX_NAME,
|
||||
WEIGHTS_NAME,
|
||||
check_torch_load_is_safe,
|
||||
)
|
||||
from transformers.utils.import_utils import (
|
||||
is_flash_attn_2_available,
|
||||
@ -739,6 +740,7 @@ class ModelUtilsTest(TestCasePlus):
|
||||
# Note: pickle adds some junk so the weight of the file can end up being slightly bigger than
|
||||
# the size asked for (since we count parameters)
|
||||
if size >= max_size_int + 50000:
|
||||
check_torch_load_is_safe()
|
||||
state_dict = torch.load(shard_file, weights_only=True)
|
||||
self.assertEqual(len(state_dict), 1)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user