remap MODEL_FOR_QUESTION_ANSWERING_MAPPING classes to names auto-generated file (#10487)

* remap classes to strings

* missing new util

* style

* doc

* move the autogenerated file

* Trigger CI
This commit is contained in:
Stas Bekman 2021-03-03 08:54:00 -08:00 committed by GitHub
parent 801ff969ce
commit 188574ac50
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 97 additions and 2 deletions

View File

@ -27,6 +27,7 @@ extra_quality_checks: deps_table_update
python utils/check_dummies.py python utils/check_dummies.py
python utils/check_repo.py python utils/check_repo.py
python utils/style_doc.py src/transformers docs/source --max_len 119 python utils/style_doc.py src/transformers docs/source --max_len 119
python utils/class_mapping_update.py
# this target runs checks on all files # this target runs checks on all files
quality: quality:

View File

@ -61,7 +61,6 @@ from .file_utils import (
is_torch_tpu_available, is_torch_tpu_available,
) )
from .modeling_utils import PreTrainedModel from .modeling_utils import PreTrainedModel
from .models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING
from .optimization import Adafactor, AdamW, get_scheduler from .optimization import Adafactor, AdamW, get_scheduler
from .tokenization_utils_base import PreTrainedTokenizerBase from .tokenization_utils_base import PreTrainedTokenizerBase
from .trainer_callback import ( from .trainer_callback import (
@ -104,6 +103,7 @@ from .trainer_utils import (
) )
from .training_args import ParallelMode, TrainingArguments from .training_args import ParallelMode, TrainingArguments
from .utils import logging from .utils import logging
from .utils.modeling_auto_mapping import MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
_is_native_amp_available = False _is_native_amp_available = False
@ -420,7 +420,7 @@ class Trainer:
self.use_tune_checkpoints = False self.use_tune_checkpoints = False
default_label_names = ( default_label_names = (
["start_positions", "end_positions"] ["start_positions", "end_positions"]
if type(self.model) in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values() if type(self.model).__name__ in MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES.values()
else ["labels"] else ["labels"]
) )
self.label_names = default_label_names if self.args.label_names is None else self.args.label_names self.label_names = default_label_names if self.args.label_names is None else self.args.label_names

View File

@ -0,0 +1,34 @@
# THIS FILE HAS BEEN AUTOGENERATED. To update:
# 1. modify: models/auto/modeling_auto.py
# 2. run: python utils/class_mapping_update.py
from collections import OrderedDict
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
[
("ConvBertConfig", "ConvBertForQuestionAnswering"),
("LEDConfig", "LEDForQuestionAnswering"),
("DistilBertConfig", "DistilBertForQuestionAnswering"),
("AlbertConfig", "AlbertForQuestionAnswering"),
("CamembertConfig", "CamembertForQuestionAnswering"),
("BartConfig", "BartForQuestionAnswering"),
("MBartConfig", "MBartForQuestionAnswering"),
("LongformerConfig", "LongformerForQuestionAnswering"),
("XLMRobertaConfig", "XLMRobertaForQuestionAnswering"),
("RobertaConfig", "RobertaForQuestionAnswering"),
("SqueezeBertConfig", "SqueezeBertForQuestionAnswering"),
("BertConfig", "BertForQuestionAnswering"),
("XLNetConfig", "XLNetForQuestionAnsweringSimple"),
("FlaubertConfig", "FlaubertForQuestionAnsweringSimple"),
("MobileBertConfig", "MobileBertForQuestionAnswering"),
("XLMConfig", "XLMForQuestionAnsweringSimple"),
("ElectraConfig", "ElectraForQuestionAnswering"),
("ReformerConfig", "ReformerForQuestionAnswering"),
("FunnelConfig", "FunnelForQuestionAnswering"),
("LxmertConfig", "LxmertForQuestionAnswering"),
("MPNetConfig", "MPNetForQuestionAnswering"),
("DebertaConfig", "DebertaForQuestionAnswering"),
("DebertaV2Config", "DebertaV2ForQuestionAnswering"),
("IBertConfig", "IBertForQuestionAnswering"),
]
)

View File

@ -0,0 +1,60 @@
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# this script remaps classes to class strings so that it's quick to load such maps and not require
# loading all possible modeling files
#
# it can be extended to auto-generate other dicts that are needed at runtime
import os
import sys
from os.path import abspath, dirname, join
git_repo_path = abspath(join(dirname(dirname(__file__)), "src"))
sys.path.insert(1, git_repo_path)
src = "src/transformers/models/auto/modeling_auto.py"
dst = "src/transformers/utils/modeling_auto_mapping.py"
if os.path.exists(dst) and os.path.getmtime(src) < os.path.getmtime(dst):
# speed things up by only running this script if the src is newer than dst
sys.exit(0)
# only load if needed
from transformers.models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING # noqa
entries = "\n".join(
[f' ("{k.__name__}", "{v.__name__}"),' for k, v in MODEL_FOR_QUESTION_ANSWERING_MAPPING.items()]
)
content = [
"# THIS FILE HAS BEEN AUTOGENERATED. To update:",
"# 1. modify: models/auto/modeling_auto.py",
"# 2. run: python utils/class_mapping_update.py",
"from collections import OrderedDict",
"",
"",
"MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(",
" [",
entries,
" ]",
")",
"",
]
print(f"updating {dst}")
with open(dst, "w", encoding="utf-8", newline="\n") as f:
f.write("\n".join(content))