mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00
Support for transformers explicit filename (#38152)
* Support for transformers explicit filename * Tests * Rerun tests
This commit is contained in:
parent
dbb9813dff
commit
003deb16f1
@ -408,6 +408,10 @@ class PretrainedConfig(PushToHubMixin):
|
|||||||
repo_id = self._create_repo(repo_id, **kwargs)
|
repo_id = self._create_repo(repo_id, **kwargs)
|
||||||
files_timestamps = self._get_files_timestamps(save_directory)
|
files_timestamps = self._get_files_timestamps(save_directory)
|
||||||
|
|
||||||
|
# This attribute is important to know on load, but should not be serialized on save.
|
||||||
|
if "transformers_weights" in self:
|
||||||
|
delattr(self, "transformers_weights")
|
||||||
|
|
||||||
# If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
|
# If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
|
||||||
# loaded from the Hub.
|
# loaded from the Hub.
|
||||||
if self._auto_class is not None:
|
if self._auto_class is not None:
|
||||||
|
@ -881,6 +881,7 @@ def _get_resolved_checkpoint_files(
|
|||||||
user_agent: dict,
|
user_agent: dict,
|
||||||
revision: str,
|
revision: str,
|
||||||
commit_hash: Optional[str],
|
commit_hash: Optional[str],
|
||||||
|
transformers_explicit_filename: Optional[str] = None,
|
||||||
) -> Tuple[Optional[List[str]], Optional[Dict]]:
|
) -> Tuple[Optional[List[str]], Optional[Dict]]:
|
||||||
"""Get all the checkpoint filenames based on `pretrained_model_name_or_path`, and optional metadata if the
|
"""Get all the checkpoint filenames based on `pretrained_model_name_or_path`, and optional metadata if the
|
||||||
checkpoints are sharded.
|
checkpoints are sharded.
|
||||||
@ -892,7 +893,11 @@ def _get_resolved_checkpoint_files(
|
|||||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||||
is_local = os.path.isdir(pretrained_model_name_or_path)
|
is_local = os.path.isdir(pretrained_model_name_or_path)
|
||||||
if is_local:
|
if is_local:
|
||||||
if from_tf and os.path.isfile(
|
if transformers_explicit_filename is not None:
|
||||||
|
# If the filename is explicitly defined, load this by default.
|
||||||
|
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, transformers_explicit_filename)
|
||||||
|
is_sharded = transformers_explicit_filename.endswith(".safetensors.index.json")
|
||||||
|
elif from_tf and os.path.isfile(
|
||||||
os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")
|
os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")
|
||||||
):
|
):
|
||||||
# Load from a TF 1.0 checkpoint in priority if from_tf
|
# Load from a TF 1.0 checkpoint in priority if from_tf
|
||||||
@ -980,7 +985,10 @@ def _get_resolved_checkpoint_files(
|
|||||||
resolved_archive_file = download_url(pretrained_model_name_or_path)
|
resolved_archive_file = download_url(pretrained_model_name_or_path)
|
||||||
else:
|
else:
|
||||||
# set correct filename
|
# set correct filename
|
||||||
if from_tf:
|
if transformers_explicit_filename is not None:
|
||||||
|
filename = transformers_explicit_filename
|
||||||
|
is_sharded = transformers_explicit_filename.endswith(".safetensors.index.json")
|
||||||
|
elif from_tf:
|
||||||
filename = TF2_WEIGHTS_NAME
|
filename = TF2_WEIGHTS_NAME
|
||||||
elif from_flax:
|
elif from_flax:
|
||||||
filename = FLAX_WEIGHTS_NAME
|
filename = FLAX_WEIGHTS_NAME
|
||||||
@ -4362,6 +4370,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
|
|
||||||
model_kwargs = kwargs
|
model_kwargs = kwargs
|
||||||
|
|
||||||
|
transformers_explicit_filename = getattr(config, "transformers_weights", None)
|
||||||
|
|
||||||
|
if transformers_explicit_filename is not None:
|
||||||
|
if not transformers_explicit_filename.endswith(
|
||||||
|
".safetensors"
|
||||||
|
) and not transformers_explicit_filename.endswith(".safetensors.index.json"):
|
||||||
|
raise ValueError(
|
||||||
|
"The transformers file in the config seems to be incorrect: it is neither a safetensors file "
|
||||||
|
"(*.safetensors) nor a safetensors index file (*.safetensors.index.json): "
|
||||||
|
f"{transformers_explicit_filename}"
|
||||||
|
)
|
||||||
|
|
||||||
pre_quantized = hasattr(config, "quantization_config")
|
pre_quantized = hasattr(config, "quantization_config")
|
||||||
if pre_quantized and not AutoHfQuantizer.supports_quant_method(config.quantization_config):
|
if pre_quantized and not AutoHfQuantizer.supports_quant_method(config.quantization_config):
|
||||||
pre_quantized = False
|
pre_quantized = False
|
||||||
@ -4430,6 +4450,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
user_agent=user_agent,
|
user_agent=user_agent,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
commit_hash=commit_hash,
|
commit_hash=commit_hash,
|
||||||
|
transformers_explicit_filename=transformers_explicit_filename,
|
||||||
)
|
)
|
||||||
|
|
||||||
is_sharded = sharded_metadata is not None
|
is_sharded = sharded_metadata is not None
|
||||||
|
@ -1958,6 +1958,80 @@ class ModelUtilsTest(TestCasePlus):
|
|||||||
except subprocess.CalledProcessError as e:
|
except subprocess.CalledProcessError as e:
|
||||||
raise Exception(f"The following error was captured: {e.stderr}")
|
raise Exception(f"The following error was captured: {e.stderr}")
|
||||||
|
|
||||||
|
def test_explicit_transformers_weights(self):
|
||||||
|
"""
|
||||||
|
Transformers supports loading from repos where the weights file is explicitly set in the config.
|
||||||
|
When loading a config file, transformers will see whether `transformers_weights` is defined in the config.
|
||||||
|
If so, it will load from that file.
|
||||||
|
|
||||||
|
Here, we ensure that the correct file is loaded.
|
||||||
|
"""
|
||||||
|
model = BertModel.from_pretrained("hf-internal-testing/explicit_transformers_weight_in_config")
|
||||||
|
self.assertEqual(model.num_parameters(), 87929)
|
||||||
|
|
||||||
|
def test_explicit_transformers_weights_index(self):
|
||||||
|
"""
|
||||||
|
Transformers supports loading from repos where the weights file is explicitly set in the config.
|
||||||
|
When loading a config file, transformers will see whether `transformers_weights` is defined in the config.
|
||||||
|
If so, it will load from that file.
|
||||||
|
|
||||||
|
Here, we ensure that the correct file is loaded, given the file is an index of multiple weights.
|
||||||
|
"""
|
||||||
|
model = BertModel.from_pretrained("hf-internal-testing/explicit_transformers_weight_in_config_sharded")
|
||||||
|
self.assertEqual(model.num_parameters(), 87929)
|
||||||
|
|
||||||
|
def test_explicit_transformers_weights_save_and_reload(self):
|
||||||
|
"""
|
||||||
|
Transformers supports loading from repos where the weights file is explicitly set in the config.
|
||||||
|
When loading a config file, transformers will see whether `transformers_weights` is defined in the config.
|
||||||
|
If so, it will load from that file.
|
||||||
|
|
||||||
|
When saving the model, we should be careful not to safe the `transformers_weights` attribute in the config;
|
||||||
|
otherwise, transformers will try to load from that file whereas it should simply load from the default file.
|
||||||
|
|
||||||
|
We test that for a non-sharded repo.
|
||||||
|
"""
|
||||||
|
model = BertModel.from_pretrained("hf-internal-testing/explicit_transformers_weight_in_config")
|
||||||
|
explicit_transformers_weights = model.config.transformers_weights
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
model.save_pretrained(tmpdirname)
|
||||||
|
|
||||||
|
# The config should not have a mention of transformers_weights
|
||||||
|
with open(os.path.join(tmpdirname, "config.json")) as f:
|
||||||
|
config = json.loads(f.read())
|
||||||
|
self.assertFalse("transformers_weights" in config)
|
||||||
|
|
||||||
|
# The serialized weights should be in model.safetensors and not the transformers_weights
|
||||||
|
self.assertTrue(explicit_transformers_weights not in os.listdir(tmpdirname))
|
||||||
|
self.assertTrue("model.safetensors" in os.listdir(tmpdirname))
|
||||||
|
|
||||||
|
def test_explicit_transformers_weights_index_save_and_reload(self):
|
||||||
|
"""
|
||||||
|
Transformers supports loading from repos where the weights file is explicitly set in the config.
|
||||||
|
When loading a config file, transformers will see whether `transformers_weights` is defined in the config.
|
||||||
|
If so, it will load from that file.
|
||||||
|
|
||||||
|
When saving the model, we should be careful not to safe the `transformers_weights` attribute in the config;
|
||||||
|
otherwise, transformers will try to load from that file whereas it should simply load from the default file.
|
||||||
|
|
||||||
|
We test that for a sharded repo.
|
||||||
|
"""
|
||||||
|
model = BertModel.from_pretrained("hf-internal-testing/explicit_transformers_weight_in_config_sharded")
|
||||||
|
explicit_transformers_weights = model.config.transformers_weights
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
model.save_pretrained(tmpdirname, max_shard_size="100kb")
|
||||||
|
|
||||||
|
# The config should not have a mention of transformers_weights
|
||||||
|
with open(os.path.join(tmpdirname, "config.json")) as f:
|
||||||
|
config = json.loads(f.read())
|
||||||
|
self.assertFalse("transformers_weights" in config)
|
||||||
|
|
||||||
|
# The serialized weights should be in model.safetensors and not the transformers_weights
|
||||||
|
self.assertTrue(explicit_transformers_weights not in os.listdir(tmpdirname))
|
||||||
|
self.assertTrue("model.safetensors.index.json" in os.listdir(tmpdirname))
|
||||||
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch
|
@require_torch
|
||||||
|
Loading…
Reference in New Issue
Block a user