mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[testing] ensure concurrent pytest workers use a unique port for torch.dist (#12166)
* ensure concurrent pytest workers use a unique port for torch.distributed.launch * reword
This commit is contained in:
parent
b9d66f4c4b
commit
6e7cc5cc51
@ -1249,6 +1249,28 @@ def execute_subprocess_async(cmd, env=None, stdin=None, timeout=180, quiet=False
|
||||
return result
|
||||
|
||||
|
||||
def pytest_xdist_worker_id():
|
||||
"""
|
||||
Returns an int value of worker's numerical id under ``pytest-xdist``'s concurrent workers ``pytest -n N`` regime,
|
||||
or 0 if ``-n 1`` or ``pytest-xdist`` isn't being used.
|
||||
"""
|
||||
worker = os.environ.get("PYTEST_XDIST_WORKER", "gw0")
|
||||
worker = re.sub(r"^gw", "", worker, 0, re.M)
|
||||
return int(worker)
|
||||
|
||||
|
||||
def get_torch_dist_unique_port():
|
||||
"""
|
||||
Returns a port number that can be fed to ``torch.distributed.launch``'s ``--master_port`` argument.
|
||||
|
||||
Under ``pytest-xdist`` it adds a delta number based on a worker id so that concurrent tests don't try to use the
|
||||
same port at once.
|
||||
"""
|
||||
port = 29500
|
||||
uniq_delta = pytest_xdist_worker_id()
|
||||
return port + uniq_delta
|
||||
|
||||
|
||||
def nested_simplify(obj, decimals=3):
|
||||
"""
|
||||
Simplifies an object by rounding float numbers, and downcasting tensors/numpy arrays to get simple equality test
|
||||
|
@ -25,6 +25,7 @@ from transformers.testing_utils import (
|
||||
TestCasePlus,
|
||||
execute_subprocess_async,
|
||||
get_gpu_count,
|
||||
get_torch_dist_unique_port,
|
||||
require_torch_gpu,
|
||||
require_torch_multi_gpu,
|
||||
require_torch_non_multi_gpu,
|
||||
@ -223,9 +224,11 @@ class TestTrainerExt(TestCasePlus):
|
||||
|
||||
if distributed:
|
||||
n_gpu = get_gpu_count()
|
||||
master_port = get_torch_dist_unique_port()
|
||||
distributed_args = f"""
|
||||
-m torch.distributed.launch
|
||||
--nproc_per_node={n_gpu}
|
||||
--master_port={master_port}
|
||||
{self.examples_dir_str}/pytorch/translation/run_translation.py
|
||||
""".split()
|
||||
cmd = [sys.executable] + distributed_args + args
|
||||
|
@ -16,7 +16,12 @@ import sys
|
||||
from typing import Dict
|
||||
|
||||
from transformers import EvalPrediction, HfArgumentParser, TrainingArguments, is_torch_available
|
||||
from transformers.testing_utils import TestCasePlus, execute_subprocess_async, require_torch_multi_gpu
|
||||
from transformers.testing_utils import (
|
||||
TestCasePlus,
|
||||
execute_subprocess_async,
|
||||
get_torch_dist_unique_port,
|
||||
require_torch_multi_gpu,
|
||||
)
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
@ -64,6 +69,7 @@ class TestTrainerDistributed(TestCasePlus):
|
||||
distributed_args = f"""
|
||||
-m torch.distributed.launch
|
||||
--nproc_per_node={torch.cuda.device_count()}
|
||||
--master_port={get_torch_dist_unique_port()}
|
||||
{self.test_file_dir}/test_trainer_distributed.py
|
||||
""".split()
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
|
Loading…
Reference in New Issue
Block a user