mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Store transformers version info when saving the model (#9421)
* Store transformers version info when saving the model * Store transformers version info when saving the model * fix format * fix format * fix format * Update src/transformers/configuration_utils.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Update configuration_utils.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
parent
ecfcac223c
commit
7a9f1b5c99
@ -21,6 +21,7 @@ import json
|
||||
import os
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
|
||||
from . import __version__
|
||||
from .file_utils import CONFIG_NAME, cached_path, hf_bucket_url, is_remote_url
|
||||
from .utils import logging
|
||||
|
||||
@ -234,6 +235,9 @@ class PretrainedConfig(object):
|
||||
# Name or path to the pretrained checkpoint
|
||||
self._name_or_path = str(kwargs.pop("name_or_path", ""))
|
||||
|
||||
# Drop the transformers version info
|
||||
kwargs.pop("transformers_version", None)
|
||||
|
||||
# Additional attributes without default values
|
||||
for key, value in kwargs.items():
|
||||
try:
|
||||
@ -520,6 +524,7 @@ class PretrainedConfig(object):
|
||||
for key, value in config_dict.items():
|
||||
if (
|
||||
key not in default_config_dict
|
||||
or key == "transformers_version"
|
||||
or value != default_config_dict[key]
|
||||
or (key in class_config_dict and value != class_config_dict[key])
|
||||
):
|
||||
@ -537,6 +542,10 @@ class PretrainedConfig(object):
|
||||
output = copy.deepcopy(self.__dict__)
|
||||
if hasattr(self.__class__, "model_type"):
|
||||
output["model_type"] = self.__class__.model_type
|
||||
|
||||
# Transformers version when serializing the model
|
||||
output["transformers_version"] = __version__
|
||||
|
||||
return output
|
||||
|
||||
def to_json_string(self, use_diff: bool = True) -> str:
|
||||
|
Loading…
Reference in New Issue
Block a user