Update JukeboxConfig.from_pretrained (#24443)

fix

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2023-06-23 15:00:52 +02:00 committed by GitHub
parent 8767958fc1
commit 1e9da2b0a6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -353,6 +353,8 @@ class JukeboxPriorConfig(PretrainedConfig):
def from_pretrained(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], level=0, **kwargs
) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
# get the prior config dict if we are loading from JukeboxConfig
@ -486,6 +488,8 @@ class JukeboxVQVAEConfig(PretrainedConfig):
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
# get the text config dict if we are loading from CLIPConfig