From 7938c8c836b2f42d25dfe32f17e5022209b76f9d Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Wed, 20 Dec 2023 14:20:02 +0100 Subject: [PATCH] Fix weights not properly initialized due to shape mismatch (#28122) * fix --------- Co-authored-by: ydshieh --- src/transformers/modeling_utils.py | 15 +++-- tests/test_modeling_common.py | 104 +++++++++++++++++++++++++++++ 2 files changed, 112 insertions(+), 7 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 1584e4a5441..3cda4e3bfa0 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3957,13 +3957,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # retrieve unintialized modules and initialize before maybe overriding that with the pretrained weights. if _fast_init: - if remove_prefix_from_model: - _loaded_keys = [f"{prefix}.{k}" for k in loaded_keys] - elif add_prefix_to_model: - _loaded_keys = [k[len(prefix) + 1 :] for k in loaded_keys] - else: - _loaded_keys = loaded_keys - set_initialized_submodules(model, _loaded_keys) + if not ignore_mismatched_sizes: + if remove_prefix_from_model: + _loaded_keys = [f"{prefix}.{k}" for k in loaded_keys] + elif add_prefix_to_model: + _loaded_keys = [k[len(prefix) + 1 :] for k in loaded_keys] + else: + _loaded_keys = loaded_keys + set_initialized_submodules(model, _loaded_keys) # This will only initialize submodules that are not marked as initialized by the line above. model.apply(model._initialize_weights) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 85e69300516..23071b93d5f 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2889,6 +2889,110 @@ class ModelTesterMixin: else: new_model_without_prefix(input_ids) + def test_mismatched_shapes_have_properly_initialized_weights(self): + if not self.test_mismatched_shapes: + return + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + configs_no_init = _config_zero_init(config) + + for model_class in self.all_model_classes: + if model_class.__name__ not in get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES): + continue + + with self.subTest(msg=f"Testing {model_class}"): + with tempfile.TemporaryDirectory() as tmp_dir: + model = model_class(configs_no_init) + model.save_pretrained(tmp_dir) + + # Fails when we don't set ignore_mismatched_sizes=True + with self.assertRaises(RuntimeError): + new_model = AutoModelForSequenceClassification.from_pretrained(tmp_dir, num_labels=42) + + logger = logging.get_logger("transformers.modeling_utils") + + with CaptureLogger(logger) as cl: + new_model = AutoModelForSequenceClassification.from_pretrained( + tmp_dir, num_labels=42, ignore_mismatched_sizes=True + ) + self.assertIn("the shapes did not match", cl.out) + + for name, param in new_model.named_parameters(): + if param.requires_grad: + self.assertIn( + ((param.data.mean() * 1e9).round() / 1e9).item(), + [0.0, 1.0], + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + + def test_matched_shapes_have_loaded_weights_when_some_mismatched_shapes_exist(self): + # 1. Create a dummy class. Should have buffers as well? To make sure we test __init__ + class MyClass(PreTrainedModel): + config_class = PretrainedConfig + + def __init__(self, config=None): + super().__init__(config if config is not None else PretrainedConfig()) + self.linear = nn.Linear(10, config.num_labels, bias=True) + self.embedding = nn.Embedding(10, 10) + self.std = 1 + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + module.weight.data = nn.init.kaiming_uniform_(module.weight.data, np.sqrt(5)) + if module.bias is not None: + module.bias.data = module.bias.data.normal_(mean=0.0, std=self.std) + + # Used to make sure the weights with matched shape are loaded correctly + config = PretrainedConfig() + config.num_labels = 3 + model = MyClass(config=config) + + # Used to make sure the weights with mismatched shape are properly initialized + set_seed(0) + config = PretrainedConfig() + config.num_labels = 4 + # not to init. the weights during the creation: to match the logic in `from_pretrained`, so we can keep the + # same sequence of random ops in the execution path to allow us to compare `target_model` and `new_model` below + # for `linear` part. + with ContextManagers([no_init_weights(True)]): + target_model = MyClass(config=config) + target_model.apply(target_model._initialize_weights) + + with tempfile.TemporaryDirectory() as tmpdirname: + state_dict = model.state_dict() + del state_dict["linear.weight"] + + model.config.save_pretrained(tmpdirname) + torch.save(state_dict, os.path.join(tmpdirname, "pytorch_model.bin")) + + set_seed(0) + new_model = MyClass.from_pretrained(tmpdirname, num_labels=4, ignore_mismatched_sizes=True) + + for key in new_model.state_dict().keys(): + # check weight values for weights with matched shapes are identical + # (i.e. correctly loaded from the checkpoint) + if key not in ["linear.weight", "linear.bias"]: + max_diff = torch.max(torch.abs(model.state_dict()[key] - new_model.state_dict()[key])) + self.assertLessEqual( + max_diff.item(), + 1e-6, + msg=f"the weight values for `{key}` in `new_model` and `model` are not identical", + ) + else: + # check we have some mismatched shapes + self.assertNotEqual( + model.state_dict()[key].shape, + new_model.state_dict()[key].shape, + msg=f"the weight shapes for {key} in `model` and `new_model` should differ", + ) + # check the weights with mismatched shape are properly initialized + max_diff = torch.max(torch.abs(new_model.state_dict()[key] - target_model.state_dict()[key])) + self.assertLessEqual( + max_diff.item(), + 1e-6, + msg=f"the weight values for `{key}` in `new_model` and `target_model` are not identical", + ) + def test_model_is_small(self): # Just a consistency check to make sure we are not running tests on 80M parameter models. config, _ = self.model_tester.prepare_config_and_inputs_for_common()