From 188574ac5078fc60060c46f01a6ca12c3020a733 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Wed, 3 Mar 2021 08:54:00 -0800 Subject: [PATCH] remap MODEL_FOR_QUESTION_ANSWERING_MAPPING classes to names auto-generated file (#10487) * remap classes to strings * missing new util * style * doc * move the autogenerated file * Trigger CI --- Makefile | 1 + src/transformers/trainer.py | 4 +- .../utils/modeling_auto_mapping.py | 34 +++++++++++ utils/class_mapping_update.py | 60 +++++++++++++++++++ 4 files changed, 97 insertions(+), 2 deletions(-) create mode 100644 src/transformers/utils/modeling_auto_mapping.py create mode 100644 utils/class_mapping_update.py diff --git a/Makefile b/Makefile index 63872a1721c..c3ac1df6239 100644 --- a/Makefile +++ b/Makefile @@ -27,6 +27,7 @@ extra_quality_checks: deps_table_update python utils/check_dummies.py python utils/check_repo.py python utils/style_doc.py src/transformers docs/source --max_len 119 + python utils/class_mapping_update.py # this target runs checks on all files quality: diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 504b852cfe5..874c3ef5230 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -61,7 +61,6 @@ from .file_utils import ( is_torch_tpu_available, ) from .modeling_utils import PreTrainedModel -from .models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING from .optimization import Adafactor, AdamW, get_scheduler from .tokenization_utils_base import PreTrainedTokenizerBase from .trainer_callback import ( @@ -104,6 +103,7 @@ from .trainer_utils import ( ) from .training_args import ParallelMode, TrainingArguments from .utils import logging +from .utils.modeling_auto_mapping import MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES _is_native_amp_available = False @@ -420,7 +420,7 @@ class Trainer: self.use_tune_checkpoints = False default_label_names = ( ["start_positions", "end_positions"] - if type(self.model) in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values() + if type(self.model).__name__ in MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES.values() else ["labels"] ) self.label_names = default_label_names if self.args.label_names is None else self.args.label_names diff --git a/src/transformers/utils/modeling_auto_mapping.py b/src/transformers/utils/modeling_auto_mapping.py new file mode 100644 index 00000000000..45424f4f029 --- /dev/null +++ b/src/transformers/utils/modeling_auto_mapping.py @@ -0,0 +1,34 @@ +# THIS FILE HAS BEEN AUTOGENERATED. To update: +# 1. modify: models/auto/modeling_auto.py +# 2. run: python utils/class_mapping_update.py +from collections import OrderedDict + + +MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( + [ + ("ConvBertConfig", "ConvBertForQuestionAnswering"), + ("LEDConfig", "LEDForQuestionAnswering"), + ("DistilBertConfig", "DistilBertForQuestionAnswering"), + ("AlbertConfig", "AlbertForQuestionAnswering"), + ("CamembertConfig", "CamembertForQuestionAnswering"), + ("BartConfig", "BartForQuestionAnswering"), + ("MBartConfig", "MBartForQuestionAnswering"), + ("LongformerConfig", "LongformerForQuestionAnswering"), + ("XLMRobertaConfig", "XLMRobertaForQuestionAnswering"), + ("RobertaConfig", "RobertaForQuestionAnswering"), + ("SqueezeBertConfig", "SqueezeBertForQuestionAnswering"), + ("BertConfig", "BertForQuestionAnswering"), + ("XLNetConfig", "XLNetForQuestionAnsweringSimple"), + ("FlaubertConfig", "FlaubertForQuestionAnsweringSimple"), + ("MobileBertConfig", "MobileBertForQuestionAnswering"), + ("XLMConfig", "XLMForQuestionAnsweringSimple"), + ("ElectraConfig", "ElectraForQuestionAnswering"), + ("ReformerConfig", "ReformerForQuestionAnswering"), + ("FunnelConfig", "FunnelForQuestionAnswering"), + ("LxmertConfig", "LxmertForQuestionAnswering"), + ("MPNetConfig", "MPNetForQuestionAnswering"), + ("DebertaConfig", "DebertaForQuestionAnswering"), + ("DebertaV2Config", "DebertaV2ForQuestionAnswering"), + ("IBertConfig", "IBertForQuestionAnswering"), + ] +) diff --git a/utils/class_mapping_update.py b/utils/class_mapping_update.py new file mode 100644 index 00000000000..126600acd14 --- /dev/null +++ b/utils/class_mapping_update.py @@ -0,0 +1,60 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# this script remaps classes to class strings so that it's quick to load such maps and not require +# loading all possible modeling files +# +# it can be extended to auto-generate other dicts that are needed at runtime + + +import os +import sys +from os.path import abspath, dirname, join + + +git_repo_path = abspath(join(dirname(dirname(__file__)), "src")) +sys.path.insert(1, git_repo_path) + +src = "src/transformers/models/auto/modeling_auto.py" +dst = "src/transformers/utils/modeling_auto_mapping.py" + +if os.path.exists(dst) and os.path.getmtime(src) < os.path.getmtime(dst): + # speed things up by only running this script if the src is newer than dst + sys.exit(0) + +# only load if needed +from transformers.models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING # noqa + + +entries = "\n".join( + [f' ("{k.__name__}", "{v.__name__}"),' for k, v in MODEL_FOR_QUESTION_ANSWERING_MAPPING.items()] +) +content = [ + "# THIS FILE HAS BEEN AUTOGENERATED. To update:", + "# 1. modify: models/auto/modeling_auto.py", + "# 2. run: python utils/class_mapping_update.py", + "from collections import OrderedDict", + "", + "", + "MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(", + " [", + entries, + " ]", + ")", + "", +] +print(f"updating {dst}") +with open(dst, "w", encoding="utf-8", newline="\n") as f: + f.write("\n".join(content))