mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Enhance IPEX integration in Trainer (#18072)
* enhance ipex import * refine codes * refine style * add link * style Co-authored-by: Stas Bekman <stas@stason.org>
This commit is contained in:
parent
a462fc9232
commit
b7d8bd378c
@ -292,10 +292,15 @@ def require_intel_extension_for_pytorch(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires Intel Extension for PyTorch.
|
||||
|
||||
These tests are skipped when Intel Extension for PyTorch isn't installed.
|
||||
These tests are skipped when Intel Extension for PyTorch isn't installed or it does not match current PyTorch
|
||||
version.
|
||||
|
||||
"""
|
||||
return unittest.skipUnless(is_ipex_available(), "test requires Intel Extension for PyTorch")(test_case)
|
||||
return unittest.skipUnless(
|
||||
is_ipex_available(),
|
||||
"test requires Intel Extension for PyTorch to be installed and match current PyTorch version, see"
|
||||
" https://github.com/intel/intel-extension-for-pytorch",
|
||||
)(test_case)
|
||||
|
||||
|
||||
def require_torch_scatter(test_case):
|
||||
|
@ -1211,8 +1211,8 @@ class Trainer:
|
||||
def ipex_optimize_model(self, model, training=False, dtype=torch.float32):
|
||||
if not is_ipex_available():
|
||||
raise ImportError(
|
||||
"Using IPEX but IPEX is not installed, please refer to"
|
||||
" https://github.com/intel/intel-extension-for-pytorch."
|
||||
"Using IPEX but IPEX is not installed or IPEX's version does not match current PyTorch, please refer"
|
||||
" to https://github.com/intel/intel-extension-for-pytorch."
|
||||
)
|
||||
|
||||
import intel_extension_for_pytorch as ipex
|
||||
@ -1223,7 +1223,9 @@ class Trainer:
|
||||
else:
|
||||
if not model.training:
|
||||
model.train()
|
||||
model, self.optimizer = ipex.optimize(model, dtype=dtype, optimizer=self.optimizer, level="O1")
|
||||
model, self.optimizer = ipex.optimize(
|
||||
model, dtype=dtype, optimizer=self.optimizer, inplace=True, level="O1"
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
@ -443,7 +443,25 @@ def is_apex_available():
|
||||
|
||||
|
||||
def is_ipex_available():
|
||||
return importlib.util.find_spec("intel_extension_for_pytorch") is not None
|
||||
def get_major_and_minor_from_version(full_version):
|
||||
return str(version.parse(full_version).major) + "." + str(version.parse(full_version).minor)
|
||||
|
||||
if not is_torch_available() or importlib.util.find_spec("intel_extension_for_pytorch") is None:
|
||||
return False
|
||||
_ipex_version = "N/A"
|
||||
try:
|
||||
_ipex_version = importlib_metadata.version("intel_extension_for_pytorch")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
return False
|
||||
torch_major_and_minor = get_major_and_minor_from_version(_torch_version)
|
||||
ipex_major_and_minor = get_major_and_minor_from_version(_ipex_version)
|
||||
if torch_major_and_minor != ipex_major_and_minor:
|
||||
logger.warning(
|
||||
f"Intel Extension for PyTorch {ipex_major_and_minor} needs to work with PyTorch {ipex_major_and_minor}.*,"
|
||||
f" but PyTorch {_torch_version} is found. Please switch to the matching version and run again."
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def is_bitsandbytes_available():
|
||||
|
@ -642,7 +642,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
train_output = trainer.train()
|
||||
self.assertEqual(train_output.global_step, 10)
|
||||
|
||||
@unittest.skip(reason="skip temporarily until intel_extension_for_pytorch works with torch 1.12")
|
||||
@require_torch_bf16_cpu
|
||||
@require_intel_extension_for_pytorch
|
||||
def test_number_of_steps_in_training_with_ipex(self):
|
||||
@ -887,7 +886,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
expected_acc = AlmostAccuracy()((pred + 1, y))["accuracy"]
|
||||
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
|
||||
|
||||
@unittest.skip(reason="skip temporarily until intel_extension_for_pytorch works with torch 1.12")
|
||||
@require_torch_bf16_cpu
|
||||
@require_intel_extension_for_pytorch
|
||||
def test_evaluate_with_ipex(self):
|
||||
@ -1008,7 +1006,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.assertTrue(np.array_equal(labels[0], trainer.eval_dataset.ys[0]))
|
||||
self.assertTrue(np.array_equal(labels[1], trainer.eval_dataset.ys[1]))
|
||||
|
||||
@unittest.skip(reason="skip temporarily until intel_extension_for_pytorch works with torch 1.12")
|
||||
@require_torch_bf16_cpu
|
||||
@require_intel_extension_for_pytorch
|
||||
def test_predict_with_ipex(self):
|
||||
|
Loading…
Reference in New Issue
Block a user