mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 21:30:07 +06:00
Check TF ops for ONNX compliance (#10025)
* Add check-ops script * Finish to implement check_tf_ops and start the test * Make the test mandatory only for BERT * Update tf_ops folder * Remove useless classes * Add the ONNX test for GPT2 and BART * Add a onnxruntime slow test + better opset flexibility * Fix test + apply style * fix tests * Switch min opset from 12 to 10 * Update src/transformers/file_utils.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Fix GPT2 * Remove extra shape_list usage * Fix GPT2 * Address Morgan's comments Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
parent
93bd2f7099
commit
c8d3fa0dfd
@ -151,6 +151,16 @@ except importlib_metadata.PackageNotFoundError:
|
|||||||
_faiss_available = False
|
_faiss_available = False
|
||||||
|
|
||||||
|
|
||||||
|
_onnx_available = (
|
||||||
|
importlib.util.find_spec("keras2onnx") is not None and importlib.util.find_spec("onnxruntime") is not None
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
_onxx_version = importlib_metadata.version("onnx")
|
||||||
|
logger.debug(f"Successfully imported onnx version {_onxx_version}")
|
||||||
|
except importlib_metadata.PackageNotFoundError:
|
||||||
|
_onnx_available = False
|
||||||
|
|
||||||
|
|
||||||
_scatter_available = importlib.util.find_spec("torch_scatter") is not None
|
_scatter_available = importlib.util.find_spec("torch_scatter") is not None
|
||||||
try:
|
try:
|
||||||
_scatter_version = importlib_metadata.version("torch_scatter")
|
_scatter_version = importlib_metadata.version("torch_scatter")
|
||||||
@ -230,6 +240,10 @@ def is_tf_available():
|
|||||||
return _tf_available
|
return _tf_available
|
||||||
|
|
||||||
|
|
||||||
|
def is_onnx_available():
|
||||||
|
return _onnx_available
|
||||||
|
|
||||||
|
|
||||||
def is_flax_available():
|
def is_flax_available():
|
||||||
return _flax_available
|
return _flax_available
|
||||||
|
|
||||||
|
@ -1030,16 +1030,7 @@ class TFGPT2ForSequenceClassification(TFGPT2PreTrainedModel, TFSequenceClassific
|
|||||||
)
|
)
|
||||||
- 1
|
- 1
|
||||||
)
|
)
|
||||||
|
in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1)
|
||||||
def get_seq_element(sequence_position, input_batch):
|
|
||||||
return tf.strided_slice(
|
|
||||||
input_batch, [sequence_position, 0], [sequence_position + 1, input_batch.shape[-1]], [1, 1]
|
|
||||||
)
|
|
||||||
|
|
||||||
result = tf.map_fn(
|
|
||||||
fn=lambda t: get_seq_element(t[0], t[1]), elems=[sequence_lengths, logits], dtype="float"
|
|
||||||
)
|
|
||||||
in_logits = tf.reshape(result, [logits_shape[0], logits_shape[-1]])
|
|
||||||
else:
|
else:
|
||||||
sequence_lengths = -1
|
sequence_lengths = -1
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@ -1049,16 +1040,12 @@ class TFGPT2ForSequenceClassification(TFGPT2PreTrainedModel, TFSequenceClassific
|
|||||||
loss = None
|
loss = None
|
||||||
|
|
||||||
if inputs["labels"] is not None:
|
if inputs["labels"] is not None:
|
||||||
if input_ids is not None:
|
|
||||||
batch_size, sequence_length = shape_list(inputs["input_ids"])[:2]
|
|
||||||
else:
|
|
||||||
batch_size, sequence_length = shape_list(inputs["inputs_embeds"])[:2]
|
|
||||||
assert (
|
assert (
|
||||||
self.config.pad_token_id is not None or batch_size == 1
|
self.config.pad_token_id is not None or logits_shape[0] == 1
|
||||||
), "Cannot handle batch sizes > 1 if no padding token is defined."
|
), "Cannot handle batch sizes > 1 if no padding token is defined."
|
||||||
|
|
||||||
if not tf.is_tensor(sequence_lengths):
|
if not tf.is_tensor(sequence_lengths):
|
||||||
in_logits = logits[0:batch_size, sequence_lengths]
|
in_logits = logits[0 : logits_shape[0], sequence_lengths]
|
||||||
|
|
||||||
loss = self.compute_loss(tf.reshape(inputs["labels"], [-1]), tf.reshape(in_logits, [-1, self.num_labels]))
|
loss = self.compute_loss(tf.reshape(inputs["labels"], [-1]), tf.reshape(in_logits, [-1, self.num_labels]))
|
||||||
pooled_logits = in_logits if in_logits is not None else logits
|
pooled_logits = in_logits if in_logits is not None else logits
|
||||||
|
@ -28,6 +28,7 @@ from .file_utils import (
|
|||||||
is_datasets_available,
|
is_datasets_available,
|
||||||
is_faiss_available,
|
is_faiss_available,
|
||||||
is_flax_available,
|
is_flax_available,
|
||||||
|
is_onnx_available,
|
||||||
is_pandas_available,
|
is_pandas_available,
|
||||||
is_scatter_available,
|
is_scatter_available,
|
||||||
is_sentencepiece_available,
|
is_sentencepiece_available,
|
||||||
@ -160,6 +161,13 @@ def require_git_lfs(test_case):
|
|||||||
return test_case
|
return test_case
|
||||||
|
|
||||||
|
|
||||||
|
def require_onnx(test_case):
|
||||||
|
if not is_onnx_available():
|
||||||
|
return unittest.skip("test requires ONNX")(test_case)
|
||||||
|
else:
|
||||||
|
return test_case
|
||||||
|
|
||||||
|
|
||||||
def require_torch(test_case):
|
def require_torch(test_case):
|
||||||
"""
|
"""
|
||||||
Decorator marking a test that requires PyTorch.
|
Decorator marking a test that requires PyTorch.
|
||||||
|
@ -241,6 +241,7 @@ class TFAlbertModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
else ()
|
else ()
|
||||||
)
|
)
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
|
test_onnx = False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = TFAlbertModelTester(self)
|
self.model_tester = TFAlbertModelTester(self)
|
||||||
|
@ -178,6 +178,8 @@ class TFBartModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
all_generative_model_classes = (TFBartForConditionalGeneration,) if is_tf_available() else ()
|
all_generative_model_classes = (TFBartForConditionalGeneration,) if is_tf_available() else ()
|
||||||
is_encoder_decoder = True
|
is_encoder_decoder = True
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
|
test_onnx = True
|
||||||
|
onnx_min_opset = 10
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = TFBartModelTester(self)
|
self.model_tester = TFBartModelTester(self)
|
||||||
|
@ -274,6 +274,8 @@ class TFBertModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
else ()
|
else ()
|
||||||
)
|
)
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
|
test_onnx = True
|
||||||
|
onnx_min_opset = 10
|
||||||
|
|
||||||
# special case for ForPreTraining model
|
# special case for ForPreTraining model
|
||||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||||
|
@ -177,6 +177,7 @@ class TFBlenderbotModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
all_generative_model_classes = (TFBlenderbotForConditionalGeneration,) if is_tf_available() else ()
|
all_generative_model_classes = (TFBlenderbotForConditionalGeneration,) if is_tf_available() else ()
|
||||||
is_encoder_decoder = True
|
is_encoder_decoder = True
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
|
test_onnx = False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = TFBlenderbotModelTester(self)
|
self.model_tester = TFBlenderbotModelTester(self)
|
||||||
|
@ -179,6 +179,7 @@ class TFBlenderbotSmallModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
all_generative_model_classes = (TFBlenderbotSmallForConditionalGeneration,) if is_tf_available() else ()
|
all_generative_model_classes = (TFBlenderbotSmallForConditionalGeneration,) if is_tf_available() else ()
|
||||||
is_encoder_decoder = True
|
is_encoder_decoder = True
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
|
test_onnx = False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = TFBlenderbotSmallModelTester(self)
|
self.model_tester = TFBlenderbotSmallModelTester(self)
|
||||||
|
@ -16,6 +16,7 @@
|
|||||||
|
|
||||||
import copy
|
import copy
|
||||||
import inspect
|
import inspect
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import tempfile
|
import tempfile
|
||||||
@ -24,7 +25,7 @@ from importlib import import_module
|
|||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
from transformers import is_tf_available
|
from transformers import is_tf_available
|
||||||
from transformers.testing_utils import _tf_gpu_memory_limit, is_pt_tf_cross_test, require_tf, slow
|
from transformers.testing_utils import _tf_gpu_memory_limit, is_pt_tf_cross_test, require_onnx, require_tf, slow
|
||||||
|
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
@ -201,6 +202,67 @@ class TFModelTesterMixin:
|
|||||||
saved_model_dir = os.path.join(tmpdirname, "saved_model", "1")
|
saved_model_dir = os.path.join(tmpdirname, "saved_model", "1")
|
||||||
self.assertTrue(os.path.exists(saved_model_dir))
|
self.assertTrue(os.path.exists(saved_model_dir))
|
||||||
|
|
||||||
|
def test_onnx_compliancy(self):
|
||||||
|
if not self.test_onnx:
|
||||||
|
return
|
||||||
|
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
INTERNAL_OPS = [
|
||||||
|
"Assert",
|
||||||
|
"AssignVariableOp",
|
||||||
|
"EmptyTensorList",
|
||||||
|
"ReadVariableOp",
|
||||||
|
"ResourceGather",
|
||||||
|
"TruncatedNormal",
|
||||||
|
"VarHandleOp",
|
||||||
|
"VarIsInitializedOp",
|
||||||
|
]
|
||||||
|
onnx_ops = []
|
||||||
|
|
||||||
|
with open(os.path.join(".", "utils", "tf_ops", "onnx.json")) as f:
|
||||||
|
onnx_opsets = json.load(f)["opsets"]
|
||||||
|
|
||||||
|
for i in range(1, self.onnx_min_opset + 1):
|
||||||
|
onnx_ops.extend(onnx_opsets[str(i)])
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model_op_names = set()
|
||||||
|
|
||||||
|
with tf.Graph().as_default() as g:
|
||||||
|
model = model_class(config)
|
||||||
|
model(model.dummy_inputs)
|
||||||
|
|
||||||
|
for op in g.get_operations():
|
||||||
|
model_op_names.add(op.node_def.op)
|
||||||
|
|
||||||
|
model_op_names = sorted(model_op_names)
|
||||||
|
incompatible_ops = []
|
||||||
|
|
||||||
|
for op in model_op_names:
|
||||||
|
if op not in onnx_ops and op not in INTERNAL_OPS:
|
||||||
|
incompatible_ops.append(op)
|
||||||
|
|
||||||
|
self.assertEqual(len(incompatible_ops), 0, incompatible_ops)
|
||||||
|
|
||||||
|
@require_onnx
|
||||||
|
@slow
|
||||||
|
def test_onnx_runtime_optimize(self):
|
||||||
|
if not self.test_onnx:
|
||||||
|
return
|
||||||
|
|
||||||
|
import keras2onnx
|
||||||
|
import onnxruntime
|
||||||
|
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
model(model.dummy_inputs)
|
||||||
|
|
||||||
|
onnx_model = keras2onnx.convert_keras(model, model.name, target_opset=self.onnx_min_opset)
|
||||||
|
|
||||||
|
onnxruntime.InferenceSession(onnx_model.SerializeToString())
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_saved_model_creation_extended(self):
|
def test_saved_model_creation_extended(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
@ -239,6 +239,7 @@ class TFConvBertModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
)
|
)
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
|
test_onnx = False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = TFConvBertModelTester(self)
|
self.model_tester = TFConvBertModelTester(self)
|
||||||
|
@ -174,6 +174,7 @@ class TFCTRLModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
all_model_classes = (TFCTRLModel, TFCTRLLMHeadModel, TFCTRLForSequenceClassification) if is_tf_available() else ()
|
all_model_classes = (TFCTRLModel, TFCTRLLMHeadModel, TFCTRLForSequenceClassification) if is_tf_available() else ()
|
||||||
all_generative_model_classes = (TFCTRLLMHeadModel,) if is_tf_available() else ()
|
all_generative_model_classes = (TFCTRLLMHeadModel,) if is_tf_available() else ()
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
|
test_onnx = False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = TFCTRLModelTester(self)
|
self.model_tester = TFCTRLModelTester(self)
|
||||||
|
@ -184,6 +184,7 @@ class TFDistilBertModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
|
test_onnx = False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = TFDistilBertModelTester(self)
|
self.model_tester = TFDistilBertModelTester(self)
|
||||||
|
@ -188,6 +188,7 @@ class TFDPRModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
test_missing_keys = False
|
test_missing_keys = False
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
|
test_onnx = False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = TFDPRModelTester(self)
|
self.model_tester = TFDPRModelTester(self)
|
||||||
|
@ -206,6 +206,7 @@ class TFElectraModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
else ()
|
else ()
|
||||||
)
|
)
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
|
test_onnx = False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = TFElectraModelTester(self)
|
self.model_tester = TFElectraModelTester(self)
|
||||||
|
@ -292,6 +292,7 @@ class TFFlaubertModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
(TFFlaubertWithLMHeadModel,) if is_tf_available() else ()
|
(TFFlaubertWithLMHeadModel,) if is_tf_available() else ()
|
||||||
) # TODO (PVP): Check other models whether language generation is also applicable
|
) # TODO (PVP): Check other models whether language generation is also applicable
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
|
test_onnx = False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = TFFlaubertModelTester(self)
|
self.model_tester = TFFlaubertModelTester(self)
|
||||||
|
@ -339,6 +339,7 @@ class TFFunnelModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
else ()
|
else ()
|
||||||
)
|
)
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
|
test_onnx = False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = TFFunnelModelTester(self)
|
self.model_tester = TFFunnelModelTester(self)
|
||||||
@ -382,6 +383,7 @@ class TFFunnelBaseModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
(TFFunnelBaseModel, TFFunnelForMultipleChoice, TFFunnelForSequenceClassification) if is_tf_available() else ()
|
(TFFunnelBaseModel, TFFunnelForMultipleChoice, TFFunnelForSequenceClassification) if is_tf_available() else ()
|
||||||
)
|
)
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
|
test_onnx = False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = TFFunnelModelTester(self, base=True)
|
self.model_tester = TFFunnelModelTester(self, base=True)
|
||||||
|
@ -333,6 +333,8 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
)
|
)
|
||||||
all_generative_model_classes = (TFGPT2LMHeadModel,) if is_tf_available() else ()
|
all_generative_model_classes = (TFGPT2LMHeadModel,) if is_tf_available() else ()
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
|
test_onnx = True
|
||||||
|
onnx_min_opset = 10
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = TFGPT2ModelTester(self)
|
self.model_tester = TFGPT2ModelTester(self)
|
||||||
|
@ -195,6 +195,8 @@ class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
all_generative_model_classes = (TFLEDForConditionalGeneration,) if is_tf_available() else ()
|
all_generative_model_classes = (TFLEDForConditionalGeneration,) if is_tf_available() else ()
|
||||||
is_encoder_decoder = True
|
is_encoder_decoder = True
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
|
test_head_masking = False
|
||||||
|
test_onnx = False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = TFLEDModelTester(self)
|
self.model_tester = TFLEDModelTester(self)
|
||||||
|
@ -297,6 +297,8 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
if is_tf_available()
|
if is_tf_available()
|
||||||
else ()
|
else ()
|
||||||
)
|
)
|
||||||
|
test_head_masking = False
|
||||||
|
test_onnx = False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = TFLongformerModelTester(self)
|
self.model_tester = TFLongformerModelTester(self)
|
||||||
|
@ -362,6 +362,7 @@ class TFLxmertModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
all_model_classes = (TFLxmertModel, TFLxmertForPreTraining) if is_tf_available() else ()
|
all_model_classes = (TFLxmertModel, TFLxmertForPreTraining) if is_tf_available() else ()
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
|
test_onnx = False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = TFLxmertModelTester(self)
|
self.model_tester = TFLxmertModelTester(self)
|
||||||
|
@ -179,6 +179,7 @@ class TFMarianModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
all_generative_model_classes = (TFMarianMTModel,) if is_tf_available() else ()
|
all_generative_model_classes = (TFMarianMTModel,) if is_tf_available() else ()
|
||||||
is_encoder_decoder = True
|
is_encoder_decoder = True
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
|
test_onnx = False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = TFMarianModelTester(self)
|
self.model_tester = TFMarianModelTester(self)
|
||||||
|
@ -181,6 +181,7 @@ class TFMBartModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
all_generative_model_classes = (TFMBartForConditionalGeneration,) if is_tf_available() else ()
|
all_generative_model_classes = (TFMBartForConditionalGeneration,) if is_tf_available() else ()
|
||||||
is_encoder_decoder = True
|
is_encoder_decoder = True
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
|
test_onnx = False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = TFMBartModelTester(self)
|
self.model_tester = TFMBartModelTester(self)
|
||||||
|
@ -56,6 +56,7 @@ class TFMobileBertModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
else ()
|
else ()
|
||||||
)
|
)
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
|
test_onnx = False
|
||||||
|
|
||||||
class TFMobileBertModelTester(object):
|
class TFMobileBertModelTester(object):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -199,6 +199,7 @@ class TFMPNetModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
else ()
|
else ()
|
||||||
)
|
)
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
|
test_onnx = False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = TFMPNetModelTester(self)
|
self.model_tester = TFMPNetModelTester(self)
|
||||||
|
@ -203,6 +203,7 @@ class TFOpenAIGPTModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
(TFOpenAIGPTLMHeadModel,) if is_tf_available() else ()
|
(TFOpenAIGPTLMHeadModel,) if is_tf_available() else ()
|
||||||
) # TODO (PVP): Add Double HeadsModel when generate() function is changed accordingly
|
) # TODO (PVP): Add Double HeadsModel when generate() function is changed accordingly
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
|
test_onnx = False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = TFOpenAIGPTModelTester(self)
|
self.model_tester = TFOpenAIGPTModelTester(self)
|
||||||
|
@ -177,6 +177,7 @@ class TFPegasusModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
all_generative_model_classes = (TFPegasusForConditionalGeneration,) if is_tf_available() else ()
|
all_generative_model_classes = (TFPegasusForConditionalGeneration,) if is_tf_available() else ()
|
||||||
is_encoder_decoder = True
|
is_encoder_decoder = True
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
|
test_onnx = False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = TFPegasusModelTester(self)
|
self.model_tester = TFPegasusModelTester(self)
|
||||||
|
@ -186,6 +186,7 @@ class TFRobertaModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
else ()
|
else ()
|
||||||
)
|
)
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
|
test_onnx = False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = TFRobertaModelTester(self)
|
self.model_tester = TFRobertaModelTester(self)
|
||||||
|
@ -249,6 +249,7 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
all_model_classes = (TFT5Model, TFT5ForConditionalGeneration) if is_tf_available() else ()
|
all_model_classes = (TFT5Model, TFT5ForConditionalGeneration) if is_tf_available() else ()
|
||||||
all_generative_model_classes = (TFT5ForConditionalGeneration,) if is_tf_available() else ()
|
all_generative_model_classes = (TFT5ForConditionalGeneration,) if is_tf_available() else ()
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
|
test_onnx = False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = TFT5ModelTester(self)
|
self.model_tester = TFT5ModelTester(self)
|
||||||
@ -427,6 +428,7 @@ class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
is_encoder_decoder = False
|
is_encoder_decoder = False
|
||||||
all_model_classes = (TFT5EncoderModel,) if is_tf_available() else ()
|
all_model_classes = (TFT5EncoderModel,) if is_tf_available() else ()
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
|
test_onnx = False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = TFT5EncoderOnlyModelTester(self)
|
self.model_tester = TFT5EncoderOnlyModelTester(self)
|
||||||
|
@ -164,6 +164,7 @@ class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
# TODO: add this test when TFTransfoXLLMHead has a linear output layer implemented
|
# TODO: add this test when TFTransfoXLLMHead has a linear output layer implemented
|
||||||
test_resize_embeddings = False
|
test_resize_embeddings = False
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
|
test_onnx = False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = TFTransfoXLModelTester(self)
|
self.model_tester = TFTransfoXLModelTester(self)
|
||||||
|
@ -294,6 +294,7 @@ class TFXLMModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
(TFXLMWithLMHeadModel,) if is_tf_available() else ()
|
(TFXLMWithLMHeadModel,) if is_tf_available() else ()
|
||||||
) # TODO (PVP): Check other models whether language generation is also applicable
|
) # TODO (PVP): Check other models whether language generation is also applicable
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
|
test_onnx = False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = TFXLMModelTester(self)
|
self.model_tester = TFXLMModelTester(self)
|
||||||
|
@ -348,6 +348,7 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
(TFXLNetLMHeadModel,) if is_tf_available() else ()
|
(TFXLNetLMHeadModel,) if is_tf_available() else ()
|
||||||
) # TODO (PVP): Check other models whether language generation is also applicable
|
) # TODO (PVP): Check other models whether language generation is also applicable
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
|
test_onnx = False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = TFXLNetModelTester(self)
|
self.model_tester = TFXLNetModelTester(self)
|
||||||
|
101
utils/check_tf_ops.py
Normal file
101
utils/check_tf_ops.py
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2020 The HuggingFace Inc. team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
from tensorflow.core.protobuf.saved_model_pb2 import SavedModel
|
||||||
|
|
||||||
|
|
||||||
|
# All paths are set with the intent you should run this script from the root of the repo with the command
|
||||||
|
# python utils/check_copies.py
|
||||||
|
REPO_PATH = "."
|
||||||
|
|
||||||
|
# Internal TensorFlow ops that can be safely ignored (mostly specific to a saved model)
|
||||||
|
INTERNAL_OPS = [
|
||||||
|
"Assert",
|
||||||
|
"AssignVariableOp",
|
||||||
|
"EmptyTensorList",
|
||||||
|
"MergeV2Checkpoints",
|
||||||
|
"ReadVariableOp",
|
||||||
|
"ResourceGather",
|
||||||
|
"RestoreV2",
|
||||||
|
"SaveV2",
|
||||||
|
"ShardedFilename",
|
||||||
|
"StatefulPartitionedCall",
|
||||||
|
"StaticRegexFullMatch",
|
||||||
|
"VarHandleOp",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def onnx_compliancy(saved_model_path, strict, opset):
|
||||||
|
saved_model = SavedModel()
|
||||||
|
onnx_ops = []
|
||||||
|
|
||||||
|
with open(os.path.join(REPO_PATH, "utils", "tf_ops", "onnx.json")) as f:
|
||||||
|
onnx_opsets = json.load(f)["opsets"]
|
||||||
|
|
||||||
|
for i in range(1, opset + 1):
|
||||||
|
onnx_ops.extend(onnx_opsets[str(i)])
|
||||||
|
|
||||||
|
with open(saved_model_path, "rb") as f:
|
||||||
|
saved_model.ParseFromString(f.read())
|
||||||
|
|
||||||
|
model_op_names = set()
|
||||||
|
|
||||||
|
# Iterate over every metagraph in case there is more than one (a saved model can contain multiple graphs)
|
||||||
|
for meta_graph in saved_model.meta_graphs:
|
||||||
|
# Add operations in the graph definition
|
||||||
|
model_op_names.update(node.op for node in meta_graph.graph_def.node)
|
||||||
|
|
||||||
|
# Go through the functions in the graph definition
|
||||||
|
for func in meta_graph.graph_def.library.function:
|
||||||
|
# Add operations in each function
|
||||||
|
model_op_names.update(node.op for node in func.node_def)
|
||||||
|
|
||||||
|
# Convert to list, sorted if you want
|
||||||
|
model_op_names = sorted(model_op_names)
|
||||||
|
incompatible_ops = []
|
||||||
|
|
||||||
|
for op in model_op_names:
|
||||||
|
if op not in onnx_ops and op not in INTERNAL_OPS:
|
||||||
|
incompatible_ops.append(op)
|
||||||
|
|
||||||
|
if strict and len(incompatible_ops) > 0:
|
||||||
|
raise Exception(f"Found the following incompatible ops for the opset {opset}:\n" + incompatible_ops)
|
||||||
|
elif len(incompatible_ops) > 0:
|
||||||
|
print(f"Found the following incompatible ops for the opset {opset}:")
|
||||||
|
print(*incompatible_ops, sep="\n")
|
||||||
|
else:
|
||||||
|
print(f"The saved model {saved_model_path} can properly be converted with ONNX.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--saved_model_path", help="Path of the saved model to check (the .pb file).")
|
||||||
|
parser.add_argument(
|
||||||
|
"--opset", default=12, type=int, help="The ONNX opset against which the model has to be tested."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--framework", choices=["onnx"], default="onnx", help="Frameworks against which to test the saved model."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--strict", action="store_true", help="Whether make the checking strict (raise errors) or not (raise warnings)"
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.framework == "onnx":
|
||||||
|
onnx_compliancy(args.saved_model_path, args.strict, args.opset)
|
245
utils/tf_ops/onnx.json
Normal file
245
utils/tf_ops/onnx.json
Normal file
@ -0,0 +1,245 @@
|
|||||||
|
{
|
||||||
|
"opsets": {
|
||||||
|
"1": [
|
||||||
|
"Abs",
|
||||||
|
"Add",
|
||||||
|
"AddV2",
|
||||||
|
"ArgMax",
|
||||||
|
"ArgMin",
|
||||||
|
"AvgPool",
|
||||||
|
"AvgPool3D",
|
||||||
|
"BatchMatMul",
|
||||||
|
"BatchMatMulV2",
|
||||||
|
"BatchToSpaceND",
|
||||||
|
"BiasAdd",
|
||||||
|
"BiasAddV1",
|
||||||
|
"Cast",
|
||||||
|
"Ceil",
|
||||||
|
"CheckNumerics",
|
||||||
|
"ComplexAbs",
|
||||||
|
"Concat",
|
||||||
|
"ConcatV2",
|
||||||
|
"Const",
|
||||||
|
"ConstV2",
|
||||||
|
"Conv1D",
|
||||||
|
"Conv2D",
|
||||||
|
"Conv2DBackpropInput",
|
||||||
|
"Conv3D",
|
||||||
|
"Conv3DBackpropInputV2",
|
||||||
|
"DepthToSpace",
|
||||||
|
"DepthwiseConv2d",
|
||||||
|
"DepthwiseConv2dNative",
|
||||||
|
"Div",
|
||||||
|
"Dropout",
|
||||||
|
"Elu",
|
||||||
|
"Equal",
|
||||||
|
"Erf",
|
||||||
|
"Exp",
|
||||||
|
"ExpandDims",
|
||||||
|
"Flatten",
|
||||||
|
"Floor",
|
||||||
|
"Gather",
|
||||||
|
"GatherNd",
|
||||||
|
"GatherV2",
|
||||||
|
"Greater",
|
||||||
|
"Identity",
|
||||||
|
"IdentityN",
|
||||||
|
"If",
|
||||||
|
"LRN",
|
||||||
|
"LSTMBlockCell",
|
||||||
|
"LeakyRelu",
|
||||||
|
"Less",
|
||||||
|
"Log",
|
||||||
|
"LogSoftmax",
|
||||||
|
"LogicalAnd",
|
||||||
|
"LogicalNot",
|
||||||
|
"LogicalOr",
|
||||||
|
"LookupTableSizeV2",
|
||||||
|
"MatMul",
|
||||||
|
"Max",
|
||||||
|
"MaxPool",
|
||||||
|
"MaxPool3D",
|
||||||
|
"MaxPoolV2",
|
||||||
|
"Maximum",
|
||||||
|
"Mean",
|
||||||
|
"Min",
|
||||||
|
"Minimum",
|
||||||
|
"MirrorPad",
|
||||||
|
"Mul",
|
||||||
|
"Neg",
|
||||||
|
"NoOp",
|
||||||
|
"NotEqual",
|
||||||
|
"OneHot",
|
||||||
|
"Pack",
|
||||||
|
"Pad",
|
||||||
|
"PadV2",
|
||||||
|
"Placeholder",
|
||||||
|
"PlaceholderV2",
|
||||||
|
"PlaceholderWithDefault",
|
||||||
|
"Pow",
|
||||||
|
"Prod",
|
||||||
|
"RFFT",
|
||||||
|
"RandomNormal",
|
||||||
|
"RandomNormalLike",
|
||||||
|
"RandomUniform",
|
||||||
|
"RandomUniformLike",
|
||||||
|
"RealDiv",
|
||||||
|
"Reciprocal",
|
||||||
|
"Relu",
|
||||||
|
"Relu6",
|
||||||
|
"Reshape",
|
||||||
|
"Rsqrt",
|
||||||
|
"Selu",
|
||||||
|
"Shape",
|
||||||
|
"Sigmoid",
|
||||||
|
"Sign",
|
||||||
|
"Size",
|
||||||
|
"Slice",
|
||||||
|
"Softmax",
|
||||||
|
"Softplus",
|
||||||
|
"Softsign",
|
||||||
|
"SpaceToBatchND",
|
||||||
|
"SpaceToDepth",
|
||||||
|
"Split",
|
||||||
|
"SplitV",
|
||||||
|
"Sqrt",
|
||||||
|
"Square",
|
||||||
|
"SquaredDifference",
|
||||||
|
"Squeeze",
|
||||||
|
"StatelessIf",
|
||||||
|
"StopGradient",
|
||||||
|
"StridedSlice",
|
||||||
|
"StringJoin",
|
||||||
|
"Sub",
|
||||||
|
"Sum",
|
||||||
|
"Tanh",
|
||||||
|
"Tile",
|
||||||
|
"TopKV2",
|
||||||
|
"Transpose",
|
||||||
|
"TruncateDiv",
|
||||||
|
"Unpack",
|
||||||
|
"ZerosLike"
|
||||||
|
],
|
||||||
|
"2": [],
|
||||||
|
"3": [],
|
||||||
|
"4": [],
|
||||||
|
"5": [],
|
||||||
|
"6": [
|
||||||
|
"AddN",
|
||||||
|
"All",
|
||||||
|
"Any",
|
||||||
|
"FloorDiv",
|
||||||
|
"FusedBatchNorm",
|
||||||
|
"FusedBatchNormV2",
|
||||||
|
"FusedBatchNormV3"
|
||||||
|
],
|
||||||
|
"7": [
|
||||||
|
"Acos",
|
||||||
|
"Asin",
|
||||||
|
"Atan",
|
||||||
|
"Cos",
|
||||||
|
"Fill",
|
||||||
|
"FloorMod",
|
||||||
|
"GreaterEqual",
|
||||||
|
"LessEqual",
|
||||||
|
"Loop",
|
||||||
|
"MatrixBandPart",
|
||||||
|
"Multinomial",
|
||||||
|
"Range",
|
||||||
|
"ResizeBilinear",
|
||||||
|
"ResizeNearestNeighbor",
|
||||||
|
"Scan",
|
||||||
|
"Select",
|
||||||
|
"SelectV2",
|
||||||
|
"Sin",
|
||||||
|
"SoftmaxCrossEntropyWithLogits",
|
||||||
|
"SparseSoftmaxCrossEntropyWithLogits",
|
||||||
|
"StatelessWhile",
|
||||||
|
"Tan",
|
||||||
|
"TensorListFromTensor",
|
||||||
|
"TensorListGetItem",
|
||||||
|
"TensorListLength",
|
||||||
|
"TensorListReserve",
|
||||||
|
"TensorListResize",
|
||||||
|
"TensorListSetItem",
|
||||||
|
"TensorListStack",
|
||||||
|
"While"
|
||||||
|
],
|
||||||
|
"8": [
|
||||||
|
"BroadcastTo",
|
||||||
|
"ClipByValue",
|
||||||
|
"FIFOQueueV2",
|
||||||
|
"HashTableV2",
|
||||||
|
"IteratorGetNext",
|
||||||
|
"IteratorV2",
|
||||||
|
"LookupTableFindV2",
|
||||||
|
"MaxPoolWithArgmax",
|
||||||
|
"QueueDequeueManyV2",
|
||||||
|
"QueueDequeueUpToV2",
|
||||||
|
"QueueDequeueV2",
|
||||||
|
"ReverseSequence"
|
||||||
|
],
|
||||||
|
"9": [
|
||||||
|
"SegmentMax",
|
||||||
|
"SegmentMean",
|
||||||
|
"SegmentMin",
|
||||||
|
"SegmentProd",
|
||||||
|
"SegmentSum",
|
||||||
|
"Sinh",
|
||||||
|
"SparseSegmentMean",
|
||||||
|
"SparseSegmentMeanWithNumSegments",
|
||||||
|
"SparseSegmentSqrtN",
|
||||||
|
"SparseSegmentSqrtNWithNumSegments",
|
||||||
|
"SparseSegmentSum",
|
||||||
|
"SparseSegmentSumWithNumSegments",
|
||||||
|
"UnsortedSegmentMax",
|
||||||
|
"UnsortedSegmentMin",
|
||||||
|
"UnsortedSegmentProd",
|
||||||
|
"UnsortedSegmentSum",
|
||||||
|
"Where"
|
||||||
|
],
|
||||||
|
"10": [
|
||||||
|
"CropAndResize",
|
||||||
|
"CudnnRNN",
|
||||||
|
"DynamicStitch",
|
||||||
|
"FakeQuantWithMinMaxArgs",
|
||||||
|
"IsFinite",
|
||||||
|
"IsInf",
|
||||||
|
"NonMaxSuppressionV2",
|
||||||
|
"NonMaxSuppressionV3",
|
||||||
|
"NonMaxSuppressionV4",
|
||||||
|
"NonMaxSuppressionV5",
|
||||||
|
"ParallelDynamicStitch",
|
||||||
|
"ReverseV2",
|
||||||
|
"Roll"
|
||||||
|
],
|
||||||
|
"11": [
|
||||||
|
"Bincount",
|
||||||
|
"Cumsum",
|
||||||
|
"InvertPermutation",
|
||||||
|
"LeftShift",
|
||||||
|
"MatrixDeterminant",
|
||||||
|
"MatrixDiagPart",
|
||||||
|
"MatrixDiagPartV2",
|
||||||
|
"MatrixDiagPartV3",
|
||||||
|
"RaggedRange",
|
||||||
|
"RightShift",
|
||||||
|
"Round",
|
||||||
|
"ScatterNd",
|
||||||
|
"SparseFillEmptyRows",
|
||||||
|
"SparseReshape",
|
||||||
|
"SparseToDense",
|
||||||
|
"TensorScatterUpdate",
|
||||||
|
"Unique"
|
||||||
|
],
|
||||||
|
"12": [
|
||||||
|
"Einsum",
|
||||||
|
"MatrixDiag",
|
||||||
|
"MatrixDiagV2",
|
||||||
|
"MatrixDiagV3",
|
||||||
|
"MatrixSetDiagV3",
|
||||||
|
"SquaredDistance"
|
||||||
|
],
|
||||||
|
"13": []
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user