Determine framework automatically before ONNX export (#18615)

* Automatic detection for framework to use when exporting to ONNX

* Log message change

* Incorporating PR comments, adding unit test

* Adding tf for pip install for run_tests_onnxruntime CI

* Restoring past changes to circleci yaml and test_onnx_v2.py, tests moved to tests/onnx/test_features.py

* Fixup

* Adding test to fetcher

* Updating circleci config to log more

* Changing test class name

* Comment typo fix in tests/onnx/test_features.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Moving torch_str/tf_str to self.framework_pt/tf

* Remove -rA flag in circleci config

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
This commit is contained in:
Craig Chan 2022-08-25 07:31:34 -07:00 committed by GitHub
parent 3223d49354
commit fbf382c84d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 182 additions and 8 deletions

View File

@ -878,7 +878,7 @@ jobs:
- v0.5-torch-{{ checksum "setup.py" }}
- v0.5-{{ checksum "setup.py" }}
- run: pip install --upgrade pip
- run: pip install .[torch,testing,sentencepiece,onnxruntime,vision,rjieba]
- run: pip install .[torch,tf,testing,sentencepiece,onnxruntime,vision,rjieba]
- save_cache:
key: v0.5-onnx-{{ checksum "setup.py" }}
paths:
@ -912,7 +912,7 @@ jobs:
- v0.5-torch-{{ checksum "setup.py" }}
- v0.5-{{ checksum "setup.py" }}
- run: pip install --upgrade pip
- run: pip install .[torch,testing,sentencepiece,onnxruntime,vision]
- run: pip install .[torch,tf,testing,sentencepiece,onnxruntime,vision]
- save_cache:
key: v0.5-onnx-{{ checksum "setup.py" }}
paths:

View File

@ -38,7 +38,15 @@ def main():
"--atol", type=float, default=None, help="Absolute difference tolerence when validating the model."
)
parser.add_argument(
"--framework", type=str, choices=["pt", "tf"], default="pt", help="The framework to use for the ONNX export."
"--framework",
type=str,
choices=["pt", "tf"],
default=None,
help=(
"The framework to use for the ONNX export."
" If not provided, will attempt to use the local checkpoint's original framework"
" or what is available in the environment."
),
)
parser.add_argument("output", type=Path, help="Path indicating where to store generated ONNX model.")
parser.add_argument("--cache_dir", type=str, default=None, help="Path indicating where to store cache.")

View File

@ -1,10 +1,11 @@
import os
from functools import partial, reduce
from typing import TYPE_CHECKING, Callable, Dict, Optional, Tuple, Type, Union
import transformers
from .. import PretrainedConfig, is_tf_available, is_torch_available
from ..utils import logging
from ..utils import TF2_WEIGHTS_NAME, WEIGHTS_NAME, logging
from .config import OnnxConfig
@ -566,9 +567,59 @@ class FeaturesManager:
)
return task_to_automodel[task]
@staticmethod
def determine_framework(model: str, framework: str = None) -> str:
"""
Determines the framework to use for the export.
The priority is in the following order:
1. User input via `framework`.
2. If local checkpoint is provided, use the same framework as the checkpoint.
3. Available framework in environment, with priority given to PyTorch
Args:
model (`str`):
The name of the model to export.
framework (`str`, *optional*, defaults to `None`):
The framework to use for the export. See above for priority if none provided.
Returns:
The framework to use for the export.
"""
if framework is not None:
return framework
framework_map = {"pt": "PyTorch", "tf": "TensorFlow"}
exporter_map = {"pt": "torch", "tf": "tf2onnx"}
if os.path.isdir(model):
if os.path.isfile(os.path.join(model, WEIGHTS_NAME)):
framework = "pt"
elif os.path.isfile(os.path.join(model, TF2_WEIGHTS_NAME)):
framework = "tf"
else:
raise FileNotFoundError(
"Cannot determine framework from given checkpoint location."
f" There should be a {WEIGHTS_NAME} for PyTorch"
f" or {TF2_WEIGHTS_NAME} for TensorFlow."
)
logger.info(f"Local {framework_map[framework]} model found.")
else:
if is_torch_available():
framework = "pt"
elif is_tf_available():
framework = "tf"
else:
raise EnvironmentError("Neither PyTorch nor TensorFlow found in environment. Cannot export to ONNX.")
logger.info(f"Framework not requested. Using {exporter_map[framework]} to export to ONNX.")
return framework
@staticmethod
def get_model_from_feature(
feature: str, model: str, framework: str = "pt", cache_dir: str = None
feature: str, model: str, framework: str = None, cache_dir: str = None
) -> Union["PreTrainedModel", "TFPreTrainedModel"]:
"""
Attempts to retrieve a model from a model's name and the feature to be enabled.
@ -578,20 +629,24 @@ class FeaturesManager:
The feature required.
model (`str`):
The name of the model to export.
framework (`str`, *optional*, defaults to `"pt"`):
The framework to use for the export.
framework (`str`, *optional*, defaults to `None`):
The framework to use for the export. See `FeaturesManager.determine_framework` for the priority should
none be provided.
Returns:
The instance of the model.
"""
framework = FeaturesManager.determine_framework(model, framework)
model_class = FeaturesManager.get_model_class_for_feature(feature, framework)
try:
model = model_class.from_pretrained(model, cache_dir=cache_dir)
except OSError:
if framework == "pt":
logger.info("Loading TensorFlow model in PyTorch before exporting to ONNX.")
model = model_class.from_pretrained(model, from_tf=True, cache_dir=cache_dir)
else:
logger.info("Loading PyTorch model in TensorFlow before exporting to ONNX.")
model = model_class.from_pretrained(model, from_pt=True, cache_dir=cache_dir)
return model

111
tests/onnx/test_features.py Normal file
View File

@ -0,0 +1,111 @@
from tempfile import TemporaryDirectory
from unittest import TestCase
from unittest.mock import MagicMock, patch
from transformers import AutoModel, TFAutoModel
from transformers.onnx import FeaturesManager
from transformers.testing_utils import SMALL_MODEL_IDENTIFIER, require_tf, require_torch
@require_torch
@require_tf
class DetermineFrameworkTest(TestCase):
"""
Test `FeaturesManager.determine_framework`
"""
def setUp(self):
self.test_model = SMALL_MODEL_IDENTIFIER
self.framework_pt = "pt"
self.framework_tf = "tf"
def _setup_pt_ckpt(self, save_dir):
model_pt = AutoModel.from_pretrained(self.test_model)
model_pt.save_pretrained(save_dir)
def _setup_tf_ckpt(self, save_dir):
model_tf = TFAutoModel.from_pretrained(self.test_model, from_pt=True)
model_tf.save_pretrained(save_dir)
def test_framework_provided(self):
"""
Ensure the that the provided framework is returned.
"""
mock_framework = "mock_framework"
# Framework provided - return whatever the user provides
result = FeaturesManager.determine_framework(self.test_model, mock_framework)
self.assertEqual(result, mock_framework)
# Local checkpoint and framework provided - return provided framework
# PyTorch checkpoint
with TemporaryDirectory() as local_pt_ckpt:
self._setup_pt_ckpt(local_pt_ckpt)
result = FeaturesManager.determine_framework(local_pt_ckpt, mock_framework)
self.assertEqual(result, mock_framework)
# TensorFlow checkpoint
with TemporaryDirectory() as local_tf_ckpt:
self._setup_tf_ckpt(local_tf_ckpt)
result = FeaturesManager.determine_framework(local_tf_ckpt, mock_framework)
self.assertEqual(result, mock_framework)
def test_checkpoint_provided(self):
"""
Ensure that the determined framework is the one used for the local checkpoint.
For the functionality to execute, local checkpoints are provided but framework is not.
"""
# PyTorch checkpoint
with TemporaryDirectory() as local_pt_ckpt:
self._setup_pt_ckpt(local_pt_ckpt)
result = FeaturesManager.determine_framework(local_pt_ckpt)
self.assertEqual(result, self.framework_pt)
# TensorFlow checkpoint
with TemporaryDirectory() as local_tf_ckpt:
self._setup_tf_ckpt(local_tf_ckpt)
result = FeaturesManager.determine_framework(local_tf_ckpt)
self.assertEqual(result, self.framework_tf)
# Invalid local checkpoint
with TemporaryDirectory() as local_invalid_ckpt:
with self.assertRaises(FileNotFoundError):
result = FeaturesManager.determine_framework(local_invalid_ckpt)
def test_from_environment(self):
"""
Ensure that the determined framework is the one available in the environment.
For the functionality to execute, framework and local checkpoints are not provided.
"""
# Framework not provided, hub model is used (no local checkpoint directory)
# TensorFlow not in environment -> use PyTorch
mock_tf_available = MagicMock(return_value=False)
with patch("transformers.onnx.features.is_tf_available", mock_tf_available):
result = FeaturesManager.determine_framework(self.test_model)
self.assertEqual(result, self.framework_pt)
# PyTorch not in environment -> use TensorFlow
mock_torch_available = MagicMock(return_value=False)
with patch("transformers.onnx.features.is_torch_available", mock_torch_available):
result = FeaturesManager.determine_framework(self.test_model)
self.assertEqual(result, self.framework_tf)
# Both in environment -> use PyTorch
mock_tf_available = MagicMock(return_value=True)
mock_torch_available = MagicMock(return_value=True)
with patch("transformers.onnx.features.is_tf_available", mock_tf_available), patch(
"transformers.onnx.features.is_torch_available", mock_torch_available
):
result = FeaturesManager.determine_framework(self.test_model)
self.assertEqual(result, self.framework_pt)
# Both not in environment -> raise error
mock_tf_available = MagicMock(return_value=False)
mock_torch_available = MagicMock(return_value=False)
with patch("transformers.onnx.features.is_tf_available", mock_tf_available), patch(
"transformers.onnx.features.is_torch_available", mock_torch_available
):
with self.assertRaises(EnvironmentError):
result = FeaturesManager.determine_framework(self.test_model)

View File

@ -434,7 +434,7 @@ def module_to_test_file(module_fname):
return "tests/utils/test_cli.py"
# Special case for onnx submodules
elif len(splits) >= 2 and splits[-2] == "onnx":
return ["tests/onnx/test_onnx.py", "tests/onnx/test_onnx_v2.py"]
return ["tests/onnx/test_features.py", "tests/onnx/test_onnx.py", "tests/onnx/test_onnx_v2.py"]
# Special case for utils (not the one in src/transformers, the ones at the root of the repo).
elif len(splits) > 0 and splits[0] == "utils":
default_test_file = f"tests/utils/test_utils_{module_name}"