Support NUMA Binding for Callable Entrypoints (#160163)

# Context
This is an extension of #149334.

# This PR
Add support for NUMA bindings with Callable entrypoints, such as `do_train` instead of `/usr/local/bin/python`.

Most notably, we utilize a hack in order to force `Process.start()` to use custom NUMA bindings for each subprocess. Please search for `HACK:` in the code to see a description of the implementation we chose, and #160006 for discussion of alternatives and why this is necessary.

Other changes:
* Remove unnecessary `--preferred` option from all binding strategies. By default, Linux already allocates memory to the NUMA node local to the CPU which triggered the allocation. (See [MPOL_LOCAL](https://man7.org/linux/man-pages/man2/set_mempolicy.2.html).)
* Refactor so that the main API is `maybe_wrap_command_with_numa_bindings`, which computes bindings for a single rank at a time, rather than `maybe_wrap_with_numa_bindings` which computed bindings for all ranks at once. This allowed for more code sharing between `Callable` and `str` entrypoints.

# Test Plan
## Automated
`$ pytest test/test_numa_binding.py`

## Manual
Using [this benchmark,](https://gist.github.com/pdesupinski/bbe01ade455d86e989794f2c612e2d91), ran

```
$ PYTHONUNBUFFERED=1 LOGLEVEL=INFO perf stat -e ls_dmnd_fills_from_sys.dram_io_far,ls_dmnd_fills_from_sys.dram_io_near -- python -m torch.distributed.run --standalone --nproc-per-node=8 --numa-binding=node --run-path mlp_train.py 2>&1 | tee node_callable.txt && PYTHONUNBUFFERED=1 LOGLEVEL=INFO perf stat -e ls_dmnd_fills_from_sys.dram_io_far,ls_dmnd_fills_from_sys.dram_io_near -- python -u -m torch.distributed.run --standalone --nproc-per-node=8 --run-path mlp_train.py 2>&1 | tee none_callable.txt
```

and observed
* 6.6% remote memory accesses with 'node' bindings
* 11.6% remote without bindings

I also ran similar with `str` entrypoints as before just to be sure it's still working.

NOTE: [--run-path triggers the code to be run inside a `Callable`.](017259f9c6/torch/distributed/run.py (L870))

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160163
Approved by: https://github.com/d4l3k
This commit is contained in:
Paul de Supinski
2025-08-12 20:08:45 +00:00
committed by PyTorch MergeBot
parent 89654db1ab
commit 7e91394955
12 changed files with 431 additions and 222 deletions

View File

@ -3,8 +3,8 @@
NUMA Binding Utilities
======================
.. automodule:: torch.distributed.numa
.. automodule:: torch.numa
:members:
.. automodule:: torch.distributed.numa.binding
.. automodule:: torch.numa.binding
:members:

View File

@ -2,16 +2,19 @@
from __future__ import annotations
import multiprocessing.spawn as spawn
import os
import subprocess
import sys
import tempfile
from dataclasses import dataclass
from typing import Any, Optional
from unittest import skipIf, skipUnless
from unittest import skipUnless
from unittest.mock import mock_open, patch
import torch
from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs, start_processes
from torch.distributed.numa.binding import (
from torch.numa.binding import (
_get_ranges_str_from_ints,
_get_set_of_int_from_ranges_str,
AffinityMode,
@ -35,12 +38,10 @@ class MockDeviceProperties:
_real_open = open
_real_mkstemp = tempfile.mkstemp
@skipIf(
sys.platform == "win32",
"Windows is missing various os module attributes like sched_getaffinity",
)
@skipUnless(sys.platform == "linux", "Only linux currently supported")
@skipUnless(
torch.distributed.is_available(), "Need access to some distributed submodules"
)
@ -53,26 +54,44 @@ class NumaBindingTest(TestCase):
self._mock_num_logical_cpus = 0
self._mock_num_numa_nodes = 0
self._mock_num_sockets = 0
self._temp_file_paths = []
self._context_managers_to_apply_to_all_tests = [
patch("torch.cuda.device_count", self._mock_device_count),
patch("torch.cuda.get_device_properties", self._mock_get_device_properties),
patch("torch.cuda.is_available", self._mock_is_available),
# Implicitly used by dynamo
patch("torch.cuda.get_rng_state"),
patch("builtins.open", new=self._mock_open),
patch("os.listdir", new=self._mock_listdir),
patch("os.sched_getaffinity", new=self._mock_sched_getaffinity),
patch("shutil.which", return_value="/usr/bin/numactl"),
patch("subprocess.run"),
patch("torch.numa.binding.run"),
patch("torch.numa.binding.mkstemp", self._mock_mkstemp),
]
for context_manager in self._context_managers_to_apply_to_all_tests:
context_manager.__enter__()
def tearDown(self) -> None:
# Clean up temporary files
for temp_file_path in self._temp_file_paths:
try:
os.unlink(temp_file_path)
except FileNotFoundError:
# File may have already been deleted or doesn't exist
pass
for context_manager in self._context_managers_to_apply_to_all_tests:
context_manager.__exit__(None, None, None)
super().tearDown()
def _mock_mkstemp(self, *args, **kwargs):
# Just keep track of temp files so we can delete them
fd, path = _real_mkstemp(*args, **kwargs)
self._temp_file_paths.append(path)
return fd, path
def _add_mock_hardware(
self,
*,
@ -204,7 +223,7 @@ class NumaBindingTest(TestCase):
def _mock_open(self, path: str, *args, **kwargs) -> Any:
if path in self._mock_file_path_to_contents:
return mock_open(read_data=self._mock_file_path_to_contents[path])()
if path.startswith("/sys/"):
if isinstance(path, str) and path.startswith("/sys/"):
raise FileNotFoundError(f"File {path} was not mocked.")
# Looks like CI is calling open and intending to open an actual file in some places.
# Need this to make the CI pass.
@ -222,8 +241,8 @@ class NumaBindingTest(TestCase):
def _mock_sched_getaffinity(self, pid: int) -> set[int]:
return set(range(self._mock_num_logical_cpus))
def _start_test_processes_and_get_command_args_for_local_rank(
self, *, numa_options: Optional[NumaOptions], local_rank: int
def _start_processes_for_str_entrypoint_and_get_Popen_args(
self, *, numa_options: Optional[NumaOptions], target_local_rank: int
) -> tuple[str, ...]:
"""
Calls start_processes like elastic_launch ultimately would
@ -250,10 +269,58 @@ class NumaBindingTest(TestCase):
call_args = next(
call_args
for call_args in mock_popen.call_args_list
if call_args.kwargs.get("env", {}).get("LOCAL_RANK") == str(local_rank)
if call_args.kwargs.get("env", {}).get("LOCAL_RANK")
== str(target_local_rank)
)
return call_args.kwargs["args"]
def _start_processes_for_callable_entrypoint_and_get_executable_contents(
self, *, numa_options: Optional[NumaOptions], target_local_rank: int
) -> str:
active_local_rank = None
executable_path = None
def _mock_process_start(self: Any) -> None:
nonlocal active_local_rank
active_local_rank = self._args[1]
spawn.get_command_line()
self._target(*self._args)
original_get_command_line = spawn.get_command_line
def _mock_get_command_line(*args, **kwargs) -> list[str]:
nonlocal executable_path
result = original_get_command_line(*args, **kwargs)
if active_local_rank == target_local_rank:
executable_path = result[0]
return result
with (
patch("multiprocessing.context.SpawnProcess.start", _mock_process_start),
patch("multiprocessing.spawn.get_command_line", _mock_get_command_line),
patch("multiprocessing.process.BaseProcess.sentinel", 1),
# Prevent hanging
patch(
"multiprocessing.synchronize.Event.wait",
lambda self, timeout=None: None,
),
):
start_processes(
name="test_process",
entrypoint=lambda x: x,
args=dict.fromkeys(range(self._mock_device_count()), (0,)),
envs={
i: {"LOCAL_RANK": str(i)} for i in range(self._mock_device_count())
},
logs_specs=DefaultLogsSpecs(),
numa_options=numa_options,
)
assert executable_path is not None
with open(executable_path) as executable_file:
return executable_file.read()
def test_node_numa_binding(self) -> None:
self._add_mock_hardware(
num_sockets=4,
@ -263,8 +330,9 @@ class NumaBindingTest(TestCase):
num_physical_core_per_l3_cache=2,
)
command_args = self._start_test_processes_and_get_command_args_for_local_rank(
numa_options=NumaOptions(affinity_mode=AffinityMode.NODE), local_rank=11
command_args = self._start_processes_for_str_entrypoint_and_get_Popen_args(
numa_options=NumaOptions(affinity_mode=AffinityMode.NODE),
target_local_rank=11,
)
self.assertEqual(
command_args,
@ -273,7 +341,6 @@ class NumaBindingTest(TestCase):
(
"numactl",
"--cpunodebind=5",
"--preferred=5",
"echo",
"Hello, world!",
),
@ -288,8 +355,8 @@ class NumaBindingTest(TestCase):
num_physical_core_per_l3_cache=2,
)
command_args = self._start_test_processes_and_get_command_args_for_local_rank(
numa_options=None, local_rank=11
command_args = self._start_processes_for_str_entrypoint_and_get_Popen_args(
numa_options=None, target_local_rank=11
)
self.assertEqual(
command_args,
@ -340,20 +407,18 @@ class NumaBindingTest(TestCase):
)
with (
patch("torch.distributed.numa.binding.signpost_event") as signpost_patch,
patch("torch.numa.binding.signpost_event") as signpost_patch,
patch(
"subprocess.run",
"torch.numa.binding.run",
side_effect=subprocess.CalledProcessError(1, "numactl"),
),
):
command_args = (
self._start_test_processes_and_get_command_args_for_local_rank(
numa_options=NumaOptions(
affinity_mode=AffinityMode.NODE,
should_fall_back_if_binding_fails=True,
),
local_rank=0,
)
command_args = self._start_processes_for_str_entrypoint_and_get_Popen_args(
numa_options=NumaOptions(
affinity_mode=AffinityMode.NODE,
should_fall_back_if_binding_fails=True,
),
target_local_rank=0,
)
self.assertIn(
"subprocess.CalledProcessError",
@ -387,6 +452,25 @@ class NumaBindingTest(TestCase):
NumaOptions(affinity_mode=AffinityMode.EXCLUSIVE),
)
def test_fork_start_method_does_not_call_get_default_numa_options(self) -> None:
# Inner import to avoid crashing if not torch.distributed.is_available()
from torch.distributed.launcher.api import LaunchConfig
with patch(
"torch.distributed.launcher.api.get_default_numa_options"
) as mock_get_default_numa_options:
launch_config = LaunchConfig(
min_nodes=1,
max_nodes=1,
nproc_per_node=1,
start_method="fork",
# Don't provide numa_options
)
# Verify get_default_numa_options was not called
mock_get_default_numa_options.assert_not_called()
# Verify numa_options is None when start_method is fork
self.assertIsNone(launch_config.numa_options)
def test_socket_numa_binding_with_multiple_numa_per_socket(self) -> None:
self._add_mock_hardware(
num_sockets=4,
@ -396,15 +480,15 @@ class NumaBindingTest(TestCase):
num_physical_core_per_l3_cache=2,
)
command_args = self._start_test_processes_and_get_command_args_for_local_rank(
numa_options=NumaOptions(affinity_mode=AffinityMode.SOCKET), local_rank=15
command_args = self._start_processes_for_str_entrypoint_and_get_Popen_args(
numa_options=NumaOptions(affinity_mode=AffinityMode.SOCKET),
target_local_rank=15,
)
self.assertEqual(
command_args,
(
"numactl",
"--cpunodebind=6-7",
"--preferred-many=6-7",
"echo",
"Hello, world!",
),
@ -419,15 +503,15 @@ class NumaBindingTest(TestCase):
num_physical_core_per_l3_cache=2,
)
command_args = self._start_test_processes_and_get_command_args_for_local_rank(
numa_options=NumaOptions(affinity_mode=AffinityMode.SOCKET), local_rank=7
command_args = self._start_processes_for_str_entrypoint_and_get_Popen_args(
numa_options=NumaOptions(affinity_mode=AffinityMode.SOCKET),
target_local_rank=7,
)
self.assertEqual(
command_args,
(
"numactl",
"--cpunodebind=3",
"--preferred=3",
"echo",
"Hello, world!",
),
@ -442,8 +526,9 @@ class NumaBindingTest(TestCase):
num_physical_core_per_l3_cache=3,
)
command_args_0 = self._start_test_processes_and_get_command_args_for_local_rank(
numa_options=NumaOptions(affinity_mode=AffinityMode.EXCLUSIVE), local_rank=0
command_args_0 = self._start_processes_for_str_entrypoint_and_get_Popen_args(
numa_options=NumaOptions(affinity_mode=AffinityMode.EXCLUSIVE),
target_local_rank=0,
)
self.assertEqual(
command_args_0,
@ -451,14 +536,14 @@ class NumaBindingTest(TestCase):
"numactl",
# Gets an extra physical core due to odd number of physical cores on numa node
"--physcpubind=0-3",
"--preferred=0",
"echo",
"Hello, world!",
),
)
command_args_1 = self._start_test_processes_and_get_command_args_for_local_rank(
numa_options=NumaOptions(affinity_mode=AffinityMode.EXCLUSIVE), local_rank=1
command_args_1 = self._start_processes_for_str_entrypoint_and_get_Popen_args(
numa_options=NumaOptions(affinity_mode=AffinityMode.EXCLUSIVE),
target_local_rank=1,
)
self.assertEqual(
command_args_1,
@ -466,7 +551,6 @@ class NumaBindingTest(TestCase):
"numactl",
# Does not get an extra physical core, since the 1st GPU already took the extra.
"--physcpubind=4-5",
"--preferred=0",
"echo",
"Hello, world!",
),
@ -485,9 +569,9 @@ class NumaBindingTest(TestCase):
RuntimeError,
"There are only 1 physical cores on numa_node_index=0, but there are 2 GPUs associated with this NUMA node.",
):
self._start_test_processes_and_get_command_args_for_local_rank(
self._start_processes_for_str_entrypoint_and_get_Popen_args(
numa_options=NumaOptions(affinity_mode=AffinityMode.EXCLUSIVE),
local_rank=1,
target_local_rank=1,
)
def test_core_complex_numa_binding_with_extra_l3(self) -> None:
@ -499,9 +583,9 @@ class NumaBindingTest(TestCase):
num_physical_core_per_l3_cache=3,
)
command_args = self._start_test_processes_and_get_command_args_for_local_rank(
command_args = self._start_processes_for_str_entrypoint_and_get_Popen_args(
numa_options=NumaOptions(affinity_mode=AffinityMode.CORE_COMPLEX),
local_rank=3,
target_local_rank=3,
)
self.assertEqual(
command_args,
@ -509,7 +593,6 @@ class NumaBindingTest(TestCase):
"numactl",
# The second L3 on the second numa node
"--physcpubind=24-29",
"--preferred=1",
"echo",
"Hello, world!",
),
@ -524,9 +607,9 @@ class NumaBindingTest(TestCase):
num_physical_core_per_l3_cache=3,
)
command_args = self._start_test_processes_and_get_command_args_for_local_rank(
command_args = self._start_processes_for_str_entrypoint_and_get_Popen_args(
numa_options=NumaOptions(affinity_mode=AffinityMode.CORE_COMPLEX),
local_rank=3,
target_local_rank=3,
)
self.assertEqual(
command_args,
@ -535,7 +618,6 @@ class NumaBindingTest(TestCase):
# There are only 2 L3 caches, so the 4th GPU shares the same
# cores as the 3rd GPU.
"--physcpubind=6-11",
"--preferred=1",
"echo",
"Hello, world!",
),
@ -552,11 +634,9 @@ class NumaBindingTest(TestCase):
# Only some subset of the CPUs are available this time.
with patch("os.sched_getaffinity", return_value={0, 4, 6, 7, 9}):
command_args = (
self._start_test_processes_and_get_command_args_for_local_rank(
numa_options=NumaOptions(affinity_mode=AffinityMode.CORE_COMPLEX),
local_rank=0,
)
command_args = self._start_processes_for_str_entrypoint_and_get_Popen_args(
numa_options=NumaOptions(affinity_mode=AffinityMode.CORE_COMPLEX),
target_local_rank=0,
)
self.assertEqual(
@ -565,7 +645,6 @@ class NumaBindingTest(TestCase):
"numactl",
# Binds to the second L3 because it has the most available CPUs
"--physcpubind=6-7,9",
"--preferred=0",
"echo",
"Hello, world!",
),
@ -584,42 +663,20 @@ class NumaBindingTest(TestCase):
num_physical_core_per_l3_cache=1,
)
command_args = self._start_test_processes_and_get_command_args_for_local_rank(
command_args = self._start_processes_for_str_entrypoint_and_get_Popen_args(
numa_options=NumaOptions(affinity_mode=AffinityMode.CORE_COMPLEX),
local_rank=0,
target_local_rank=0,
)
self.assertEqual(
command_args,
(
"numactl",
"--physcpubind=0-1",
"--preferred=0",
"echo",
"Hello, world!",
),
)
def test_raises_error_if_numa_options_provided_for_callable_entrypoint(
self,
) -> None:
# Inner import to avoid crashing if not torch.distributed.is_available()
from torch.distributed.elastic.agent.server.api import WorkerSpec
def mock_entrypoint() -> None:
pass
with self.assertRaisesRegex(ValueError, r".*numa_options.*"):
# not relevant to test, just pass in an arbitrary value
mock_rdzv_handler: Any = 0
WorkerSpec(
role="trainer",
# Only str entrypoint (e.g. "echo") is currently supported
entrypoint=mock_entrypoint,
local_world_size=8,
rdzv_handler=mock_rdzv_handler,
numa_options=NumaOptions(affinity_mode=AffinityMode.NODE),
)
def test_raises_error_if_numactl_unavailable(self) -> None:
self._add_mock_hardware(
num_sockets=1,
@ -632,8 +689,9 @@ class NumaBindingTest(TestCase):
patch("shutil.which", return_value=None),
self.assertRaisesRegex(RuntimeError, r".*numactl.*"),
):
self._start_test_processes_and_get_command_args_for_local_rank(
numa_options=NumaOptions(affinity_mode=AffinityMode.NODE), local_rank=0
self._start_processes_for_str_entrypoint_and_get_Popen_args(
numa_options=NumaOptions(affinity_mode=AffinityMode.NODE),
target_local_rank=0,
)
def test_binds_to_node_0_if_node_stored_as_minus_one(self) -> None:
@ -654,20 +712,50 @@ class NumaBindingTest(TestCase):
contents="-1",
)
command_args = self._start_test_processes_and_get_command_args_for_local_rank(
numa_options=NumaOptions(affinity_mode=AffinityMode.NODE), local_rank=0
command_args = self._start_processes_for_str_entrypoint_and_get_Popen_args(
numa_options=NumaOptions(affinity_mode=AffinityMode.NODE),
target_local_rank=0,
)
self.assertEqual(
command_args,
(
"numactl",
"--cpunodebind=0",
"--preferred=0",
"echo",
"Hello, world!",
),
)
def test_callable_entrypoint_basic(self) -> None:
self._add_mock_hardware(
num_sockets=4,
num_numa_nodes_per_socket=2,
num_gpus_per_numa_node=2,
num_l3_caches_per_numa_node=4,
num_physical_core_per_l3_cache=2,
)
executable_contents = (
self._start_processes_for_callable_entrypoint_and_get_executable_contents(
numa_options=NumaOptions(affinity_mode=AffinityMode.NODE),
target_local_rank=11,
)
)
self.assertEqual(
executable_contents,
# There are 8 numa nodes and 2 GPUs per numa node, so GPU 11 would be
# on numa node 11 // 2 = 5.
f"""#!/bin/bash
# If this file is more than a few minutes old and still exists on your machine,
# that is NOT expected. It should have deleted itself. If you are seeing an accumulation of such
# files, that could suggest a bug in pytorch. See https://github.com/pytorch/pytorch/pull/160163.
rm -- "$0"
numactl --cpunodebind=5 {sys.executable} "$@"
""",
)
def test_get_set_of_int_from_ranges_str(self) -> None:
self.assertEqual(
_get_set_of_int_from_ranges_str("0-2,4,6-7"), {0, 1, 2, 4, 6, 7}

View File

@ -27,7 +27,7 @@ from torch.distributed.elastic.metrics import prof, put_metric
from torch.distributed.elastic.multiprocessing import ProcessFailure, SignalException
from torch.distributed.elastic.rendezvous import RendezvousGracefulExitError
from torch.distributed.elastic.utils.logging import get_logger
from torch.distributed.numa.binding import NumaOptions
from torch.numa.binding import NumaOptions
__all__ = [
@ -104,13 +104,6 @@ class WorkerSpec:
self.entrypoint = self.fn
assert self.entrypoint
if (
self.numa_options is not None
and not self.numa_options.should_fall_back_if_binding_fails
and not isinstance(self.entrypoint, str)
):
raise ValueError("numa_options is only supported for str entrypoints.")
def get_entrypoint_name(self):
"""Get the entry point name.

View File

@ -80,7 +80,7 @@ from torch.distributed.elastic.multiprocessing.api import ( # noqa: F401
to_map,
)
from torch.distributed.elastic.utils.logging import get_logger
from torch.distributed.numa.binding import NumaOptions
from torch.numa.binding import NumaOptions
__all__ = [
@ -227,6 +227,7 @@ def start_processes(
log_line_prefixes=log_line_prefixes,
start_method=start_method,
logs_specs=logs_specs,
numa_options=numa_options,
)
try:

View File

@ -37,7 +37,7 @@ from torch.distributed.elastic.multiprocessing.subprocess_handler import (
SubprocessHandler,
)
from torch.distributed.elastic.multiprocessing.tail_log import TailLog
from torch.distributed.numa.binding import maybe_wrap_with_numa_bindings, NumaOptions
from torch.numa.binding import NumaOptions
IS_WINDOWS = sys.platform == "win32"
@ -631,6 +631,7 @@ class MultiprocessContext(PContext):
start_method: str,
logs_specs: LogsSpecs,
log_line_prefixes: Optional[dict[int, str]] = None,
numa_options: Optional[NumaOptions] = None,
):
super().__init__(
name,
@ -655,6 +656,8 @@ class MultiprocessContext(PContext):
# successfully. If any process died on event.wait() calling set() method will deadlock.
self._worker_finished_event = mp.get_context(self.start_method).Event()
self._numa_options: Optional[NumaOptions] = numa_options
def _start(self):
if self._pc:
raise ValueError(
@ -676,6 +679,7 @@ class MultiprocessContext(PContext):
join=False,
daemon=False,
start_method=self.start_method,
numa_options=self._numa_options,
)
def _is_done(self) -> bool:
@ -814,10 +818,6 @@ class SubprocessContext(PContext):
log_line_prefixes: Optional[dict[int, str]] = None,
numa_options: Optional[NumaOptions] = None,
):
entrypoint, args = maybe_wrap_with_numa_bindings(
entrypoint=entrypoint, local_rank_to_args=args, numa_options=numa_options
)
super().__init__(
name,
entrypoint,
@ -831,6 +831,7 @@ class SubprocessContext(PContext):
self._running_local_ranks: set[int] = set(range(self.nprocs))
self._failures: dict[int, ProcessFailure] = {}
self.subprocess_handlers: dict[int, SubprocessHandler] = {}
self._numa_options: Optional[NumaOptions] = numa_options
def _start(self):
if self.subprocess_handlers:
@ -845,6 +846,7 @@ class SubprocessContext(PContext):
stdout=self.stdouts[local_rank],
stderr=self.stderrs[local_rank],
local_rank_id=local_rank,
numa_options=self._numa_options,
)
for local_rank in range(self.nprocs)
}

View File

@ -3,10 +3,12 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional
from torch.distributed.elastic.multiprocessing.subprocess_handler.subprocess_handler import (
SubprocessHandler,
)
from torch.numa.binding import NumaOptions
__all__ = ["get_subprocess_handler"]
@ -19,6 +21,7 @@ def get_subprocess_handler(
stdout: str,
stderr: str,
local_rank_id: int,
numa_options: Optional[NumaOptions] = None,
) -> SubprocessHandler:
return SubprocessHandler(
entrypoint=entrypoint,
@ -27,4 +30,5 @@ def get_subprocess_handler(
stdout=stdout,
stderr=stderr,
local_rank_id=local_rank_id,
numa_options=numa_options,
)

View File

@ -11,6 +11,8 @@ import sys
from subprocess import Popen
from typing import Any, Optional
from torch.numa.binding import maybe_wrap_command_with_numa_bindings, NumaOptions
__all__ = ["SubprocessHandler"]
@ -39,6 +41,7 @@ class SubprocessHandler:
stdout: Optional[str],
stderr: Optional[str],
local_rank_id: int,
numa_options: Optional[NumaOptions],
):
self._stdout = open(stdout, "w") if stdout else None
self._stderr = open(stderr, "w") if stderr else None
@ -47,6 +50,15 @@ class SubprocessHandler:
env_vars.update(env)
args_str = (entrypoint, *[str(e) for e in args])
args_str = (
maybe_wrap_command_with_numa_bindings(
command_args=args_str,
gpu_index=local_rank_id,
numa_options=numa_options,
)
or args_str
)
self.local_rank_id = local_rank_id
self.proc: Popen = self._popen(args_str, env_vars)

View File

@ -26,7 +26,7 @@ from torch.distributed.elastic.multiprocessing.errors import ChildFailedError
from torch.distributed.elastic.rendezvous import RendezvousParameters
from torch.distributed.elastic.rendezvous.utils import parse_rendezvous_endpoint
from torch.distributed.elastic.utils.logging import get_logger
from torch.distributed.numa.binding import NumaOptions
from torch.numa.binding import NumaOptions
__all__ = ["LaunchConfig", "elastic_launch", "launch_agent"]
@ -107,7 +107,13 @@ class LaunchConfig:
if self.logs_specs is None:
self.logs_specs = DefaultLogsSpecs()
if self.numa_options is None and torch.cuda.is_available():
if (
self.numa_options is None
# NOTE: This filter isn't relevant for str entrypoints,
# but it's the default anyway.
and self.start_method == "spawn"
and torch.cuda.is_available()
):
self.numa_options = get_default_numa_options()
logger.info("Using default numa options = %r", self.numa_options)

View File

@ -382,7 +382,7 @@ from torch.distributed.elastic.rendezvous.utils import _parse_rendezvous_config
from torch.distributed.elastic.utils import macros
from torch.distributed.elastic.utils.logging import get_logger
from torch.distributed.launcher.api import elastic_launch, LaunchConfig
from torch.distributed.numa.binding import (
from torch.numa.binding import (
AffinityMode as _AffinityMode, # Signify as private with _
NumaOptions as _NumaOptions,
)

View File

@ -2,6 +2,7 @@
import logging
import multiprocessing
import multiprocessing.connection
import multiprocessing.spawn as mp_spawn
import os
import pickle
import signal
@ -12,6 +13,11 @@ import warnings
from concurrent.futures import as_completed, ThreadPoolExecutor
from typing import Optional
from torch.numa.binding import (
maybe_get_temporary_python_executable_with_numa_bindings,
NumaOptions,
)
from . import _prctl_pr_set_pdeathsig # type: ignore[attr-defined]
@ -236,6 +242,7 @@ def start_processes(
join=True,
daemon=False,
start_method="spawn",
numa_options: Optional[NumaOptions] = None,
):
# To speed up performance in certain cases (see https://github.com/pytorch/pytorch/issues/133010),
# this func will start processes in parallel if start_method is 'forkserver'.
@ -251,11 +258,43 @@ def start_processes(
# Set env var TORCH_MP_PARALLEL_START to 0 to disable parallel start
start_parallel = False
if numa_options is not None and start_method != "spawn":
raise ValueError("NUMA binding is only compatible with spawn")
if numa_options is not None and start_parallel:
raise ValueError("NUMA binding is not compatible with parallel start")
mp = multiprocessing.get_context(start_method)
error_files = [None] * nprocs
processes = [None] * nprocs
original_executable = mp_spawn.get_executable()
def start_process(i):
# HACK: We want to force Process.start() to kick off the subprocess
# using a custom numactl command per rank. However, the API exposed
# by multiprocessing only allows us to override the executable for
# the entire context, and only with a single str rather than a tuple.
# Furthermore, there is no API for passing additional options, e.g.
# to make LOCAL_RANK available to the executable.
#
# In order to get around these limitations, we pre-compute
# the appropriate command containing NUMA bindings and store it in a
# temporary executable which passes Python args on to the original
# executable. Then, we call set_executable before and after each
# Process.start() call.
#
# This assumes that, under the hood, Process.start() for rank n
# will not call get_executable after start_process for rank n+1
# calls set_executable again. We guarantee this by
# raising an exception if `start_parallel`, above. (Not clear
# if there would be a race condition otherwise, but we want to be safe.)
temporary_executable_path = (
maybe_get_temporary_python_executable_with_numa_bindings(
python_executable_path=original_executable,
gpu_index=i,
numa_options=numa_options,
)
)
# Each process is assigned a file to write tracebacks to. We
# use the file being non-empty to indicate an exception
# occurred (vs an expected shutdown). Note: this previously
@ -267,12 +306,19 @@ def start_processes(
)
tf.close()
os.unlink(tf.name)
process = mp.Process(
target=_wrap,
args=(fn, i, args, tf.name),
daemon=daemon,
)
process.start()
try:
if temporary_executable_path is not None:
mp.set_executable(temporary_executable_path)
process = mp.Process(
target=_wrap,
args=(fn, i, args, tf.name),
daemon=daemon,
)
process.start()
finally:
if temporary_executable_path is not None:
mp.set_executable(original_executable)
return i, process, tf.name
if not start_parallel:

View File

@ -1,28 +1,31 @@
import os
import shutil
import stat
import subprocess
import traceback
from collections import defaultdict
from collections.abc import Iterable
from dataclasses import dataclass
from enum import Enum
from logging import getLogger
from subprocess import run
from tempfile import mkstemp
from typing import Callable, Optional, TypeVar
import torch
from torch._utils_internal import signpost_event
from torch.distributed.elastic.utils.logging import get_logger
__all__ = [
"maybe_wrap_with_numa_bindings",
"AffinityMode",
"maybe_get_temporary_python_executable_with_numa_bindings",
"maybe_wrap_command_with_numa_bindings",
"NumaOptions",
]
_NUMACTL_COMMAND = "numactl"
logger = get_logger(__file__)
logger = getLogger(__name__)
class AffinityMode(str, Enum):
@ -40,10 +43,10 @@ class AffinityMode(str, Enum):
@dataclass(frozen=True)
class NumaOptions:
affinity_mode: AffinityMode
"""
If true, we will silently return the original command if any of the following occur:
- An exception is raised as we compute the wrapped command.
- During a dry run of the wrapped command, numactl fails for any reason.
If true, we will fall back to using the original command/entrypoint if we fail to compute
or apply NUMA bindings.
You should avoid using this option! It is only intended as a safety mechanism for facilitating
mass rollouts of numa binding.
@ -51,52 +54,156 @@ class NumaOptions:
should_fall_back_if_binding_fails: bool = False
def maybe_wrap_with_numa_bindings(
*,
entrypoint: str,
local_rank_to_args: dict[int, tuple],
numa_options: Optional[NumaOptions],
) -> tuple[str, dict[int, tuple]]:
def maybe_get_temporary_python_executable_with_numa_bindings(
*, python_executable_path: str, gpu_index: int, numa_options: Optional[NumaOptions]
) -> Optional[str]:
"""
Args:
entrypoint: The entrypoint to the program, such as might be input to Popen.
Example: "python"
local_rank_to_args: A mapping from local rank to args for the entrypoint.
Example: {0: ("trainer.py",)}
numa_options: See NumaOptions for details.
python_executable_path: E.g., "/usr/local/bin/python"
Returns:
A tuple of (entrypoint, local_rank_to_args), basically transforming the inputs,
where the entrypoint and args may now involve numa binding.
Example: ("numactl", {"0": ("--cpunodebind=0", "--preferred=0", "python", "trainer.py")})
Path to a temporary file. This file can be executed just like the original python
executable, except it will first apply NUMA bindings.
"""
if numa_options is None:
return (entrypoint, local_rank_to_args)
logger.info("Received numa_options=None, not creating numa executable.")
return None
wrapped_local_rank_to_args = {}
for local_rank, args in local_rank_to_args.items():
try:
numactl_command_options = _maybe_get_numactl_options(
command_args=(entrypoint, *[str(arg) for arg in args]),
gpu_index=local_rank,
numa_options=numa_options,
)
except Exception:
if numa_options.should_fall_back_if_binding_fails:
# NOTE: If any element of the batch fails to apply NUMA bindings
# for any reason, we do not apply NUMA bindings to any element of the batch,
# for maximum safety. This only applies if fallback is enabled.
return (entrypoint, local_rank_to_args)
raise
wrapped_local_rank_to_args[local_rank] = (
*numactl_command_options,
entrypoint,
*args,
if isinstance(python_executable_path, bytes):
python_executable_path = python_executable_path.decode()
full_numactl_command = maybe_wrap_command_with_numa_bindings(
# "$@", i.e. pass through any args the python executable would have
# received.
command_args=(python_executable_path, '"$@"'),
gpu_index=gpu_index,
numa_options=numa_options,
)
if full_numactl_command is None:
return None
executable_path = _get_temporary_executable_for_command(
command_args=full_numactl_command
)
logger.info("Returning python executable with NUMA bindings %s", executable_path)
return executable_path
def maybe_wrap_command_with_numa_bindings(
*,
command_args: tuple[str, ...],
gpu_index: int,
numa_options: Optional[NumaOptions],
) -> Optional[tuple[str, ...]]:
"""
Args:
command_args: Full shell command, like ("/usr/local/bin/python", "train.py")
gpu_index: The index of the GPU which command_args should bind to
Returns:
command_args, but wrapped so that it runs with NUMA bindings corresponding to
gpu_index and numa_options.
E.g., ("numactl", "--cpunodebind=0", "/usr/local/bin/python", "train.py")
"""
if not numa_options:
logger.info("Received numa_options=None, not applying bindings.")
return None
kwargs = {
"command_args": command_args,
"gpu_index": gpu_index,
"numa_options": numa_options,
}
logger.info("Attempting to wrap command with NUMA bindings, given input %r", kwargs)
try:
_raise_if_numactl_not_available()
numactl_options = _get_numactl_cli_options(
command_args=command_args, gpu_index=gpu_index, numa_options=numa_options
)
return (_NUMACTL_COMMAND, wrapped_local_rank_to_args)
logger.info("Computed numactl_options=%r", numactl_options)
_raise_if_numactl_fails_dry_run(numactl_options=numactl_options)
logger.info("Validated numactl_options=%r", numactl_options)
full_numactl_command = _get_assembled_command_from_pieces(
command_args=command_args, numactl_options=numactl_options
)
logger.info(
"Successfully wrapped command with numa_bindings. Returning %r",
full_numactl_command,
)
signpost_event(
category="numa_binding",
name="wrap_command_success",
parameters={**kwargs, "result": full_numactl_command},
)
return full_numactl_command
except Exception:
signpost_event(
category="numa_binding",
name="wrap_command_exception",
parameters={
**kwargs,
"traceback": traceback.format_exc(),
},
)
logger.exception(
"Failed to wrap command with NUMA bindings for input = %r", kwargs
)
if numa_options.should_fall_back_if_binding_fails:
logger.warning("Falling back to original command without NUMA bindings.")
return None
raise
def _maybe_get_numactl_options(
def _get_temporary_executable_for_command(
*,
command_args: tuple[str, ...],
) -> str:
"""
Returns:
Path to a temporary file which executes the specified command. The executable
deletes itself the first time it runs, so do not try to run it multiple times.
"""
fd, path = mkstemp(
prefix="pytorch-numa-bind",
suffix=".sh",
)
# We do rm first to guarantee the file deletes itself. The rest of the file
# will still run as intended.
contents = f"""#!/bin/bash
# If this file is more than a few minutes old and still exists on your machine,
# that is NOT expected. It should have deleted itself. If you are seeing an accumulation of such
# files, that could suggest a bug in pytorch. See https://github.com/pytorch/pytorch/pull/160163.
rm -- "$0"
{" ".join(command_args)}
"""
with os.fdopen(fd, "w") as file:
file.write(contents)
# Ensure the file is fully synced, in order to avoid race condition
# from trying to execute it too early.
file.flush()
os.fsync(fd)
# Make the script executable
os.chmod(path, stat.S_IRWXU)
logger.info(
"Created temporary executable at path %s, with contents\n%s", path, contents
)
return path
def _get_numactl_cli_options(
*,
command_args: tuple[str, ...],
gpu_index: int,
@ -112,63 +219,20 @@ def _maybe_get_numactl_options(
Returns:
Depending on numa_options, something like
("--cpunodebind=0", "--preferred=0")
("--cpunodebind=0")
"""
try:
_raise_if_numactl_not_available()
if numa_options.affinity_mode == AffinityMode.NODE:
numactl_command_options = _get_node_numactl_options(gpu_index=gpu_index)
elif numa_options.affinity_mode == AffinityMode.SOCKET:
numactl_command_options = _get_socket_numactl_options(gpu_index=gpu_index)
elif numa_options.affinity_mode == AffinityMode.EXCLUSIVE:
numactl_command_options = _get_exclusive_numactl_options(
gpu_index=gpu_index
)
elif numa_options.affinity_mode == AffinityMode.CORE_COMPLEX:
numactl_command_options = _get_core_complex_numactl_options(
gpu_index=gpu_index
)
else:
raise ValueError(
f"Affinity mode {numa_options.affinity_mode} not supported."
)
if numa_options.affinity_mode == AffinityMode.NODE:
numactl_command_options = _get_node_numactl_options(gpu_index=gpu_index)
elif numa_options.affinity_mode == AffinityMode.SOCKET:
numactl_command_options = _get_socket_numactl_options(gpu_index=gpu_index)
elif numa_options.affinity_mode == AffinityMode.EXCLUSIVE:
numactl_command_options = _get_exclusive_numactl_options(gpu_index=gpu_index)
elif numa_options.affinity_mode == AffinityMode.CORE_COMPLEX:
numactl_command_options = _get_core_complex_numactl_options(gpu_index=gpu_index)
else:
raise ValueError(f"Affinity mode {numa_options.affinity_mode} not supported.")
if numa_options.should_fall_back_if_binding_fails:
_raise_if_numactl_fails_dry_run(numactl_options=numactl_command_options)
signpost_event(
category="numa_binding",
name="wrap_command_success",
parameters={
"original_command_args": command_args,
"gpu_index": gpu_index,
"numa_options": numa_options,
"numactl_command_options": numactl_command_options,
},
)
return numactl_command_options
except Exception:
signpost_event(
category="numa_binding",
name="wrap_command_exception",
parameters={
"traceback": traceback.format_exc(),
"original_command_args": command_args,
"gpu_index": gpu_index,
"numa_options": numa_options,
},
)
logger.exception(
"""Failed to wrap command with NUMA bindings.
Input:
command_args=%r,
gpu_index=%d,
numa_options=%r,
""",
command_args,
gpu_index,
numa_options,
)
raise
return numactl_command_options
def _raise_if_numactl_fails_dry_run(*, numactl_options: tuple[str, ...]) -> None:
@ -177,9 +241,14 @@ def _raise_if_numactl_fails_dry_run(*, numactl_options: tuple[str, ...]) -> None
command_args=("true",),
numactl_options=numactl_options,
)
temporary_executable_path = _get_temporary_executable_for_command(
command_args=noop_args
)
try:
subprocess.run(
noop_args,
run(
(temporary_executable_path,),
stdout=subprocess.DEVNULL,
# These allow us to capture the stderr as text
stderr=subprocess.PIPE,
@ -219,14 +288,11 @@ def _get_node_numactl_options(*, gpu_index: int) -> tuple[str, ...]:
Core logic of 'node' numa strategy.
Returns options to be used with numactl. E.g.,
("--cpunodebind=0", "--preferred=0").
("--cpunodebind=0").
"""
numa_node_index = _get_numa_node_index_for_gpu_index(gpu_index=gpu_index)
return (
f"--cpunodebind={numa_node_index}",
f"--preferred={numa_node_index}",
)
return (f"--cpunodebind={numa_node_index}",)
def _get_socket_numactl_options(*, gpu_index: int) -> tuple[str, ...]:
@ -242,14 +308,7 @@ def _get_socket_numactl_options(*, gpu_index: int) -> tuple[str, ...]:
)
numa_node_indices_str = _get_ranges_str_from_ints(numa_node_indices)
return (
f"--cpunodebind={numa_node_indices_str}",
(
f"--preferred-many={numa_node_indices_str}"
if len(numa_node_indices) > 1
else f"--preferred={numa_node_indices_str}"
),
)
return (f"--cpunodebind={numa_node_indices_str}",)
def _get_exclusive_numactl_options(*, gpu_index: int) -> tuple[str, ...]:
@ -321,7 +380,6 @@ def _get_exclusive_numactl_options(*, gpu_index: int) -> tuple[str, ...]:
return (
f"--physcpubind={_get_ranges_str_from_ints(logical_cpu_indices_for_original_gpu)}",
f"--preferred={numa_node_index}",
)
@ -371,7 +429,6 @@ def _get_core_complex_numactl_options(*, gpu_index: int) -> tuple[str, ...]:
return (
f"--physcpubind={_get_ranges_str_from_ints(logical_cpu_indices_for_original_gpu)}",
f"--preferred={numa_node_index}",
)