Quality + Fix PoC

This commit is contained in:
Sylvain Gugger 2022-01-20 16:05:51 -05:00
parent 5b7636c0fe
commit 3ba6d0d4ca
2 changed files with 15 additions and 1 deletions

View File

@ -546,6 +546,21 @@ class PretrainedConfig(PushToHubMixin):
`Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the configuration object.
"""
config_dict, kwargs = cls._get_config_dict(pretrained_model_name_or_path, **kwargs)
if "configuration_files" in config_dict:
# We may have to load another config
configuration_file = get_configuration_file(config_dict["configuration_files"])
if configuration_file != CONFIG_NAME:
config_dict, kwargs = cls._get_config_dict(
pretrained_model_name_or_path, _configuration_file=configuration_file, **kwargs
)
return config_dict, kwargs
@classmethod
def _get_config_dict(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)

View File

@ -17,7 +17,6 @@ import importlib
import io
import unittest
import requests
import transformers
# Try to import everything from transformers to ensure every object can be loaded.