enable cpu offloading for Bark on xpu (#37599)

* enable cpu offloading of bark modeling on XPU

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* remove debug print

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* fix style

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* fix review comments

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* enhance test

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* update

* add deprecate message

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* update

* update

* trigger CI

---------

Signed-off-by: YAO Matrix <matrix.yao@intel.com>
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yao Matrix 2025-04-23 17:37:15 +08:00 committed by GitHub
parent 4f9893cbbc
commit 12f65ee752
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 55 additions and 17 deletions

View File

@ -15,6 +15,7 @@
"""PyTorch BARK model.""" """PyTorch BARK model."""
import math import math
import warnings
from typing import Dict, Optional, Tuple, Union from typing import Dict, Optional, Tuple, Union
import numpy as np import numpy as np
@ -36,6 +37,7 @@ from ...utils import (
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
is_accelerate_available, is_accelerate_available,
is_torch_accelerator_available,
logging, logging,
) )
from ..auto import AutoModel from ..auto import AutoModel
@ -1598,26 +1600,45 @@ class BarkModel(BarkPreTrainedModel):
): ):
return torch.device(module._hf_hook.execution_device) return torch.device(module._hf_hook.execution_device)
def enable_cpu_offload(self, gpu_id: Optional[int] = 0): def enable_cpu_offload(
self,
accelerator_id: Optional[int] = 0,
**kwargs,
):
r""" r"""
Offloads all sub-models to CPU using accelerate, reducing memory usage with a low impact on performance. This Offloads all sub-models to CPU using accelerate, reducing memory usage with a low impact on performance. This
method moves one whole sub-model at a time to the GPU when it is used, and the sub-model remains in GPU until method moves one whole sub-model at a time to the accelerator when it is used, and the sub-model remains in accelerator until the next sub-model runs.
the next sub-model runs.
Args: Args:
gpu_id (`int`, *optional*, defaults to 0): accelerator_id (`int`, *optional*, defaults to 0):
GPU id on which the sub-models will be loaded and offloaded. accelerator id on which the sub-models will be loaded and offloaded. This argument is deprecated.
kwargs (`dict`, *optional*):
additional keyword arguments:
`gpu_id`: accelerator id on which the sub-models will be loaded and offloaded.
""" """
if is_accelerate_available(): if is_accelerate_available():
from accelerate import cpu_offload_with_hook from accelerate import cpu_offload_with_hook
else: else:
raise ImportError("`enable_model_cpu_offload` requires `accelerate`.") raise ImportError("`enable_model_cpu_offload` requires `accelerate`.")
device = torch.device(f"cuda:{gpu_id}") gpu_id = kwargs.get("gpu_id", 0)
if gpu_id != 0:
warnings.warn(
"The argument `gpu_id` is deprecated and will be removed in version 4.54.0 of Transformers. Please use `accelerator_id` instead.",
FutureWarning,
)
accelerator_id = gpu_id
device_type = "cuda"
if is_torch_accelerator_available():
device_type = torch.accelerator.current_accelerator().type
device = torch.device(f"{device_type}:{accelerator_id}")
torch_accelerator_module = getattr(torch, device_type)
if self.device.type != "cpu": if self.device.type != "cpu":
self.to("cpu") self.to("cpu")
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) torch_accelerator_module.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
# this layer is used outside the first foward pass of semantic so need to be loaded before semantic # this layer is used outside the first foward pass of semantic so need to be loaded before semantic
self.semantic.input_embeds_layer, _ = cpu_offload_with_hook(self.semantic.input_embeds_layer, device) self.semantic.input_embeds_layer, _ = cpu_offload_with_hook(self.semantic.input_embeds_layer, device)

View File

@ -211,6 +211,7 @@ from .import_utils import (
is_tiktoken_available, is_tiktoken_available,
is_timm_available, is_timm_available,
is_tokenizers_available, is_tokenizers_available,
is_torch_accelerator_available,
is_torch_available, is_torch_available,
is_torch_bf16_available, is_torch_bf16_available,
is_torch_bf16_available_on_device, is_torch_bf16_available_on_device,

View File

@ -346,16 +346,28 @@ def is_accelerate_available(min_version: str = ACCELERATE_MIN_VERSION):
return _accelerate_available and version.parse(_accelerate_version) >= version.parse(min_version) return _accelerate_available and version.parse(_accelerate_version) >= version.parse(min_version)
def is_torch_accelerator_available():
if is_torch_available():
import torch
return hasattr(torch, "accelerator")
return False
def is_torch_deterministic(): def is_torch_deterministic():
""" """
Check whether pytorch uses deterministic algorithms by looking if torch.set_deterministic_debug_mode() is set to 1 or 2" Check whether pytorch uses deterministic algorithms by looking if torch.set_deterministic_debug_mode() is set to 1 or 2"
""" """
import torch if is_torch_available():
import torch
if torch.get_deterministic_debug_mode() == 0: if torch.get_deterministic_debug_mode() == 0:
return False return False
else: else:
return True return True
return False
def is_hadamard_available(): def is_hadamard_available():

View File

@ -36,6 +36,7 @@ from transformers.models.bark.generation_configuration_bark import (
from transformers.testing_utils import ( from transformers.testing_utils import (
require_flash_attn, require_flash_attn,
require_torch, require_torch,
require_torch_accelerator,
require_torch_fp16, require_torch_fp16,
require_torch_gpu, require_torch_gpu,
slow, slow,
@ -1056,7 +1057,8 @@ class BarkModelIntegrationTests(unittest.TestCase):
def inputs(self): def inputs(self):
input_ids = self.processor("In the light of the moon, a little egg lay on a leaf", voice_preset="en_speaker_6") input_ids = self.processor("In the light of the moon, a little egg lay on a leaf", voice_preset="en_speaker_6")
input_ids = input_ids.to(torch_device) for k, v in input_ids.items():
input_ids[k] = v.to(torch_device)
return input_ids return input_ids
@ -1295,7 +1297,7 @@ class BarkModelIntegrationTests(unittest.TestCase):
len(output_ids_with_min_eos_p[0, :].tolist()), len(output_ids_without_min_eos_p[0, :].tolist()) len(output_ids_with_min_eos_p[0, :].tolist()), len(output_ids_without_min_eos_p[0, :].tolist())
) )
@require_torch_gpu @require_torch_accelerator
@slow @slow
def test_generate_end_to_end_with_offload(self): def test_generate_end_to_end_with_offload(self):
input_ids = self.inputs input_ids = self.inputs
@ -1304,15 +1306,17 @@ class BarkModelIntegrationTests(unittest.TestCase):
# standard generation # standard generation
output_with_no_offload = self.model.generate(**input_ids, do_sample=False, temperature=1.0) output_with_no_offload = self.model.generate(**input_ids, do_sample=False, temperature=1.0)
torch.cuda.empty_cache() torch_accelerator_module = getattr(torch, torch_device, torch.cuda)
memory_before_offload = torch.cuda.memory_allocated() torch_accelerator_module.empty_cache()
memory_before_offload = torch_accelerator_module.memory_allocated()
model_memory_footprint = self.model.get_memory_footprint() model_memory_footprint = self.model.get_memory_footprint()
# activate cpu offload # activate cpu offload
self.model.enable_cpu_offload() self.model.enable_cpu_offload()
memory_after_offload = torch.cuda.memory_allocated() memory_after_offload = torch_accelerator_module.memory_allocated()
# checks if the model have been offloaded # checks if the model have been offloaded