mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
enable cpu offloading for Bark on xpu (#37599)
* enable cpu offloading of bark modeling on XPU Signed-off-by: YAO Matrix <matrix.yao@intel.com> * remove debug print Signed-off-by: YAO Matrix <matrix.yao@intel.com> * fix style Signed-off-by: YAO Matrix <matrix.yao@intel.com> * fix review comments Signed-off-by: YAO Matrix <matrix.yao@intel.com> * enhance test Signed-off-by: YAO Matrix <matrix.yao@intel.com> * update * add deprecate message Signed-off-by: YAO Matrix <matrix.yao@intel.com> * update * update * trigger CI --------- Signed-off-by: YAO Matrix <matrix.yao@intel.com> Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
4f9893cbbc
commit
12f65ee752
@ -15,6 +15,7 @@
|
||||
"""PyTorch BARK model."""
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@ -36,6 +37,7 @@ from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_accelerate_available,
|
||||
is_torch_accelerator_available,
|
||||
logging,
|
||||
)
|
||||
from ..auto import AutoModel
|
||||
@ -1598,26 +1600,45 @@ class BarkModel(BarkPreTrainedModel):
|
||||
):
|
||||
return torch.device(module._hf_hook.execution_device)
|
||||
|
||||
def enable_cpu_offload(self, gpu_id: Optional[int] = 0):
|
||||
def enable_cpu_offload(
|
||||
self,
|
||||
accelerator_id: Optional[int] = 0,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Offloads all sub-models to CPU using accelerate, reducing memory usage with a low impact on performance. This
|
||||
method moves one whole sub-model at a time to the GPU when it is used, and the sub-model remains in GPU until
|
||||
the next sub-model runs.
|
||||
method moves one whole sub-model at a time to the accelerator when it is used, and the sub-model remains in accelerator until the next sub-model runs.
|
||||
|
||||
Args:
|
||||
gpu_id (`int`, *optional*, defaults to 0):
|
||||
GPU id on which the sub-models will be loaded and offloaded.
|
||||
accelerator_id (`int`, *optional*, defaults to 0):
|
||||
accelerator id on which the sub-models will be loaded and offloaded. This argument is deprecated.
|
||||
kwargs (`dict`, *optional*):
|
||||
additional keyword arguments:
|
||||
`gpu_id`: accelerator id on which the sub-models will be loaded and offloaded.
|
||||
"""
|
||||
if is_accelerate_available():
|
||||
from accelerate import cpu_offload_with_hook
|
||||
else:
|
||||
raise ImportError("`enable_model_cpu_offload` requires `accelerate`.")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
gpu_id = kwargs.get("gpu_id", 0)
|
||||
|
||||
if gpu_id != 0:
|
||||
warnings.warn(
|
||||
"The argument `gpu_id` is deprecated and will be removed in version 4.54.0 of Transformers. Please use `accelerator_id` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
accelerator_id = gpu_id
|
||||
|
||||
device_type = "cuda"
|
||||
if is_torch_accelerator_available():
|
||||
device_type = torch.accelerator.current_accelerator().type
|
||||
device = torch.device(f"{device_type}:{accelerator_id}")
|
||||
|
||||
torch_accelerator_module = getattr(torch, device_type)
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu")
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
torch_accelerator_module.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
# this layer is used outside the first foward pass of semantic so need to be loaded before semantic
|
||||
self.semantic.input_embeds_layer, _ = cpu_offload_with_hook(self.semantic.input_embeds_layer, device)
|
||||
|
@ -211,6 +211,7 @@ from .import_utils import (
|
||||
is_tiktoken_available,
|
||||
is_timm_available,
|
||||
is_tokenizers_available,
|
||||
is_torch_accelerator_available,
|
||||
is_torch_available,
|
||||
is_torch_bf16_available,
|
||||
is_torch_bf16_available_on_device,
|
||||
|
@ -346,16 +346,28 @@ def is_accelerate_available(min_version: str = ACCELERATE_MIN_VERSION):
|
||||
return _accelerate_available and version.parse(_accelerate_version) >= version.parse(min_version)
|
||||
|
||||
|
||||
def is_torch_accelerator_available():
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
return hasattr(torch, "accelerator")
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def is_torch_deterministic():
|
||||
"""
|
||||
Check whether pytorch uses deterministic algorithms by looking if torch.set_deterministic_debug_mode() is set to 1 or 2"
|
||||
"""
|
||||
import torch
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if torch.get_deterministic_debug_mode() == 0:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
if torch.get_deterministic_debug_mode() == 0:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def is_hadamard_available():
|
||||
|
@ -36,6 +36,7 @@ from transformers.models.bark.generation_configuration_bark import (
|
||||
from transformers.testing_utils import (
|
||||
require_flash_attn,
|
||||
require_torch,
|
||||
require_torch_accelerator,
|
||||
require_torch_fp16,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
@ -1056,7 +1057,8 @@ class BarkModelIntegrationTests(unittest.TestCase):
|
||||
def inputs(self):
|
||||
input_ids = self.processor("In the light of the moon, a little egg lay on a leaf", voice_preset="en_speaker_6")
|
||||
|
||||
input_ids = input_ids.to(torch_device)
|
||||
for k, v in input_ids.items():
|
||||
input_ids[k] = v.to(torch_device)
|
||||
|
||||
return input_ids
|
||||
|
||||
@ -1295,7 +1297,7 @@ class BarkModelIntegrationTests(unittest.TestCase):
|
||||
len(output_ids_with_min_eos_p[0, :].tolist()), len(output_ids_without_min_eos_p[0, :].tolist())
|
||||
)
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
@slow
|
||||
def test_generate_end_to_end_with_offload(self):
|
||||
input_ids = self.inputs
|
||||
@ -1304,15 +1306,17 @@ class BarkModelIntegrationTests(unittest.TestCase):
|
||||
# standard generation
|
||||
output_with_no_offload = self.model.generate(**input_ids, do_sample=False, temperature=1.0)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
torch_accelerator_module = getattr(torch, torch_device, torch.cuda)
|
||||
|
||||
memory_before_offload = torch.cuda.memory_allocated()
|
||||
torch_accelerator_module.empty_cache()
|
||||
|
||||
memory_before_offload = torch_accelerator_module.memory_allocated()
|
||||
model_memory_footprint = self.model.get_memory_footprint()
|
||||
|
||||
# activate cpu offload
|
||||
self.model.enable_cpu_offload()
|
||||
|
||||
memory_after_offload = torch.cuda.memory_allocated()
|
||||
memory_after_offload = torch_accelerator_module.memory_allocated()
|
||||
|
||||
# checks if the model have been offloaded
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user