mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Reverting Deta cloning mecanism. (#22656)
* Fixed the revert by making sure that even the regexp can cover all duplicates. * Code simplification using hash. * Fixing the `ident`. * Fixing ignoring patterened duplicate names. * Using `accelerate@find_tied_parameters` for from_pretrained This is more correct there, since it handles meta device seemlessly and we don't need to handle "non-duplicate" tensors (slices of each other). * Protecting accelerate. * Update src/transformers/modeling_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --------- Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
d6f1da6b71
commit
6e32959329
@ -83,6 +83,7 @@ if is_accelerate_available():
|
||||
from accelerate import __version__ as accelerate_version
|
||||
from accelerate import dispatch_model, infer_auto_device_map, init_empty_weights
|
||||
from accelerate.utils import (
|
||||
find_tied_parameters,
|
||||
load_offloaded_weights,
|
||||
offload_weight,
|
||||
save_offload_index,
|
||||
@ -93,6 +94,8 @@ if is_accelerate_available():
|
||||
from accelerate.utils import get_balanced_memory
|
||||
else:
|
||||
get_balanced_memory = None
|
||||
else:
|
||||
find_tied_parameters = None
|
||||
|
||||
if is_safetensors_available():
|
||||
from safetensors import safe_open
|
||||
@ -1776,7 +1779,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
# We're going to remove aliases before saving
|
||||
ptrs = collections.defaultdict(list)
|
||||
for name, tensor in state_dict.items():
|
||||
ptrs[tensor.data_ptr()].append(name)
|
||||
ident = (tensor.data_ptr(), tensor.device, tensor.shape, tensor.stride())
|
||||
ptrs[ident].append(name)
|
||||
|
||||
# These are all the pointers of shared tensors.
|
||||
shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
|
||||
@ -1785,10 +1789,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
# Removing the keys which are declared as known duplicates on
|
||||
# load. This allows to make sure the name which is kept is consistent.
|
||||
if self._keys_to_ignore_on_load_missing is not None:
|
||||
for name in names:
|
||||
found = 0
|
||||
for name in sorted(names):
|
||||
matches_pattern = any(re.search(pat, name) for pat in self._keys_to_ignore_on_load_missing)
|
||||
if matches_pattern and name in state_dict:
|
||||
del state_dict[name]
|
||||
found += 1
|
||||
if found < len(names):
|
||||
del state_dict[name]
|
||||
|
||||
# When not all duplicates have been cleaned, still remove those keys, but put a clear warning.
|
||||
# If the link between tensors was done at runtime then `from_pretrained` will not get
|
||||
@ -2934,12 +2941,24 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
missing_keys = list(set(expected_keys) - set(loaded_keys))
|
||||
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
|
||||
|
||||
# Some tensors maybe have been already filled by another key (tied weights).
|
||||
# TODO: Sylvain -> make this work even on meta device.
|
||||
# existing_ptrs = {model_state_dict[k].data_ptr() for k in loaded_keys if k in model_state_dict}
|
||||
# missing_keys = [
|
||||
# k for k in missing_keys if k in model_state_dict and model_state_dict[k].data_ptr() not in existing_ptrs
|
||||
# ]
|
||||
if find_tied_parameters is not None:
|
||||
tied_params = find_tied_parameters(model)
|
||||
else:
|
||||
tied_params = []
|
||||
_missing = []
|
||||
for k in missing_keys:
|
||||
found = False
|
||||
for group in tied_params:
|
||||
if k in group:
|
||||
found = True
|
||||
if len(group) > 2:
|
||||
group.remove(k)
|
||||
else:
|
||||
_missing.append(k)
|
||||
if not found:
|
||||
_missing.append(k)
|
||||
missing_keys = _missing
|
||||
|
||||
# Some models may have keys that are not in the state by design, removing them before needlessly warning
|
||||
# the user.
|
||||
if cls._keys_to_ignore_on_load_missing is not None:
|
||||
|
Loading…
Reference in New Issue
Block a user