mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix weights not properly initialized due to shape mismatch (#28122)
* fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
769a9542de
commit
7938c8c836
@ -3957,13 +3957,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
|
||||
# retrieve unintialized modules and initialize before maybe overriding that with the pretrained weights.
|
||||
if _fast_init:
|
||||
if remove_prefix_from_model:
|
||||
_loaded_keys = [f"{prefix}.{k}" for k in loaded_keys]
|
||||
elif add_prefix_to_model:
|
||||
_loaded_keys = [k[len(prefix) + 1 :] for k in loaded_keys]
|
||||
else:
|
||||
_loaded_keys = loaded_keys
|
||||
set_initialized_submodules(model, _loaded_keys)
|
||||
if not ignore_mismatched_sizes:
|
||||
if remove_prefix_from_model:
|
||||
_loaded_keys = [f"{prefix}.{k}" for k in loaded_keys]
|
||||
elif add_prefix_to_model:
|
||||
_loaded_keys = [k[len(prefix) + 1 :] for k in loaded_keys]
|
||||
else:
|
||||
_loaded_keys = loaded_keys
|
||||
set_initialized_submodules(model, _loaded_keys)
|
||||
# This will only initialize submodules that are not marked as initialized by the line above.
|
||||
model.apply(model._initialize_weights)
|
||||
|
||||
|
@ -2889,6 +2889,110 @@ class ModelTesterMixin:
|
||||
else:
|
||||
new_model_without_prefix(input_ids)
|
||||
|
||||
def test_mismatched_shapes_have_properly_initialized_weights(self):
|
||||
if not self.test_mismatched_shapes:
|
||||
return
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
configs_no_init = _config_zero_init(config)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class.__name__ not in get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES):
|
||||
continue
|
||||
|
||||
with self.subTest(msg=f"Testing {model_class}"):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model = model_class(configs_no_init)
|
||||
model.save_pretrained(tmp_dir)
|
||||
|
||||
# Fails when we don't set ignore_mismatched_sizes=True
|
||||
with self.assertRaises(RuntimeError):
|
||||
new_model = AutoModelForSequenceClassification.from_pretrained(tmp_dir, num_labels=42)
|
||||
|
||||
logger = logging.get_logger("transformers.modeling_utils")
|
||||
|
||||
with CaptureLogger(logger) as cl:
|
||||
new_model = AutoModelForSequenceClassification.from_pretrained(
|
||||
tmp_dir, num_labels=42, ignore_mismatched_sizes=True
|
||||
)
|
||||
self.assertIn("the shapes did not match", cl.out)
|
||||
|
||||
for name, param in new_model.named_parameters():
|
||||
if param.requires_grad:
|
||||
self.assertIn(
|
||||
((param.data.mean() * 1e9).round() / 1e9).item(),
|
||||
[0.0, 1.0],
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
||||
def test_matched_shapes_have_loaded_weights_when_some_mismatched_shapes_exist(self):
|
||||
# 1. Create a dummy class. Should have buffers as well? To make sure we test __init__
|
||||
class MyClass(PreTrainedModel):
|
||||
config_class = PretrainedConfig
|
||||
|
||||
def __init__(self, config=None):
|
||||
super().__init__(config if config is not None else PretrainedConfig())
|
||||
self.linear = nn.Linear(10, config.num_labels, bias=True)
|
||||
self.embedding = nn.Embedding(10, 10)
|
||||
self.std = 1
|
||||
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data = nn.init.kaiming_uniform_(module.weight.data, np.sqrt(5))
|
||||
if module.bias is not None:
|
||||
module.bias.data = module.bias.data.normal_(mean=0.0, std=self.std)
|
||||
|
||||
# Used to make sure the weights with matched shape are loaded correctly
|
||||
config = PretrainedConfig()
|
||||
config.num_labels = 3
|
||||
model = MyClass(config=config)
|
||||
|
||||
# Used to make sure the weights with mismatched shape are properly initialized
|
||||
set_seed(0)
|
||||
config = PretrainedConfig()
|
||||
config.num_labels = 4
|
||||
# not to init. the weights during the creation: to match the logic in `from_pretrained`, so we can keep the
|
||||
# same sequence of random ops in the execution path to allow us to compare `target_model` and `new_model` below
|
||||
# for `linear` part.
|
||||
with ContextManagers([no_init_weights(True)]):
|
||||
target_model = MyClass(config=config)
|
||||
target_model.apply(target_model._initialize_weights)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
state_dict = model.state_dict()
|
||||
del state_dict["linear.weight"]
|
||||
|
||||
model.config.save_pretrained(tmpdirname)
|
||||
torch.save(state_dict, os.path.join(tmpdirname, "pytorch_model.bin"))
|
||||
|
||||
set_seed(0)
|
||||
new_model = MyClass.from_pretrained(tmpdirname, num_labels=4, ignore_mismatched_sizes=True)
|
||||
|
||||
for key in new_model.state_dict().keys():
|
||||
# check weight values for weights with matched shapes are identical
|
||||
# (i.e. correctly loaded from the checkpoint)
|
||||
if key not in ["linear.weight", "linear.bias"]:
|
||||
max_diff = torch.max(torch.abs(model.state_dict()[key] - new_model.state_dict()[key]))
|
||||
self.assertLessEqual(
|
||||
max_diff.item(),
|
||||
1e-6,
|
||||
msg=f"the weight values for `{key}` in `new_model` and `model` are not identical",
|
||||
)
|
||||
else:
|
||||
# check we have some mismatched shapes
|
||||
self.assertNotEqual(
|
||||
model.state_dict()[key].shape,
|
||||
new_model.state_dict()[key].shape,
|
||||
msg=f"the weight shapes for {key} in `model` and `new_model` should differ",
|
||||
)
|
||||
# check the weights with mismatched shape are properly initialized
|
||||
max_diff = torch.max(torch.abs(new_model.state_dict()[key] - target_model.state_dict()[key]))
|
||||
self.assertLessEqual(
|
||||
max_diff.item(),
|
||||
1e-6,
|
||||
msg=f"the weight values for `{key}` in `new_model` and `target_model` are not identical",
|
||||
)
|
||||
|
||||
def test_model_is_small(self):
|
||||
# Just a consistency check to make sure we are not running tests on 80M parameter models.
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
Loading…
Reference in New Issue
Block a user