Compare commits

..

3 Commits

Author SHA1 Message Date
fa9d5c2dd7 Update on "conv: refactor for lookup table support"
\# why

enable configuring conv operations through the lookup table

\# what

- move kwargs etc into template_heuristics
- add conv specific kernel inputs
- add lookup table e2e test for conv

\# testing

```
python3 -bb -m pytest test/inductor/test_lookup_table.py -k "conv2d" -v
python3 -bb -m pytest test/inductor/test_max_autotune.py -k "conv" -v
```

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

Differential Revision: [D86474839](https://our.internmc.facebook.com/intern/diff/D86474839)

[ghstack-poisoned]
2025-11-10 17:28:12 -08:00
f048cb1f3c Update on "conv: refactor for lookup table support"
\# why

enable configuring conv operations through the lookup table

\# what

- move kwargs etc into template_heuristics
- add conv specific kernel inputs
- add lookup table e2e test for conv

\# testing

```
python3 -bb -m pytest test/inductor/test_lookup_table.py -k "conv2d" -v
python3 -bb -m pytest test/inductor/test_max_autotune.py -k "conv" -v
```

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
2025-11-06 16:29:43 -08:00
c277e07f77 conv: refactor for lookup table support
\# why

enable configuring conv operations through the lookup table

\# what

- move kwargs etc into template_heuristics
- add conv specific kernel inputs
- add lookup table e2e test for conv

\# testing

```
python3 -bb -m pytest test/inductor/test_lookup_table.py -k "conv2d" -v
python3 -bb -m pytest test/inductor/test_max_autotune.py -k "conv" -v
```

[ghstack-poisoned]
2025-11-05 18:57:57 -08:00
34 changed files with 748 additions and 838 deletions

View File

@ -226,8 +226,8 @@ template <
typename B = HostBlock<S>>
struct CachingHostAllocatorImpl {
virtual ~CachingHostAllocatorImpl() {
if (active_) {
active_ = false;
active_ = false;
if (pinned_use_background_threads()) {
getBackgroundThreadPool()->waitWorkComplete();
}
}
@ -260,7 +260,6 @@ struct CachingHostAllocatorImpl {
if (pinned_use_background_threads()) {
// Launch the background thread and process events in a loop.
static bool background_thread_flag [[maybe_unused]] = [this] {
active_ = true;
getBackgroundThreadPool()->run([&]() {
while (active_) {
process_events();
@ -684,9 +683,9 @@ struct CachingHostAllocatorImpl {
alignas(hardware_destructive_interference_size) std::mutex events_mutex_;
std::deque<std::pair<E, B*>> events_; // event queue paired with block
// Indicates whether the event-processing thread pool is active.
// Indicates whether the object is active.
// Set to false in the destructor to signal background threads to stop.
std::atomic<bool> active_{false};
std::atomic<bool> active_{true};
protected:
alignas(hardware_destructive_interference_size) HostStatsStaged stats_;
};

View File

@ -24,13 +24,7 @@ namespace detail {
// radix_sort_pairs doesn't interact with value_t other than to copy
// the data, so we can save template instantiations by reinterpreting
// it as an opaque type.
// We use native integer types for 1/2/4/8-byte values to reduce
// register usage in CUDA kernels. For sizes > 8 fall back to char array.
template <int N> struct alignas(N) OpaqueType { char data[N]; };
template <> struct alignas(1) OpaqueType<1> { uint8_t data; };
template <> struct alignas(2) OpaqueType<2> { uint16_t data; };
template <> struct alignas(4) OpaqueType<4> { uint32_t data; };
template <> struct alignas(8) OpaqueType<8> { uint64_t data; };
template<typename key_t, int value_size>
void radix_sort_pairs_impl(

View File

@ -247,8 +247,8 @@ void binary_kernel_reduce(TensorIteratorBase& iter, ops_t ops, init_t init) {
});
}
template <typename func_t, typename vec_func_t, typename ident_t = double>
void binary_kernel_reduce_vec(TensorIteratorBase& iter, func_t op, vec_func_t vop, ident_t ident = static_cast<ident_t>(0)) {
template <typename func_t, typename vec_func_t>
void binary_kernel_reduce_vec(TensorIteratorBase& iter, func_t op, vec_func_t vop, double ident = 0) {
using traits = binary_function_traits<func_t>;
static_assert(
all_same<

View File

@ -339,13 +339,33 @@ void or_kernel_impl(TensorIterator& iter) {
}
}
template<typename scalar_t>
struct MinValuesOps: public at::native::MinOps<scalar_t> {
using arg_t = typename MinOps<scalar_t>::arg_t;
static scalar_t project(arg_t arg) {
return arg.first;
}
};
void min_values_kernel_impl(TensorIterator& iter) {
// This case is special because of Vectorized<int64_t> does not
// handle upper_bound<int64_t>().
// See: https://github.com/pytorch/pytorch/issues/43254
if (iter.dtype() == kLong || iter.dtype() == kUInt64) {
AT_DISPATCH_V2(iter.dtype(), "min_values_cpu", AT_WRAP([&iter] {
binary_kernel_reduce(
iter,
MinValuesOps<scalar_t>{},
std::pair<scalar_t, int64_t>(upper_bound<scalar_t>(), -1));
}), kLong, kUInt64);
return;
}
AT_DISPATCH_V2(iter.dtype(), "min_values_cpu", AT_WRAP([&iter] {
binary_kernel_reduce_vec(
iter,
[](scalar_t a, scalar_t b) -> scalar_t { return min_impl(a, b); },
[](Vectorized<scalar_t> a, Vectorized<scalar_t> b) { return minimum(a, b); },
upper_bound<scalar_t>());
static_cast<double>(upper_bound<scalar_t>()));
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool);
}

View File

@ -47,10 +47,20 @@ Tensor sgd_out_of_place(
STD_TORCH_CHECK(param.get_device() == -1, "CPU device index = -1");
STD_TORCH_CHECK(param.get_device_index() == -1, "CPU device index = -1");
// testing Tensor strides + stride
STD_TORCH_CHECK(param.strides()[0] == param.stride(0));
int64_t *param_sizes;
int64_t *param_strides;
aoti_torch_get_sizes(param.get(), &param_sizes);
aoti_torch_get_strides(param.get(), &param_strides);
auto out = new_empty(param, param.sizes());
int32_t param_dtype;
aoti_torch_get_dtype(param.get(), &param_dtype);
int32_t param_device_type;
aoti_torch_get_device_type(param.get(), &param_device_type);
AtenTensorHandle out_ath;
aoti_torch_empty_strided(param.dim(), param_sizes, param_strides, param_dtype, param_device_type, param.get_device(), &out_ath);
auto out = Tensor(out_ath);
sgd_math(
reinterpret_cast<float*>(param.data_ptr()),
@ -334,8 +344,6 @@ Tensor my_new_empty_dtype_variant(Tensor t) {
// Still using a std::vector below even though people can just pass in an
// initializer list (which will be implicitly converted to an HeaderOnlyArrayRef)
// directly.
// This is to test that passing in a std::vector works for BC. (It gets
// implicitly converted to HeaderOnlyArrayRef too!)
std::vector<int64_t> sizes = {2, 5};
auto dtype = std::make_optional(torch::headeronly::ScalarType::BFloat16);
return new_empty(t, sizes, dtype);

View File

@ -5789,229 +5789,6 @@ class NCCLTraceTest(NCCLTraceTestBase):
else:
self.assertTrue("duration_ms" not in t["entries"][0])
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@parametrize("timing_enabled", [True, False])
def test_fr_record_reset_circular_buffer_full(self, timing_enabled):
"""
Test that when the circular buffer in entries_ is full and we call reset,
then fill the buffer with new entries, dump_entries returns only the new
entries and not the old ones.
"""
if self.rank == self.MAIN_PROCESS_RANK:
return
# Override buffer size to 10 for faster testing
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10"
pg = self._create_process_group_nccl()
if timing_enabled:
pg._enable_collectives_timing()
device = self.local_device
self.set_thread_name("fr_test_thread")
a = torch.full((3, 4), float(self.rank), device=device)
# Fill the buffer completely with 10 entries
for _ in range(10):
f = pg.allreduce(a)
f.wait()
torch.cuda.synchronize(device=device)
time.sleep(1)
# Verify buffer is full with 10 entries
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
self.assertEqual(len(t["entries"]), 10)
# Now reset the flight recorder
torch._C._distributed_c10d._reset_fr_recording_nccl()
# Add new entries after reset - fill the buffer completely again
for _ in range(10):
f = pg.allreduce(a)
f.wait()
torch.cuda.synchronize(device=device)
time.sleep(1)
# Verify we get exactly 10 new entries, not 20
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
self.assertEqual(len(t["entries"]), 10)
# Verify all entries have the expected properties (from after reset)
# After reset, record IDs should start from 0 again
for i, entry in enumerate(t["entries"]):
self.assertIn("profiling_name", entry)
self.assertEqual(entry["profiling_name"], "nccl:all_reduce")
self.assertIn("record_id", entry)
# Record IDs should be sequential starting from 0 after reset
self.assertEqual(entry["record_id"], i)
dist.destroy_process_group()
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@parametrize("timing_enabled", [True, False])
def test_fr_record_reset_partial_overwrite(self, timing_enabled):
"""
Test that when the circular buffer is full, we reset, and then add fewer
entries than the buffer size, we only get the new entries.
This tests that old entries at the end of the circular buffer are properly
filtered out based on reset_epoch.
"""
if self.rank == self.MAIN_PROCESS_RANK:
return
# Override buffer size to 10 for faster testing
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10"
pg = self._create_process_group_nccl()
if timing_enabled:
pg._enable_collectives_timing()
device = self.local_device
self.set_thread_name("fr_test_thread")
a = torch.full((3, 4), float(self.rank), device=device)
# Fill the buffer completely
for _ in range(10):
f = pg.allreduce(a)
f.wait()
torch.cuda.synchronize(device=device)
time.sleep(1)
# Reset the flight recorder
torch._C._distributed_c10d._reset_fr_recording_nccl()
# Add only 3 new entries (much less than buffer size)
for _ in range(3):
f = pg.allreduce(a)
f.wait()
torch.cuda.synchronize(device=device)
time.sleep(1)
# Verify we only get the 3 new entries, not 10
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
self.assertEqual(len(t["entries"]), 3)
# Verify record IDs start from 0 after reset
for i, entry in enumerate(t["entries"]):
self.assertIn("record_id", entry)
self.assertEqual(entry["record_id"], i)
dist.destroy_process_group()
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@parametrize("timing_enabled", [True, False])
def test_fr_record_reset_wraparound(self, timing_enabled):
"""
Test that when we reset in the middle of the circular buffer and then
wrap around, dump_entries correctly returns only entries from the current
epoch in the correct order.
"""
if self.rank == self.MAIN_PROCESS_RANK:
return
# Override buffer size to 10 for faster testing
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10"
pg = self._create_process_group_nccl()
if timing_enabled:
pg._enable_collectives_timing()
device = self.local_device
self.set_thread_name("fr_test_thread")
a = torch.full((3, 4), float(self.rank), device=device)
# Fill half the buffer
for _ in range(5):
f = pg.allreduce(a)
f.wait()
torch.cuda.synchronize(device=device)
time.sleep(1)
# Reset at this point (reset happens at index 5)
torch._C._distributed_c10d._reset_fr_recording_nccl()
# Now add 8 entries, which will wrap around
# (5->9 fills rest of buffer, then 0->2 wraps around)
for _ in range(8):
f = pg.allreduce(a)
f.wait()
torch.cuda.synchronize(device=device)
time.sleep(1)
# Should get exactly 8 entries, properly ordered
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
self.assertEqual(len(t["entries"]), 8)
# Entries should be in chronological order
# The dump_entries() method returns entries from next_ to end, then 0 to next_
# After filtering old entries, we should have 8 entries in order
# Verify record IDs start from 0 after reset (id_ is reset in reset_all())
for i, entry in enumerate(t["entries"]):
self.assertIn("profiling_name", entry)
self.assertIn("record_id", entry)
self.assertEqual(entry["record_id"], i)
dist.destroy_process_group()
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@parametrize("timing_enabled", [True, False])
def test_fr_record_multiple_resets(self, timing_enabled):
"""
Test multiple consecutive resets to ensure each reset properly increments
the epoch and filters out entries from previous epochs.
"""
if self.rank == self.MAIN_PROCESS_RANK:
return
# Override buffer size to 10 for faster testing
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10"
pg = self._create_process_group_nccl()
if timing_enabled:
pg._enable_collectives_timing()
device = self.local_device
self.set_thread_name("fr_test_thread")
a = torch.full((3, 4), float(self.rank), device=device)
# First batch: 2 entries
for _ in range(2):
f = pg.allreduce(a)
f.wait()
torch.cuda.synchronize(device=device)
time.sleep(1)
# First reset
torch._C._distributed_c10d._reset_fr_recording_nccl()
# Second batch: 3 entries
for _ in range(3):
f = pg.allreduce(a)
f.wait()
torch.cuda.synchronize(device=device)
time.sleep(1)
# Second reset
torch._C._distributed_c10d._reset_fr_recording_nccl()
# Third batch: 4 entries
for _ in range(4):
f = pg.allreduce(a)
f.wait()
torch.cuda.synchronize(device=device)
time.sleep(1)
# Should only see the last 4 entries
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
self.assertEqual(len(t["entries"]), 4)
# Verify record IDs start from 0 after the last reset
for i, entry in enumerate(t["entries"]):
self.assertIn("record_id", entry)
self.assertEqual(entry["record_id"], i)
dist.destroy_process_group()
def check_if_test_is_skipped(fn):
def wrapper(self, *args, **kwargs):

View File

@ -8,11 +8,21 @@ from torch._dynamo.graph_deduplication import apply_graph_deduplication
from torch._dynamo.graph_utils import _detect_cycles
from torch._dynamo.output_graph import FakeRootModule
from torch._dynamo.test_case import TestCase
from torch._dynamo.testing import extract_graph, extract_graph_and_tracker, normalize_gm
from torch._dynamo.testing import (
AotEagerAndRecordGraphs,
extract_graph_and_tracker,
normalize_gm,
)
from torch.compiler import allow_in_graph
from torch.utils._ordered_set import OrderedSet
def extract_graph(fn, *args, **kwargs):
backend = AotEagerAndRecordGraphs()
result = torch.compile(backend=backend)(fn)(*args, **kwargs)
return result, backend.graphs, backend.fw_graphs
def graph_str(gm):
return normalize_gm(gm.print_readable(print_output=False))
@ -30,7 +40,7 @@ class GraphDededuplicationTests(TestCase):
super().tearDown()
def run_and_return_graphs(self, fn, *args, **kwargs):
return extract_graph(fn, *args, **kwargs)[0:3]
return extract_graph(fn, *args, **kwargs)
def run_and_get_simple_graph(self):
def fn(x, y):

View File

@ -1,7 +1,7 @@
# Owner(s): ["module: dynamo"]
import unittest
from collections.abc import Callable, Sequence
from typing import Any, Union
from collections.abc import Sequence
from typing import Any, Callable, Union
import torch
import torch._dynamo

View File

@ -1,5 +1,5 @@
# Owner(s): ["module: dynamo"]
from typing import NamedTuple, Optional, TYPE_CHECKING
from typing import Callable, NamedTuple, Optional
import torch
import torch._dynamo
@ -7,10 +7,6 @@ from torch._dynamo.test_case import run_tests, TestCase
from torch._dynamo.testing import CompileCounter, same
if TYPE_CHECKING:
from collections.abc import Callable
"""
This is an example of a pure-python version of autograd implemented by
@zdevito. It represents a rather challenging test case for TorchDynamo

View File

@ -1,13 +1,11 @@
# Owner(s): ["module: dynamo"]
import functools
import re
import unittest
import weakref
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
from torch._dynamo.testing import extract_graph, remove_trailing_space
from torch.testing._internal.common_cuda import TEST_MULTIGPU
from torch.testing._internal.common_utils import requires_cuda
@ -17,14 +15,6 @@ requires_multigpu = functools.partial(
)
def remove_file_comment(gm_str: str) -> str:
return remove_trailing_space(re.sub(r"File.*\n", "\n", gm_str))
def print_graph(graph: torch.fx.GraphModule) -> str:
return remove_file_comment(graph.print_readable())
class TestStreams(torch._dynamo.test_case.TestCase):
@classmethod
def setUpClass(cls):
@ -46,7 +36,9 @@ class TestStreams(torch._dynamo.test_case.TestCase):
@requires_cuda
def test_stream_enter_exit(self):
def fn(x, y, s1, s2):
def fn(x, y):
s2 = torch.Stream()
s1 = torch.Stream()
with s1:
z1 = torch.add(x, y)
with s2:
@ -55,36 +47,13 @@ class TestStreams(torch._dynamo.test_case.TestCase):
return y
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2), torch.Stream(), torch.Stream())
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
expected = fn(*inp)
(
actual,
_,
fw_graphs,
_,
) = extract_graph(fn, *inp)
self.assertEqual(len(fw_graphs), 1)
fn_opt = torch.compile(fn, fullgraph=True)
actual = fn_opt(*inp)
self.assertEqual(expected, actual)
self.assertExpectedInline(
print_graph(fw_graphs[0]),
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
# Annotation: {'stream': None}
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
# Annotation: {'stream': None}
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
# Annotation: {'stream': None}
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 2); add_1 = None
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_2, add); add_2 = add = None
return (add_3,)
""",
)
@requires_cuda
@unittest.skip("Needs graph break support with annotation context")
def test_stream_context_graph_break(self):
def fn(x, y):
s2 = torch.Stream()
@ -101,16 +70,9 @@ class <lambda>(torch.nn.Module):
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
expected = fn(*inp)
(
actual,
_,
fw_graphs,
_,
) = extract_graph(fn, *inp)
fn_opt = torch.compile(fn)
actual = fn_opt(*inp)
self.assertEqual(expected, actual)
self.assertEqual(len(fw_graphs), 2)
self.assertExpectedInline(print_graph(fw_graphs[0]), """""")
self.assertExpectedInline(print_graph(fw_graphs[1]), """""")
@requires_cuda
def test_stream_input(self):
@ -193,248 +155,22 @@ class <lambda>(torch.nn.Module):
self.assertEqual(s_act, s_exp)
def test_nested_stream_enter_exit(self):
def fn(x, y, s0, s1, s2):
with s1:
with s2:
z1 = torch.add(x, y)
with s0:
z0 = torch.add(x, y)
with s2:
y = 2 + z1
pass
return z0, y
inp = (
torch.ones(2, 2) + 1,
torch.ones(2, 2),
torch.Stream(),
torch.Stream(),
torch.Stream(),
)
expected = fn(*inp)
(
actual,
_,
fw_graphs,
_,
) = extract_graph(fn, *inp)
self.assertEqual(len(fw_graphs), 1)
self.assertEqual(expected, actual)
self.assertExpectedInline(
print_graph(fw_graphs[0]),
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
# Annotation: {'stream': None}
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
# Annotation: {'stream': None}
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
# Annotation: {'stream': None}
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add, 2); add = None
return (add_1, add_2)
""",
)
@unittest.skip("Needs graph break support with annotation context")
def test_stream_enter_exit_graph_break(self):
pass
@unittest.skip("Needs graph break support with annotation context")
def test_nested_stream_enter_exit_graph_break(self):
pass
def test_local_stream_enter_exit(self):
def fn(x, y):
s2 = torch.Stream()
s1 = torch.Stream()
with s1:
z1 = torch.add(x, y)
with s2:
z = torch.add(x, y)
y = z + 2 + z1
return y
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
expected = fn(*inp)
(
actual,
_,
fw_graphs,
_,
) = extract_graph(fn, *inp)
self.assertEqual(len(fw_graphs), 1)
self.assertEqual(expected, actual)
self.assertExpectedInline(
print_graph(fw_graphs[0]),
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
# Annotation: {'stream': 1}
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
# Annotation: {'stream': 0}
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
# Annotation: {'stream': 0}
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 2); add_1 = None
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_2, add); add_2 = add = None
return (add_3,)
""",
)
pass
def test_local_stream_nested_enter_exit(self):
def fn(x, y):
s2 = torch.Stream()
s1 = torch.Stream()
s0 = torch.Stream()
with s1:
with s2:
z1 = torch.add(x, y)
with s0:
z0 = torch.add(x, y)
with s2:
y = 2 + z1
return z0, y
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
expected = fn(*inp)
(
actual,
_,
fw_graphs,
_,
) = extract_graph(fn, *inp)
self.assertEqual(len(fw_graphs), 1)
self.assertEqual(expected, actual)
self.assertExpectedInline(
print_graph(fw_graphs[0]),
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
# Annotation: {'stream': 0}
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
# Annotation: {'stream': 2}
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
# Annotation: {'stream': 0}
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add, 2); add = None
return (add_1, add_2)
""",
)
pass
def test_stream_with_mutation(self):
def fn(x, y):
s2 = torch.Stream()
s1 = torch.Stream()
s0 = torch.Stream()
with s1:
with s2:
x.add_(y)
with s0:
z1 = torch.add(y, y)
z0 = torch.add(z1, y)
with s2:
y = 2 + z1
return z0, y
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
expected = fn(*inp)
(
actual,
_,
fw_graphs,
_,
) = extract_graph(fn, *inp)
self.assertEqual(len(fw_graphs), 1)
self.assertEqual(expected, actual)
self.assertExpectedInline(
print_graph(fw_graphs[0]),
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
# Annotation: {'stream': 0}
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
# Annotation: {'stream': 2}
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg1_1, arg1_1)
# Annotation: {'stream': 2}
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, arg1_1); arg1_1 = None
# Annotation: {'stream': 0}
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 2); add_1 = None
#
copy_: "f32[2, 2]" = torch.ops.aten.copy_.default(arg0_1, add); arg0_1 = add = copy_ = None
return (add_2, add_3)
""",
)
def test_stream_backward(self) -> None:
def fn(x, y):
s2 = torch.Stream()
s0 = torch.Stream()
with s0:
y0 = 2 * x + y
with s2:
z = 2 * x + y
return y0, z
inp = (
torch.ones(2, 2, requires_grad=True) + 1,
torch.ones(2, 2, requires_grad=True),
)
expected = fn(*inp)
(
actual,
_,
fw_graphs,
bw_graphs,
) = extract_graph(fn, *inp)
self.assertEqual(len(fw_graphs), 1)
self.assertEqual(expected, actual)
self.assertExpectedInline(
print_graph(fw_graphs[0]),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[2, 2]", primals_2: "f32[2, 2]"):
# Annotation: {'stream': 1}
mul: "f32[2, 2]" = torch.ops.aten.mul.Tensor(primals_1, 2); primals_1 = None
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul, primals_2)
# Annotation: {'stream': 0}
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul, primals_2); mul = primals_2 = None
return (add, add_1)
""",
)
actual[1].sum().backward()
self.assertExpectedInline(
print_graph(bw_graphs[0]),
"""\
class GraphModule(torch.nn.Module):
def forward(self, tangents_1: "f32[2, 2]", tangents_2: "f32[2, 2]"):
# Annotation: {'stream': 0}
mul_2: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_2, 2)
#
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(tangents_2, tangents_1); tangents_2 = None
# Annotation: {'stream': 1}
mul_3: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_1, 2); tangents_1 = None
#
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul_2, mul_3); mul_2 = mul_3 = None
return (add_3, add_2)
""",
)
pass
@requires_cuda
def test_run_opcheck(self):

View File

@ -2,14 +2,18 @@
import re
import unittest
from functools import partial
from typing import Any, Optional, Union
from typing import Any, Optional
from unittest.mock import patch
import torch
import torch.nn as nn
from torch._inductor import config as inductor_config
from torch._inductor.choices import InductorChoices
from torch._inductor.kernel_inputs import MMKernelInputs
from torch._inductor.kernel_inputs import (
ConvKernelInputs,
MMKernelInputs,
SerializableValue,
)
from torch._inductor.lookup_table.choices import LookupTableChoices
from torch._inductor.select_algorithm import (
add_preprocessing_fn,
@ -54,7 +58,7 @@ class MockMMKernelInputs(MMKernelInputs):
def __init__(
self,
tensors: list[torch.Tensor],
scalars: Optional[dict[str, Union[float, int]]] = None,
scalars: Optional[dict[str, SerializableValue]] = None,
mat1_idx: int = -2,
mat2_idx: int = -1,
):
@ -80,6 +84,37 @@ class MockMMKernelInputs(MMKernelInputs):
return self.tensors[0].device.type
class MockConvKernelInputs(ConvKernelInputs):
"""Mock ConvKernelInputs that subclasses the real class and uses real tensors"""
def __init__(
self,
tensors: list[torch.Tensor],
scalars: Optional[dict[str, SerializableValue]] = None,
x_idx: int = 0,
weight_idx: int = 1,
bias_idx: Optional[int] = None,
):
"""Initialize with real tensors, creating mock nodes for the base class"""
mock_nodes = [MockTensorNode(t) for t in tensors]
super().__init__(
mock_nodes, scalars, x_idx=x_idx, weight_idx=weight_idx, bias_idx=bias_idx
)
self.tensors = tensors # Keep reference to original tensors
def shapes_hinted(self) -> tuple[tuple[int, ...], ...]:
"""Delegate to symbolic since real tensors already have int shapes"""
return self.shapes_symbolic()
def strides_hinted(self) -> tuple[tuple[int, ...], ...]:
"""Delegate to symbolic since real tensors already have int strides"""
return self.strides_symbolic() # pyre-ignore
@property
def device_type(self) -> Optional[str]:
return self.tensors[0].device.type
class BaseLookupTableTest(TestCase):
"""Base class for lookup table tests with common setup and utilities"""
@ -103,7 +138,7 @@ class BaseLookupTableTest(TestCase):
shapes: Optional[list[tuple[int, ...]]] = None,
device: torch.device = torch.device("cuda"),
dtype: torch.dtype = torch.float32,
scalars: Optional[dict[str, Union[float, int]]] = None,
scalars: Optional[dict[str, SerializableValue]] = None,
) -> MockMMKernelInputs:
"""Create MockMMKernelInputs with real tensors"""
if shapes is None:
@ -1055,6 +1090,119 @@ class TestLookupTableE2E(BaseE2ELookupTableTest):
with patch.object(inductor_config.lookup_table, "check_src_hash", True):
self.run_model("mm", tensors)
@fresh_cache()
def test_conv2d_lookup_table_entry_e2e(self):
"""Test end-to-end conv2d with lookup table entry - verifies config is picked up and produces valid results"""
import torch._inductor.kernel.conv
# Create input tensors with specific shapes for conv2d
# Input: [batch=2, in_channels=3, height=32, width=32]
# Weight: [out_channels=64, in_channels=3, kernel_h=3, kernel_w=3]
# Make them channels-last to match what conv lowering uses
x = torch.randn(2, 3, 32, 32, device=self.device, dtype=torch.float16).to(
memory_format=torch.channels_last
)
weight = torch.randn(64, 3, 3, 3, device=self.device, dtype=torch.float16).to(
memory_format=torch.channels_last
)
# Define conv parameters - use these SAME values everywhere
stride = (1, 1)
padding = (1, 1)
dilation = (1, 1)
groups = 1
# Create MockConvKernelInputs using the SAME tensors and SAME scalar values
mock_scalars = {
"stride": stride,
"padding": padding,
"dilation": dilation,
"transposed": False,
"output_padding": (0, 0),
"groups": groups,
}
mock_kernel_inputs = MockConvKernelInputs([x, weight], mock_scalars)
# Create lookup key for "convolution" operation
choices_handler = LookupTableChoices()
lookup_key = choices_handler.make_lookup_key(mock_kernel_inputs, "convolution")
# Get the exact template UID from conv2d_template
template_uid = torch._inductor.kernel.conv.conv2d_template.uid
# Create a precisely configured conv2d config
# IMPORTANT: Only include per-config tunable parameters!
# Static parameters (KERNEL_H, STRIDE_H, GROUPS, UNROLL, ALLOW_TF32) are
# automatically generated by get_extra_kwargs() and should NOT be in the lookup table
conv2d_config = {
"template_id": template_uid,
# Per-config tunable parameters only (what you'd tune via autotuning)
"BLOCK_M": 64,
"BLOCK_N": 64,
"BLOCK_K": 32,
"num_stages": 2,
"num_warps": 4,
}
# Setup lookup table
inductor_config.lookup_table.table = {lookup_key: [conv2d_config]}
def validate_conv_choice(choices):
assert len(choices) == 1, (
f"Expected 1 choice from lookup table, got {len(choices)}"
)
assert isinstance(choices[0], TritonTemplateCaller), (
f"Expected TritonTemplateCaller, got {type(choices[0])}"
)
assert "convolution2d" in choices[0].name, (
f"Expected 'convolution2d' in name, got {choices[0].name}"
)
return choices
add_preprocessing_fn(validate_conv_choice)
# Create and compile the model using the SAME weight tensor
class SimpleConv2d(nn.Module):
def __init__(self, weight):
super().__init__()
self.register_buffer("weight", weight)
def forward(self, x):
return torch.conv2d(
x,
self.weight,
bias=None,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
model = SimpleConv2d(weight).to(self.device)
with inductor_config.patch({"max_autotune": True, "max_autotune_gemm": True}):
compiled_model = torch.compile(model)
result = compiled_model(x) # Use the SAME x tensor
# Output shape: [batch=2, out_channels=64, out_h=32, out_w=32]
# (same spatial dims due to padding=1, stride=1, kernel=3)
expected_shape = (2, 64, 32, 32)
self.assertEqual(
result.shape,
expected_shape,
f"Expected shape {expected_shape}, got {result.shape}",
)
self.assertFalse(
torch.isnan(result).any().item(),
"Output contains NaN values",
)
self.assertFalse(
torch.isinf(result).any().item(),
"Output contains Inf values",
)
if __name__ == "__main__":
from torch._inductor.utils import is_big_gpu

View File

@ -1864,8 +1864,6 @@ class TestFP8Matmul(TestCase):
], name_fn=lambda mkn: f"{mkn[0]}_{mkn[1]}_{mkn[2]}")
@parametrize("recipe", ["mxfp8", "mxfp4", "nvfp4"])
def test_blockwise_mxfp8_nvfp4_mxfp4_numerics(self, test_case_name, fast_accum, mkn, recipe) -> None:
if torch.version.hip and recipe == "nvfp4":
raise unittest.SkipTest("nvfp4 not supported on ROCm, skipping")
if (recipe == "nvfp4" or recipe == "mxfp4") and fast_accum:
raise unittest.SkipTest("fast_accum not supported in nvfp4/mxfp4 cublas gemm, skipping")

View File

@ -1914,7 +1914,6 @@ class TestSDPAFailureModes(NNTestCase):
q, k, v, None, 0.0, is_causal=True))
@onlyCUDA
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support Efficient Attention")
def test_mem_eff_attention_fail_with_batch_size_geq_65536(self):
batch_size = 2**16
query = torch.rand([batch_size, 2, 2, 8], device='cuda', dtype=torch.float16, requires_grad=True)
@ -1936,7 +1935,6 @@ class TestSDPAFailureModes(NNTestCase):
self.assertEqual(value.grad, v_cpu.grad, atol=2e-3, rtol=1e-4)
@onlyCUDA
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support Efficient Attention")
def test_mem_eff_attention_fail_with_batch_size_geq_65536_error(self):
query = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16)
key = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16)
@ -1950,7 +1948,6 @@ class TestSDPAFailureModes(NNTestCase):
@largeTensorTest("15GB", "cuda")
@onlyCUDA
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support Efficient Attention")
def test_mem_eff_attention_large_seq_len_uniform_attention(self):
device = torch.device("cuda")
dtype = torch.bfloat16

View File

@ -1,5 +1,5 @@
from typing import TypeAlias, Union
from typing_extensions import assert_type
from typing import Union
from typing_extensions import assert_type, TypeAlias
from torch import randn, Tensor

View File

@ -1,9 +1,8 @@
# mypy: allow-untyped-defs
# mypy: disable-error-code="type-arg"
from collections.abc import Callable
from datetime import timedelta
from enum import Enum
from typing import Any, Optional, overload, Union
from typing import Any, Callable, Optional, overload, Union
import torch
from torch import Tensor

View File

@ -87,12 +87,6 @@ def extract_graph_and_tracker(fn, *args, **kwargs): # type: ignore[no-untyped-d
return gm.graph, region_tracker # type: ignore[union-attr]
def extract_graph(fn, *args, **kwargs): # type: ignore[no-untyped-def]
backend = AotEagerAndRecordGraphs()
result = torch.compile(backend=backend)(fn)(*args, **kwargs)
return result, backend.graphs, backend.fw_graphs, backend.bw_graphs
def collect_results(
model: torch.nn.Module, prediction: Any, loss: Any, example_inputs: Any
) -> list[Any]:

View File

@ -21,9 +21,9 @@ restoring state changes.
import inspect
import sys
import warnings
from collections.abc import Callable, Sequence, Sized
from collections.abc import Callable, Sequence
from contextlib import ExitStack
from typing import Any, ContextManager, Optional, TYPE_CHECKING, Union
from typing import Any, ContextManager, Optional, Sized, TYPE_CHECKING, Union
import torch._C
from torch._guards import Guard

View File

@ -2,7 +2,7 @@
from __future__ import annotations
import hashlib
from typing import Any, Optional, TYPE_CHECKING
from typing import Any, Callable, Optional, Sequence, TYPE_CHECKING
import sympy # noqa: TC002
@ -17,8 +17,6 @@ from .simd import SIMDKernel, SIMDScheduling
if TYPE_CHECKING:
from collections.abc import Callable, Sequence
from ..ir import IRNode
from ..scheduler import BaseSchedulerNode

View File

@ -360,7 +360,7 @@ def estimate_nccl_collective_runtime_from_fx_node(
fx_node: torch.fx.Node,
override_size: Optional[int] = None,
# TODO(ivankobzarev): NCCL estimator sometimes fail unexpectedly, enable back after fix.
use_nccl_estimator: bool = True,
use_nccl_estimator: bool = False,
) -> float:
"""
Returns estimated NCCL collective runtime in nanoseconds (ns).

View File

@ -8,6 +8,7 @@ import torch
from torch._inductor.codegen.rocm.ck_conv_template import CKGroupedConvFwdTemplate
from .. import config, ir
from ..kernel_inputs import ConvKernelInputs
from ..lowering import (
add_layout_constraint,
constrain_to_fx_strides,
@ -16,7 +17,9 @@ from ..lowering import (
)
from ..select_algorithm import (
autotune_select_algorithm,
ChoiceCaller,
ExternKernelChoice,
KernelTemplate,
SymbolicGridFn,
TritonTemplate,
)
@ -76,7 +79,7 @@ LOOP_BODY_2D = """
& (idx_x_h < IN_H)[:, None]
& (idx_x_w >= 0)[:, None]
& (idx_x_w < IN_W)[:, None]
& (idx_x_c < GROUP_IN_C)[None, :]
& (idx_x_c < GROUP_IN_C)[None, :
)
matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0)
@ -542,34 +545,40 @@ def convolution(
x = ir.ExternKernel.require_stride_order(x, req_stride_order) # type: ignore[assignment]
weight = ir.ExternKernel.require_stride_order(weight, req_stride_order) # type: ignore[assignment]
ordered_kwargs_for_cpp_kernel = [
"stride",
"padding",
"dilation",
"transposed",
"output_padding",
"groups",
]
if bias is None:
args = [x, weight]
kwargs["bias"] = None # type: ignore[typeddict-unknown-key]
ordered_kwargs_for_cpp_kernel.insert(0, "bias")
else:
args = [x, weight, bias]
# Create ConvKernelInputs for unified template configuration
# Only include bias in input_nodes when it's not None
# - For Triton templates: bias is always None here (peeled off earlier), so input_nodes = [x, weight]
# - For ATEN: input_nodes = [x, weight] when bias is None, [x, weight, bias] when bias is present
if bias is not None:
bias.realize()
bias.freeze_layout()
V.graph.sizevars.guard_int_seq(bias.get_size())
input_nodes = [x, weight, bias]
bias_idx = 2
else:
input_nodes = [x, weight]
bias_idx = None
kernel_inputs = ConvKernelInputs(
input_nodes,
scalars={
"stride": stride,
"padding": padding,
"dilation": dilation,
"transposed": transposed,
"output_padding": output_padding,
"groups": groups,
},
x_idx=0,
weight_idx=1,
bias_idx=bias_idx,
)
# Build list of templates to try
templates: list[ExternKernelChoice | KernelTemplate] = []
choices = []
if torch._inductor.utils._use_conv_autotune_backend("ATEN"):
choices = [
aten_convolution.bind(
args,
layout,
ordered_kwargs_for_cpp_kernel,
**kwargs,
)
]
templates.append(aten_convolution)
if (
torch._inductor.utils._use_conv_autotune_backend("TRITON")
@ -587,60 +596,23 @@ def convolution(
and is_zeros(padding)
and groups == 1
):
choices.append(aten_conv1x1_via_mm.bind(args, layout))
templates.append(aten_conv1x1_via_mm)
conv_configs = V.choices.get_conv_configs(device_type)
# Add appropriate template based on ndim
if ndim == 2:
templates.append(conv2d_template)
elif ndim == 3:
templates.append(conv3d_template)
dtype_size = x.get_dtype().itemsize
for cfg in conv_configs(
sympy_product([x.get_size()[0], *x.get_size()[2:]]),
out_chan,
in_chan,
dtype_size=dtype_size,
):
if ndim == 2:
conv2d_template.maybe_append_choice(
choices,
input_nodes=(x, weight),
layout=layout,
KERNEL_H=kernel_shape[0],
KERNEL_W=kernel_shape[1],
STRIDE_H=stride[0],
STRIDE_W=stride[1],
PADDING_H=padding[0],
PADDING_W=padding[1],
GROUPS=groups,
# TODO(jansel): try unroll for bigger kernels once fixed:
# https://github.com/triton-lang/triton/issues/1254
UNROLL=is_ones(kernel_shape),
ALLOW_TF32=torch.backends.cudnn.allow_tf32,
num_stages=cfg.num_stages,
num_warps=cfg.num_warps,
**cfg.kwargs,
)
elif ndim == 3:
conv3d_template.maybe_append_choice(
choices,
input_nodes=(x, weight),
layout=layout,
KERNEL_D=kernel_shape[0],
KERNEL_H=kernel_shape[1],
KERNEL_W=kernel_shape[2],
STRIDE_D=stride[0],
STRIDE_H=stride[1],
STRIDE_W=stride[2],
PADDING_D=padding[0],
PADDING_H=padding[1],
PADDING_W=padding[2],
GROUPS=groups,
# TODO(jansel): try unroll for bigger kernels once fixed:
# https://github.com/triton-lang/triton/issues/1254
UNROLL=is_ones(kernel_shape),
ALLOW_TF32=torch.backends.cudnn.allow_tf32,
num_stages=cfg.num_stages,
num_warps=cfg.num_warps,
**cfg.kwargs,
)
# Initialize choices list and extend with template configs
choices: list[ChoiceCaller] = []
choices.extend(
V.choices.get_template_configs(
kernel_inputs,
templates,
"convolution",
)
)
if use_ck_conv_template(layout):
CKGroupedConvFwdTemplate.add_ck_conv_choices(
choices,
@ -652,7 +624,9 @@ def convolution(
groups=groups,
n_spatial_dimensions=ndim,
)
return autotune_select_algorithm("convolution", choices, args, layout)
return autotune_select_algorithm(
"convolution", choices, kernel_inputs.nodes(), layout
)
@register_lowering(aten._convolution)

View File

@ -1,6 +1,7 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Any, Optional, TYPE_CHECKING, Union
import torch
@ -12,10 +13,12 @@ from .ir import FixedLayout, FlexibleLayout, Layout
if TYPE_CHECKING:
from collections.abc import Sequence
import sympy
# Type aliases for serializable scalar values
Serializable = Union[int, float, bool]
SerializableValue = Union[Serializable, Sequence[Serializable]]
class KernelInputs(ABC):
"""
@ -27,7 +30,7 @@ class KernelInputs(ABC):
def __init__(
self,
input_nodes: list[Any],
scalars: Optional[dict[str, Union[float, int]]] = None,
scalars: Optional[dict[str, SerializableValue]] = None,
out_dtype: Optional[torch.dtype] = None,
):
"""
@ -183,7 +186,7 @@ class KernelInputs(ABC):
The output dtype
"""
def get_scalar(self, name: str) -> Union[float, int]:
def get_scalar(self, name: str) -> SerializableValue:
"""
Get the scalar value for a given name.
@ -191,7 +194,7 @@ class KernelInputs(ABC):
name: Name of the scalar to get
Returns:
The scalar value
The scalar value (can be int, float, bool, or tuple of these types)
"""
assert name in self._scalars, f"Scalar {name} not found, but required"
return self._scalars[name]
@ -216,7 +219,7 @@ class MMKernelInputs(KernelInputs):
def __init__(
self,
input_nodes: list[Any],
scalars: Optional[dict[str, Union[float, int]]] = None,
scalars: Optional[dict[str, SerializableValue]] = None,
out_dtype: Optional[torch.dtype] = None,
mat1_idx: int = -2,
mat2_idx: int = -1,
@ -336,3 +339,113 @@ class MMKernelInputs(KernelInputs):
assert k == k_check, f"K dimensions don't match: {k} vs {k_check}"
return (m, n, k)
class ConvKernelInputs(KernelInputs):
"""
Specialized KernelInputs for convolution operations.
Stores input tensor, weight tensor, and optional bias, along with conv parameters.
"""
def __init__(
self,
input_nodes: list[Any],
scalars: Optional[dict[str, SerializableValue]] = None,
out_dtype: Optional[torch.dtype] = None,
x_idx: int = 0,
weight_idx: int = 1,
bias_idx: Optional[int] = None,
):
"""
Initialize with convolution input nodes.
Args:
input_nodes: List containing [x, weight] or [x, weight, bias]
scalars: Dict with conv params (stride, padding, dilation, groups, transposed, output_padding)
out_dtype: Optional output dtype
x_idx: Index of input tensor (default: 0)
weight_idx: Index of weight tensor (default: 1)
bias_idx: Index of bias tensor if present (default: None)
"""
super().__init__(input_nodes, scalars, out_dtype)
assert len(input_nodes) >= 2, "Expected at least 2 input nodes (x, weight)"
self._x_idx = x_idx
self._weight_idx = weight_idx
self._bias_idx = bias_idx
# Validate that required scalars are present
required_scalars = [
"stride",
"padding",
"dilation",
"transposed",
"output_padding",
"groups",
]
for key in required_scalars:
assert key in self._scalars, f"Conv requires scalar '{key}'"
def out_dtype(self) -> torch.dtype:
"""
Get the output dtype, whether passed in or inferred from the nodes
Returns:
The output dtype
"""
if self._out_dtype is not None:
return self._out_dtype
return self._input_nodes[self._x_idx].get_dtype()
def output_layout(self, flexible: bool = True) -> Layout:
"""
Handle output layout generation for convolution.
Args:
flexible: If True, return FlexibleLayout, otherwise FixedLayout
Returns:
Layout for the convolution output
"""
from torch._inductor.kernel.conv import conv_layout
x = self._input_nodes[self._x_idx]
weight = self._input_nodes[self._weight_idx]
bias = self._input_nodes[self._bias_idx] if self._bias_idx is not None else None
# Use existing conv_layout function
# We know the types here because conv requires these specific scalar types
layout = conv_layout(
x,
weight,
bias,
self._scalars["stride"], # type: ignore[arg-type]
self._scalars["padding"], # type: ignore[arg-type]
self._scalars["dilation"], # type: ignore[arg-type]
self._scalars["transposed"], # type: ignore[arg-type]
self._scalars["output_padding"], # type: ignore[arg-type]
self._scalars["groups"], # type: ignore[arg-type]
)
# TODO: Handle flexible vs fixed based on config if needed
return layout
def get_x_weight_bias(self) -> tuple[Any, Any, Optional[Any]]:
"""
Get x, weight, and optional bias nodes.
Returns:
Tuple of (x, weight, bias) where bias may be None
"""
bias = self._input_nodes[self._bias_idx] if self._bias_idx is not None else None
return self._input_nodes[self._x_idx], self._input_nodes[self._weight_idx], bias
def spatial_dims(self) -> tuple[Any, ...]:
"""
Get spatial dimensions from input tensor (H, W for 2D, D, H, W for 3D).
Returns:
Tuple of spatial dimension sizes
"""
x_shape = self._input_nodes[self._x_idx].get_size()
return x_shape[2:] # Skip batch and channel dims

View File

@ -1,6 +1,6 @@
import os
from collections.abc import Callable
from functools import cache, partial
from typing import Callable
import torch
from torch._environment import is_fbcode

View File

@ -1,6 +1,6 @@
# NOTE: add new template heuristics here, so they get imported and registered
# TODO: write a simple glob if there are many heuristics to auto import them in the right order
from . import aten, base, contiguous_mm, decompose_k, registry, triton
from . import aten, base, contiguous_mm, conv, decompose_k, registry, triton
# expose the entry function
from .registry import get_template_heuristic

View File

@ -0,0 +1,287 @@
from __future__ import annotations
from typing import Any, cast, TYPE_CHECKING
import torch
from ..kernel.conv import aten_convolution, conv2d_template, conv3d_template
from ..kernel_inputs import ConvKernelInputs
from ..utils import is_ones, sympy_product
from ..virtualized import V
from .base import TemplateConfigHeuristics
from .registry import register_template_heuristic
from .triton import (
CPUConfigHeuristic,
CUDAConfigHeuristic,
MTIAConfigHeuristic,
ROCmConfigHeuristic,
XPUConfigHeuristic,
)
if TYPE_CHECKING:
from collections.abc import Generator
from ..kernel_inputs import KernelInputs
class ConvTemplateConfigMixin(TemplateConfigHeuristics):
"""
Mixin for conv templates that converts config lists to template kwargs.
Similar to MMTemplateConfigMixin but for convolutions.
This handles generating both the static template kwargs (KERNEL_H, STRIDE_H, etc.)
and the per-config kwargs (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps).
"""
# Type hint for methods from BaseConfigHeuristic
get_conv_configs: Any
def get_extra_kwargs(
self,
kernel_inputs: KernelInputs,
op_name: str,
) -> dict[str, Any]:
"""
Return template kwargs that don't change per-config.
These are derived from kernel_inputs and must include all template parameters.
Args:
kernel_inputs: ConvKernelInputs containing input tensors and conv params
op_name: Operation name (e.g., "convolution")
Returns:
Dict of static template kwargs (KERNEL_H, STRIDE_H, GROUPS, etc.)
"""
assert isinstance(kernel_inputs, ConvKernelInputs), (
f"ConvTemplateConfigMixin requires ConvKernelInputs, got {type(kernel_inputs)}"
)
x, weight, bias = kernel_inputs.get_x_weight_bias()
# Extract kernel shape from weight: [out_chan, in_chan, *kernel_shape]
weight_size = V.graph.sizevars.guard_int_seq(weight.get_size())
kernel_shape = weight_size[2:] # Skip out_chan, in_chan
ndim = len(kernel_shape)
# Extract scalars
stride = cast(tuple[int, ...], kernel_inputs.get_scalar("stride"))
padding = cast(tuple[int, ...], kernel_inputs.get_scalar("padding"))
groups = cast(int, kernel_inputs.get_scalar("groups"))
# Check if we should unroll (only for 1x1 kernels)
unroll = is_ones(kernel_shape)
# Build kwargs dict based on ndim
kwargs: dict[str, Any] = {
"GROUPS": groups,
"UNROLL": unroll,
"ALLOW_TF32": torch.backends.cudnn.allow_tf32,
}
if ndim == 2:
kwargs.update(
{
"KERNEL_H": kernel_shape[0],
"KERNEL_W": kernel_shape[1],
"STRIDE_H": stride[0],
"STRIDE_W": stride[1],
"PADDING_H": padding[0],
"PADDING_W": padding[1],
}
)
elif ndim == 3:
kwargs.update(
{
"KERNEL_D": kernel_shape[0],
"KERNEL_H": kernel_shape[1],
"KERNEL_W": kernel_shape[2],
"STRIDE_D": stride[0],
"STRIDE_H": stride[1],
"STRIDE_W": stride[2],
"PADDING_D": padding[0],
"PADDING_H": padding[1],
"PADDING_W": padding[2],
}
)
return kwargs
def _get_template_configs_impl(
self,
kernel_inputs: KernelInputs,
op_name: str,
) -> Generator[dict[str, Any], None, None]:
"""
Yield per-config kwargs (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps).
Args:
kernel_inputs: ConvKernelInputs containing input tensors
op_name: Operation name
Yields:
Dict of per-config kwargs for each configuration to try
"""
assert isinstance(kernel_inputs, ConvKernelInputs), (
"ConvTemplateConfigMixin requires ConvKernelInputs"
)
x, weight, bias = kernel_inputs.get_x_weight_bias()
# Calculate dimensions for heuristics
weight_size = weight.get_size()
out_chan = weight_size[0]
in_chan = weight_size[1]
# Batch * spatial dimensions product
x_size = x.get_size()
batch_spatial_product = sympy_product([x_size[0], *x_size[2:]])
# Get conv config generator from self (which is a BaseConfigHeuristic subclass)
conv_configs_generator = self.get_conv_configs()
dtype_size = x.get_dtype().itemsize
# Generate configs (reusing mm preprocess_mm_configs machinery)
for c in conv_configs_generator(
batch_spatial_product,
out_chan,
in_chan,
dtype_size=dtype_size,
op_name="conv",
):
# Yield per-config kwargs
yield {
"BLOCK_M": c.kwargs.get("BLOCK_M"),
"BLOCK_N": c.kwargs.get("BLOCK_N"),
"BLOCK_K": c.kwargs.get("BLOCK_K"),
"num_stages": c.num_stages,
"num_warps": c.num_warps,
}
# ATEN convolution heuristic (no per-config tuning)
@register_template_heuristic(aten_convolution.uid, None)
class ATenConvConfigHeuristic(TemplateConfigHeuristics):
"""
Pseudo heuristic for ATen convolution.
ATen doesn't have configs to tune - it's a single choice.
"""
def _get_template_configs_impl(
self,
kernel_inputs: KernelInputs,
op_name: str,
) -> Generator[dict[str, Any], None, None]:
# ATen doesn't have per-config kwargs to tune
yield dict()
def get_extra_kwargs(
self,
kernel_inputs: KernelInputs,
op_name: str,
) -> dict[str, Any]:
"""
ATen gets stride, padding, etc. as ordered kwargs for the C++ kernel.
"""
assert isinstance(kernel_inputs, ConvKernelInputs)
# Extract scalar values from kernel_inputs
stride = cast(tuple[int, ...], kernel_inputs.get_scalar("stride"))
padding = cast(tuple[int, ...], kernel_inputs.get_scalar("padding"))
dilation = cast(tuple[int, ...], kernel_inputs.get_scalar("dilation"))
transposed = cast(bool, kernel_inputs.get_scalar("transposed"))
output_padding = cast(
tuple[int, ...], kernel_inputs.get_scalar("output_padding")
)
groups = cast(int, kernel_inputs.get_scalar("groups"))
# Check if bias is None to match old behavior
# When bias is None: input_nodes = [x, weight], add 'bias' to kwargs and ordered list
# When bias is present: input_nodes = [x, weight, bias], don't add 'bias' to kwargs
x, weight, bias = kernel_inputs.get_x_weight_bias()
kwargs: dict[str, Any] = {
"stride": stride,
"padding": padding,
"dilation": dilation,
"transposed": transposed,
"output_padding": output_padding,
"groups": groups,
}
if bias is None:
# When bias is None, torch.convolution expects it as a kwarg
kwargs["bias"] = None
kwargs["ordered_kwargs_for_cpp_kernel"] = [
"bias",
"stride",
"padding",
"dilation",
"transposed",
"output_padding",
"groups",
]
else:
# When bias is present, it's passed as a positional arg (3rd in input_nodes)
kwargs["ordered_kwargs_for_cpp_kernel"] = [
"stride",
"padding",
"dilation",
"transposed",
"output_padding",
"groups",
]
return kwargs
# CUDA Conv2D/Conv3D heuristics
@register_template_heuristic(
conv2d_template.uid,
"cuda",
register=torch.version.hip is None,
)
@register_template_heuristic(
conv3d_template.uid,
"cuda",
register=torch.version.hip is None,
)
class CUDAConvTemplateConfigHeuristic(ConvTemplateConfigMixin, CUDAConfigHeuristic):
"""Conv template heuristic for CUDA."""
# ROCm Conv2D/Conv3D heuristics
@register_template_heuristic(
conv2d_template.uid,
"cuda",
register=torch.version.hip is not None,
)
@register_template_heuristic(
conv3d_template.uid,
"cuda",
register=torch.version.hip is not None,
)
class ROCmConvTemplateConfigHeuristic(ConvTemplateConfigMixin, ROCmConfigHeuristic):
"""Conv template heuristic for ROCm."""
# CPU Conv2D/Conv3D heuristics
@register_template_heuristic(conv2d_template.uid, "cpu")
@register_template_heuristic(conv3d_template.uid, "cpu")
class CPUConvTemplateConfigHeuristic(ConvTemplateConfigMixin, CPUConfigHeuristic):
"""Conv template heuristic for CPU."""
# XPU Conv2D/Conv3D heuristics
@register_template_heuristic(conv2d_template.uid, "xpu")
@register_template_heuristic(conv3d_template.uid, "xpu")
class XPUConvTemplateConfigHeuristic(ConvTemplateConfigMixin, XPUConfigHeuristic):
"""Conv template heuristic for XPU."""
# MTIA Conv2D/Conv3D heuristics
@register_template_heuristic(conv2d_template.uid, "mtia")
@register_template_heuristic(conv3d_template.uid, "mtia")
class MTIAConvTemplateConfigHeuristic(ConvTemplateConfigMixin, MTIAConfigHeuristic):
"""Conv template heuristic for MTIA."""

View File

@ -52,7 +52,26 @@ __all__ = [
"MemRecordsAcc",
]
from contextlib import ContextDecorator
try:
# Available in Python >= 3.2
from contextlib import ContextDecorator as _ContextDecorator
except ImportError:
import functools
class _ContextDecorator: # type: ignore[no-redef]
def __enter__(self):
raise NotImplementedError
def __exit__(self, exc_type, exc_val, exc_tb):
raise NotImplementedError
def __call__(self, func):
@functools.wraps(func)
def wrapped(*args, **kwargs):
with self:
return func(*args, **kwargs)
return wrapped
# global python state - whether profiler is currently enabled
@ -725,7 +744,8 @@ class profile:
return all_function_events
class record_function(ContextDecorator):
# pyrefly: ignore [invalid-inheritance]
class record_function(_ContextDecorator):
"""Context manager/function decorator that adds a label to a code block/function when running autograd profiler.
Label will only appear if CPU activity tracing is enabled.

View File

@ -108,14 +108,12 @@ struct FlightRecorder {
capture_cpp_stack_ = getCvarBool(
{"TORCH_FR_CPP_STACK", "TORCH_NCCL_TRACE_CPP_STACK"}, false);
enabled_ = max_entries_ > 0;
reset_epoch_start_idx_[0] = 0;
}
struct Entry {
size_t id_; // incremented id in the trace buffer
// used to figure out where in the circular entries
// buffer this entry will be located to
// update state information
size_t reset_epoch_; // epoch when this entry was created
size_t pg_id_;
std::tuple<std::string, std::string> pg_name_; // <group_name, group_desc>
@ -185,34 +183,11 @@ struct FlightRecorder {
size_t max_entries_ = 0;
size_t next_ = 0;
size_t id_ = 0;
size_t reset_epoch_ = 0;
std::unordered_map<size_t, size_t>
reset_epoch_start_idx_; // maps reset_epoch to the idx where it starts
std::map<size_t, std::shared_ptr<ProcessGroupStatus>> all_pg_status_;
std::map<std::tuple<std::string, std::string>, std::vector<uint64_t>>
pg_name_to_ranks_;
std::string comm_lib_version_;
struct TraceIdentifier {
std::optional<size_t> id;
std::optional<size_t> reset_epoch;
};
TraceIdentifier recordWithResetEnabled(
size_t pg_id,
const std::tuple<std::string, std::string>& pg_name,
size_t collective_seq_id,
size_t p2p_seq_id,
size_t op_id,
std::string profiling_name,
const std::vector<at::Tensor>& inputs,
const std::vector<at::Tensor>& outputs,
EventType* start,
EventType* end,
std::chrono::milliseconds timeout_ms,
std::shared_ptr<ProcessGroupStatus> pg_status,
bool isP2P);
std::optional<size_t> record(
size_t pg_id,
const std::tuple<std::string, std::string>& pg_name,
@ -238,16 +213,8 @@ struct FlightRecorder {
std::vector<Entry> dump_entries();
// Returns the index in entries_ for the given id and reset_epoch.
// Caller must hold mutex_lock before calling this method.
size_t getIdxFromId(size_t id, size_t reset_epoch) const;
// Returns the entry with the given id and reset_epoch, if it exists.
// Otherwise, returns std::nullopt.
TORCH_API std::optional<Entry> getEntry(
std::optional<size_t> id,
std::optional<size_t> reset_epoch);
// Returns the entry with the given id, if it exists. Otherwise, returns
// std::nullopt.
TORCH_API std::optional<Entry> getEntry(std::optional<size_t> id);
/*
@ -260,11 +227,6 @@ struct FlightRecorder {
never hang. (timing must also be enabled for compute_duration - see
TORCH_NCCL_ENABLE_TIMING).
*/
TORCH_API void retire_id(
std::optional<size_t> id,
std::optional<size_t> reset_epoch,
bool compute_duration = true);
TORCH_API void retire_id(
std::optional<size_t> id,
bool compute_duration = true);

View File

@ -53,41 +53,8 @@ std::optional<size_t> FlightRecorder<EventType>::record(
std::chrono::milliseconds timeout_ms,
std::shared_ptr<ProcessGroupStatus> pg_status,
bool isP2P) {
auto result = recordWithResetEnabled(
pg_id,
pg_name,
collective_seq_id,
p2p_seq_id,
op_id,
std::move(profiling_name),
inputs,
outputs,
start,
end,
timeout_ms,
std::move(pg_status),
isP2P);
return result.id;
}
template <typename EventType>
typename FlightRecorder<EventType>::TraceIdentifier FlightRecorder<EventType>::
recordWithResetEnabled(
size_t pg_id,
const std::tuple<std::string, std::string>& pg_name,
size_t collective_seq_id,
size_t p2p_seq_id,
size_t op_id,
std::string profiling_name,
const std::vector<at::Tensor>& inputs,
const std::vector<at::Tensor>& outputs,
EventType* start,
EventType* end,
std::chrono::milliseconds timeout_ms,
std::shared_ptr<ProcessGroupStatus> pg_status,
bool isP2P) {
if (!enabled_) {
return TraceIdentifier{std::nullopt, std::nullopt};
return std::nullopt;
}
if (all_pg_status_.find(pg_id) == all_pg_status_.end()) {
// Current pg_status is not in FR.
@ -97,13 +64,8 @@ typename FlightRecorder<EventType>::TraceIdentifier FlightRecorder<EventType>::
torch::CapturedTraceback::gather(true, true, capture_cpp_stack_);
std::lock_guard<std::mutex> guard(mutex_);
TORCH_CHECK(
reset_epoch_start_idx_.find(reset_epoch_) !=
reset_epoch_start_idx_.end());
auto te = Entry{
id_,
reset_epoch_,
pg_id,
pg_name,
collective_seq_id,
@ -142,20 +104,15 @@ typename FlightRecorder<EventType>::TraceIdentifier FlightRecorder<EventType>::
te.sizes_.insert(te.sizes_.end(), sizes.begin(), sizes.end());
}
const auto next = next_++;
if (entries_.size() < max_entries_) {
entries_.emplace_back(std::move(te));
} else {
entries_[next] = std::move(te);
entries_[next_++] = std::move(te);
if (next_ == max_entries_) {
next_ = 0;
}
}
if (next_ == max_entries_) {
next_ = 0;
}
const auto id = id_++;
return TraceIdentifier{id, reset_epoch_};
return id_++;
}
template <typename EventType>
@ -206,20 +163,15 @@ std::vector<typename FlightRecorder<EventType>::Entry> FlightRecorder<
std::vector<Entry> result;
{
std::lock_guard<std::mutex> guard(mutex_);
// Filter entries during insertion - only keep entries from current epoch
auto filter = [this](const Entry& e) {
return e.reset_epoch_ == reset_epoch_;
};
std::copy_if(
result.reserve(entries_.size());
result.insert(
result.end(),
entries_.begin() + static_cast<std::ptrdiff_t>(next_),
entries_.end(),
std::back_inserter(result),
filter);
std::copy_if(
entries_.end());
result.insert(
result.end(),
entries_.begin(),
entries_.begin() + static_cast<std::ptrdiff_t>(next_),
std::back_inserter(result),
filter);
entries_.begin() + static_cast<std::ptrdiff_t>(next_));
}
// query any remaining events
for (auto& r : result) {
@ -230,47 +182,28 @@ std::vector<typename FlightRecorder<EventType>::Entry> FlightRecorder<
}
template <typename EventType>
// Returns the index in entries_ for the given id and reset_epoch.
// Caller must hold mutex_lock before calling this method.
size_t FlightRecorder<EventType>::getIdxFromId(size_t id, size_t reset_epoch)
const {
// Look up the starting idx for the given reset epoch
auto it = reset_epoch_start_idx_.find(reset_epoch);
TORCH_CHECK(it != reset_epoch_start_idx_.end());
// Calculate idx based on where the epoch started
return (it->second + id) % max_entries_;
}
template <typename EventType>
// Returns the entry with the given id and reset_epoch, if it exists. Otherwise,
// returns std::nullopt.
// Returns the entry with the given id, if it exists. Otherwise, returns
// std::nullopt.
std::optional<typename FlightRecorder<EventType>::Entry> FlightRecorder<
EventType>::
getEntry(std::optional<size_t> id, std::optional<size_t> reset_epoch) {
if (!enabled_ || !id || !reset_epoch) {
EventType>::getEntry(std::optional<size_t> id) {
if (!enabled_ || !id) {
return std::nullopt;
}
std::unique_lock<std::mutex> guard(mutex_);
Entry entry = entries_.at(getIdxFromId(*id, *reset_epoch));
if (entry.id_ == *id && entry.reset_epoch_ == *reset_epoch) {
Entry entry = entries_.at(*id % max_entries_);
if (entry.id_ == *id) {
return entry;
} else {
return std::nullopt;
}
return std::nullopt;
}
template <typename EventType>
std::optional<typename FlightRecorder<EventType>::Entry> FlightRecorder<
EventType>::getEntry(std::optional<size_t> id) {
return getEntry(id, 0);
}
template <typename EventType>
void FlightRecorder<EventType>::retire_id(
std::optional<size_t> id,
std::optional<size_t> reset_epoch,
bool compute_duration) {
if (!enabled_ || !id || !reset_epoch) {
if (!enabled_ || !id) {
return;
}
@ -281,8 +214,8 @@ void FlightRecorder<EventType>::retire_id(
std::unique_lock<std::mutex> guard(mutex_);
Entry* entry = &entries_.at(getIdxFromId(*id, *reset_epoch));
if (entry->id_ == *id && entry->reset_epoch_ == *reset_epoch) {
Entry* entry = &entries_.at(*id % max_entries_);
if (entry->id_ == *id) {
update_state(*entry);
if (compute_duration) {
@ -304,8 +237,8 @@ void FlightRecorder<EventType>::retire_id(
guard.lock();
// Refresh the entry pointer, see if the entry has been overwritten
entry = &entries_.at(getIdxFromId(*id, *reset_epoch));
if (!(entry->id_ == *id && entry->reset_epoch_ == *reset_epoch)) {
entry = &entries_.at(*id % max_entries_);
if (entry->id_ != *id) {
LOG(INFO) << "retire_id abandoned for id " << *id
<< ", event was overwritten while waiting to compute duration.";
return;
@ -316,23 +249,12 @@ void FlightRecorder<EventType>::retire_id(
}
}
template <typename EventType>
void FlightRecorder<EventType>::retire_id(
std::optional<size_t> id,
bool compute_duration) {
retire_id(id, 0, compute_duration);
}
template <typename EventType>
void FlightRecorder<EventType>::reset_all() {
std::lock_guard<std::mutex> guard(mutex_);
if (!entries_.empty()) {
// Soft delete: increment epoch to mark all existing entries as old
// Store where the new epoch starts in the circular buffer
reset_epoch_++;
reset_epoch_start_idx_[reset_epoch_] = next_;
id_ = 0;
}
next_ = 0;
id_ = 0;
entries_.clear();
}
template <typename EventType>

View File

@ -708,8 +708,7 @@ void ProcessGroupGloo::runLoop(int workerIndex) {
// TODO: We need to have numel of tensors for gloo as well.
pgStatus_->lastCompletedNumelIn = 0;
pgStatus_->lastCompletedNumelOut = 0;
FlightRecorder<c10::Event>::get()->retire_id(
work->trace_id_, work->trace_reset_epoch_, false);
FlightRecorder<c10::Event>::get()->retire_id(work->trace_id_, false);
lock.lock();
workInProgress_[workerIndex].reset();
}
@ -781,7 +780,7 @@ void ProcessGroupGloo::enqueue(c10::intrusive_ptr<AsyncWork> work) {
pgStatus_->lastEnqueuedNumelOut = 0;
// using c10d::FlightRecorder;
// TODO: We need to have a way to use c10::Event inside gloo as well.
auto traceId = FlightRecorder<c10::Event>::get()->recordWithResetEnabled(
work->trace_id_ = FlightRecorder<c10::Event>::get()->record(
local_id_,
std::make_tuple(pg_uid_, pg_desc_),
collectiveCounter_,
@ -796,8 +795,6 @@ void ProcessGroupGloo::enqueue(c10::intrusive_ptr<AsyncWork> work) {
work->getTimeout(),
pgStatus_,
false);
work->trace_id_ = traceId.id;
work->trace_reset_epoch_ = traceId.reset_epoch;
workQueue_.push_back(std::move(work));
lock.unlock();

View File

@ -99,7 +99,6 @@ class TORCH_API ProcessGroupGloo : public Backend {
// unique id used to tell the trace buffer that this
// work has completed
std::optional<uint64_t> trace_id_;
std::optional<uint64_t> trace_reset_epoch_;
std::shared_ptr<gloo::Context> context_;
const std::chrono::milliseconds timeout_;

View File

@ -575,7 +575,6 @@ ProcessGroupNCCL::WorkNCCL::WorkNCCL(const WorkNCCL& w)
futureWorkResult_(w.futureWorkResult_),
timingEnabled_(w.timingEnabled_),
trace_id_(w.trace_id_),
trace_reset_epoch_(w.trace_reset_epoch_),
distDebugLevel_(w.distDebugLevel_) {
exception_ = w.exception_;
}
@ -705,9 +704,9 @@ bool ProcessGroupNCCL::WorkNCCL::checkTimeout(
// Print the traceback of the collective at call time
std::string ProcessGroupNCCL::WorkNCCL::getTraceback() const {
// First step we get the corresponding record entry from FR, based on work's
// trace_id_ and trace_reset_epoch_
// trace_id_
std::optional<FlightRecorderCUDA::Entry> entry =
FlightRecorderCUDA::get()->getEntry(trace_id_, trace_reset_epoch_);
FlightRecorderCUDA::get()->getEntry(trace_id_);
if (entry.has_value()) {
auto entryVal = entry.value();
// Get stack trace from FR entry, in string format
@ -2395,8 +2394,7 @@ void ProcessGroupNCCL::Watchdog::runLoop() {
pg_->pgStatus_->lastCompletedWorkName = opTypeToString(work.opType_);
pg_->pgStatus_->lastCompletedNumelIn = work.numelIn_;
pg_->pgStatus_->lastCompletedNumelOut = work.numelOut_;
FlightRecorderCUDA::get()->retire_id(
work.trace_id_, work.trace_reset_epoch_, true);
FlightRecorderCUDA::get()->retire_id(work.trace_id_, true);
if (pg_->onCompletionHook_) {
// Move Work object to completedWorkList_ to be consumed by the hook
// thread
@ -3362,7 +3360,7 @@ c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> ProcessGroupNCCL::initWork(
// these objects to the Work because it has implications for keeping those
// tensors alive longer and adds overhead when copying Work objects
// between threads
auto traceId = FlightRecorderCUDA::get()->recordWithResetEnabled(
r->trace_id_ = FlightRecorderCUDA::get()->record(
local_id_,
std::make_tuple(pg_uid_, pg_desc_),
seqCollective_,
@ -3376,8 +3374,6 @@ c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> ProcessGroupNCCL::initWork(
options_->timeout,
pgStatus_,
isP2P);
r->trace_id_ = traceId.id;
r->trace_reset_epoch_ = traceId.reset_epoch;
}
return r;
}
@ -3597,7 +3593,6 @@ float ProcessGroupNCCL::endTimeEstimate() {
#ifdef NCCL_SIM_INFO_INITIALIZER
ncclSimInfo_t simInfo = NCCL_SIM_INFO_INITIALIZER;
C10D_NCCL_CHECK(ncclGroupSimulateEnd(&simInfo), std::nullopt);
--ncclActiveGroupCounter_;
return simInfo.estimatedTime;
#else
TORCH_CHECK(
@ -3681,7 +3676,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
// later in endCoalescing we record a 'coalesced' Work which has
// timing/state updates via watchdog thread, but lacks op metadata such as
// input/output sizes and profilingTitle per-op in the group.
FlightRecorderCUDA::get()->recordWithResetEnabled(
FlightRecorderCUDA::get()->record(
local_id_,
std::make_tuple(pg_uid_, pg_desc_),
seqCollective_,
@ -4173,7 +4168,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::pointToPoint(
// TODO(whc) because we don't pass output {tensor} to initWork, we tell
// initWork to not record, and then we manually call record passing all the
// information it wants.
auto traceId = FlightRecorderCUDA::get()->recordWithResetEnabled(
work->trace_id_ = FlightRecorderCUDA::get()->record(
local_id_,
std::make_tuple(pg_uid_, pg_desc_),
seqCollective_,
@ -4187,8 +4182,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::pointToPoint(
options_->timeout,
pgStatus_,
/*isP2P=*/true);
work->trace_id_ = traceId.id;
work->trace_reset_epoch_ = traceId.reset_epoch;
}
// Only check for NaN for send ops, for recv ops `tensor` can be a random

View File

@ -505,7 +505,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// unique id used to tell the trace buffer that this
// work has completed
std::optional<uint64_t> trace_id_;
std::optional<uint64_t> trace_reset_epoch_;
DebugLevel distDebugLevel_;
friend class ProcessGroupNCCL;
};

View File

@ -4,7 +4,6 @@
#include <torch/headeronly/core/ScalarType.h>
#include <torch/headeronly/macros/Macros.h>
#include <torch/headeronly/util/Exception.h>
#include <torch/headeronly/util/HeaderOnlyArrayRef.h>
#include <torch/headeronly/util/shim_utils.h>
#include <climits>
#include <memory>
@ -14,7 +13,6 @@
HIDDEN_NAMESPACE_BEGIN(torch, stable)
using accelerator::DeviceIndex;
using torch::headeronly::IntHeaderOnlyArrayRef;
using torch::headeronly::ScalarType;
// The torch::stable::Tensor class is a highlevel C++ wrapper around
@ -95,32 +93,6 @@ class Tensor {
return numel;
}
// note: this API is, for all intents and purposes, the same as the one in
// TensorBase.h: it returns a borrowed reference of the dimension sizes of
// a Tensor.
//
// The only difference is that it returns a header-only IntHeaderOnlyArrayRef,
// which has slightly less functionality than a regular IntArrayRef. See
// [HeaderOnlyArrayRef vs ArrayRef note] for more details.
IntHeaderOnlyArrayRef sizes() const {
int64_t* sizes;
TORCH_ERROR_CODE_CHECK(aoti_torch_get_sizes(ath_.get(), &sizes));
return IntHeaderOnlyArrayRef(sizes, dim());
}
// note: this API is, for all intents and purposes, the same as the one in
// TensorBase.h: it returns a borrowed reference of the strides of a
// Tensor.
//
// The only difference is that it returns a header-only IntHeaderOnlyArrayRef,
// which has slightly less functionality than a regular IntArrayRef. See
// [HeaderOnlyArrayRef vs ArrayRef note] for more details.
IntHeaderOnlyArrayRef strides() const {
int64_t* strides;
TORCH_ERROR_CODE_CHECK(aoti_torch_get_strides(ath_.get(), &strides));
return IntHeaderOnlyArrayRef(strides, dim());
}
// note: this is a subset of the original TensorBase API. It takes no
// arguments whereas the original API takes in a kwarg of memory format.
// Here, we assume the default contiguous memory format.

View File

@ -1,8 +1,9 @@
import functools
import math
import operator
from collections.abc import Callable, Sequence
from collections.abc import Sequence
from datetime import timedelta
from typing import Callable
import torch
from torch._C import ScriptObject

View File

@ -10,7 +10,6 @@ from ._context_parallel._attention import (
_enable_context_parallel_dispatcher,
_is_causal_behavior,
_RotateMethod,
_templated_ring_attention,
context_parallel,
context_parallel_unshard,
set_rotate_method,
@ -23,7 +22,6 @@ from ._context_parallel._load_balancer import (
)
# TODO(fegin): add deprecation message once the final interfaces are concluded.
__all__ = [
"_CausalBehavior",
"_context_parallel_shard",
@ -33,7 +31,6 @@ __all__ = [
"_enable_context_parallel_dispatcher",
"_is_causal_behavior",
"_RotateMethod",
"_templated_ring_attention",
"context_parallel",
"context_parallel_unshard",
"set_rotate_method",