mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
parent
bc2dea3f54
commit
41f5c3216c
@ -761,9 +761,6 @@ def _load_state_dict_into_meta_model(
|
||||
if is_meta_state_dict:
|
||||
file_pointer = safe_open(shard_file, framework="pt", device=tensor_device)
|
||||
|
||||
# Used to fix the issue mentioned in #37031: when loading a model with tied weights in state_dict + `tie_word_embeddings = False`,
|
||||
# we need to make sure they are not loaded as tied weights!
|
||||
data_ptrs = set()
|
||||
for param_name, empty_param in state_dict.items():
|
||||
if param_name not in expected_keys:
|
||||
continue
|
||||
@ -833,14 +830,8 @@ def _load_state_dict_into_meta_model(
|
||||
if is_fsdp_enabled():
|
||||
param_device = "cpu" if is_local_dist_rank_0() else "meta"
|
||||
|
||||
# avoid tied weights
|
||||
if param.data_ptr() in data_ptrs:
|
||||
param = param.clone()
|
||||
|
||||
_load_parameter_into_model(model, param_name, param.to(param_device))
|
||||
|
||||
# Add `data_ptr` of `model.state_dict()[param_name]` to avoid tied weights
|
||||
data_ptrs.add(model.state_dict()[param_name].data_ptr())
|
||||
else:
|
||||
hf_quantizer.create_quantized_param(
|
||||
model, param, param_name, param_device, state_dict, unexpected_keys
|
||||
|
Loading…
Reference in New Issue
Block a user