[Config, Serialization] more readable config serialization (#3797)

* better config serialization

* finish configuration utils
This commit is contained in:
Patrick von Platen 2020-04-18 02:07:18 +02:00 committed by GitHub
parent 8b63a01d95
commit e9d0bc027a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -141,7 +141,7 @@ class PretrainedConfig(object):
# If we save using the predefined names, we can load using `from_pretrained`
output_config_file = os.path.join(save_directory, CONFIG_NAME)
self.to_json_file(output_config_file)
self.to_json_file(output_config_file, use_diff=True)
logger.info("Configuration saved in {}".format(output_config_file))
@classmethod
@ -353,6 +353,29 @@ class PretrainedConfig(object):
def __repr__(self):
return "{} {}".format(self.__class__.__name__, self.to_json_string())
def to_diff_dict(self):
"""
Removes all attributes from config which correspond to the default
config attributes for better readability and serializes to a Python
dictionary.
Returns:
:obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
config_dict = self.to_dict()
# get the default config dict
default_config_dict = PretrainedConfig().to_dict()
serializable_config_dict = {}
# only serialize values that differ from the default config
for key, value in config_dict.items():
if key not in default_config_dict or value != default_config_dict[key]:
serializable_config_dict[key] = value
return serializable_config_dict
def to_dict(self):
"""
Serializes this instance to a Python dictionary.
@ -365,25 +388,35 @@ class PretrainedConfig(object):
output["model_type"] = self.__class__.model_type
return output
def to_json_string(self):
def to_json_string(self, use_diff=True):
"""
Serializes this instance to a JSON string.
Args:
use_diff (:obj:`bool`):
If set to True, only the difference between the config instance and the default PretrainedConfig() is serialized to JSON string.
Returns:
:obj:`string`: String containing all the attributes that make up this configuration instance in JSON format.
"""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
if use_diff is True:
config_dict = self.to_diff_dict()
else:
config_dict = self.to_dict()
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
def to_json_file(self, json_file_path):
def to_json_file(self, json_file_path, use_diff=True):
"""
Save this instance to a json file.
Args:
json_file_path (:obj:`string`):
Path to the JSON file in which this configuration instance's parameters will be saved.
use_diff (:obj:`bool`):
If set to True, only the difference between the config instance and the default PretrainedConfig() is serialized to JSON file.
"""
with open(json_file_path, "w", encoding="utf-8") as writer:
writer.write(self.to_json_string())
writer.write(self.to_json_string(use_diff=use_diff))
def update(self, config_dict: Dict):
"""