Extend Transformers Trainer Class to Enable CPU AMP and Integrate Intel Extension for PyTorch (#17138)

* init PR

* fix import ipex

* minor fix on bf16

* refine optimizer

* refine args notes

* refine code

* refine ipex optimize args

* refine half_precision_backend

* black format

* isort format

* isort format files

* flake8 format

* doc builder format

* refine codes

* remove jit and optim bits

* black preview format

* Update src/transformers/trainer.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* refine code

* refine notes

* Update src/transformers/trainer.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/trainer.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* code refine

* add ipex ut

* add performance cpu doc

* link to the cpu doc from main perf doc

* install ipex into CI's docker

* Update perf_train_cpu.mdx

* Update docs/source/en/perf_train_cpu.mdx

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

* Update perf_train_cpu.mdx

* Update perf_train_cpu.mdx

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Stas Bekman <stas@stason.org>
Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
This commit is contained in:
jianan-gu 2022-06-08 21:41:57 +08:00 committed by GitHub
parent ae7bae8fe7
commit 34097b3304
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 336 additions and 42 deletions

View File

@ -85,6 +85,8 @@
title: Training on one GPU
- local: perf_train_gpu_many
title: Training on many GPUs
- local: perf_train_cpu
title: Training on CPU
- local: perf_hardware
title: Custom hardware for training
- local: testing

View File

@ -0,0 +1,61 @@
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
-->
# Efficient Training on CPU
This guide focuses on training large models efficiently on CPU.
## Mixed precision with IPEX
IPEX is optimized for CPUs with AVX-512 or above, and functionally works for CPUs with only AVX2. So, it is expected to bring performance benefit for Intel CPU generations with AVX-512 or above while CPUs with only AVX2 (e.g., AMD CPUs or older Intel CPUs) might result in a better performance under IPEX, but not guaranteed. IPEX provides performance optimizations for CPU training with both Float32 and BFloat16. The usage of BFloat16 is the main focus of the following sections.
Low precision data type BFloat16 has been natively supported on the 3rd Generation Xeon® Scalable Processors (aka Cooper Lake) with AVX512 instruction set and will be supported on the next generation of Intel® Xeon® Scalable Processors with Intel® Advanced Matrix Extensions (Intel® AMX) instruction set with further boosted performance. The Auto Mixed Precision for CPU backend has been enabled since PyTorch-1.10. At the same time, the support of Auto Mixed Precision with BFloat16 for CPU and BFloat16 optimization of operators has been massively enabled in Intel® Extension for PyTorch, and partially upstreamed to PyTorch master branch. Users can get better performance and user experience with IPEX Auto Mixed Precision.
Check more detailed information for [Auto Mixed Precision](https://intel.github.io/intel-extension-for-pytorch/1.11.200/tutorials/features/amp.html).
### IPEX installation:
IPEX release is following PyTorch, to install via pip:
For PyTorch-1.10:
```
pip install intel_extension_for_pytorch==1.10.100+cpu -f https://software.intel.com/ipex-whl-stable
```
For PyTorch-1.11:
```
pip install intel_extension_for_pytorch==1.11.200+cpu -f https://software.intel.com/ipex-whl-stable
```
Check more approaches for [IPEX installation](https://intel.github.io/intel-extension-for-pytorch/1.11.200/tutorials/installation.html).
### Usage in Trainer
To enable auto mixed precision with IPEX in Trainer, users should add `use_ipex`, `bf16` and `no_cuda` in training command arguments.
Take an example of the use cases on [Transformers question-answering](https://github.com/huggingface/transformers/tree/main/examples/pytorch/question-answering)
- Training with IPEX using BF16 auto mixed precision on CPU:
<pre> python run_qa.py \
--model_name_or_path bert-base-uncased \
--dataset_name squad \
--do_train \
--do_eval \
--per_device_train_batch_size 12 \
--learning_rate 3e-5 \
--num_train_epochs 2 \
--max_seq_length 384 \
--doc_stride 128 \
--output_dir /tmp/debug_squad/ \
<b>--use_ipex \</b>
<b>--bf16 --no_cuda</b></pre>

View File

@ -38,6 +38,12 @@ In some cases training on a single GPU is still too slow or won't fit the large
[Go to multi-GPU training section](perf_train_gpu_many)
### CPU
[Go to CPU training section](perf_train_cpu)
### TPU
_Coming soon_

View File

@ -92,6 +92,7 @@ from .utils import (
is_flax_available,
is_ftfy_available,
is_in_notebook,
is_ipex_available,
is_librosa_available,
is_local_clone,
is_offline_mode,

View File

@ -47,6 +47,7 @@ from .utils import (
is_faiss_available,
is_flax_available,
is_ftfy_available,
is_ipex_available,
is_librosa_available,
is_onnx_available,
is_pandas_available,
@ -282,6 +283,16 @@ def require_torch(test_case):
return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case)
def require_intel_extension_for_pytorch(test_case):
"""
Decorator marking a test that requires Intel Extension for PyTorch.
These tests are skipped when Intel Extension for PyTorch isn't installed.
"""
return unittest.skipUnless(is_ipex_available(), "test requires Intel Extension for PyTorch")(test_case)
def require_torch_scatter(test_case):
"""
Decorator marking a test that requires PyTorch scatter.
@ -476,9 +487,10 @@ def require_torch_gpu(test_case):
def require_torch_bf16(test_case):
"""Decorator marking a test that requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.10."""
"""Decorator marking a test that requires torch>=1.10, using Ampere GPU or newer arch with cuda>=11.0 or using CPU."""
return unittest.skipUnless(
is_torch_bf16_available(), "test requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.10"
is_torch_bf16_available(),
"test requires torch>=1.10, using Ampere GPU or newer arch with cuda>=11.0 or using CPU",
)(test_case)

View File

@ -136,6 +136,7 @@ from .utils import (
is_apex_available,
is_datasets_available,
is_in_notebook,
is_ipex_available,
is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled,
is_torch_tpu_available,
@ -146,7 +147,8 @@ from .utils.generic import ContextManagers
_is_torch_generator_available = False
_is_native_amp_available = False
_is_native_cuda_amp_available = False
_is_native_cpu_amp_available = False
DEFAULT_CALLBACKS = [DefaultFlowCallback]
DEFAULT_PROGRESS_CALLBACK = ProgressCallback
@ -161,8 +163,10 @@ if is_apex_available():
if version.parse(torch.__version__) >= version.parse("1.6"):
_is_torch_generator_available = True
_is_native_amp_available = True
from torch.cuda.amp import autocast
_is_native_cuda_amp_available = True
if version.parse(torch.__version__) >= version.parse("1.10"):
_is_native_cpu_amp_available = True
if is_datasets_available():
import datasets
@ -487,7 +491,8 @@ class Trainer:
# Mixed precision setup
self.use_apex = False
self.use_amp = False
self.use_cuda_amp = False
self.use_cpu_amp = False
if args.fp16 or args.bf16:
if self.fsdp is not None:
@ -496,19 +501,27 @@ class Trainer:
"Please do not set arguments related to `mixed_precision`"
)
if args.half_precision_backend == "auto":
if _is_native_amp_available:
args.half_precision_backend = "amp"
if args.device == torch.device("cpu"):
if args.fp16:
raise ValueError("Tried to use `fp16` but it is not supported on cpu")
elif _is_native_cpu_amp_available:
args.half_precision_backend = "cpu_amp"
else:
raise ValueError("Tried to use cpu amp but native cpu amp is not available")
else:
if args.bf16:
if _is_native_cuda_amp_available:
args.half_precision_backend = "cuda_amp"
elif args.bf16:
raise ValueError("Tried to use `bf16` but native amp is not available")
else:
args.half_precision_backend = "apex"
logger.info(f"Using {args.half_precision_backend} half precision backend")
self.do_grad_scaling = False
if (args.fp16 or args.bf16) and not args.deepspeed: # deepspeed manages its own half precision
if args.half_precision_backend == "amp":
self.use_amp = True
if args.half_precision_backend == "cuda_amp":
self.use_cuda_amp = True
self.amp_dtype = torch.float16 if args.fp16 else torch.bfloat16
self.do_grad_scaling = True
if is_sagemaker_mp_enabled():
@ -521,6 +534,9 @@ class Trainer:
self.scaler = GradScaler()
else:
self.scaler = torch.cuda.amp.GradScaler()
elif args.half_precision_backend == "cpu_amp":
self.use_cpu_amp = True
self.amp_dtype = torch.bfloat16
else:
if not is_apex_available():
raise ImportError(
@ -1142,7 +1158,30 @@ class Trainer:
return model
def ipex_optimize_model(self, model, training=False, dtype=torch.float32):
if not is_ipex_available():
raise ImportError(
"Using IPEX but IPEX is not installed, please refer to"
" https://github.com/intel/intel-extension-for-pytorch."
)
import intel_extension_for_pytorch as ipex
if not training:
model.eval()
model = ipex.optimize(model, dtype=dtype, level="O1")
else:
if not model.training:
model.train()
model, self.optimizer = ipex.optimize(model, dtype=dtype, optimizer=self.optimizer, level="O1")
return model
def _wrap_model(self, model, training=True):
if self.args.use_ipex:
dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32
model = self.ipex_optimize_model(model, training, dtype=dtype)
if is_sagemaker_mp_enabled():
# Wrapping the base model twice in a DistributedModel will raise an error.
if isinstance(self.model_wrapped, smp.model.DistributedModel):
@ -2212,11 +2251,15 @@ class Trainer:
A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired
arguments, depending on the situation.
"""
if self.use_amp:
if self.use_cuda_amp or self.use_cpu_amp:
if version.parse(torch.__version__) >= version.parse("1.10"):
ctx_manager = autocast(dtype=self.amp_dtype)
ctx_manager = (
torch.cpu.amp.autocast(dtype=self.amp_dtype)
if self.use_cpu_amp
else torch.cuda.amp.autocast(dtype=self.amp_dtype)
)
else:
ctx_manager = autocast()
ctx_manager = torch.cuda.amp.autocast()
else:
ctx_manager = contextlib.nullcontext() if sys.version_info >= (3, 7) else contextlib.suppress()

View File

@ -236,9 +236,12 @@ class TrainingArguments:
Random seed to be used with data samplers. If not set, random generators for data sampling will use the
same seed as `seed`. This can be used to ensure reproducibility of data sampling, independent of the model
seed.
use_ipex (`bool`, *optional*, defaults to `False`):
Use Intel extension for PyTorch when it is available. [IPEX
installation](https://github.com/intel/intel-extension-for-pytorch).
bf16 (`bool`, *optional*, defaults to `False`):
Whether to use bf16 16-bit (mixed) precision training instead of 32-bit training. Requires Ampere or higher
NVIDIA architecture. This is an experimental API and it may change.
NVIDIA architecture or using CPU (no_cuda). This is an experimental API and it may change.
fp16 (`bool`, *optional*, defaults to `False`):
Whether to use fp16 16-bit (mixed) precision training instead of 32-bit training.
fp16_opt_level (`str`, *optional*, defaults to 'O1'):
@ -247,9 +250,9 @@ class TrainingArguments:
fp16_backend (`str`, *optional*, defaults to `"auto"`):
This argument is deprecated. Use `half_precision_backend` instead.
half_precision_backend (`str`, *optional*, defaults to `"auto"`):
The backend to use for mixed precision training. Must be one of `"auto"`, `"amp"` or `"apex"`. `"auto"`
will use AMP or APEX depending on the PyTorch version detected, while the other choices will force the
requested backend.
The backend to use for mixed precision training. Must be one of `"auto", "cuda_amp", "apex", "cpu_amp"`.
`"auto"` will use CPU/CUDA AMP or APEX depending on the PyTorch version detected, while the other choices
will force the requested backend.
bf16_full_eval (`bool`, *optional*, defaults to `False`):
Whether to use full bfloat16 evaluation instead of 32-bit. This will be faster and save memory but can harm
metric values. This is an experimental API and it may change.
@ -607,12 +610,21 @@ class TrainingArguments:
no_cuda: bool = field(default=False, metadata={"help": "Do not use CUDA even when it is available"})
seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
data_seed: Optional[int] = field(default=None, metadata={"help": "Random seed to be used with data samplers."})
use_ipex: bool = field(
default=False,
metadata={
"help": (
"Use Intel extension for PyTorch when it is available, installation:"
" 'https://github.com/intel/intel-extension-for-pytorch'"
)
},
)
bf16: bool = field(
default=False,
metadata={
"help": (
"Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA"
" architecture. This is an experimental API and it may change."
" architecture or using CPU (no_cuda). This is an experimental API and it may change."
)
},
)
@ -631,7 +643,10 @@ class TrainingArguments:
)
half_precision_backend: str = field(
default="auto",
metadata={"help": "The backend to be used for half precision.", "choices": ["auto", "amp", "apex"]},
metadata={
"help": "The backend to be used for half precision.",
"choices": ["auto", "cuda_amp", "apex", "cpu_amp"],
},
)
bf16_full_eval: bool = field(
default=False,
@ -849,7 +864,10 @@ class TrainingArguments:
# Deprecated arguments
fp16_backend: str = field(
default="auto",
metadata={"help": "Deprecated. Use half_precision_backend instead", "choices": ["auto", "amp", "apex"]},
metadata={
"help": "Deprecated. Use half_precision_backend instead",
"choices": ["auto", "cuda_amp", "apex", "cpu_amp"],
},
)
push_to_hub_model_id: Optional[str] = field(
default=None, metadata={"help": "The name of the repository to which push the `Trainer`."}
@ -984,16 +1002,19 @@ class TrainingArguments:
)
self.half_precision_backend = self.fp16_backend
if (self.bf16 or self.bf16_full_eval) and not is_torch_bf16_available():
raise ValueError("Your setup doesn't support bf16. You need Ampere GPU, torch>=1.10, cuda>=11.0")
if (self.bf16 or self.bf16_full_eval) and not is_torch_bf16_available() and not self.no_cuda:
raise ValueError(
"Your setup doesn't support bf16. You need torch>=1.10, using Ampere GPU with cuda>=11.0 or using CPU"
" (no_cuda)"
)
if self.fp16 and self.bf16:
raise ValueError("At most one of fp16 and bf16 can be True, but not both")
if self.bf16:
if self.half_precision_backend == "apex":
raise ValueError(
" `--half_precision_backend apex`: bf16 is not supported by apex. Use `--half_precision_backend"
" amp` instead"
" `--half_precision_backend apex`: GPU bf16 is not supported by apex. Use"
" `--half_precision_backend cuda_amp` instead"
)
if not (self.sharded_ddp == "" or not self.sharded_ddp):
raise ValueError("sharded_ddp is not supported with bf16")
@ -1011,11 +1032,23 @@ class TrainingArguments:
is_torch_available()
and (self.device.type != "cuda")
and not (self.device.type == "xla" and "GPU_NUM_DEVICES" in os.environ)
and (self.fp16 or self.fp16_full_eval or self.bf16 or self.bf16_full_eval)
and (self.fp16 or self.fp16_full_eval)
):
raise ValueError(
"Mixed precision training with AMP or APEX (`--fp16` or `--bf16`) and half precision evaluation"
" (`--fp16_full_eval` or `--bf16_full_eval`) can only be used on CUDA devices."
"FP16 Mixed precision training with AMP or APEX (`--fp16`) and FP16 half precision evaluation"
" (`--fp16_full_eval`) can only be used on CUDA devices."
)
if (
is_torch_available()
and (self.device.type != "cuda")
and not (self.device.type == "xla" and "GPU_NUM_DEVICES" in os.environ)
and (self.device.type != "cpu")
and (self.bf16 or self.bf16_full_eval)
):
raise ValueError(
"BF16 Mixed precision training with AMP (`--bf16`) and BF16 half precision evaluation"
" (`--bf16_full_eval`) can only be used on CUDA or CPU devices."
)
if is_torch_available() and self.tf32 is not None:

View File

@ -97,6 +97,7 @@ from .import_utils import (
is_flax_available,
is_ftfy_available,
is_in_notebook,
is_ipex_available,
is_librosa_available,
is_onnx_available,
is_pandas_available,

View File

@ -282,25 +282,33 @@ def is_torch_bf16_available():
# some bits come from https://github.com/pytorch/pytorch/blob/2289a12f21c54da93bf5d696e3f9aea83dd9c10d/torch/testing/_internal/common_cuda.py#L51
# with additional check for torch version
# to succeed:
# 1. the hardware needs to support bf16 (arch >= Ampere)
# 2. torch >= 1.10 (1.9 should be enough for AMP API has changed in 1.10, so using 1.10 as minimal)
# 3. CUDA >= 11
# 1. torch >= 1.10 (1.9 should be enough for AMP API has changed in 1.10, so using 1.10 as minimal)
# 2. the hardware needs to support bf16 (GPU arch >= Ampere, or CPU)
# 3. if using gpu, CUDA >= 11
# 4. torch.autocast exists
# XXX: one problem here is that it may give invalid results on mixed gpus setup, so it's
# really only correct for the 0th gpu (or currently set default device if different from 0)
if not torch.cuda.is_available() or torch.version.cuda is None:
return False
if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8:
return False
if int(torch.version.cuda.split(".")[0]) < 11:
return False
is_torch_gpu_bf16_available = True
is_torch_cpu_bf16_available = True
if version.parse(torch.__version__) < version.parse("1.10"):
return False
if not hasattr(torch, "autocast"):
return False
is_torch_gpu_bf16_available = False
is_torch_cpu_bf16_available = False
return True
if torch.cuda.is_available() and torch.version.cuda is not None:
if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8:
is_torch_gpu_bf16_available = False
if int(torch.version.cuda.split(".")[0]) < 11:
is_torch_gpu_bf16_available = False
if not hasattr(torch.cuda.amp, "autocast"):
is_torch_gpu_bf16_available = False
else:
is_torch_gpu_bf16_available = False
# checking CPU
if not hasattr(torch.cpu.amp, "autocast"):
is_torch_cpu_bf16_available = False
return is_torch_cpu_bf16_available or is_torch_gpu_bf16_available
def is_torch_tf32_available():
@ -404,6 +412,10 @@ def is_apex_available():
return importlib.util.find_spec("apex") is not None
def is_ipex_available():
return importlib.util.find_spec("intel_extension_for_pytorch") is not None
def is_bitsandbytes_available():
return importlib.util.find_spec("bitsandbytes") is not None

View File

@ -50,6 +50,7 @@ from transformers.testing_utils import (
get_gpu_count,
get_tests_dir,
is_staging_test,
require_intel_extension_for_pytorch,
require_optuna,
require_ray,
require_sentencepiece,
@ -640,6 +641,29 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
train_output = trainer.train()
self.assertEqual(train_output.global_step, 10)
@require_torch_bf16
@require_intel_extension_for_pytorch
def test_number_of_steps_in_training_with_ipex(self):
for mix_bf16 in [True, False]:
# Regular training has n_epochs * len(train_dl) steps
trainer = get_regression_trainer(learning_rate=0.1, use_ipex=True, bf16=mix_bf16, no_cuda=True)
train_output = trainer.train()
self.assertEqual(train_output.global_step, self.n_epochs * 64 / self.batch_size)
# Check passing num_train_epochs works (and a float version too):
trainer = get_regression_trainer(
learning_rate=0.1, num_train_epochs=1.5, use_ipex=True, bf16=mix_bf16, no_cuda=True
)
train_output = trainer.train()
self.assertEqual(train_output.global_step, int(1.5 * 64 / self.batch_size))
# If we pass a max_steps, num_train_epochs is ignored
trainer = get_regression_trainer(
learning_rate=0.1, max_steps=10, use_ipex=True, bf16=mix_bf16, no_cuda=True
)
train_output = trainer.train()
self.assertEqual(train_output.global_step, 10)
def test_logging_inf_nan_filter(self):
config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4)
tiny_gpt2 = GPT2LMHeadModel(config)
@ -820,6 +844,60 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
expected_acc = AlmostAccuracy()((pred + 1, y))["accuracy"]
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
@require_torch_bf16
@require_intel_extension_for_pytorch
def test_evaluate_with_ipex(self):
for mix_bf16 in [True, False]:
trainer = get_regression_trainer(
a=1.5, b=2.5, use_ipex=True, compute_metrics=AlmostAccuracy(), bf16=mix_bf16, no_cuda=True
)
results = trainer.evaluate()
x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0]
pred = 1.5 * x + 2.5
expected_loss = ((pred - y) ** 2).mean()
self.assertAlmostEqual(results["eval_loss"], expected_loss)
expected_acc = AlmostAccuracy()((pred, y))["accuracy"]
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
# With a number of elements not a round multiple of the batch size
trainer = get_regression_trainer(
a=1.5,
b=2.5,
use_ipex=True,
eval_len=66,
compute_metrics=AlmostAccuracy(),
bf16=mix_bf16,
no_cuda=True,
)
results = trainer.evaluate()
x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0]
pred = 1.5 * x + 2.5
expected_loss = ((pred - y) ** 2).mean()
self.assertAlmostEqual(results["eval_loss"], expected_loss)
expected_acc = AlmostAccuracy()((pred, y))["accuracy"]
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
# With logits preprocess
trainer = get_regression_trainer(
a=1.5,
b=2.5,
use_ipex=True,
compute_metrics=AlmostAccuracy(),
preprocess_logits_for_metrics=lambda logits, labels: logits + 1,
bf16=mix_bf16,
no_cuda=True,
)
results = trainer.evaluate()
x, y = trainer.eval_dataset.x, trainer.eval_dataset.ys[0]
pred = 1.5 * x + 2.5
expected_loss = ((pred - y) ** 2).mean()
self.assertAlmostEqual(results["eval_loss"], expected_loss)
expected_acc = AlmostAccuracy()((pred + 1, y))["accuracy"]
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
def test_predict(self):
trainer = get_regression_trainer(a=1.5, b=2.5)
preds = trainer.predict(trainer.eval_dataset).predictions
@ -852,6 +930,51 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertTrue(np.array_equal(labels[0], trainer.eval_dataset.ys[0]))
self.assertTrue(np.array_equal(labels[1], trainer.eval_dataset.ys[1]))
@require_torch_bf16
@require_intel_extension_for_pytorch
def test_predict_with_ipex(self):
for mix_bf16 in [True, False]:
trainer = get_regression_trainer(a=1.5, b=2.5, use_ipex=True, bf16=mix_bf16, no_cuda=True)
preds = trainer.predict(trainer.eval_dataset).predictions
x = trainer.eval_dataset.x
self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))
# With a number of elements not a round multiple of the batch size
trainer = get_regression_trainer(a=1.5, b=2.5, eval_len=66, use_ipex=True, bf16=mix_bf16, no_cuda=True)
preds = trainer.predict(trainer.eval_dataset).predictions
x = trainer.eval_dataset.x
self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))
# With more than one output of the model
trainer = get_regression_trainer(
a=1.5, b=2.5, double_output=True, use_ipex=True, bf16=mix_bf16, no_cuda=True
)
preds = trainer.predict(trainer.eval_dataset).predictions
x = trainer.eval_dataset.x
self.assertEqual(len(preds), 2)
self.assertTrue(np.allclose(preds[0], 1.5 * x + 2.5))
self.assertTrue(np.allclose(preds[1], 1.5 * x + 2.5))
# With more than one output/label of the model
trainer = get_regression_trainer(
a=1.5,
b=2.5,
double_output=True,
label_names=["labels", "labels_2"],
use_ipex=True,
bf16=mix_bf16,
no_cuda=True,
)
outputs = trainer.predict(trainer.eval_dataset)
preds = outputs.predictions
labels = outputs.label_ids
x = trainer.eval_dataset.x
self.assertEqual(len(preds), 2)
self.assertTrue(np.allclose(preds[0], 1.5 * x + 2.5))
self.assertTrue(np.allclose(preds[1], 1.5 * x + 2.5))
self.assertTrue(np.array_equal(labels[0], trainer.eval_dataset.ys[0]))
self.assertTrue(np.array_equal(labels[1], trainer.eval_dataset.ys[1]))
def test_dynamic_shapes(self):
eval_dataset = DynamicShapesDataset(batch_size=self.batch_size)
model = RegressionModel(a=2, b=1)