mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Determine framework automatically before ONNX export (#18615)
* Automatic detection for framework to use when exporting to ONNX * Log message change * Incorporating PR comments, adding unit test * Adding tf for pip install for run_tests_onnxruntime CI * Restoring past changes to circleci yaml and test_onnx_v2.py, tests moved to tests/onnx/test_features.py * Fixup * Adding test to fetcher * Updating circleci config to log more * Changing test class name * Comment typo fix in tests/onnx/test_features.py Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * Moving torch_str/tf_str to self.framework_pt/tf * Remove -rA flag in circleci config Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
This commit is contained in:
parent
3223d49354
commit
fbf382c84d
@ -878,7 +878,7 @@ jobs:
|
||||
- v0.5-torch-{{ checksum "setup.py" }}
|
||||
- v0.5-{{ checksum "setup.py" }}
|
||||
- run: pip install --upgrade pip
|
||||
- run: pip install .[torch,testing,sentencepiece,onnxruntime,vision,rjieba]
|
||||
- run: pip install .[torch,tf,testing,sentencepiece,onnxruntime,vision,rjieba]
|
||||
- save_cache:
|
||||
key: v0.5-onnx-{{ checksum "setup.py" }}
|
||||
paths:
|
||||
@ -912,7 +912,7 @@ jobs:
|
||||
- v0.5-torch-{{ checksum "setup.py" }}
|
||||
- v0.5-{{ checksum "setup.py" }}
|
||||
- run: pip install --upgrade pip
|
||||
- run: pip install .[torch,testing,sentencepiece,onnxruntime,vision]
|
||||
- run: pip install .[torch,tf,testing,sentencepiece,onnxruntime,vision]
|
||||
- save_cache:
|
||||
key: v0.5-onnx-{{ checksum "setup.py" }}
|
||||
paths:
|
||||
|
@ -38,7 +38,15 @@ def main():
|
||||
"--atol", type=float, default=None, help="Absolute difference tolerence when validating the model."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--framework", type=str, choices=["pt", "tf"], default="pt", help="The framework to use for the ONNX export."
|
||||
"--framework",
|
||||
type=str,
|
||||
choices=["pt", "tf"],
|
||||
default=None,
|
||||
help=(
|
||||
"The framework to use for the ONNX export."
|
||||
" If not provided, will attempt to use the local checkpoint's original framework"
|
||||
" or what is available in the environment."
|
||||
),
|
||||
)
|
||||
parser.add_argument("output", type=Path, help="Path indicating where to store generated ONNX model.")
|
||||
parser.add_argument("--cache_dir", type=str, default=None, help="Path indicating where to store cache.")
|
||||
|
@ -1,10 +1,11 @@
|
||||
import os
|
||||
from functools import partial, reduce
|
||||
from typing import TYPE_CHECKING, Callable, Dict, Optional, Tuple, Type, Union
|
||||
|
||||
import transformers
|
||||
|
||||
from .. import PretrainedConfig, is_tf_available, is_torch_available
|
||||
from ..utils import logging
|
||||
from ..utils import TF2_WEIGHTS_NAME, WEIGHTS_NAME, logging
|
||||
from .config import OnnxConfig
|
||||
|
||||
|
||||
@ -566,9 +567,59 @@ class FeaturesManager:
|
||||
)
|
||||
return task_to_automodel[task]
|
||||
|
||||
@staticmethod
|
||||
def determine_framework(model: str, framework: str = None) -> str:
|
||||
"""
|
||||
Determines the framework to use for the export.
|
||||
|
||||
The priority is in the following order:
|
||||
1. User input via `framework`.
|
||||
2. If local checkpoint is provided, use the same framework as the checkpoint.
|
||||
3. Available framework in environment, with priority given to PyTorch
|
||||
|
||||
Args:
|
||||
model (`str`):
|
||||
The name of the model to export.
|
||||
framework (`str`, *optional*, defaults to `None`):
|
||||
The framework to use for the export. See above for priority if none provided.
|
||||
|
||||
Returns:
|
||||
The framework to use for the export.
|
||||
|
||||
"""
|
||||
if framework is not None:
|
||||
return framework
|
||||
|
||||
framework_map = {"pt": "PyTorch", "tf": "TensorFlow"}
|
||||
exporter_map = {"pt": "torch", "tf": "tf2onnx"}
|
||||
|
||||
if os.path.isdir(model):
|
||||
if os.path.isfile(os.path.join(model, WEIGHTS_NAME)):
|
||||
framework = "pt"
|
||||
elif os.path.isfile(os.path.join(model, TF2_WEIGHTS_NAME)):
|
||||
framework = "tf"
|
||||
else:
|
||||
raise FileNotFoundError(
|
||||
"Cannot determine framework from given checkpoint location."
|
||||
f" There should be a {WEIGHTS_NAME} for PyTorch"
|
||||
f" or {TF2_WEIGHTS_NAME} for TensorFlow."
|
||||
)
|
||||
logger.info(f"Local {framework_map[framework]} model found.")
|
||||
else:
|
||||
if is_torch_available():
|
||||
framework = "pt"
|
||||
elif is_tf_available():
|
||||
framework = "tf"
|
||||
else:
|
||||
raise EnvironmentError("Neither PyTorch nor TensorFlow found in environment. Cannot export to ONNX.")
|
||||
|
||||
logger.info(f"Framework not requested. Using {exporter_map[framework]} to export to ONNX.")
|
||||
|
||||
return framework
|
||||
|
||||
@staticmethod
|
||||
def get_model_from_feature(
|
||||
feature: str, model: str, framework: str = "pt", cache_dir: str = None
|
||||
feature: str, model: str, framework: str = None, cache_dir: str = None
|
||||
) -> Union["PreTrainedModel", "TFPreTrainedModel"]:
|
||||
"""
|
||||
Attempts to retrieve a model from a model's name and the feature to be enabled.
|
||||
@ -578,20 +629,24 @@ class FeaturesManager:
|
||||
The feature required.
|
||||
model (`str`):
|
||||
The name of the model to export.
|
||||
framework (`str`, *optional*, defaults to `"pt"`):
|
||||
The framework to use for the export.
|
||||
framework (`str`, *optional*, defaults to `None`):
|
||||
The framework to use for the export. See `FeaturesManager.determine_framework` for the priority should
|
||||
none be provided.
|
||||
|
||||
Returns:
|
||||
The instance of the model.
|
||||
|
||||
"""
|
||||
framework = FeaturesManager.determine_framework(model, framework)
|
||||
model_class = FeaturesManager.get_model_class_for_feature(feature, framework)
|
||||
try:
|
||||
model = model_class.from_pretrained(model, cache_dir=cache_dir)
|
||||
except OSError:
|
||||
if framework == "pt":
|
||||
logger.info("Loading TensorFlow model in PyTorch before exporting to ONNX.")
|
||||
model = model_class.from_pretrained(model, from_tf=True, cache_dir=cache_dir)
|
||||
else:
|
||||
logger.info("Loading PyTorch model in TensorFlow before exporting to ONNX.")
|
||||
model = model_class.from_pretrained(model, from_pt=True, cache_dir=cache_dir)
|
||||
return model
|
||||
|
||||
|
111
tests/onnx/test_features.py
Normal file
111
tests/onnx/test_features.py
Normal file
@ -0,0 +1,111 @@
|
||||
from tempfile import TemporaryDirectory
|
||||
from unittest import TestCase
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from transformers import AutoModel, TFAutoModel
|
||||
from transformers.onnx import FeaturesManager
|
||||
from transformers.testing_utils import SMALL_MODEL_IDENTIFIER, require_tf, require_torch
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_tf
|
||||
class DetermineFrameworkTest(TestCase):
|
||||
"""
|
||||
Test `FeaturesManager.determine_framework`
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
self.test_model = SMALL_MODEL_IDENTIFIER
|
||||
self.framework_pt = "pt"
|
||||
self.framework_tf = "tf"
|
||||
|
||||
def _setup_pt_ckpt(self, save_dir):
|
||||
model_pt = AutoModel.from_pretrained(self.test_model)
|
||||
model_pt.save_pretrained(save_dir)
|
||||
|
||||
def _setup_tf_ckpt(self, save_dir):
|
||||
model_tf = TFAutoModel.from_pretrained(self.test_model, from_pt=True)
|
||||
model_tf.save_pretrained(save_dir)
|
||||
|
||||
def test_framework_provided(self):
|
||||
"""
|
||||
Ensure the that the provided framework is returned.
|
||||
"""
|
||||
mock_framework = "mock_framework"
|
||||
|
||||
# Framework provided - return whatever the user provides
|
||||
result = FeaturesManager.determine_framework(self.test_model, mock_framework)
|
||||
self.assertEqual(result, mock_framework)
|
||||
|
||||
# Local checkpoint and framework provided - return provided framework
|
||||
# PyTorch checkpoint
|
||||
with TemporaryDirectory() as local_pt_ckpt:
|
||||
self._setup_pt_ckpt(local_pt_ckpt)
|
||||
result = FeaturesManager.determine_framework(local_pt_ckpt, mock_framework)
|
||||
self.assertEqual(result, mock_framework)
|
||||
|
||||
# TensorFlow checkpoint
|
||||
with TemporaryDirectory() as local_tf_ckpt:
|
||||
self._setup_tf_ckpt(local_tf_ckpt)
|
||||
result = FeaturesManager.determine_framework(local_tf_ckpt, mock_framework)
|
||||
self.assertEqual(result, mock_framework)
|
||||
|
||||
def test_checkpoint_provided(self):
|
||||
"""
|
||||
Ensure that the determined framework is the one used for the local checkpoint.
|
||||
|
||||
For the functionality to execute, local checkpoints are provided but framework is not.
|
||||
"""
|
||||
# PyTorch checkpoint
|
||||
with TemporaryDirectory() as local_pt_ckpt:
|
||||
self._setup_pt_ckpt(local_pt_ckpt)
|
||||
result = FeaturesManager.determine_framework(local_pt_ckpt)
|
||||
self.assertEqual(result, self.framework_pt)
|
||||
|
||||
# TensorFlow checkpoint
|
||||
with TemporaryDirectory() as local_tf_ckpt:
|
||||
self._setup_tf_ckpt(local_tf_ckpt)
|
||||
result = FeaturesManager.determine_framework(local_tf_ckpt)
|
||||
self.assertEqual(result, self.framework_tf)
|
||||
|
||||
# Invalid local checkpoint
|
||||
with TemporaryDirectory() as local_invalid_ckpt:
|
||||
with self.assertRaises(FileNotFoundError):
|
||||
result = FeaturesManager.determine_framework(local_invalid_ckpt)
|
||||
|
||||
def test_from_environment(self):
|
||||
"""
|
||||
Ensure that the determined framework is the one available in the environment.
|
||||
|
||||
For the functionality to execute, framework and local checkpoints are not provided.
|
||||
"""
|
||||
# Framework not provided, hub model is used (no local checkpoint directory)
|
||||
# TensorFlow not in environment -> use PyTorch
|
||||
mock_tf_available = MagicMock(return_value=False)
|
||||
with patch("transformers.onnx.features.is_tf_available", mock_tf_available):
|
||||
result = FeaturesManager.determine_framework(self.test_model)
|
||||
self.assertEqual(result, self.framework_pt)
|
||||
|
||||
# PyTorch not in environment -> use TensorFlow
|
||||
mock_torch_available = MagicMock(return_value=False)
|
||||
with patch("transformers.onnx.features.is_torch_available", mock_torch_available):
|
||||
result = FeaturesManager.determine_framework(self.test_model)
|
||||
self.assertEqual(result, self.framework_tf)
|
||||
|
||||
# Both in environment -> use PyTorch
|
||||
mock_tf_available = MagicMock(return_value=True)
|
||||
mock_torch_available = MagicMock(return_value=True)
|
||||
with patch("transformers.onnx.features.is_tf_available", mock_tf_available), patch(
|
||||
"transformers.onnx.features.is_torch_available", mock_torch_available
|
||||
):
|
||||
result = FeaturesManager.determine_framework(self.test_model)
|
||||
self.assertEqual(result, self.framework_pt)
|
||||
|
||||
# Both not in environment -> raise error
|
||||
mock_tf_available = MagicMock(return_value=False)
|
||||
mock_torch_available = MagicMock(return_value=False)
|
||||
with patch("transformers.onnx.features.is_tf_available", mock_tf_available), patch(
|
||||
"transformers.onnx.features.is_torch_available", mock_torch_available
|
||||
):
|
||||
with self.assertRaises(EnvironmentError):
|
||||
result = FeaturesManager.determine_framework(self.test_model)
|
@ -434,7 +434,7 @@ def module_to_test_file(module_fname):
|
||||
return "tests/utils/test_cli.py"
|
||||
# Special case for onnx submodules
|
||||
elif len(splits) >= 2 and splits[-2] == "onnx":
|
||||
return ["tests/onnx/test_onnx.py", "tests/onnx/test_onnx_v2.py"]
|
||||
return ["tests/onnx/test_features.py", "tests/onnx/test_onnx.py", "tests/onnx/test_onnx_v2.py"]
|
||||
# Special case for utils (not the one in src/transformers, the ones at the root of the repo).
|
||||
elif len(splits) > 0 and splits[0] == "utils":
|
||||
default_test_file = f"tests/utils/test_utils_{module_name}"
|
||||
|
Loading…
Reference in New Issue
Block a user