mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Support for transformers explicit filename (#38152)
* Support for transformers explicit filename * Tests * Rerun tests
This commit is contained in:
parent
dbb9813dff
commit
003deb16f1
@ -408,6 +408,10 @@ class PretrainedConfig(PushToHubMixin):
|
||||
repo_id = self._create_repo(repo_id, **kwargs)
|
||||
files_timestamps = self._get_files_timestamps(save_directory)
|
||||
|
||||
# This attribute is important to know on load, but should not be serialized on save.
|
||||
if "transformers_weights" in self:
|
||||
delattr(self, "transformers_weights")
|
||||
|
||||
# If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
|
||||
# loaded from the Hub.
|
||||
if self._auto_class is not None:
|
||||
|
@ -881,6 +881,7 @@ def _get_resolved_checkpoint_files(
|
||||
user_agent: dict,
|
||||
revision: str,
|
||||
commit_hash: Optional[str],
|
||||
transformers_explicit_filename: Optional[str] = None,
|
||||
) -> Tuple[Optional[List[str]], Optional[Dict]]:
|
||||
"""Get all the checkpoint filenames based on `pretrained_model_name_or_path`, and optional metadata if the
|
||||
checkpoints are sharded.
|
||||
@ -892,7 +893,11 @@ def _get_resolved_checkpoint_files(
|
||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||
is_local = os.path.isdir(pretrained_model_name_or_path)
|
||||
if is_local:
|
||||
if from_tf and os.path.isfile(
|
||||
if transformers_explicit_filename is not None:
|
||||
# If the filename is explicitly defined, load this by default.
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, transformers_explicit_filename)
|
||||
is_sharded = transformers_explicit_filename.endswith(".safetensors.index.json")
|
||||
elif from_tf and os.path.isfile(
|
||||
os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")
|
||||
):
|
||||
# Load from a TF 1.0 checkpoint in priority if from_tf
|
||||
@ -980,7 +985,10 @@ def _get_resolved_checkpoint_files(
|
||||
resolved_archive_file = download_url(pretrained_model_name_or_path)
|
||||
else:
|
||||
# set correct filename
|
||||
if from_tf:
|
||||
if transformers_explicit_filename is not None:
|
||||
filename = transformers_explicit_filename
|
||||
is_sharded = transformers_explicit_filename.endswith(".safetensors.index.json")
|
||||
elif from_tf:
|
||||
filename = TF2_WEIGHTS_NAME
|
||||
elif from_flax:
|
||||
filename = FLAX_WEIGHTS_NAME
|
||||
@ -4362,6 +4370,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
|
||||
model_kwargs = kwargs
|
||||
|
||||
transformers_explicit_filename = getattr(config, "transformers_weights", None)
|
||||
|
||||
if transformers_explicit_filename is not None:
|
||||
if not transformers_explicit_filename.endswith(
|
||||
".safetensors"
|
||||
) and not transformers_explicit_filename.endswith(".safetensors.index.json"):
|
||||
raise ValueError(
|
||||
"The transformers file in the config seems to be incorrect: it is neither a safetensors file "
|
||||
"(*.safetensors) nor a safetensors index file (*.safetensors.index.json): "
|
||||
f"{transformers_explicit_filename}"
|
||||
)
|
||||
|
||||
pre_quantized = hasattr(config, "quantization_config")
|
||||
if pre_quantized and not AutoHfQuantizer.supports_quant_method(config.quantization_config):
|
||||
pre_quantized = False
|
||||
@ -4430,6 +4450,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
user_agent=user_agent,
|
||||
revision=revision,
|
||||
commit_hash=commit_hash,
|
||||
transformers_explicit_filename=transformers_explicit_filename,
|
||||
)
|
||||
|
||||
is_sharded = sharded_metadata is not None
|
||||
|
@ -1958,6 +1958,80 @@ class ModelUtilsTest(TestCasePlus):
|
||||
except subprocess.CalledProcessError as e:
|
||||
raise Exception(f"The following error was captured: {e.stderr}")
|
||||
|
||||
def test_explicit_transformers_weights(self):
|
||||
"""
|
||||
Transformers supports loading from repos where the weights file is explicitly set in the config.
|
||||
When loading a config file, transformers will see whether `transformers_weights` is defined in the config.
|
||||
If so, it will load from that file.
|
||||
|
||||
Here, we ensure that the correct file is loaded.
|
||||
"""
|
||||
model = BertModel.from_pretrained("hf-internal-testing/explicit_transformers_weight_in_config")
|
||||
self.assertEqual(model.num_parameters(), 87929)
|
||||
|
||||
def test_explicit_transformers_weights_index(self):
|
||||
"""
|
||||
Transformers supports loading from repos where the weights file is explicitly set in the config.
|
||||
When loading a config file, transformers will see whether `transformers_weights` is defined in the config.
|
||||
If so, it will load from that file.
|
||||
|
||||
Here, we ensure that the correct file is loaded, given the file is an index of multiple weights.
|
||||
"""
|
||||
model = BertModel.from_pretrained("hf-internal-testing/explicit_transformers_weight_in_config_sharded")
|
||||
self.assertEqual(model.num_parameters(), 87929)
|
||||
|
||||
def test_explicit_transformers_weights_save_and_reload(self):
|
||||
"""
|
||||
Transformers supports loading from repos where the weights file is explicitly set in the config.
|
||||
When loading a config file, transformers will see whether `transformers_weights` is defined in the config.
|
||||
If so, it will load from that file.
|
||||
|
||||
When saving the model, we should be careful not to safe the `transformers_weights` attribute in the config;
|
||||
otherwise, transformers will try to load from that file whereas it should simply load from the default file.
|
||||
|
||||
We test that for a non-sharded repo.
|
||||
"""
|
||||
model = BertModel.from_pretrained("hf-internal-testing/explicit_transformers_weight_in_config")
|
||||
explicit_transformers_weights = model.config.transformers_weights
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
|
||||
# The config should not have a mention of transformers_weights
|
||||
with open(os.path.join(tmpdirname, "config.json")) as f:
|
||||
config = json.loads(f.read())
|
||||
self.assertFalse("transformers_weights" in config)
|
||||
|
||||
# The serialized weights should be in model.safetensors and not the transformers_weights
|
||||
self.assertTrue(explicit_transformers_weights not in os.listdir(tmpdirname))
|
||||
self.assertTrue("model.safetensors" in os.listdir(tmpdirname))
|
||||
|
||||
def test_explicit_transformers_weights_index_save_and_reload(self):
|
||||
"""
|
||||
Transformers supports loading from repos where the weights file is explicitly set in the config.
|
||||
When loading a config file, transformers will see whether `transformers_weights` is defined in the config.
|
||||
If so, it will load from that file.
|
||||
|
||||
When saving the model, we should be careful not to safe the `transformers_weights` attribute in the config;
|
||||
otherwise, transformers will try to load from that file whereas it should simply load from the default file.
|
||||
|
||||
We test that for a sharded repo.
|
||||
"""
|
||||
model = BertModel.from_pretrained("hf-internal-testing/explicit_transformers_weight_in_config_sharded")
|
||||
explicit_transformers_weights = model.config.transformers_weights
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname, max_shard_size="100kb")
|
||||
|
||||
# The config should not have a mention of transformers_weights
|
||||
with open(os.path.join(tmpdirname, "config.json")) as f:
|
||||
config = json.loads(f.read())
|
||||
self.assertFalse("transformers_weights" in config)
|
||||
|
||||
# The serialized weights should be in model.safetensors and not the transformers_weights
|
||||
self.assertTrue(explicit_transformers_weights not in os.listdir(tmpdirname))
|
||||
self.assertTrue("model.safetensors.index.json" in os.listdir(tmpdirname))
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
|
Loading…
Reference in New Issue
Block a user