mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-13 09:40:06 +06:00

* Automatic detection for framework to use when exporting to ONNX * Log message change * Incorporating PR comments, adding unit test * Adding tf for pip install for run_tests_onnxruntime CI * Restoring past changes to circleci yaml and test_onnx_v2.py, tests moved to tests/onnx/test_features.py * Fixup * Adding test to fetcher * Updating circleci config to log more * Changing test class name * Comment typo fix in tests/onnx/test_features.py Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * Moving torch_str/tf_str to self.framework_pt/tf * Remove -rA flag in circleci config Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
112 lines
4.7 KiB
Python
112 lines
4.7 KiB
Python
from tempfile import TemporaryDirectory
|
|
from unittest import TestCase
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
from transformers import AutoModel, TFAutoModel
|
|
from transformers.onnx import FeaturesManager
|
|
from transformers.testing_utils import SMALL_MODEL_IDENTIFIER, require_tf, require_torch
|
|
|
|
|
|
@require_torch
|
|
@require_tf
|
|
class DetermineFrameworkTest(TestCase):
|
|
"""
|
|
Test `FeaturesManager.determine_framework`
|
|
"""
|
|
|
|
def setUp(self):
|
|
self.test_model = SMALL_MODEL_IDENTIFIER
|
|
self.framework_pt = "pt"
|
|
self.framework_tf = "tf"
|
|
|
|
def _setup_pt_ckpt(self, save_dir):
|
|
model_pt = AutoModel.from_pretrained(self.test_model)
|
|
model_pt.save_pretrained(save_dir)
|
|
|
|
def _setup_tf_ckpt(self, save_dir):
|
|
model_tf = TFAutoModel.from_pretrained(self.test_model, from_pt=True)
|
|
model_tf.save_pretrained(save_dir)
|
|
|
|
def test_framework_provided(self):
|
|
"""
|
|
Ensure the that the provided framework is returned.
|
|
"""
|
|
mock_framework = "mock_framework"
|
|
|
|
# Framework provided - return whatever the user provides
|
|
result = FeaturesManager.determine_framework(self.test_model, mock_framework)
|
|
self.assertEqual(result, mock_framework)
|
|
|
|
# Local checkpoint and framework provided - return provided framework
|
|
# PyTorch checkpoint
|
|
with TemporaryDirectory() as local_pt_ckpt:
|
|
self._setup_pt_ckpt(local_pt_ckpt)
|
|
result = FeaturesManager.determine_framework(local_pt_ckpt, mock_framework)
|
|
self.assertEqual(result, mock_framework)
|
|
|
|
# TensorFlow checkpoint
|
|
with TemporaryDirectory() as local_tf_ckpt:
|
|
self._setup_tf_ckpt(local_tf_ckpt)
|
|
result = FeaturesManager.determine_framework(local_tf_ckpt, mock_framework)
|
|
self.assertEqual(result, mock_framework)
|
|
|
|
def test_checkpoint_provided(self):
|
|
"""
|
|
Ensure that the determined framework is the one used for the local checkpoint.
|
|
|
|
For the functionality to execute, local checkpoints are provided but framework is not.
|
|
"""
|
|
# PyTorch checkpoint
|
|
with TemporaryDirectory() as local_pt_ckpt:
|
|
self._setup_pt_ckpt(local_pt_ckpt)
|
|
result = FeaturesManager.determine_framework(local_pt_ckpt)
|
|
self.assertEqual(result, self.framework_pt)
|
|
|
|
# TensorFlow checkpoint
|
|
with TemporaryDirectory() as local_tf_ckpt:
|
|
self._setup_tf_ckpt(local_tf_ckpt)
|
|
result = FeaturesManager.determine_framework(local_tf_ckpt)
|
|
self.assertEqual(result, self.framework_tf)
|
|
|
|
# Invalid local checkpoint
|
|
with TemporaryDirectory() as local_invalid_ckpt:
|
|
with self.assertRaises(FileNotFoundError):
|
|
result = FeaturesManager.determine_framework(local_invalid_ckpt)
|
|
|
|
def test_from_environment(self):
|
|
"""
|
|
Ensure that the determined framework is the one available in the environment.
|
|
|
|
For the functionality to execute, framework and local checkpoints are not provided.
|
|
"""
|
|
# Framework not provided, hub model is used (no local checkpoint directory)
|
|
# TensorFlow not in environment -> use PyTorch
|
|
mock_tf_available = MagicMock(return_value=False)
|
|
with patch("transformers.onnx.features.is_tf_available", mock_tf_available):
|
|
result = FeaturesManager.determine_framework(self.test_model)
|
|
self.assertEqual(result, self.framework_pt)
|
|
|
|
# PyTorch not in environment -> use TensorFlow
|
|
mock_torch_available = MagicMock(return_value=False)
|
|
with patch("transformers.onnx.features.is_torch_available", mock_torch_available):
|
|
result = FeaturesManager.determine_framework(self.test_model)
|
|
self.assertEqual(result, self.framework_tf)
|
|
|
|
# Both in environment -> use PyTorch
|
|
mock_tf_available = MagicMock(return_value=True)
|
|
mock_torch_available = MagicMock(return_value=True)
|
|
with patch("transformers.onnx.features.is_tf_available", mock_tf_available), patch(
|
|
"transformers.onnx.features.is_torch_available", mock_torch_available
|
|
):
|
|
result = FeaturesManager.determine_framework(self.test_model)
|
|
self.assertEqual(result, self.framework_pt)
|
|
|
|
# Both not in environment -> raise error
|
|
mock_tf_available = MagicMock(return_value=False)
|
|
mock_torch_available = MagicMock(return_value=False)
|
|
with patch("transformers.onnx.features.is_tf_available", mock_tf_available), patch(
|
|
"transformers.onnx.features.is_torch_available", mock_torch_available
|
|
):
|
|
with self.assertRaises(EnvironmentError):
|
|
result = FeaturesManager.determine_framework(self.test_model)
|