diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index b37f968ae92..e831ba36130 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -104,6 +104,8 @@ from .utils.quantization_config import BitsAndBytesConfig, QuantizationMethod XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0").upper() XLA_DOWNCAST_BF16 = os.environ.get("XLA_DOWNCAST_BF16", "0").upper() +PARAM_RENAME_WARNING = "A parameter name that contains `{}` will be renamed internally to `{}`. Please use a different name to suppress this warning." + if is_accelerate_available(): from accelerate import dispatch_model, infer_auto_device_map, init_empty_weights @@ -662,8 +664,10 @@ def _load_state_dict_into_model(model_to_load, state_dict, start_prefix): for key in state_dict.keys(): new_key = None if "gamma" in key: + logger.warning(PARAM_RENAME_WARNING.format("gamma", "weight")) new_key = key.replace("gamma", "weight") if "beta" in key: + logger.warning(PARAM_RENAME_WARNING.format("beta", "bias")) new_key = key.replace("beta", "bias") if new_key: old_keys.append(key) @@ -807,8 +811,10 @@ def _load_state_dict_into_meta_model( for key in state_dict.keys(): new_key = None if "gamma" in key: + logger.warning(PARAM_RENAME_WARNING.format("gamma", "weight")) new_key = key.replace("gamma", "weight") if "beta" in key: + logger.warning(PARAM_RENAME_WARNING.format("beta", "bias")) new_key = key.replace("beta", "bias") if new_key: old_keys.append(key) diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index c86c340017b..83c8ec8499b 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -1511,6 +1511,57 @@ class ModelUtilsTest(TestCasePlus): outputs_from_saved = new_model(input_ids) self.assertTrue(torch.allclose(outputs_from_saved["logits"], outputs["logits"])) + def test_warning_for_beta_gamma_parameters(self): + class TestModelGamma(PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.gamma_param = nn.Parameter(torch.ones(10)) + self.post_init() + + def forward(self): + return self.gamma_param.sum() + + logger = logging.get_logger("transformers.modeling_utils") + config = PretrainedConfig() + warning_msg_gamma = "A parameter name that contains `gamma` will be renamed internally" + model = TestModelGamma(config) + + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir) + with LoggingLevel(logging.WARNING): + with CaptureLogger(logger) as cl1: + _, loading_info = TestModelGamma.from_pretrained(tmp_dir, config=config, output_loading_info=True) + + missing_keys = loading_info["missing_keys"] + unexpected_keys = loading_info["unexpected_keys"] + self.assertIn(warning_msg_gamma, cl1.out) + self.assertIn("gamma_param", missing_keys) + self.assertIn("weight_param", unexpected_keys) + + class TestModelBeta(PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.beta_param = nn.Parameter(torch.ones(10)) + self.post_init() + + def forward(self): + return self.beta_param.sum() + + warning_msg_beta = "A parameter name that contains `beta` will be renamed internally" + model = TestModelBeta(config) + + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir) + with LoggingLevel(logging.WARNING): + with CaptureLogger(logger) as cl2: + _, loading_info = TestModelBeta.from_pretrained(tmp_dir, config=config, output_loading_info=True) + + missing_keys = loading_info["missing_keys"] + unexpected_keys = loading_info["unexpected_keys"] + self.assertIn(warning_msg_beta, cl2.out) + self.assertIn("beta_param", missing_keys) + self.assertIn("bias_param", unexpected_keys) + @slow @require_torch