From 8c1b5d37827a6691fef4b2d926f2d04fb6f5a9e3 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 16 Jan 2025 17:25:44 -0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=A8=F0=9F=9A=A8=F0=9F=9A=A8=20An=20att?= =?UTF-8?q?empt=20to=20fix=20#29554.=20Include=20'LayerNorm.'=20in=20gamma?= =?UTF-8?q?/beta=20rename=20scope,=20optimize=20string=20search.=20(#35615?= =?UTF-8?q?)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * An attempt to fix #29554. Include 'LayerNorm.' in gamma/beta rename scope, reduce number of characters searched on every load considerably. * Fix fix on load issue * Fix gamma/beta warning test * A style complaint * Improve efficiency of weight norm key rename. Add better comments about weight norm and layer norm renaming. * Habitual elif redunant with the return --- src/transformers/modeling_utils.py | 59 ++++++++++--------- .../timm_wrapper/modeling_timm_wrapper.py | 8 +-- tests/utils/test_modeling_utils.py | 58 ++++++++---------- 3 files changed, 60 insertions(+), 65 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 3e8d73f94b1..d003c02c85e 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4367,26 +4367,31 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix return model @staticmethod - def _fix_state_dict_key_on_load(key): + def _fix_state_dict_key_on_load(key) -> Tuple[str, bool]: """Replace legacy parameter names with their modern equivalents. E.g. beta -> bias, gamma -> weight.""" - if "beta" in key: - return key.replace("beta", "bias") - if "gamma" in key: - return key.replace("gamma", "weight") + # Rename LayerNorm beta & gamma params for some early models ported from Tensorflow (e.g. Bert) + # This rename is logged. + if key.endswith("LayerNorm.beta"): + return key.replace("LayerNorm.beta", "LayerNorm.bias"), True + if key.endswith("LayerNorm.gamma"): + return key.replace("LayerNorm.gamma", "LayerNorm.weight"), True - # to avoid logging parametrized weight norm renaming + # Rename weight norm parametrizations to match changes across torch versions. + # Impacts a number of speech/wav2vec models. e.g. Hubert, Wav2Vec2, and others. + # This rename is not logged. if hasattr(nn.utils.parametrizations, "weight_norm"): - if "weight_g" in key: - return key.replace("weight_g", "parametrizations.weight.original0") - if "weight_v" in key: - return key.replace("weight_v", "parametrizations.weight.original1") + if key.endswith("weight_g"): + return key.replace("weight_g", "parametrizations.weight.original0"), True + if key.endswith("weight_v"): + return key.replace("weight_v", "parametrizations.weight.original1"), True else: - if "parametrizations.weight.original0" in key: - return key.replace("parametrizations.weight.original0", "weight_g") - if "parametrizations.weight.original1" in key: - return key.replace("parametrizations.weight.original1", "weight_v") - return key + if key.endswith("parametrizations.weight.original0"): + return key.replace("parametrizations.weight.original0", "weight_g"), True + if key.endswith("parametrizations.weight.original1"): + return key.replace("parametrizations.weight.original1", "weight_v"), True + + return key, False @classmethod def _fix_state_dict_keys_on_load(cls, state_dict): @@ -4397,15 +4402,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix renamed_keys = {} state_dict_keys = list(state_dict.keys()) for key in state_dict_keys: - new_key = cls._fix_state_dict_key_on_load(key) - if new_key != key: + new_key, has_changed = cls._fix_state_dict_key_on_load(key) + if has_changed: state_dict[new_key] = state_dict.pop(key) - # add it once for logging - if "gamma" in key and "gamma" not in renamed_keys: - renamed_keys["gamma"] = (key, new_key) - if "beta" in key and "beta" not in renamed_keys: - renamed_keys["beta"] = (key, new_key) + # track gamma/beta rename for logging + if key.endswith("LayerNorm.gamma"): + renamed_keys["LayerNorm.gamma"] = (key, new_key) + elif key.endswith("LayerNorm.beta"): + renamed_keys["LayerNorm.beta"] = (key, new_key) if renamed_keys: warning_msg = f"A pretrained model of type `{cls.__name__}` " @@ -4418,19 +4423,19 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix return state_dict @staticmethod - def _fix_state_dict_key_on_save(key): + def _fix_state_dict_key_on_save(key) -> Tuple[str, bool]: """ Similar to `_fix_state_dict_key_on_load` allows to define hook for state dict key renaming on model save. - Do nothing by default, but can be overriden in particular models. + Do nothing by default, but can be overridden in particular models. """ - return key + return key, False def _fix_state_dict_keys_on_save(self, state_dict): """ Similar to `_fix_state_dict_keys_on_load` allows to define hook for state dict key renaming on model save. Apply `_fix_state_dict_key_on_save` to all keys in `state_dict`. """ - return {self._fix_state_dict_key_on_save(key): value for key, value in state_dict.items()} + return {self._fix_state_dict_key_on_save(key)[0]: value for key, value in state_dict.items()} @classmethod def _load_pretrained_model( @@ -4488,7 +4493,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix expected_keys = hf_quantizer.update_expected_keys(model, expected_keys, loaded_keys) original_loaded_keys = loaded_keys - loaded_keys = [cls._fix_state_dict_key_on_load(key) for key in loaded_keys] + loaded_keys = [cls._fix_state_dict_key_on_load(key)[0] for key in loaded_keys] if len(prefix) > 0: has_prefix_module = any(s.startswith(prefix) for s in loaded_keys) diff --git a/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py b/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py index 47e8944583b..a74202ce5aa 100644 --- a/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py +++ b/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py @@ -90,22 +90,22 @@ class TimmWrapperPreTrainedModel(PreTrainedModel): super().__init__(*args, **kwargs) @staticmethod - def _fix_state_dict_key_on_load(key): + def _fix_state_dict_key_on_load(key) -> Tuple[str, bool]: """ Overrides original method that renames `gamma` and `beta` to `weight` and `bias`. We don't want this behavior for timm wrapped models. Instead, this method adds a "timm_model." prefix to enable loading official timm Hub checkpoints. """ if "timm_model." not in key: - return f"timm_model.{key}" - return key + return f"timm_model.{key}", True + return key, False def _fix_state_dict_key_on_save(self, key): """ Overrides original method to remove "timm_model." prefix from state_dict keys. Makes the saved checkpoint compatible with the `timm` library. """ - return key.replace("timm_model.", "") + return key.replace("timm_model.", ""), True def load_state_dict(self, state_dict, *args, **kwargs): """ diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 63f8e7ec465..84b5ebbb24c 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -1618,57 +1618,47 @@ class ModelUtilsTest(TestCasePlus): self.assertTrue(torch.allclose(outputs_from_saved["logits"], outputs["logits"])) def test_warning_for_beta_gamma_parameters(self): - class TestModelGamma(PreTrainedModel): + class TestGammaBetaNorm(torch.nn.Module): + def __init__(self): + super().__init__() + self.gamma = torch.nn.Parameter(torch.ones(1)) + self.beta = torch.nn.Parameter(torch.zeros(1)) + + def forward(self): + return self.gamma.sum() + self.beta.sum() + + class TestModelGammaBeta(PreTrainedModel): def __init__(self, config): super().__init__(config) - self.gamma_param = nn.Parameter(torch.ones(10)) + self.LayerNorm = TestGammaBetaNorm() self.post_init() def forward(self): - return self.gamma_param.sum() + return self.LayerNorm() logger = logging.get_logger("transformers.modeling_utils") config = PretrainedConfig() - warning_msg_gamma = "`gamma_param` -> `weight_param`" - model = TestModelGamma(config) + warning_msg_gamma = "`LayerNorm.gamma` -> `LayerNorm.weight`" + warning_msg_beta = "`LayerNorm.beta` -> `LayerNorm.bias`" + model = TestModelGammaBeta(config) with tempfile.TemporaryDirectory() as tmp_dir: model.save_pretrained(tmp_dir) with LoggingLevel(logging.INFO): with CaptureLogger(logger) as cl1: - _, loading_info = TestModelGamma.from_pretrained(tmp_dir, config=config, output_loading_info=True) + _, loading_info = TestModelGammaBeta.from_pretrained( + tmp_dir, config=config, output_loading_info=True + ) missing_keys = loading_info["missing_keys"] unexpected_keys = loading_info["unexpected_keys"] - self.assertIn("`TestModelGamma`", cl1.out) + self.assertIn("`TestModelGammaBeta`", cl1.out) self.assertIn(warning_msg_gamma, cl1.out) - self.assertIn("gamma_param", missing_keys) - self.assertIn("weight_param", unexpected_keys) - - class TestModelBeta(PreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.beta_param = nn.Parameter(torch.ones(10)) - self.post_init() - - def forward(self): - return self.beta_param.sum() - - warning_msg_beta = "`beta_param` -> `bias_param`" - model = TestModelBeta(config) - - with tempfile.TemporaryDirectory() as tmp_dir: - model.save_pretrained(tmp_dir) - with LoggingLevel(logging.INFO): - with CaptureLogger(logger) as cl2: - _, loading_info = TestModelBeta.from_pretrained(tmp_dir, config=config, output_loading_info=True) - - missing_keys = loading_info["missing_keys"] - unexpected_keys = loading_info["unexpected_keys"] - self.assertIn("`TestModelBeta`", cl2.out) - self.assertIn(warning_msg_beta, cl2.out) - self.assertIn("beta_param", missing_keys) - self.assertIn("bias_param", unexpected_keys) + self.assertIn(warning_msg_beta, cl1.out) + self.assertIn("LayerNorm.gamma", missing_keys) + self.assertIn("LayerNorm.weight", unexpected_keys) + self.assertIn("LayerNorm.beta", missing_keys) + self.assertIn("LayerNorm.bias", unexpected_keys) def test_isin_mps_friendly(self): """tests that our custom `isin_mps_friendly` matches `torch.isin`"""