mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-12 17:20:03 +06:00
Set weights_only in torch.load (#36991)
This commit is contained in:
parent
de77f5b1ec
commit
41a0e58e5b
@ -36,12 +36,12 @@ import optax
|
|||||||
# for dataset and preprocessing
|
# for dataset and preprocessing
|
||||||
import torch
|
import torch
|
||||||
import torchvision
|
import torchvision
|
||||||
import torchvision.transforms as transforms
|
|
||||||
from flax import jax_utils
|
from flax import jax_utils
|
||||||
from flax.jax_utils import pad_shard_unpad, unreplicate
|
from flax.jax_utils import pad_shard_unpad, unreplicate
|
||||||
from flax.training import train_state
|
from flax.training import train_state
|
||||||
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
||||||
from huggingface_hub import HfApi
|
from huggingface_hub import HfApi
|
||||||
|
from torchvision import transforms
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
|
@ -113,7 +113,7 @@ if is_torch_available():
|
|||||||
with FileLock(lock_path):
|
with FileLock(lock_path):
|
||||||
if os.path.exists(cached_features_file) and not overwrite_cache:
|
if os.path.exists(cached_features_file) and not overwrite_cache:
|
||||||
logger.info(f"Loading features from cached file {cached_features_file}")
|
logger.info(f"Loading features from cached file {cached_features_file}")
|
||||||
self.features = torch.load(cached_features_file)
|
self.features = torch.load(cached_features_file, weights_only=True)
|
||||||
else:
|
else:
|
||||||
logger.info(f"Creating features from dataset file at {data_dir}")
|
logger.info(f"Creating features from dataset file at {data_dir}")
|
||||||
label_list = processor.get_labels()
|
label_list = processor.get_labels()
|
||||||
|
@ -81,7 +81,7 @@ class GLUETransformer(BaseTransformer):
|
|||||||
|
|
||||||
cached_features_file = self._feature_file(mode)
|
cached_features_file = self._feature_file(mode)
|
||||||
logger.info("Loading features from cached file %s", cached_features_file)
|
logger.info("Loading features from cached file %s", cached_features_file)
|
||||||
features = torch.load(cached_features_file)
|
features = torch.load(cached_features_file, weights_only=True)
|
||||||
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
|
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
|
||||||
all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
|
all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
|
||||||
all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
|
all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
|
||||||
|
@ -63,7 +63,7 @@ class NERTransformer(BaseTransformer):
|
|||||||
cached_features_file = self._feature_file(mode)
|
cached_features_file = self._feature_file(mode)
|
||||||
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
||||||
logger.info("Loading features from cached file %s", cached_features_file)
|
logger.info("Loading features from cached file %s", cached_features_file)
|
||||||
features = torch.load(cached_features_file)
|
features = torch.load(cached_features_file, weights_only=True)
|
||||||
else:
|
else:
|
||||||
logger.info("Creating features from dataset file at %s", args.data_dir)
|
logger.info("Creating features from dataset file at %s", args.data_dir)
|
||||||
examples = self.token_classification_task.read_examples_from_file(args.data_dir, mode)
|
examples = self.token_classification_task.read_examples_from_file(args.data_dir, mode)
|
||||||
@ -89,7 +89,7 @@ class NERTransformer(BaseTransformer):
|
|||||||
"Load datasets. Called after prepare data."
|
"Load datasets. Called after prepare data."
|
||||||
cached_features_file = self._feature_file(mode)
|
cached_features_file = self._feature_file(mode)
|
||||||
logger.info("Loading features from cached file %s", cached_features_file)
|
logger.info("Loading features from cached file %s", cached_features_file)
|
||||||
features = torch.load(cached_features_file)
|
features = torch.load(cached_features_file, weights_only=True)
|
||||||
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
|
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
|
||||||
all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
|
all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
|
||||||
if features[0].token_type_ids is not None:
|
if features[0].token_type_ids is not None:
|
||||||
|
@ -105,8 +105,8 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
os.path.join(args.model_name_or_path, "scheduler.pt")
|
os.path.join(args.model_name_or_path, "scheduler.pt")
|
||||||
):
|
):
|
||||||
# Load in optimizer and scheduler states
|
# Load in optimizer and scheduler states
|
||||||
optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
|
optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt"), weights_only=True))
|
||||||
scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))
|
scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt"), weights_only=True))
|
||||||
|
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
try:
|
try:
|
||||||
@ -417,7 +417,7 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
|
|||||||
# Init features and dataset from cache if it exists
|
# Init features and dataset from cache if it exists
|
||||||
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
||||||
logger.info("Loading features from cached file %s", cached_features_file)
|
logger.info("Loading features from cached file %s", cached_features_file)
|
||||||
features_and_dataset = torch.load(cached_features_file)
|
features_and_dataset = torch.load(cached_features_file, weights_only=True)
|
||||||
features, dataset, examples = (
|
features, dataset, examples = (
|
||||||
features_and_dataset["features"],
|
features_and_dataset["features"],
|
||||||
features_and_dataset["dataset"],
|
features_and_dataset["dataset"],
|
||||||
|
@ -244,7 +244,7 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
|
|||||||
)
|
)
|
||||||
if os.path.exists(cached_features_file) and not args.overwrite_cache and not output_examples:
|
if os.path.exists(cached_features_file) and not args.overwrite_cache and not output_examples:
|
||||||
logger.info("Loading features from cached file %s", cached_features_file)
|
logger.info("Loading features from cached file %s", cached_features_file)
|
||||||
features = torch.load(cached_features_file)
|
features = torch.load(cached_features_file, weights_only=True)
|
||||||
else:
|
else:
|
||||||
logger.info("Creating features from dataset file at %s", input_file)
|
logger.info("Creating features from dataset file at %s", input_file)
|
||||||
examples = read_swag_examples(input_file)
|
examples = read_swag_examples(input_file)
|
||||||
|
@ -22,7 +22,7 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
def convert(src_path: str, map_location: str = "cpu", save_path: Union[str, None] = None) -> None:
|
def convert(src_path: str, map_location: str = "cpu", save_path: Union[str, None] = None) -> None:
|
||||||
"""Convert a pytorch_model.bin or model.pt file to torch.float16 for faster downloads, less disk space."""
|
"""Convert a pytorch_model.bin or model.pt file to torch.float16 for faster downloads, less disk space."""
|
||||||
state_dict = torch.load(src_path, map_location=map_location)
|
state_dict = torch.load(src_path, map_location=map_location, weights_only=True)
|
||||||
for k, v in tqdm(state_dict.items()):
|
for k, v in tqdm(state_dict.items()):
|
||||||
if not isinstance(v, torch.Tensor):
|
if not isinstance(v, torch.Tensor):
|
||||||
raise TypeError("FP16 conversion only works on paths that are saved state dicts, like pytorch_model.bin")
|
raise TypeError("FP16 conversion only works on paths that are saved state dicts, like pytorch_model.bin")
|
||||||
|
@ -242,7 +242,7 @@ if is_torch_available():
|
|||||||
with FileLock(lock_path):
|
with FileLock(lock_path):
|
||||||
if os.path.exists(cached_features_file) and not overwrite_cache:
|
if os.path.exists(cached_features_file) and not overwrite_cache:
|
||||||
logger.info(f"Loading features from cached file {cached_features_file}")
|
logger.info(f"Loading features from cached file {cached_features_file}")
|
||||||
self.features = torch.load(cached_features_file)
|
self.features = torch.load(cached_features_file, weights_only=True)
|
||||||
else:
|
else:
|
||||||
logger.info(f"Creating features from dataset file at {data_dir}")
|
logger.info(f"Creating features from dataset file at {data_dir}")
|
||||||
examples = token_classification_task.read_examples_from_file(data_dir, mode)
|
examples = token_classification_task.read_examples_from_file(data_dir, mode)
|
||||||
|
@ -277,12 +277,7 @@ def convert_pt_checkpoint_to_tf(
|
|||||||
if compare_with_pt_model:
|
if compare_with_pt_model:
|
||||||
tfo = tf_model(tf_model.dummy_inputs, training=False) # build the network
|
tfo = tf_model(tf_model.dummy_inputs, training=False) # build the network
|
||||||
|
|
||||||
weights_only_kwarg = {"weights_only": True}
|
state_dict = torch.load(pytorch_checkpoint_path, map_location="cpu", weights_only=True)
|
||||||
state_dict = torch.load(
|
|
||||||
pytorch_checkpoint_path,
|
|
||||||
map_location="cpu",
|
|
||||||
**weights_only_kwarg,
|
|
||||||
)
|
|
||||||
pt_model = pt_model_class.from_pretrained(
|
pt_model = pt_model_class.from_pretrained(
|
||||||
pretrained_model_name_or_path=None, config=config, state_dict=state_dict
|
pretrained_model_name_or_path=None, config=config, state_dict=state_dict
|
||||||
)
|
)
|
||||||
|
@ -148,7 +148,7 @@ class SquadDataset(Dataset):
|
|||||||
with FileLock(lock_path):
|
with FileLock(lock_path):
|
||||||
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
||||||
start = time.time()
|
start = time.time()
|
||||||
self.old_features = torch.load(cached_features_file)
|
self.old_features = torch.load(cached_features_file, weights_only=True)
|
||||||
|
|
||||||
# Legacy cache files have only features, while new cache files
|
# Legacy cache files have only features, while new cache files
|
||||||
# will have dataset and examples also.
|
# will have dataset and examples also.
|
||||||
|
@ -71,8 +71,7 @@ def load_pytorch_checkpoint_in_flax_state_dict(
|
|||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
weights_only_kwarg = {"weights_only": True}
|
pt_state_dict = torch.load(pt_path, map_location="cpu", weights_only=True)
|
||||||
pt_state_dict = torch.load(pt_path, map_location="cpu", **weights_only_kwarg)
|
|
||||||
logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.")
|
logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.")
|
||||||
|
|
||||||
flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model)
|
flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model)
|
||||||
@ -248,8 +247,7 @@ def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model):
|
|||||||
flax_state_dict = {}
|
flax_state_dict = {}
|
||||||
for shard_file in shard_filenames:
|
for shard_file in shard_filenames:
|
||||||
# load using msgpack utils
|
# load using msgpack utils
|
||||||
weights_only_kwarg = {"weights_only": True}
|
pt_state_dict = torch.load(shard_file, weights_only=True)
|
||||||
pt_state_dict = torch.load(shard_file, **weights_only_kwarg)
|
|
||||||
weight_dtypes = {k: v.dtype for k, v in pt_state_dict.items()}
|
weight_dtypes = {k: v.dtype for k, v in pt_state_dict.items()}
|
||||||
pt_state_dict = {
|
pt_state_dict = {
|
||||||
k: v.numpy() if v.dtype != torch.bfloat16 else v.float().numpy() for k, v in pt_state_dict.items()
|
k: v.numpy() if v.dtype != torch.bfloat16 else v.float().numpy() for k, v in pt_state_dict.items()
|
||||||
|
@ -198,8 +198,7 @@ def load_pytorch_checkpoint_in_tf2_model(
|
|||||||
if pt_path.endswith(".safetensors"):
|
if pt_path.endswith(".safetensors"):
|
||||||
state_dict = safe_load_file(pt_path)
|
state_dict = safe_load_file(pt_path)
|
||||||
else:
|
else:
|
||||||
weights_only_kwarg = {"weights_only": True}
|
state_dict = torch.load(pt_path, map_location="cpu", weights_only=True)
|
||||||
state_dict = torch.load(pt_path, map_location="cpu", **weights_only_kwarg)
|
|
||||||
|
|
||||||
pt_state_dict.update(state_dict)
|
pt_state_dict.update(state_dict)
|
||||||
|
|
||||||
|
@ -504,8 +504,7 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
|
|||||||
error_message += f"\nMissing key(s): {str_unexpected_keys}."
|
error_message += f"\nMissing key(s): {str_unexpected_keys}."
|
||||||
raise RuntimeError(error_message)
|
raise RuntimeError(error_message)
|
||||||
|
|
||||||
weights_only_kwarg = {"weights_only": True}
|
loader = safe_load_file if load_safe else partial(torch.load, map_location="cpu", weights_only=True)
|
||||||
loader = safe_load_file if load_safe else partial(torch.load, map_location="cpu", **weights_only_kwarg)
|
|
||||||
|
|
||||||
for shard_file in shard_files:
|
for shard_file in shard_files:
|
||||||
state_dict = loader(os.path.join(folder, shard_file))
|
state_dict = loader(os.path.join(folder, shard_file))
|
||||||
@ -598,11 +597,10 @@ def load_state_dict(
|
|||||||
and is_zipfile(checkpoint_file)
|
and is_zipfile(checkpoint_file)
|
||||||
):
|
):
|
||||||
extra_args = {"mmap": True}
|
extra_args = {"mmap": True}
|
||||||
weights_only_kwarg = {"weights_only": weights_only}
|
|
||||||
return torch.load(
|
return torch.load(
|
||||||
checkpoint_file,
|
checkpoint_file,
|
||||||
map_location=map_location,
|
map_location=map_location,
|
||||||
**weights_only_kwarg,
|
weights_only=weights_only,
|
||||||
**extra_args,
|
**extra_args,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -1216,7 +1214,7 @@ def _get_torch_dtype(
|
|||||||
weights_only: bool,
|
weights_only: bool,
|
||||||
) -> Tuple[PretrainedConfig, Optional[torch.dtype], Optional[torch.dtype]]:
|
) -> Tuple[PretrainedConfig, Optional[torch.dtype], Optional[torch.dtype]]:
|
||||||
"""Find the correct `torch_dtype` to use based on provided arguments. Also update the `config` based on the
|
"""Find the correct `torch_dtype` to use based on provided arguments. Also update the `config` based on the
|
||||||
infered dtype. We do the following:
|
inferred dtype. We do the following:
|
||||||
1. If torch_dtype is not None, we use that dtype
|
1. If torch_dtype is not None, we use that dtype
|
||||||
2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first
|
2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first
|
||||||
weights entry that is of a floating type - we assume all floating dtype weights are of the same dtype
|
weights entry that is of a floating type - we assume all floating dtype weights are of the same dtype
|
||||||
|
@ -207,7 +207,7 @@ def convert_wav2vec2_checkpoint(
|
|||||||
hf_wav2vec = Data2VecAudioModel(config)
|
hf_wav2vec = Data2VecAudioModel(config)
|
||||||
data2vec_checkpoint_dir = os.path.dirname(checkpoint_path)
|
data2vec_checkpoint_dir = os.path.dirname(checkpoint_path)
|
||||||
|
|
||||||
state_dict = torch.load(checkpoint_path)
|
state_dict = torch.load(checkpoint_path, weights_only=True)
|
||||||
state_dict["model"]["final_proj.weight"] = state_dict["model"].pop("final_proj.0.weight")
|
state_dict["model"]["final_proj.weight"] = state_dict["model"].pop("final_proj.0.weight")
|
||||||
state_dict["model"]["final_proj.bias"] = state_dict["model"].pop("final_proj.0.bias")
|
state_dict["model"]["final_proj.bias"] = state_dict["model"].pop("final_proj.0.bias")
|
||||||
converted_ckpt = os.path.join(data2vec_checkpoint_dir, "converted.pt")
|
converted_ckpt = os.path.join(data2vec_checkpoint_dir, "converted.pt")
|
||||||
|
@ -121,7 +121,7 @@ def convert_phi_weights(
|
|||||||
if model_path.endswith("safetensors"):
|
if model_path.endswith("safetensors"):
|
||||||
loaded_weights = safetensors.torch.load_file(model_path, device=device)
|
loaded_weights = safetensors.torch.load_file(model_path, device=device)
|
||||||
else:
|
else:
|
||||||
loaded_weights = torch.load(model_path, map_location=device)
|
loaded_weights = torch.load(model_path, map_location=device, weights_only=True)
|
||||||
model_checkpoint.update(**loaded_weights)
|
model_checkpoint.update(**loaded_weights)
|
||||||
|
|
||||||
model_type = model_name.split("/")[1] # phi-1 or phi-1_5 or phi-2
|
model_type = model_name.split("/")[1] # phi-1 or phi-1_5 or phi-2
|
||||||
|
@ -1589,11 +1589,10 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
|
|||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
)
|
)
|
||||||
|
|
||||||
weights_only_kwarg = {"weights_only": True}
|
|
||||||
state_dict = torch.load(
|
state_dict = torch.load(
|
||||||
weight_path,
|
weight_path,
|
||||||
map_location="cpu",
|
map_location="cpu",
|
||||||
**weights_only_kwarg,
|
weights_only=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
except EnvironmentError:
|
except EnvironmentError:
|
||||||
|
@ -2820,7 +2820,6 @@ class Trainer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if os.path.isfile(weights_file) or os.path.isfile(safe_weights_file) or is_fsdp_ckpt:
|
if os.path.isfile(weights_file) or os.path.isfile(safe_weights_file) or is_fsdp_ckpt:
|
||||||
weights_only_kwarg = {"weights_only": True}
|
|
||||||
# If the model is on the GPU, it still works!
|
# If the model is on the GPU, it still works!
|
||||||
if is_sagemaker_mp_enabled():
|
if is_sagemaker_mp_enabled():
|
||||||
if os.path.isfile(os.path.join(resume_from_checkpoint, "user_content.pt")):
|
if os.path.isfile(os.path.join(resume_from_checkpoint, "user_content.pt")):
|
||||||
@ -2836,11 +2835,7 @@ class Trainer:
|
|||||||
logger.warning(
|
logger.warning(
|
||||||
"Enabling FP16 and loading from smp < 1.10 checkpoint together is not supported."
|
"Enabling FP16 and loading from smp < 1.10 checkpoint together is not supported."
|
||||||
)
|
)
|
||||||
state_dict = torch.load(
|
state_dict = torch.load(weights_file, map_location="cpu", weights_only=True)
|
||||||
weights_file,
|
|
||||||
map_location="cpu",
|
|
||||||
**weights_only_kwarg,
|
|
||||||
)
|
|
||||||
# Required for smp to not auto-translate state_dict from hf to smp (is already smp).
|
# Required for smp to not auto-translate state_dict from hf to smp (is already smp).
|
||||||
state_dict["_smp_is_partial"] = False
|
state_dict["_smp_is_partial"] = False
|
||||||
load_result = model.load_state_dict(state_dict, strict=True)
|
load_result = model.load_state_dict(state_dict, strict=True)
|
||||||
@ -2859,11 +2854,7 @@ class Trainer:
|
|||||||
if self.args.save_safetensors and os.path.isfile(safe_weights_file):
|
if self.args.save_safetensors and os.path.isfile(safe_weights_file):
|
||||||
state_dict = safetensors.torch.load_file(safe_weights_file, device="cpu")
|
state_dict = safetensors.torch.load_file(safe_weights_file, device="cpu")
|
||||||
else:
|
else:
|
||||||
state_dict = torch.load(
|
state_dict = torch.load(weights_file, map_location="cpu", weights_only=True)
|
||||||
weights_file,
|
|
||||||
map_location="cpu",
|
|
||||||
**weights_only_kwarg,
|
|
||||||
)
|
|
||||||
|
|
||||||
# workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
|
# workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
|
||||||
# which takes *args instead of **kwargs
|
# which takes *args instead of **kwargs
|
||||||
@ -2941,7 +2932,6 @@ class Trainer:
|
|||||||
or os.path.exists(best_safe_adapter_model_path)
|
or os.path.exists(best_safe_adapter_model_path)
|
||||||
):
|
):
|
||||||
has_been_loaded = True
|
has_been_loaded = True
|
||||||
weights_only_kwarg = {"weights_only": True}
|
|
||||||
if is_sagemaker_mp_enabled():
|
if is_sagemaker_mp_enabled():
|
||||||
if os.path.isfile(os.path.join(self.state.best_model_checkpoint, "user_content.pt")):
|
if os.path.isfile(os.path.join(self.state.best_model_checkpoint, "user_content.pt")):
|
||||||
# If the 'user_content.pt' file exists, load with the new smp api.
|
# If the 'user_content.pt' file exists, load with the new smp api.
|
||||||
@ -2958,11 +2948,7 @@ class Trainer:
|
|||||||
if self.args.save_safetensors and os.path.isfile(best_safe_model_path):
|
if self.args.save_safetensors and os.path.isfile(best_safe_model_path):
|
||||||
state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu")
|
state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu")
|
||||||
else:
|
else:
|
||||||
state_dict = torch.load(
|
state_dict = torch.load(best_model_path, map_location="cpu", weights_only=True)
|
||||||
best_model_path,
|
|
||||||
map_location="cpu",
|
|
||||||
**weights_only_kwarg,
|
|
||||||
)
|
|
||||||
|
|
||||||
state_dict["_smp_is_partial"] = False
|
state_dict["_smp_is_partial"] = False
|
||||||
load_result = model.load_state_dict(state_dict, strict=True)
|
load_result = model.load_state_dict(state_dict, strict=True)
|
||||||
@ -3017,11 +3003,7 @@ class Trainer:
|
|||||||
if self.args.save_safetensors and os.path.isfile(best_safe_model_path):
|
if self.args.save_safetensors and os.path.isfile(best_safe_model_path):
|
||||||
state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu")
|
state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu")
|
||||||
else:
|
else:
|
||||||
state_dict = torch.load(
|
state_dict = torch.load(best_model_path, map_location="cpu", weights_only=True)
|
||||||
best_model_path,
|
|
||||||
map_location="cpu",
|
|
||||||
**weights_only_kwarg,
|
|
||||||
)
|
|
||||||
|
|
||||||
# If the model is on the GPU, it still works!
|
# If the model is on the GPU, it still works!
|
||||||
# workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
|
# workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
|
||||||
@ -3142,7 +3124,7 @@ class Trainer:
|
|||||||
return
|
return
|
||||||
|
|
||||||
with safe_globals():
|
with safe_globals():
|
||||||
checkpoint_rng_state = torch.load(rng_file)
|
checkpoint_rng_state = torch.load(rng_file, weights_only=True)
|
||||||
random.setstate(checkpoint_rng_state["python"])
|
random.setstate(checkpoint_rng_state["python"])
|
||||||
np.random.set_state(checkpoint_rng_state["numpy"])
|
np.random.set_state(checkpoint_rng_state["numpy"])
|
||||||
torch.random.set_rng_state(checkpoint_rng_state["cpu"])
|
torch.random.set_rng_state(checkpoint_rng_state["cpu"])
|
||||||
@ -3375,7 +3357,9 @@ class Trainer:
|
|||||||
# deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init
|
# deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init
|
||||||
if not isinstance(self.lr_scheduler, DeepSpeedSchedulerWrapper):
|
if not isinstance(self.lr_scheduler, DeepSpeedSchedulerWrapper):
|
||||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||||
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
|
self.lr_scheduler.load_state_dict(
|
||||||
|
torch.load(os.path.join(checkpoint, SCHEDULER_NAME), weights_only=True)
|
||||||
|
)
|
||||||
reissue_pt_warnings(caught_warnings)
|
reissue_pt_warnings(caught_warnings)
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -3410,13 +3394,18 @@ class Trainer:
|
|||||||
checkpoint, f"rank{self.args.process_index}-of-{self.args.world_size}-{OPTIMIZER_NAME}"
|
checkpoint, f"rank{self.args.process_index}-of-{self.args.world_size}-{OPTIMIZER_NAME}"
|
||||||
),
|
),
|
||||||
map_location="cpu",
|
map_location="cpu",
|
||||||
|
weights_only=True,
|
||||||
)
|
)
|
||||||
# We only need `optimizer` when resuming from checkpoint
|
# We only need `optimizer` when resuming from checkpoint
|
||||||
optimizer_state = optimizer_state["optimizer"]
|
optimizer_state = optimizer_state["optimizer"]
|
||||||
else:
|
else:
|
||||||
optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu")
|
optimizer_state = torch.load(
|
||||||
|
os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu", weights_only=True
|
||||||
|
)
|
||||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||||
lr_scheduler_state = torch.load(os.path.join(checkpoint, SCHEDULER_NAME), map_location="cpu")
|
lr_scheduler_state = torch.load(
|
||||||
|
os.path.join(checkpoint, SCHEDULER_NAME), map_location="cpu", weights_only=True
|
||||||
|
)
|
||||||
reissue_pt_warnings(caught_warnings)
|
reissue_pt_warnings(caught_warnings)
|
||||||
|
|
||||||
xm.send_cpu_data_to_device(optimizer_state, self.args.device)
|
xm.send_cpu_data_to_device(optimizer_state, self.args.device)
|
||||||
@ -3458,10 +3447,14 @@ class Trainer:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.optimizer.load_state_dict(
|
self.optimizer.load_state_dict(
|
||||||
torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)
|
torch.load(
|
||||||
|
os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location, weights_only=True
|
||||||
|
)
|
||||||
)
|
)
|
||||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||||
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
|
self.lr_scheduler.load_state_dict(
|
||||||
|
torch.load(os.path.join(checkpoint, SCHEDULER_NAME), weights_only=True)
|
||||||
|
)
|
||||||
reissue_pt_warnings(caught_warnings)
|
reissue_pt_warnings(caught_warnings)
|
||||||
|
|
||||||
def _save_scaler(self, output_dir):
|
def _save_scaler(self, output_dir):
|
||||||
@ -3496,13 +3489,17 @@ class Trainer:
|
|||||||
# Load in scaler states
|
# Load in scaler states
|
||||||
if is_torch_xla_available():
|
if is_torch_xla_available():
|
||||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||||
scaler_state = torch.load(os.path.join(checkpoint, SCALER_NAME), map_location="cpu")
|
scaler_state = torch.load(
|
||||||
|
os.path.join(checkpoint, SCALER_NAME), map_location="cpu", weights_only=True
|
||||||
|
)
|
||||||
reissue_pt_warnings(caught_warnings)
|
reissue_pt_warnings(caught_warnings)
|
||||||
xm.send_cpu_data_to_device(scaler_state, self.args.device)
|
xm.send_cpu_data_to_device(scaler_state, self.args.device)
|
||||||
self.accelerator.scaler.load_state_dict(scaler_state)
|
self.accelerator.scaler.load_state_dict(scaler_state)
|
||||||
else:
|
else:
|
||||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||||
self.accelerator.scaler.load_state_dict(torch.load(os.path.join(checkpoint, SCALER_NAME)))
|
self.accelerator.scaler.load_state_dict(
|
||||||
|
torch.load(os.path.join(checkpoint, SCALER_NAME), weights_only=True)
|
||||||
|
)
|
||||||
reissue_pt_warnings(caught_warnings)
|
reissue_pt_warnings(caught_warnings)
|
||||||
|
|
||||||
def _load_callback_state(self):
|
def _load_callback_state(self):
|
||||||
|
@ -415,7 +415,7 @@ class AutoformerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
|
|||||||
|
|
||||||
def prepare_batch(filename="train-batch.pt"):
|
def prepare_batch(filename="train-batch.pt"):
|
||||||
file = hf_hub_download(repo_id="hf-internal-testing/tourism-monthly-batch", filename=filename, repo_type="dataset")
|
file = hf_hub_download(repo_id="hf-internal-testing/tourism-monthly-batch", filename=filename, repo_type="dataset")
|
||||||
batch = torch.load(file, map_location=torch_device)
|
batch = torch.load(file, map_location=torch_device, weights_only=True)
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
@ -28,7 +28,7 @@ if is_torch_available():
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
if is_torchvision_available():
|
if is_torchvision_available():
|
||||||
import torchvision.transforms as transforms
|
from torchvision import transforms
|
||||||
|
|
||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
@ -476,7 +476,7 @@ class InformerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
|||||||
|
|
||||||
def prepare_batch(filename="train-batch.pt"):
|
def prepare_batch(filename="train-batch.pt"):
|
||||||
file = hf_hub_download(repo_id="hf-internal-testing/tourism-monthly-batch", filename=filename, repo_type="dataset")
|
file = hf_hub_download(repo_id="hf-internal-testing/tourism-monthly-batch", filename=filename, repo_type="dataset")
|
||||||
batch = torch.load(file, map_location=torch_device)
|
batch = torch.load(file, map_location=torch_device, weights_only=True)
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
@ -408,7 +408,7 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||||||
filename="llava_1_6_input_ids.pt",
|
filename="llava_1_6_input_ids.pt",
|
||||||
repo_type="dataset",
|
repo_type="dataset",
|
||||||
)
|
)
|
||||||
original_input_ids = torch.load(filepath, map_location="cpu")
|
original_input_ids = torch.load(filepath, map_location="cpu", weights_only=True)
|
||||||
# replace -200 by image_token_index (since we use token ID = 32000 for the image token)
|
# replace -200 by image_token_index (since we use token ID = 32000 for the image token)
|
||||||
# remove image token indices because HF impl expands image tokens `image_seq_length` times
|
# remove image token indices because HF impl expands image tokens `image_seq_length` times
|
||||||
original_input_ids = original_input_ids[original_input_ids != -200]
|
original_input_ids = original_input_ids[original_input_ids != -200]
|
||||||
@ -420,7 +420,7 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||||||
filename="llava_1_6_pixel_values.pt",
|
filename="llava_1_6_pixel_values.pt",
|
||||||
repo_type="dataset",
|
repo_type="dataset",
|
||||||
)
|
)
|
||||||
original_pixel_values = torch.load(filepath, map_location="cpu")
|
original_pixel_values = torch.load(filepath, map_location="cpu", weights_only=True)
|
||||||
assert torch.allclose(original_pixel_values, inputs.pixel_values.half())
|
assert torch.allclose(original_pixel_values, inputs.pixel_values.half())
|
||||||
|
|
||||||
# verify generation
|
# verify generation
|
||||||
|
@ -452,7 +452,7 @@ class PatchTSMixerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Test
|
|||||||
def prepare_batch(repo_id="ibm/patchtsmixer-etth1-test-data", file="pretrain_batch.pt"):
|
def prepare_batch(repo_id="ibm/patchtsmixer-etth1-test-data", file="pretrain_batch.pt"):
|
||||||
# TODO: Make repo public
|
# TODO: Make repo public
|
||||||
file = hf_hub_download(repo_id=repo_id, filename=file, repo_type="dataset")
|
file = hf_hub_download(repo_id=repo_id, filename=file, repo_type="dataset")
|
||||||
batch = torch.load(file, map_location=torch_device)
|
batch = torch.load(file, map_location=torch_device, weights_only=True)
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
@ -303,7 +303,7 @@ class PatchTSTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
|||||||
|
|
||||||
def prepare_batch(repo_id="hf-internal-testing/etth1-hourly-batch", file="train-batch.pt"):
|
def prepare_batch(repo_id="hf-internal-testing/etth1-hourly-batch", file="train-batch.pt"):
|
||||||
file = hf_hub_download(repo_id=repo_id, filename=file, repo_type="dataset")
|
file = hf_hub_download(repo_id=repo_id, filename=file, repo_type="dataset")
|
||||||
batch = torch.load(file, map_location=torch_device)
|
batch = torch.load(file, map_location=torch_device, weights_only=True)
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
@ -481,7 +481,7 @@ class TimeSeriesTransformerModelTest(ModelTesterMixin, PipelineTesterMixin, unit
|
|||||||
|
|
||||||
def prepare_batch(filename="train-batch.pt"):
|
def prepare_batch(filename="train-batch.pt"):
|
||||||
file = hf_hub_download(repo_id="hf-internal-testing/tourism-monthly-batch", filename=filename, repo_type="dataset")
|
file = hf_hub_download(repo_id="hf-internal-testing/tourism-monthly-batch", filename=filename, repo_type="dataset")
|
||||||
batch = torch.load(file, map_location=torch_device)
|
batch = torch.load(file, map_location=torch_device, weights_only=True)
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
@ -456,7 +456,7 @@ class VideoMAEModelIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
# add boolean mask, indicating which patches to mask
|
# add boolean mask, indicating which patches to mask
|
||||||
local_path = hf_hub_download(repo_id="hf-internal-testing/bool-masked-pos", filename="bool_masked_pos.pt")
|
local_path = hf_hub_download(repo_id="hf-internal-testing/bool-masked-pos", filename="bool_masked_pos.pt")
|
||||||
inputs["bool_masked_pos"] = torch.load(local_path)
|
inputs["bool_masked_pos"] = torch.load(local_path, weights_only=True)
|
||||||
|
|
||||||
# forward pass
|
# forward pass
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
@ -554,7 +554,7 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
|
|||||||
|
|
||||||
state_dict_path = hf_hub_download(peft_model_id, "adapter_model.bin")
|
state_dict_path = hf_hub_download(peft_model_id, "adapter_model.bin")
|
||||||
|
|
||||||
dummy_state_dict = torch.load(state_dict_path)
|
dummy_state_dict = torch.load(state_dict_path, weights_only=True)
|
||||||
|
|
||||||
model.load_adapter(adapter_state_dict=dummy_state_dict, peft_config=peft_config)
|
model.load_adapter(adapter_state_dict=dummy_state_dict, peft_config=peft_config)
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
@ -579,7 +579,7 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
|
|||||||
|
|
||||||
peft_config = LoraConfig()
|
peft_config = LoraConfig()
|
||||||
state_dict_path = hf_hub_download(peft_model_id, "adapter_model.bin")
|
state_dict_path = hf_hub_download(peft_model_id, "adapter_model.bin")
|
||||||
dummy_state_dict = torch.load(state_dict_path)
|
dummy_state_dict = torch.load(state_dict_path, weights_only=True)
|
||||||
|
|
||||||
# this should always work
|
# this should always work
|
||||||
model.load_adapter(
|
model.load_adapter(
|
||||||
@ -647,7 +647,7 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
|
|||||||
|
|
||||||
peft_config = LoraConfig()
|
peft_config = LoraConfig()
|
||||||
state_dict_path = hf_hub_download(peft_model_id, "adapter_model.bin")
|
state_dict_path = hf_hub_download(peft_model_id, "adapter_model.bin")
|
||||||
dummy_state_dict = torch.load(state_dict_path)
|
dummy_state_dict = torch.load(state_dict_path, weights_only=True)
|
||||||
|
|
||||||
# add unexpected key
|
# add unexpected key
|
||||||
dummy_state_dict["foobar"] = next(iter(dummy_state_dict.values()))
|
dummy_state_dict["foobar"] = next(iter(dummy_state_dict.values()))
|
||||||
@ -674,7 +674,7 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
|
|||||||
|
|
||||||
peft_config = LoraConfig()
|
peft_config = LoraConfig()
|
||||||
state_dict_path = hf_hub_download(peft_model_id, "adapter_model.bin")
|
state_dict_path = hf_hub_download(peft_model_id, "adapter_model.bin")
|
||||||
dummy_state_dict = torch.load(state_dict_path)
|
dummy_state_dict = torch.load(state_dict_path, weights_only=True)
|
||||||
|
|
||||||
# remove a key so that we have missing keys
|
# remove a key so that we have missing keys
|
||||||
key = next(iter(dummy_state_dict.keys()))
|
key = next(iter(dummy_state_dict.keys()))
|
||||||
|
@ -648,7 +648,7 @@ class TrainerIntegrationCommon:
|
|||||||
else:
|
else:
|
||||||
best_model = RegressionModel()
|
best_model = RegressionModel()
|
||||||
if not safe_weights:
|
if not safe_weights:
|
||||||
state_dict = torch.load(os.path.join(checkpoint, WEIGHTS_NAME))
|
state_dict = torch.load(os.path.join(checkpoint, WEIGHTS_NAME), weights_only=True)
|
||||||
else:
|
else:
|
||||||
state_dict = safetensors.torch.load_file(os.path.join(checkpoint, SAFE_WEIGHTS_NAME))
|
state_dict = safetensors.torch.load_file(os.path.join(checkpoint, SAFE_WEIGHTS_NAME))
|
||||||
best_model.load_state_dict(state_dict)
|
best_model.load_state_dict(state_dict)
|
||||||
|
@ -765,7 +765,7 @@ class ModelUtilsTest(TestCasePlus):
|
|||||||
# Note: pickle adds some junk so the weight of the file can end up being slightly bigger than
|
# Note: pickle adds some junk so the weight of the file can end up being slightly bigger than
|
||||||
# the size asked for (since we count parameters)
|
# the size asked for (since we count parameters)
|
||||||
if size >= max_size_int + 50000:
|
if size >= max_size_int + 50000:
|
||||||
state_dict = torch.load(shard_file)
|
state_dict = torch.load(shard_file, weights_only=True)
|
||||||
self.assertEqual(len(state_dict), 1)
|
self.assertEqual(len(state_dict), 1)
|
||||||
|
|
||||||
# Check the index and the shard files found match
|
# Check the index and the shard files found match
|
||||||
|
Loading…
Reference in New Issue
Block a user