[AutoModels] Fix config params handling of all PT and TF AutoModels (#5665)

* fix auto model causal lm

* leverage given functionality

* apply unused kwargs to all auto models
This commit is contained in:
Patrick von Platen 2020-07-15 09:51:14 +02:00 committed by GitHub
parent 8ab565a4be
commit ec0a945cf9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 60 additions and 20 deletions

View File

@ -498,7 +498,9 @@ class AutoModel:
""" """
config = kwargs.pop("config", None) config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
for config_class, model_class in MODEL_MAPPING.items(): for config_class, model_class in MODEL_MAPPING.items():
if isinstance(config, config_class): if isinstance(config, config_class):
@ -645,7 +647,9 @@ class AutoModelForPreTraining:
""" """
config = kwargs.pop("config", None) config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
for config_class, model_class in MODEL_FOR_PRETRAINING_MAPPING.items(): for config_class, model_class in MODEL_FOR_PRETRAINING_MAPPING.items():
if isinstance(config, config_class): if isinstance(config, config_class):
@ -802,7 +806,9 @@ class AutoModelWithLMHead:
) )
config = kwargs.pop("config", None) config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
for config_class, model_class in MODEL_WITH_LM_HEAD_MAPPING.items(): for config_class, model_class in MODEL_WITH_LM_HEAD_MAPPING.items():
if isinstance(config, config_class): if isinstance(config, config_class):
@ -937,7 +943,9 @@ class AutoModelForCausalLM:
""" """
config = kwargs.pop("config", None) config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
for config_class, model_class in MODEL_FOR_CAUSAL_LM_MAPPING.items(): for config_class, model_class in MODEL_FOR_CAUSAL_LM_MAPPING.items():
if isinstance(config, config_class): if isinstance(config, config_class):
@ -1078,7 +1086,9 @@ class AutoModelForMaskedLM:
""" """
config = kwargs.pop("config", None) config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
for config_class, model_class in MODEL_FOR_MASKED_LM_MAPPING.items(): for config_class, model_class in MODEL_FOR_MASKED_LM_MAPPING.items():
if isinstance(config, config_class): if isinstance(config, config_class):
@ -1209,7 +1219,9 @@ class AutoModelForSeq2SeqLM:
""" """
config = kwargs.pop("config", None) config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
for config_class, model_class in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.items(): for config_class, model_class in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.items():
if isinstance(config, config_class): if isinstance(config, config_class):
@ -1359,7 +1371,9 @@ class AutoModelForSequenceClassification:
""" """
config = kwargs.pop("config", None) config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
for config_class, model_class in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.items(): for config_class, model_class in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.items():
if isinstance(config, config_class): if isinstance(config, config_class):
@ -1501,7 +1515,9 @@ class AutoModelForQuestionAnswering:
""" """
config = kwargs.pop("config", None) config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
for config_class, model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.items(): for config_class, model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.items():
if isinstance(config, config_class): if isinstance(config, config_class):
@ -1651,7 +1667,9 @@ class AutoModelForTokenClassification:
""" """
config = kwargs.pop("config", None) config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
for config_class, model_class in MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.items(): for config_class, model_class in MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.items():
if isinstance(config, config_class): if isinstance(config, config_class):
@ -1703,7 +1721,9 @@ class AutoModelForMultipleChoice:
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
config = kwargs.pop("config", None) config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
for config_class, model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.items(): for config_class, model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.items():
if isinstance(config, config_class): if isinstance(config, config_class):

View File

@ -450,7 +450,9 @@ class TFAutoModel(object):
""" """
config = kwargs.pop("config", None) config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
for config_class, model_class in TF_MODEL_MAPPING.items(): for config_class, model_class in TF_MODEL_MAPPING.items():
if isinstance(config, config_class): if isinstance(config, config_class):
@ -601,7 +603,9 @@ class TFAutoModelForPreTraining(object):
""" """
config = kwargs.pop("config", None) config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
for config_class, model_class in TF_MODEL_FOR_PRETRAINING_MAPPING.items(): for config_class, model_class in TF_MODEL_FOR_PRETRAINING_MAPPING.items():
if isinstance(config, config_class): if isinstance(config, config_class):
@ -776,7 +780,9 @@ class TFAutoModelWithLMHead(object):
config = kwargs.pop("config", None) config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
for config_class, model_class in TF_MODEL_WITH_LM_HEAD_MAPPING.items(): for config_class, model_class in TF_MODEL_WITH_LM_HEAD_MAPPING.items():
# Not using isinstance() here to do not take into account inheritance # Not using isinstance() here to do not take into account inheritance
@ -923,7 +929,9 @@ class TFAutoModelForMultipleChoice:
""" """
config = kwargs.pop("config", None) config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
for config_class, model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.items(): for config_class, model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.items():
if isinstance(config, config_class): if isinstance(config, config_class):
@ -1058,7 +1066,9 @@ class TFAutoModelForCausalLM:
""" """
config = kwargs.pop("config", None) config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
for config_class, model_class in TF_MODEL_FOR_CAUSAL_LM_MAPPING.items(): for config_class, model_class in TF_MODEL_FOR_CAUSAL_LM_MAPPING.items():
if isinstance(config, config_class): if isinstance(config, config_class):
@ -1198,7 +1208,9 @@ class TFAutoModelForMaskedLM:
""" """
config = kwargs.pop("config", None) config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
for config_class, model_class in TF_MODEL_FOR_MASKED_LM_MAPPING.items(): for config_class, model_class in TF_MODEL_FOR_MASKED_LM_MAPPING.items():
if isinstance(config, config_class): if isinstance(config, config_class):
@ -1323,7 +1335,9 @@ class TFAutoModelForSeq2SeqLM:
""" """
config = kwargs.pop("config", None) config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
for config_class, model_class in TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.items(): for config_class, model_class in TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.items():
if isinstance(config, config_class): if isinstance(config, config_class):
@ -1482,7 +1496,9 @@ class TFAutoModelForSequenceClassification(object):
""" """
config = kwargs.pop("config", None) config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
for config_class, model_class in TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.items(): for config_class, model_class in TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.items():
if isinstance(config, config_class): if isinstance(config, config_class):
@ -1644,7 +1660,9 @@ class TFAutoModelForQuestionAnswering(object):
""" """
config = kwargs.pop("config", None) config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
for config_class, model_class in TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING.items(): for config_class, model_class in TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING.items():
if isinstance(config, config_class): if isinstance(config, config_class):
@ -1775,7 +1793,9 @@ class TFAutoModelForTokenClassification:
""" """
config = kwargs.pop("config", None) config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
for config_class, model_class in TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.items(): for config_class, model_class in TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.items():
if isinstance(config, config_class): if isinstance(config, config_class):