Check TF ops for ONNX compliance (#10025)

* Add check-ops script

* Finish to implement check_tf_ops and start the test

* Make the test mandatory only for BERT

* Update tf_ops folder

* Remove useless classes

* Add the ONNX test for GPT2 and BART

* Add a onnxruntime slow test + better opset flexibility

* Fix test + apply style

* fix tests

* Switch min opset from 12 to 10

* Update src/transformers/file_utils.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Fix GPT2

* Remove extra shape_list usage

* Fix GPT2

* Address Morgan's comments

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
Julien Plu 2021-02-15 13:55:10 +01:00 committed by GitHub
parent 93bd2f7099
commit c8d3fa0dfd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
33 changed files with 468 additions and 17 deletions

View File

@ -151,6 +151,16 @@ except importlib_metadata.PackageNotFoundError:
_faiss_available = False
_onnx_available = (
importlib.util.find_spec("keras2onnx") is not None and importlib.util.find_spec("onnxruntime") is not None
)
try:
_onxx_version = importlib_metadata.version("onnx")
logger.debug(f"Successfully imported onnx version {_onxx_version}")
except importlib_metadata.PackageNotFoundError:
_onnx_available = False
_scatter_available = importlib.util.find_spec("torch_scatter") is not None
try:
_scatter_version = importlib_metadata.version("torch_scatter")
@ -230,6 +240,10 @@ def is_tf_available():
return _tf_available
def is_onnx_available():
return _onnx_available
def is_flax_available():
return _flax_available

View File

@ -1030,16 +1030,7 @@ class TFGPT2ForSequenceClassification(TFGPT2PreTrainedModel, TFSequenceClassific
)
- 1
)
def get_seq_element(sequence_position, input_batch):
return tf.strided_slice(
input_batch, [sequence_position, 0], [sequence_position + 1, input_batch.shape[-1]], [1, 1]
)
result = tf.map_fn(
fn=lambda t: get_seq_element(t[0], t[1]), elems=[sequence_lengths, logits], dtype="float"
)
in_logits = tf.reshape(result, [logits_shape[0], logits_shape[-1]])
in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1)
else:
sequence_lengths = -1
logger.warning(
@ -1049,16 +1040,12 @@ class TFGPT2ForSequenceClassification(TFGPT2PreTrainedModel, TFSequenceClassific
loss = None
if inputs["labels"] is not None:
if input_ids is not None:
batch_size, sequence_length = shape_list(inputs["input_ids"])[:2]
else:
batch_size, sequence_length = shape_list(inputs["inputs_embeds"])[:2]
assert (
self.config.pad_token_id is not None or batch_size == 1
self.config.pad_token_id is not None or logits_shape[0] == 1
), "Cannot handle batch sizes > 1 if no padding token is defined."
if not tf.is_tensor(sequence_lengths):
in_logits = logits[0:batch_size, sequence_lengths]
in_logits = logits[0 : logits_shape[0], sequence_lengths]
loss = self.compute_loss(tf.reshape(inputs["labels"], [-1]), tf.reshape(in_logits, [-1, self.num_labels]))
pooled_logits = in_logits if in_logits is not None else logits

View File

@ -28,6 +28,7 @@ from .file_utils import (
is_datasets_available,
is_faiss_available,
is_flax_available,
is_onnx_available,
is_pandas_available,
is_scatter_available,
is_sentencepiece_available,
@ -160,6 +161,13 @@ def require_git_lfs(test_case):
return test_case
def require_onnx(test_case):
if not is_onnx_available():
return unittest.skip("test requires ONNX")(test_case)
else:
return test_case
def require_torch(test_case):
"""
Decorator marking a test that requires PyTorch.

View File

@ -241,6 +241,7 @@ class TFAlbertModelTest(TFModelTesterMixin, unittest.TestCase):
else ()
)
test_head_masking = False
test_onnx = False
def setUp(self):
self.model_tester = TFAlbertModelTester(self)

View File

@ -178,6 +178,8 @@ class TFBartModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = (TFBartForConditionalGeneration,) if is_tf_available() else ()
is_encoder_decoder = True
test_pruning = False
test_onnx = True
onnx_min_opset = 10
def setUp(self):
self.model_tester = TFBartModelTester(self)

View File

@ -274,6 +274,8 @@ class TFBertModelTest(TFModelTesterMixin, unittest.TestCase):
else ()
)
test_head_masking = False
test_onnx = True
onnx_min_opset = 10
# special case for ForPreTraining model
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):

View File

@ -177,6 +177,7 @@ class TFBlenderbotModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = (TFBlenderbotForConditionalGeneration,) if is_tf_available() else ()
is_encoder_decoder = True
test_pruning = False
test_onnx = False
def setUp(self):
self.model_tester = TFBlenderbotModelTester(self)

View File

@ -179,6 +179,7 @@ class TFBlenderbotSmallModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = (TFBlenderbotSmallForConditionalGeneration,) if is_tf_available() else ()
is_encoder_decoder = True
test_pruning = False
test_onnx = False
def setUp(self):
self.model_tester = TFBlenderbotSmallModelTester(self)

View File

@ -16,6 +16,7 @@
import copy
import inspect
import json
import os
import random
import tempfile
@ -24,7 +25,7 @@ from importlib import import_module
from typing import List, Tuple
from transformers import is_tf_available
from transformers.testing_utils import _tf_gpu_memory_limit, is_pt_tf_cross_test, require_tf, slow
from transformers.testing_utils import _tf_gpu_memory_limit, is_pt_tf_cross_test, require_onnx, require_tf, slow
if is_tf_available():
@ -201,6 +202,67 @@ class TFModelTesterMixin:
saved_model_dir = os.path.join(tmpdirname, "saved_model", "1")
self.assertTrue(os.path.exists(saved_model_dir))
def test_onnx_compliancy(self):
if not self.test_onnx:
return
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
INTERNAL_OPS = [
"Assert",
"AssignVariableOp",
"EmptyTensorList",
"ReadVariableOp",
"ResourceGather",
"TruncatedNormal",
"VarHandleOp",
"VarIsInitializedOp",
]
onnx_ops = []
with open(os.path.join(".", "utils", "tf_ops", "onnx.json")) as f:
onnx_opsets = json.load(f)["opsets"]
for i in range(1, self.onnx_min_opset + 1):
onnx_ops.extend(onnx_opsets[str(i)])
for model_class in self.all_model_classes:
model_op_names = set()
with tf.Graph().as_default() as g:
model = model_class(config)
model(model.dummy_inputs)
for op in g.get_operations():
model_op_names.add(op.node_def.op)
model_op_names = sorted(model_op_names)
incompatible_ops = []
for op in model_op_names:
if op not in onnx_ops and op not in INTERNAL_OPS:
incompatible_ops.append(op)
self.assertEqual(len(incompatible_ops), 0, incompatible_ops)
@require_onnx
@slow
def test_onnx_runtime_optimize(self):
if not self.test_onnx:
return
import keras2onnx
import onnxruntime
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model(model.dummy_inputs)
onnx_model = keras2onnx.convert_keras(model, model.name, target_opset=self.onnx_min_opset)
onnxruntime.InferenceSession(onnx_model.SerializeToString())
@slow
def test_saved_model_creation_extended(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

View File

@ -239,6 +239,7 @@ class TFConvBertModelTest(TFModelTesterMixin, unittest.TestCase):
)
test_pruning = False
test_head_masking = False
test_onnx = False
def setUp(self):
self.model_tester = TFConvBertModelTester(self)

View File

@ -174,6 +174,7 @@ class TFCTRLModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = (TFCTRLModel, TFCTRLLMHeadModel, TFCTRLForSequenceClassification) if is_tf_available() else ()
all_generative_model_classes = (TFCTRLLMHeadModel,) if is_tf_available() else ()
test_head_masking = False
test_onnx = False
def setUp(self):
self.model_tester = TFCTRLModelTester(self)

View File

@ -184,6 +184,7 @@ class TFDistilBertModelTest(TFModelTesterMixin, unittest.TestCase):
else None
)
test_head_masking = False
test_onnx = False
def setUp(self):
self.model_tester = TFDistilBertModelTester(self)

View File

@ -188,6 +188,7 @@ class TFDPRModelTest(TFModelTesterMixin, unittest.TestCase):
test_missing_keys = False
test_pruning = False
test_head_masking = False
test_onnx = False
def setUp(self):
self.model_tester = TFDPRModelTester(self)

View File

@ -206,6 +206,7 @@ class TFElectraModelTest(TFModelTesterMixin, unittest.TestCase):
else ()
)
test_head_masking = False
test_onnx = False
def setUp(self):
self.model_tester = TFElectraModelTester(self)

View File

@ -292,6 +292,7 @@ class TFFlaubertModelTest(TFModelTesterMixin, unittest.TestCase):
(TFFlaubertWithLMHeadModel,) if is_tf_available() else ()
) # TODO (PVP): Check other models whether language generation is also applicable
test_head_masking = False
test_onnx = False
def setUp(self):
self.model_tester = TFFlaubertModelTester(self)

View File

@ -339,6 +339,7 @@ class TFFunnelModelTest(TFModelTesterMixin, unittest.TestCase):
else ()
)
test_head_masking = False
test_onnx = False
def setUp(self):
self.model_tester = TFFunnelModelTester(self)
@ -382,6 +383,7 @@ class TFFunnelBaseModelTest(TFModelTesterMixin, unittest.TestCase):
(TFFunnelBaseModel, TFFunnelForMultipleChoice, TFFunnelForSequenceClassification) if is_tf_available() else ()
)
test_head_masking = False
test_onnx = False
def setUp(self):
self.model_tester = TFFunnelModelTester(self, base=True)

View File

@ -333,6 +333,8 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
)
all_generative_model_classes = (TFGPT2LMHeadModel,) if is_tf_available() else ()
test_head_masking = False
test_onnx = True
onnx_min_opset = 10
def setUp(self):
self.model_tester = TFGPT2ModelTester(self)

View File

@ -195,6 +195,8 @@ class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = (TFLEDForConditionalGeneration,) if is_tf_available() else ()
is_encoder_decoder = True
test_pruning = False
test_head_masking = False
test_onnx = False
def setUp(self):
self.model_tester = TFLEDModelTester(self)

View File

@ -297,6 +297,8 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase):
if is_tf_available()
else ()
)
test_head_masking = False
test_onnx = False
def setUp(self):
self.model_tester = TFLongformerModelTester(self)

View File

@ -362,6 +362,7 @@ class TFLxmertModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = (TFLxmertModel, TFLxmertForPreTraining) if is_tf_available() else ()
test_head_masking = False
test_onnx = False
def setUp(self):
self.model_tester = TFLxmertModelTester(self)

View File

@ -179,6 +179,7 @@ class TFMarianModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = (TFMarianMTModel,) if is_tf_available() else ()
is_encoder_decoder = True
test_pruning = False
test_onnx = False
def setUp(self):
self.model_tester = TFMarianModelTester(self)

View File

@ -181,6 +181,7 @@ class TFMBartModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = (TFMBartForConditionalGeneration,) if is_tf_available() else ()
is_encoder_decoder = True
test_pruning = False
test_onnx = False
def setUp(self):
self.model_tester = TFMBartModelTester(self)

View File

@ -56,6 +56,7 @@ class TFMobileBertModelTest(TFModelTesterMixin, unittest.TestCase):
else ()
)
test_head_masking = False
test_onnx = False
class TFMobileBertModelTester(object):
def __init__(

View File

@ -199,6 +199,7 @@ class TFMPNetModelTest(TFModelTesterMixin, unittest.TestCase):
else ()
)
test_head_masking = False
test_onnx = False
def setUp(self):
self.model_tester = TFMPNetModelTester(self)

View File

@ -203,6 +203,7 @@ class TFOpenAIGPTModelTest(TFModelTesterMixin, unittest.TestCase):
(TFOpenAIGPTLMHeadModel,) if is_tf_available() else ()
) # TODO (PVP): Add Double HeadsModel when generate() function is changed accordingly
test_head_masking = False
test_onnx = False
def setUp(self):
self.model_tester = TFOpenAIGPTModelTester(self)

View File

@ -177,6 +177,7 @@ class TFPegasusModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = (TFPegasusForConditionalGeneration,) if is_tf_available() else ()
is_encoder_decoder = True
test_pruning = False
test_onnx = False
def setUp(self):
self.model_tester = TFPegasusModelTester(self)

View File

@ -186,6 +186,7 @@ class TFRobertaModelTest(TFModelTesterMixin, unittest.TestCase):
else ()
)
test_head_masking = False
test_onnx = False
def setUp(self):
self.model_tester = TFRobertaModelTester(self)

View File

@ -249,6 +249,7 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = (TFT5Model, TFT5ForConditionalGeneration) if is_tf_available() else ()
all_generative_model_classes = (TFT5ForConditionalGeneration,) if is_tf_available() else ()
test_head_masking = False
test_onnx = False
def setUp(self):
self.model_tester = TFT5ModelTester(self)
@ -427,6 +428,7 @@ class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase):
is_encoder_decoder = False
all_model_classes = (TFT5EncoderModel,) if is_tf_available() else ()
test_head_masking = False
test_onnx = False
def setUp(self):
self.model_tester = TFT5EncoderOnlyModelTester(self)

View File

@ -164,6 +164,7 @@ class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase):
# TODO: add this test when TFTransfoXLLMHead has a linear output layer implemented
test_resize_embeddings = False
test_head_masking = False
test_onnx = False
def setUp(self):
self.model_tester = TFTransfoXLModelTester(self)

View File

@ -294,6 +294,7 @@ class TFXLMModelTest(TFModelTesterMixin, unittest.TestCase):
(TFXLMWithLMHeadModel,) if is_tf_available() else ()
) # TODO (PVP): Check other models whether language generation is also applicable
test_head_masking = False
test_onnx = False
def setUp(self):
self.model_tester = TFXLMModelTester(self)

View File

@ -348,6 +348,7 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase):
(TFXLNetLMHeadModel,) if is_tf_available() else ()
) # TODO (PVP): Check other models whether language generation is also applicable
test_head_masking = False
test_onnx = False
def setUp(self):
self.model_tester = TFXLNetModelTester(self)

101
utils/check_tf_ops.py Normal file
View File

@ -0,0 +1,101 @@
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import os
from tensorflow.core.protobuf.saved_model_pb2 import SavedModel
# All paths are set with the intent you should run this script from the root of the repo with the command
# python utils/check_copies.py
REPO_PATH = "."
# Internal TensorFlow ops that can be safely ignored (mostly specific to a saved model)
INTERNAL_OPS = [
"Assert",
"AssignVariableOp",
"EmptyTensorList",
"MergeV2Checkpoints",
"ReadVariableOp",
"ResourceGather",
"RestoreV2",
"SaveV2",
"ShardedFilename",
"StatefulPartitionedCall",
"StaticRegexFullMatch",
"VarHandleOp",
]
def onnx_compliancy(saved_model_path, strict, opset):
saved_model = SavedModel()
onnx_ops = []
with open(os.path.join(REPO_PATH, "utils", "tf_ops", "onnx.json")) as f:
onnx_opsets = json.load(f)["opsets"]
for i in range(1, opset + 1):
onnx_ops.extend(onnx_opsets[str(i)])
with open(saved_model_path, "rb") as f:
saved_model.ParseFromString(f.read())
model_op_names = set()
# Iterate over every metagraph in case there is more than one (a saved model can contain multiple graphs)
for meta_graph in saved_model.meta_graphs:
# Add operations in the graph definition
model_op_names.update(node.op for node in meta_graph.graph_def.node)
# Go through the functions in the graph definition
for func in meta_graph.graph_def.library.function:
# Add operations in each function
model_op_names.update(node.op for node in func.node_def)
# Convert to list, sorted if you want
model_op_names = sorted(model_op_names)
incompatible_ops = []
for op in model_op_names:
if op not in onnx_ops and op not in INTERNAL_OPS:
incompatible_ops.append(op)
if strict and len(incompatible_ops) > 0:
raise Exception(f"Found the following incompatible ops for the opset {opset}:\n" + incompatible_ops)
elif len(incompatible_ops) > 0:
print(f"Found the following incompatible ops for the opset {opset}:")
print(*incompatible_ops, sep="\n")
else:
print(f"The saved model {saved_model_path} can properly be converted with ONNX.")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--saved_model_path", help="Path of the saved model to check (the .pb file).")
parser.add_argument(
"--opset", default=12, type=int, help="The ONNX opset against which the model has to be tested."
)
parser.add_argument(
"--framework", choices=["onnx"], default="onnx", help="Frameworks against which to test the saved model."
)
parser.add_argument(
"--strict", action="store_true", help="Whether make the checking strict (raise errors) or not (raise warnings)"
)
args = parser.parse_args()
if args.framework == "onnx":
onnx_compliancy(args.saved_model_path, args.strict, args.opset)

245
utils/tf_ops/onnx.json Normal file
View File

@ -0,0 +1,245 @@
{
"opsets": {
"1": [
"Abs",
"Add",
"AddV2",
"ArgMax",
"ArgMin",
"AvgPool",
"AvgPool3D",
"BatchMatMul",
"BatchMatMulV2",
"BatchToSpaceND",
"BiasAdd",
"BiasAddV1",
"Cast",
"Ceil",
"CheckNumerics",
"ComplexAbs",
"Concat",
"ConcatV2",
"Const",
"ConstV2",
"Conv1D",
"Conv2D",
"Conv2DBackpropInput",
"Conv3D",
"Conv3DBackpropInputV2",
"DepthToSpace",
"DepthwiseConv2d",
"DepthwiseConv2dNative",
"Div",
"Dropout",
"Elu",
"Equal",
"Erf",
"Exp",
"ExpandDims",
"Flatten",
"Floor",
"Gather",
"GatherNd",
"GatherV2",
"Greater",
"Identity",
"IdentityN",
"If",
"LRN",
"LSTMBlockCell",
"LeakyRelu",
"Less",
"Log",
"LogSoftmax",
"LogicalAnd",
"LogicalNot",
"LogicalOr",
"LookupTableSizeV2",
"MatMul",
"Max",
"MaxPool",
"MaxPool3D",
"MaxPoolV2",
"Maximum",
"Mean",
"Min",
"Minimum",
"MirrorPad",
"Mul",
"Neg",
"NoOp",
"NotEqual",
"OneHot",
"Pack",
"Pad",
"PadV2",
"Placeholder",
"PlaceholderV2",
"PlaceholderWithDefault",
"Pow",
"Prod",
"RFFT",
"RandomNormal",
"RandomNormalLike",
"RandomUniform",
"RandomUniformLike",
"RealDiv",
"Reciprocal",
"Relu",
"Relu6",
"Reshape",
"Rsqrt",
"Selu",
"Shape",
"Sigmoid",
"Sign",
"Size",
"Slice",
"Softmax",
"Softplus",
"Softsign",
"SpaceToBatchND",
"SpaceToDepth",
"Split",
"SplitV",
"Sqrt",
"Square",
"SquaredDifference",
"Squeeze",
"StatelessIf",
"StopGradient",
"StridedSlice",
"StringJoin",
"Sub",
"Sum",
"Tanh",
"Tile",
"TopKV2",
"Transpose",
"TruncateDiv",
"Unpack",
"ZerosLike"
],
"2": [],
"3": [],
"4": [],
"5": [],
"6": [
"AddN",
"All",
"Any",
"FloorDiv",
"FusedBatchNorm",
"FusedBatchNormV2",
"FusedBatchNormV3"
],
"7": [
"Acos",
"Asin",
"Atan",
"Cos",
"Fill",
"FloorMod",
"GreaterEqual",
"LessEqual",
"Loop",
"MatrixBandPart",
"Multinomial",
"Range",
"ResizeBilinear",
"ResizeNearestNeighbor",
"Scan",
"Select",
"SelectV2",
"Sin",
"SoftmaxCrossEntropyWithLogits",
"SparseSoftmaxCrossEntropyWithLogits",
"StatelessWhile",
"Tan",
"TensorListFromTensor",
"TensorListGetItem",
"TensorListLength",
"TensorListReserve",
"TensorListResize",
"TensorListSetItem",
"TensorListStack",
"While"
],
"8": [
"BroadcastTo",
"ClipByValue",
"FIFOQueueV2",
"HashTableV2",
"IteratorGetNext",
"IteratorV2",
"LookupTableFindV2",
"MaxPoolWithArgmax",
"QueueDequeueManyV2",
"QueueDequeueUpToV2",
"QueueDequeueV2",
"ReverseSequence"
],
"9": [
"SegmentMax",
"SegmentMean",
"SegmentMin",
"SegmentProd",
"SegmentSum",
"Sinh",
"SparseSegmentMean",
"SparseSegmentMeanWithNumSegments",
"SparseSegmentSqrtN",
"SparseSegmentSqrtNWithNumSegments",
"SparseSegmentSum",
"SparseSegmentSumWithNumSegments",
"UnsortedSegmentMax",
"UnsortedSegmentMin",
"UnsortedSegmentProd",
"UnsortedSegmentSum",
"Where"
],
"10": [
"CropAndResize",
"CudnnRNN",
"DynamicStitch",
"FakeQuantWithMinMaxArgs",
"IsFinite",
"IsInf",
"NonMaxSuppressionV2",
"NonMaxSuppressionV3",
"NonMaxSuppressionV4",
"NonMaxSuppressionV5",
"ParallelDynamicStitch",
"ReverseV2",
"Roll"
],
"11": [
"Bincount",
"Cumsum",
"InvertPermutation",
"LeftShift",
"MatrixDeterminant",
"MatrixDiagPart",
"MatrixDiagPartV2",
"MatrixDiagPartV3",
"RaggedRange",
"RightShift",
"Round",
"ScatterNd",
"SparseFillEmptyRows",
"SparseReshape",
"SparseToDense",
"TensorScatterUpdate",
"Unique"
],
"12": [
"Einsum",
"MatrixDiag",
"MatrixDiagV2",
"MatrixDiagV3",
"MatrixSetDiagV3",
"SquaredDistance"
],
"13": []
}
}