Fix weights not properly initialized due to shape mismatch (#28122)

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2023-12-20 14:20:02 +01:00 committed by GitHub
parent 769a9542de
commit 7938c8c836
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 112 additions and 7 deletions

View File

@ -3957,13 +3957,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# retrieve unintialized modules and initialize before maybe overriding that with the pretrained weights. # retrieve unintialized modules and initialize before maybe overriding that with the pretrained weights.
if _fast_init: if _fast_init:
if remove_prefix_from_model: if not ignore_mismatched_sizes:
_loaded_keys = [f"{prefix}.{k}" for k in loaded_keys] if remove_prefix_from_model:
elif add_prefix_to_model: _loaded_keys = [f"{prefix}.{k}" for k in loaded_keys]
_loaded_keys = [k[len(prefix) + 1 :] for k in loaded_keys] elif add_prefix_to_model:
else: _loaded_keys = [k[len(prefix) + 1 :] for k in loaded_keys]
_loaded_keys = loaded_keys else:
set_initialized_submodules(model, _loaded_keys) _loaded_keys = loaded_keys
set_initialized_submodules(model, _loaded_keys)
# This will only initialize submodules that are not marked as initialized by the line above. # This will only initialize submodules that are not marked as initialized by the line above.
model.apply(model._initialize_weights) model.apply(model._initialize_weights)

View File

@ -2889,6 +2889,110 @@ class ModelTesterMixin:
else: else:
new_model_without_prefix(input_ids) new_model_without_prefix(input_ids)
def test_mismatched_shapes_have_properly_initialized_weights(self):
if not self.test_mismatched_shapes:
return
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
configs_no_init = _config_zero_init(config)
for model_class in self.all_model_classes:
if model_class.__name__ not in get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES):
continue
with self.subTest(msg=f"Testing {model_class}"):
with tempfile.TemporaryDirectory() as tmp_dir:
model = model_class(configs_no_init)
model.save_pretrained(tmp_dir)
# Fails when we don't set ignore_mismatched_sizes=True
with self.assertRaises(RuntimeError):
new_model = AutoModelForSequenceClassification.from_pretrained(tmp_dir, num_labels=42)
logger = logging.get_logger("transformers.modeling_utils")
with CaptureLogger(logger) as cl:
new_model = AutoModelForSequenceClassification.from_pretrained(
tmp_dir, num_labels=42, ignore_mismatched_sizes=True
)
self.assertIn("the shapes did not match", cl.out)
for name, param in new_model.named_parameters():
if param.requires_grad:
self.assertIn(
((param.data.mean() * 1e9).round() / 1e9).item(),
[0.0, 1.0],
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
def test_matched_shapes_have_loaded_weights_when_some_mismatched_shapes_exist(self):
# 1. Create a dummy class. Should have buffers as well? To make sure we test __init__
class MyClass(PreTrainedModel):
config_class = PretrainedConfig
def __init__(self, config=None):
super().__init__(config if config is not None else PretrainedConfig())
self.linear = nn.Linear(10, config.num_labels, bias=True)
self.embedding = nn.Embedding(10, 10)
self.std = 1
def _init_weights(self, module):
if isinstance(module, nn.Linear):
module.weight.data = nn.init.kaiming_uniform_(module.weight.data, np.sqrt(5))
if module.bias is not None:
module.bias.data = module.bias.data.normal_(mean=0.0, std=self.std)
# Used to make sure the weights with matched shape are loaded correctly
config = PretrainedConfig()
config.num_labels = 3
model = MyClass(config=config)
# Used to make sure the weights with mismatched shape are properly initialized
set_seed(0)
config = PretrainedConfig()
config.num_labels = 4
# not to init. the weights during the creation: to match the logic in `from_pretrained`, so we can keep the
# same sequence of random ops in the execution path to allow us to compare `target_model` and `new_model` below
# for `linear` part.
with ContextManagers([no_init_weights(True)]):
target_model = MyClass(config=config)
target_model.apply(target_model._initialize_weights)
with tempfile.TemporaryDirectory() as tmpdirname:
state_dict = model.state_dict()
del state_dict["linear.weight"]
model.config.save_pretrained(tmpdirname)
torch.save(state_dict, os.path.join(tmpdirname, "pytorch_model.bin"))
set_seed(0)
new_model = MyClass.from_pretrained(tmpdirname, num_labels=4, ignore_mismatched_sizes=True)
for key in new_model.state_dict().keys():
# check weight values for weights with matched shapes are identical
# (i.e. correctly loaded from the checkpoint)
if key not in ["linear.weight", "linear.bias"]:
max_diff = torch.max(torch.abs(model.state_dict()[key] - new_model.state_dict()[key]))
self.assertLessEqual(
max_diff.item(),
1e-6,
msg=f"the weight values for `{key}` in `new_model` and `model` are not identical",
)
else:
# check we have some mismatched shapes
self.assertNotEqual(
model.state_dict()[key].shape,
new_model.state_dict()[key].shape,
msg=f"the weight shapes for {key} in `model` and `new_model` should differ",
)
# check the weights with mismatched shape are properly initialized
max_diff = torch.max(torch.abs(new_model.state_dict()[key] - target_model.state_dict()[key]))
self.assertLessEqual(
max_diff.item(),
1e-6,
msg=f"the weight values for `{key}` in `new_model` and `target_model` are not identical",
)
def test_model_is_small(self): def test_model_is_small(self):
# Just a consistency check to make sure we are not running tests on 80M parameter models. # Just a consistency check to make sure we are not running tests on 80M parameter models.
config, _ = self.model_tester.prepare_config_and_inputs_for_common() config, _ = self.model_tester.prepare_config_and_inputs_for_common()