diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 1c0bff6cf39..18cc0ff3a5d 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -408,6 +408,10 @@ class PretrainedConfig(PushToHubMixin): repo_id = self._create_repo(repo_id, **kwargs) files_timestamps = self._get_files_timestamps(save_directory) + # This attribute is important to know on load, but should not be serialized on save. + if "transformers_weights" in self: + delattr(self, "transformers_weights") + # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be # loaded from the Hub. if self._auto_class is not None: diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 459cd7aca55..6e4c631d481 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -881,6 +881,7 @@ def _get_resolved_checkpoint_files( user_agent: dict, revision: str, commit_hash: Optional[str], + transformers_explicit_filename: Optional[str] = None, ) -> Tuple[Optional[List[str]], Optional[Dict]]: """Get all the checkpoint filenames based on `pretrained_model_name_or_path`, and optional metadata if the checkpoints are sharded. @@ -892,7 +893,11 @@ def _get_resolved_checkpoint_files( pretrained_model_name_or_path = str(pretrained_model_name_or_path) is_local = os.path.isdir(pretrained_model_name_or_path) if is_local: - if from_tf and os.path.isfile( + if transformers_explicit_filename is not None: + # If the filename is explicitly defined, load this by default. + archive_file = os.path.join(pretrained_model_name_or_path, subfolder, transformers_explicit_filename) + is_sharded = transformers_explicit_filename.endswith(".safetensors.index.json") + elif from_tf and os.path.isfile( os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index") ): # Load from a TF 1.0 checkpoint in priority if from_tf @@ -980,7 +985,10 @@ def _get_resolved_checkpoint_files( resolved_archive_file = download_url(pretrained_model_name_or_path) else: # set correct filename - if from_tf: + if transformers_explicit_filename is not None: + filename = transformers_explicit_filename + is_sharded = transformers_explicit_filename.endswith(".safetensors.index.json") + elif from_tf: filename = TF2_WEIGHTS_NAME elif from_flax: filename = FLAX_WEIGHTS_NAME @@ -4362,6 +4370,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi model_kwargs = kwargs + transformers_explicit_filename = getattr(config, "transformers_weights", None) + + if transformers_explicit_filename is not None: + if not transformers_explicit_filename.endswith( + ".safetensors" + ) and not transformers_explicit_filename.endswith(".safetensors.index.json"): + raise ValueError( + "The transformers file in the config seems to be incorrect: it is neither a safetensors file " + "(*.safetensors) nor a safetensors index file (*.safetensors.index.json): " + f"{transformers_explicit_filename}" + ) + pre_quantized = hasattr(config, "quantization_config") if pre_quantized and not AutoHfQuantizer.supports_quant_method(config.quantization_config): pre_quantized = False @@ -4430,6 +4450,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi user_agent=user_agent, revision=revision, commit_hash=commit_hash, + transformers_explicit_filename=transformers_explicit_filename, ) is_sharded = sharded_metadata is not None diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index a5aef44c38d..896b8771c41 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -1958,6 +1958,80 @@ class ModelUtilsTest(TestCasePlus): except subprocess.CalledProcessError as e: raise Exception(f"The following error was captured: {e.stderr}") + def test_explicit_transformers_weights(self): + """ + Transformers supports loading from repos where the weights file is explicitly set in the config. + When loading a config file, transformers will see whether `transformers_weights` is defined in the config. + If so, it will load from that file. + + Here, we ensure that the correct file is loaded. + """ + model = BertModel.from_pretrained("hf-internal-testing/explicit_transformers_weight_in_config") + self.assertEqual(model.num_parameters(), 87929) + + def test_explicit_transformers_weights_index(self): + """ + Transformers supports loading from repos where the weights file is explicitly set in the config. + When loading a config file, transformers will see whether `transformers_weights` is defined in the config. + If so, it will load from that file. + + Here, we ensure that the correct file is loaded, given the file is an index of multiple weights. + """ + model = BertModel.from_pretrained("hf-internal-testing/explicit_transformers_weight_in_config_sharded") + self.assertEqual(model.num_parameters(), 87929) + + def test_explicit_transformers_weights_save_and_reload(self): + """ + Transformers supports loading from repos where the weights file is explicitly set in the config. + When loading a config file, transformers will see whether `transformers_weights` is defined in the config. + If so, it will load from that file. + + When saving the model, we should be careful not to safe the `transformers_weights` attribute in the config; + otherwise, transformers will try to load from that file whereas it should simply load from the default file. + + We test that for a non-sharded repo. + """ + model = BertModel.from_pretrained("hf-internal-testing/explicit_transformers_weight_in_config") + explicit_transformers_weights = model.config.transformers_weights + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + # The config should not have a mention of transformers_weights + with open(os.path.join(tmpdirname, "config.json")) as f: + config = json.loads(f.read()) + self.assertFalse("transformers_weights" in config) + + # The serialized weights should be in model.safetensors and not the transformers_weights + self.assertTrue(explicit_transformers_weights not in os.listdir(tmpdirname)) + self.assertTrue("model.safetensors" in os.listdir(tmpdirname)) + + def test_explicit_transformers_weights_index_save_and_reload(self): + """ + Transformers supports loading from repos where the weights file is explicitly set in the config. + When loading a config file, transformers will see whether `transformers_weights` is defined in the config. + If so, it will load from that file. + + When saving the model, we should be careful not to safe the `transformers_weights` attribute in the config; + otherwise, transformers will try to load from that file whereas it should simply load from the default file. + + We test that for a sharded repo. + """ + model = BertModel.from_pretrained("hf-internal-testing/explicit_transformers_weight_in_config_sharded") + explicit_transformers_weights = model.config.transformers_weights + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname, max_shard_size="100kb") + + # The config should not have a mention of transformers_weights + with open(os.path.join(tmpdirname, "config.json")) as f: + config = json.loads(f.read()) + self.assertFalse("transformers_weights" in config) + + # The serialized weights should be in model.safetensors and not the transformers_weights + self.assertTrue(explicit_transformers_weights not in os.listdir(tmpdirname)) + self.assertTrue("model.safetensors.index.json" in os.listdir(tmpdirname)) + @slow @require_torch