remove ipex_optimize_model usage (#38632)

* remove ipex_optimize_model usage

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* update Dockerfile

Signed-off-by: root <root@a4bf01945cfe.jf.intel.com>

---------

Signed-off-by: YAO Matrix <matrix.yao@intel.com>
Signed-off-by: root <root@a4bf01945cfe.jf.intel.com>
Co-authored-by: root <root@a4bf01945cfe.jf.intel.com>
This commit is contained in:
Yao Matrix 2025-06-07 02:04:44 +08:00 committed by GitHub
parent 5009252a05
commit dc76eff12b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 10 additions and 231 deletions

View File

@ -10,8 +10,6 @@ SHELL ["sh", "-lc"]
# to be used as arguments for docker build (so far).
ARG PYTORCH='2.6.0'
# (not always a valid torch version)
ARG INTEL_TORCH_EXT='2.3.0'
# Example: `cu102`, `cu113`, etc.
ARG CUDA='cu121'
# Disable kernel mapping for now until all tests pass
@ -32,8 +30,6 @@ RUN python3 -m pip install --no-cache-dir -e ./transformers[dev,onnxruntime] &&
RUN python3 -m pip uninstall -y flax jax
RUN python3 -m pip install --no-cache-dir intel_extension_for_pytorch==$INTEL_TORCH_EXT -f https://developer.intel.com/ipex-whl-stable-cpu
RUN python3 -m pip install --no-cache-dir git+https://github.com/facebookresearch/detectron2.git pytesseract
RUN python3 -m pip install -U "itsdangerous<2.1.0"

View File

@ -78,26 +78,3 @@ python examples/pytorch/question-answering/run_qa.py \
--no_cuda \
--jit_mode_eval
```
## IPEX
[Intel Extension for PyTorch](https://intel.github.io/intel-extension-for-pytorch/cpu/latest/tutorials/getting_started.html) (IPEX) offers additional optimizations for PyTorch on Intel CPUs. IPEX further optimizes TorchScript with [graph optimization](https://intel.github.io/intel-extension-for-pytorch/cpu/latest/tutorials/features/graph_optimization.html) which fuses operations like Multi-head attention, Concat Linear, Linear + Add, Linear + Gelu, Add + LayerNorm, and more, into single kernels for faster execution.
Make sure IPEX is installed, and set the `--use_opex` and `--jit_mode_eval` flags in [`Trainer`] to enable IPEX graph optimization and TorchScript.
```bash
!pip install intel_extension_for_pytorch
```
```bash
python examples/pytorch/question-answering/run_qa.py \
--model_name_or_path csarron/bert-base-uncased-squad-v1 \
--dataset_name squad \
--do_eval \
--max_seq_length 384 \
--doc_stride 128 \
--output_dir /tmp/ \
--no_cuda \
--use_ipex \
--jit_mode_eval
```

View File

@ -17,30 +17,9 @@ rendered properly in your Markdown viewer.
A modern CPU is capable of efficiently training large models by leveraging the underlying optimizations built into the hardware and training on fp16 or bf16 data types.
This guide focuses on how to train large models on an Intel CPU using mixed precision and the [Intel Extension for PyTorch (IPEX)](https://intel.github.io/intel-extension-for-pytorch/index.html) library.
This guide focuses on how to train large models on an Intel CPU using mixed precision. AMP is enabled for CPU backends training with PyTorch.
You can Find your PyTorch version by running the command below.
```bash
pip list | grep torch
```
Install IPEX with the PyTorch version from above.
```bash
pip install intel_extension_for_pytorch==<version_name> -f https://developer.intel.com/ipex-whl-stable-cpu
```
> [!TIP]
> Refer to the IPEX [installation](https://intel.github.io/intel-extension-for-pytorch/index.html#installation) guide for more details.
IPEX provides additional performance optimizations for Intel CPUs. These include additional CPU instruction level architecture (ISA) support such as [Intel AVX512-VNNI](https://en.wikichip.org/wiki/x86/avx512_vnni) and [Intel AMX](https://www.intel.com/content/www/us/en/products/docs/accelerator-engines/what-is-intel-amx.html). Both of these features are designed to accelerate matrix multiplication. Older AMD and Intel CPUs with only Intel AVX2, however, aren't guaranteed better performance with IPEX.
IPEX also supports [Auto Mixed Precision (AMP)](https://intel.github.io/intel-extension-for-pytorch/cpu/latest/tutorials/features/amp.html) training with the fp16 and bf16 data types. Reducing precision speeds up training and reduces memory usage because it requires less computation. The loss in accuracy from using full-precision is minimal. 3rd, 4th, and 5th generation Intel Xeon Scalable processors natively support bf16, and the 6th generation processor also natively supports fp16 in addition to bf16.
AMP is enabled for CPU backends training with PyTorch.
[`Trainer`] supports AMP training with a CPU by adding the `--use_cpu`, `--use_ipex`, and `--bf16` parameters. The example below demonstrates the [run_qa.py](https://github.com/huggingface/transformers/tree/main/examples/pytorch/question-answering) script.
[`Trainer`] supports AMP training with CPU by adding the `--use_cpu`, and `--bf16` parameters. The example below demonstrates the [run_qa.py](https://github.com/huggingface/transformers/tree/main/examples/pytorch/question-answering) script.
```bash
python run_qa.py \
@ -54,7 +33,6 @@ python run_qa.py \
--max_seq_length 384 \
--doc_stride 128 \
--output_dir /tmp/debug_squad/ \
--use_ipex \
--bf16 \
--use_cpu
```
@ -65,7 +43,6 @@ These parameters can also be added to [`TrainingArguments`] as shown below.
training_args = TrainingArguments(
output_dir="./outputs",
bf16=True,
use_ipex=True,
use_cpu=True,
)
```

View File

@ -75,8 +75,7 @@ python3 run_qa.py \
--doc_stride 128 \
--output_dir /tmp/debug_squad/ \
--no_cuda \
--ddp_backend ccl \
--use_ipex
--ddp_backend ccl
```
</hfoption>
@ -115,7 +114,6 @@ python3 run_qa.py \
--output_dir /tmp/debug_squad/ \
--no_cuda \
--ddp_backend ccl \
--use_ipex \
--bf16
```
@ -201,8 +199,7 @@ spec:
--output_dir /tmp/pvc-mount/output_$(date +%Y%m%d_%H%M%S) \
--no_cuda \
--ddp_backend ccl \
--bf16 \
--use_ipex;
--bf16;
env:
- name: LD_PRELOAD
value: "/usr/lib/x86_64-linux-gnu/libtcmalloc.so.4.5.9:/usr/local/lib/libiomp5.so"

View File

@ -157,7 +157,6 @@ from .utils import (
is_galore_torch_available,
is_grokadamw_available,
is_in_notebook,
is_ipex_available,
is_liger_kernel_available,
is_lomo_available,
is_peft_available,
@ -1916,29 +1915,6 @@ class Trainer:
return model
def ipex_optimize_model(self, model, training=False, dtype=torch.float32):
if not is_ipex_available():
raise ImportError(
"Using IPEX but IPEX is not installed or IPEX's version does not match current PyTorch, please refer"
" to https://github.com/intel/intel-extension-for-pytorch."
)
import intel_extension_for_pytorch as ipex
if not training:
model.eval()
dtype = torch.bfloat16 if not self.is_in_train and self.args.bf16_full_eval else dtype
# conv_bn_folding is disabled as it fails in symbolic tracing, resulting in ipex warnings
model = ipex.optimize(model, dtype=dtype, level="O1", conv_bn_folding=False, inplace=not self.is_in_train)
else:
if not model.training:
model.train()
model, self.optimizer = ipex.optimize(
model, dtype=dtype, optimizer=self.optimizer, inplace=True, level="O1"
)
return model
def compare_trainer_and_checkpoint_args(self, training_args, trainer_state):
attributes_map = {
"logging_steps": "logging_steps",
@ -1968,10 +1944,6 @@ class Trainer:
logger.warning_once(warning_str)
def _wrap_model(self, model, training=True, dataloader=None):
if self.args.use_ipex:
dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32
model = self.ipex_optimize_model(model, training, dtype=dtype)
if is_sagemaker_mp_enabled():
# Wrapping the base model twice in a DistributedModel will raise an error.
if isinstance(self.model_wrapped, smp.model.DistributedModel):

View File

@ -1581,6 +1581,12 @@ class TrainingArguments:
FutureWarning,
)
self.use_cpu = self.no_cuda
if self.use_ipex:
warnings.warn(
"using `use_ipex` is deprecated and will be removed in version 4.54 of 🤗 Transformers. "
"You only need PyTorch for the needed optimizations on Intel CPU and XPU.",
FutureWarning,
)
self.eval_strategy = IntervalStrategy(self.eval_strategy)
self.logging_strategy = IntervalStrategy(self.logging_strategy)

View File

@ -79,7 +79,6 @@ from transformers.testing_utils import (
require_deepspeed,
require_galore_torch,
require_grokadamw,
require_intel_extension_for_pytorch,
require_liger_kernel,
require_lomo,
require_non_hpu,
@ -1325,37 +1324,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
train_output = trainer.train()
self.assertEqual(train_output.global_step, 10)
@require_torch_bf16
@require_intel_extension_for_pytorch
def test_number_of_steps_in_training_with_ipex(self):
for mix_bf16 in [True, False]:
tmp_dir = self.get_auto_remove_tmp_dir()
# Regular training has n_epochs * len(train_dl) steps
trainer = get_regression_trainer(
learning_rate=0.1, use_ipex=True, bf16=mix_bf16, use_cpu=True, output_dir=tmp_dir
)
train_output = trainer.train()
self.assertEqual(train_output.global_step, self.n_epochs * 64 / trainer.args.train_batch_size)
# Check passing num_train_epochs works (and a float version too):
trainer = get_regression_trainer(
learning_rate=0.1,
num_train_epochs=1.5,
use_ipex=True,
bf16=mix_bf16,
use_cpu=True,
output_dir=tmp_dir,
)
train_output = trainer.train()
self.assertEqual(train_output.global_step, int(1.5 * 64 / trainer.args.train_batch_size))
# If we pass a max_steps, num_train_epochs is ignored
trainer = get_regression_trainer(
learning_rate=0.1, max_steps=10, use_ipex=True, bf16=mix_bf16, use_cpu=True, output_dir=tmp_dir
)
train_output = trainer.train()
self.assertEqual(train_output.global_step, 10)
def test_torch_compile_loss_func_compatibility(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
@ -2628,69 +2596,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
expected_acc = AlmostAccuracy()((pred + 1, y))["accuracy"]
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
@require_torch_bf16
@require_intel_extension_for_pytorch
def test_evaluate_with_ipex(self):
for mix_bf16 in [True, False]:
with tempfile.TemporaryDirectory() as tmp_dir:
trainer = get_regression_trainer(
a=1.5,
b=2.5,
use_ipex=True,
compute_metrics=AlmostAccuracy(),
bf16=mix_bf16,
use_cpu=True,
output_dir=tmp_dir,
)
results = trainer.evaluate()
x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0]
pred = 1.5 * x + 2.5
expected_loss = ((pred - y) ** 2).mean()
self.assertAlmostEqual(results["eval_loss"], expected_loss)
expected_acc = AlmostAccuracy()((pred, y))["accuracy"]
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
# With a number of elements not a round multiple of the batch size
trainer = get_regression_trainer(
a=1.5,
b=2.5,
use_ipex=True,
eval_len=66,
compute_metrics=AlmostAccuracy(),
bf16=mix_bf16,
use_cpu=True,
output_dir=tmp_dir,
)
results = trainer.evaluate()
x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0]
pred = 1.5 * x + 2.5
expected_loss = ((pred - y) ** 2).mean()
self.assertAlmostEqual(results["eval_loss"], expected_loss)
expected_acc = AlmostAccuracy()((pred, y))["accuracy"]
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
# With logits preprocess
trainer = get_regression_trainer(
a=1.5,
b=2.5,
use_ipex=True,
compute_metrics=AlmostAccuracy(),
preprocess_logits_for_metrics=lambda logits, labels: logits + 1,
bf16=mix_bf16,
use_cpu=True,
output_dir=tmp_dir,
)
results = trainer.evaluate()
x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0]
pred = 1.5 * x + 2.5
expected_loss = ((pred - y) ** 2).mean()
self.assertAlmostEqual(results["eval_loss"], expected_loss)
expected_acc = AlmostAccuracy()((pred + 1, y))["accuracy"]
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
def test_predict(self):
with tempfile.TemporaryDirectory() as tmp_dir:
trainer = get_regression_trainer(a=1.5, b=2.5, output_dir=tmp_dir)
@ -2830,57 +2735,6 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertTrue(np.array_equal(labels[0], trainer.eval_dataset.ys[0]))
self.assertTrue(np.array_equal(labels[1], trainer.eval_dataset.ys[1]))
@require_torch_bf16
@require_intel_extension_for_pytorch
def test_predict_with_ipex(self):
for mix_bf16 in [True, False]:
with tempfile.TemporaryDirectory() as tmp_dir:
trainer = get_regression_trainer(
a=1.5, b=2.5, use_ipex=True, bf16=mix_bf16, use_cpu=True, output_dir=tmp_dir
)
preds = trainer.predict(trainer.eval_dataset).predictions
x = trainer.eval_dataset.x
self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))
# With a number of elements not a round multiple of the batch size
trainer = get_regression_trainer(
a=1.5, b=2.5, eval_len=66, use_ipex=True, bf16=mix_bf16, use_cpu=True, output_dir=tmp_dir
)
preds = trainer.predict(trainer.eval_dataset).predictions
x = trainer.eval_dataset.x
self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))
# With more than one output of the model
trainer = get_regression_trainer(
a=1.5, b=2.5, double_output=True, use_ipex=True, bf16=mix_bf16, use_cpu=True, output_dir=tmp_dir
)
preds = trainer.predict(trainer.eval_dataset).predictions
x = trainer.eval_dataset.x
self.assertEqual(len(preds), 2)
self.assertTrue(np.allclose(preds[0], 1.5 * x + 2.5))
self.assertTrue(np.allclose(preds[1], 1.5 * x + 2.5))
# With more than one output/label of the model
trainer = get_regression_trainer(
a=1.5,
b=2.5,
double_output=True,
label_names=["labels", "labels_2"],
use_ipex=True,
bf16=mix_bf16,
use_cpu=True,
output_dir=tmp_dir,
)
outputs = trainer.predict(trainer.eval_dataset)
preds = outputs.predictions
labels = outputs.label_ids
x = trainer.eval_dataset.x
self.assertEqual(len(preds), 2)
self.assertTrue(np.allclose(preds[0], 1.5 * x + 2.5))
self.assertTrue(np.allclose(preds[1], 1.5 * x + 2.5))
self.assertTrue(np.array_equal(labels[0], trainer.eval_dataset.ys[0]))
self.assertTrue(np.array_equal(labels[1], trainer.eval_dataset.ys[1]))
def test_dynamic_shapes(self):
eval_dataset = DynamicShapesDataset(batch_size=self.batch_size)
model = RegressionModel(a=2, b=1)