mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
enable trainer test cases on xpu (#38138)
* enable trainer test cases on xpu Signed-off-by: Matrix Yao <matrix.yao@intel.com> * fix style Signed-off-by: Matrix Yao <matrix.yao@intel.com> --------- Signed-off-by: Matrix Yao <matrix.yao@intel.com>
This commit is contained in:
parent
b11b28cc4e
commit
7caa57e85e
@ -2039,7 +2039,7 @@ class TestCasePlus(unittest.TestCase):
|
||||
|
||||
"""
|
||||
env = os.environ.copy()
|
||||
paths = [self.src_dir_str]
|
||||
paths = [self.repo_root_dir_str, self.src_dir_str]
|
||||
if "/examples" in self.test_file_dir_str:
|
||||
paths.append(self.examples_dir_str)
|
||||
else:
|
||||
|
@ -97,7 +97,6 @@ from transformers.testing_utils import (
|
||||
require_torch_fp16,
|
||||
require_torch_gpu,
|
||||
require_torch_multi_accelerator,
|
||||
require_torch_multi_gpu,
|
||||
require_torch_non_multi_accelerator,
|
||||
require_torch_non_multi_gpu,
|
||||
require_torch_tensorrt_fx,
|
||||
@ -3766,7 +3765,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
train_output = trainer.train()
|
||||
self.assertEqual(train_output.global_step, int(self.n_epochs))
|
||||
|
||||
@require_torch_multi_gpu
|
||||
@require_torch_multi_accelerator
|
||||
def test_num_batches_in_training_with_gradient_accumulation(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
for num_train_epochs in [1, 2]:
|
||||
|
@ -1,7 +1,6 @@
|
||||
import json
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
|
||||
from tests.trainer.test_trainer import StoreLossCallback
|
||||
from transformers import (
|
||||
@ -15,16 +14,18 @@ from transformers import (
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
TestCasePlus,
|
||||
backend_device_count,
|
||||
execute_subprocess_async,
|
||||
get_torch_dist_unique_port,
|
||||
require_torch_multi_gpu,
|
||||
require_torch_multi_accelerator,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
class TestTrainerDistributedLoss(TestCasePlus):
|
||||
@require_torch_multi_gpu
|
||||
@require_torch_multi_accelerator
|
||||
def test_trainer(self):
|
||||
device_count = torch.cuda.device_count()
|
||||
device_count = backend_device_count(torch_device)
|
||||
min_bs = 1
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
for gpu_num, enable, bs, name in (
|
||||
|
@ -14,9 +14,11 @@ from transformers import (
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
TestCasePlus,
|
||||
backend_device_count,
|
||||
execute_subprocess_async,
|
||||
get_torch_dist_unique_port,
|
||||
require_torch_multi_gpu,
|
||||
require_torch_multi_accelerator,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
@ -47,7 +49,7 @@ class DummyModel(nn.Module):
|
||||
self.fc = nn.Linear(3, 1)
|
||||
|
||||
def forward(self, x):
|
||||
local_tensor = torch.tensor(x, device="cuda")
|
||||
local_tensor = torch.tensor(x, device=torch_device)
|
||||
gathered = gather_from_all_gpus(local_tensor, dist.get_world_size())
|
||||
assert not all(torch.allclose(t, gathered[0]) for t in gathered[1:])
|
||||
y = self.fc(x)
|
||||
@ -55,9 +57,9 @@ class DummyModel(nn.Module):
|
||||
|
||||
|
||||
class TestTrainerDistributedWorkerSeed(TestCasePlus):
|
||||
@require_torch_multi_gpu
|
||||
@require_torch_multi_accelerator
|
||||
def test_trainer(self):
|
||||
device_count = torch.cuda.device_count()
|
||||
device_count = backend_device_count(torch_device)
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
distributed_args = f"""--nproc_per_node={device_count}
|
||||
--master_port={get_torch_dist_unique_port()}
|
||||
|
Loading…
Reference in New Issue
Block a user