Clean imports to fix test_fetcher (#17531)

* Clean imports to fix test_fetcher

* Add dependencies printer

* Update utils/tests_fetcher.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Fix Perceiver import

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
This commit is contained in:
Sylvain Gugger 2022-06-03 12:34:41 -04:00 committed by GitHub
parent 254d9c068e
commit c4e58cd8ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 124 additions and 72 deletions

View File

@ -1177,7 +1177,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
dataset_args: Optional[Union[str, List[str]]] = None,
):
# Avoids a circular import by doing this when necessary.
from .modelcard import TrainingSummary
from .modelcard import TrainingSummary # tests_ignore
training_summary = TrainingSummary.from_keras(
self,

View File

@ -27,7 +27,7 @@ from ...utils import (
_import_structure = {
"configuration_perceiver": ["PERCEIVER_PRETRAINED_CONFIG_ARCHIVE_MAP", "PerceiverConfig"],
"configuration_perceiver": ["PERCEIVER_PRETRAINED_CONFIG_ARCHIVE_MAP", "PerceiverConfig", "PerceiverOnnxConfig"],
"tokenization_perceiver": ["PerceiverTokenizer"],
}
@ -61,7 +61,7 @@ else:
if TYPE_CHECKING:
from .configuration_perceiver import PERCEIVER_PRETRAINED_CONFIG_ARCHIVE_MAP, PerceiverConfig
from .configuration_perceiver import PERCEIVER_PRETRAINED_CONFIG_ARCHIVE_MAP, PerceiverConfig, PerceiverOnnxConfig
from .tokenization_perceiver import PerceiverTokenizer
try:

View File

@ -1,39 +1,9 @@
from functools import partial, reduce
from typing import Callable, Dict, Optional, Tuple, Type, Union
import transformers
from .. import PretrainedConfig, PreTrainedModel, TFPreTrainedModel, is_tf_available, is_torch_available
from ..models.albert import AlbertOnnxConfig
from ..models.bart import BartOnnxConfig
from ..models.beit import BeitOnnxConfig
from ..models.bert import BertOnnxConfig
from ..models.big_bird import BigBirdOnnxConfig
from ..models.bigbird_pegasus import BigBirdPegasusOnnxConfig
from ..models.blenderbot import BlenderbotOnnxConfig
from ..models.blenderbot_small import BlenderbotSmallOnnxConfig
from ..models.camembert import CamembertOnnxConfig
from ..models.convbert import ConvBertOnnxConfig
from ..models.data2vec import Data2VecTextOnnxConfig
from ..models.deit import DeiTOnnxConfig
from ..models.distilbert import DistilBertOnnxConfig
from ..models.electra import ElectraOnnxConfig
from ..models.flaubert import FlaubertOnnxConfig
from ..models.gpt2 import GPT2OnnxConfig
from ..models.gpt_neo import GPTNeoOnnxConfig
from ..models.gptj import GPTJOnnxConfig
from ..models.ibert import IBertOnnxConfig
from ..models.layoutlm import LayoutLMOnnxConfig
from ..models.m2m_100 import M2M100OnnxConfig
from ..models.marian import MarianOnnxConfig
from ..models.mbart import MBartOnnxConfig
from ..models.mobilebert import MobileBertOnnxConfig
from ..models.perceiver.configuration_perceiver import PerceiverOnnxConfig
from ..models.roberta import RobertaOnnxConfig
from ..models.roformer import RoFormerOnnxConfig
from ..models.squeezebert import SqueezeBertOnnxConfig
from ..models.t5 import T5OnnxConfig
from ..models.vit import ViTOnnxConfig
from ..models.xlm import XLMOnnxConfig
from ..models.xlm_roberta import XLMRobertaOnnxConfig
from ..utils import logging
from .config import OnnxConfig
@ -72,14 +42,14 @@ if not is_torch_available() and not is_tf_available():
def supported_features_mapping(
*supported_features: str, onnx_config_cls: Type[OnnxConfig] = None
*supported_features: str, onnx_config_cls: str = None
) -> Dict[str, Callable[[PretrainedConfig], OnnxConfig]]:
"""
Generate the mapping between supported the features and their corresponding OnnxConfig for a given model.
Args:
*supported_features: The names of the supported features.
onnx_config_cls: The OnnxConfig class corresponding to the model.
onnx_config_cls: The OnnxConfig full name corresponding to the model.
Returns:
The dictionary mapping a feature to an OnnxConfig constructor.
@ -87,13 +57,16 @@ def supported_features_mapping(
if onnx_config_cls is None:
raise ValueError("A OnnxConfig class must be provided")
config_cls = transformers
for attr_name in onnx_config_cls.split("."):
config_cls = getattr(config_cls, attr_name)
mapping = {}
for feature in supported_features:
if "-with-past" in feature:
task = feature.replace("-with-past", "")
mapping[feature] = partial(onnx_config_cls.with_past, task=task)
mapping[feature] = partial(config_cls.with_past, task=task)
else:
mapping[feature] = partial(onnx_config_cls.from_model_config, task=feature)
mapping[feature] = partial(config_cls.from_model_config, task=feature)
return mapping
@ -135,7 +108,7 @@ class FeaturesManager:
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls=AlbertOnnxConfig,
onnx_config_cls="models.albert.AlbertOnnxConfig",
),
"bart": supported_features_mapping(
"default",
@ -146,10 +119,12 @@ class FeaturesManager:
"seq2seq-lm-with-past",
"sequence-classification",
"question-answering",
onnx_config_cls=BartOnnxConfig,
onnx_config_cls="models.bart.BartOnnxConfig",
),
# BEiT cannot be used with the masked image modeling autoclass, so this feature is excluded here
"beit": supported_features_mapping("default", "image-classification", onnx_config_cls=BeitOnnxConfig),
"beit": supported_features_mapping(
"default", "image-classification", onnx_config_cls="models.beit.BeitOnnxConfig"
),
"bert": supported_features_mapping(
"default",
"masked-lm",
@ -158,7 +133,7 @@ class FeaturesManager:
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls=BertOnnxConfig,
onnx_config_cls="models.bert.BertOnnxConfig",
),
"big-bird": supported_features_mapping(
"default",
@ -168,7 +143,7 @@ class FeaturesManager:
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls=BigBirdOnnxConfig,
onnx_config_cls="models.big_bird.BigBirdOnnxConfig",
),
"bigbird-pegasus": supported_features_mapping(
"default",
@ -179,7 +154,7 @@ class FeaturesManager:
"seq2seq-lm-with-past",
"sequence-classification",
"question-answering",
onnx_config_cls=BigBirdPegasusOnnxConfig,
onnx_config_cls="models.bigbird_pegasus.BigBirdPegasusOnnxConfig",
),
"blenderbot": supported_features_mapping(
"default",
@ -188,7 +163,7 @@ class FeaturesManager:
"causal-lm-with-past",
"seq2seq-lm",
"seq2seq-lm-with-past",
onnx_config_cls=BlenderbotOnnxConfig,
onnx_config_cls="models.blenderbot.BlenderbotOnnxConfig",
),
"blenderbot-small": supported_features_mapping(
"default",
@ -197,7 +172,7 @@ class FeaturesManager:
"causal-lm-with-past",
"seq2seq-lm",
"seq2seq-lm-with-past",
onnx_config_cls=BlenderbotSmallOnnxConfig,
onnx_config_cls="models.blenderbot_small.BlenderbotSmallOnnxConfig",
),
"camembert": supported_features_mapping(
"default",
@ -207,7 +182,7 @@ class FeaturesManager:
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls=CamembertOnnxConfig,
onnx_config_cls="models.camembert.CamembertOnnxConfig",
),
"convbert": supported_features_mapping(
"default",
@ -216,7 +191,7 @@ class FeaturesManager:
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls=ConvBertOnnxConfig,
onnx_config_cls="models.convbert.ConvBertOnnxConfig",
),
"data2vec-text": supported_features_mapping(
"default",
@ -225,10 +200,10 @@ class FeaturesManager:
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls=Data2VecTextOnnxConfig,
onnx_config_cls="models.data2vec.Data2VecTextOnnxConfig",
),
"deit": supported_features_mapping(
"default", "image-classification", "masked-im", onnx_config_cls=DeiTOnnxConfig
"default", "image-classification", "masked-im", onnx_config_cls="models.deit.DeiTOnnxConfig"
),
"distilbert": supported_features_mapping(
"default",
@ -237,7 +212,7 @@ class FeaturesManager:
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls=DistilBertOnnxConfig,
onnx_config_cls="models.distilbert.DistilBertOnnxConfig",
),
"electra": supported_features_mapping(
"default",
@ -247,7 +222,7 @@ class FeaturesManager:
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls=ElectraOnnxConfig,
onnx_config_cls="models.electra.ElectraOnnxConfig",
),
"flaubert": supported_features_mapping(
"default",
@ -257,7 +232,7 @@ class FeaturesManager:
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls=FlaubertOnnxConfig,
onnx_config_cls="models.flaubert.FlaubertOnnxConfig",
),
"gpt2": supported_features_mapping(
"default",
@ -266,7 +241,7 @@ class FeaturesManager:
"causal-lm-with-past",
"sequence-classification",
"token-classification",
onnx_config_cls=GPT2OnnxConfig,
onnx_config_cls="models.gpt2.GPT2OnnxConfig",
),
"gptj": supported_features_mapping(
"default",
@ -275,7 +250,7 @@ class FeaturesManager:
"causal-lm-with-past",
"question-answering",
"sequence-classification",
onnx_config_cls=GPTJOnnxConfig,
onnx_config_cls="models.gptj.GPTJOnnxConfig",
),
"gpt-neo": supported_features_mapping(
"default",
@ -283,7 +258,7 @@ class FeaturesManager:
"causal-lm",
"causal-lm-with-past",
"sequence-classification",
onnx_config_cls=GPTNeoOnnxConfig,
onnx_config_cls="models.gpt_neo.GPTNeoOnnxConfig",
),
"ibert": supported_features_mapping(
"default",
@ -292,14 +267,14 @@ class FeaturesManager:
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls=IBertOnnxConfig,
onnx_config_cls="models.ibert.IBertOnnxConfig",
),
"layoutlm": supported_features_mapping(
"default",
"masked-lm",
"sequence-classification",
"token-classification",
onnx_config_cls=LayoutLMOnnxConfig,
onnx_config_cls="models.layoutlm.LayoutLMOnnxConfig",
),
"marian": supported_features_mapping(
"default",
@ -308,7 +283,7 @@ class FeaturesManager:
"seq2seq-lm-with-past",
"causal-lm",
"causal-lm-with-past",
onnx_config_cls=MarianOnnxConfig,
onnx_config_cls="models.marian.MarianOnnxConfig",
),
"mbart": supported_features_mapping(
"default",
@ -319,7 +294,7 @@ class FeaturesManager:
"seq2seq-lm-with-past",
"sequence-classification",
"question-answering",
onnx_config_cls=MBartOnnxConfig,
onnx_config_cls="models.mbart.MBartOnnxConfig",
),
"mobilebert": supported_features_mapping(
"default",
@ -328,16 +303,20 @@ class FeaturesManager:
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls=MobileBertOnnxConfig,
onnx_config_cls="models.mobilebert.MobileBertOnnxConfig",
),
"m2m-100": supported_features_mapping(
"default", "default-with-past", "seq2seq-lm", "seq2seq-lm-with-past", onnx_config_cls=M2M100OnnxConfig
"default",
"default-with-past",
"seq2seq-lm",
"seq2seq-lm-with-past",
onnx_config_cls="models.m2m_100.M2M100OnnxConfig",
),
"perceiver": supported_features_mapping(
"image-classification",
"masked-lm",
"sequence-classification",
onnx_config_cls=PerceiverOnnxConfig,
onnx_config_cls="models.perceiver.PerceiverOnnxConfig",
),
"roberta": supported_features_mapping(
"default",
@ -347,7 +326,7 @@ class FeaturesManager:
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls=RobertaOnnxConfig,
onnx_config_cls="models.roberta.RobertaOnnxConfig",
),
"roformer": supported_features_mapping(
"default",
@ -358,7 +337,7 @@ class FeaturesManager:
"multiple-choice",
"question-answering",
"token-classification",
onnx_config_cls=RoFormerOnnxConfig,
onnx_config_cls="models.roformer.RoFormerOnnxConfig",
),
"squeezebert": supported_features_mapping(
"default",
@ -367,13 +346,17 @@ class FeaturesManager:
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls=SqueezeBertOnnxConfig,
onnx_config_cls="models.squeezebert.SqueezeBertOnnxConfig",
),
"t5": supported_features_mapping(
"default", "default-with-past", "seq2seq-lm", "seq2seq-lm-with-past", onnx_config_cls=T5OnnxConfig
"default",
"default-with-past",
"seq2seq-lm",
"seq2seq-lm-with-past",
onnx_config_cls="models.t5.T5OnnxConfig",
),
"vit": supported_features_mapping(
"default", "image-classification", "masked-im", onnx_config_cls=ViTOnnxConfig
"default", "image-classification", "masked-im", onnx_config_cls="models.vit.ViTOnnxConfig"
),
"xlm": supported_features_mapping(
"default",
@ -383,7 +366,7 @@ class FeaturesManager:
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls=XLMOnnxConfig,
onnx_config_cls="models.xlm.XLMOnnxConfig",
),
"xlm-roberta": supported_features_mapping(
"default",
@ -393,7 +376,7 @@ class FeaturesManager:
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls=XLMRobertaOnnxConfig,
onnx_config_cls="models.xlm_roberta.XLMRobertaOnnxConfig",
),
}

View File

@ -242,6 +242,67 @@ def get_test_dependencies(test_fname):
return [f for f in [*parent_imports, *current_dir_imports] if os.path.isfile(f)]
def create_reverse_dependency_tree():
"""
Create a list of all edges (a, b) which mean that modifying a impacts b with a going over all module and test files.
"""
modules = [
str(f.relative_to(PATH_TO_TRANFORMERS))
for f in (Path(PATH_TO_TRANFORMERS) / "src/transformers").glob("**/*.py")
]
module_edges = [(d, m) for m in modules for d in get_module_dependencies(m)]
tests = [str(f.relative_to(PATH_TO_TRANFORMERS)) for f in (Path(PATH_TO_TRANFORMERS) / "tests").glob("**/*.py")]
test_edges = [(d, t) for t in tests for d in get_test_dependencies(t)]
return module_edges + test_edges
def get_tree_starting_at(module, edges):
"""
Returns the tree starting at a given module following all edges in the following format: [module, [list of edges
starting at module], [list of edges starting at the preceding level], ...]
"""
vertices_seen = [module]
new_edges = [edge for edge in edges if edge[0] == module and edge[1] != module]
tree = [module]
while len(new_edges) > 0:
tree.append(new_edges)
final_vertices = list(set(edge[1] for edge in new_edges))
vertices_seen.extend(final_vertices)
new_edges = [edge for edge in edges if edge[0] in final_vertices and edge[1] not in vertices_seen]
return tree
def print_tree_deps_of(module, all_edges=None):
"""
Prints the tree of modules depending on a given module.
"""
if all_edges is None:
all_edges = create_reverse_dependency_tree()
tree = get_tree_starting_at(module, all_edges)
# The list of lines is a list of tuples (line_to_be_printed, module)
# Keeping the modules lets us know where to insert each new lines in the list.
lines = [(tree[0], tree[0])]
for index in range(1, len(tree)):
edges = tree[index]
start_edges = set([edge[0] for edge in edges])
for start in start_edges:
end_edges = set([edge[1] for edge in edges if edge[0] == start])
# We will insert all those edges just after the line showing start.
pos = 0
while lines[pos][1] != start:
pos += 1
lines = lines[: pos + 1] + [(" " * (2 * index) + end, end) for end in end_edges] + lines[pos + 1 :]
for line in lines:
# We don't print the refs that where just here to help build lines.
print(line[0])
def create_reverse_dependency_map():
"""
Create the dependency map from module/test filename to the list of modules/tests that depend on it (even
@ -585,8 +646,16 @@ if __name__ == "__main__":
default=["tests"],
help="Only keep the test files matching one of those filters.",
)
parser.add_argument(
"--print_dependencies_of",
type=str,
help="Will only print the tree of modules depending on the file passed.",
default=None,
)
args = parser.parse_args()
if args.sanity_check:
if args.print_dependencies_of is not None:
print_tree_deps_of(args.print_dependencies_of)
elif args.sanity_check:
sanity_check()
else:
repo = Repo(PATH_TO_TRANFORMERS)