🚨🚨🚨 An attempt to fix #29554. Include 'LayerNorm.' in gamma/beta rename scope, optimize string search. (#35615)

* An attempt to fix #29554. Include 'LayerNorm.' in gamma/beta rename scope, reduce number of characters searched on every load considerably.

* Fix fix on load issue

* Fix gamma/beta warning test

* A style complaint

* Improve efficiency of weight norm key rename. Add better comments about weight norm and layer norm renaming.

* Habitual elif redunant with the return
This commit is contained in:
Ross Wightman 2025-01-16 17:25:44 -08:00 committed by GitHub
parent 02a492a838
commit 8c1b5d3782
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 60 additions and 65 deletions

View File

@ -4367,26 +4367,31 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
return model
@staticmethod
def _fix_state_dict_key_on_load(key):
def _fix_state_dict_key_on_load(key) -> Tuple[str, bool]:
"""Replace legacy parameter names with their modern equivalents. E.g. beta -> bias, gamma -> weight."""
if "beta" in key:
return key.replace("beta", "bias")
if "gamma" in key:
return key.replace("gamma", "weight")
# Rename LayerNorm beta & gamma params for some early models ported from Tensorflow (e.g. Bert)
# This rename is logged.
if key.endswith("LayerNorm.beta"):
return key.replace("LayerNorm.beta", "LayerNorm.bias"), True
if key.endswith("LayerNorm.gamma"):
return key.replace("LayerNorm.gamma", "LayerNorm.weight"), True
# to avoid logging parametrized weight norm renaming
# Rename weight norm parametrizations to match changes across torch versions.
# Impacts a number of speech/wav2vec models. e.g. Hubert, Wav2Vec2, and others.
# This rename is not logged.
if hasattr(nn.utils.parametrizations, "weight_norm"):
if "weight_g" in key:
return key.replace("weight_g", "parametrizations.weight.original0")
if "weight_v" in key:
return key.replace("weight_v", "parametrizations.weight.original1")
if key.endswith("weight_g"):
return key.replace("weight_g", "parametrizations.weight.original0"), True
if key.endswith("weight_v"):
return key.replace("weight_v", "parametrizations.weight.original1"), True
else:
if "parametrizations.weight.original0" in key:
return key.replace("parametrizations.weight.original0", "weight_g")
if "parametrizations.weight.original1" in key:
return key.replace("parametrizations.weight.original1", "weight_v")
return key
if key.endswith("parametrizations.weight.original0"):
return key.replace("parametrizations.weight.original0", "weight_g"), True
if key.endswith("parametrizations.weight.original1"):
return key.replace("parametrizations.weight.original1", "weight_v"), True
return key, False
@classmethod
def _fix_state_dict_keys_on_load(cls, state_dict):
@ -4397,15 +4402,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
renamed_keys = {}
state_dict_keys = list(state_dict.keys())
for key in state_dict_keys:
new_key = cls._fix_state_dict_key_on_load(key)
if new_key != key:
new_key, has_changed = cls._fix_state_dict_key_on_load(key)
if has_changed:
state_dict[new_key] = state_dict.pop(key)
# add it once for logging
if "gamma" in key and "gamma" not in renamed_keys:
renamed_keys["gamma"] = (key, new_key)
if "beta" in key and "beta" not in renamed_keys:
renamed_keys["beta"] = (key, new_key)
# track gamma/beta rename for logging
if key.endswith("LayerNorm.gamma"):
renamed_keys["LayerNorm.gamma"] = (key, new_key)
elif key.endswith("LayerNorm.beta"):
renamed_keys["LayerNorm.beta"] = (key, new_key)
if renamed_keys:
warning_msg = f"A pretrained model of type `{cls.__name__}` "
@ -4418,19 +4423,19 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
return state_dict
@staticmethod
def _fix_state_dict_key_on_save(key):
def _fix_state_dict_key_on_save(key) -> Tuple[str, bool]:
"""
Similar to `_fix_state_dict_key_on_load` allows to define hook for state dict key renaming on model save.
Do nothing by default, but can be overriden in particular models.
Do nothing by default, but can be overridden in particular models.
"""
return key
return key, False
def _fix_state_dict_keys_on_save(self, state_dict):
"""
Similar to `_fix_state_dict_keys_on_load` allows to define hook for state dict key renaming on model save.
Apply `_fix_state_dict_key_on_save` to all keys in `state_dict`.
"""
return {self._fix_state_dict_key_on_save(key): value for key, value in state_dict.items()}
return {self._fix_state_dict_key_on_save(key)[0]: value for key, value in state_dict.items()}
@classmethod
def _load_pretrained_model(
@ -4488,7 +4493,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
expected_keys = hf_quantizer.update_expected_keys(model, expected_keys, loaded_keys)
original_loaded_keys = loaded_keys
loaded_keys = [cls._fix_state_dict_key_on_load(key) for key in loaded_keys]
loaded_keys = [cls._fix_state_dict_key_on_load(key)[0] for key in loaded_keys]
if len(prefix) > 0:
has_prefix_module = any(s.startswith(prefix) for s in loaded_keys)

View File

@ -90,22 +90,22 @@ class TimmWrapperPreTrainedModel(PreTrainedModel):
super().__init__(*args, **kwargs)
@staticmethod
def _fix_state_dict_key_on_load(key):
def _fix_state_dict_key_on_load(key) -> Tuple[str, bool]:
"""
Overrides original method that renames `gamma` and `beta` to `weight` and `bias`.
We don't want this behavior for timm wrapped models. Instead, this method adds a
"timm_model." prefix to enable loading official timm Hub checkpoints.
"""
if "timm_model." not in key:
return f"timm_model.{key}"
return key
return f"timm_model.{key}", True
return key, False
def _fix_state_dict_key_on_save(self, key):
"""
Overrides original method to remove "timm_model." prefix from state_dict keys.
Makes the saved checkpoint compatible with the `timm` library.
"""
return key.replace("timm_model.", "")
return key.replace("timm_model.", ""), True
def load_state_dict(self, state_dict, *args, **kwargs):
"""

View File

@ -1618,57 +1618,47 @@ class ModelUtilsTest(TestCasePlus):
self.assertTrue(torch.allclose(outputs_from_saved["logits"], outputs["logits"]))
def test_warning_for_beta_gamma_parameters(self):
class TestModelGamma(PreTrainedModel):
class TestGammaBetaNorm(torch.nn.Module):
def __init__(self):
super().__init__()
self.gamma = torch.nn.Parameter(torch.ones(1))
self.beta = torch.nn.Parameter(torch.zeros(1))
def forward(self):
return self.gamma.sum() + self.beta.sum()
class TestModelGammaBeta(PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.gamma_param = nn.Parameter(torch.ones(10))
self.LayerNorm = TestGammaBetaNorm()
self.post_init()
def forward(self):
return self.gamma_param.sum()
return self.LayerNorm()
logger = logging.get_logger("transformers.modeling_utils")
config = PretrainedConfig()
warning_msg_gamma = "`gamma_param` -> `weight_param`"
model = TestModelGamma(config)
warning_msg_gamma = "`LayerNorm.gamma` -> `LayerNorm.weight`"
warning_msg_beta = "`LayerNorm.beta` -> `LayerNorm.bias`"
model = TestModelGammaBeta(config)
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
with LoggingLevel(logging.INFO):
with CaptureLogger(logger) as cl1:
_, loading_info = TestModelGamma.from_pretrained(tmp_dir, config=config, output_loading_info=True)
_, loading_info = TestModelGammaBeta.from_pretrained(
tmp_dir, config=config, output_loading_info=True
)
missing_keys = loading_info["missing_keys"]
unexpected_keys = loading_info["unexpected_keys"]
self.assertIn("`TestModelGamma`", cl1.out)
self.assertIn("`TestModelGammaBeta`", cl1.out)
self.assertIn(warning_msg_gamma, cl1.out)
self.assertIn("gamma_param", missing_keys)
self.assertIn("weight_param", unexpected_keys)
class TestModelBeta(PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.beta_param = nn.Parameter(torch.ones(10))
self.post_init()
def forward(self):
return self.beta_param.sum()
warning_msg_beta = "`beta_param` -> `bias_param`"
model = TestModelBeta(config)
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
with LoggingLevel(logging.INFO):
with CaptureLogger(logger) as cl2:
_, loading_info = TestModelBeta.from_pretrained(tmp_dir, config=config, output_loading_info=True)
missing_keys = loading_info["missing_keys"]
unexpected_keys = loading_info["unexpected_keys"]
self.assertIn("`TestModelBeta`", cl2.out)
self.assertIn(warning_msg_beta, cl2.out)
self.assertIn("beta_param", missing_keys)
self.assertIn("bias_param", unexpected_keys)
self.assertIn(warning_msg_beta, cl1.out)
self.assertIn("LayerNorm.gamma", missing_keys)
self.assertIn("LayerNorm.weight", unexpected_keys)
self.assertIn("LayerNorm.beta", missing_keys)
self.assertIn("LayerNorm.bias", unexpected_keys)
def test_isin_mps_friendly(self):
"""tests that our custom `isin_mps_friendly` matches `torch.isin`"""