mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-16 07:24:54 +08:00
Compare commits
3 Commits
ngimel/hos
...
ciflow/tru
| Author | SHA1 | Date | |
|---|---|---|---|
| fa9d5c2dd7 | |||
| f048cb1f3c | |||
| c277e07f77 |
@ -226,8 +226,8 @@ template <
|
||||
typename B = HostBlock<S>>
|
||||
struct CachingHostAllocatorImpl {
|
||||
virtual ~CachingHostAllocatorImpl() {
|
||||
if (active_) {
|
||||
active_ = false;
|
||||
active_ = false;
|
||||
if (pinned_use_background_threads()) {
|
||||
getBackgroundThreadPool()->waitWorkComplete();
|
||||
}
|
||||
}
|
||||
@ -260,7 +260,6 @@ struct CachingHostAllocatorImpl {
|
||||
if (pinned_use_background_threads()) {
|
||||
// Launch the background thread and process events in a loop.
|
||||
static bool background_thread_flag [[maybe_unused]] = [this] {
|
||||
active_ = true;
|
||||
getBackgroundThreadPool()->run([&]() {
|
||||
while (active_) {
|
||||
process_events();
|
||||
@ -684,9 +683,9 @@ struct CachingHostAllocatorImpl {
|
||||
alignas(hardware_destructive_interference_size) std::mutex events_mutex_;
|
||||
std::deque<std::pair<E, B*>> events_; // event queue paired with block
|
||||
|
||||
// Indicates whether the event-processing thread pool is active.
|
||||
// Indicates whether the object is active.
|
||||
// Set to false in the destructor to signal background threads to stop.
|
||||
std::atomic<bool> active_{false};
|
||||
std::atomic<bool> active_{true};
|
||||
protected:
|
||||
alignas(hardware_destructive_interference_size) HostStatsStaged stats_;
|
||||
};
|
||||
|
||||
@ -24,13 +24,7 @@ namespace detail {
|
||||
// radix_sort_pairs doesn't interact with value_t other than to copy
|
||||
// the data, so we can save template instantiations by reinterpreting
|
||||
// it as an opaque type.
|
||||
// We use native integer types for 1/2/4/8-byte values to reduce
|
||||
// register usage in CUDA kernels. For sizes > 8 fall back to char array.
|
||||
template <int N> struct alignas(N) OpaqueType { char data[N]; };
|
||||
template <> struct alignas(1) OpaqueType<1> { uint8_t data; };
|
||||
template <> struct alignas(2) OpaqueType<2> { uint16_t data; };
|
||||
template <> struct alignas(4) OpaqueType<4> { uint32_t data; };
|
||||
template <> struct alignas(8) OpaqueType<8> { uint64_t data; };
|
||||
|
||||
template<typename key_t, int value_size>
|
||||
void radix_sort_pairs_impl(
|
||||
|
||||
@ -247,8 +247,8 @@ void binary_kernel_reduce(TensorIteratorBase& iter, ops_t ops, init_t init) {
|
||||
});
|
||||
}
|
||||
|
||||
template <typename func_t, typename vec_func_t, typename ident_t = double>
|
||||
void binary_kernel_reduce_vec(TensorIteratorBase& iter, func_t op, vec_func_t vop, ident_t ident = static_cast<ident_t>(0)) {
|
||||
template <typename func_t, typename vec_func_t>
|
||||
void binary_kernel_reduce_vec(TensorIteratorBase& iter, func_t op, vec_func_t vop, double ident = 0) {
|
||||
using traits = binary_function_traits<func_t>;
|
||||
static_assert(
|
||||
all_same<
|
||||
|
||||
@ -339,13 +339,33 @@ void or_kernel_impl(TensorIterator& iter) {
|
||||
}
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
struct MinValuesOps: public at::native::MinOps<scalar_t> {
|
||||
using arg_t = typename MinOps<scalar_t>::arg_t;
|
||||
static scalar_t project(arg_t arg) {
|
||||
return arg.first;
|
||||
}
|
||||
};
|
||||
|
||||
void min_values_kernel_impl(TensorIterator& iter) {
|
||||
// This case is special because of Vectorized<int64_t> does not
|
||||
// handle upper_bound<int64_t>().
|
||||
// See: https://github.com/pytorch/pytorch/issues/43254
|
||||
if (iter.dtype() == kLong || iter.dtype() == kUInt64) {
|
||||
AT_DISPATCH_V2(iter.dtype(), "min_values_cpu", AT_WRAP([&iter] {
|
||||
binary_kernel_reduce(
|
||||
iter,
|
||||
MinValuesOps<scalar_t>{},
|
||||
std::pair<scalar_t, int64_t>(upper_bound<scalar_t>(), -1));
|
||||
}), kLong, kUInt64);
|
||||
return;
|
||||
}
|
||||
AT_DISPATCH_V2(iter.dtype(), "min_values_cpu", AT_WRAP([&iter] {
|
||||
binary_kernel_reduce_vec(
|
||||
iter,
|
||||
[](scalar_t a, scalar_t b) -> scalar_t { return min_impl(a, b); },
|
||||
[](Vectorized<scalar_t> a, Vectorized<scalar_t> b) { return minimum(a, b); },
|
||||
upper_bound<scalar_t>());
|
||||
static_cast<double>(upper_bound<scalar_t>()));
|
||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool);
|
||||
}
|
||||
|
||||
|
||||
@ -47,10 +47,20 @@ Tensor sgd_out_of_place(
|
||||
STD_TORCH_CHECK(param.get_device() == -1, "CPU device index = -1");
|
||||
STD_TORCH_CHECK(param.get_device_index() == -1, "CPU device index = -1");
|
||||
|
||||
// testing Tensor strides + stride
|
||||
STD_TORCH_CHECK(param.strides()[0] == param.stride(0));
|
||||
int64_t *param_sizes;
|
||||
int64_t *param_strides;
|
||||
aoti_torch_get_sizes(param.get(), ¶m_sizes);
|
||||
aoti_torch_get_strides(param.get(), ¶m_strides);
|
||||
|
||||
auto out = new_empty(param, param.sizes());
|
||||
int32_t param_dtype;
|
||||
aoti_torch_get_dtype(param.get(), ¶m_dtype);
|
||||
|
||||
int32_t param_device_type;
|
||||
aoti_torch_get_device_type(param.get(), ¶m_device_type);
|
||||
|
||||
AtenTensorHandle out_ath;
|
||||
aoti_torch_empty_strided(param.dim(), param_sizes, param_strides, param_dtype, param_device_type, param.get_device(), &out_ath);
|
||||
auto out = Tensor(out_ath);
|
||||
|
||||
sgd_math(
|
||||
reinterpret_cast<float*>(param.data_ptr()),
|
||||
@ -334,8 +344,6 @@ Tensor my_new_empty_dtype_variant(Tensor t) {
|
||||
// Still using a std::vector below even though people can just pass in an
|
||||
// initializer list (which will be implicitly converted to an HeaderOnlyArrayRef)
|
||||
// directly.
|
||||
// This is to test that passing in a std::vector works for BC. (It gets
|
||||
// implicitly converted to HeaderOnlyArrayRef too!)
|
||||
std::vector<int64_t> sizes = {2, 5};
|
||||
auto dtype = std::make_optional(torch::headeronly::ScalarType::BFloat16);
|
||||
return new_empty(t, sizes, dtype);
|
||||
|
||||
@ -5789,229 +5789,6 @@ class NCCLTraceTest(NCCLTraceTestBase):
|
||||
else:
|
||||
self.assertTrue("duration_ms" not in t["entries"][0])
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
@parametrize("timing_enabled", [True, False])
|
||||
def test_fr_record_reset_circular_buffer_full(self, timing_enabled):
|
||||
"""
|
||||
Test that when the circular buffer in entries_ is full and we call reset,
|
||||
then fill the buffer with new entries, dump_entries returns only the new
|
||||
entries and not the old ones.
|
||||
"""
|
||||
if self.rank == self.MAIN_PROCESS_RANK:
|
||||
return
|
||||
|
||||
# Override buffer size to 10 for faster testing
|
||||
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10"
|
||||
|
||||
pg = self._create_process_group_nccl()
|
||||
if timing_enabled:
|
||||
pg._enable_collectives_timing()
|
||||
device = self.local_device
|
||||
self.set_thread_name("fr_test_thread")
|
||||
a = torch.full((3, 4), float(self.rank), device=device)
|
||||
|
||||
# Fill the buffer completely with 10 entries
|
||||
for _ in range(10):
|
||||
f = pg.allreduce(a)
|
||||
f.wait()
|
||||
torch.cuda.synchronize(device=device)
|
||||
time.sleep(1)
|
||||
|
||||
# Verify buffer is full with 10 entries
|
||||
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
|
||||
self.assertEqual(len(t["entries"]), 10)
|
||||
|
||||
# Now reset the flight recorder
|
||||
torch._C._distributed_c10d._reset_fr_recording_nccl()
|
||||
|
||||
# Add new entries after reset - fill the buffer completely again
|
||||
for _ in range(10):
|
||||
f = pg.allreduce(a)
|
||||
f.wait()
|
||||
torch.cuda.synchronize(device=device)
|
||||
time.sleep(1)
|
||||
|
||||
# Verify we get exactly 10 new entries, not 20
|
||||
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
|
||||
self.assertEqual(len(t["entries"]), 10)
|
||||
|
||||
# Verify all entries have the expected properties (from after reset)
|
||||
# After reset, record IDs should start from 0 again
|
||||
for i, entry in enumerate(t["entries"]):
|
||||
self.assertIn("profiling_name", entry)
|
||||
self.assertEqual(entry["profiling_name"], "nccl:all_reduce")
|
||||
self.assertIn("record_id", entry)
|
||||
# Record IDs should be sequential starting from 0 after reset
|
||||
self.assertEqual(entry["record_id"], i)
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
@parametrize("timing_enabled", [True, False])
|
||||
def test_fr_record_reset_partial_overwrite(self, timing_enabled):
|
||||
"""
|
||||
Test that when the circular buffer is full, we reset, and then add fewer
|
||||
entries than the buffer size, we only get the new entries.
|
||||
This tests that old entries at the end of the circular buffer are properly
|
||||
filtered out based on reset_epoch.
|
||||
"""
|
||||
if self.rank == self.MAIN_PROCESS_RANK:
|
||||
return
|
||||
|
||||
# Override buffer size to 10 for faster testing
|
||||
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10"
|
||||
|
||||
pg = self._create_process_group_nccl()
|
||||
if timing_enabled:
|
||||
pg._enable_collectives_timing()
|
||||
device = self.local_device
|
||||
self.set_thread_name("fr_test_thread")
|
||||
a = torch.full((3, 4), float(self.rank), device=device)
|
||||
|
||||
# Fill the buffer completely
|
||||
for _ in range(10):
|
||||
f = pg.allreduce(a)
|
||||
f.wait()
|
||||
torch.cuda.synchronize(device=device)
|
||||
time.sleep(1)
|
||||
|
||||
# Reset the flight recorder
|
||||
torch._C._distributed_c10d._reset_fr_recording_nccl()
|
||||
|
||||
# Add only 3 new entries (much less than buffer size)
|
||||
for _ in range(3):
|
||||
f = pg.allreduce(a)
|
||||
f.wait()
|
||||
torch.cuda.synchronize(device=device)
|
||||
time.sleep(1)
|
||||
|
||||
# Verify we only get the 3 new entries, not 10
|
||||
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
|
||||
self.assertEqual(len(t["entries"]), 3)
|
||||
|
||||
# Verify record IDs start from 0 after reset
|
||||
for i, entry in enumerate(t["entries"]):
|
||||
self.assertIn("record_id", entry)
|
||||
self.assertEqual(entry["record_id"], i)
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
@parametrize("timing_enabled", [True, False])
|
||||
def test_fr_record_reset_wraparound(self, timing_enabled):
|
||||
"""
|
||||
Test that when we reset in the middle of the circular buffer and then
|
||||
wrap around, dump_entries correctly returns only entries from the current
|
||||
epoch in the correct order.
|
||||
"""
|
||||
if self.rank == self.MAIN_PROCESS_RANK:
|
||||
return
|
||||
|
||||
# Override buffer size to 10 for faster testing
|
||||
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10"
|
||||
|
||||
pg = self._create_process_group_nccl()
|
||||
if timing_enabled:
|
||||
pg._enable_collectives_timing()
|
||||
device = self.local_device
|
||||
self.set_thread_name("fr_test_thread")
|
||||
a = torch.full((3, 4), float(self.rank), device=device)
|
||||
|
||||
# Fill half the buffer
|
||||
for _ in range(5):
|
||||
f = pg.allreduce(a)
|
||||
f.wait()
|
||||
torch.cuda.synchronize(device=device)
|
||||
time.sleep(1)
|
||||
|
||||
# Reset at this point (reset happens at index 5)
|
||||
torch._C._distributed_c10d._reset_fr_recording_nccl()
|
||||
|
||||
# Now add 8 entries, which will wrap around
|
||||
# (5->9 fills rest of buffer, then 0->2 wraps around)
|
||||
for _ in range(8):
|
||||
f = pg.allreduce(a)
|
||||
f.wait()
|
||||
torch.cuda.synchronize(device=device)
|
||||
time.sleep(1)
|
||||
|
||||
# Should get exactly 8 entries, properly ordered
|
||||
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
|
||||
self.assertEqual(len(t["entries"]), 8)
|
||||
|
||||
# Entries should be in chronological order
|
||||
# The dump_entries() method returns entries from next_ to end, then 0 to next_
|
||||
# After filtering old entries, we should have 8 entries in order
|
||||
# Verify record IDs start from 0 after reset (id_ is reset in reset_all())
|
||||
for i, entry in enumerate(t["entries"]):
|
||||
self.assertIn("profiling_name", entry)
|
||||
self.assertIn("record_id", entry)
|
||||
self.assertEqual(entry["record_id"], i)
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
@parametrize("timing_enabled", [True, False])
|
||||
def test_fr_record_multiple_resets(self, timing_enabled):
|
||||
"""
|
||||
Test multiple consecutive resets to ensure each reset properly increments
|
||||
the epoch and filters out entries from previous epochs.
|
||||
"""
|
||||
if self.rank == self.MAIN_PROCESS_RANK:
|
||||
return
|
||||
|
||||
# Override buffer size to 10 for faster testing
|
||||
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10"
|
||||
|
||||
pg = self._create_process_group_nccl()
|
||||
if timing_enabled:
|
||||
pg._enable_collectives_timing()
|
||||
device = self.local_device
|
||||
self.set_thread_name("fr_test_thread")
|
||||
a = torch.full((3, 4), float(self.rank), device=device)
|
||||
|
||||
# First batch: 2 entries
|
||||
for _ in range(2):
|
||||
f = pg.allreduce(a)
|
||||
f.wait()
|
||||
torch.cuda.synchronize(device=device)
|
||||
time.sleep(1)
|
||||
|
||||
# First reset
|
||||
torch._C._distributed_c10d._reset_fr_recording_nccl()
|
||||
|
||||
# Second batch: 3 entries
|
||||
for _ in range(3):
|
||||
f = pg.allreduce(a)
|
||||
f.wait()
|
||||
torch.cuda.synchronize(device=device)
|
||||
time.sleep(1)
|
||||
|
||||
# Second reset
|
||||
torch._C._distributed_c10d._reset_fr_recording_nccl()
|
||||
|
||||
# Third batch: 4 entries
|
||||
for _ in range(4):
|
||||
f = pg.allreduce(a)
|
||||
f.wait()
|
||||
torch.cuda.synchronize(device=device)
|
||||
time.sleep(1)
|
||||
|
||||
# Should only see the last 4 entries
|
||||
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
|
||||
self.assertEqual(len(t["entries"]), 4)
|
||||
|
||||
# Verify record IDs start from 0 after the last reset
|
||||
for i, entry in enumerate(t["entries"]):
|
||||
self.assertIn("record_id", entry)
|
||||
self.assertEqual(entry["record_id"], i)
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def check_if_test_is_skipped(fn):
|
||||
def wrapper(self, *args, **kwargs):
|
||||
|
||||
@ -8,11 +8,21 @@ from torch._dynamo.graph_deduplication import apply_graph_deduplication
|
||||
from torch._dynamo.graph_utils import _detect_cycles
|
||||
from torch._dynamo.output_graph import FakeRootModule
|
||||
from torch._dynamo.test_case import TestCase
|
||||
from torch._dynamo.testing import extract_graph, extract_graph_and_tracker, normalize_gm
|
||||
from torch._dynamo.testing import (
|
||||
AotEagerAndRecordGraphs,
|
||||
extract_graph_and_tracker,
|
||||
normalize_gm,
|
||||
)
|
||||
from torch.compiler import allow_in_graph
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
|
||||
def extract_graph(fn, *args, **kwargs):
|
||||
backend = AotEagerAndRecordGraphs()
|
||||
result = torch.compile(backend=backend)(fn)(*args, **kwargs)
|
||||
return result, backend.graphs, backend.fw_graphs
|
||||
|
||||
|
||||
def graph_str(gm):
|
||||
return normalize_gm(gm.print_readable(print_output=False))
|
||||
|
||||
@ -30,7 +40,7 @@ class GraphDededuplicationTests(TestCase):
|
||||
super().tearDown()
|
||||
|
||||
def run_and_return_graphs(self, fn, *args, **kwargs):
|
||||
return extract_graph(fn, *args, **kwargs)[0:3]
|
||||
return extract_graph(fn, *args, **kwargs)
|
||||
|
||||
def run_and_get_simple_graph(self):
|
||||
def fn(x, y):
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
import unittest
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Callable, Union
|
||||
|
||||
import torch
|
||||
import torch._dynamo
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
from typing import NamedTuple, Optional, TYPE_CHECKING
|
||||
from typing import Callable, NamedTuple, Optional
|
||||
|
||||
import torch
|
||||
import torch._dynamo
|
||||
@ -7,10 +7,6 @@ from torch._dynamo.test_case import run_tests, TestCase
|
||||
from torch._dynamo.testing import CompileCounter, same
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
|
||||
"""
|
||||
This is an example of a pure-python version of autograd implemented by
|
||||
@zdevito. It represents a rather challenging test case for TorchDynamo
|
||||
|
||||
@ -1,13 +1,11 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
import functools
|
||||
import re
|
||||
import unittest
|
||||
import weakref
|
||||
|
||||
import torch
|
||||
import torch._dynamo.test_case
|
||||
import torch._dynamo.testing
|
||||
from torch._dynamo.testing import extract_graph, remove_trailing_space
|
||||
from torch.testing._internal.common_cuda import TEST_MULTIGPU
|
||||
from torch.testing._internal.common_utils import requires_cuda
|
||||
|
||||
@ -17,14 +15,6 @@ requires_multigpu = functools.partial(
|
||||
)
|
||||
|
||||
|
||||
def remove_file_comment(gm_str: str) -> str:
|
||||
return remove_trailing_space(re.sub(r"File.*\n", "\n", gm_str))
|
||||
|
||||
|
||||
def print_graph(graph: torch.fx.GraphModule) -> str:
|
||||
return remove_file_comment(graph.print_readable())
|
||||
|
||||
|
||||
class TestStreams(torch._dynamo.test_case.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
@ -46,7 +36,9 @@ class TestStreams(torch._dynamo.test_case.TestCase):
|
||||
|
||||
@requires_cuda
|
||||
def test_stream_enter_exit(self):
|
||||
def fn(x, y, s1, s2):
|
||||
def fn(x, y):
|
||||
s2 = torch.Stream()
|
||||
s1 = torch.Stream()
|
||||
with s1:
|
||||
z1 = torch.add(x, y)
|
||||
with s2:
|
||||
@ -55,36 +47,13 @@ class TestStreams(torch._dynamo.test_case.TestCase):
|
||||
|
||||
return y
|
||||
|
||||
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2), torch.Stream(), torch.Stream())
|
||||
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
|
||||
expected = fn(*inp)
|
||||
(
|
||||
actual,
|
||||
_,
|
||||
fw_graphs,
|
||||
_,
|
||||
) = extract_graph(fn, *inp)
|
||||
self.assertEqual(len(fw_graphs), 1)
|
||||
fn_opt = torch.compile(fn, fullgraph=True)
|
||||
actual = fn_opt(*inp)
|
||||
self.assertEqual(expected, actual)
|
||||
self.assertExpectedInline(
|
||||
print_graph(fw_graphs[0]),
|
||||
"""\
|
||||
class <lambda>(torch.nn.Module):
|
||||
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
|
||||
# Annotation: {'stream': None}
|
||||
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
|
||||
|
||||
# Annotation: {'stream': None}
|
||||
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
|
||||
|
||||
# Annotation: {'stream': None}
|
||||
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 2); add_1 = None
|
||||
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_2, add); add_2 = add = None
|
||||
return (add_3,)
|
||||
""",
|
||||
)
|
||||
|
||||
@requires_cuda
|
||||
@unittest.skip("Needs graph break support with annotation context")
|
||||
def test_stream_context_graph_break(self):
|
||||
def fn(x, y):
|
||||
s2 = torch.Stream()
|
||||
@ -101,16 +70,9 @@ class <lambda>(torch.nn.Module):
|
||||
|
||||
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
|
||||
expected = fn(*inp)
|
||||
(
|
||||
actual,
|
||||
_,
|
||||
fw_graphs,
|
||||
_,
|
||||
) = extract_graph(fn, *inp)
|
||||
fn_opt = torch.compile(fn)
|
||||
actual = fn_opt(*inp)
|
||||
self.assertEqual(expected, actual)
|
||||
self.assertEqual(len(fw_graphs), 2)
|
||||
self.assertExpectedInline(print_graph(fw_graphs[0]), """""")
|
||||
self.assertExpectedInline(print_graph(fw_graphs[1]), """""")
|
||||
|
||||
@requires_cuda
|
||||
def test_stream_input(self):
|
||||
@ -193,248 +155,22 @@ class <lambda>(torch.nn.Module):
|
||||
self.assertEqual(s_act, s_exp)
|
||||
|
||||
def test_nested_stream_enter_exit(self):
|
||||
def fn(x, y, s0, s1, s2):
|
||||
with s1:
|
||||
with s2:
|
||||
z1 = torch.add(x, y)
|
||||
with s0:
|
||||
z0 = torch.add(x, y)
|
||||
with s2:
|
||||
y = 2 + z1
|
||||
pass
|
||||
|
||||
return z0, y
|
||||
|
||||
inp = (
|
||||
torch.ones(2, 2) + 1,
|
||||
torch.ones(2, 2),
|
||||
torch.Stream(),
|
||||
torch.Stream(),
|
||||
torch.Stream(),
|
||||
)
|
||||
expected = fn(*inp)
|
||||
(
|
||||
actual,
|
||||
_,
|
||||
fw_graphs,
|
||||
_,
|
||||
) = extract_graph(fn, *inp)
|
||||
self.assertEqual(len(fw_graphs), 1)
|
||||
self.assertEqual(expected, actual)
|
||||
self.assertExpectedInline(
|
||||
print_graph(fw_graphs[0]),
|
||||
"""\
|
||||
class <lambda>(torch.nn.Module):
|
||||
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
|
||||
# Annotation: {'stream': None}
|
||||
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
|
||||
|
||||
# Annotation: {'stream': None}
|
||||
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
|
||||
|
||||
# Annotation: {'stream': None}
|
||||
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add, 2); add = None
|
||||
return (add_1, add_2)
|
||||
""",
|
||||
)
|
||||
|
||||
@unittest.skip("Needs graph break support with annotation context")
|
||||
def test_stream_enter_exit_graph_break(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Needs graph break support with annotation context")
|
||||
def test_nested_stream_enter_exit_graph_break(self):
|
||||
pass
|
||||
|
||||
def test_local_stream_enter_exit(self):
|
||||
def fn(x, y):
|
||||
s2 = torch.Stream()
|
||||
s1 = torch.Stream()
|
||||
with s1:
|
||||
z1 = torch.add(x, y)
|
||||
with s2:
|
||||
z = torch.add(x, y)
|
||||
y = z + 2 + z1
|
||||
|
||||
return y
|
||||
|
||||
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
|
||||
expected = fn(*inp)
|
||||
(
|
||||
actual,
|
||||
_,
|
||||
fw_graphs,
|
||||
_,
|
||||
) = extract_graph(fn, *inp)
|
||||
self.assertEqual(len(fw_graphs), 1)
|
||||
self.assertEqual(expected, actual)
|
||||
self.assertExpectedInline(
|
||||
print_graph(fw_graphs[0]),
|
||||
"""\
|
||||
class <lambda>(torch.nn.Module):
|
||||
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
|
||||
# Annotation: {'stream': 1}
|
||||
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
|
||||
|
||||
# Annotation: {'stream': 0}
|
||||
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
|
||||
|
||||
# Annotation: {'stream': 0}
|
||||
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 2); add_1 = None
|
||||
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_2, add); add_2 = add = None
|
||||
return (add_3,)
|
||||
""",
|
||||
)
|
||||
pass
|
||||
|
||||
def test_local_stream_nested_enter_exit(self):
|
||||
def fn(x, y):
|
||||
s2 = torch.Stream()
|
||||
s1 = torch.Stream()
|
||||
s0 = torch.Stream()
|
||||
with s1:
|
||||
with s2:
|
||||
z1 = torch.add(x, y)
|
||||
with s0:
|
||||
z0 = torch.add(x, y)
|
||||
with s2:
|
||||
y = 2 + z1
|
||||
|
||||
return z0, y
|
||||
|
||||
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
|
||||
expected = fn(*inp)
|
||||
(
|
||||
actual,
|
||||
_,
|
||||
fw_graphs,
|
||||
_,
|
||||
) = extract_graph(fn, *inp)
|
||||
self.assertEqual(len(fw_graphs), 1)
|
||||
self.assertEqual(expected, actual)
|
||||
self.assertExpectedInline(
|
||||
print_graph(fw_graphs[0]),
|
||||
"""\
|
||||
class <lambda>(torch.nn.Module):
|
||||
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
|
||||
# Annotation: {'stream': 0}
|
||||
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
|
||||
|
||||
# Annotation: {'stream': 2}
|
||||
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
|
||||
|
||||
# Annotation: {'stream': 0}
|
||||
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add, 2); add = None
|
||||
return (add_1, add_2)
|
||||
""",
|
||||
)
|
||||
pass
|
||||
|
||||
def test_stream_with_mutation(self):
|
||||
def fn(x, y):
|
||||
s2 = torch.Stream()
|
||||
s1 = torch.Stream()
|
||||
s0 = torch.Stream()
|
||||
with s1:
|
||||
with s2:
|
||||
x.add_(y)
|
||||
with s0:
|
||||
z1 = torch.add(y, y)
|
||||
z0 = torch.add(z1, y)
|
||||
with s2:
|
||||
y = 2 + z1
|
||||
|
||||
return z0, y
|
||||
|
||||
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
|
||||
expected = fn(*inp)
|
||||
(
|
||||
actual,
|
||||
_,
|
||||
fw_graphs,
|
||||
_,
|
||||
) = extract_graph(fn, *inp)
|
||||
self.assertEqual(len(fw_graphs), 1)
|
||||
self.assertEqual(expected, actual)
|
||||
self.assertExpectedInline(
|
||||
print_graph(fw_graphs[0]),
|
||||
"""\
|
||||
class <lambda>(torch.nn.Module):
|
||||
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
|
||||
# Annotation: {'stream': 0}
|
||||
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
|
||||
|
||||
# Annotation: {'stream': 2}
|
||||
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg1_1, arg1_1)
|
||||
|
||||
# Annotation: {'stream': 2}
|
||||
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, arg1_1); arg1_1 = None
|
||||
|
||||
# Annotation: {'stream': 0}
|
||||
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 2); add_1 = None
|
||||
|
||||
#
|
||||
copy_: "f32[2, 2]" = torch.ops.aten.copy_.default(arg0_1, add); arg0_1 = add = copy_ = None
|
||||
return (add_2, add_3)
|
||||
""",
|
||||
)
|
||||
|
||||
def test_stream_backward(self) -> None:
|
||||
def fn(x, y):
|
||||
s2 = torch.Stream()
|
||||
s0 = torch.Stream()
|
||||
with s0:
|
||||
y0 = 2 * x + y
|
||||
with s2:
|
||||
z = 2 * x + y
|
||||
|
||||
return y0, z
|
||||
|
||||
inp = (
|
||||
torch.ones(2, 2, requires_grad=True) + 1,
|
||||
torch.ones(2, 2, requires_grad=True),
|
||||
)
|
||||
expected = fn(*inp)
|
||||
(
|
||||
actual,
|
||||
_,
|
||||
fw_graphs,
|
||||
bw_graphs,
|
||||
) = extract_graph(fn, *inp)
|
||||
self.assertEqual(len(fw_graphs), 1)
|
||||
self.assertEqual(expected, actual)
|
||||
self.assertExpectedInline(
|
||||
print_graph(fw_graphs[0]),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "f32[2, 2]", primals_2: "f32[2, 2]"):
|
||||
# Annotation: {'stream': 1}
|
||||
mul: "f32[2, 2]" = torch.ops.aten.mul.Tensor(primals_1, 2); primals_1 = None
|
||||
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul, primals_2)
|
||||
|
||||
# Annotation: {'stream': 0}
|
||||
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul, primals_2); mul = primals_2 = None
|
||||
return (add, add_1)
|
||||
""",
|
||||
)
|
||||
|
||||
actual[1].sum().backward()
|
||||
self.assertExpectedInline(
|
||||
print_graph(bw_graphs[0]),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, tangents_1: "f32[2, 2]", tangents_2: "f32[2, 2]"):
|
||||
# Annotation: {'stream': 0}
|
||||
mul_2: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_2, 2)
|
||||
|
||||
#
|
||||
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(tangents_2, tangents_1); tangents_2 = None
|
||||
|
||||
# Annotation: {'stream': 1}
|
||||
mul_3: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_1, 2); tangents_1 = None
|
||||
|
||||
#
|
||||
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul_2, mul_3); mul_2 = mul_3 = None
|
||||
return (add_3, add_2)
|
||||
""",
|
||||
)
|
||||
pass
|
||||
|
||||
@requires_cuda
|
||||
def test_run_opcheck(self):
|
||||
|
||||
@ -2,14 +2,18 @@
|
||||
import re
|
||||
import unittest
|
||||
from functools import partial
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch._inductor import config as inductor_config
|
||||
from torch._inductor.choices import InductorChoices
|
||||
from torch._inductor.kernel_inputs import MMKernelInputs
|
||||
from torch._inductor.kernel_inputs import (
|
||||
ConvKernelInputs,
|
||||
MMKernelInputs,
|
||||
SerializableValue,
|
||||
)
|
||||
from torch._inductor.lookup_table.choices import LookupTableChoices
|
||||
from torch._inductor.select_algorithm import (
|
||||
add_preprocessing_fn,
|
||||
@ -54,7 +58,7 @@ class MockMMKernelInputs(MMKernelInputs):
|
||||
def __init__(
|
||||
self,
|
||||
tensors: list[torch.Tensor],
|
||||
scalars: Optional[dict[str, Union[float, int]]] = None,
|
||||
scalars: Optional[dict[str, SerializableValue]] = None,
|
||||
mat1_idx: int = -2,
|
||||
mat2_idx: int = -1,
|
||||
):
|
||||
@ -80,6 +84,37 @@ class MockMMKernelInputs(MMKernelInputs):
|
||||
return self.tensors[0].device.type
|
||||
|
||||
|
||||
class MockConvKernelInputs(ConvKernelInputs):
|
||||
"""Mock ConvKernelInputs that subclasses the real class and uses real tensors"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tensors: list[torch.Tensor],
|
||||
scalars: Optional[dict[str, SerializableValue]] = None,
|
||||
x_idx: int = 0,
|
||||
weight_idx: int = 1,
|
||||
bias_idx: Optional[int] = None,
|
||||
):
|
||||
"""Initialize with real tensors, creating mock nodes for the base class"""
|
||||
mock_nodes = [MockTensorNode(t) for t in tensors]
|
||||
super().__init__(
|
||||
mock_nodes, scalars, x_idx=x_idx, weight_idx=weight_idx, bias_idx=bias_idx
|
||||
)
|
||||
self.tensors = tensors # Keep reference to original tensors
|
||||
|
||||
def shapes_hinted(self) -> tuple[tuple[int, ...], ...]:
|
||||
"""Delegate to symbolic since real tensors already have int shapes"""
|
||||
return self.shapes_symbolic()
|
||||
|
||||
def strides_hinted(self) -> tuple[tuple[int, ...], ...]:
|
||||
"""Delegate to symbolic since real tensors already have int strides"""
|
||||
return self.strides_symbolic() # pyre-ignore
|
||||
|
||||
@property
|
||||
def device_type(self) -> Optional[str]:
|
||||
return self.tensors[0].device.type
|
||||
|
||||
|
||||
class BaseLookupTableTest(TestCase):
|
||||
"""Base class for lookup table tests with common setup and utilities"""
|
||||
|
||||
@ -103,7 +138,7 @@ class BaseLookupTableTest(TestCase):
|
||||
shapes: Optional[list[tuple[int, ...]]] = None,
|
||||
device: torch.device = torch.device("cuda"),
|
||||
dtype: torch.dtype = torch.float32,
|
||||
scalars: Optional[dict[str, Union[float, int]]] = None,
|
||||
scalars: Optional[dict[str, SerializableValue]] = None,
|
||||
) -> MockMMKernelInputs:
|
||||
"""Create MockMMKernelInputs with real tensors"""
|
||||
if shapes is None:
|
||||
@ -1055,6 +1090,119 @@ class TestLookupTableE2E(BaseE2ELookupTableTest):
|
||||
with patch.object(inductor_config.lookup_table, "check_src_hash", True):
|
||||
self.run_model("mm", tensors)
|
||||
|
||||
@fresh_cache()
|
||||
def test_conv2d_lookup_table_entry_e2e(self):
|
||||
"""Test end-to-end conv2d with lookup table entry - verifies config is picked up and produces valid results"""
|
||||
import torch._inductor.kernel.conv
|
||||
|
||||
# Create input tensors with specific shapes for conv2d
|
||||
# Input: [batch=2, in_channels=3, height=32, width=32]
|
||||
# Weight: [out_channels=64, in_channels=3, kernel_h=3, kernel_w=3]
|
||||
# Make them channels-last to match what conv lowering uses
|
||||
x = torch.randn(2, 3, 32, 32, device=self.device, dtype=torch.float16).to(
|
||||
memory_format=torch.channels_last
|
||||
)
|
||||
weight = torch.randn(64, 3, 3, 3, device=self.device, dtype=torch.float16).to(
|
||||
memory_format=torch.channels_last
|
||||
)
|
||||
|
||||
# Define conv parameters - use these SAME values everywhere
|
||||
stride = (1, 1)
|
||||
padding = (1, 1)
|
||||
dilation = (1, 1)
|
||||
groups = 1
|
||||
|
||||
# Create MockConvKernelInputs using the SAME tensors and SAME scalar values
|
||||
mock_scalars = {
|
||||
"stride": stride,
|
||||
"padding": padding,
|
||||
"dilation": dilation,
|
||||
"transposed": False,
|
||||
"output_padding": (0, 0),
|
||||
"groups": groups,
|
||||
}
|
||||
mock_kernel_inputs = MockConvKernelInputs([x, weight], mock_scalars)
|
||||
|
||||
# Create lookup key for "convolution" operation
|
||||
choices_handler = LookupTableChoices()
|
||||
lookup_key = choices_handler.make_lookup_key(mock_kernel_inputs, "convolution")
|
||||
|
||||
# Get the exact template UID from conv2d_template
|
||||
template_uid = torch._inductor.kernel.conv.conv2d_template.uid
|
||||
|
||||
# Create a precisely configured conv2d config
|
||||
# IMPORTANT: Only include per-config tunable parameters!
|
||||
# Static parameters (KERNEL_H, STRIDE_H, GROUPS, UNROLL, ALLOW_TF32) are
|
||||
# automatically generated by get_extra_kwargs() and should NOT be in the lookup table
|
||||
conv2d_config = {
|
||||
"template_id": template_uid,
|
||||
# Per-config tunable parameters only (what you'd tune via autotuning)
|
||||
"BLOCK_M": 64,
|
||||
"BLOCK_N": 64,
|
||||
"BLOCK_K": 32,
|
||||
"num_stages": 2,
|
||||
"num_warps": 4,
|
||||
}
|
||||
|
||||
# Setup lookup table
|
||||
inductor_config.lookup_table.table = {lookup_key: [conv2d_config]}
|
||||
|
||||
def validate_conv_choice(choices):
|
||||
assert len(choices) == 1, (
|
||||
f"Expected 1 choice from lookup table, got {len(choices)}"
|
||||
)
|
||||
assert isinstance(choices[0], TritonTemplateCaller), (
|
||||
f"Expected TritonTemplateCaller, got {type(choices[0])}"
|
||||
)
|
||||
assert "convolution2d" in choices[0].name, (
|
||||
f"Expected 'convolution2d' in name, got {choices[0].name}"
|
||||
)
|
||||
return choices
|
||||
|
||||
add_preprocessing_fn(validate_conv_choice)
|
||||
|
||||
# Create and compile the model using the SAME weight tensor
|
||||
class SimpleConv2d(nn.Module):
|
||||
def __init__(self, weight):
|
||||
super().__init__()
|
||||
self.register_buffer("weight", weight)
|
||||
|
||||
def forward(self, x):
|
||||
return torch.conv2d(
|
||||
x,
|
||||
self.weight,
|
||||
bias=None,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
)
|
||||
|
||||
model = SimpleConv2d(weight).to(self.device)
|
||||
|
||||
with inductor_config.patch({"max_autotune": True, "max_autotune_gemm": True}):
|
||||
compiled_model = torch.compile(model)
|
||||
result = compiled_model(x) # Use the SAME x tensor
|
||||
|
||||
# Output shape: [batch=2, out_channels=64, out_h=32, out_w=32]
|
||||
# (same spatial dims due to padding=1, stride=1, kernel=3)
|
||||
expected_shape = (2, 64, 32, 32)
|
||||
self.assertEqual(
|
||||
result.shape,
|
||||
expected_shape,
|
||||
f"Expected shape {expected_shape}, got {result.shape}",
|
||||
)
|
||||
|
||||
self.assertFalse(
|
||||
torch.isnan(result).any().item(),
|
||||
"Output contains NaN values",
|
||||
)
|
||||
|
||||
self.assertFalse(
|
||||
torch.isinf(result).any().item(),
|
||||
"Output contains Inf values",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._inductor.utils import is_big_gpu
|
||||
|
||||
@ -1864,8 +1864,6 @@ class TestFP8Matmul(TestCase):
|
||||
], name_fn=lambda mkn: f"{mkn[0]}_{mkn[1]}_{mkn[2]}")
|
||||
@parametrize("recipe", ["mxfp8", "mxfp4", "nvfp4"])
|
||||
def test_blockwise_mxfp8_nvfp4_mxfp4_numerics(self, test_case_name, fast_accum, mkn, recipe) -> None:
|
||||
if torch.version.hip and recipe == "nvfp4":
|
||||
raise unittest.SkipTest("nvfp4 not supported on ROCm, skipping")
|
||||
if (recipe == "nvfp4" or recipe == "mxfp4") and fast_accum:
|
||||
raise unittest.SkipTest("fast_accum not supported in nvfp4/mxfp4 cublas gemm, skipping")
|
||||
|
||||
|
||||
@ -1914,7 +1914,6 @@ class TestSDPAFailureModes(NNTestCase):
|
||||
q, k, v, None, 0.0, is_causal=True))
|
||||
|
||||
@onlyCUDA
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support Efficient Attention")
|
||||
def test_mem_eff_attention_fail_with_batch_size_geq_65536(self):
|
||||
batch_size = 2**16
|
||||
query = torch.rand([batch_size, 2, 2, 8], device='cuda', dtype=torch.float16, requires_grad=True)
|
||||
@ -1936,7 +1935,6 @@ class TestSDPAFailureModes(NNTestCase):
|
||||
self.assertEqual(value.grad, v_cpu.grad, atol=2e-3, rtol=1e-4)
|
||||
|
||||
@onlyCUDA
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support Efficient Attention")
|
||||
def test_mem_eff_attention_fail_with_batch_size_geq_65536_error(self):
|
||||
query = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16)
|
||||
key = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16)
|
||||
@ -1950,7 +1948,6 @@ class TestSDPAFailureModes(NNTestCase):
|
||||
|
||||
@largeTensorTest("15GB", "cuda")
|
||||
@onlyCUDA
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support Efficient Attention")
|
||||
def test_mem_eff_attention_large_seq_len_uniform_attention(self):
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.bfloat16
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from typing import TypeAlias, Union
|
||||
from typing_extensions import assert_type
|
||||
from typing import Union
|
||||
from typing_extensions import assert_type, TypeAlias
|
||||
|
||||
from torch import randn, Tensor
|
||||
|
||||
|
||||
@ -1,9 +1,8 @@
|
||||
# mypy: allow-untyped-defs
|
||||
# mypy: disable-error-code="type-arg"
|
||||
from collections.abc import Callable
|
||||
from datetime import timedelta
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, overload, Union
|
||||
from typing import Any, Callable, Optional, overload, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
@ -87,12 +87,6 @@ def extract_graph_and_tracker(fn, *args, **kwargs): # type: ignore[no-untyped-d
|
||||
return gm.graph, region_tracker # type: ignore[union-attr]
|
||||
|
||||
|
||||
def extract_graph(fn, *args, **kwargs): # type: ignore[no-untyped-def]
|
||||
backend = AotEagerAndRecordGraphs()
|
||||
result = torch.compile(backend=backend)(fn)(*args, **kwargs)
|
||||
return result, backend.graphs, backend.fw_graphs, backend.bw_graphs
|
||||
|
||||
|
||||
def collect_results(
|
||||
model: torch.nn.Module, prediction: Any, loss: Any, example_inputs: Any
|
||||
) -> list[Any]:
|
||||
|
||||
@ -21,9 +21,9 @@ restoring state changes.
|
||||
import inspect
|
||||
import sys
|
||||
import warnings
|
||||
from collections.abc import Callable, Sequence, Sized
|
||||
from collections.abc import Callable, Sequence
|
||||
from contextlib import ExitStack
|
||||
from typing import Any, ContextManager, Optional, TYPE_CHECKING, Union
|
||||
from typing import Any, ContextManager, Optional, Sized, TYPE_CHECKING, Union
|
||||
|
||||
import torch._C
|
||||
from torch._guards import Guard
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from typing import Any, Optional, TYPE_CHECKING
|
||||
from typing import Any, Callable, Optional, Sequence, TYPE_CHECKING
|
||||
|
||||
import sympy # noqa: TC002
|
||||
|
||||
@ -17,8 +17,6 @@ from .simd import SIMDKernel, SIMDScheduling
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable, Sequence
|
||||
|
||||
from ..ir import IRNode
|
||||
from ..scheduler import BaseSchedulerNode
|
||||
|
||||
|
||||
@ -360,7 +360,7 @@ def estimate_nccl_collective_runtime_from_fx_node(
|
||||
fx_node: torch.fx.Node,
|
||||
override_size: Optional[int] = None,
|
||||
# TODO(ivankobzarev): NCCL estimator sometimes fail unexpectedly, enable back after fix.
|
||||
use_nccl_estimator: bool = True,
|
||||
use_nccl_estimator: bool = False,
|
||||
) -> float:
|
||||
"""
|
||||
Returns estimated NCCL collective runtime in nanoseconds (ns).
|
||||
|
||||
@ -8,6 +8,7 @@ import torch
|
||||
from torch._inductor.codegen.rocm.ck_conv_template import CKGroupedConvFwdTemplate
|
||||
|
||||
from .. import config, ir
|
||||
from ..kernel_inputs import ConvKernelInputs
|
||||
from ..lowering import (
|
||||
add_layout_constraint,
|
||||
constrain_to_fx_strides,
|
||||
@ -16,7 +17,9 @@ from ..lowering import (
|
||||
)
|
||||
from ..select_algorithm import (
|
||||
autotune_select_algorithm,
|
||||
ChoiceCaller,
|
||||
ExternKernelChoice,
|
||||
KernelTemplate,
|
||||
SymbolicGridFn,
|
||||
TritonTemplate,
|
||||
)
|
||||
@ -76,7 +79,7 @@ LOOP_BODY_2D = """
|
||||
& (idx_x_h < IN_H)[:, None]
|
||||
& (idx_x_w >= 0)[:, None]
|
||||
& (idx_x_w < IN_W)[:, None]
|
||||
& (idx_x_c < GROUP_IN_C)[None, :]
|
||||
& (idx_x_c < GROUP_IN_C)[None, :
|
||||
)
|
||||
matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0)
|
||||
|
||||
@ -542,34 +545,40 @@ def convolution(
|
||||
x = ir.ExternKernel.require_stride_order(x, req_stride_order) # type: ignore[assignment]
|
||||
weight = ir.ExternKernel.require_stride_order(weight, req_stride_order) # type: ignore[assignment]
|
||||
|
||||
ordered_kwargs_for_cpp_kernel = [
|
||||
"stride",
|
||||
"padding",
|
||||
"dilation",
|
||||
"transposed",
|
||||
"output_padding",
|
||||
"groups",
|
||||
]
|
||||
if bias is None:
|
||||
args = [x, weight]
|
||||
kwargs["bias"] = None # type: ignore[typeddict-unknown-key]
|
||||
ordered_kwargs_for_cpp_kernel.insert(0, "bias")
|
||||
else:
|
||||
args = [x, weight, bias]
|
||||
# Create ConvKernelInputs for unified template configuration
|
||||
# Only include bias in input_nodes when it's not None
|
||||
# - For Triton templates: bias is always None here (peeled off earlier), so input_nodes = [x, weight]
|
||||
# - For ATEN: input_nodes = [x, weight] when bias is None, [x, weight, bias] when bias is present
|
||||
if bias is not None:
|
||||
bias.realize()
|
||||
bias.freeze_layout()
|
||||
V.graph.sizevars.guard_int_seq(bias.get_size())
|
||||
input_nodes = [x, weight, bias]
|
||||
bias_idx = 2
|
||||
else:
|
||||
input_nodes = [x, weight]
|
||||
bias_idx = None
|
||||
|
||||
kernel_inputs = ConvKernelInputs(
|
||||
input_nodes,
|
||||
scalars={
|
||||
"stride": stride,
|
||||
"padding": padding,
|
||||
"dilation": dilation,
|
||||
"transposed": transposed,
|
||||
"output_padding": output_padding,
|
||||
"groups": groups,
|
||||
},
|
||||
x_idx=0,
|
||||
weight_idx=1,
|
||||
bias_idx=bias_idx,
|
||||
)
|
||||
|
||||
# Build list of templates to try
|
||||
templates: list[ExternKernelChoice | KernelTemplate] = []
|
||||
|
||||
choices = []
|
||||
if torch._inductor.utils._use_conv_autotune_backend("ATEN"):
|
||||
choices = [
|
||||
aten_convolution.bind(
|
||||
args,
|
||||
layout,
|
||||
ordered_kwargs_for_cpp_kernel,
|
||||
**kwargs,
|
||||
)
|
||||
]
|
||||
templates.append(aten_convolution)
|
||||
|
||||
if (
|
||||
torch._inductor.utils._use_conv_autotune_backend("TRITON")
|
||||
@ -587,60 +596,23 @@ def convolution(
|
||||
and is_zeros(padding)
|
||||
and groups == 1
|
||||
):
|
||||
choices.append(aten_conv1x1_via_mm.bind(args, layout))
|
||||
templates.append(aten_conv1x1_via_mm)
|
||||
|
||||
conv_configs = V.choices.get_conv_configs(device_type)
|
||||
# Add appropriate template based on ndim
|
||||
if ndim == 2:
|
||||
templates.append(conv2d_template)
|
||||
elif ndim == 3:
|
||||
templates.append(conv3d_template)
|
||||
|
||||
dtype_size = x.get_dtype().itemsize
|
||||
for cfg in conv_configs(
|
||||
sympy_product([x.get_size()[0], *x.get_size()[2:]]),
|
||||
out_chan,
|
||||
in_chan,
|
||||
dtype_size=dtype_size,
|
||||
):
|
||||
if ndim == 2:
|
||||
conv2d_template.maybe_append_choice(
|
||||
choices,
|
||||
input_nodes=(x, weight),
|
||||
layout=layout,
|
||||
KERNEL_H=kernel_shape[0],
|
||||
KERNEL_W=kernel_shape[1],
|
||||
STRIDE_H=stride[0],
|
||||
STRIDE_W=stride[1],
|
||||
PADDING_H=padding[0],
|
||||
PADDING_W=padding[1],
|
||||
GROUPS=groups,
|
||||
# TODO(jansel): try unroll for bigger kernels once fixed:
|
||||
# https://github.com/triton-lang/triton/issues/1254
|
||||
UNROLL=is_ones(kernel_shape),
|
||||
ALLOW_TF32=torch.backends.cudnn.allow_tf32,
|
||||
num_stages=cfg.num_stages,
|
||||
num_warps=cfg.num_warps,
|
||||
**cfg.kwargs,
|
||||
)
|
||||
elif ndim == 3:
|
||||
conv3d_template.maybe_append_choice(
|
||||
choices,
|
||||
input_nodes=(x, weight),
|
||||
layout=layout,
|
||||
KERNEL_D=kernel_shape[0],
|
||||
KERNEL_H=kernel_shape[1],
|
||||
KERNEL_W=kernel_shape[2],
|
||||
STRIDE_D=stride[0],
|
||||
STRIDE_H=stride[1],
|
||||
STRIDE_W=stride[2],
|
||||
PADDING_D=padding[0],
|
||||
PADDING_H=padding[1],
|
||||
PADDING_W=padding[2],
|
||||
GROUPS=groups,
|
||||
# TODO(jansel): try unroll for bigger kernels once fixed:
|
||||
# https://github.com/triton-lang/triton/issues/1254
|
||||
UNROLL=is_ones(kernel_shape),
|
||||
ALLOW_TF32=torch.backends.cudnn.allow_tf32,
|
||||
num_stages=cfg.num_stages,
|
||||
num_warps=cfg.num_warps,
|
||||
**cfg.kwargs,
|
||||
)
|
||||
# Initialize choices list and extend with template configs
|
||||
choices: list[ChoiceCaller] = []
|
||||
choices.extend(
|
||||
V.choices.get_template_configs(
|
||||
kernel_inputs,
|
||||
templates,
|
||||
"convolution",
|
||||
)
|
||||
)
|
||||
if use_ck_conv_template(layout):
|
||||
CKGroupedConvFwdTemplate.add_ck_conv_choices(
|
||||
choices,
|
||||
@ -652,7 +624,9 @@ def convolution(
|
||||
groups=groups,
|
||||
n_spatial_dimensions=ndim,
|
||||
)
|
||||
return autotune_select_algorithm("convolution", choices, args, layout)
|
||||
return autotune_select_algorithm(
|
||||
"convolution", choices, kernel_inputs.nodes(), layout
|
||||
)
|
||||
|
||||
|
||||
@register_lowering(aten._convolution)
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
@ -12,10 +13,12 @@ from .ir import FixedLayout, FlexibleLayout, Layout
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sympy
|
||||
|
||||
# Type aliases for serializable scalar values
|
||||
Serializable = Union[int, float, bool]
|
||||
SerializableValue = Union[Serializable, Sequence[Serializable]]
|
||||
|
||||
|
||||
class KernelInputs(ABC):
|
||||
"""
|
||||
@ -27,7 +30,7 @@ class KernelInputs(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
input_nodes: list[Any],
|
||||
scalars: Optional[dict[str, Union[float, int]]] = None,
|
||||
scalars: Optional[dict[str, SerializableValue]] = None,
|
||||
out_dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
"""
|
||||
@ -183,7 +186,7 @@ class KernelInputs(ABC):
|
||||
The output dtype
|
||||
"""
|
||||
|
||||
def get_scalar(self, name: str) -> Union[float, int]:
|
||||
def get_scalar(self, name: str) -> SerializableValue:
|
||||
"""
|
||||
Get the scalar value for a given name.
|
||||
|
||||
@ -191,7 +194,7 @@ class KernelInputs(ABC):
|
||||
name: Name of the scalar to get
|
||||
|
||||
Returns:
|
||||
The scalar value
|
||||
The scalar value (can be int, float, bool, or tuple of these types)
|
||||
"""
|
||||
assert name in self._scalars, f"Scalar {name} not found, but required"
|
||||
return self._scalars[name]
|
||||
@ -216,7 +219,7 @@ class MMKernelInputs(KernelInputs):
|
||||
def __init__(
|
||||
self,
|
||||
input_nodes: list[Any],
|
||||
scalars: Optional[dict[str, Union[float, int]]] = None,
|
||||
scalars: Optional[dict[str, SerializableValue]] = None,
|
||||
out_dtype: Optional[torch.dtype] = None,
|
||||
mat1_idx: int = -2,
|
||||
mat2_idx: int = -1,
|
||||
@ -336,3 +339,113 @@ class MMKernelInputs(KernelInputs):
|
||||
assert k == k_check, f"K dimensions don't match: {k} vs {k_check}"
|
||||
|
||||
return (m, n, k)
|
||||
|
||||
|
||||
class ConvKernelInputs(KernelInputs):
|
||||
"""
|
||||
Specialized KernelInputs for convolution operations.
|
||||
Stores input tensor, weight tensor, and optional bias, along with conv parameters.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_nodes: list[Any],
|
||||
scalars: Optional[dict[str, SerializableValue]] = None,
|
||||
out_dtype: Optional[torch.dtype] = None,
|
||||
x_idx: int = 0,
|
||||
weight_idx: int = 1,
|
||||
bias_idx: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Initialize with convolution input nodes.
|
||||
|
||||
Args:
|
||||
input_nodes: List containing [x, weight] or [x, weight, bias]
|
||||
scalars: Dict with conv params (stride, padding, dilation, groups, transposed, output_padding)
|
||||
out_dtype: Optional output dtype
|
||||
x_idx: Index of input tensor (default: 0)
|
||||
weight_idx: Index of weight tensor (default: 1)
|
||||
bias_idx: Index of bias tensor if present (default: None)
|
||||
"""
|
||||
super().__init__(input_nodes, scalars, out_dtype)
|
||||
assert len(input_nodes) >= 2, "Expected at least 2 input nodes (x, weight)"
|
||||
|
||||
self._x_idx = x_idx
|
||||
self._weight_idx = weight_idx
|
||||
self._bias_idx = bias_idx
|
||||
|
||||
# Validate that required scalars are present
|
||||
required_scalars = [
|
||||
"stride",
|
||||
"padding",
|
||||
"dilation",
|
||||
"transposed",
|
||||
"output_padding",
|
||||
"groups",
|
||||
]
|
||||
for key in required_scalars:
|
||||
assert key in self._scalars, f"Conv requires scalar '{key}'"
|
||||
|
||||
def out_dtype(self) -> torch.dtype:
|
||||
"""
|
||||
Get the output dtype, whether passed in or inferred from the nodes
|
||||
|
||||
Returns:
|
||||
The output dtype
|
||||
"""
|
||||
if self._out_dtype is not None:
|
||||
return self._out_dtype
|
||||
return self._input_nodes[self._x_idx].get_dtype()
|
||||
|
||||
def output_layout(self, flexible: bool = True) -> Layout:
|
||||
"""
|
||||
Handle output layout generation for convolution.
|
||||
|
||||
Args:
|
||||
flexible: If True, return FlexibleLayout, otherwise FixedLayout
|
||||
|
||||
Returns:
|
||||
Layout for the convolution output
|
||||
"""
|
||||
from torch._inductor.kernel.conv import conv_layout
|
||||
|
||||
x = self._input_nodes[self._x_idx]
|
||||
weight = self._input_nodes[self._weight_idx]
|
||||
bias = self._input_nodes[self._bias_idx] if self._bias_idx is not None else None
|
||||
|
||||
# Use existing conv_layout function
|
||||
# We know the types here because conv requires these specific scalar types
|
||||
layout = conv_layout(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
self._scalars["stride"], # type: ignore[arg-type]
|
||||
self._scalars["padding"], # type: ignore[arg-type]
|
||||
self._scalars["dilation"], # type: ignore[arg-type]
|
||||
self._scalars["transposed"], # type: ignore[arg-type]
|
||||
self._scalars["output_padding"], # type: ignore[arg-type]
|
||||
self._scalars["groups"], # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
# TODO: Handle flexible vs fixed based on config if needed
|
||||
return layout
|
||||
|
||||
def get_x_weight_bias(self) -> tuple[Any, Any, Optional[Any]]:
|
||||
"""
|
||||
Get x, weight, and optional bias nodes.
|
||||
|
||||
Returns:
|
||||
Tuple of (x, weight, bias) where bias may be None
|
||||
"""
|
||||
bias = self._input_nodes[self._bias_idx] if self._bias_idx is not None else None
|
||||
return self._input_nodes[self._x_idx], self._input_nodes[self._weight_idx], bias
|
||||
|
||||
def spatial_dims(self) -> tuple[Any, ...]:
|
||||
"""
|
||||
Get spatial dimensions from input tensor (H, W for 2D, D, H, W for 3D).
|
||||
|
||||
Returns:
|
||||
Tuple of spatial dimension sizes
|
||||
"""
|
||||
x_shape = self._input_nodes[self._x_idx].get_size()
|
||||
return x_shape[2:] # Skip batch and channel dims
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from functools import cache, partial
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from torch._environment import is_fbcode
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# NOTE: add new template heuristics here, so they get imported and registered
|
||||
# TODO: write a simple glob if there are many heuristics to auto import them in the right order
|
||||
from . import aten, base, contiguous_mm, decompose_k, registry, triton
|
||||
from . import aten, base, contiguous_mm, conv, decompose_k, registry, triton
|
||||
|
||||
# expose the entry function
|
||||
from .registry import get_template_heuristic
|
||||
|
||||
287
torch/_inductor/template_heuristics/conv.py
Normal file
287
torch/_inductor/template_heuristics/conv.py
Normal file
@ -0,0 +1,287 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, cast, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from ..kernel.conv import aten_convolution, conv2d_template, conv3d_template
|
||||
from ..kernel_inputs import ConvKernelInputs
|
||||
from ..utils import is_ones, sympy_product
|
||||
from ..virtualized import V
|
||||
from .base import TemplateConfigHeuristics
|
||||
from .registry import register_template_heuristic
|
||||
from .triton import (
|
||||
CPUConfigHeuristic,
|
||||
CUDAConfigHeuristic,
|
||||
MTIAConfigHeuristic,
|
||||
ROCmConfigHeuristic,
|
||||
XPUConfigHeuristic,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Generator
|
||||
|
||||
from ..kernel_inputs import KernelInputs
|
||||
|
||||
|
||||
class ConvTemplateConfigMixin(TemplateConfigHeuristics):
|
||||
"""
|
||||
Mixin for conv templates that converts config lists to template kwargs.
|
||||
Similar to MMTemplateConfigMixin but for convolutions.
|
||||
|
||||
This handles generating both the static template kwargs (KERNEL_H, STRIDE_H, etc.)
|
||||
and the per-config kwargs (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps).
|
||||
"""
|
||||
|
||||
# Type hint for methods from BaseConfigHeuristic
|
||||
get_conv_configs: Any
|
||||
|
||||
def get_extra_kwargs(
|
||||
self,
|
||||
kernel_inputs: KernelInputs,
|
||||
op_name: str,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Return template kwargs that don't change per-config.
|
||||
These are derived from kernel_inputs and must include all template parameters.
|
||||
|
||||
Args:
|
||||
kernel_inputs: ConvKernelInputs containing input tensors and conv params
|
||||
op_name: Operation name (e.g., "convolution")
|
||||
|
||||
Returns:
|
||||
Dict of static template kwargs (KERNEL_H, STRIDE_H, GROUPS, etc.)
|
||||
"""
|
||||
assert isinstance(kernel_inputs, ConvKernelInputs), (
|
||||
f"ConvTemplateConfigMixin requires ConvKernelInputs, got {type(kernel_inputs)}"
|
||||
)
|
||||
|
||||
x, weight, bias = kernel_inputs.get_x_weight_bias()
|
||||
|
||||
# Extract kernel shape from weight: [out_chan, in_chan, *kernel_shape]
|
||||
weight_size = V.graph.sizevars.guard_int_seq(weight.get_size())
|
||||
kernel_shape = weight_size[2:] # Skip out_chan, in_chan
|
||||
ndim = len(kernel_shape)
|
||||
|
||||
# Extract scalars
|
||||
stride = cast(tuple[int, ...], kernel_inputs.get_scalar("stride"))
|
||||
padding = cast(tuple[int, ...], kernel_inputs.get_scalar("padding"))
|
||||
groups = cast(int, kernel_inputs.get_scalar("groups"))
|
||||
|
||||
# Check if we should unroll (only for 1x1 kernels)
|
||||
unroll = is_ones(kernel_shape)
|
||||
|
||||
# Build kwargs dict based on ndim
|
||||
kwargs: dict[str, Any] = {
|
||||
"GROUPS": groups,
|
||||
"UNROLL": unroll,
|
||||
"ALLOW_TF32": torch.backends.cudnn.allow_tf32,
|
||||
}
|
||||
|
||||
if ndim == 2:
|
||||
kwargs.update(
|
||||
{
|
||||
"KERNEL_H": kernel_shape[0],
|
||||
"KERNEL_W": kernel_shape[1],
|
||||
"STRIDE_H": stride[0],
|
||||
"STRIDE_W": stride[1],
|
||||
"PADDING_H": padding[0],
|
||||
"PADDING_W": padding[1],
|
||||
}
|
||||
)
|
||||
elif ndim == 3:
|
||||
kwargs.update(
|
||||
{
|
||||
"KERNEL_D": kernel_shape[0],
|
||||
"KERNEL_H": kernel_shape[1],
|
||||
"KERNEL_W": kernel_shape[2],
|
||||
"STRIDE_D": stride[0],
|
||||
"STRIDE_H": stride[1],
|
||||
"STRIDE_W": stride[2],
|
||||
"PADDING_D": padding[0],
|
||||
"PADDING_H": padding[1],
|
||||
"PADDING_W": padding[2],
|
||||
}
|
||||
)
|
||||
|
||||
return kwargs
|
||||
|
||||
def _get_template_configs_impl(
|
||||
self,
|
||||
kernel_inputs: KernelInputs,
|
||||
op_name: str,
|
||||
) -> Generator[dict[str, Any], None, None]:
|
||||
"""
|
||||
Yield per-config kwargs (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps).
|
||||
|
||||
Args:
|
||||
kernel_inputs: ConvKernelInputs containing input tensors
|
||||
op_name: Operation name
|
||||
|
||||
Yields:
|
||||
Dict of per-config kwargs for each configuration to try
|
||||
"""
|
||||
assert isinstance(kernel_inputs, ConvKernelInputs), (
|
||||
"ConvTemplateConfigMixin requires ConvKernelInputs"
|
||||
)
|
||||
|
||||
x, weight, bias = kernel_inputs.get_x_weight_bias()
|
||||
|
||||
# Calculate dimensions for heuristics
|
||||
weight_size = weight.get_size()
|
||||
out_chan = weight_size[0]
|
||||
in_chan = weight_size[1]
|
||||
|
||||
# Batch * spatial dimensions product
|
||||
x_size = x.get_size()
|
||||
batch_spatial_product = sympy_product([x_size[0], *x_size[2:]])
|
||||
|
||||
# Get conv config generator from self (which is a BaseConfigHeuristic subclass)
|
||||
conv_configs_generator = self.get_conv_configs()
|
||||
|
||||
dtype_size = x.get_dtype().itemsize
|
||||
|
||||
# Generate configs (reusing mm preprocess_mm_configs machinery)
|
||||
for c in conv_configs_generator(
|
||||
batch_spatial_product,
|
||||
out_chan,
|
||||
in_chan,
|
||||
dtype_size=dtype_size,
|
||||
op_name="conv",
|
||||
):
|
||||
# Yield per-config kwargs
|
||||
yield {
|
||||
"BLOCK_M": c.kwargs.get("BLOCK_M"),
|
||||
"BLOCK_N": c.kwargs.get("BLOCK_N"),
|
||||
"BLOCK_K": c.kwargs.get("BLOCK_K"),
|
||||
"num_stages": c.num_stages,
|
||||
"num_warps": c.num_warps,
|
||||
}
|
||||
|
||||
|
||||
# ATEN convolution heuristic (no per-config tuning)
|
||||
@register_template_heuristic(aten_convolution.uid, None)
|
||||
class ATenConvConfigHeuristic(TemplateConfigHeuristics):
|
||||
"""
|
||||
Pseudo heuristic for ATen convolution.
|
||||
ATen doesn't have configs to tune - it's a single choice.
|
||||
"""
|
||||
|
||||
def _get_template_configs_impl(
|
||||
self,
|
||||
kernel_inputs: KernelInputs,
|
||||
op_name: str,
|
||||
) -> Generator[dict[str, Any], None, None]:
|
||||
# ATen doesn't have per-config kwargs to tune
|
||||
yield dict()
|
||||
|
||||
def get_extra_kwargs(
|
||||
self,
|
||||
kernel_inputs: KernelInputs,
|
||||
op_name: str,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
ATen gets stride, padding, etc. as ordered kwargs for the C++ kernel.
|
||||
"""
|
||||
assert isinstance(kernel_inputs, ConvKernelInputs)
|
||||
|
||||
# Extract scalar values from kernel_inputs
|
||||
stride = cast(tuple[int, ...], kernel_inputs.get_scalar("stride"))
|
||||
padding = cast(tuple[int, ...], kernel_inputs.get_scalar("padding"))
|
||||
dilation = cast(tuple[int, ...], kernel_inputs.get_scalar("dilation"))
|
||||
transposed = cast(bool, kernel_inputs.get_scalar("transposed"))
|
||||
output_padding = cast(
|
||||
tuple[int, ...], kernel_inputs.get_scalar("output_padding")
|
||||
)
|
||||
groups = cast(int, kernel_inputs.get_scalar("groups"))
|
||||
|
||||
# Check if bias is None to match old behavior
|
||||
# When bias is None: input_nodes = [x, weight], add 'bias' to kwargs and ordered list
|
||||
# When bias is present: input_nodes = [x, weight, bias], don't add 'bias' to kwargs
|
||||
x, weight, bias = kernel_inputs.get_x_weight_bias()
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"stride": stride,
|
||||
"padding": padding,
|
||||
"dilation": dilation,
|
||||
"transposed": transposed,
|
||||
"output_padding": output_padding,
|
||||
"groups": groups,
|
||||
}
|
||||
|
||||
if bias is None:
|
||||
# When bias is None, torch.convolution expects it as a kwarg
|
||||
kwargs["bias"] = None
|
||||
kwargs["ordered_kwargs_for_cpp_kernel"] = [
|
||||
"bias",
|
||||
"stride",
|
||||
"padding",
|
||||
"dilation",
|
||||
"transposed",
|
||||
"output_padding",
|
||||
"groups",
|
||||
]
|
||||
else:
|
||||
# When bias is present, it's passed as a positional arg (3rd in input_nodes)
|
||||
kwargs["ordered_kwargs_for_cpp_kernel"] = [
|
||||
"stride",
|
||||
"padding",
|
||||
"dilation",
|
||||
"transposed",
|
||||
"output_padding",
|
||||
"groups",
|
||||
]
|
||||
|
||||
return kwargs
|
||||
|
||||
|
||||
# CUDA Conv2D/Conv3D heuristics
|
||||
@register_template_heuristic(
|
||||
conv2d_template.uid,
|
||||
"cuda",
|
||||
register=torch.version.hip is None,
|
||||
)
|
||||
@register_template_heuristic(
|
||||
conv3d_template.uid,
|
||||
"cuda",
|
||||
register=torch.version.hip is None,
|
||||
)
|
||||
class CUDAConvTemplateConfigHeuristic(ConvTemplateConfigMixin, CUDAConfigHeuristic):
|
||||
"""Conv template heuristic for CUDA."""
|
||||
|
||||
|
||||
# ROCm Conv2D/Conv3D heuristics
|
||||
@register_template_heuristic(
|
||||
conv2d_template.uid,
|
||||
"cuda",
|
||||
register=torch.version.hip is not None,
|
||||
)
|
||||
@register_template_heuristic(
|
||||
conv3d_template.uid,
|
||||
"cuda",
|
||||
register=torch.version.hip is not None,
|
||||
)
|
||||
class ROCmConvTemplateConfigHeuristic(ConvTemplateConfigMixin, ROCmConfigHeuristic):
|
||||
"""Conv template heuristic for ROCm."""
|
||||
|
||||
|
||||
# CPU Conv2D/Conv3D heuristics
|
||||
@register_template_heuristic(conv2d_template.uid, "cpu")
|
||||
@register_template_heuristic(conv3d_template.uid, "cpu")
|
||||
class CPUConvTemplateConfigHeuristic(ConvTemplateConfigMixin, CPUConfigHeuristic):
|
||||
"""Conv template heuristic for CPU."""
|
||||
|
||||
|
||||
# XPU Conv2D/Conv3D heuristics
|
||||
@register_template_heuristic(conv2d_template.uid, "xpu")
|
||||
@register_template_heuristic(conv3d_template.uid, "xpu")
|
||||
class XPUConvTemplateConfigHeuristic(ConvTemplateConfigMixin, XPUConfigHeuristic):
|
||||
"""Conv template heuristic for XPU."""
|
||||
|
||||
|
||||
# MTIA Conv2D/Conv3D heuristics
|
||||
@register_template_heuristic(conv2d_template.uid, "mtia")
|
||||
@register_template_heuristic(conv3d_template.uid, "mtia")
|
||||
class MTIAConvTemplateConfigHeuristic(ConvTemplateConfigMixin, MTIAConfigHeuristic):
|
||||
"""Conv template heuristic for MTIA."""
|
||||
@ -52,7 +52,26 @@ __all__ = [
|
||||
"MemRecordsAcc",
|
||||
]
|
||||
|
||||
from contextlib import ContextDecorator
|
||||
try:
|
||||
# Available in Python >= 3.2
|
||||
from contextlib import ContextDecorator as _ContextDecorator
|
||||
except ImportError:
|
||||
import functools
|
||||
|
||||
class _ContextDecorator: # type: ignore[no-redef]
|
||||
def __enter__(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(self, func):
|
||||
@functools.wraps(func)
|
||||
def wrapped(*args, **kwargs):
|
||||
with self:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
# global python state - whether profiler is currently enabled
|
||||
@ -725,7 +744,8 @@ class profile:
|
||||
return all_function_events
|
||||
|
||||
|
||||
class record_function(ContextDecorator):
|
||||
# pyrefly: ignore [invalid-inheritance]
|
||||
class record_function(_ContextDecorator):
|
||||
"""Context manager/function decorator that adds a label to a code block/function when running autograd profiler.
|
||||
Label will only appear if CPU activity tracing is enabled.
|
||||
|
||||
|
||||
@ -108,14 +108,12 @@ struct FlightRecorder {
|
||||
capture_cpp_stack_ = getCvarBool(
|
||||
{"TORCH_FR_CPP_STACK", "TORCH_NCCL_TRACE_CPP_STACK"}, false);
|
||||
enabled_ = max_entries_ > 0;
|
||||
reset_epoch_start_idx_[0] = 0;
|
||||
}
|
||||
struct Entry {
|
||||
size_t id_; // incremented id in the trace buffer
|
||||
// used to figure out where in the circular entries
|
||||
// buffer this entry will be located to
|
||||
// update state information
|
||||
size_t reset_epoch_; // epoch when this entry was created
|
||||
size_t pg_id_;
|
||||
std::tuple<std::string, std::string> pg_name_; // <group_name, group_desc>
|
||||
|
||||
@ -185,34 +183,11 @@ struct FlightRecorder {
|
||||
size_t max_entries_ = 0;
|
||||
size_t next_ = 0;
|
||||
size_t id_ = 0;
|
||||
size_t reset_epoch_ = 0;
|
||||
std::unordered_map<size_t, size_t>
|
||||
reset_epoch_start_idx_; // maps reset_epoch to the idx where it starts
|
||||
std::map<size_t, std::shared_ptr<ProcessGroupStatus>> all_pg_status_;
|
||||
std::map<std::tuple<std::string, std::string>, std::vector<uint64_t>>
|
||||
pg_name_to_ranks_;
|
||||
std::string comm_lib_version_;
|
||||
|
||||
struct TraceIdentifier {
|
||||
std::optional<size_t> id;
|
||||
std::optional<size_t> reset_epoch;
|
||||
};
|
||||
|
||||
TraceIdentifier recordWithResetEnabled(
|
||||
size_t pg_id,
|
||||
const std::tuple<std::string, std::string>& pg_name,
|
||||
size_t collective_seq_id,
|
||||
size_t p2p_seq_id,
|
||||
size_t op_id,
|
||||
std::string profiling_name,
|
||||
const std::vector<at::Tensor>& inputs,
|
||||
const std::vector<at::Tensor>& outputs,
|
||||
EventType* start,
|
||||
EventType* end,
|
||||
std::chrono::milliseconds timeout_ms,
|
||||
std::shared_ptr<ProcessGroupStatus> pg_status,
|
||||
bool isP2P);
|
||||
|
||||
std::optional<size_t> record(
|
||||
size_t pg_id,
|
||||
const std::tuple<std::string, std::string>& pg_name,
|
||||
@ -238,16 +213,8 @@ struct FlightRecorder {
|
||||
|
||||
std::vector<Entry> dump_entries();
|
||||
|
||||
// Returns the index in entries_ for the given id and reset_epoch.
|
||||
// Caller must hold mutex_lock before calling this method.
|
||||
size_t getIdxFromId(size_t id, size_t reset_epoch) const;
|
||||
|
||||
// Returns the entry with the given id and reset_epoch, if it exists.
|
||||
// Otherwise, returns std::nullopt.
|
||||
TORCH_API std::optional<Entry> getEntry(
|
||||
std::optional<size_t> id,
|
||||
std::optional<size_t> reset_epoch);
|
||||
|
||||
// Returns the entry with the given id, if it exists. Otherwise, returns
|
||||
// std::nullopt.
|
||||
TORCH_API std::optional<Entry> getEntry(std::optional<size_t> id);
|
||||
|
||||
/*
|
||||
@ -260,11 +227,6 @@ struct FlightRecorder {
|
||||
never hang. (timing must also be enabled for compute_duration - see
|
||||
TORCH_NCCL_ENABLE_TIMING).
|
||||
*/
|
||||
TORCH_API void retire_id(
|
||||
std::optional<size_t> id,
|
||||
std::optional<size_t> reset_epoch,
|
||||
bool compute_duration = true);
|
||||
|
||||
TORCH_API void retire_id(
|
||||
std::optional<size_t> id,
|
||||
bool compute_duration = true);
|
||||
|
||||
@ -53,41 +53,8 @@ std::optional<size_t> FlightRecorder<EventType>::record(
|
||||
std::chrono::milliseconds timeout_ms,
|
||||
std::shared_ptr<ProcessGroupStatus> pg_status,
|
||||
bool isP2P) {
|
||||
auto result = recordWithResetEnabled(
|
||||
pg_id,
|
||||
pg_name,
|
||||
collective_seq_id,
|
||||
p2p_seq_id,
|
||||
op_id,
|
||||
std::move(profiling_name),
|
||||
inputs,
|
||||
outputs,
|
||||
start,
|
||||
end,
|
||||
timeout_ms,
|
||||
std::move(pg_status),
|
||||
isP2P);
|
||||
return result.id;
|
||||
}
|
||||
|
||||
template <typename EventType>
|
||||
typename FlightRecorder<EventType>::TraceIdentifier FlightRecorder<EventType>::
|
||||
recordWithResetEnabled(
|
||||
size_t pg_id,
|
||||
const std::tuple<std::string, std::string>& pg_name,
|
||||
size_t collective_seq_id,
|
||||
size_t p2p_seq_id,
|
||||
size_t op_id,
|
||||
std::string profiling_name,
|
||||
const std::vector<at::Tensor>& inputs,
|
||||
const std::vector<at::Tensor>& outputs,
|
||||
EventType* start,
|
||||
EventType* end,
|
||||
std::chrono::milliseconds timeout_ms,
|
||||
std::shared_ptr<ProcessGroupStatus> pg_status,
|
||||
bool isP2P) {
|
||||
if (!enabled_) {
|
||||
return TraceIdentifier{std::nullopt, std::nullopt};
|
||||
return std::nullopt;
|
||||
}
|
||||
if (all_pg_status_.find(pg_id) == all_pg_status_.end()) {
|
||||
// Current pg_status is not in FR.
|
||||
@ -97,13 +64,8 @@ typename FlightRecorder<EventType>::TraceIdentifier FlightRecorder<EventType>::
|
||||
torch::CapturedTraceback::gather(true, true, capture_cpp_stack_);
|
||||
std::lock_guard<std::mutex> guard(mutex_);
|
||||
|
||||
TORCH_CHECK(
|
||||
reset_epoch_start_idx_.find(reset_epoch_) !=
|
||||
reset_epoch_start_idx_.end());
|
||||
|
||||
auto te = Entry{
|
||||
id_,
|
||||
reset_epoch_,
|
||||
pg_id,
|
||||
pg_name,
|
||||
collective_seq_id,
|
||||
@ -142,20 +104,15 @@ typename FlightRecorder<EventType>::TraceIdentifier FlightRecorder<EventType>::
|
||||
te.sizes_.insert(te.sizes_.end(), sizes.begin(), sizes.end());
|
||||
}
|
||||
|
||||
const auto next = next_++;
|
||||
|
||||
if (entries_.size() < max_entries_) {
|
||||
entries_.emplace_back(std::move(te));
|
||||
} else {
|
||||
entries_[next] = std::move(te);
|
||||
entries_[next_++] = std::move(te);
|
||||
if (next_ == max_entries_) {
|
||||
next_ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
if (next_ == max_entries_) {
|
||||
next_ = 0;
|
||||
}
|
||||
|
||||
const auto id = id_++;
|
||||
return TraceIdentifier{id, reset_epoch_};
|
||||
return id_++;
|
||||
}
|
||||
|
||||
template <typename EventType>
|
||||
@ -206,20 +163,15 @@ std::vector<typename FlightRecorder<EventType>::Entry> FlightRecorder<
|
||||
std::vector<Entry> result;
|
||||
{
|
||||
std::lock_guard<std::mutex> guard(mutex_);
|
||||
// Filter entries during insertion - only keep entries from current epoch
|
||||
auto filter = [this](const Entry& e) {
|
||||
return e.reset_epoch_ == reset_epoch_;
|
||||
};
|
||||
std::copy_if(
|
||||
result.reserve(entries_.size());
|
||||
result.insert(
|
||||
result.end(),
|
||||
entries_.begin() + static_cast<std::ptrdiff_t>(next_),
|
||||
entries_.end(),
|
||||
std::back_inserter(result),
|
||||
filter);
|
||||
std::copy_if(
|
||||
entries_.end());
|
||||
result.insert(
|
||||
result.end(),
|
||||
entries_.begin(),
|
||||
entries_.begin() + static_cast<std::ptrdiff_t>(next_),
|
||||
std::back_inserter(result),
|
||||
filter);
|
||||
entries_.begin() + static_cast<std::ptrdiff_t>(next_));
|
||||
}
|
||||
// query any remaining events
|
||||
for (auto& r : result) {
|
||||
@ -230,47 +182,28 @@ std::vector<typename FlightRecorder<EventType>::Entry> FlightRecorder<
|
||||
}
|
||||
|
||||
template <typename EventType>
|
||||
// Returns the index in entries_ for the given id and reset_epoch.
|
||||
// Caller must hold mutex_lock before calling this method.
|
||||
size_t FlightRecorder<EventType>::getIdxFromId(size_t id, size_t reset_epoch)
|
||||
const {
|
||||
// Look up the starting idx for the given reset epoch
|
||||
auto it = reset_epoch_start_idx_.find(reset_epoch);
|
||||
TORCH_CHECK(it != reset_epoch_start_idx_.end());
|
||||
// Calculate idx based on where the epoch started
|
||||
return (it->second + id) % max_entries_;
|
||||
}
|
||||
|
||||
template <typename EventType>
|
||||
// Returns the entry with the given id and reset_epoch, if it exists. Otherwise,
|
||||
// returns std::nullopt.
|
||||
// Returns the entry with the given id, if it exists. Otherwise, returns
|
||||
// std::nullopt.
|
||||
std::optional<typename FlightRecorder<EventType>::Entry> FlightRecorder<
|
||||
EventType>::
|
||||
getEntry(std::optional<size_t> id, std::optional<size_t> reset_epoch) {
|
||||
if (!enabled_ || !id || !reset_epoch) {
|
||||
EventType>::getEntry(std::optional<size_t> id) {
|
||||
if (!enabled_ || !id) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
std::unique_lock<std::mutex> guard(mutex_);
|
||||
Entry entry = entries_.at(getIdxFromId(*id, *reset_epoch));
|
||||
if (entry.id_ == *id && entry.reset_epoch_ == *reset_epoch) {
|
||||
Entry entry = entries_.at(*id % max_entries_);
|
||||
if (entry.id_ == *id) {
|
||||
return entry;
|
||||
} else {
|
||||
return std::nullopt;
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
template <typename EventType>
|
||||
std::optional<typename FlightRecorder<EventType>::Entry> FlightRecorder<
|
||||
EventType>::getEntry(std::optional<size_t> id) {
|
||||
return getEntry(id, 0);
|
||||
}
|
||||
|
||||
template <typename EventType>
|
||||
void FlightRecorder<EventType>::retire_id(
|
||||
std::optional<size_t> id,
|
||||
std::optional<size_t> reset_epoch,
|
||||
bool compute_duration) {
|
||||
if (!enabled_ || !id || !reset_epoch) {
|
||||
if (!enabled_ || !id) {
|
||||
return;
|
||||
}
|
||||
|
||||
@ -281,8 +214,8 @@ void FlightRecorder<EventType>::retire_id(
|
||||
|
||||
std::unique_lock<std::mutex> guard(mutex_);
|
||||
|
||||
Entry* entry = &entries_.at(getIdxFromId(*id, *reset_epoch));
|
||||
if (entry->id_ == *id && entry->reset_epoch_ == *reset_epoch) {
|
||||
Entry* entry = &entries_.at(*id % max_entries_);
|
||||
if (entry->id_ == *id) {
|
||||
update_state(*entry);
|
||||
|
||||
if (compute_duration) {
|
||||
@ -304,8 +237,8 @@ void FlightRecorder<EventType>::retire_id(
|
||||
guard.lock();
|
||||
|
||||
// Refresh the entry pointer, see if the entry has been overwritten
|
||||
entry = &entries_.at(getIdxFromId(*id, *reset_epoch));
|
||||
if (!(entry->id_ == *id && entry->reset_epoch_ == *reset_epoch)) {
|
||||
entry = &entries_.at(*id % max_entries_);
|
||||
if (entry->id_ != *id) {
|
||||
LOG(INFO) << "retire_id abandoned for id " << *id
|
||||
<< ", event was overwritten while waiting to compute duration.";
|
||||
return;
|
||||
@ -316,23 +249,12 @@ void FlightRecorder<EventType>::retire_id(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename EventType>
|
||||
void FlightRecorder<EventType>::retire_id(
|
||||
std::optional<size_t> id,
|
||||
bool compute_duration) {
|
||||
retire_id(id, 0, compute_duration);
|
||||
}
|
||||
|
||||
template <typename EventType>
|
||||
void FlightRecorder<EventType>::reset_all() {
|
||||
std::lock_guard<std::mutex> guard(mutex_);
|
||||
if (!entries_.empty()) {
|
||||
// Soft delete: increment epoch to mark all existing entries as old
|
||||
// Store where the new epoch starts in the circular buffer
|
||||
reset_epoch_++;
|
||||
reset_epoch_start_idx_[reset_epoch_] = next_;
|
||||
id_ = 0;
|
||||
}
|
||||
next_ = 0;
|
||||
id_ = 0;
|
||||
entries_.clear();
|
||||
}
|
||||
|
||||
template <typename EventType>
|
||||
|
||||
@ -708,8 +708,7 @@ void ProcessGroupGloo::runLoop(int workerIndex) {
|
||||
// TODO: We need to have numel of tensors for gloo as well.
|
||||
pgStatus_->lastCompletedNumelIn = 0;
|
||||
pgStatus_->lastCompletedNumelOut = 0;
|
||||
FlightRecorder<c10::Event>::get()->retire_id(
|
||||
work->trace_id_, work->trace_reset_epoch_, false);
|
||||
FlightRecorder<c10::Event>::get()->retire_id(work->trace_id_, false);
|
||||
lock.lock();
|
||||
workInProgress_[workerIndex].reset();
|
||||
}
|
||||
@ -781,7 +780,7 @@ void ProcessGroupGloo::enqueue(c10::intrusive_ptr<AsyncWork> work) {
|
||||
pgStatus_->lastEnqueuedNumelOut = 0;
|
||||
// using c10d::FlightRecorder;
|
||||
// TODO: We need to have a way to use c10::Event inside gloo as well.
|
||||
auto traceId = FlightRecorder<c10::Event>::get()->recordWithResetEnabled(
|
||||
work->trace_id_ = FlightRecorder<c10::Event>::get()->record(
|
||||
local_id_,
|
||||
std::make_tuple(pg_uid_, pg_desc_),
|
||||
collectiveCounter_,
|
||||
@ -796,8 +795,6 @@ void ProcessGroupGloo::enqueue(c10::intrusive_ptr<AsyncWork> work) {
|
||||
work->getTimeout(),
|
||||
pgStatus_,
|
||||
false);
|
||||
work->trace_id_ = traceId.id;
|
||||
work->trace_reset_epoch_ = traceId.reset_epoch;
|
||||
workQueue_.push_back(std::move(work));
|
||||
lock.unlock();
|
||||
|
||||
|
||||
@ -99,7 +99,6 @@ class TORCH_API ProcessGroupGloo : public Backend {
|
||||
// unique id used to tell the trace buffer that this
|
||||
// work has completed
|
||||
std::optional<uint64_t> trace_id_;
|
||||
std::optional<uint64_t> trace_reset_epoch_;
|
||||
std::shared_ptr<gloo::Context> context_;
|
||||
const std::chrono::milliseconds timeout_;
|
||||
|
||||
|
||||
@ -575,7 +575,6 @@ ProcessGroupNCCL::WorkNCCL::WorkNCCL(const WorkNCCL& w)
|
||||
futureWorkResult_(w.futureWorkResult_),
|
||||
timingEnabled_(w.timingEnabled_),
|
||||
trace_id_(w.trace_id_),
|
||||
trace_reset_epoch_(w.trace_reset_epoch_),
|
||||
distDebugLevel_(w.distDebugLevel_) {
|
||||
exception_ = w.exception_;
|
||||
}
|
||||
@ -705,9 +704,9 @@ bool ProcessGroupNCCL::WorkNCCL::checkTimeout(
|
||||
// Print the traceback of the collective at call time
|
||||
std::string ProcessGroupNCCL::WorkNCCL::getTraceback() const {
|
||||
// First step we get the corresponding record entry from FR, based on work's
|
||||
// trace_id_ and trace_reset_epoch_
|
||||
// trace_id_
|
||||
std::optional<FlightRecorderCUDA::Entry> entry =
|
||||
FlightRecorderCUDA::get()->getEntry(trace_id_, trace_reset_epoch_);
|
||||
FlightRecorderCUDA::get()->getEntry(trace_id_);
|
||||
if (entry.has_value()) {
|
||||
auto entryVal = entry.value();
|
||||
// Get stack trace from FR entry, in string format
|
||||
@ -2395,8 +2394,7 @@ void ProcessGroupNCCL::Watchdog::runLoop() {
|
||||
pg_->pgStatus_->lastCompletedWorkName = opTypeToString(work.opType_);
|
||||
pg_->pgStatus_->lastCompletedNumelIn = work.numelIn_;
|
||||
pg_->pgStatus_->lastCompletedNumelOut = work.numelOut_;
|
||||
FlightRecorderCUDA::get()->retire_id(
|
||||
work.trace_id_, work.trace_reset_epoch_, true);
|
||||
FlightRecorderCUDA::get()->retire_id(work.trace_id_, true);
|
||||
if (pg_->onCompletionHook_) {
|
||||
// Move Work object to completedWorkList_ to be consumed by the hook
|
||||
// thread
|
||||
@ -3362,7 +3360,7 @@ c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> ProcessGroupNCCL::initWork(
|
||||
// these objects to the Work because it has implications for keeping those
|
||||
// tensors alive longer and adds overhead when copying Work objects
|
||||
// between threads
|
||||
auto traceId = FlightRecorderCUDA::get()->recordWithResetEnabled(
|
||||
r->trace_id_ = FlightRecorderCUDA::get()->record(
|
||||
local_id_,
|
||||
std::make_tuple(pg_uid_, pg_desc_),
|
||||
seqCollective_,
|
||||
@ -3376,8 +3374,6 @@ c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> ProcessGroupNCCL::initWork(
|
||||
options_->timeout,
|
||||
pgStatus_,
|
||||
isP2P);
|
||||
r->trace_id_ = traceId.id;
|
||||
r->trace_reset_epoch_ = traceId.reset_epoch;
|
||||
}
|
||||
return r;
|
||||
}
|
||||
@ -3597,7 +3593,6 @@ float ProcessGroupNCCL::endTimeEstimate() {
|
||||
#ifdef NCCL_SIM_INFO_INITIALIZER
|
||||
ncclSimInfo_t simInfo = NCCL_SIM_INFO_INITIALIZER;
|
||||
C10D_NCCL_CHECK(ncclGroupSimulateEnd(&simInfo), std::nullopt);
|
||||
--ncclActiveGroupCounter_;
|
||||
return simInfo.estimatedTime;
|
||||
#else
|
||||
TORCH_CHECK(
|
||||
@ -3681,7 +3676,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
|
||||
// later in endCoalescing we record a 'coalesced' Work which has
|
||||
// timing/state updates via watchdog thread, but lacks op metadata such as
|
||||
// input/output sizes and profilingTitle per-op in the group.
|
||||
FlightRecorderCUDA::get()->recordWithResetEnabled(
|
||||
FlightRecorderCUDA::get()->record(
|
||||
local_id_,
|
||||
std::make_tuple(pg_uid_, pg_desc_),
|
||||
seqCollective_,
|
||||
@ -4173,7 +4168,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::pointToPoint(
|
||||
// TODO(whc) because we don't pass output {tensor} to initWork, we tell
|
||||
// initWork to not record, and then we manually call record passing all the
|
||||
// information it wants.
|
||||
auto traceId = FlightRecorderCUDA::get()->recordWithResetEnabled(
|
||||
work->trace_id_ = FlightRecorderCUDA::get()->record(
|
||||
local_id_,
|
||||
std::make_tuple(pg_uid_, pg_desc_),
|
||||
seqCollective_,
|
||||
@ -4187,8 +4182,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::pointToPoint(
|
||||
options_->timeout,
|
||||
pgStatus_,
|
||||
/*isP2P=*/true);
|
||||
work->trace_id_ = traceId.id;
|
||||
work->trace_reset_epoch_ = traceId.reset_epoch;
|
||||
}
|
||||
|
||||
// Only check for NaN for send ops, for recv ops `tensor` can be a random
|
||||
|
||||
@ -505,7 +505,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
// unique id used to tell the trace buffer that this
|
||||
// work has completed
|
||||
std::optional<uint64_t> trace_id_;
|
||||
std::optional<uint64_t> trace_reset_epoch_;
|
||||
DebugLevel distDebugLevel_;
|
||||
friend class ProcessGroupNCCL;
|
||||
};
|
||||
|
||||
@ -4,7 +4,6 @@
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
#include <torch/headeronly/macros/Macros.h>
|
||||
#include <torch/headeronly/util/Exception.h>
|
||||
#include <torch/headeronly/util/HeaderOnlyArrayRef.h>
|
||||
#include <torch/headeronly/util/shim_utils.h>
|
||||
#include <climits>
|
||||
#include <memory>
|
||||
@ -14,7 +13,6 @@
|
||||
HIDDEN_NAMESPACE_BEGIN(torch, stable)
|
||||
|
||||
using accelerator::DeviceIndex;
|
||||
using torch::headeronly::IntHeaderOnlyArrayRef;
|
||||
using torch::headeronly::ScalarType;
|
||||
|
||||
// The torch::stable::Tensor class is a highlevel C++ wrapper around
|
||||
@ -95,32 +93,6 @@ class Tensor {
|
||||
return numel;
|
||||
}
|
||||
|
||||
// note: this API is, for all intents and purposes, the same as the one in
|
||||
// TensorBase.h: it returns a borrowed reference of the dimension sizes of
|
||||
// a Tensor.
|
||||
//
|
||||
// The only difference is that it returns a header-only IntHeaderOnlyArrayRef,
|
||||
// which has slightly less functionality than a regular IntArrayRef. See
|
||||
// [HeaderOnlyArrayRef vs ArrayRef note] for more details.
|
||||
IntHeaderOnlyArrayRef sizes() const {
|
||||
int64_t* sizes;
|
||||
TORCH_ERROR_CODE_CHECK(aoti_torch_get_sizes(ath_.get(), &sizes));
|
||||
return IntHeaderOnlyArrayRef(sizes, dim());
|
||||
}
|
||||
|
||||
// note: this API is, for all intents and purposes, the same as the one in
|
||||
// TensorBase.h: it returns a borrowed reference of the strides of a
|
||||
// Tensor.
|
||||
//
|
||||
// The only difference is that it returns a header-only IntHeaderOnlyArrayRef,
|
||||
// which has slightly less functionality than a regular IntArrayRef. See
|
||||
// [HeaderOnlyArrayRef vs ArrayRef note] for more details.
|
||||
IntHeaderOnlyArrayRef strides() const {
|
||||
int64_t* strides;
|
||||
TORCH_ERROR_CODE_CHECK(aoti_torch_get_strides(ath_.get(), &strides));
|
||||
return IntHeaderOnlyArrayRef(strides, dim());
|
||||
}
|
||||
|
||||
// note: this is a subset of the original TensorBase API. It takes no
|
||||
// arguments whereas the original API takes in a kwarg of memory format.
|
||||
// Here, we assume the default contiguous memory format.
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
import functools
|
||||
import math
|
||||
import operator
|
||||
from collections.abc import Callable, Sequence
|
||||
from collections.abc import Sequence
|
||||
from datetime import timedelta
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from torch._C import ScriptObject
|
||||
|
||||
@ -10,7 +10,6 @@ from ._context_parallel._attention import (
|
||||
_enable_context_parallel_dispatcher,
|
||||
_is_causal_behavior,
|
||||
_RotateMethod,
|
||||
_templated_ring_attention,
|
||||
context_parallel,
|
||||
context_parallel_unshard,
|
||||
set_rotate_method,
|
||||
@ -23,7 +22,6 @@ from ._context_parallel._load_balancer import (
|
||||
)
|
||||
|
||||
|
||||
# TODO(fegin): add deprecation message once the final interfaces are concluded.
|
||||
__all__ = [
|
||||
"_CausalBehavior",
|
||||
"_context_parallel_shard",
|
||||
@ -33,7 +31,6 @@ __all__ = [
|
||||
"_enable_context_parallel_dispatcher",
|
||||
"_is_causal_behavior",
|
||||
"_RotateMethod",
|
||||
"_templated_ring_attention",
|
||||
"context_parallel",
|
||||
"context_parallel_unshard",
|
||||
"set_rotate_method",
|
||||
|
||||
Reference in New Issue
Block a user