mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
HPU support (#36424)
* test * fix * fix * skip some and run some first * test fsdp * fix * patches for generate * test distributed * copy * don't test distributed loss for hpu * require fp16 and run first * changes from marc's PR fixing zero3 * better alternative * return True when fp16 support on gaudi without creating bridge * fix * fix tested dtype in deepspeed inference test * test * fix * test * fix * skip * require fp16 * run first fsdp * Apply suggestions from code review * address comments * address comments and refactor test * reduce precison * avoid doing gaudi1 specific stuff in the genreation loop * document test_gradient_accumulation_loss_alignment_with_model_loss test a bit more
This commit is contained in:
parent
50d3530aa0
commit
89f6956015
2
setup.py
2
setup.py
@ -152,6 +152,7 @@ _deps = [
|
|||||||
"pytest-asyncio",
|
"pytest-asyncio",
|
||||||
"pytest-timeout",
|
"pytest-timeout",
|
||||||
"pytest-xdist",
|
"pytest-xdist",
|
||||||
|
"pytest-order",
|
||||||
"python>=3.9.0",
|
"python>=3.9.0",
|
||||||
"ray[tune]>=2.7.0",
|
"ray[tune]>=2.7.0",
|
||||||
"regex!=2019.12.17",
|
"regex!=2019.12.17",
|
||||||
@ -324,6 +325,7 @@ extras["testing"] = (
|
|||||||
"pytest-asyncio",
|
"pytest-asyncio",
|
||||||
"pytest-rich",
|
"pytest-rich",
|
||||||
"pytest-xdist",
|
"pytest-xdist",
|
||||||
|
"pytest-order",
|
||||||
"timeout-decorator",
|
"timeout-decorator",
|
||||||
"parameterized",
|
"parameterized",
|
||||||
"psutil",
|
"psutil",
|
||||||
|
@ -1016,6 +1016,7 @@ _import_structure = {
|
|||||||
"is_timm_available",
|
"is_timm_available",
|
||||||
"is_tokenizers_available",
|
"is_tokenizers_available",
|
||||||
"is_torch_available",
|
"is_torch_available",
|
||||||
|
"is_torch_hpu_available",
|
||||||
"is_torch_mlu_available",
|
"is_torch_mlu_available",
|
||||||
"is_torch_musa_available",
|
"is_torch_musa_available",
|
||||||
"is_torch_neuroncore_available",
|
"is_torch_neuroncore_available",
|
||||||
@ -6243,6 +6244,7 @@ if TYPE_CHECKING:
|
|||||||
is_timm_available,
|
is_timm_available,
|
||||||
is_tokenizers_available,
|
is_tokenizers_available,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
|
is_torch_hpu_available,
|
||||||
is_torch_mlu_available,
|
is_torch_mlu_available,
|
||||||
is_torch_musa_available,
|
is_torch_musa_available,
|
||||||
is_torch_neuroncore_available,
|
is_torch_neuroncore_available,
|
||||||
|
@ -30,6 +30,7 @@ from ..utils import (
|
|||||||
is_safetensors_available,
|
is_safetensors_available,
|
||||||
is_tf_available,
|
is_tf_available,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
|
is_torch_hpu_available,
|
||||||
is_torch_npu_available,
|
is_torch_npu_available,
|
||||||
)
|
)
|
||||||
from . import BaseTransformersCLICommand
|
from . import BaseTransformersCLICommand
|
||||||
@ -94,6 +95,7 @@ class EnvironmentCommand(BaseTransformersCLICommand):
|
|||||||
pt_version = torch.__version__
|
pt_version = torch.__version__
|
||||||
pt_cuda_available = torch.cuda.is_available()
|
pt_cuda_available = torch.cuda.is_available()
|
||||||
pt_npu_available = is_torch_npu_available()
|
pt_npu_available = is_torch_npu_available()
|
||||||
|
pt_hpu_available = is_torch_hpu_available()
|
||||||
|
|
||||||
tf_version = "not installed"
|
tf_version = "not installed"
|
||||||
tf_cuda_available = "NA"
|
tf_cuda_available = "NA"
|
||||||
@ -149,6 +151,9 @@ class EnvironmentCommand(BaseTransformersCLICommand):
|
|||||||
if pt_cuda_available:
|
if pt_cuda_available:
|
||||||
info["Using GPU in script?"] = "<fill in>"
|
info["Using GPU in script?"] = "<fill in>"
|
||||||
info["GPU type"] = torch.cuda.get_device_name()
|
info["GPU type"] = torch.cuda.get_device_name()
|
||||||
|
elif pt_hpu_available:
|
||||||
|
info["Using HPU in script?"] = "<fill in>"
|
||||||
|
info["HPU type"] = torch.hpu.get_device_name()
|
||||||
elif pt_npu_available:
|
elif pt_npu_available:
|
||||||
info["Using NPU in script?"] = "<fill in>"
|
info["Using NPU in script?"] = "<fill in>"
|
||||||
info["NPU type"] = torch.npu.get_device_name()
|
info["NPU type"] = torch.npu.get_device_name()
|
||||||
|
@ -58,6 +58,7 @@ deps = {
|
|||||||
"pytest-asyncio": "pytest-asyncio",
|
"pytest-asyncio": "pytest-asyncio",
|
||||||
"pytest-timeout": "pytest-timeout",
|
"pytest-timeout": "pytest-timeout",
|
||||||
"pytest-xdist": "pytest-xdist",
|
"pytest-xdist": "pytest-xdist",
|
||||||
|
"pytest-order": "pytest-order",
|
||||||
"python": "python>=3.9.0",
|
"python": "python>=3.9.0",
|
||||||
"ray[tune]": "ray[tune]>=2.7.0",
|
"ray[tune]": "ray[tune]>=2.7.0",
|
||||||
"regex": "regex!=2019.12.17",
|
"regex": "regex!=2019.12.17",
|
||||||
|
@ -598,6 +598,8 @@ class EncoderDecoderModel(PreTrainedModel, GenerationMixin):
|
|||||||
kwargs_decoder = {
|
kwargs_decoder = {
|
||||||
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
|
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
|
||||||
}
|
}
|
||||||
|
if "num_items_in_batch" in kwargs_encoder:
|
||||||
|
kwargs_decoder["num_items_in_batch"] = kwargs_encoder.pop("num_items_in_batch", None)
|
||||||
|
|
||||||
if encoder_outputs is None:
|
if encoder_outputs is None:
|
||||||
encoder_outputs = self.encoder(
|
encoder_outputs = self.encoder(
|
||||||
|
@ -45,6 +45,7 @@ from ..utils import (
|
|||||||
is_tf_available,
|
is_tf_available,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
is_torch_cuda_available,
|
is_torch_cuda_available,
|
||||||
|
is_torch_hpu_available,
|
||||||
is_torch_mlu_available,
|
is_torch_mlu_available,
|
||||||
is_torch_mps_available,
|
is_torch_mps_available,
|
||||||
is_torch_musa_available,
|
is_torch_musa_available,
|
||||||
@ -963,6 +964,8 @@ class Pipeline(_ScikitCompat, PushToHubMixin):
|
|||||||
self.device = torch.device(f"cuda:{device}")
|
self.device = torch.device(f"cuda:{device}")
|
||||||
elif is_torch_npu_available():
|
elif is_torch_npu_available():
|
||||||
self.device = torch.device(f"npu:{device}")
|
self.device = torch.device(f"npu:{device}")
|
||||||
|
elif is_torch_hpu_available():
|
||||||
|
self.device = torch.device(f"hpu:{device}")
|
||||||
elif is_torch_xpu_available(check_device=True):
|
elif is_torch_xpu_available(check_device=True):
|
||||||
self.device = torch.device(f"xpu:{device}")
|
self.device = torch.device(f"xpu:{device}")
|
||||||
elif is_torch_mps_available():
|
elif is_torch_mps_available():
|
||||||
|
@ -29,6 +29,7 @@ from ..utils import (
|
|||||||
is_accelerate_available,
|
is_accelerate_available,
|
||||||
is_bitsandbytes_available,
|
is_bitsandbytes_available,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
|
is_torch_hpu_available,
|
||||||
is_torch_npu_available,
|
is_torch_npu_available,
|
||||||
is_torch_xpu_available,
|
is_torch_xpu_available,
|
||||||
logging,
|
logging,
|
||||||
@ -269,6 +270,8 @@ class Bnb4BitHfQuantizer(HfQuantizer):
|
|||||||
device_map = {"": torch.cuda.current_device()}
|
device_map = {"": torch.cuda.current_device()}
|
||||||
elif is_torch_npu_available():
|
elif is_torch_npu_available():
|
||||||
device_map = {"": f"npu:{torch.npu.current_device()}"}
|
device_map = {"": f"npu:{torch.npu.current_device()}"}
|
||||||
|
elif is_torch_hpu_available():
|
||||||
|
device_map = {"": f"hpu:{torch.hpu.current_device()}"}
|
||||||
elif is_torch_xpu_available():
|
elif is_torch_xpu_available():
|
||||||
device_map = {"": f"xpu:{torch.xpu.current_device()}"}
|
device_map = {"": f"xpu:{torch.xpu.current_device()}"}
|
||||||
else:
|
else:
|
||||||
|
@ -141,6 +141,7 @@ from .utils import (
|
|||||||
is_torch_deterministic,
|
is_torch_deterministic,
|
||||||
is_torch_fp16_available_on_device,
|
is_torch_fp16_available_on_device,
|
||||||
is_torch_greater_or_equal,
|
is_torch_greater_or_equal,
|
||||||
|
is_torch_hpu_available,
|
||||||
is_torch_neuroncore_available,
|
is_torch_neuroncore_available,
|
||||||
is_torch_npu_available,
|
is_torch_npu_available,
|
||||||
is_torch_sdpa_available,
|
is_torch_sdpa_available,
|
||||||
@ -858,6 +859,13 @@ def require_torch_multi_npu(test_case):
|
|||||||
return unittest.skipUnless(torch.npu.device_count() > 1, "test requires multiple NPUs")(test_case)
|
return unittest.skipUnless(torch.npu.device_count() > 1, "test requires multiple NPUs")(test_case)
|
||||||
|
|
||||||
|
|
||||||
|
def require_non_hpu(test_case):
|
||||||
|
"""
|
||||||
|
Decorator marking a test that should be skipped for HPU.
|
||||||
|
"""
|
||||||
|
return unittest.skipUnless(torch_device != "hpu", "test requires a non-HPU")(test_case)
|
||||||
|
|
||||||
|
|
||||||
def require_torch_xpu(test_case):
|
def require_torch_xpu(test_case):
|
||||||
"""
|
"""
|
||||||
Decorator marking a test that requires XPU (in PyTorch).
|
Decorator marking a test that requires XPU (in PyTorch).
|
||||||
@ -889,6 +897,19 @@ def require_torch_multi_xpu(test_case):
|
|||||||
return unittest.skipUnless(torch.xpu.device_count() > 1, "test requires multiple XPUs")(test_case)
|
return unittest.skipUnless(torch.xpu.device_count() > 1, "test requires multiple XPUs")(test_case)
|
||||||
|
|
||||||
|
|
||||||
|
def require_torch_multi_hpu(test_case):
|
||||||
|
"""
|
||||||
|
Decorator marking a test that requires a multi-HPU setup (in PyTorch). These tests are skipped on a machine without
|
||||||
|
multiple HPUs.
|
||||||
|
|
||||||
|
To run *only* the multi_hpu tests, assuming all test names contain multi_hpu: $ pytest -sv ./tests -k "multi_hpu"
|
||||||
|
"""
|
||||||
|
if not is_torch_hpu_available():
|
||||||
|
return unittest.skip(reason="test requires PyTorch HPU")(test_case)
|
||||||
|
|
||||||
|
return unittest.skipUnless(torch.hpu.device_count() > 1, "test requires multiple HPUs")(test_case)
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
# Set env var CUDA_VISIBLE_DEVICES="" to force cpu-mode
|
# Set env var CUDA_VISIBLE_DEVICES="" to force cpu-mode
|
||||||
import torch
|
import torch
|
||||||
@ -917,6 +938,10 @@ if is_torch_available():
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"TRANSFORMERS_TEST_DEVICE={torch_device}, but NPU is unavailable. Please double-check your testing environment."
|
f"TRANSFORMERS_TEST_DEVICE={torch_device}, but NPU is unavailable. Please double-check your testing environment."
|
||||||
)
|
)
|
||||||
|
if torch_device == "hpu" and not is_torch_hpu_available():
|
||||||
|
raise ValueError(
|
||||||
|
f"TRANSFORMERS_TEST_DEVICE={torch_device}, but HPU is unavailable. Please double-check your testing environment."
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# try creating device to see if provided device is valid
|
# try creating device to see if provided device is valid
|
||||||
@ -929,6 +954,8 @@ if is_torch_available():
|
|||||||
torch_device = "cuda"
|
torch_device = "cuda"
|
||||||
elif _run_third_party_device_tests and is_torch_npu_available():
|
elif _run_third_party_device_tests and is_torch_npu_available():
|
||||||
torch_device = "npu"
|
torch_device = "npu"
|
||||||
|
elif _run_third_party_device_tests and is_torch_hpu_available():
|
||||||
|
torch_device = "hpu"
|
||||||
elif _run_third_party_device_tests and is_torch_xpu_available():
|
elif _run_third_party_device_tests and is_torch_xpu_available():
|
||||||
torch_device = "xpu"
|
torch_device = "xpu"
|
||||||
else:
|
else:
|
||||||
@ -2565,6 +2592,20 @@ def hub_retry(max_attempts: int = 5, wait_before_retry: Optional[float] = 2):
|
|||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def run_first(test_case):
|
||||||
|
"""
|
||||||
|
Decorator marking a test with order(1). When pytest-order plugin is installed, tests marked with this decorator
|
||||||
|
are garanteed to run first.
|
||||||
|
|
||||||
|
This is especially useful in some test settings like on a Gaudi instance where a Gaudi device can only be used by a
|
||||||
|
single process at a time. So we make sure all tests that run in a subprocess are launched first, to avoid device
|
||||||
|
allocation conflicts.
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
return pytest.mark.order(1)(test_case)
|
||||||
|
|
||||||
|
|
||||||
def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None):
|
def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None):
|
||||||
"""
|
"""
|
||||||
To run a test in a subprocess. In particular, this can avoid (GPU) memory issue.
|
To run a test in a subprocess. In particular, this can avoid (GPU) memory issue.
|
||||||
@ -2853,6 +2894,25 @@ else:
|
|||||||
BACKEND_EMPTY_CACHE = {"default": None}
|
BACKEND_EMPTY_CACHE = {"default": None}
|
||||||
BACKEND_DEVICE_COUNT = {"default": lambda: 0}
|
BACKEND_DEVICE_COUNT = {"default": lambda: 0}
|
||||||
|
|
||||||
|
if is_torch_hpu_available():
|
||||||
|
BACKEND_MANUAL_SEED["hpu"] = torch.hpu.manual_seed
|
||||||
|
BACKEND_DEVICE_COUNT["hpu"] = torch.hpu.device_count
|
||||||
|
|
||||||
|
if is_torch_npu_available():
|
||||||
|
BACKEND_EMPTY_CACHE["npu"] = torch.npu.empty_cache
|
||||||
|
BACKEND_MANUAL_SEED["npu"] = torch.npu.manual_seed
|
||||||
|
BACKEND_DEVICE_COUNT["npu"] = torch.npu.device_count
|
||||||
|
|
||||||
|
if is_torch_xpu_available():
|
||||||
|
BACKEND_EMPTY_CACHE["xpu"] = torch.xpu.empty_cache
|
||||||
|
BACKEND_MANUAL_SEED["xpu"] = torch.xpu.manual_seed
|
||||||
|
BACKEND_DEVICE_COUNT["xpu"] = torch.xpu.device_count
|
||||||
|
|
||||||
|
if is_torch_xla_available():
|
||||||
|
BACKEND_EMPTY_CACHE["xla"] = torch.cuda.empty_cache
|
||||||
|
BACKEND_MANUAL_SEED["xla"] = torch.cuda.manual_seed
|
||||||
|
BACKEND_DEVICE_COUNT["xla"] = torch.cuda.device_count
|
||||||
|
|
||||||
|
|
||||||
def backend_manual_seed(device: str, seed: int):
|
def backend_manual_seed(device: str, seed: int):
|
||||||
return _device_agnostic_dispatch(device, BACKEND_MANUAL_SEED, seed)
|
return _device_agnostic_dispatch(device, BACKEND_MANUAL_SEED, seed)
|
||||||
|
@ -166,6 +166,7 @@ from .utils import (
|
|||||||
is_sagemaker_mp_enabled,
|
is_sagemaker_mp_enabled,
|
||||||
is_schedulefree_available,
|
is_schedulefree_available,
|
||||||
is_torch_compile_available,
|
is_torch_compile_available,
|
||||||
|
is_torch_hpu_available,
|
||||||
is_torch_mlu_available,
|
is_torch_mlu_available,
|
||||||
is_torch_mps_available,
|
is_torch_mps_available,
|
||||||
is_torch_musa_available,
|
is_torch_musa_available,
|
||||||
@ -3141,9 +3142,10 @@ class Trainer:
|
|||||||
set_rng_state_for_device("CUDA", torch.cuda, checkpoint_rng_state, is_distributed)
|
set_rng_state_for_device("CUDA", torch.cuda, checkpoint_rng_state, is_distributed)
|
||||||
if is_torch_npu_available():
|
if is_torch_npu_available():
|
||||||
set_rng_state_for_device("NPU", torch.npu, checkpoint_rng_state, is_distributed)
|
set_rng_state_for_device("NPU", torch.npu, checkpoint_rng_state, is_distributed)
|
||||||
|
if is_torch_hpu_available():
|
||||||
|
set_rng_state_for_device("HPU", torch.hpu, checkpoint_rng_state, is_distributed)
|
||||||
if is_torch_mlu_available():
|
if is_torch_mlu_available():
|
||||||
set_rng_state_for_device("MLU", torch.mlu, checkpoint_rng_state, is_distributed)
|
set_rng_state_for_device("MLU", torch.mlu, checkpoint_rng_state, is_distributed)
|
||||||
|
|
||||||
if is_torch_musa_available():
|
if is_torch_musa_available():
|
||||||
set_rng_state_for_device("MUSA", torch.musa, checkpoint_rng_state, is_distributed)
|
set_rng_state_for_device("MUSA", torch.musa, checkpoint_rng_state, is_distributed)
|
||||||
|
|
||||||
@ -3255,6 +3257,12 @@ class Trainer:
|
|||||||
else:
|
else:
|
||||||
rng_states["npu"] = torch.npu.random.get_rng_state()
|
rng_states["npu"] = torch.npu.random.get_rng_state()
|
||||||
|
|
||||||
|
if is_torch_hpu_available():
|
||||||
|
if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
|
||||||
|
rng_states["hpu"] = torch.hpu.random.get_rng_state_all()
|
||||||
|
else:
|
||||||
|
rng_states["hpu"] = torch.hpu.random.get_rng_state()
|
||||||
|
|
||||||
if is_torch_mlu_available():
|
if is_torch_mlu_available():
|
||||||
if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
|
if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
|
||||||
rng_states["mlu"] = torch.mlu.random.get_rng_state_all()
|
rng_states["mlu"] = torch.mlu.random.get_rng_state_all()
|
||||||
@ -3725,6 +3733,10 @@ class Trainer:
|
|||||||
torch.npu.empty_cache()
|
torch.npu.empty_cache()
|
||||||
elif is_torch_mps_available(min_version="2.0"):
|
elif is_torch_mps_available(min_version="2.0"):
|
||||||
torch.mps.empty_cache()
|
torch.mps.empty_cache()
|
||||||
|
elif is_torch_hpu_available():
|
||||||
|
logger.warning(
|
||||||
|
"`torch_empty_cache_steps` is set but HPU device/backend does not support empty_cache()."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
@ -35,6 +35,7 @@ from .utils import (
|
|||||||
is_tf_available,
|
is_tf_available,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
is_torch_cuda_available,
|
is_torch_cuda_available,
|
||||||
|
is_torch_hpu_available,
|
||||||
is_torch_mlu_available,
|
is_torch_mlu_available,
|
||||||
is_torch_mps_available,
|
is_torch_mps_available,
|
||||||
is_torch_musa_available,
|
is_torch_musa_available,
|
||||||
@ -113,6 +114,8 @@ def set_seed(seed: int, deterministic: bool = False):
|
|||||||
torch.musa.manual_seed_all(seed)
|
torch.musa.manual_seed_all(seed)
|
||||||
if is_torch_npu_available():
|
if is_torch_npu_available():
|
||||||
torch.npu.manual_seed_all(seed)
|
torch.npu.manual_seed_all(seed)
|
||||||
|
if is_torch_hpu_available():
|
||||||
|
torch.hpu.manual_seed_all(seed)
|
||||||
if is_torch_xpu_available():
|
if is_torch_xpu_available():
|
||||||
torch.xpu.manual_seed_all(seed)
|
torch.xpu.manual_seed_all(seed)
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
@ -506,6 +509,11 @@ class TrainerMemoryTracker:
|
|||||||
elif is_torch_npu_available():
|
elif is_torch_npu_available():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
self.torch = torch
|
||||||
|
self.gpu = {}
|
||||||
|
elif is_torch_hpu_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
self.torch = torch
|
self.torch = torch
|
||||||
self.gpu = {}
|
self.gpu = {}
|
||||||
else:
|
else:
|
||||||
@ -573,6 +581,10 @@ class TrainerMemoryTracker:
|
|||||||
elif is_torch_npu_available():
|
elif is_torch_npu_available():
|
||||||
self.torch.npu.reset_peak_memory_stats()
|
self.torch.npu.reset_peak_memory_stats()
|
||||||
self.torch.npu.empty_cache()
|
self.torch.npu.empty_cache()
|
||||||
|
elif is_torch_hpu_available():
|
||||||
|
self.torch.hpu.reset_peak_memory_stats()
|
||||||
|
# not available on hpu as it reserves all device memory for the current process
|
||||||
|
# self.torch.hpu.empty_cache()
|
||||||
elif is_torch_mps_available():
|
elif is_torch_mps_available():
|
||||||
self.torch.mps.empty_cache()
|
self.torch.mps.empty_cache()
|
||||||
|
|
||||||
@ -588,6 +600,8 @@ class TrainerMemoryTracker:
|
|||||||
self.gpu_mem_used_at_start = self.torch.xpu.memory_allocated()
|
self.gpu_mem_used_at_start = self.torch.xpu.memory_allocated()
|
||||||
elif is_torch_npu_available():
|
elif is_torch_npu_available():
|
||||||
self.gpu_mem_used_at_start = self.torch.npu.memory_allocated()
|
self.gpu_mem_used_at_start = self.torch.npu.memory_allocated()
|
||||||
|
elif is_torch_hpu_available():
|
||||||
|
self.gpu_mem_used_at_start = self.torch.hpu.memory_allocated()
|
||||||
elif is_torch_mps_available():
|
elif is_torch_mps_available():
|
||||||
self.gpu_mem_used_at_start = self.torch.mps.current_allocated_memory()
|
self.gpu_mem_used_at_start = self.torch.mps.current_allocated_memory()
|
||||||
|
|
||||||
@ -623,6 +637,10 @@ class TrainerMemoryTracker:
|
|||||||
self.torch.xpu.empty_cache()
|
self.torch.xpu.empty_cache()
|
||||||
elif is_torch_npu_available():
|
elif is_torch_npu_available():
|
||||||
self.torch.npu.empty_cache()
|
self.torch.npu.empty_cache()
|
||||||
|
elif is_torch_hpu_available():
|
||||||
|
# not available on hpu as it reserves all device memory for the current process
|
||||||
|
# self.torch.npu.empty_cache()
|
||||||
|
pass
|
||||||
elif is_torch_mps_available():
|
elif is_torch_mps_available():
|
||||||
self.torch.mps.empty_cache()
|
self.torch.mps.empty_cache()
|
||||||
|
|
||||||
@ -648,6 +666,9 @@ class TrainerMemoryTracker:
|
|||||||
elif is_torch_npu_available():
|
elif is_torch_npu_available():
|
||||||
self.gpu_mem_used_now = self.torch.npu.memory_allocated()
|
self.gpu_mem_used_now = self.torch.npu.memory_allocated()
|
||||||
self.gpu_mem_used_peak = self.torch.npu.max_memory_allocated()
|
self.gpu_mem_used_peak = self.torch.npu.max_memory_allocated()
|
||||||
|
elif is_torch_hpu_available():
|
||||||
|
self.gpu_mem_used_now = self.torch.hpu.memory_allocated()
|
||||||
|
self.gpu_mem_used_peak = self.torch.hpu.max_memory_allocated()
|
||||||
elif is_torch_mps_available():
|
elif is_torch_mps_available():
|
||||||
self.gpu_mem_used_now = self.torch.mps.current_allocated_memory()
|
self.gpu_mem_used_now = self.torch.mps.current_allocated_memory()
|
||||||
# self.torch.mps.max_memory_allocated() does not exist yet
|
# self.torch.mps.max_memory_allocated() does not exist yet
|
||||||
|
@ -48,6 +48,7 @@ from .utils import (
|
|||||||
is_torch_available,
|
is_torch_available,
|
||||||
is_torch_bf16_cpu_available,
|
is_torch_bf16_cpu_available,
|
||||||
is_torch_bf16_gpu_available,
|
is_torch_bf16_gpu_available,
|
||||||
|
is_torch_hpu_available,
|
||||||
is_torch_mlu_available,
|
is_torch_mlu_available,
|
||||||
is_torch_mps_available,
|
is_torch_mps_available,
|
||||||
is_torch_musa_available,
|
is_torch_musa_available,
|
||||||
@ -260,9 +261,9 @@ class TrainingArguments:
|
|||||||
prediction_loss_only (`bool`, *optional*, defaults to `False`):
|
prediction_loss_only (`bool`, *optional*, defaults to `False`):
|
||||||
When performing evaluation and generating predictions, only returns the loss.
|
When performing evaluation and generating predictions, only returns the loss.
|
||||||
per_device_train_batch_size (`int`, *optional*, defaults to 8):
|
per_device_train_batch_size (`int`, *optional*, defaults to 8):
|
||||||
The batch size per GPU/XPU/TPU/MPS/NPU core/CPU for training.
|
The batch size per device accelerator core/CPU for training.
|
||||||
per_device_eval_batch_size (`int`, *optional*, defaults to 8):
|
per_device_eval_batch_size (`int`, *optional*, defaults to 8):
|
||||||
The batch size per GPU/XPU/TPU/MPS/NPU core/CPU for evaluation.
|
The batch size per device accelerator core/CPU for evaluation.
|
||||||
gradient_accumulation_steps (`int`, *optional*, defaults to 1):
|
gradient_accumulation_steps (`int`, *optional*, defaults to 1):
|
||||||
Number of updates steps to accumulate the gradients for, before performing a backward/update pass.
|
Number of updates steps to accumulate the gradients for, before performing a backward/update pass.
|
||||||
|
|
||||||
@ -275,7 +276,7 @@ class TrainingArguments:
|
|||||||
|
|
||||||
eval_accumulation_steps (`int`, *optional*):
|
eval_accumulation_steps (`int`, *optional*):
|
||||||
Number of predictions steps to accumulate the output tensors for, before moving the results to the CPU. If
|
Number of predictions steps to accumulate the output tensors for, before moving the results to the CPU. If
|
||||||
left unset, the whole predictions are accumulated on GPU/NPU/TPU before being moved to the CPU (faster but
|
left unset, the whole predictions are accumulated on the device accelerator before being moved to the CPU (faster but
|
||||||
requires more memory).
|
requires more memory).
|
||||||
eval_delay (`float`, *optional*):
|
eval_delay (`float`, *optional*):
|
||||||
Number of epochs or steps to wait for before the first evaluation can be performed, depending on the
|
Number of epochs or steps to wait for before the first evaluation can be performed, depending on the
|
||||||
@ -853,10 +854,10 @@ class TrainingArguments:
|
|||||||
)
|
)
|
||||||
|
|
||||||
per_device_train_batch_size: int = field(
|
per_device_train_batch_size: int = field(
|
||||||
default=8, metadata={"help": "Batch size per GPU/TPU/MPS/NPU core/CPU for training."}
|
default=8, metadata={"help": "Batch size per device accelerator core/CPU for training."}
|
||||||
)
|
)
|
||||||
per_device_eval_batch_size: int = field(
|
per_device_eval_batch_size: int = field(
|
||||||
default=8, metadata={"help": "Batch size per GPU/TPU/MPS/NPU core/CPU for evaluation."}
|
default=8, metadata={"help": "Batch size per device accelerator core/CPU for evaluation."}
|
||||||
)
|
)
|
||||||
|
|
||||||
per_gpu_train_batch_size: Optional[int] = field(
|
per_gpu_train_batch_size: Optional[int] = field(
|
||||||
@ -1044,7 +1045,7 @@ class TrainingArguments:
|
|||||||
use_cpu: bool = field(
|
use_cpu: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Whether or not to use cpu. If set to False, we will use cuda/tpu/mps/npu device if available."
|
"help": "Whether or not to use cpu. If left to False, we will use the available torch device/backend (cuda/mps/xpu/hpu etc.)"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
use_mps_device: bool = field(
|
use_mps_device: bool = field(
|
||||||
@ -1830,7 +1831,10 @@ class TrainingArguments:
|
|||||||
if (self.torch_compile_mode is not None or self.torch_compile_backend is not None) and not self.torch_compile:
|
if (self.torch_compile_mode is not None or self.torch_compile_backend is not None) and not self.torch_compile:
|
||||||
self.torch_compile = True
|
self.torch_compile = True
|
||||||
if self.torch_compile and self.torch_compile_backend is None:
|
if self.torch_compile and self.torch_compile_backend is None:
|
||||||
self.torch_compile_backend = "inductor"
|
if not self.use_cpu and is_torch_hpu_available():
|
||||||
|
self.torch_compile_backend = "hpu_backend"
|
||||||
|
else:
|
||||||
|
self.torch_compile_backend = "inductor"
|
||||||
|
|
||||||
# accelerate integration for torch compile
|
# accelerate integration for torch compile
|
||||||
if self.torch_compile:
|
if self.torch_compile:
|
||||||
@ -2312,6 +2316,9 @@ class TrainingArguments:
|
|||||||
elif is_torch_npu_available():
|
elif is_torch_npu_available():
|
||||||
device = torch.device("npu:0")
|
device = torch.device("npu:0")
|
||||||
torch.npu.set_device(device)
|
torch.npu.set_device(device)
|
||||||
|
elif is_torch_hpu_available():
|
||||||
|
device = torch.device("hpu:0")
|
||||||
|
torch.hpu.set_device(device)
|
||||||
else:
|
else:
|
||||||
# if n_gpu is > 1 we'll use nn.DataParallel.
|
# if n_gpu is > 1 we'll use nn.DataParallel.
|
||||||
# If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0`
|
# If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0`
|
||||||
|
@ -148,6 +148,7 @@ from .import_utils import (
|
|||||||
is_gguf_available,
|
is_gguf_available,
|
||||||
is_gptqmodel_available,
|
is_gptqmodel_available,
|
||||||
is_grokadamw_available,
|
is_grokadamw_available,
|
||||||
|
is_habana_gaudi1,
|
||||||
is_hadamard_available,
|
is_hadamard_available,
|
||||||
is_hqq_available,
|
is_hqq_available,
|
||||||
is_in_notebook,
|
is_in_notebook,
|
||||||
@ -218,6 +219,7 @@ from .import_utils import (
|
|||||||
is_torch_fx_available,
|
is_torch_fx_available,
|
||||||
is_torch_fx_proxy,
|
is_torch_fx_proxy,
|
||||||
is_torch_greater_or_equal,
|
is_torch_greater_or_equal,
|
||||||
|
is_torch_hpu_available,
|
||||||
is_torch_mlu_available,
|
is_torch_mlu_available,
|
||||||
is_torch_mps_available,
|
is_torch_mps_available,
|
||||||
is_torch_musa_available,
|
is_torch_musa_available,
|
||||||
@ -316,6 +318,9 @@ def get_available_devices() -> FrozenSet[str]:
|
|||||||
if is_torch_npu_available():
|
if is_torch_npu_available():
|
||||||
devices.add("npu")
|
devices.add("npu")
|
||||||
|
|
||||||
|
if is_torch_hpu_available():
|
||||||
|
devices.add("hpu")
|
||||||
|
|
||||||
if is_torch_mlu_available():
|
if is_torch_mlu_available():
|
||||||
devices.add("mlu")
|
devices.add("mlu")
|
||||||
|
|
||||||
|
@ -542,6 +542,12 @@ def is_torch_fp16_available_on_device(device):
|
|||||||
if not is_torch_available():
|
if not is_torch_available():
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
if is_torch_hpu_available():
|
||||||
|
if is_habana_gaudi1():
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -573,6 +579,9 @@ def is_torch_bf16_available_on_device(device):
|
|||||||
if device == "cuda":
|
if device == "cuda":
|
||||||
return is_torch_bf16_gpu_available()
|
return is_torch_bf16_gpu_available()
|
||||||
|
|
||||||
|
if device == "hpu":
|
||||||
|
return True
|
||||||
|
|
||||||
try:
|
try:
|
||||||
x = torch.zeros(2, 2, dtype=torch.bfloat16).to(device)
|
x = torch.zeros(2, 2, dtype=torch.bfloat16).to(device)
|
||||||
_ = x @ x
|
_ = x @ x
|
||||||
@ -773,6 +782,61 @@ def is_torch_musa_available(check_device=False):
|
|||||||
return hasattr(torch, "musa") and torch.musa.is_available()
|
return hasattr(torch, "musa") and torch.musa.is_available()
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def is_torch_hpu_available():
|
||||||
|
"Checks if `torch.hpu` is available and potentially if a HPU is in the environment"
|
||||||
|
if (
|
||||||
|
not _torch_available
|
||||||
|
or importlib.util.find_spec("habana_frameworks") is None
|
||||||
|
or importlib.util.find_spec("habana_frameworks.torch") is None
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
|
||||||
|
torch_hpu_min_version = "1.5.0"
|
||||||
|
if _accelerate_available and version.parse(_accelerate_version) < version.parse(torch_hpu_min_version):
|
||||||
|
return False
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if not hasattr(torch, "hpu") or not torch.hpu.is_available():
|
||||||
|
return False
|
||||||
|
|
||||||
|
import habana_frameworks.torch.utils.experimental as htexp # noqa: F401
|
||||||
|
|
||||||
|
# IlyasMoutawwakil: We patch masked_fill_ for int64 tensors to avoid a bug on Gaudi1
|
||||||
|
# synNodeCreateWithId failed for node: masked_fill_fwd_i64 with synStatus 26 [Generic failure]
|
||||||
|
# This can be removed once Gaudi1 support is discontinued but for now we need it to keep using
|
||||||
|
# dl1.24xlarge Gaudi1 instances on AWS for testing.
|
||||||
|
# check if the device is Gaudi1 (vs Gaudi2, Gaudi3).
|
||||||
|
if htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi:
|
||||||
|
original_masked_fill_ = torch.Tensor.masked_fill_
|
||||||
|
|
||||||
|
def patched_masked_fill_(self, mask, value):
|
||||||
|
if self.dtype == torch.int64:
|
||||||
|
logger.warning(
|
||||||
|
"In-place tensor.masked_fill_(mask, value) is not supported for int64 tensors on Gaudi1. "
|
||||||
|
"This operation will be performed out-of-place using tensor[mask] = value."
|
||||||
|
)
|
||||||
|
self[mask] = value
|
||||||
|
else:
|
||||||
|
original_masked_fill_(self, mask, value)
|
||||||
|
|
||||||
|
torch.Tensor.masked_fill_ = patched_masked_fill_
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def is_habana_gaudi1():
|
||||||
|
if not is_torch_hpu_available():
|
||||||
|
return False
|
||||||
|
|
||||||
|
import habana_frameworks.torch.utils.experimental as htexp # noqa: F401
|
||||||
|
|
||||||
|
# Check if the device is Gaudi1 (vs Gaudi2, Gaudi3)
|
||||||
|
return htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi
|
||||||
|
|
||||||
|
|
||||||
def is_torchdynamo_available():
|
def is_torchdynamo_available():
|
||||||
if not is_torch_available():
|
if not is_torch_available():
|
||||||
return False
|
return False
|
||||||
|
@ -45,12 +45,14 @@ from transformers.testing_utils import (
|
|||||||
require_deepspeed,
|
require_deepspeed,
|
||||||
require_optuna,
|
require_optuna,
|
||||||
require_torch_accelerator,
|
require_torch_accelerator,
|
||||||
|
require_torch_fp16,
|
||||||
require_torch_multi_accelerator,
|
require_torch_multi_accelerator,
|
||||||
|
run_first,
|
||||||
slow,
|
slow,
|
||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
from transformers.trainer_utils import get_last_checkpoint, set_seed
|
from transformers.trainer_utils import get_last_checkpoint, set_seed
|
||||||
from transformers.utils import SAFE_WEIGHTS_NAME, is_torch_bf16_available_on_device
|
from transformers.utils import SAFE_WEIGHTS_NAME, is_torch_bf16_available_on_device, is_torch_fp16_available_on_device
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@ -150,10 +152,12 @@ optims = [HF_OPTIM, DS_OPTIM]
|
|||||||
schedulers = [HF_SCHEDULER, DS_SCHEDULER]
|
schedulers = [HF_SCHEDULER, DS_SCHEDULER]
|
||||||
|
|
||||||
stages = [ZERO2, ZERO3]
|
stages = [ZERO2, ZERO3]
|
||||||
|
|
||||||
|
dtypes = []
|
||||||
if is_torch_bf16_available_on_device(torch_device):
|
if is_torch_bf16_available_on_device(torch_device):
|
||||||
dtypes = [FP16, BF16]
|
dtypes.append(BF16)
|
||||||
else:
|
if is_torch_fp16_available_on_device(torch_device):
|
||||||
dtypes = [FP16]
|
dtypes.append(FP16)
|
||||||
|
|
||||||
|
|
||||||
def parameterized_custom_name_func(func, param_num, param):
|
def parameterized_custom_name_func(func, param_num, param):
|
||||||
@ -228,6 +232,7 @@ class CoreIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
AutoModel.from_pretrained(T5_TINY)
|
AutoModel.from_pretrained(T5_TINY)
|
||||||
self.assertNotIn("Detected DeepSpeed ZeRO-3", cl.out)
|
self.assertNotIn("Detected DeepSpeed ZeRO-3", cl.out)
|
||||||
|
|
||||||
|
@require_torch_fp16
|
||||||
@require_torch_accelerator
|
@require_torch_accelerator
|
||||||
def test_init_zero3_fp16(self):
|
def test_init_zero3_fp16(self):
|
||||||
# test that zero.Init() works correctly under zero3/fp16
|
# test that zero.Init() works correctly under zero3/fp16
|
||||||
@ -456,6 +461,7 @@ class TrainerIntegrationDeepSpeedWithCustomConfig(TestCasePlus):
|
|||||||
|
|
||||||
|
|
||||||
@require_deepspeed
|
@require_deepspeed
|
||||||
|
@require_torch_fp16
|
||||||
@require_torch_accelerator
|
@require_torch_accelerator
|
||||||
class TrainerIntegrationDeepSpeed(TrainerIntegrationDeepSpeedWithCustomConfig, TrainerIntegrationCommon):
|
class TrainerIntegrationDeepSpeed(TrainerIntegrationDeepSpeedWithCustomConfig, TrainerIntegrationCommon):
|
||||||
"""
|
"""
|
||||||
@ -714,7 +720,7 @@ class TrainerIntegrationDeepSpeed(TrainerIntegrationDeepSpeedWithCustomConfig, T
|
|||||||
# dynamic loss scale value set to:
|
# dynamic loss scale value set to:
|
||||||
# "fp16.initial_scale_power": 1
|
# "fp16.initial_scale_power": 1
|
||||||
# plus having the same WarmupLR's warmup_min_lr == warmup_max_lr in the config file
|
# plus having the same WarmupLR's warmup_min_lr == warmup_max_lr in the config file
|
||||||
# but for some reason going to train_len=64 the weights, weights start to mismatch with this setup.
|
# but for some reason going to train_len=64, the weights start to mismatch with this setup.
|
||||||
# the culprit seems to be `initial_scale_power` - putting it back to its default 32 keeps the weights identical
|
# the culprit seems to be `initial_scale_power` - putting it back to its default 32 keeps the weights identical
|
||||||
|
|
||||||
train_len = 64
|
train_len = 64
|
||||||
@ -757,8 +763,12 @@ class TrainerIntegrationDeepSpeed(TrainerIntegrationDeepSpeedWithCustomConfig, T
|
|||||||
|
|
||||||
# training with half the batch size but accumulation steps as 2 should give the same
|
# training with half the batch size but accumulation steps as 2 should give the same
|
||||||
# weights, but sometimes get a slight difference still of 1e-6
|
# weights, but sometimes get a slight difference still of 1e-6
|
||||||
self.assertAlmostEqual(no_grad_accum_a, yes_grad_accum_a, places=5)
|
if torch_device == "hpu":
|
||||||
self.assertAlmostEqual(no_grad_accum_b, yes_grad_accum_b, places=5)
|
self.assertAlmostEqual(no_grad_accum_a, yes_grad_accum_a, delta=1e-4)
|
||||||
|
self.assertAlmostEqual(no_grad_accum_b, yes_grad_accum_b, delta=1e-4)
|
||||||
|
else:
|
||||||
|
self.assertAlmostEqual(no_grad_accum_a, yes_grad_accum_a, places=5)
|
||||||
|
self.assertAlmostEqual(no_grad_accum_b, yes_grad_accum_b, places=5)
|
||||||
|
|
||||||
# Relative difference. See the note above how to get identical loss on a small bs
|
# Relative difference. See the note above how to get identical loss on a small bs
|
||||||
self.assertTrue((no_grad_accum_loss - yes_grad_accum_loss) / (no_grad_accum_loss + 1e-15) <= 1e-3)
|
self.assertTrue((no_grad_accum_loss - yes_grad_accum_loss) / (no_grad_accum_loss + 1e-15) <= 1e-3)
|
||||||
@ -1100,6 +1110,7 @@ class TrainerIntegrationDeepSpeed(TrainerIntegrationDeepSpeedWithCustomConfig, T
|
|||||||
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
|
@run_first
|
||||||
@require_deepspeed
|
@require_deepspeed
|
||||||
@require_torch_accelerator
|
@require_torch_accelerator
|
||||||
class TestDeepSpeedWithLauncher(TestCasePlus):
|
class TestDeepSpeedWithLauncher(TestCasePlus):
|
||||||
@ -1126,6 +1137,7 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
|
|||||||
def test_basic_distributed(self, stage, dtype):
|
def test_basic_distributed(self, stage, dtype):
|
||||||
self.run_and_check(stage=stage, dtype=dtype, distributed=True)
|
self.run_and_check(stage=stage, dtype=dtype, distributed=True)
|
||||||
|
|
||||||
|
@require_torch_fp16
|
||||||
def test_do_eval_no_train(self):
|
def test_do_eval_no_train(self):
|
||||||
# testing only zero3 since zero2 makes no sense with inference
|
# testing only zero3 since zero2 makes no sense with inference
|
||||||
self.run_and_check(
|
self.run_and_check(
|
||||||
@ -1199,12 +1211,15 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
|
|||||||
if dtype == "bf16" and not is_torch_bf16_available_on_device(torch_device):
|
if dtype == "bf16" and not is_torch_bf16_available_on_device(torch_device):
|
||||||
self.skipTest(reason="test requires bfloat16 hardware support")
|
self.skipTest(reason="test requires bfloat16 hardware support")
|
||||||
|
|
||||||
|
if dtype == "fp16" and not is_torch_fp16_available_on_device(torch_device):
|
||||||
|
self.skipTest(reason="test requires fp16 hardware support")
|
||||||
|
|
||||||
# this is just inference, so no optimizer should be loaded
|
# this is just inference, so no optimizer should be loaded
|
||||||
# it only works for z3 (makes no sense with z1-z2)
|
# it only works for z3 (makes no sense with z1-z2)
|
||||||
fp32 = True if dtype == "fp32" else False
|
fp32 = True if dtype == "fp32" else False
|
||||||
self.run_and_check(
|
self.run_and_check(
|
||||||
stage=ZERO3,
|
stage=ZERO3,
|
||||||
dtype=FP16,
|
dtype=dtype,
|
||||||
model_name=T5_TINY,
|
model_name=T5_TINY,
|
||||||
distributed=True,
|
distributed=True,
|
||||||
do_train=False,
|
do_train=False,
|
||||||
@ -1381,6 +1396,7 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
|
|||||||
# print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die
|
# print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die
|
||||||
execute_subprocess_async(cmd, env=self.get_env())
|
execute_subprocess_async(cmd, env=self.get_env())
|
||||||
|
|
||||||
|
@require_torch_fp16
|
||||||
def test_clm_from_config_zero3_fp16(self):
|
def test_clm_from_config_zero3_fp16(self):
|
||||||
# this test exercises AutoModel.from_config(config) - to ensure zero.Init is called
|
# this test exercises AutoModel.from_config(config) - to ensure zero.Init is called
|
||||||
|
|
||||||
|
@ -33,12 +33,17 @@ from transformers.testing_utils import (
|
|||||||
require_fsdp,
|
require_fsdp,
|
||||||
require_torch_accelerator,
|
require_torch_accelerator,
|
||||||
require_torch_multi_accelerator,
|
require_torch_multi_accelerator,
|
||||||
|
run_first,
|
||||||
slow,
|
slow,
|
||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
from transformers.trainer_callback import TrainerState
|
from transformers.trainer_callback import TrainerState
|
||||||
from transformers.trainer_utils import FSDPOption, set_seed
|
from transformers.trainer_utils import FSDPOption, set_seed
|
||||||
from transformers.utils import is_accelerate_available, is_torch_bf16_available_on_device
|
from transformers.utils import (
|
||||||
|
is_accelerate_available,
|
||||||
|
is_torch_bf16_available_on_device,
|
||||||
|
is_torch_fp16_available_on_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@ -49,14 +54,19 @@ else:
|
|||||||
|
|
||||||
# default torch.distributed port
|
# default torch.distributed port
|
||||||
DEFAULT_MASTER_PORT = "10999"
|
DEFAULT_MASTER_PORT = "10999"
|
||||||
dtypes = ["fp16"]
|
|
||||||
|
dtypes = []
|
||||||
if is_torch_bf16_available_on_device(torch_device):
|
if is_torch_bf16_available_on_device(torch_device):
|
||||||
dtypes += ["bf16"]
|
dtypes += ["bf16"]
|
||||||
|
if is_torch_fp16_available_on_device(torch_device):
|
||||||
|
dtypes += ["fp16"]
|
||||||
|
|
||||||
sharding_strategies = ["full_shard", "shard_grad_op"]
|
sharding_strategies = ["full_shard", "shard_grad_op"]
|
||||||
state_dict_types = ["FULL_STATE_DICT", "SHARDED_STATE_DICT"]
|
state_dict_types = ["FULL_STATE_DICT", "SHARDED_STATE_DICT"]
|
||||||
set_seed(42)
|
|
||||||
params = list(itertools.product(sharding_strategies, dtypes))
|
params = list(itertools.product(sharding_strategies, dtypes))
|
||||||
|
|
||||||
|
set_seed(42)
|
||||||
|
|
||||||
|
|
||||||
def get_master_port(real_launcher=False):
|
def get_master_port(real_launcher=False):
|
||||||
"""
|
"""
|
||||||
@ -140,13 +150,13 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
}
|
}
|
||||||
|
|
||||||
self.fsdp_config = {
|
self.fsdp_config = {
|
||||||
"backward_prefetch": "backward_pre",
|
"backward_prefetch": "BACKWARD_PRE",
|
||||||
"forward_prefetch": "False",
|
"forward_prefetch": "false",
|
||||||
"limit_all_gathers": "False",
|
"limit_all_gathers": "false",
|
||||||
"use_orig_params": "True",
|
"use_orig_params": "true",
|
||||||
"sync_module_states": "True",
|
"sync_module_states": "true",
|
||||||
"cpu_ram_efficient_loading": "True",
|
"cpu_ram_efficient_loading": "true",
|
||||||
"activation_checkpointing": "False",
|
"activation_checkpointing": "false",
|
||||||
"min_num_params": 1,
|
"min_num_params": 1,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -202,7 +212,7 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
os.environ[f"{prefix}TRANSFORMER_CLS_TO_WRAP"], ",".join(fsdp_config["transformer_layer_cls_to_wrap"])
|
os.environ[f"{prefix}TRANSFORMER_CLS_TO_WRAP"], ",".join(fsdp_config["transformer_layer_cls_to_wrap"])
|
||||||
)
|
)
|
||||||
self.assertEqual(os.environ[f"{prefix}BACKWARD_PREFETCH"], fsdp_config["backward_prefetch"].upper())
|
self.assertEqual(os.environ[f"{prefix}BACKWARD_PREFETCH"], fsdp_config["backward_prefetch"])
|
||||||
self.assertEqual(os.environ[f"{prefix}FORWARD_PREFETCH"], fsdp_config["forward_prefetch"])
|
self.assertEqual(os.environ[f"{prefix}FORWARD_PREFETCH"], fsdp_config["forward_prefetch"])
|
||||||
self.assertEqual(os.environ[f"{prefix}USE_ORIG_PARAMS"], fsdp_config["use_orig_params"])
|
self.assertEqual(os.environ[f"{prefix}USE_ORIG_PARAMS"], fsdp_config["use_orig_params"])
|
||||||
self.assertEqual(os.environ[f"{prefix}SYNC_MODULE_STATES"], fsdp_config["sync_module_states"])
|
self.assertEqual(os.environ[f"{prefix}SYNC_MODULE_STATES"], fsdp_config["sync_module_states"])
|
||||||
@ -213,6 +223,7 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
|
|
||||||
@parameterized.expand(params, name_func=_parameterized_custom_name_func)
|
@parameterized.expand(params, name_func=_parameterized_custom_name_func)
|
||||||
@require_torch_multi_accelerator
|
@require_torch_multi_accelerator
|
||||||
|
@run_first
|
||||||
@slow
|
@slow
|
||||||
def test_basic_run(self, sharding_strategy, dtype):
|
def test_basic_run(self, sharding_strategy, dtype):
|
||||||
launcher = get_launcher(distributed=True, use_accelerate=False)
|
launcher = get_launcher(distributed=True, use_accelerate=False)
|
||||||
@ -225,6 +236,7 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
|
|
||||||
@parameterized.expand(params, name_func=_parameterized_custom_name_func)
|
@parameterized.expand(params, name_func=_parameterized_custom_name_func)
|
||||||
@require_torch_multi_accelerator
|
@require_torch_multi_accelerator
|
||||||
|
@run_first
|
||||||
@slow
|
@slow
|
||||||
def test_basic_run_with_gradient_accumulation(self, sharding_strategy, dtype):
|
def test_basic_run_with_gradient_accumulation(self, sharding_strategy, dtype):
|
||||||
launcher = get_launcher(distributed=True, use_accelerate=False)
|
launcher = get_launcher(distributed=True, use_accelerate=False)
|
||||||
@ -237,6 +249,7 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
|
|
||||||
@parameterized.expand(dtypes)
|
@parameterized.expand(dtypes)
|
||||||
@require_torch_multi_accelerator
|
@require_torch_multi_accelerator
|
||||||
|
@run_first
|
||||||
@slow
|
@slow
|
||||||
@unittest.skipIf(not is_torch_greater_or_equal_than_2_1, reason="This test on pytorch 2.0 takes 4 hours.")
|
@unittest.skipIf(not is_torch_greater_or_equal_than_2_1, reason="This test on pytorch 2.0 takes 4 hours.")
|
||||||
def test_basic_run_with_cpu_offload(self, dtype):
|
def test_basic_run_with_cpu_offload(self, dtype):
|
||||||
@ -250,6 +263,7 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
|
|
||||||
@parameterized.expand(state_dict_types, name_func=_parameterized_custom_name_func)
|
@parameterized.expand(state_dict_types, name_func=_parameterized_custom_name_func)
|
||||||
@require_torch_multi_accelerator
|
@require_torch_multi_accelerator
|
||||||
|
@run_first
|
||||||
@slow
|
@slow
|
||||||
def test_training_and_can_resume_normally(self, state_dict_type):
|
def test_training_and_can_resume_normally(self, state_dict_type):
|
||||||
output_dir = self.get_auto_remove_tmp_dir("./xxx", after=False)
|
output_dir = self.get_auto_remove_tmp_dir("./xxx", after=False)
|
||||||
@ -286,10 +300,13 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
self.assertAlmostEqual(log["learning_rate"], log1["learning_rate"], delta=1e-5)
|
self.assertAlmostEqual(log["learning_rate"], log1["learning_rate"], delta=1e-5)
|
||||||
|
|
||||||
@require_torch_multi_accelerator
|
@require_torch_multi_accelerator
|
||||||
|
@run_first
|
||||||
@slow
|
@slow
|
||||||
@require_torch_accelerator
|
|
||||||
@require_fsdp
|
|
||||||
def test_fsdp_cpu_offloading(self):
|
def test_fsdp_cpu_offloading(self):
|
||||||
|
# TODO: This file is missing and should be added or the test should be removed
|
||||||
|
if not os.path.exists("utils/testing_scripts/fsdp_cpu_offloading.py"):
|
||||||
|
raise unittest.SkipTest("FSDP CPU offloading script not found!")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
subprocess.run(
|
subprocess.run(
|
||||||
"accelerate launch utils/testing_scripts/fsdp_cpu_offloading.py --config utils/testing_scripts/dummy_fsdp_config.yml",
|
"accelerate launch utils/testing_scripts/fsdp_cpu_offloading.py --config utils/testing_scripts/dummy_fsdp_config.yml",
|
||||||
|
@ -2770,7 +2770,7 @@ class ModelTesterMixin:
|
|||||||
elif param_device in ["mps"]:
|
elif param_device in ["mps"]:
|
||||||
self.assertEqual(param.device, torch.device("mps"))
|
self.assertEqual(param.device, torch.device("mps"))
|
||||||
else:
|
else:
|
||||||
# when loaded with device_map, `param_device` are integer values for cuda/xpu/npu/mlu
|
# when loaded with device_map, `param_device` are integer values for cuda/xpu/hpu/npu/mlu
|
||||||
self.assertEqual(param.device, torch.device(f"{torch_device}:{param_device}"))
|
self.assertEqual(param.device, torch.device(f"{torch_device}:{param_device}"))
|
||||||
|
|
||||||
@require_accelerate
|
@require_accelerate
|
||||||
|
@ -75,6 +75,7 @@ from transformers.testing_utils import (
|
|||||||
require_intel_extension_for_pytorch,
|
require_intel_extension_for_pytorch,
|
||||||
require_liger_kernel,
|
require_liger_kernel,
|
||||||
require_lomo,
|
require_lomo,
|
||||||
|
require_non_hpu,
|
||||||
require_non_xpu,
|
require_non_xpu,
|
||||||
require_optuna,
|
require_optuna,
|
||||||
require_peft,
|
require_peft,
|
||||||
@ -88,6 +89,7 @@ from transformers.testing_utils import (
|
|||||||
require_torch,
|
require_torch,
|
||||||
require_torch_accelerator,
|
require_torch_accelerator,
|
||||||
require_torch_bf16,
|
require_torch_bf16,
|
||||||
|
require_torch_fp16,
|
||||||
require_torch_gpu,
|
require_torch_gpu,
|
||||||
require_torch_multi_accelerator,
|
require_torch_multi_accelerator,
|
||||||
require_torch_non_multi_accelerator,
|
require_torch_non_multi_accelerator,
|
||||||
@ -98,6 +100,7 @@ from transformers.testing_utils import (
|
|||||||
require_torchdynamo,
|
require_torchdynamo,
|
||||||
require_vision,
|
require_vision,
|
||||||
require_wandb,
|
require_wandb,
|
||||||
|
run_first,
|
||||||
run_test_using_subprocess,
|
run_test_using_subprocess,
|
||||||
slow,
|
slow,
|
||||||
torch_device,
|
torch_device,
|
||||||
@ -119,6 +122,13 @@ from transformers.utils import (
|
|||||||
from transformers.utils.hp_naming import TrialShortNamer
|
from transformers.utils.hp_naming import TrialShortNamer
|
||||||
|
|
||||||
|
|
||||||
|
if torch_device == "hpu":
|
||||||
|
RTOL = 1e-3
|
||||||
|
ATOL = 1e-3
|
||||||
|
else:
|
||||||
|
RTOL = 1e-5
|
||||||
|
ATOL = 1e-5
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -726,11 +736,11 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
trainer.train()
|
trainer.train()
|
||||||
self.alternate_trained_model = (trainer.model.a, trainer.model.b)
|
self.alternate_trained_model = (trainer.model.a, trainer.model.b)
|
||||||
|
|
||||||
def check_trained_model(self, model, alternate_seed=False):
|
def check_trained_model(self, model, alternate_seed=False, **kwargs):
|
||||||
# Checks a training seeded with learning_rate = 0.1
|
# Checks a training seeded with learning_rate = 0.1
|
||||||
(a, b) = self.alternate_trained_model if alternate_seed else self.default_trained_model
|
(a, b) = self.alternate_trained_model if alternate_seed else self.default_trained_model
|
||||||
torch.testing.assert_close(model.a, a)
|
torch.testing.assert_close(model.a, a, **kwargs)
|
||||||
torch.testing.assert_close(model.b, b)
|
torch.testing.assert_close(model.b, b, **kwargs)
|
||||||
|
|
||||||
def test_reproducible_training(self):
|
def test_reproducible_training(self):
|
||||||
# Checks that training worked, model trained and seed made a reproducible training.
|
# Checks that training worked, model trained and seed made a reproducible training.
|
||||||
@ -812,11 +822,6 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
|
|
||||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(model_name)
|
|
||||||
state_dict = model.state_dict()
|
|
||||||
|
|
||||||
base_loss_callback = StoreLossCallback()
|
|
||||||
|
|
||||||
args_kwargs = {
|
args_kwargs = {
|
||||||
"report_to": "none",
|
"report_to": "none",
|
||||||
"logging_steps": 1,
|
"logging_steps": 1,
|
||||||
@ -830,6 +835,10 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
tmp_dir,
|
tmp_dir,
|
||||||
**args_kwargs,
|
**args_kwargs,
|
||||||
)
|
)
|
||||||
|
# train with base loss
|
||||||
|
set_seed(42)
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||||
|
base_loss_callback = StoreLossCallback()
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
model,
|
model,
|
||||||
args,
|
args,
|
||||||
@ -840,16 +849,17 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
assert trainer.model_accepts_loss_kwargs
|
assert trainer.model_accepts_loss_kwargs
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
grad_accum_loss_callback = StoreLossCallback()
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
||||||
args = TrainingArguments(
|
args = TrainingArguments(
|
||||||
tmp_dir,
|
tmp_dir,
|
||||||
**args_kwargs,
|
**args_kwargs,
|
||||||
gradient_accumulation_steps=2,
|
gradient_accumulation_steps=2,
|
||||||
per_device_train_batch_size=4,
|
per_device_train_batch_size=4,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# train with gradient accumulation
|
||||||
set_seed(42)
|
set_seed(42)
|
||||||
model = AutoModelForCausalLM.from_pretrained(model_name)
|
model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||||
|
grad_accum_loss_callback = StoreLossCallback()
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
model,
|
model,
|
||||||
args,
|
args,
|
||||||
@ -857,10 +867,12 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
callbacks=[grad_accum_loss_callback],
|
callbacks=[grad_accum_loss_callback],
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
)
|
)
|
||||||
|
assert trainer.model_accepts_loss_kwargs
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
|
# train with broken loss
|
||||||
set_seed(42)
|
set_seed(42)
|
||||||
model.load_state_dict(state_dict)
|
model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||||
broken_loss_callback = StoreLossCallback()
|
broken_loss_callback = StoreLossCallback()
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
model,
|
model,
|
||||||
@ -869,30 +881,28 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
callbacks=[broken_loss_callback],
|
callbacks=[broken_loss_callback],
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
)
|
)
|
||||||
# disable model_accepts_loss_kwargs
|
# disable model_accepts_loss_kwargs so that "num_items_in_batch" is not passed to the model
|
||||||
trainer.model_accepts_loss_kwargs = False
|
trainer.model_accepts_loss_kwargs = False
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
# Calculate the difference between the base loss and the grad_accum loss
|
# Calculate the difference between the base loss and the grad_accum loss
|
||||||
diff_truth = [
|
diff_truth = [
|
||||||
abs(base - grad) for base, grad in zip(base_loss_callback.losses, grad_accum_loss_callback.losses)
|
abs(base - grad) for base, grad in zip(base_loss_callback.losses, grad_accum_loss_callback.losses)
|
||||||
]
|
]
|
||||||
diff_broken = [
|
diff_broken = [abs(base - grad) for base, grad in zip(base_loss_callback.losses, broken_loss_callback.losses)]
|
||||||
abs(base - grad) for base, grad in zip(base_loss_callback.losses, broken_loss_callback.losses)
|
|
||||||
]
|
|
||||||
|
|
||||||
# all diff truth should be quite close
|
# all diff truth should be quite close
|
||||||
self.assertLess(max(diff_truth), 0.01, f"Difference {max(diff_truth)} is not within 0.01")
|
self.assertLess(max(diff_truth), 0.01, f"Difference {max(diff_truth)} is not within 0.01")
|
||||||
|
|
||||||
# max diff broken should be very off
|
# max diff broken should be very off
|
||||||
self.assertGreater(max(diff_broken), 1.5, f"Difference {max(diff_broken)} is not greater than 2")
|
self.assertGreater(max(diff_broken), 1.3, f"Difference {max(diff_broken)} is not greater than 1.3")
|
||||||
|
|
||||||
loss_base = sum(base_loss_callback.losses)
|
loss_base = sum(base_loss_callback.losses)
|
||||||
loss_broken = sum(broken_loss_callback.losses)
|
loss_broken = sum(broken_loss_callback.losses)
|
||||||
|
|
||||||
# mean/sum loss should not vary too much.
|
# mean/sum loss should not vary too much.
|
||||||
relative_diff = abs(loss_base - loss_broken) / max(loss_base, loss_broken)
|
relative_diff = abs(loss_base - loss_broken) / max(loss_base, loss_broken)
|
||||||
self.assertLess(relative_diff, 0.2, f"Relative difference {relative_diff} is not within 0.2")
|
self.assertLess(relative_diff, 0.2, f"Relative difference {relative_diff} is not within 0.2")
|
||||||
|
|
||||||
def test_gradient_accumulation_loss_alignment_with_loss_func(self):
|
def test_gradient_accumulation_loss_alignment_with_loss_func(self):
|
||||||
set_seed(42)
|
set_seed(42)
|
||||||
@ -1214,14 +1224,14 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
self.assertFalse(torch.allclose(trainer.model.b, b))
|
self.assertFalse(torch.allclose(trainer.model.b, b))
|
||||||
self.assertGreater(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 0)
|
self.assertGreater(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 0)
|
||||||
|
|
||||||
@require_torch_accelerator
|
|
||||||
@require_torch_bf16
|
@require_torch_bf16
|
||||||
|
@require_torch_accelerator
|
||||||
def test_mixed_bf16(self):
|
def test_mixed_bf16(self):
|
||||||
# very basic test
|
# very basic test
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
trainer = get_regression_trainer(learning_rate=0.1, bf16=True, output_dir=tmp_dir)
|
trainer = get_regression_trainer(learning_rate=0.1, bf16=True, output_dir=tmp_dir)
|
||||||
trainer.train()
|
trainer.train()
|
||||||
self.check_trained_model(trainer.model)
|
self.check_trained_model(trainer.model, atol=ATOL, rtol=RTOL)
|
||||||
|
|
||||||
# --bf16 --half_precision_backend apex can't be used together
|
# --bf16 --half_precision_backend apex can't be used together
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
@ -3582,6 +3592,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
|
@run_first
|
||||||
def test_trainer_eval_mrpc(self):
|
def test_trainer_eval_mrpc(self):
|
||||||
MODEL_ID = "google-bert/bert-base-cased-finetuned-mrpc"
|
MODEL_ID = "google-bert/bert-base-cased-finetuned-mrpc"
|
||||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
||||||
@ -3598,6 +3609,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
self.assertLess(result["eval_loss"], 0.2)
|
self.assertLess(result["eval_loss"], 0.2)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
|
@run_first
|
||||||
def test_trainer_eval_multiple(self):
|
def test_trainer_eval_multiple(self):
|
||||||
MODEL_ID = "openai-community/gpt2"
|
MODEL_ID = "openai-community/gpt2"
|
||||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
||||||
@ -3897,6 +3909,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
trainer = get_regression_trainer(skip_memory_metrics=True, output_dir=tmp_dir)
|
trainer = get_regression_trainer(skip_memory_metrics=True, output_dir=tmp_dir)
|
||||||
self.check_mem_metrics(trainer, self.assertNotIn)
|
self.check_mem_metrics(trainer, self.assertNotIn)
|
||||||
|
|
||||||
|
@require_torch_fp16
|
||||||
@require_torch_accelerator
|
@require_torch_accelerator
|
||||||
def test_fp16_full_eval(self):
|
def test_fp16_full_eval(self):
|
||||||
# this is a sensitive test so let's keep debugging printouts in place for quick diagnosis.
|
# this is a sensitive test so let's keep debugging printouts in place for quick diagnosis.
|
||||||
@ -4152,6 +4165,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
self.assertListEqual(trainer.optimizer.param_groups[1]["params"], no_wd_params)
|
self.assertListEqual(trainer.optimizer.param_groups[1]["params"], no_wd_params)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
|
@require_non_hpu
|
||||||
@require_torch_multi_accelerator
|
@require_torch_multi_accelerator
|
||||||
def test_end_to_end_example(self):
|
def test_end_to_end_example(self):
|
||||||
# Tests that `translation.py` will run without issues
|
# Tests that `translation.py` will run without issues
|
||||||
|
@ -19,12 +19,11 @@ import numpy as np
|
|||||||
from transformers import EvalPrediction, HfArgumentParser, TrainingArguments, is_torch_available
|
from transformers import EvalPrediction, HfArgumentParser, TrainingArguments, is_torch_available
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
TestCasePlus,
|
TestCasePlus,
|
||||||
|
backend_device_count,
|
||||||
execute_subprocess_async,
|
execute_subprocess_async,
|
||||||
get_torch_dist_unique_port,
|
get_torch_dist_unique_port,
|
||||||
require_torch_multi_gpu,
|
require_torch_multi_accelerator,
|
||||||
require_torch_multi_xpu,
|
torch_device,
|
||||||
require_torch_neuroncore,
|
|
||||||
require_torch_npu,
|
|
||||||
)
|
)
|
||||||
from transformers.training_args import ParallelMode
|
from transformers.training_args import ParallelMode
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
@ -117,38 +116,10 @@ if is_torch_available():
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
class TestTrainerDistributedNeuronCore(TestCasePlus):
|
|
||||||
@require_torch_neuroncore
|
|
||||||
def test_trainer(self):
|
|
||||||
distributed_args = f"""--nproc_per_node=2
|
|
||||||
--master_port={get_torch_dist_unique_port()}
|
|
||||||
{self.test_file_dir}/test_trainer_distributed.py
|
|
||||||
""".split()
|
|
||||||
output_dir = self.get_auto_remove_tmp_dir()
|
|
||||||
args = f"--output_dir {output_dir}".split()
|
|
||||||
cmd = ["torchrun"] + distributed_args + args
|
|
||||||
execute_subprocess_async(cmd, env=self.get_env())
|
|
||||||
# successful return here == success - any errors would have caused an error in the sub-call
|
|
||||||
|
|
||||||
|
|
||||||
class TestTrainerDistributedNPU(TestCasePlus):
|
|
||||||
@require_torch_npu
|
|
||||||
def test_trainer(self):
|
|
||||||
distributed_args = f"""--nproc_per_node=2
|
|
||||||
--master_port={get_torch_dist_unique_port()}
|
|
||||||
{self.test_file_dir}/test_trainer_distributed.py
|
|
||||||
""".split()
|
|
||||||
output_dir = self.get_auto_remove_tmp_dir()
|
|
||||||
args = f"--output_dir {output_dir}".split()
|
|
||||||
cmd = ["torchrun"] + distributed_args + args
|
|
||||||
execute_subprocess_async(cmd, env=self.get_env())
|
|
||||||
# successful return here == success - any errors would have caused an error in the sub-call
|
|
||||||
|
|
||||||
|
|
||||||
class TestTrainerDistributed(TestCasePlus):
|
class TestTrainerDistributed(TestCasePlus):
|
||||||
@require_torch_multi_gpu
|
@require_torch_multi_accelerator
|
||||||
def test_trainer(self):
|
def test_trainer(self):
|
||||||
distributed_args = f"""--nproc_per_node={torch.cuda.device_count()}
|
distributed_args = f"""--nproc_per_node={backend_device_count(torch_device)}
|
||||||
--master_port={get_torch_dist_unique_port()}
|
--master_port={get_torch_dist_unique_port()}
|
||||||
{self.test_file_dir}/test_trainer_distributed.py
|
{self.test_file_dir}/test_trainer_distributed.py
|
||||||
""".split()
|
""".split()
|
||||||
@ -159,20 +130,6 @@ class TestTrainerDistributed(TestCasePlus):
|
|||||||
# successful return here == success - any errors would have caused an error in the sub-call
|
# successful return here == success - any errors would have caused an error in the sub-call
|
||||||
|
|
||||||
|
|
||||||
@require_torch_multi_xpu
|
|
||||||
class TestTrainerDistributedXPU(TestCasePlus):
|
|
||||||
def test_trainer(self):
|
|
||||||
distributed_args = f"""--nproc_per_node={torch.xpu.device_count()}
|
|
||||||
--master_port={get_torch_dist_unique_port()}
|
|
||||||
{self.test_file_dir}/test_trainer_distributed.py
|
|
||||||
""".split()
|
|
||||||
output_dir = self.get_auto_remove_tmp_dir()
|
|
||||||
args = f"--output_dir {output_dir}".split()
|
|
||||||
cmd = ["torchrun"] + distributed_args + args
|
|
||||||
execute_subprocess_async(cmd, env=self.get_env())
|
|
||||||
# successful return here == success - any errors would have caused an error in the sub-call
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# The script below is meant to be run under torch.distributed, on a machine with multiple GPUs:
|
# The script below is meant to be run under torch.distributed, on a machine with multiple GPUs:
|
||||||
#
|
#
|
||||||
|
@ -17,12 +17,15 @@ from typing import Dict
|
|||||||
from transformers import is_torch_available
|
from transformers import is_torch_available
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
TestCasePlus,
|
TestCasePlus,
|
||||||
|
backend_device_count,
|
||||||
execute_subprocess_async,
|
execute_subprocess_async,
|
||||||
get_torch_dist_unique_port,
|
get_torch_dist_unique_port,
|
||||||
require_accelerate,
|
require_accelerate,
|
||||||
require_fp8,
|
require_fp8,
|
||||||
require_fsdp,
|
require_fsdp,
|
||||||
require_torch_multi_gpu,
|
require_torch_multi_accelerator,
|
||||||
|
run_first,
|
||||||
|
torch_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -64,9 +67,10 @@ if is_torch_available():
|
|||||||
|
|
||||||
|
|
||||||
class TestFSDPTrainer(TestCasePlus):
|
class TestFSDPTrainer(TestCasePlus):
|
||||||
|
@require_torch_multi_accelerator
|
||||||
@require_accelerate
|
@require_accelerate
|
||||||
@require_torch_multi_gpu
|
|
||||||
@require_fsdp
|
@require_fsdp
|
||||||
|
@run_first
|
||||||
def test_trainer(self):
|
def test_trainer(self):
|
||||||
output_dir = self.get_auto_remove_tmp_dir()
|
output_dir = self.get_auto_remove_tmp_dir()
|
||||||
cmd = [
|
cmd = [
|
||||||
@ -76,7 +80,7 @@ class TestFSDPTrainer(TestCasePlus):
|
|||||||
"--main_process_port",
|
"--main_process_port",
|
||||||
f"{get_torch_dist_unique_port()}",
|
f"{get_torch_dist_unique_port()}",
|
||||||
"--num_processes",
|
"--num_processes",
|
||||||
f"{torch.cuda.device_count()}",
|
f"{backend_device_count(torch_device)}",
|
||||||
"--fsdp_transformer_layer_cls_to_wrap",
|
"--fsdp_transformer_layer_cls_to_wrap",
|
||||||
"GPT2Block",
|
"GPT2Block",
|
||||||
f"{self.test_file_dir}/test_trainer_fsdp.py",
|
f"{self.test_file_dir}/test_trainer_fsdp.py",
|
||||||
@ -90,10 +94,11 @@ class TestFSDPTrainer(TestCasePlus):
|
|||||||
|
|
||||||
|
|
||||||
class TestFSDPTrainerFP8(TestCasePlus):
|
class TestFSDPTrainerFP8(TestCasePlus):
|
||||||
|
@require_torch_multi_accelerator
|
||||||
@require_accelerate
|
@require_accelerate
|
||||||
@require_torch_multi_gpu
|
|
||||||
@require_fsdp
|
@require_fsdp
|
||||||
@require_fp8
|
@require_fp8
|
||||||
|
@run_first
|
||||||
def test_trainer(self):
|
def test_trainer(self):
|
||||||
output_dir = self.get_auto_remove_tmp_dir()
|
output_dir = self.get_auto_remove_tmp_dir()
|
||||||
cmd = [
|
cmd = [
|
||||||
@ -103,7 +108,7 @@ class TestFSDPTrainerFP8(TestCasePlus):
|
|||||||
"--main_process_port",
|
"--main_process_port",
|
||||||
f"{get_torch_dist_unique_port()}",
|
f"{get_torch_dist_unique_port()}",
|
||||||
"--num_processes",
|
"--num_processes",
|
||||||
f"{torch.cuda.device_count()}",
|
f"{backend_device_count(torch_device)}",
|
||||||
"--mixed_precision",
|
"--mixed_precision",
|
||||||
"fp8",
|
"fp8",
|
||||||
"--fsdp_transformer_layer_cls_to_wrap",
|
"--fsdp_transformer_layer_cls_to_wrap",
|
||||||
@ -117,32 +122,34 @@ class TestFSDPTrainerFP8(TestCasePlus):
|
|||||||
execute_subprocess_async(cmd, env=self.get_env())
|
execute_subprocess_async(cmd, env=self.get_env())
|
||||||
# successful return here == success - any errors would have caused an error in the sub-call
|
# successful return here == success - any errors would have caused an error in the sub-call
|
||||||
|
|
||||||
class TestFSDPTrainerWrap(TestCasePlus):
|
|
||||||
@require_accelerate
|
class TestFSDPTrainerWrap(TestCasePlus):
|
||||||
@require_torch_multi_gpu
|
@require_torch_multi_accelerator
|
||||||
@require_fsdp
|
@require_accelerate
|
||||||
def test_trainer(self):
|
@require_fsdp
|
||||||
output_dir = self.get_auto_remove_tmp_dir()
|
@run_first
|
||||||
cmd = [
|
def test_trainer(self):
|
||||||
"accelerate",
|
output_dir = self.get_auto_remove_tmp_dir()
|
||||||
"launch",
|
cmd = [
|
||||||
"--use_fsdp",
|
"accelerate",
|
||||||
"--main_process_port",
|
"launch",
|
||||||
f"{get_torch_dist_unique_port()}",
|
"--use_fsdp",
|
||||||
"--num_processes",
|
"--main_process_port",
|
||||||
f"{torch.cuda.device_count()}",
|
f"{get_torch_dist_unique_port()}",
|
||||||
"--fsdp_transformer_layer_cls_to_wrap",
|
"--num_processes",
|
||||||
"GPT2Block",
|
f"{backend_device_count(torch_device)}",
|
||||||
f"{self.test_file_dir}/test_trainer_fsdp.py",
|
"--fsdp_transformer_layer_cls_to_wrap",
|
||||||
"--output_dir",
|
"GPT2Block",
|
||||||
f"{output_dir}",
|
f"{self.test_file_dir}/test_trainer_fsdp.py",
|
||||||
"--report_to",
|
"--output_dir",
|
||||||
"none",
|
f"{output_dir}",
|
||||||
"--auto_find_batch_size",
|
"--report_to",
|
||||||
"True",
|
"none",
|
||||||
]
|
"--auto_find_batch_size",
|
||||||
execute_subprocess_async(cmd, env=self.get_env())
|
"True",
|
||||||
# successful return here == success - any errors would have caused an error in the sub-call
|
]
|
||||||
|
execute_subprocess_async(cmd, env=self.get_env())
|
||||||
|
# successful return here == success - any errors would have caused an error in the sub-call
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
Reference in New Issue
Block a user