Support for transformers explicit filename (#38152)

* Support for transformers explicit filename

* Tests

* Rerun tests
This commit is contained in:
Lysandre Debut 2025-05-19 14:33:47 +02:00 committed by GitHub
parent dbb9813dff
commit 003deb16f1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 101 additions and 2 deletions

View File

@ -408,6 +408,10 @@ class PretrainedConfig(PushToHubMixin):
repo_id = self._create_repo(repo_id, **kwargs)
files_timestamps = self._get_files_timestamps(save_directory)
# This attribute is important to know on load, but should not be serialized on save.
if "transformers_weights" in self:
delattr(self, "transformers_weights")
# If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
# loaded from the Hub.
if self._auto_class is not None:

View File

@ -881,6 +881,7 @@ def _get_resolved_checkpoint_files(
user_agent: dict,
revision: str,
commit_hash: Optional[str],
transformers_explicit_filename: Optional[str] = None,
) -> Tuple[Optional[List[str]], Optional[Dict]]:
"""Get all the checkpoint filenames based on `pretrained_model_name_or_path`, and optional metadata if the
checkpoints are sharded.
@ -892,7 +893,11 @@ def _get_resolved_checkpoint_files(
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
is_local = os.path.isdir(pretrained_model_name_or_path)
if is_local:
if from_tf and os.path.isfile(
if transformers_explicit_filename is not None:
# If the filename is explicitly defined, load this by default.
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, transformers_explicit_filename)
is_sharded = transformers_explicit_filename.endswith(".safetensors.index.json")
elif from_tf and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")
):
# Load from a TF 1.0 checkpoint in priority if from_tf
@ -980,7 +985,10 @@ def _get_resolved_checkpoint_files(
resolved_archive_file = download_url(pretrained_model_name_or_path)
else:
# set correct filename
if from_tf:
if transformers_explicit_filename is not None:
filename = transformers_explicit_filename
is_sharded = transformers_explicit_filename.endswith(".safetensors.index.json")
elif from_tf:
filename = TF2_WEIGHTS_NAME
elif from_flax:
filename = FLAX_WEIGHTS_NAME
@ -4362,6 +4370,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
model_kwargs = kwargs
transformers_explicit_filename = getattr(config, "transformers_weights", None)
if transformers_explicit_filename is not None:
if not transformers_explicit_filename.endswith(
".safetensors"
) and not transformers_explicit_filename.endswith(".safetensors.index.json"):
raise ValueError(
"The transformers file in the config seems to be incorrect: it is neither a safetensors file "
"(*.safetensors) nor a safetensors index file (*.safetensors.index.json): "
f"{transformers_explicit_filename}"
)
pre_quantized = hasattr(config, "quantization_config")
if pre_quantized and not AutoHfQuantizer.supports_quant_method(config.quantization_config):
pre_quantized = False
@ -4430,6 +4450,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
user_agent=user_agent,
revision=revision,
commit_hash=commit_hash,
transformers_explicit_filename=transformers_explicit_filename,
)
is_sharded = sharded_metadata is not None

View File

@ -1958,6 +1958,80 @@ class ModelUtilsTest(TestCasePlus):
except subprocess.CalledProcessError as e:
raise Exception(f"The following error was captured: {e.stderr}")
def test_explicit_transformers_weights(self):
"""
Transformers supports loading from repos where the weights file is explicitly set in the config.
When loading a config file, transformers will see whether `transformers_weights` is defined in the config.
If so, it will load from that file.
Here, we ensure that the correct file is loaded.
"""
model = BertModel.from_pretrained("hf-internal-testing/explicit_transformers_weight_in_config")
self.assertEqual(model.num_parameters(), 87929)
def test_explicit_transformers_weights_index(self):
"""
Transformers supports loading from repos where the weights file is explicitly set in the config.
When loading a config file, transformers will see whether `transformers_weights` is defined in the config.
If so, it will load from that file.
Here, we ensure that the correct file is loaded, given the file is an index of multiple weights.
"""
model = BertModel.from_pretrained("hf-internal-testing/explicit_transformers_weight_in_config_sharded")
self.assertEqual(model.num_parameters(), 87929)
def test_explicit_transformers_weights_save_and_reload(self):
"""
Transformers supports loading from repos where the weights file is explicitly set in the config.
When loading a config file, transformers will see whether `transformers_weights` is defined in the config.
If so, it will load from that file.
When saving the model, we should be careful not to safe the `transformers_weights` attribute in the config;
otherwise, transformers will try to load from that file whereas it should simply load from the default file.
We test that for a non-sharded repo.
"""
model = BertModel.from_pretrained("hf-internal-testing/explicit_transformers_weight_in_config")
explicit_transformers_weights = model.config.transformers_weights
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
# The config should not have a mention of transformers_weights
with open(os.path.join(tmpdirname, "config.json")) as f:
config = json.loads(f.read())
self.assertFalse("transformers_weights" in config)
# The serialized weights should be in model.safetensors and not the transformers_weights
self.assertTrue(explicit_transformers_weights not in os.listdir(tmpdirname))
self.assertTrue("model.safetensors" in os.listdir(tmpdirname))
def test_explicit_transformers_weights_index_save_and_reload(self):
"""
Transformers supports loading from repos where the weights file is explicitly set in the config.
When loading a config file, transformers will see whether `transformers_weights` is defined in the config.
If so, it will load from that file.
When saving the model, we should be careful not to safe the `transformers_weights` attribute in the config;
otherwise, transformers will try to load from that file whereas it should simply load from the default file.
We test that for a sharded repo.
"""
model = BertModel.from_pretrained("hf-internal-testing/explicit_transformers_weight_in_config_sharded")
explicit_transformers_weights = model.config.transformers_weights
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname, max_shard_size="100kb")
# The config should not have a mention of transformers_weights
with open(os.path.join(tmpdirname, "config.json")) as f:
config = json.loads(f.read())
self.assertFalse("transformers_weights" in config)
# The serialized weights should be in model.safetensors and not the transformers_weights
self.assertTrue(explicit_transformers_weights not in os.listdir(tmpdirname))
self.assertTrue("model.safetensors.index.json" in os.listdir(tmpdirname))
@slow
@require_torch