mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[Config, Serialization] more readable config serialization (#3797)
* better config serialization * finish configuration utils
This commit is contained in:
parent
8b63a01d95
commit
e9d0bc027a
@ -141,7 +141,7 @@ class PretrainedConfig(object):
|
||||
# If we save using the predefined names, we can load using `from_pretrained`
|
||||
output_config_file = os.path.join(save_directory, CONFIG_NAME)
|
||||
|
||||
self.to_json_file(output_config_file)
|
||||
self.to_json_file(output_config_file, use_diff=True)
|
||||
logger.info("Configuration saved in {}".format(output_config_file))
|
||||
|
||||
@classmethod
|
||||
@ -353,6 +353,29 @@ class PretrainedConfig(object):
|
||||
def __repr__(self):
|
||||
return "{} {}".format(self.__class__.__name__, self.to_json_string())
|
||||
|
||||
def to_diff_dict(self):
|
||||
"""
|
||||
Removes all attributes from config which correspond to the default
|
||||
config attributes for better readability and serializes to a Python
|
||||
dictionary.
|
||||
|
||||
Returns:
|
||||
:obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
|
||||
"""
|
||||
config_dict = self.to_dict()
|
||||
|
||||
# get the default config dict
|
||||
default_config_dict = PretrainedConfig().to_dict()
|
||||
|
||||
serializable_config_dict = {}
|
||||
|
||||
# only serialize values that differ from the default config
|
||||
for key, value in config_dict.items():
|
||||
if key not in default_config_dict or value != default_config_dict[key]:
|
||||
serializable_config_dict[key] = value
|
||||
|
||||
return serializable_config_dict
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
Serializes this instance to a Python dictionary.
|
||||
@ -365,25 +388,35 @@ class PretrainedConfig(object):
|
||||
output["model_type"] = self.__class__.model_type
|
||||
return output
|
||||
|
||||
def to_json_string(self):
|
||||
def to_json_string(self, use_diff=True):
|
||||
"""
|
||||
Serializes this instance to a JSON string.
|
||||
|
||||
Args:
|
||||
use_diff (:obj:`bool`):
|
||||
If set to True, only the difference between the config instance and the default PretrainedConfig() is serialized to JSON string.
|
||||
|
||||
Returns:
|
||||
:obj:`string`: String containing all the attributes that make up this configuration instance in JSON format.
|
||||
"""
|
||||
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
||||
if use_diff is True:
|
||||
config_dict = self.to_diff_dict()
|
||||
else:
|
||||
config_dict = self.to_dict()
|
||||
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
|
||||
|
||||
def to_json_file(self, json_file_path):
|
||||
def to_json_file(self, json_file_path, use_diff=True):
|
||||
"""
|
||||
Save this instance to a json file.
|
||||
|
||||
Args:
|
||||
json_file_path (:obj:`string`):
|
||||
Path to the JSON file in which this configuration instance's parameters will be saved.
|
||||
use_diff (:obj:`bool`):
|
||||
If set to True, only the difference between the config instance and the default PretrainedConfig() is serialized to JSON file.
|
||||
"""
|
||||
with open(json_file_path, "w", encoding="utf-8") as writer:
|
||||
writer.write(self.to_json_string())
|
||||
writer.write(self.to_json_string(use_diff=use_diff))
|
||||
|
||||
def update(self, config_dict: Dict):
|
||||
"""
|
||||
|
Loading…
Reference in New Issue
Block a user