mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 04:40:06 +06:00
Check TF ops for ONNX compliance (#10025)
* Add check-ops script * Finish to implement check_tf_ops and start the test * Make the test mandatory only for BERT * Update tf_ops folder * Remove useless classes * Add the ONNX test for GPT2 and BART * Add a onnxruntime slow test + better opset flexibility * Fix test + apply style * fix tests * Switch min opset from 12 to 10 * Update src/transformers/file_utils.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Fix GPT2 * Remove extra shape_list usage * Fix GPT2 * Address Morgan's comments Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
parent
93bd2f7099
commit
c8d3fa0dfd
@ -151,6 +151,16 @@ except importlib_metadata.PackageNotFoundError:
|
||||
_faiss_available = False
|
||||
|
||||
|
||||
_onnx_available = (
|
||||
importlib.util.find_spec("keras2onnx") is not None and importlib.util.find_spec("onnxruntime") is not None
|
||||
)
|
||||
try:
|
||||
_onxx_version = importlib_metadata.version("onnx")
|
||||
logger.debug(f"Successfully imported onnx version {_onxx_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_onnx_available = False
|
||||
|
||||
|
||||
_scatter_available = importlib.util.find_spec("torch_scatter") is not None
|
||||
try:
|
||||
_scatter_version = importlib_metadata.version("torch_scatter")
|
||||
@ -230,6 +240,10 @@ def is_tf_available():
|
||||
return _tf_available
|
||||
|
||||
|
||||
def is_onnx_available():
|
||||
return _onnx_available
|
||||
|
||||
|
||||
def is_flax_available():
|
||||
return _flax_available
|
||||
|
||||
|
@ -1030,16 +1030,7 @@ class TFGPT2ForSequenceClassification(TFGPT2PreTrainedModel, TFSequenceClassific
|
||||
)
|
||||
- 1
|
||||
)
|
||||
|
||||
def get_seq_element(sequence_position, input_batch):
|
||||
return tf.strided_slice(
|
||||
input_batch, [sequence_position, 0], [sequence_position + 1, input_batch.shape[-1]], [1, 1]
|
||||
)
|
||||
|
||||
result = tf.map_fn(
|
||||
fn=lambda t: get_seq_element(t[0], t[1]), elems=[sequence_lengths, logits], dtype="float"
|
||||
)
|
||||
in_logits = tf.reshape(result, [logits_shape[0], logits_shape[-1]])
|
||||
in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
logger.warning(
|
||||
@ -1049,16 +1040,12 @@ class TFGPT2ForSequenceClassification(TFGPT2PreTrainedModel, TFSequenceClassific
|
||||
loss = None
|
||||
|
||||
if inputs["labels"] is not None:
|
||||
if input_ids is not None:
|
||||
batch_size, sequence_length = shape_list(inputs["input_ids"])[:2]
|
||||
else:
|
||||
batch_size, sequence_length = shape_list(inputs["inputs_embeds"])[:2]
|
||||
assert (
|
||||
self.config.pad_token_id is not None or batch_size == 1
|
||||
self.config.pad_token_id is not None or logits_shape[0] == 1
|
||||
), "Cannot handle batch sizes > 1 if no padding token is defined."
|
||||
|
||||
if not tf.is_tensor(sequence_lengths):
|
||||
in_logits = logits[0:batch_size, sequence_lengths]
|
||||
in_logits = logits[0 : logits_shape[0], sequence_lengths]
|
||||
|
||||
loss = self.compute_loss(tf.reshape(inputs["labels"], [-1]), tf.reshape(in_logits, [-1, self.num_labels]))
|
||||
pooled_logits = in_logits if in_logits is not None else logits
|
||||
|
@ -28,6 +28,7 @@ from .file_utils import (
|
||||
is_datasets_available,
|
||||
is_faiss_available,
|
||||
is_flax_available,
|
||||
is_onnx_available,
|
||||
is_pandas_available,
|
||||
is_scatter_available,
|
||||
is_sentencepiece_available,
|
||||
@ -160,6 +161,13 @@ def require_git_lfs(test_case):
|
||||
return test_case
|
||||
|
||||
|
||||
def require_onnx(test_case):
|
||||
if not is_onnx_available():
|
||||
return unittest.skip("test requires ONNX")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
|
||||
|
||||
def require_torch(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires PyTorch.
|
||||
|
@ -241,6 +241,7 @@ class TFAlbertModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
else ()
|
||||
)
|
||||
test_head_masking = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFAlbertModelTester(self)
|
||||
|
@ -178,6 +178,8 @@ class TFBartModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
all_generative_model_classes = (TFBartForConditionalGeneration,) if is_tf_available() else ()
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
test_onnx = True
|
||||
onnx_min_opset = 10
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFBartModelTester(self)
|
||||
|
@ -274,6 +274,8 @@ class TFBertModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
else ()
|
||||
)
|
||||
test_head_masking = False
|
||||
test_onnx = True
|
||||
onnx_min_opset = 10
|
||||
|
||||
# special case for ForPreTraining model
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
|
@ -177,6 +177,7 @@ class TFBlenderbotModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
all_generative_model_classes = (TFBlenderbotForConditionalGeneration,) if is_tf_available() else ()
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFBlenderbotModelTester(self)
|
||||
|
@ -179,6 +179,7 @@ class TFBlenderbotSmallModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
all_generative_model_classes = (TFBlenderbotSmallForConditionalGeneration,) if is_tf_available() else ()
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFBlenderbotSmallModelTester(self)
|
||||
|
@ -16,6 +16,7 @@
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import tempfile
|
||||
@ -24,7 +25,7 @@ from importlib import import_module
|
||||
from typing import List, Tuple
|
||||
|
||||
from transformers import is_tf_available
|
||||
from transformers.testing_utils import _tf_gpu_memory_limit, is_pt_tf_cross_test, require_tf, slow
|
||||
from transformers.testing_utils import _tf_gpu_memory_limit, is_pt_tf_cross_test, require_onnx, require_tf, slow
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
@ -201,6 +202,67 @@ class TFModelTesterMixin:
|
||||
saved_model_dir = os.path.join(tmpdirname, "saved_model", "1")
|
||||
self.assertTrue(os.path.exists(saved_model_dir))
|
||||
|
||||
def test_onnx_compliancy(self):
|
||||
if not self.test_onnx:
|
||||
return
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
INTERNAL_OPS = [
|
||||
"Assert",
|
||||
"AssignVariableOp",
|
||||
"EmptyTensorList",
|
||||
"ReadVariableOp",
|
||||
"ResourceGather",
|
||||
"TruncatedNormal",
|
||||
"VarHandleOp",
|
||||
"VarIsInitializedOp",
|
||||
]
|
||||
onnx_ops = []
|
||||
|
||||
with open(os.path.join(".", "utils", "tf_ops", "onnx.json")) as f:
|
||||
onnx_opsets = json.load(f)["opsets"]
|
||||
|
||||
for i in range(1, self.onnx_min_opset + 1):
|
||||
onnx_ops.extend(onnx_opsets[str(i)])
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model_op_names = set()
|
||||
|
||||
with tf.Graph().as_default() as g:
|
||||
model = model_class(config)
|
||||
model(model.dummy_inputs)
|
||||
|
||||
for op in g.get_operations():
|
||||
model_op_names.add(op.node_def.op)
|
||||
|
||||
model_op_names = sorted(model_op_names)
|
||||
incompatible_ops = []
|
||||
|
||||
for op in model_op_names:
|
||||
if op not in onnx_ops and op not in INTERNAL_OPS:
|
||||
incompatible_ops.append(op)
|
||||
|
||||
self.assertEqual(len(incompatible_ops), 0, incompatible_ops)
|
||||
|
||||
@require_onnx
|
||||
@slow
|
||||
def test_onnx_runtime_optimize(self):
|
||||
if not self.test_onnx:
|
||||
return
|
||||
|
||||
import keras2onnx
|
||||
import onnxruntime
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model(model.dummy_inputs)
|
||||
|
||||
onnx_model = keras2onnx.convert_keras(model, model.name, target_opset=self.onnx_min_opset)
|
||||
|
||||
onnxruntime.InferenceSession(onnx_model.SerializeToString())
|
||||
|
||||
@slow
|
||||
def test_saved_model_creation_extended(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
@ -239,6 +239,7 @@ class TFConvBertModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
)
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFConvBertModelTester(self)
|
||||
|
@ -174,6 +174,7 @@ class TFCTRLModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (TFCTRLModel, TFCTRLLMHeadModel, TFCTRLForSequenceClassification) if is_tf_available() else ()
|
||||
all_generative_model_classes = (TFCTRLLMHeadModel,) if is_tf_available() else ()
|
||||
test_head_masking = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFCTRLModelTester(self)
|
||||
|
@ -184,6 +184,7 @@ class TFDistilBertModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
else None
|
||||
)
|
||||
test_head_masking = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFDistilBertModelTester(self)
|
||||
|
@ -188,6 +188,7 @@ class TFDPRModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
test_missing_keys = False
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFDPRModelTester(self)
|
||||
|
@ -206,6 +206,7 @@ class TFElectraModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
else ()
|
||||
)
|
||||
test_head_masking = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFElectraModelTester(self)
|
||||
|
@ -292,6 +292,7 @@ class TFFlaubertModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
(TFFlaubertWithLMHeadModel,) if is_tf_available() else ()
|
||||
) # TODO (PVP): Check other models whether language generation is also applicable
|
||||
test_head_masking = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFFlaubertModelTester(self)
|
||||
|
@ -339,6 +339,7 @@ class TFFunnelModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
else ()
|
||||
)
|
||||
test_head_masking = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFFunnelModelTester(self)
|
||||
@ -382,6 +383,7 @@ class TFFunnelBaseModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
(TFFunnelBaseModel, TFFunnelForMultipleChoice, TFFunnelForSequenceClassification) if is_tf_available() else ()
|
||||
)
|
||||
test_head_masking = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFFunnelModelTester(self, base=True)
|
||||
|
@ -333,6 +333,8 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
)
|
||||
all_generative_model_classes = (TFGPT2LMHeadModel,) if is_tf_available() else ()
|
||||
test_head_masking = False
|
||||
test_onnx = True
|
||||
onnx_min_opset = 10
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFGPT2ModelTester(self)
|
||||
|
@ -195,6 +195,8 @@ class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
all_generative_model_classes = (TFLEDForConditionalGeneration,) if is_tf_available() else ()
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFLEDModelTester(self)
|
||||
|
@ -297,6 +297,8 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
if is_tf_available()
|
||||
else ()
|
||||
)
|
||||
test_head_masking = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFLongformerModelTester(self)
|
||||
|
@ -362,6 +362,7 @@ class TFLxmertModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (TFLxmertModel, TFLxmertForPreTraining) if is_tf_available() else ()
|
||||
test_head_masking = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFLxmertModelTester(self)
|
||||
|
@ -179,6 +179,7 @@ class TFMarianModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
all_generative_model_classes = (TFMarianMTModel,) if is_tf_available() else ()
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFMarianModelTester(self)
|
||||
|
@ -181,6 +181,7 @@ class TFMBartModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
all_generative_model_classes = (TFMBartForConditionalGeneration,) if is_tf_available() else ()
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFMBartModelTester(self)
|
||||
|
@ -56,6 +56,7 @@ class TFMobileBertModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
else ()
|
||||
)
|
||||
test_head_masking = False
|
||||
test_onnx = False
|
||||
|
||||
class TFMobileBertModelTester(object):
|
||||
def __init__(
|
||||
|
@ -199,6 +199,7 @@ class TFMPNetModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
else ()
|
||||
)
|
||||
test_head_masking = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFMPNetModelTester(self)
|
||||
|
@ -203,6 +203,7 @@ class TFOpenAIGPTModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
(TFOpenAIGPTLMHeadModel,) if is_tf_available() else ()
|
||||
) # TODO (PVP): Add Double HeadsModel when generate() function is changed accordingly
|
||||
test_head_masking = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFOpenAIGPTModelTester(self)
|
||||
|
@ -177,6 +177,7 @@ class TFPegasusModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
all_generative_model_classes = (TFPegasusForConditionalGeneration,) if is_tf_available() else ()
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFPegasusModelTester(self)
|
||||
|
@ -186,6 +186,7 @@ class TFRobertaModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
else ()
|
||||
)
|
||||
test_head_masking = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFRobertaModelTester(self)
|
||||
|
@ -249,6 +249,7 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (TFT5Model, TFT5ForConditionalGeneration) if is_tf_available() else ()
|
||||
all_generative_model_classes = (TFT5ForConditionalGeneration,) if is_tf_available() else ()
|
||||
test_head_masking = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFT5ModelTester(self)
|
||||
@ -427,6 +428,7 @@ class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
is_encoder_decoder = False
|
||||
all_model_classes = (TFT5EncoderModel,) if is_tf_available() else ()
|
||||
test_head_masking = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFT5EncoderOnlyModelTester(self)
|
||||
|
@ -164,6 +164,7 @@ class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
# TODO: add this test when TFTransfoXLLMHead has a linear output layer implemented
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFTransfoXLModelTester(self)
|
||||
|
@ -294,6 +294,7 @@ class TFXLMModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
(TFXLMWithLMHeadModel,) if is_tf_available() else ()
|
||||
) # TODO (PVP): Check other models whether language generation is also applicable
|
||||
test_head_masking = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFXLMModelTester(self)
|
||||
|
@ -348,6 +348,7 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
(TFXLNetLMHeadModel,) if is_tf_available() else ()
|
||||
) # TODO (PVP): Check other models whether language generation is also applicable
|
||||
test_head_masking = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFXLNetModelTester(self)
|
||||
|
101
utils/check_tf_ops.py
Normal file
101
utils/check_tf_ops.py
Normal file
@ -0,0 +1,101 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
|
||||
from tensorflow.core.protobuf.saved_model_pb2 import SavedModel
|
||||
|
||||
|
||||
# All paths are set with the intent you should run this script from the root of the repo with the command
|
||||
# python utils/check_copies.py
|
||||
REPO_PATH = "."
|
||||
|
||||
# Internal TensorFlow ops that can be safely ignored (mostly specific to a saved model)
|
||||
INTERNAL_OPS = [
|
||||
"Assert",
|
||||
"AssignVariableOp",
|
||||
"EmptyTensorList",
|
||||
"MergeV2Checkpoints",
|
||||
"ReadVariableOp",
|
||||
"ResourceGather",
|
||||
"RestoreV2",
|
||||
"SaveV2",
|
||||
"ShardedFilename",
|
||||
"StatefulPartitionedCall",
|
||||
"StaticRegexFullMatch",
|
||||
"VarHandleOp",
|
||||
]
|
||||
|
||||
|
||||
def onnx_compliancy(saved_model_path, strict, opset):
|
||||
saved_model = SavedModel()
|
||||
onnx_ops = []
|
||||
|
||||
with open(os.path.join(REPO_PATH, "utils", "tf_ops", "onnx.json")) as f:
|
||||
onnx_opsets = json.load(f)["opsets"]
|
||||
|
||||
for i in range(1, opset + 1):
|
||||
onnx_ops.extend(onnx_opsets[str(i)])
|
||||
|
||||
with open(saved_model_path, "rb") as f:
|
||||
saved_model.ParseFromString(f.read())
|
||||
|
||||
model_op_names = set()
|
||||
|
||||
# Iterate over every metagraph in case there is more than one (a saved model can contain multiple graphs)
|
||||
for meta_graph in saved_model.meta_graphs:
|
||||
# Add operations in the graph definition
|
||||
model_op_names.update(node.op for node in meta_graph.graph_def.node)
|
||||
|
||||
# Go through the functions in the graph definition
|
||||
for func in meta_graph.graph_def.library.function:
|
||||
# Add operations in each function
|
||||
model_op_names.update(node.op for node in func.node_def)
|
||||
|
||||
# Convert to list, sorted if you want
|
||||
model_op_names = sorted(model_op_names)
|
||||
incompatible_ops = []
|
||||
|
||||
for op in model_op_names:
|
||||
if op not in onnx_ops and op not in INTERNAL_OPS:
|
||||
incompatible_ops.append(op)
|
||||
|
||||
if strict and len(incompatible_ops) > 0:
|
||||
raise Exception(f"Found the following incompatible ops for the opset {opset}:\n" + incompatible_ops)
|
||||
elif len(incompatible_ops) > 0:
|
||||
print(f"Found the following incompatible ops for the opset {opset}:")
|
||||
print(*incompatible_ops, sep="\n")
|
||||
else:
|
||||
print(f"The saved model {saved_model_path} can properly be converted with ONNX.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--saved_model_path", help="Path of the saved model to check (the .pb file).")
|
||||
parser.add_argument(
|
||||
"--opset", default=12, type=int, help="The ONNX opset against which the model has to be tested."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--framework", choices=["onnx"], default="onnx", help="Frameworks against which to test the saved model."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--strict", action="store_true", help="Whether make the checking strict (raise errors) or not (raise warnings)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.framework == "onnx":
|
||||
onnx_compliancy(args.saved_model_path, args.strict, args.opset)
|
245
utils/tf_ops/onnx.json
Normal file
245
utils/tf_ops/onnx.json
Normal file
@ -0,0 +1,245 @@
|
||||
{
|
||||
"opsets": {
|
||||
"1": [
|
||||
"Abs",
|
||||
"Add",
|
||||
"AddV2",
|
||||
"ArgMax",
|
||||
"ArgMin",
|
||||
"AvgPool",
|
||||
"AvgPool3D",
|
||||
"BatchMatMul",
|
||||
"BatchMatMulV2",
|
||||
"BatchToSpaceND",
|
||||
"BiasAdd",
|
||||
"BiasAddV1",
|
||||
"Cast",
|
||||
"Ceil",
|
||||
"CheckNumerics",
|
||||
"ComplexAbs",
|
||||
"Concat",
|
||||
"ConcatV2",
|
||||
"Const",
|
||||
"ConstV2",
|
||||
"Conv1D",
|
||||
"Conv2D",
|
||||
"Conv2DBackpropInput",
|
||||
"Conv3D",
|
||||
"Conv3DBackpropInputV2",
|
||||
"DepthToSpace",
|
||||
"DepthwiseConv2d",
|
||||
"DepthwiseConv2dNative",
|
||||
"Div",
|
||||
"Dropout",
|
||||
"Elu",
|
||||
"Equal",
|
||||
"Erf",
|
||||
"Exp",
|
||||
"ExpandDims",
|
||||
"Flatten",
|
||||
"Floor",
|
||||
"Gather",
|
||||
"GatherNd",
|
||||
"GatherV2",
|
||||
"Greater",
|
||||
"Identity",
|
||||
"IdentityN",
|
||||
"If",
|
||||
"LRN",
|
||||
"LSTMBlockCell",
|
||||
"LeakyRelu",
|
||||
"Less",
|
||||
"Log",
|
||||
"LogSoftmax",
|
||||
"LogicalAnd",
|
||||
"LogicalNot",
|
||||
"LogicalOr",
|
||||
"LookupTableSizeV2",
|
||||
"MatMul",
|
||||
"Max",
|
||||
"MaxPool",
|
||||
"MaxPool3D",
|
||||
"MaxPoolV2",
|
||||
"Maximum",
|
||||
"Mean",
|
||||
"Min",
|
||||
"Minimum",
|
||||
"MirrorPad",
|
||||
"Mul",
|
||||
"Neg",
|
||||
"NoOp",
|
||||
"NotEqual",
|
||||
"OneHot",
|
||||
"Pack",
|
||||
"Pad",
|
||||
"PadV2",
|
||||
"Placeholder",
|
||||
"PlaceholderV2",
|
||||
"PlaceholderWithDefault",
|
||||
"Pow",
|
||||
"Prod",
|
||||
"RFFT",
|
||||
"RandomNormal",
|
||||
"RandomNormalLike",
|
||||
"RandomUniform",
|
||||
"RandomUniformLike",
|
||||
"RealDiv",
|
||||
"Reciprocal",
|
||||
"Relu",
|
||||
"Relu6",
|
||||
"Reshape",
|
||||
"Rsqrt",
|
||||
"Selu",
|
||||
"Shape",
|
||||
"Sigmoid",
|
||||
"Sign",
|
||||
"Size",
|
||||
"Slice",
|
||||
"Softmax",
|
||||
"Softplus",
|
||||
"Softsign",
|
||||
"SpaceToBatchND",
|
||||
"SpaceToDepth",
|
||||
"Split",
|
||||
"SplitV",
|
||||
"Sqrt",
|
||||
"Square",
|
||||
"SquaredDifference",
|
||||
"Squeeze",
|
||||
"StatelessIf",
|
||||
"StopGradient",
|
||||
"StridedSlice",
|
||||
"StringJoin",
|
||||
"Sub",
|
||||
"Sum",
|
||||
"Tanh",
|
||||
"Tile",
|
||||
"TopKV2",
|
||||
"Transpose",
|
||||
"TruncateDiv",
|
||||
"Unpack",
|
||||
"ZerosLike"
|
||||
],
|
||||
"2": [],
|
||||
"3": [],
|
||||
"4": [],
|
||||
"5": [],
|
||||
"6": [
|
||||
"AddN",
|
||||
"All",
|
||||
"Any",
|
||||
"FloorDiv",
|
||||
"FusedBatchNorm",
|
||||
"FusedBatchNormV2",
|
||||
"FusedBatchNormV3"
|
||||
],
|
||||
"7": [
|
||||
"Acos",
|
||||
"Asin",
|
||||
"Atan",
|
||||
"Cos",
|
||||
"Fill",
|
||||
"FloorMod",
|
||||
"GreaterEqual",
|
||||
"LessEqual",
|
||||
"Loop",
|
||||
"MatrixBandPart",
|
||||
"Multinomial",
|
||||
"Range",
|
||||
"ResizeBilinear",
|
||||
"ResizeNearestNeighbor",
|
||||
"Scan",
|
||||
"Select",
|
||||
"SelectV2",
|
||||
"Sin",
|
||||
"SoftmaxCrossEntropyWithLogits",
|
||||
"SparseSoftmaxCrossEntropyWithLogits",
|
||||
"StatelessWhile",
|
||||
"Tan",
|
||||
"TensorListFromTensor",
|
||||
"TensorListGetItem",
|
||||
"TensorListLength",
|
||||
"TensorListReserve",
|
||||
"TensorListResize",
|
||||
"TensorListSetItem",
|
||||
"TensorListStack",
|
||||
"While"
|
||||
],
|
||||
"8": [
|
||||
"BroadcastTo",
|
||||
"ClipByValue",
|
||||
"FIFOQueueV2",
|
||||
"HashTableV2",
|
||||
"IteratorGetNext",
|
||||
"IteratorV2",
|
||||
"LookupTableFindV2",
|
||||
"MaxPoolWithArgmax",
|
||||
"QueueDequeueManyV2",
|
||||
"QueueDequeueUpToV2",
|
||||
"QueueDequeueV2",
|
||||
"ReverseSequence"
|
||||
],
|
||||
"9": [
|
||||
"SegmentMax",
|
||||
"SegmentMean",
|
||||
"SegmentMin",
|
||||
"SegmentProd",
|
||||
"SegmentSum",
|
||||
"Sinh",
|
||||
"SparseSegmentMean",
|
||||
"SparseSegmentMeanWithNumSegments",
|
||||
"SparseSegmentSqrtN",
|
||||
"SparseSegmentSqrtNWithNumSegments",
|
||||
"SparseSegmentSum",
|
||||
"SparseSegmentSumWithNumSegments",
|
||||
"UnsortedSegmentMax",
|
||||
"UnsortedSegmentMin",
|
||||
"UnsortedSegmentProd",
|
||||
"UnsortedSegmentSum",
|
||||
"Where"
|
||||
],
|
||||
"10": [
|
||||
"CropAndResize",
|
||||
"CudnnRNN",
|
||||
"DynamicStitch",
|
||||
"FakeQuantWithMinMaxArgs",
|
||||
"IsFinite",
|
||||
"IsInf",
|
||||
"NonMaxSuppressionV2",
|
||||
"NonMaxSuppressionV3",
|
||||
"NonMaxSuppressionV4",
|
||||
"NonMaxSuppressionV5",
|
||||
"ParallelDynamicStitch",
|
||||
"ReverseV2",
|
||||
"Roll"
|
||||
],
|
||||
"11": [
|
||||
"Bincount",
|
||||
"Cumsum",
|
||||
"InvertPermutation",
|
||||
"LeftShift",
|
||||
"MatrixDeterminant",
|
||||
"MatrixDiagPart",
|
||||
"MatrixDiagPartV2",
|
||||
"MatrixDiagPartV3",
|
||||
"RaggedRange",
|
||||
"RightShift",
|
||||
"Round",
|
||||
"ScatterNd",
|
||||
"SparseFillEmptyRows",
|
||||
"SparseReshape",
|
||||
"SparseToDense",
|
||||
"TensorScatterUpdate",
|
||||
"Unique"
|
||||
],
|
||||
"12": [
|
||||
"Einsum",
|
||||
"MatrixDiag",
|
||||
"MatrixDiagV2",
|
||||
"MatrixDiagV3",
|
||||
"MatrixSetDiagV3",
|
||||
"SquaredDistance"
|
||||
],
|
||||
"13": []
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user