Enhance IPEX integration in Trainer (#18072)

* enhance ipex import

* refine codes

* refine style

* add link

* style

Co-authored-by: Stas Bekman <stas@stason.org>
This commit is contained in:
jianan-gu 2022-07-12 12:34:09 +08:00 committed by GitHub
parent a462fc9232
commit b7d8bd378c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 31 additions and 9 deletions

View File

@ -292,10 +292,15 @@ def require_intel_extension_for_pytorch(test_case):
"""
Decorator marking a test that requires Intel Extension for PyTorch.
These tests are skipped when Intel Extension for PyTorch isn't installed.
These tests are skipped when Intel Extension for PyTorch isn't installed or it does not match current PyTorch
version.
"""
return unittest.skipUnless(is_ipex_available(), "test requires Intel Extension for PyTorch")(test_case)
return unittest.skipUnless(
is_ipex_available(),
"test requires Intel Extension for PyTorch to be installed and match current PyTorch version, see"
" https://github.com/intel/intel-extension-for-pytorch",
)(test_case)
def require_torch_scatter(test_case):

View File

@ -1211,8 +1211,8 @@ class Trainer:
def ipex_optimize_model(self, model, training=False, dtype=torch.float32):
if not is_ipex_available():
raise ImportError(
"Using IPEX but IPEX is not installed, please refer to"
" https://github.com/intel/intel-extension-for-pytorch."
"Using IPEX but IPEX is not installed or IPEX's version does not match current PyTorch, please refer"
" to https://github.com/intel/intel-extension-for-pytorch."
)
import intel_extension_for_pytorch as ipex
@ -1223,7 +1223,9 @@ class Trainer:
else:
if not model.training:
model.train()
model, self.optimizer = ipex.optimize(model, dtype=dtype, optimizer=self.optimizer, level="O1")
model, self.optimizer = ipex.optimize(
model, dtype=dtype, optimizer=self.optimizer, inplace=True, level="O1"
)
return model

View File

@ -443,7 +443,25 @@ def is_apex_available():
def is_ipex_available():
return importlib.util.find_spec("intel_extension_for_pytorch") is not None
def get_major_and_minor_from_version(full_version):
return str(version.parse(full_version).major) + "." + str(version.parse(full_version).minor)
if not is_torch_available() or importlib.util.find_spec("intel_extension_for_pytorch") is None:
return False
_ipex_version = "N/A"
try:
_ipex_version = importlib_metadata.version("intel_extension_for_pytorch")
except importlib_metadata.PackageNotFoundError:
return False
torch_major_and_minor = get_major_and_minor_from_version(_torch_version)
ipex_major_and_minor = get_major_and_minor_from_version(_ipex_version)
if torch_major_and_minor != ipex_major_and_minor:
logger.warning(
f"Intel Extension for PyTorch {ipex_major_and_minor} needs to work with PyTorch {ipex_major_and_minor}.*,"
f" but PyTorch {_torch_version} is found. Please switch to the matching version and run again."
)
return False
return True
def is_bitsandbytes_available():

View File

@ -642,7 +642,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
train_output = trainer.train()
self.assertEqual(train_output.global_step, 10)
@unittest.skip(reason="skip temporarily until intel_extension_for_pytorch works with torch 1.12")
@require_torch_bf16_cpu
@require_intel_extension_for_pytorch
def test_number_of_steps_in_training_with_ipex(self):
@ -887,7 +886,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
expected_acc = AlmostAccuracy()((pred + 1, y))["accuracy"]
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
@unittest.skip(reason="skip temporarily until intel_extension_for_pytorch works with torch 1.12")
@require_torch_bf16_cpu
@require_intel_extension_for_pytorch
def test_evaluate_with_ipex(self):
@ -1008,7 +1006,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertTrue(np.array_equal(labels[0], trainer.eval_dataset.ys[0]))
self.assertTrue(np.array_equal(labels[1], trainer.eval_dataset.ys[1]))
@unittest.skip(reason="skip temporarily until intel_extension_for_pytorch works with torch 1.12")
@require_torch_bf16_cpu
@require_intel_extension_for_pytorch
def test_predict_with_ipex(self):