mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[AutoModels] Fix config params handling of all PT and TF AutoModels (#5665)
* fix auto model causal lm * leverage given functionality * apply unused kwargs to all auto models
This commit is contained in:
parent
8ab565a4be
commit
ec0a945cf9
@ -498,7 +498,9 @@ class AutoModel:
|
||||
"""
|
||||
config = kwargs.pop("config", None)
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
config, kwargs = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
|
||||
)
|
||||
|
||||
for config_class, model_class in MODEL_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
@ -645,7 +647,9 @@ class AutoModelForPreTraining:
|
||||
"""
|
||||
config = kwargs.pop("config", None)
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
config, kwargs = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
|
||||
)
|
||||
|
||||
for config_class, model_class in MODEL_FOR_PRETRAINING_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
@ -802,7 +806,9 @@ class AutoModelWithLMHead:
|
||||
)
|
||||
config = kwargs.pop("config", None)
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
config, kwargs = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
|
||||
)
|
||||
|
||||
for config_class, model_class in MODEL_WITH_LM_HEAD_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
@ -937,7 +943,9 @@ class AutoModelForCausalLM:
|
||||
"""
|
||||
config = kwargs.pop("config", None)
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
config, kwargs = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
|
||||
)
|
||||
|
||||
for config_class, model_class in MODEL_FOR_CAUSAL_LM_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
@ -1078,7 +1086,9 @@ class AutoModelForMaskedLM:
|
||||
"""
|
||||
config = kwargs.pop("config", None)
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
config, kwargs = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
|
||||
)
|
||||
|
||||
for config_class, model_class in MODEL_FOR_MASKED_LM_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
@ -1209,7 +1219,9 @@ class AutoModelForSeq2SeqLM:
|
||||
"""
|
||||
config = kwargs.pop("config", None)
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
config, kwargs = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
|
||||
)
|
||||
|
||||
for config_class, model_class in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
@ -1359,7 +1371,9 @@ class AutoModelForSequenceClassification:
|
||||
"""
|
||||
config = kwargs.pop("config", None)
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
config, kwargs = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
|
||||
)
|
||||
|
||||
for config_class, model_class in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
@ -1501,7 +1515,9 @@ class AutoModelForQuestionAnswering:
|
||||
"""
|
||||
config = kwargs.pop("config", None)
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
config, kwargs = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
|
||||
)
|
||||
|
||||
for config_class, model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
@ -1651,7 +1667,9 @@ class AutoModelForTokenClassification:
|
||||
"""
|
||||
config = kwargs.pop("config", None)
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
config, kwargs = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
|
||||
)
|
||||
|
||||
for config_class, model_class in MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
@ -1703,7 +1721,9 @@ class AutoModelForMultipleChoice:
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
config = kwargs.pop("config", None)
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
config, kwargs = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
|
||||
)
|
||||
|
||||
for config_class, model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
|
@ -450,7 +450,9 @@ class TFAutoModel(object):
|
||||
"""
|
||||
config = kwargs.pop("config", None)
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
config, kwargs = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
|
||||
)
|
||||
|
||||
for config_class, model_class in TF_MODEL_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
@ -601,7 +603,9 @@ class TFAutoModelForPreTraining(object):
|
||||
"""
|
||||
config = kwargs.pop("config", None)
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
config, kwargs = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
|
||||
)
|
||||
|
||||
for config_class, model_class in TF_MODEL_FOR_PRETRAINING_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
@ -776,7 +780,9 @@ class TFAutoModelWithLMHead(object):
|
||||
config = kwargs.pop("config", None)
|
||||
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
config, kwargs = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
|
||||
)
|
||||
|
||||
for config_class, model_class in TF_MODEL_WITH_LM_HEAD_MAPPING.items():
|
||||
# Not using isinstance() here to do not take into account inheritance
|
||||
@ -923,7 +929,9 @@ class TFAutoModelForMultipleChoice:
|
||||
"""
|
||||
config = kwargs.pop("config", None)
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
config, kwargs = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
|
||||
)
|
||||
|
||||
for config_class, model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
@ -1058,7 +1066,9 @@ class TFAutoModelForCausalLM:
|
||||
"""
|
||||
config = kwargs.pop("config", None)
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
config, kwargs = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
|
||||
)
|
||||
|
||||
for config_class, model_class in TF_MODEL_FOR_CAUSAL_LM_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
@ -1198,7 +1208,9 @@ class TFAutoModelForMaskedLM:
|
||||
"""
|
||||
config = kwargs.pop("config", None)
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
config, kwargs = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
|
||||
)
|
||||
|
||||
for config_class, model_class in TF_MODEL_FOR_MASKED_LM_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
@ -1323,7 +1335,9 @@ class TFAutoModelForSeq2SeqLM:
|
||||
"""
|
||||
config = kwargs.pop("config", None)
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
config, kwargs = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
|
||||
)
|
||||
|
||||
for config_class, model_class in TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
@ -1482,7 +1496,9 @@ class TFAutoModelForSequenceClassification(object):
|
||||
"""
|
||||
config = kwargs.pop("config", None)
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
config, kwargs = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
|
||||
)
|
||||
|
||||
for config_class, model_class in TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
@ -1644,7 +1660,9 @@ class TFAutoModelForQuestionAnswering(object):
|
||||
"""
|
||||
config = kwargs.pop("config", None)
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
config, kwargs = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
|
||||
)
|
||||
|
||||
for config_class, model_class in TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
@ -1775,7 +1793,9 @@ class TFAutoModelForTokenClassification:
|
||||
"""
|
||||
config = kwargs.pop("config", None)
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
config, kwargs = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
|
||||
)
|
||||
|
||||
for config_class, model_class in TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
|
Loading…
Reference in New Issue
Block a user