transformers/tests/test_onnx_v2.py
Funtowicz Morgan 2aa3cd935d
[RFC] Laying down building stone for more flexible ONNX export capabilities (#11786)
* Laying down building stone for more flexible ONNX export capabilities

* Ability to provide a map of config key to override before exporting.

* Makes it possible to export BART with/without past keys.

* Supports simple mathematical syntax for OnnxVariable.repeated

* Effectively apply value override from onnx config for model

* Supports export with additional features such as with-past for seq2seq

* Store the output path directly in the args for uniform usage across.

* Make BART_ONNX_CONFIG_* constants and fix imports.

* Support BERT model.

* Use tokenizer for more flexibility in defining the inputs of a model.

* Add TODO as remainder to provide the batch/sequence_length as CLI args

* Enable optimizations to be done on the model.

* Enable GPT2 + past

* Improve model validation with outputs containing nested structures

* Enable Roberta

* Enable Albert

* Albert requires opset >= 12

* BERT-like models requires opset >= 12

* Remove double printing.

* Enable XLM-Roberta

* Enable DistilBERT

* Disable optimization by default

* Fix missing setattr when applying optimizer_features

* Add value field to OnnxVariable to define constant input (not from tokenizers)

* Add T5 support.

* Simplify model type retrieval

* Example exporting token_classification pipeline for DistilBERT.

* Refactoring to package `transformers.onnx`

* Solve circular dependency & __main__

* Remove unnecessary imports in `__init__`

* Licences

* Use @Narsil's suggestion to forward the model's configuration to the ONNXConfig to avoid interpolation.

* Onnx export v2 fixes (#12388)

* Tiny fixes
Remove `convert_pytorch` from onnxruntime-less runtimes
Correct reference to model

* Style

* Fix Copied from

* LongFormer ONNX config.

* Removed optimizations

* Remvoe bad merge relicas.

* Remove unused constants.

* Remove some deleted constants from imports.

* Fix unittest to remove usage of PyTorch model for onnx.utils.

* Fix distilbert export

* Enable ONNX export test for supported model.

* Style.

* Fix lint.

* Enable all supported default models.

* GPT2 only has one output

* Fix bad property name when overriding config.

* Added unittests and docstrings.

* Disable with_past tests for now.

* Enable outputs validation for default export.

* Remove graph opt lvls.

* Last commit with on-going past commented.

* Style.

* Disabled `with_past` for now

* Remove unused imports.

* Remove framework argument

* Remove TFPreTrainedModel reference

* Add documentation

* Add onnxruntime tests to CircleCI

* Add test

* Rename `convert_pytorch` to `export`

* Use OrderedDict for dummy inputs

* WIP Wav2Vec2

* Revert "WIP Wav2Vec2"

This reverts commit f665efb04c92525c3530e589029f0ae7afdf603e.

* Style

* Use OrderedDict for I/O

* Style.

* Specify OrderedDict documentation.

* Style :)

Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
2021-07-08 10:54:42 -04:00

252 lines
10 KiB
Python

from pathlib import Path
from tempfile import NamedTemporaryFile
from unittest import TestCase
from unittest.mock import patch
from transformers import ( # LongformerConfig,
AlbertConfig,
AutoTokenizer,
BartConfig,
DistilBertConfig,
GPT2Config,
RobertaConfig,
T5Config,
XLMRobertaConfig,
is_torch_available,
)
from transformers.models.albert import AlbertOnnxConfig
from transformers.models.bart import BartOnnxConfig
from transformers.models.bert.configuration_bert import BertConfig, BertOnnxConfig
from transformers.models.distilbert import DistilBertOnnxConfig
# from transformers.models.longformer import LongformerOnnxConfig
from transformers.models.gpt2 import GPT2OnnxConfig
from transformers.models.roberta import RobertaOnnxConfig
from transformers.models.t5 import T5OnnxConfig
from transformers.models.xlm_roberta import XLMRobertaOnnxConfig
from transformers.onnx import EXTERNAL_DATA_FORMAT_SIZE_LIMIT, OnnxConfig, ParameterFormat, validate_model_outputs
from transformers.onnx.config import DEFAULT_ONNX_OPSET, OnnxConfigWithPast
from transformers.onnx.utils import (
compute_effective_axis_dimension,
compute_serialized_parameters_size,
flatten_output_collection_property,
)
from transformers.testing_utils import require_onnx, require_torch, slow
@require_onnx
class OnnxUtilsTestCaseV2(TestCase):
"""
Cover all the utilities involved to export ONNX models
"""
def test_compute_effective_axis_dimension(self):
"""
When exporting ONNX model with dynamic axis (batch or sequence) we set batch_size and/or sequence_length = -1.
We cannot generate an effective tensor with axis dim == -1, so we trick by using some "fixed" values
(> 1 to avoid ONNX squeezing the axis).
This test ensure we are correctly replacing generated batch / sequence tensor with axis > 1
"""
# Dynamic axis (batch, no token added by the tokenizer)
self.assertEqual(compute_effective_axis_dimension(-1, fixed_dimension=2, num_token_to_add=0), 2)
# Static axis (batch, no token added by the tokenizer)
self.assertEqual(compute_effective_axis_dimension(0, fixed_dimension=2, num_token_to_add=0), 2)
# Dynamic axis (sequence, token added by the tokenizer 2 (no pair))
self.assertEqual(compute_effective_axis_dimension(0, fixed_dimension=8, num_token_to_add=2), 6)
self.assertEqual(compute_effective_axis_dimension(0, fixed_dimension=8, num_token_to_add=2), 6)
# Dynamic axis (sequence, token added by the tokenizer 3 (pair))
self.assertEqual(compute_effective_axis_dimension(0, fixed_dimension=8, num_token_to_add=3), 5)
self.assertEqual(compute_effective_axis_dimension(0, fixed_dimension=8, num_token_to_add=3), 5)
def test_compute_parameters_serialized_size(self):
"""
This test ensures we compute a "correct" approximation of the underlying storage requirement (size) for all the
parameters for the specified parameter's dtype.
"""
self.assertEqual(compute_serialized_parameters_size(2, ParameterFormat.Float), 2 * ParameterFormat.Float.size)
def test_flatten_output_collection_property(self):
"""
This test ensures we correctly flatten nested collection such as the one we use when returning past_keys.
past_keys = Tuple[Tuple]
ONNX exporter will export nested collections as ${collection_name}.${level_idx_0}.${level_idx_1}...${idx_n}
"""
self.assertEqual(
flatten_output_collection_property("past_key", [[0], [1], [2]]),
{
"past_key.0": 0,
"past_key.1": 1,
"past_key.2": 2,
},
)
class OnnxConfigTestCaseV2(TestCase):
"""
Cover the test for models default.
Default means no specific features is being enabled on the model.
"""
@patch.multiple(OnnxConfig, __abstractmethods__=set())
def test_use_external_data_format(self):
"""
External data format is required only if the serialized size of the parameters if bigger than 2Gb
"""
TWO_GB_LIMIT = EXTERNAL_DATA_FORMAT_SIZE_LIMIT
# No parameters
self.assertFalse(OnnxConfig.use_external_data_format(0))
# Some parameters
self.assertFalse(OnnxConfig.use_external_data_format(1))
# Almost 2Gb parameters
self.assertFalse(OnnxConfig.use_external_data_format((TWO_GB_LIMIT - 1) // ParameterFormat.Float.size))
# Exactly 2Gb parameters
self.assertTrue(OnnxConfig.use_external_data_format(TWO_GB_LIMIT))
# More than 2Gb parameters
self.assertTrue(OnnxConfig.use_external_data_format((TWO_GB_LIMIT + 1) // ParameterFormat.Float.size))
class OnnxConfigWithPastTestCaseV2(TestCase):
"""
Cover the tests for model which have use_cache feature (i.e. "with_past" for ONNX)
"""
SUPPORTED_WITH_PAST_CONFIGS = {("BART", BartConfig), ("GPT2", GPT2Config), ("T5", T5Config)}
@patch.multiple(OnnxConfigWithPast, __abstractmethods__=set())
def test_use_past(self):
"""
Ensure the use_past variable is correctly being set
"""
for name, config in OnnxConfigWithPastTestCaseV2.SUPPORTED_WITH_PAST_CONFIGS:
with self.subTest(name):
self.assertFalse(
OnnxConfigWithPast.default(config()).use_past, "OnnxConfigWithPast.default() should not use_past"
)
self.assertTrue(
OnnxConfigWithPast.with_past(config()).use_past, "OnnxConfigWithPast.default() should use_past"
)
@patch.multiple(OnnxConfigWithPast, __abstractmethods__=set())
def test_values_override(self):
"""
Ensure the use_past variable correctly set the `use_cache` value in model's configuration
"""
for name, config in OnnxConfigWithPastTestCaseV2.SUPPORTED_WITH_PAST_CONFIGS:
with self.subTest(name):
# without past
onnx_config_default = OnnxConfigWithPast.default(config())
self.assertIsNotNone(onnx_config_default.values_override, "values_override should not be None")
self.assertIn("use_cache", onnx_config_default.values_override, "use_cache should be present")
self.assertFalse(
onnx_config_default.values_override["use_cache"], "use_cache should be False if not using past"
)
# with past
onnx_config_default = OnnxConfigWithPast.with_past(config())
self.assertIsNotNone(onnx_config_default.values_override, "values_override should not be None")
self.assertIn("use_cache", onnx_config_default.values_override, "use_cache should be present")
self.assertTrue(
onnx_config_default.values_override["use_cache"], "use_cache should be False if not using past"
)
if is_torch_available():
from transformers import (
AlbertModel,
BartModel,
BertModel,
DistilBertModel,
GPT2Model,
RobertaModel,
T5Model,
XLMRobertaModel,
)
PYTORCH_EXPORT_DEFAULT_MODELS = {
("ALBERT", "albert-base-v2", AlbertModel, AlbertConfig, AlbertOnnxConfig),
("BART", "facebook/bart-base", BartModel, BartConfig, BartOnnxConfig),
("BERT", "bert-base-cased", BertModel, BertConfig, BertOnnxConfig),
("DistilBERT", "distilbert-base-cased", DistilBertModel, DistilBertConfig, DistilBertOnnxConfig),
("GPT2", "gpt2", GPT2Model, GPT2Config, GPT2OnnxConfig),
# ("LongFormer", "longformer-base-4096", LongformerModel, LongformerConfig, LongformerOnnxConfig),
("Roberta", "roberta-base", RobertaModel, RobertaConfig, RobertaOnnxConfig),
("XLM-Roberta", "roberta-base", XLMRobertaModel, XLMRobertaConfig, XLMRobertaOnnxConfig),
("T5", "t5-small", T5Model, T5Config, T5OnnxConfig),
}
PYTORCH_EXPORT_WITH_PAST_MODELS = {
# ("BART", "facebook/bart-base", BartModel, BartConfig, BartOnnxConfig),
# ("GPT2", "gpt2", GPT2Model, GPT2Config, GPT2OnnxConfig),
# ("T5", "t5-small", T5Model, T5Config, T5OnnxConfig)
}
class OnnxExportTestCaseV2(TestCase):
"""
Integration tests ensuring supported models are correctly exported
"""
@slow
@require_torch
def test_pytorch_export_default(self):
from transformers.onnx import export
for name, model, model_class, config_class, onnx_config_class in PYTORCH_EXPORT_DEFAULT_MODELS:
with self.subTest(name):
self.assertTrue(hasattr(onnx_config_class, "default"))
tokenizer = AutoTokenizer.from_pretrained(model)
model = model_class(config_class())
onnx_config = onnx_config_class.default(model.config)
with NamedTemporaryFile("w") as output:
onnx_inputs, onnx_outputs = export(
tokenizer, model, onnx_config, DEFAULT_ONNX_OPSET, Path(output.name)
)
try:
validate_model_outputs(onnx_config, tokenizer, model, Path(output.name), onnx_outputs, 1e-5)
except ValueError as ve:
self.fail(f"{name} -> {ve}")
@slow
@require_torch
def test_pytorch_export_with_past(self):
from transformers.onnx import export
for name, model, model_class, config_class, onnx_config_class in PYTORCH_EXPORT_WITH_PAST_MODELS:
with self.subTest(name):
self.assertTrue(hasattr(onnx_config_class, "with_past"), "OnnxConfigWithPast should have with_past()")
tokenizer = AutoTokenizer.from_pretrained(model)
model = model_class(config_class())
onnx_config = onnx_config_class.with_past(model.config)
self.assertTrue(hasattr(onnx_config, "use_past"), "OnnxConfigWithPast should have use_past attribute.")
self.assertTrue(
onnx_config.use_past, "OnnxConfigWithPast.use_past should be if called with with_past()"
)
with NamedTemporaryFile("w") as output:
output = Path(output.name)
onnx_inputs, onnx_outputs = export(tokenizer, model, onnx_config, DEFAULT_ONNX_OPSET, output)
try:
validate_model_outputs(onnx_config, tokenizer, model, output, onnx_outputs, 1e-5)
except ValueError as ve:
self.fail(f"{name} -> {ve}")