Compare commits

...

1 Commits

Author SHA1 Message Date
9545b961df Revert "[Inductor] support masked vectorization for the tail_loop for float64 datatype (#163316)"
This reverts commit e9d89734274a4a2640fa77b898c800a87d1d874e.

Reverted https://github.com/pytorch/pytorch/pull/163316 on behalf of https://github.com/clee2000 due to seems to have broken some no_gpu tests? test/inductor/test_cpu_repro.py::CPUReproTests::test_double_reduction_vec [GH job link](https://github.com/pytorch/pytorch/actions/runs/18689033019/job/53290772740) [HUD commit link](e9d8973427) ([comment](https://github.com/pytorch/pytorch/pull/163316#issuecomment-3428210509))
2025-10-21 11:23:58 -07:00
7 changed files with 310 additions and 323 deletions

View File

@ -4810,22 +4810,6 @@ class CPUReproTests(TestCase):
self.common(fn, (x,))
check_metrics_vec_kernel_count(1)
# Tail vectorization case
x = torch.randn((22, 22), dtype=torch.double)
torch._dynamo.reset()
metrics.reset()
with torch.no_grad():
expected = fn(x)
compiled_fn = torch.compile(fn)
actual, code = run_and_get_cpp_code(compiled_fn, x)
self.assertEqual(expected, actual)
# 1 generated vec kernel
self.assertEqual(metrics.generated_cpp_vec_kernel_count, 1)
# Check that both main and tail loops are vectorized
FileCheck().check_count(
"at::vec::VectorizedN<double,2>::loadu", 2, exactly=True
).run(code)
def test_double_reduction_vec(self):
def fn(x):
return x.sum(dim=1)
@ -4835,22 +4819,6 @@ class CPUReproTests(TestCase):
self.common(fn, (x,))
check_metrics_vec_kernel_count(1)
# Tail vectorization case
x = torch.randn((22, 22), dtype=torch.double)
torch._dynamo.reset()
metrics.reset()
with torch.no_grad():
expected = fn(x)
compiled_fn = torch.compile(fn)
actual, code = run_and_get_cpp_code(compiled_fn, x)
self.assertEqual(expected, actual)
# 1 generated vec kernel
self.assertEqual(metrics.generated_cpp_vec_kernel_count, 1)
# Check that both main and tail loops are vectorized
FileCheck().check_count(
"at::vec::VectorizedN<double,2>::loadu", 2, exactly=True
).run(code)
def test_convert_fp32_to_double_vec(self):
def fn(x):
return x.to(torch.double)
@ -4860,22 +4828,6 @@ class CPUReproTests(TestCase):
self.common(fn, (x,))
check_metrics_vec_kernel_count(1)
# Tail vectorization case
x = torch.randn(22, 22)
torch._dynamo.reset()
metrics.reset()
with torch.no_grad():
expected = fn(x)
compiled_fn = torch.compile(fn)
actual, code = run_and_get_cpp_code(compiled_fn, x)
self.assertEqual(expected, actual)
# 1 generated vec kernel
self.assertEqual(metrics.generated_cpp_vec_kernel_count, 1)
# Check that both main and tail loops are vectorized
FileCheck().check_count(
"at::vec::convert<double,2,float,1>", 2, exactly=True
).run(code)
def test_convert_double_to_fp32_vec(self):
def fn(x):
return x.to(torch.float32)
@ -4885,22 +4837,6 @@ class CPUReproTests(TestCase):
self.common(fn, (x,))
check_metrics_vec_kernel_count(1)
# Tail vectorization case
x = torch.randn((22, 22), dtype=torch.double)
torch._dynamo.reset()
metrics.reset()
with torch.no_grad():
expected = fn(x)
compiled_fn = torch.compile(fn)
actual, code = run_and_get_cpp_code(compiled_fn, x)
self.assertEqual(expected, actual)
# 1 generated vec kernel
self.assertEqual(metrics.generated_cpp_vec_kernel_count, 1)
# Check that both main and tail loops are vectorized
FileCheck().check_count(
"at::vec::convert<float,1,double,2>", 2, exactly=True
).run(code)
def test_no_redundant_to_dtypes_between_fused_scheduler_node(self):
# https://github.com/pytorch/pytorch/issues/115260
p0 = torch.tensor([1.0879], dtype=torch.float16)

View File

@ -5,7 +5,6 @@ from __future__ import annotations
import json
import sys
from dataclasses import dataclass
from multiprocessing.context import SpawnProcess
from typing import Any, Optional
from unittest import skipUnless
from unittest.mock import mock_open, patch
@ -13,10 +12,9 @@ from unittest.mock import mock_open, patch
import torch
from torch._utils_internal import signpost_event
from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs, start_processes
from torch.distributed.elastic.multiprocessing.subprocess_handler import (
SubprocessHandler,
)
from torch.distributed.elastic.multiprocessing.api import _wrap
from torch.numa.binding import (
_bind_all_threads_in_current_process_to_logical_cpus,
_get_ranges_str_from_ints,
_get_set_of_int_from_ranges_str,
AffinityMode,
@ -66,6 +64,7 @@ class NumaBindingTest(TestCase):
patch("os.listdir", new=self._mock_listdir),
patch("os.sched_getaffinity", new=self._mock_sched_getaffinity),
patch("torch.numa.binding.signpost_event", self._mock_signpost_event),
patch("torch.numa.binding.shutil.which", return_value="/usr/bin/numactl"),
]
for context_manager in self._context_managers_to_apply_to_all_tests:
@ -230,41 +229,12 @@ class NumaBindingTest(TestCase):
def _mock_sched_getaffinity(self, pid: int) -> set[int]:
return set(range(self._mock_num_logical_cpus))
def _start_processes_for_str_entrypoint_and_get_sched_setaffinity_cpus(
def _start_processes_for_str_entrypoint_and_get_command_args(
self, *, numa_options: Optional[NumaOptions], target_local_rank: int
) -> Optional[set[int]]:
active_local_rank = None
target_sched_setaffinity_logical_cpu_indices = None
real_subprocess_handler_init = SubprocessHandler.__init__
def mock_SubprocessHandler__init__(*args, **kwargs) -> None:
nonlocal active_local_rank
active_local_rank = kwargs["local_rank_id"]
return real_subprocess_handler_init(*args, **kwargs)
def mock_sched_setaffinity(*args, **kwargs) -> None:
nonlocal target_sched_setaffinity_logical_cpu_indices
if (
active_local_rank == target_local_rank
# We only care about the first call, not the second
# one where it gets reset
and target_sched_setaffinity_logical_cpu_indices is None
):
target_sched_setaffinity_logical_cpu_indices = args[1]
with (
patch(
"os.sched_setaffinity", mock_sched_setaffinity
) as mock_sched_setaffinity,
patch(
"torch.distributed.elastic.multiprocessing.subprocess_handler.subprocess_handler.Popen"
),
patch(
"torch.distributed.elastic.multiprocessing.subprocess_handler.SubprocessHandler.__init__",
mock_SubprocessHandler__init__,
),
):
) -> tuple[str, ...]:
with patch(
"torch.distributed.elastic.multiprocessing.subprocess_handler.subprocess_handler.Popen"
) as mock_popen:
start_processes(
name="test_process",
entrypoint="echo",
@ -277,55 +247,44 @@ class NumaBindingTest(TestCase):
logs_specs=DefaultLogsSpecs(),
numa_options=numa_options,
)
return target_sched_setaffinity_logical_cpu_indices
call_args = next(
call_args
for call_args in mock_popen.call_args_list
if call_args.kwargs.get("env", {}).get("LOCAL_RANK")
== str(target_local_rank)
)
return call_args.kwargs["args"]
def _start_processes_for_callable_entrypoint_and_get_sched_setaffinity_cpus(
self, *, numa_options: Optional[NumaOptions], target_local_rank: int
) -> Optional[set[int]]:
active_local_rank = None
target_sched_setaffinity_logical_cpu_indices = None
real_process__init__ = SpawnProcess.__init__
def _mock_process__init__(*args, **kwargs) -> None:
nonlocal active_local_rank
active_local_rank = kwargs["args"][1]
return real_process__init__(*args, **kwargs)
def mock_sched_setaffinity(*args, **kwargs) -> None:
nonlocal target_sched_setaffinity_logical_cpu_indices
if (
active_local_rank == target_local_rank
# We only care about the first call, not the second
# one where it gets reset
and target_sched_setaffinity_logical_cpu_indices is None
):
target_sched_setaffinity_logical_cpu_indices = args[1]
target_sched_setaffinity_logical_cpu_indices = args[1]
with (
patch(
"os.sched_setaffinity", mock_sched_setaffinity
) as mock_sched_setaffinity,
patch("multiprocessing.context.SpawnProcess.start"),
patch(
"multiprocessing.context.SpawnProcess.__init__", _mock_process__init__
),
patch("multiprocessing.process.BaseProcess.sentinel", 1),
# Prevent hanging
patch(
"multiprocessing.synchronize.Event.wait",
lambda self, timeout=None: None,
),
):
start_processes(
name="test_process",
entrypoint=lambda x: x,
args=dict.fromkeys(range(self._mock_device_count()), (0,)),
envs={
i: {"LOCAL_RANK": str(i)} for i in range(self._mock_device_count())
},
logs_specs=DefaultLogsSpecs(),
def dummy_fn():
pass
import torch.multiprocessing as mp
ctx = mp.get_context()
mock_queue = ctx.SimpleQueue()
mock_event = ctx.Event()
with patch("os.sched_setaffinity", mock_sched_setaffinity):
mock_event.set() # Prevent hanging
# This is the entrypoint for subprocesses with Callable entrypoints
_wrap(
local_rank=target_local_rank,
fn=dummy_fn,
args={target_local_rank: ()},
envs={target_local_rank: {}},
stdout_redirects={target_local_rank: ""},
stderr_redirects={target_local_rank: ""},
ret_vals={target_local_rank: mock_queue},
queue_finished_reading_event=mock_event,
numa_options=numa_options,
)
@ -340,19 +299,17 @@ class NumaBindingTest(TestCase):
num_physical_core_per_l3_cache=2,
)
bound_logical_cpu_indices = (
self._start_processes_for_str_entrypoint_and_get_sched_setaffinity_cpus(
numa_options=NumaOptions(affinity_mode=AffinityMode.NODE),
target_local_rank=11,
)
command_args = self._start_processes_for_str_entrypoint_and_get_command_args(
numa_options=NumaOptions(affinity_mode=AffinityMode.NODE),
target_local_rank=11,
)
self.assertEqual(
bound_logical_cpu_indices,
command_args,
# There are 8 numa nodes and 2 GPUs per numa node, so GPU 11 would be
# on numa node 11 // 2 = 5.
# Each numa node has 4 * 2 * 2 = 16 logical CPUs
# Numa node 5 has CPUs 80-95
set(range(80, 96)),
("numactl", "--physcpubind=80-95", "echo", "Hello, world!"),
)
def test_no_numa_binding_if_numa_options_not_provided(self) -> None:
@ -364,14 +321,12 @@ class NumaBindingTest(TestCase):
num_physical_core_per_l3_cache=2,
)
bound_logical_cpu_indices = (
self._start_processes_for_str_entrypoint_and_get_sched_setaffinity_cpus(
numa_options=None, target_local_rank=11
)
command_args = self._start_processes_for_str_entrypoint_and_get_command_args(
numa_options=None, target_local_rank=11
)
self.assertEqual(
bound_logical_cpu_indices,
None,
command_args,
("echo", "Hello, world!"),
)
def test_default_numa_binding(self) -> None:
@ -421,8 +376,8 @@ class NumaBindingTest(TestCase):
side_effect=Exception("Mock exception!"),
),
):
bound_logical_cpu_indices = (
self._start_processes_for_str_entrypoint_and_get_sched_setaffinity_cpus(
command_args = (
self._start_processes_for_str_entrypoint_and_get_command_args(
numa_options=NumaOptions(
affinity_mode=AffinityMode.NODE,
should_fall_back_if_binding_fails=True,
@ -435,9 +390,39 @@ class NumaBindingTest(TestCase):
signpost_patch.call_args.kwargs["parameters"]["traceback"],
)
self.assertEqual(
bound_logical_cpu_indices,
# We should just reset to the original CPU affinity, which is all the CPUs
set(range(4)),
command_args,
("echo", "Hello, world!"),
)
def test_fallback_if_numactl_not_available(self) -> None:
self._add_mock_hardware(
num_sockets=2,
num_numa_nodes_per_socket=1,
num_gpus_per_numa_node=1,
num_l3_caches_per_numa_node=1,
num_physical_core_per_l3_cache=1,
)
with (
patch("torch.numa.binding.signpost_event") as signpost_patch,
patch("torch.numa.binding.shutil.which", return_value=None),
):
command_args = (
self._start_processes_for_str_entrypoint_and_get_command_args(
numa_options=NumaOptions(
affinity_mode=AffinityMode.NODE,
should_fall_back_if_binding_fails=True,
),
target_local_rank=0,
)
)
self.assertIn(
"numactl CLI is required for NUMA binding",
signpost_patch.call_args.kwargs["parameters"]["traceback"],
)
self.assertEqual(
command_args,
("echo", "Hello, world!"),
)
def test_explicit_numa_options_overrides_default(self) -> None:
@ -493,18 +478,16 @@ class NumaBindingTest(TestCase):
num_physical_core_per_l3_cache=2,
)
bound_logical_cpu_indices = (
self._start_processes_for_str_entrypoint_and_get_sched_setaffinity_cpus(
numa_options=NumaOptions(affinity_mode=AffinityMode.SOCKET),
target_local_rank=15,
)
command_args = self._start_processes_for_str_entrypoint_and_get_command_args(
numa_options=NumaOptions(affinity_mode=AffinityMode.SOCKET),
target_local_rank=15,
)
self.assertEqual(
bound_logical_cpu_indices,
command_args,
# GPU 15 is on numa node 15 // 2 = 7, which is on socket 3 (numa nodes 6 and 7)
# Each numa node has 4 * 2 * 2 = 16 logical CPUs
# Numa nodes 6 and 7 have CPUs 96-111 and 112-127
set(range(96, 128)),
("numactl", "--physcpubind=96-127", "echo", "Hello, world!"),
)
def test_socket_numa_binding_with_single_numa_per_socket(self) -> None:
@ -516,18 +499,16 @@ class NumaBindingTest(TestCase):
num_physical_core_per_l3_cache=2,
)
bound_logical_cpu_indices = (
self._start_processes_for_str_entrypoint_and_get_sched_setaffinity_cpus(
numa_options=NumaOptions(affinity_mode=AffinityMode.SOCKET),
target_local_rank=7,
)
command_args = self._start_processes_for_str_entrypoint_and_get_command_args(
numa_options=NumaOptions(affinity_mode=AffinityMode.SOCKET),
target_local_rank=7,
)
self.assertEqual(
bound_logical_cpu_indices,
command_args,
# GPU 7 is on numa node 7 // 2 = 3, which is socket 3 by itself
# Each numa node has 4 * 2 * 2 = 16 logical CPUs
# Numa node 3 has CPUs 48-63
set(range(48, 64)),
("numactl", "--physcpubind=48-63", "echo", "Hello, world!"),
)
def test_exclusive_numa_binding(self) -> None:
@ -539,30 +520,26 @@ class NumaBindingTest(TestCase):
num_physical_core_per_l3_cache=3,
)
bound_logical_cpu_indices_0 = (
self._start_processes_for_str_entrypoint_and_get_sched_setaffinity_cpus(
numa_options=NumaOptions(affinity_mode=AffinityMode.EXCLUSIVE),
target_local_rank=0,
)
command_args_0 = self._start_processes_for_str_entrypoint_and_get_command_args(
numa_options=NumaOptions(affinity_mode=AffinityMode.EXCLUSIVE),
target_local_rank=0,
)
self.assertEqual(
bound_logical_cpu_indices_0,
command_args_0,
# Gets an extra physical core due to odd number of physical cores on numa node
# 3 physical cores total, 2 GPUs: GPU 0 gets 2 physical cores (CPUs 0-3)
set(range(4)),
("numactl", "--physcpubind=0-3", "echo", "Hello, world!"),
)
bound_logical_cpu_indices_1 = (
self._start_processes_for_str_entrypoint_and_get_sched_setaffinity_cpus(
numa_options=NumaOptions(affinity_mode=AffinityMode.EXCLUSIVE),
target_local_rank=1,
)
command_args_1 = self._start_processes_for_str_entrypoint_and_get_command_args(
numa_options=NumaOptions(affinity_mode=AffinityMode.EXCLUSIVE),
target_local_rank=1,
)
self.assertEqual(
bound_logical_cpu_indices_1,
command_args_1,
# Does not get an extra physical core, since the 1st GPU already took the extra.
# GPU 1 gets 1 physical core (CPUs 4-5)
set(range(4, 6)),
("numactl", "--physcpubind=4-5", "echo", "Hello, world!"),
)
def test_exclusive_raises_if_too_few_physical_cores(self) -> None:
@ -578,7 +555,7 @@ class NumaBindingTest(TestCase):
RuntimeError,
"There are only 1 physical cores on numa_node_index=0, but there are 2 GPUs associated with this NUMA node.",
):
self._start_processes_for_str_entrypoint_and_get_sched_setaffinity_cpus(
self._start_processes_for_str_entrypoint_and_get_command_args(
numa_options=NumaOptions(affinity_mode=AffinityMode.EXCLUSIVE),
target_local_rank=1,
)
@ -592,18 +569,16 @@ class NumaBindingTest(TestCase):
num_physical_core_per_l3_cache=3,
)
bound_logical_cpu_indices = (
self._start_processes_for_str_entrypoint_and_get_sched_setaffinity_cpus(
numa_options=NumaOptions(affinity_mode=AffinityMode.CORE_COMPLEX),
target_local_rank=3,
)
command_args = self._start_processes_for_str_entrypoint_and_get_command_args(
numa_options=NumaOptions(affinity_mode=AffinityMode.CORE_COMPLEX),
target_local_rank=3,
)
self.assertEqual(
bound_logical_cpu_indices,
command_args,
# GPU 3 is on numa node 3 // 2 = 1, relative GPU index is 3 % 2 = 1
# The second L3 on the second numa node (numa node 1)
# Second numa node starts at CPU 18, second L3 cache is CPUs 24-29
set(range(24, 30)),
("numactl", "--physcpubind=24-29", "echo", "Hello, world!"),
)
def test_core_complex_numa_binding_with_fewer_l3_than_gpu(self) -> None:
@ -615,18 +590,16 @@ class NumaBindingTest(TestCase):
num_physical_core_per_l3_cache=3,
)
bound_logical_cpu_indices = (
self._start_processes_for_str_entrypoint_and_get_sched_setaffinity_cpus(
numa_options=NumaOptions(affinity_mode=AffinityMode.CORE_COMPLEX),
target_local_rank=3,
)
command_args = self._start_processes_for_str_entrypoint_and_get_command_args(
numa_options=NumaOptions(affinity_mode=AffinityMode.CORE_COMPLEX),
target_local_rank=3,
)
self.assertEqual(
bound_logical_cpu_indices,
command_args,
# GPU 3 is on numa node 3 // 2 = 1, relative GPU index is 3 % 2 = 1
# With 1 L3 cache per numa node, GPU 3 uses L3 cache index 1 % 1 = 0 (the only cache)
# Second numa node starts at CPU 6, single L3 cache spans CPUs 6-11
set(range(6, 12)),
("numactl", "--physcpubind=6-11", "echo", "Hello, world!"),
)
def test_core_complex_prefers_caches_with_more_cpus(self) -> None:
@ -640,17 +613,17 @@ class NumaBindingTest(TestCase):
# Only some subset of the CPUs are available this time.
with patch("os.sched_getaffinity", return_value={0, 4, 6, 7, 9}):
bound_logical_cpu_indices = (
self._start_processes_for_str_entrypoint_and_get_sched_setaffinity_cpus(
command_args = (
self._start_processes_for_str_entrypoint_and_get_command_args(
numa_options=NumaOptions(affinity_mode=AffinityMode.CORE_COMPLEX),
target_local_rank=0,
)
)
# Binds to the second L3 because it has the most available CPUs
self.assertEqual(
bound_logical_cpu_indices,
# Binds to the second L3 because it has the most available CPUs
{6, 7, 9},
command_args,
("numactl", "--physcpubind=6-7,9", "echo", "Hello, world!"),
)
def test_core_complex_tiebreak_prefers_lower_cache_key(self) -> None:
@ -666,18 +639,16 @@ class NumaBindingTest(TestCase):
num_physical_core_per_l3_cache=1,
)
bound_logical_cpu_indices = (
self._start_processes_for_str_entrypoint_and_get_sched_setaffinity_cpus(
numa_options=NumaOptions(affinity_mode=AffinityMode.CORE_COMPLEX),
target_local_rank=0,
)
command_args = self._start_processes_for_str_entrypoint_and_get_command_args(
numa_options=NumaOptions(affinity_mode=AffinityMode.CORE_COMPLEX),
target_local_rank=0,
)
# 1 numa node, 2 L3 caches, 1 physical core per L3 cache = 2 logical CPUs per cache
# L3 cache 0: CPUs 0-1, L3 cache 1: CPUs 2-3
# Both have same number of CPUs, so prefer lower cache key (0)
self.assertEqual(
bound_logical_cpu_indices,
# 1 numa node, 2 L3 caches, 1 physical core per L3 cache = 2 logical CPUs per cache
# L3 cache 0: CPUs 0-1, L3 cache 1: CPUs 2-3
# Both have same number of CPUs, so prefer lower cache key (0)
set(range(2)),
command_args,
("numactl", "--physcpubind=0-1", "echo", "Hello, world!"),
)
def test_binds_to_node_0_if_node_stored_as_minus_one(self) -> None:
@ -698,18 +669,16 @@ class NumaBindingTest(TestCase):
contents="-1",
)
bound_logical_cpu_indices = (
self._start_processes_for_str_entrypoint_and_get_sched_setaffinity_cpus(
numa_options=NumaOptions(affinity_mode=AffinityMode.NODE),
target_local_rank=0,
)
command_args = self._start_processes_for_str_entrypoint_and_get_command_args(
numa_options=NumaOptions(affinity_mode=AffinityMode.NODE),
target_local_rank=0,
)
# GPU 0 has numa node stored as -1, which is treated as numa node 0
# Each numa node has 1 * 1 * 2 = 2 logical CPUs
# Numa node 0 has CPUs 0-1
self.assertEqual(
bound_logical_cpu_indices,
# GPU 0 has numa node stored as -1, which is treated as numa node 0
# Each numa node has 1 * 1 * 2 = 2 logical CPUs
# Numa node 0 has CPUs 0-1
set(range(2)),
command_args,
("numactl", "--physcpubind=0-1", "echo", "Hello, world!"),
)
def test_callable_entrypoint_basic(self) -> None:
@ -764,6 +733,31 @@ class NumaBindingTest(TestCase):
def test_get_range_str_from_ints(self) -> None:
self.assertEqual(_get_ranges_str_from_ints([7, 0, 1, 6, 2, 4]), "0-2,4,6-7")
def test_bind_all_threads_in_current_process_to_logical_cpus(self) -> None:
self._add_mock_hardware(
num_sockets=1,
num_numa_nodes_per_socket=1,
num_gpus_per_numa_node=1,
num_l3_caches_per_numa_node=1,
num_physical_core_per_l3_cache=1,
)
self._mock_file_contents(file_path="/proc/self/task/8675309", contents="")
self._mock_file_contents(
# The exception from casting not_an_integer to int should get silenced.
file_path="/proc/self/task/not_an_integer",
contents="",
)
with patch("os.sched_setaffinity") as mock_sched_setaffinity:
_bind_all_threads_in_current_process_to_logical_cpus(
logical_cpu_indices={2, 0, 9} # arbitrary
)
self.assertEqual(mock_sched_setaffinity.call_count, 2)
mock_sched_setaffinity.assert_any_call(0, {2, 0, 9})
mock_sched_setaffinity.assert_any_call(8675309, {2, 0, 9})
if __name__ == "__main__":
run_tests()

View File

@ -159,7 +159,6 @@ VECTORIZABLE_DTYPES: list[torch.dtype] = [
]
MASKED_VECTORIZABLE_DTYPES: list[torch.dtype] = [
torch.float64,
torch.float,
torch.bfloat16,
torch.float16,

View File

@ -38,7 +38,7 @@ from torch.distributed.elastic.multiprocessing.subprocess_handler import (
SubprocessHandler,
)
from torch.distributed.elastic.multiprocessing.tail_log import TailLog
from torch.numa.binding import NumaOptions
from torch.numa.binding import maybe_with_numa_binding, NumaOptions
IS_WINDOWS = sys.platform == "win32"
@ -627,6 +627,7 @@ def _wrap(
stderr_redirects: dict[int, str], # redirect file for stderr (to console if None)
ret_vals: dict[int, mp.SimpleQueue],
queue_finished_reading_event: synchronize.Event,
numa_options: Optional[NumaOptions],
) -> None:
# get the per-rank params up front so we fail fast if no mapping is found
args_ = args[local_rank]
@ -643,6 +644,9 @@ def _wrap(
os.environ[k] = v
with stdout_cm, stderr_cm:
fn = maybe_with_numa_binding(
fn, gpu_index=local_rank, numa_options=numa_options
)
ret = record(fn)(*args_)
ret_val_.put(ret)
queue_finished_reading_event.wait()
@ -703,12 +707,12 @@ class MultiprocessContext(PContext):
self.stderrs,
self._ret_vals,
self._worker_finished_event,
self._numa_options,
),
nprocs=self.nprocs,
join=False,
daemon=False,
start_method=self.start_method,
numa_options=self._numa_options,
)
def _is_done(self) -> bool:

View File

@ -11,10 +11,7 @@ import sys
from subprocess import Popen
from typing import Any, Optional
from torch.numa.binding import (
maybe_temporarily_apply_numa_binding_to_current_thread,
NumaOptions,
)
from torch.numa.binding import maybe_apply_numa_binding_to_command_args, NumaOptions
__all__ = ["SubprocessHandler"]
@ -53,14 +50,15 @@ class SubprocessHandler:
env_vars.update(env)
args_str = (entrypoint, *[str(e) for e in args])
args_str = maybe_apply_numa_binding_to_command_args(
args_str,
gpu_index=local_rank_id,
numa_options=numa_options,
)
self.local_rank_id = local_rank_id
# See HACK [NUMA inheritance] in spawn.py for context.
with maybe_temporarily_apply_numa_binding_to_current_thread(
gpu_index=local_rank_id, numa_options=numa_options
):
self.proc: Popen = self._popen(args_str, env_vars)
self.proc: Popen = self._popen(args_str, env_vars)
def _popen(self, args: tuple, env: dict[str, str]) -> Popen:
kwargs: dict[str, Any] = {}

View File

@ -12,11 +12,6 @@ import warnings
from concurrent.futures import as_completed, ThreadPoolExecutor
from typing import Optional
from torch.numa.binding import (
maybe_temporarily_apply_numa_binding_to_current_thread,
NumaOptions,
)
from . import _prctl_pr_set_pdeathsig # type: ignore[attr-defined]
@ -242,7 +237,6 @@ def start_processes(
join=True,
daemon=False,
start_method="spawn",
numa_options: Optional[NumaOptions] = None,
):
# To speed up performance in certain cases (see https://github.com/pytorch/pytorch/issues/133010),
# this func will start processes in parallel if start_method is 'forkserver'.
@ -281,23 +275,7 @@ def start_processes(
daemon=daemon,
)
# HACK [NUMA inheritance]: Subprocesses inherit the parent thread's CPU
# affinity. So, we temporarily apply the bindings to the current thread,
# and then immediately undo them.
# This is necessary because the alternatives would be to
# either
# 1. Use numactl CLI. However, Python's multiprocessing library
# does not provide an API which would allow us to prepend
# the command it runs with numactl options.
# 2. Wrap the provided function such that it first applies
# NUMA bindings, and then executes as expected. However, this
# can result in worse memory locality, because torch and CUDA
# initialization would occur before applying the bindings, thus
# allowing some memory to be allocated on the wrong NUMA nodes.
with maybe_temporarily_apply_numa_binding_to_current_thread(
gpu_index=i, numa_options=numa_options
):
process.start()
process.start()
return i, process, tf.name
if not start_parallel:

View File

@ -1,12 +1,13 @@
import os
import shutil
import traceback
from collections import defaultdict
from collections.abc import Callable, Iterable, Iterator
from contextlib import contextmanager
from collections.abc import Callable, Iterable
from dataclasses import asdict, dataclass
from enum import Enum
from functools import wraps
from logging import getLogger
from typing import Optional, TypeVar
from typing import Optional, ParamSpec, TypeVar
import torch
from torch._utils_internal import signpost_event
@ -14,7 +15,8 @@ from torch._utils_internal import signpost_event
__all__ = [
"AffinityMode",
"maybe_temporarily_apply_numa_binding_to_current_thread",
"maybe_apply_numa_binding_to_command_args",
"maybe_with_numa_binding",
"NumaOptions",
]
@ -39,7 +41,7 @@ class NumaOptions:
"""
If true, we will fall back to using the original command/entrypoint if we fail to compute
or apply NUMA bindings.
NUMA bindings.
You should avoid using this option! It is only intended as a safety mechanism for facilitating
mass rollouts of numa binding.
@ -47,57 +49,85 @@ class NumaOptions:
should_fall_back_if_binding_fails: bool = False
@contextmanager
def maybe_temporarily_apply_numa_binding_to_current_thread(
*, gpu_index: int, numa_options: Optional[NumaOptions]
) -> Iterator[None]:
"""
1. Applies NUMA binding to the current thread, suitable for the thread
which will be interacting with GPU gpu_index.
2. Resets to the original CPU affinity before exiting the context manager.
"""
def maybe_apply_numa_binding_to_command_args(
command_args: tuple[str, ...],
*,
gpu_index: int,
numa_options: Optional[NumaOptions],
) -> tuple[str, ...]:
if numa_options is None:
yield
return
return command_args
original_logical_cpu_indices = _get_allowed_cpu_indices_for_current_thread()
_apply_numa_binding_to_current_thread(
gpu_index=gpu_index, numa_options=numa_options
)
yield
_bind_current_thread_to_logical_cpus(
logical_cpu_indices=original_logical_cpu_indices
)
kwargs = {
"command_args": command_args,
"gpu_index": gpu_index,
"numa_options": asdict(numa_options),
}
try:
logical_cpu_indices = _get_validated_logical_cpus_to_bind_to(
gpu_index=gpu_index,
numa_options=numa_options,
)
wrapped_command_args = _assemble_numactl_command_args(
original_command_args=command_args,
logical_cpu_indices=logical_cpu_indices,
)
signpost_event(
category="numa_binding",
name="apply_success",
parameters={
**kwargs,
"wrapped_command": wrapped_command_args,
},
)
return wrapped_command_args
except Exception:
_handle_exception(numa_options=numa_options, logger_kwargs=kwargs)
return command_args
def _apply_numa_binding_to_current_thread(
_TParams = ParamSpec("_TParams")
_TReturn = TypeVar("_TReturn")
def maybe_with_numa_binding(
func: Callable[_TParams, _TReturn],
*,
gpu_index: int,
numa_options: Optional[NumaOptions],
) -> Callable[_TParams, _TReturn]:
if numa_options is None:
return func
@wraps(func)
def wrapped(*args: _TParams.args, **kwargs: _TParams.kwargs) -> _TReturn:
_maybe_apply_numa_binding_to_current_thread(
gpu_index=gpu_index,
numa_options=numa_options,
)
return func(*args, **kwargs)
return wrapped
def _maybe_apply_numa_binding_to_current_thread(
*, gpu_index: int, numa_options: NumaOptions
) -> None:
kwargs = {
"gpu_index": gpu_index,
"numa_options": asdict(numa_options),
}
logger.info("Attempting to apply NUMA binding, given input %r", kwargs)
try:
logical_cpu_indices = _get_logical_cpus_to_bind_to(
gpu_index=gpu_index, numa_options=numa_options
)
logger.info(
"Computed logical_cpu_indices=%s for NUMA binding",
_get_ranges_str_from_ints(logical_cpu_indices),
logical_cpu_indices = _get_validated_logical_cpus_to_bind_to(
gpu_index=gpu_index,
numa_options=numa_options,
)
_raise_if_logical_cpu_indices_invalid(logical_cpu_indices=logical_cpu_indices)
logger.info(
"Validated logical_cpu_indices=%s for NUMA binding",
_get_ranges_str_from_ints(logical_cpu_indices),
)
_bind_current_thread_to_logical_cpus(logical_cpu_indices=logical_cpu_indices)
logger.info(
"Successfully bound to logical_cpu_indices=%s for NUMA binding",
_get_ranges_str_from_ints(logical_cpu_indices),
_bind_all_threads_in_current_process_to_logical_cpus(
logical_cpu_indices=logical_cpu_indices
)
signpost_event(
@ -109,34 +139,82 @@ def _apply_numa_binding_to_current_thread(
},
)
except Exception:
signpost_event(
category="numa_binding",
name="apply_exception",
parameters={
**kwargs,
"traceback": traceback.format_exc(),
},
_handle_exception(numa_options=numa_options, logger_kwargs=kwargs)
def _assemble_numactl_command_args(
*, original_command_args: tuple[str, ...], logical_cpu_indices: set[int]
) -> tuple[str, ...]:
return (
"numactl",
f"--physcpubind={_get_ranges_str_from_ints(logical_cpu_indices)}",
*original_command_args,
)
def _handle_exception(
*, numa_options: NumaOptions, logger_kwargs: dict[str, object]
) -> None:
signpost_event(
category="numa_binding",
name="apply_exception",
parameters={
**logger_kwargs,
"traceback": traceback.format_exc(),
},
)
logger.exception("Failed to apply NUMA binding for input=%r", logger_kwargs)
if numa_options.should_fall_back_if_binding_fails:
logger.warning(
"Continuing executing without applying NUMA binding, despite exception %s",
traceback.format_exc(),
)
logger.exception("Failed to apply NUMA binding for input=%r", kwargs)
if numa_options.should_fall_back_if_binding_fails:
logger.warning(
"Continuing executing without applying NUMA binding, despite exception %s",
traceback.format_exc(),
)
return None
raise
return
# This function is called within an except block, so silence the warning
# about raise without an exception.
raise # noqa: PLE0704
def _raise_if_logical_cpu_indices_invalid(*, logical_cpu_indices: set[int]) -> None:
def _get_validated_logical_cpus_to_bind_to(
*,
gpu_index: int,
numa_options: NumaOptions,
) -> set[int]:
logical_cpu_indices = _get_logical_cpus_to_bind_to(
gpu_index=gpu_index, numa_options=numa_options
)
_raise_if_binding_invalid(logical_cpu_indices=logical_cpu_indices)
return logical_cpu_indices
def _raise_if_binding_invalid(*, logical_cpu_indices: set[int]) -> None:
# NOTE: numactl CLI is only actually necessary for the str entrypoint path,
# but for simplicity we will just check it no matter what.
if shutil.which("numactl") is None:
raise RuntimeError("numactl CLI is required for NUMA binding")
if not logical_cpu_indices:
raise RuntimeError("Must bind to a non-empty set of CPU indices")
def _bind_current_thread_to_logical_cpus(*, logical_cpu_indices: set[int]) -> None:
# 0 represents the current thread
def _bind_all_threads_in_current_process_to_logical_cpus(
*, logical_cpu_indices: set[int]
) -> None:
# 0 represents the current thread.
# This is outside the try/except because the main thread should always bind successfully.
# pyrefly: ignore # missing-attribute
os.sched_setaffinity(0, logical_cpu_indices) # type: ignore[attr-defined]
for tid_str in os.listdir("/proc/self/task"):
try:
tid = int(tid_str)
# pyrefly: ignore # missing-attribute
os.sched_setaffinity(tid, logical_cpu_indices) # type: ignore[attr-defined]
except Exception:
# Thread may have exited or otherwise become invalid
pass
def _get_logical_cpus_to_bind_to(
*,