mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
Fix weights not properly initialized due to shape mismatch (#28122)
* fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
769a9542de
commit
7938c8c836
@ -3957,13 +3957,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
|
|
||||||
# retrieve unintialized modules and initialize before maybe overriding that with the pretrained weights.
|
# retrieve unintialized modules and initialize before maybe overriding that with the pretrained weights.
|
||||||
if _fast_init:
|
if _fast_init:
|
||||||
if remove_prefix_from_model:
|
if not ignore_mismatched_sizes:
|
||||||
_loaded_keys = [f"{prefix}.{k}" for k in loaded_keys]
|
if remove_prefix_from_model:
|
||||||
elif add_prefix_to_model:
|
_loaded_keys = [f"{prefix}.{k}" for k in loaded_keys]
|
||||||
_loaded_keys = [k[len(prefix) + 1 :] for k in loaded_keys]
|
elif add_prefix_to_model:
|
||||||
else:
|
_loaded_keys = [k[len(prefix) + 1 :] for k in loaded_keys]
|
||||||
_loaded_keys = loaded_keys
|
else:
|
||||||
set_initialized_submodules(model, _loaded_keys)
|
_loaded_keys = loaded_keys
|
||||||
|
set_initialized_submodules(model, _loaded_keys)
|
||||||
# This will only initialize submodules that are not marked as initialized by the line above.
|
# This will only initialize submodules that are not marked as initialized by the line above.
|
||||||
model.apply(model._initialize_weights)
|
model.apply(model._initialize_weights)
|
||||||
|
|
||||||
|
@ -2889,6 +2889,110 @@ class ModelTesterMixin:
|
|||||||
else:
|
else:
|
||||||
new_model_without_prefix(input_ids)
|
new_model_without_prefix(input_ids)
|
||||||
|
|
||||||
|
def test_mismatched_shapes_have_properly_initialized_weights(self):
|
||||||
|
if not self.test_mismatched_shapes:
|
||||||
|
return
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
configs_no_init = _config_zero_init(config)
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
if model_class.__name__ not in get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES):
|
||||||
|
continue
|
||||||
|
|
||||||
|
with self.subTest(msg=f"Testing {model_class}"):
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model = model_class(configs_no_init)
|
||||||
|
model.save_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
# Fails when we don't set ignore_mismatched_sizes=True
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
new_model = AutoModelForSequenceClassification.from_pretrained(tmp_dir, num_labels=42)
|
||||||
|
|
||||||
|
logger = logging.get_logger("transformers.modeling_utils")
|
||||||
|
|
||||||
|
with CaptureLogger(logger) as cl:
|
||||||
|
new_model = AutoModelForSequenceClassification.from_pretrained(
|
||||||
|
tmp_dir, num_labels=42, ignore_mismatched_sizes=True
|
||||||
|
)
|
||||||
|
self.assertIn("the shapes did not match", cl.out)
|
||||||
|
|
||||||
|
for name, param in new_model.named_parameters():
|
||||||
|
if param.requires_grad:
|
||||||
|
self.assertIn(
|
||||||
|
((param.data.mean() * 1e9).round() / 1e9).item(),
|
||||||
|
[0.0, 1.0],
|
||||||
|
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_matched_shapes_have_loaded_weights_when_some_mismatched_shapes_exist(self):
|
||||||
|
# 1. Create a dummy class. Should have buffers as well? To make sure we test __init__
|
||||||
|
class MyClass(PreTrainedModel):
|
||||||
|
config_class = PretrainedConfig
|
||||||
|
|
||||||
|
def __init__(self, config=None):
|
||||||
|
super().__init__(config if config is not None else PretrainedConfig())
|
||||||
|
self.linear = nn.Linear(10, config.num_labels, bias=True)
|
||||||
|
self.embedding = nn.Embedding(10, 10)
|
||||||
|
self.std = 1
|
||||||
|
|
||||||
|
def _init_weights(self, module):
|
||||||
|
if isinstance(module, nn.Linear):
|
||||||
|
module.weight.data = nn.init.kaiming_uniform_(module.weight.data, np.sqrt(5))
|
||||||
|
if module.bias is not None:
|
||||||
|
module.bias.data = module.bias.data.normal_(mean=0.0, std=self.std)
|
||||||
|
|
||||||
|
# Used to make sure the weights with matched shape are loaded correctly
|
||||||
|
config = PretrainedConfig()
|
||||||
|
config.num_labels = 3
|
||||||
|
model = MyClass(config=config)
|
||||||
|
|
||||||
|
# Used to make sure the weights with mismatched shape are properly initialized
|
||||||
|
set_seed(0)
|
||||||
|
config = PretrainedConfig()
|
||||||
|
config.num_labels = 4
|
||||||
|
# not to init. the weights during the creation: to match the logic in `from_pretrained`, so we can keep the
|
||||||
|
# same sequence of random ops in the execution path to allow us to compare `target_model` and `new_model` below
|
||||||
|
# for `linear` part.
|
||||||
|
with ContextManagers([no_init_weights(True)]):
|
||||||
|
target_model = MyClass(config=config)
|
||||||
|
target_model.apply(target_model._initialize_weights)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
state_dict = model.state_dict()
|
||||||
|
del state_dict["linear.weight"]
|
||||||
|
|
||||||
|
model.config.save_pretrained(tmpdirname)
|
||||||
|
torch.save(state_dict, os.path.join(tmpdirname, "pytorch_model.bin"))
|
||||||
|
|
||||||
|
set_seed(0)
|
||||||
|
new_model = MyClass.from_pretrained(tmpdirname, num_labels=4, ignore_mismatched_sizes=True)
|
||||||
|
|
||||||
|
for key in new_model.state_dict().keys():
|
||||||
|
# check weight values for weights with matched shapes are identical
|
||||||
|
# (i.e. correctly loaded from the checkpoint)
|
||||||
|
if key not in ["linear.weight", "linear.bias"]:
|
||||||
|
max_diff = torch.max(torch.abs(model.state_dict()[key] - new_model.state_dict()[key]))
|
||||||
|
self.assertLessEqual(
|
||||||
|
max_diff.item(),
|
||||||
|
1e-6,
|
||||||
|
msg=f"the weight values for `{key}` in `new_model` and `model` are not identical",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# check we have some mismatched shapes
|
||||||
|
self.assertNotEqual(
|
||||||
|
model.state_dict()[key].shape,
|
||||||
|
new_model.state_dict()[key].shape,
|
||||||
|
msg=f"the weight shapes for {key} in `model` and `new_model` should differ",
|
||||||
|
)
|
||||||
|
# check the weights with mismatched shape are properly initialized
|
||||||
|
max_diff = torch.max(torch.abs(new_model.state_dict()[key] - target_model.state_dict()[key]))
|
||||||
|
self.assertLessEqual(
|
||||||
|
max_diff.item(),
|
||||||
|
1e-6,
|
||||||
|
msg=f"the weight values for `{key}` in `new_model` and `target_model` are not identical",
|
||||||
|
)
|
||||||
|
|
||||||
def test_model_is_small(self):
|
def test_model_is_small(self):
|
||||||
# Just a consistency check to make sure we are not running tests on 80M parameter models.
|
# Just a consistency check to make sure we are not running tests on 80M parameter models.
|
||||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
Loading…
Reference in New Issue
Block a user