mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Support XPU in --nproc-per-node option to torchrun (#159474)
Support both --nproc-per-node=xpu and autodetection of XPU device in case of --nproc-per-node=auto Pull Request resolved: https://github.com/pytorch/pytorch/pull/159474 Approved by: https://github.com/tsocha, https://github.com/guangyey, https://github.com/d4l3k Co-authored-by: Yu, Guangye <106960996+guangyey@users.noreply.github.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
972e409829
commit
66c0f14ecc
@ -273,10 +273,29 @@ class ElasticLaunchTest(TestCase):
|
||||
)
|
||||
@patch("torch.cuda.is_available", return_value=True)
|
||||
@patch("torch.cuda.device_count", return_value=3)
|
||||
def test_nproc_gpu_launch_configurations(self, _mock1, _mock2):
|
||||
@patch("torch.accelerator.is_available", return_value=True)
|
||||
@patch("torch.accelerator.device_count", return_value=3)
|
||||
@patch("torch.accelerator.current_accelerator", return_value=MagicMock(type="gpu"))
|
||||
def test_nproc_gpu_launch_configurations(
|
||||
self, _mock1, _mock2, _mock3, _mock4, _mock5
|
||||
):
|
||||
self._test_nproc_launch_configuration("auto", 3)
|
||||
self._test_nproc_launch_configuration("gpu", 3)
|
||||
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
|
||||
)
|
||||
@patch("torch.xpu.is_available", return_value=True)
|
||||
@patch("torch.xpu.device_count", return_value=3)
|
||||
@patch("torch.accelerator.is_available", return_value=True)
|
||||
@patch("torch.accelerator.device_count", return_value=3)
|
||||
@patch("torch.accelerator.current_accelerator", return_value=MagicMock(type="xpu"))
|
||||
def test_nproc_xpu_launch_configurations(
|
||||
self, _mock1, _mock2, _mock3, _mock4, _mock5
|
||||
):
|
||||
self._test_nproc_launch_configuration("auto", 3)
|
||||
self._test_nproc_launch_configuration("xpu", 3)
|
||||
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
|
||||
)
|
||||
|
@ -77,7 +77,9 @@ Single-node multi-worker
|
||||
.. note:: ``--nproc-per-node`` may be
|
||||
``"gpu"`` (spawn one process per GPU),
|
||||
``"cpu"`` (spawn one process per CPU),
|
||||
``"xpu"`` (spawn one process per XPU),
|
||||
``"auto"`` (equivalent to ``"gpu"`` if CUDA is available,
|
||||
else equivalent to ``"xpu"`` if XPU is available,
|
||||
else equivalent to ``"cpu"``),
|
||||
or an integer specifying the number of processes.
|
||||
See `torch.distributed.run.determine_local_world_size
|
||||
@ -413,7 +415,7 @@ def get_args_parser() -> ArgumentParser:
|
||||
action=env,
|
||||
type=str,
|
||||
default="1",
|
||||
help="Number of workers per node; supported values: [auto, cpu, gpu, int].",
|
||||
help="Number of workers per node; supported values: [auto, cpu, gpu, xpu, int].",
|
||||
)
|
||||
|
||||
#
|
||||
@ -705,21 +707,20 @@ def determine_local_world_size(nproc_per_node: str):
|
||||
raise ValueError("Cuda is not available.") from e
|
||||
device_type = "gpu"
|
||||
num_proc = torch.cuda.device_count()
|
||||
elif nproc_per_node == "xpu":
|
||||
if not torch.xpu.is_available():
|
||||
raise ValueError("Xpu is not available.") from e
|
||||
device_type = "xpu"
|
||||
num_proc = torch.xpu.device_count()
|
||||
elif nproc_per_node == torch._C._get_privateuse1_backend_name():
|
||||
if not _get_custom_mod_func("is_available")():
|
||||
raise ValueError(f"{nproc_per_node} is not available.") from e
|
||||
device_type = nproc_per_node
|
||||
num_proc = _get_custom_mod_func("device_count")()
|
||||
elif nproc_per_node == "auto":
|
||||
if torch.cuda.is_available():
|
||||
num_proc = torch.cuda.device_count()
|
||||
device_type = "gpu"
|
||||
elif (
|
||||
hasattr(torch, torch._C._get_privateuse1_backend_name())
|
||||
and _get_custom_mod_func("is_available")()
|
||||
):
|
||||
num_proc = _get_custom_mod_func("device_count")()
|
||||
device_type = torch._C._get_privateuse1_backend_name()
|
||||
if torch.accelerator.is_available():
|
||||
num_proc = torch.accelerator.device_count()
|
||||
device_type = torch.accelerator.current_accelerator().type # type: ignore[union-attr]
|
||||
else:
|
||||
num_proc = os.cpu_count()
|
||||
device_type = "cpu"
|
||||
|
Reference in New Issue
Block a user