mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
🚨🚨🚨 An attempt to fix #29554. Include 'LayerNorm.' in gamma/beta rename scope, optimize string search. (#35615)
* An attempt to fix #29554. Include 'LayerNorm.' in gamma/beta rename scope, reduce number of characters searched on every load considerably. * Fix fix on load issue * Fix gamma/beta warning test * A style complaint * Improve efficiency of weight norm key rename. Add better comments about weight norm and layer norm renaming. * Habitual elif redunant with the return
This commit is contained in:
parent
02a492a838
commit
8c1b5d3782
@ -4367,26 +4367,31 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _fix_state_dict_key_on_load(key):
|
def _fix_state_dict_key_on_load(key) -> Tuple[str, bool]:
|
||||||
"""Replace legacy parameter names with their modern equivalents. E.g. beta -> bias, gamma -> weight."""
|
"""Replace legacy parameter names with their modern equivalents. E.g. beta -> bias, gamma -> weight."""
|
||||||
|
|
||||||
if "beta" in key:
|
# Rename LayerNorm beta & gamma params for some early models ported from Tensorflow (e.g. Bert)
|
||||||
return key.replace("beta", "bias")
|
# This rename is logged.
|
||||||
if "gamma" in key:
|
if key.endswith("LayerNorm.beta"):
|
||||||
return key.replace("gamma", "weight")
|
return key.replace("LayerNorm.beta", "LayerNorm.bias"), True
|
||||||
|
if key.endswith("LayerNorm.gamma"):
|
||||||
|
return key.replace("LayerNorm.gamma", "LayerNorm.weight"), True
|
||||||
|
|
||||||
# to avoid logging parametrized weight norm renaming
|
# Rename weight norm parametrizations to match changes across torch versions.
|
||||||
|
# Impacts a number of speech/wav2vec models. e.g. Hubert, Wav2Vec2, and others.
|
||||||
|
# This rename is not logged.
|
||||||
if hasattr(nn.utils.parametrizations, "weight_norm"):
|
if hasattr(nn.utils.parametrizations, "weight_norm"):
|
||||||
if "weight_g" in key:
|
if key.endswith("weight_g"):
|
||||||
return key.replace("weight_g", "parametrizations.weight.original0")
|
return key.replace("weight_g", "parametrizations.weight.original0"), True
|
||||||
if "weight_v" in key:
|
if key.endswith("weight_v"):
|
||||||
return key.replace("weight_v", "parametrizations.weight.original1")
|
return key.replace("weight_v", "parametrizations.weight.original1"), True
|
||||||
else:
|
else:
|
||||||
if "parametrizations.weight.original0" in key:
|
if key.endswith("parametrizations.weight.original0"):
|
||||||
return key.replace("parametrizations.weight.original0", "weight_g")
|
return key.replace("parametrizations.weight.original0", "weight_g"), True
|
||||||
if "parametrizations.weight.original1" in key:
|
if key.endswith("parametrizations.weight.original1"):
|
||||||
return key.replace("parametrizations.weight.original1", "weight_v")
|
return key.replace("parametrizations.weight.original1", "weight_v"), True
|
||||||
return key
|
|
||||||
|
return key, False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _fix_state_dict_keys_on_load(cls, state_dict):
|
def _fix_state_dict_keys_on_load(cls, state_dict):
|
||||||
@ -4397,15 +4402,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
renamed_keys = {}
|
renamed_keys = {}
|
||||||
state_dict_keys = list(state_dict.keys())
|
state_dict_keys = list(state_dict.keys())
|
||||||
for key in state_dict_keys:
|
for key in state_dict_keys:
|
||||||
new_key = cls._fix_state_dict_key_on_load(key)
|
new_key, has_changed = cls._fix_state_dict_key_on_load(key)
|
||||||
if new_key != key:
|
if has_changed:
|
||||||
state_dict[new_key] = state_dict.pop(key)
|
state_dict[new_key] = state_dict.pop(key)
|
||||||
|
|
||||||
# add it once for logging
|
# track gamma/beta rename for logging
|
||||||
if "gamma" in key and "gamma" not in renamed_keys:
|
if key.endswith("LayerNorm.gamma"):
|
||||||
renamed_keys["gamma"] = (key, new_key)
|
renamed_keys["LayerNorm.gamma"] = (key, new_key)
|
||||||
if "beta" in key and "beta" not in renamed_keys:
|
elif key.endswith("LayerNorm.beta"):
|
||||||
renamed_keys["beta"] = (key, new_key)
|
renamed_keys["LayerNorm.beta"] = (key, new_key)
|
||||||
|
|
||||||
if renamed_keys:
|
if renamed_keys:
|
||||||
warning_msg = f"A pretrained model of type `{cls.__name__}` "
|
warning_msg = f"A pretrained model of type `{cls.__name__}` "
|
||||||
@ -4418,19 +4423,19 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _fix_state_dict_key_on_save(key):
|
def _fix_state_dict_key_on_save(key) -> Tuple[str, bool]:
|
||||||
"""
|
"""
|
||||||
Similar to `_fix_state_dict_key_on_load` allows to define hook for state dict key renaming on model save.
|
Similar to `_fix_state_dict_key_on_load` allows to define hook for state dict key renaming on model save.
|
||||||
Do nothing by default, but can be overriden in particular models.
|
Do nothing by default, but can be overridden in particular models.
|
||||||
"""
|
"""
|
||||||
return key
|
return key, False
|
||||||
|
|
||||||
def _fix_state_dict_keys_on_save(self, state_dict):
|
def _fix_state_dict_keys_on_save(self, state_dict):
|
||||||
"""
|
"""
|
||||||
Similar to `_fix_state_dict_keys_on_load` allows to define hook for state dict key renaming on model save.
|
Similar to `_fix_state_dict_keys_on_load` allows to define hook for state dict key renaming on model save.
|
||||||
Apply `_fix_state_dict_key_on_save` to all keys in `state_dict`.
|
Apply `_fix_state_dict_key_on_save` to all keys in `state_dict`.
|
||||||
"""
|
"""
|
||||||
return {self._fix_state_dict_key_on_save(key): value for key, value in state_dict.items()}
|
return {self._fix_state_dict_key_on_save(key)[0]: value for key, value in state_dict.items()}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _load_pretrained_model(
|
def _load_pretrained_model(
|
||||||
@ -4488,7 +4493,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
expected_keys = hf_quantizer.update_expected_keys(model, expected_keys, loaded_keys)
|
expected_keys = hf_quantizer.update_expected_keys(model, expected_keys, loaded_keys)
|
||||||
|
|
||||||
original_loaded_keys = loaded_keys
|
original_loaded_keys = loaded_keys
|
||||||
loaded_keys = [cls._fix_state_dict_key_on_load(key) for key in loaded_keys]
|
loaded_keys = [cls._fix_state_dict_key_on_load(key)[0] for key in loaded_keys]
|
||||||
|
|
||||||
if len(prefix) > 0:
|
if len(prefix) > 0:
|
||||||
has_prefix_module = any(s.startswith(prefix) for s in loaded_keys)
|
has_prefix_module = any(s.startswith(prefix) for s in loaded_keys)
|
||||||
|
@ -90,22 +90,22 @@ class TimmWrapperPreTrainedModel(PreTrainedModel):
|
|||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _fix_state_dict_key_on_load(key):
|
def _fix_state_dict_key_on_load(key) -> Tuple[str, bool]:
|
||||||
"""
|
"""
|
||||||
Overrides original method that renames `gamma` and `beta` to `weight` and `bias`.
|
Overrides original method that renames `gamma` and `beta` to `weight` and `bias`.
|
||||||
We don't want this behavior for timm wrapped models. Instead, this method adds a
|
We don't want this behavior for timm wrapped models. Instead, this method adds a
|
||||||
"timm_model." prefix to enable loading official timm Hub checkpoints.
|
"timm_model." prefix to enable loading official timm Hub checkpoints.
|
||||||
"""
|
"""
|
||||||
if "timm_model." not in key:
|
if "timm_model." not in key:
|
||||||
return f"timm_model.{key}"
|
return f"timm_model.{key}", True
|
||||||
return key
|
return key, False
|
||||||
|
|
||||||
def _fix_state_dict_key_on_save(self, key):
|
def _fix_state_dict_key_on_save(self, key):
|
||||||
"""
|
"""
|
||||||
Overrides original method to remove "timm_model." prefix from state_dict keys.
|
Overrides original method to remove "timm_model." prefix from state_dict keys.
|
||||||
Makes the saved checkpoint compatible with the `timm` library.
|
Makes the saved checkpoint compatible with the `timm` library.
|
||||||
"""
|
"""
|
||||||
return key.replace("timm_model.", "")
|
return key.replace("timm_model.", ""), True
|
||||||
|
|
||||||
def load_state_dict(self, state_dict, *args, **kwargs):
|
def load_state_dict(self, state_dict, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
@ -1618,57 +1618,47 @@ class ModelUtilsTest(TestCasePlus):
|
|||||||
self.assertTrue(torch.allclose(outputs_from_saved["logits"], outputs["logits"]))
|
self.assertTrue(torch.allclose(outputs_from_saved["logits"], outputs["logits"]))
|
||||||
|
|
||||||
def test_warning_for_beta_gamma_parameters(self):
|
def test_warning_for_beta_gamma_parameters(self):
|
||||||
class TestModelGamma(PreTrainedModel):
|
class TestGammaBetaNorm(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.gamma = torch.nn.Parameter(torch.ones(1))
|
||||||
|
self.beta = torch.nn.Parameter(torch.zeros(1))
|
||||||
|
|
||||||
|
def forward(self):
|
||||||
|
return self.gamma.sum() + self.beta.sum()
|
||||||
|
|
||||||
|
class TestModelGammaBeta(PreTrainedModel):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.gamma_param = nn.Parameter(torch.ones(10))
|
self.LayerNorm = TestGammaBetaNorm()
|
||||||
self.post_init()
|
self.post_init()
|
||||||
|
|
||||||
def forward(self):
|
def forward(self):
|
||||||
return self.gamma_param.sum()
|
return self.LayerNorm()
|
||||||
|
|
||||||
logger = logging.get_logger("transformers.modeling_utils")
|
logger = logging.get_logger("transformers.modeling_utils")
|
||||||
config = PretrainedConfig()
|
config = PretrainedConfig()
|
||||||
warning_msg_gamma = "`gamma_param` -> `weight_param`"
|
warning_msg_gamma = "`LayerNorm.gamma` -> `LayerNorm.weight`"
|
||||||
model = TestModelGamma(config)
|
warning_msg_beta = "`LayerNorm.beta` -> `LayerNorm.bias`"
|
||||||
|
model = TestModelGammaBeta(config)
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
model.save_pretrained(tmp_dir)
|
model.save_pretrained(tmp_dir)
|
||||||
with LoggingLevel(logging.INFO):
|
with LoggingLevel(logging.INFO):
|
||||||
with CaptureLogger(logger) as cl1:
|
with CaptureLogger(logger) as cl1:
|
||||||
_, loading_info = TestModelGamma.from_pretrained(tmp_dir, config=config, output_loading_info=True)
|
_, loading_info = TestModelGammaBeta.from_pretrained(
|
||||||
|
tmp_dir, config=config, output_loading_info=True
|
||||||
|
)
|
||||||
|
|
||||||
missing_keys = loading_info["missing_keys"]
|
missing_keys = loading_info["missing_keys"]
|
||||||
unexpected_keys = loading_info["unexpected_keys"]
|
unexpected_keys = loading_info["unexpected_keys"]
|
||||||
self.assertIn("`TestModelGamma`", cl1.out)
|
self.assertIn("`TestModelGammaBeta`", cl1.out)
|
||||||
self.assertIn(warning_msg_gamma, cl1.out)
|
self.assertIn(warning_msg_gamma, cl1.out)
|
||||||
self.assertIn("gamma_param", missing_keys)
|
self.assertIn(warning_msg_beta, cl1.out)
|
||||||
self.assertIn("weight_param", unexpected_keys)
|
self.assertIn("LayerNorm.gamma", missing_keys)
|
||||||
|
self.assertIn("LayerNorm.weight", unexpected_keys)
|
||||||
class TestModelBeta(PreTrainedModel):
|
self.assertIn("LayerNorm.beta", missing_keys)
|
||||||
def __init__(self, config):
|
self.assertIn("LayerNorm.bias", unexpected_keys)
|
||||||
super().__init__(config)
|
|
||||||
self.beta_param = nn.Parameter(torch.ones(10))
|
|
||||||
self.post_init()
|
|
||||||
|
|
||||||
def forward(self):
|
|
||||||
return self.beta_param.sum()
|
|
||||||
|
|
||||||
warning_msg_beta = "`beta_param` -> `bias_param`"
|
|
||||||
model = TestModelBeta(config)
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
||||||
model.save_pretrained(tmp_dir)
|
|
||||||
with LoggingLevel(logging.INFO):
|
|
||||||
with CaptureLogger(logger) as cl2:
|
|
||||||
_, loading_info = TestModelBeta.from_pretrained(tmp_dir, config=config, output_loading_info=True)
|
|
||||||
|
|
||||||
missing_keys = loading_info["missing_keys"]
|
|
||||||
unexpected_keys = loading_info["unexpected_keys"]
|
|
||||||
self.assertIn("`TestModelBeta`", cl2.out)
|
|
||||||
self.assertIn(warning_msg_beta, cl2.out)
|
|
||||||
self.assertIn("beta_param", missing_keys)
|
|
||||||
self.assertIn("bias_param", unexpected_keys)
|
|
||||||
|
|
||||||
def test_isin_mps_friendly(self):
|
def test_isin_mps_friendly(self):
|
||||||
"""tests that our custom `isin_mps_friendly` matches `torch.isin`"""
|
"""tests that our custom `isin_mps_friendly` matches `torch.isin`"""
|
||||||
|
Loading…
Reference in New Issue
Block a user