Store transformers version info when saving the model (#9421)

* Store transformers version info when saving the model

* Store transformers version info when saving the model

* fix format

* fix format

* fix format

* Update src/transformers/configuration_utils.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Update configuration_utils.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
Kevin Canwen Xu 2021-01-06 23:34:48 +08:00 committed by GitHub
parent ecfcac223c
commit 7a9f1b5c99
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -21,6 +21,7 @@ import json
import os
from typing import Any, Dict, Tuple, Union
from . import __version__
from .file_utils import CONFIG_NAME, cached_path, hf_bucket_url, is_remote_url
from .utils import logging
@ -234,6 +235,9 @@ class PretrainedConfig(object):
# Name or path to the pretrained checkpoint
self._name_or_path = str(kwargs.pop("name_or_path", ""))
# Drop the transformers version info
kwargs.pop("transformers_version", None)
# Additional attributes without default values
for key, value in kwargs.items():
try:
@ -520,6 +524,7 @@ class PretrainedConfig(object):
for key, value in config_dict.items():
if (
key not in default_config_dict
or key == "transformers_version"
or value != default_config_dict[key]
or (key in class_config_dict and value != class_config_dict[key])
):
@ -537,6 +542,10 @@ class PretrainedConfig(object):
output = copy.deepcopy(self.__dict__)
if hasattr(self.__class__, "model_type"):
output["model_type"] = self.__class__.model_type
# Transformers version when serializing the model
output["transformers_version"] = __version__
return output
def to_json_string(self, use_diff: bool = True) -> str: