Compare commits

..

1 Commits

Author SHA1 Message Date
a09b0c26a7 Automated submodule update: tensorpipe 2025-11-05 10:57:08 -08:00
41 changed files with 340 additions and 1052 deletions

View File

@ -18,7 +18,7 @@ aspects of contributing to PyTorch.
- [Python Unit Testing](#python-unit-testing)
- [Better local unit tests with `pytest`](#better-local-unit-tests-with-pytest)
- [Local linting](#local-linting)
- [Running `pyrefly`](#running-pyrefly)
- [Running `mypy`](#running-mypy)
- [C++ Unit Testing](#c-unit-testing)
- [Run Specific CI Jobs](#run-specific-ci-jobs)
- [Merging your Change](#merging-your-change)
@ -281,7 +281,7 @@ dependencies as well as the nightly binaries into the repo directory.
**Prerequisites**:
The following packages should be installed with `pip`:
- `expecttest` and `hypothesis` - required to run tests
- `pyrefly` - recommended for type checking. [Pyrefly](https://pyrefly.org/)
- `mypy` - recommended for linting
- `pytest` - recommended to run tests more selectively
Running
```
@ -350,32 +350,15 @@ make lint
Learn more about the linter on the [lintrunner wiki page](https://github.com/pytorch/pytorch/wiki/lintrunner)
#### Running `pyrefly`
#### Running `mypy`
[Pyrefly](https://pyrefly.org/) is a high-performance static type checker for Python. It provides fast type checking along with IDE features like autocomplete and instant error feedback.
PyTorch uses Pyrefly for type checking across the codebase. The configuration is managed in `pyrefly.toml` at the root of the repository.
**Getting Started with Pyrefly:**
To run type checking on the PyTorch codebase:
```bash
pyrefly check
```
For more detailed error information with summaries:
```bash
pyrefly check --summarize-errors
```
**Learn More:**
- [Pyrefly Configuration](https://pyrefly.org/en/docs/configuration/) - Detailed configuration options
- [Pyrefly IDE Features](https://pyrefly.org/en/docs/IDE-features/) - Set up Pyrefly in your editor for real-time type checking
- [Python Typing Tutorial](https://pyrefly.org/en/docs/typing-for-python-developers/) - Learn about Python type annotations
`mypy` is an optional static type checker for Python. We have multiple `mypy`
configs for the PyTorch codebase that are automatically validated against whenever the linter is run.
See [Guide for adding type annotations to
PyTorch](https://github.com/pytorch/pytorch/wiki/Guide-for-adding-type-annotations-to-PyTorch)
for PyTorch-specific guidance on how to set up `pyrefly` and tackle type annotation tasks in this codebase.
for more information on how to set up `mypy` and tackle type annotation
tasks.
### C++ Unit Testing

View File

@ -226,8 +226,8 @@ template <
typename B = HostBlock<S>>
struct CachingHostAllocatorImpl {
virtual ~CachingHostAllocatorImpl() {
if (active_) {
active_ = false;
active_ = false;
if (pinned_use_background_threads()) {
getBackgroundThreadPool()->waitWorkComplete();
}
}
@ -260,7 +260,6 @@ struct CachingHostAllocatorImpl {
if (pinned_use_background_threads()) {
// Launch the background thread and process events in a loop.
static bool background_thread_flag [[maybe_unused]] = [this] {
active_ = true;
getBackgroundThreadPool()->run([&]() {
while (active_) {
process_events();
@ -684,9 +683,9 @@ struct CachingHostAllocatorImpl {
alignas(hardware_destructive_interference_size) std::mutex events_mutex_;
std::deque<std::pair<E, B*>> events_; // event queue paired with block
// Indicates whether the event-processing thread pool is active.
// Indicates whether the object is active.
// Set to false in the destructor to signal background threads to stop.
std::atomic<bool> active_{false};
std::atomic<bool> active_{true};
protected:
alignas(hardware_destructive_interference_size) HostStatsStaged stats_;
};

View File

@ -24,13 +24,7 @@ namespace detail {
// radix_sort_pairs doesn't interact with value_t other than to copy
// the data, so we can save template instantiations by reinterpreting
// it as an opaque type.
// We use native integer types for 1/2/4/8-byte values to reduce
// register usage in CUDA kernels. For sizes > 8 fall back to char array.
template <int N> struct alignas(N) OpaqueType { char data[N]; };
template <> struct alignas(1) OpaqueType<1> { uint8_t data; };
template <> struct alignas(2) OpaqueType<2> { uint16_t data; };
template <> struct alignas(4) OpaqueType<4> { uint32_t data; };
template <> struct alignas(8) OpaqueType<8> { uint64_t data; };
template<typename key_t, int value_size>
void radix_sort_pairs_impl(

View File

@ -247,8 +247,8 @@ void binary_kernel_reduce(TensorIteratorBase& iter, ops_t ops, init_t init) {
});
}
template <typename func_t, typename vec_func_t, typename ident_t = double>
void binary_kernel_reduce_vec(TensorIteratorBase& iter, func_t op, vec_func_t vop, ident_t ident = static_cast<ident_t>(0)) {
template <typename func_t, typename vec_func_t>
void binary_kernel_reduce_vec(TensorIteratorBase& iter, func_t op, vec_func_t vop, double ident = 0) {
using traits = binary_function_traits<func_t>;
static_assert(
all_same<

View File

@ -339,13 +339,33 @@ void or_kernel_impl(TensorIterator& iter) {
}
}
template<typename scalar_t>
struct MinValuesOps: public at::native::MinOps<scalar_t> {
using arg_t = typename MinOps<scalar_t>::arg_t;
static scalar_t project(arg_t arg) {
return arg.first;
}
};
void min_values_kernel_impl(TensorIterator& iter) {
// This case is special because of Vectorized<int64_t> does not
// handle upper_bound<int64_t>().
// See: https://github.com/pytorch/pytorch/issues/43254
if (iter.dtype() == kLong || iter.dtype() == kUInt64) {
AT_DISPATCH_V2(iter.dtype(), "min_values_cpu", AT_WRAP([&iter] {
binary_kernel_reduce(
iter,
MinValuesOps<scalar_t>{},
std::pair<scalar_t, int64_t>(upper_bound<scalar_t>(), -1));
}), kLong, kUInt64);
return;
}
AT_DISPATCH_V2(iter.dtype(), "min_values_cpu", AT_WRAP([&iter] {
binary_kernel_reduce_vec(
iter,
[](scalar_t a, scalar_t b) -> scalar_t { return min_impl(a, b); },
[](Vectorized<scalar_t> a, Vectorized<scalar_t> b) { return minimum(a, b); },
upper_bound<scalar_t>());
static_cast<double>(upper_bound<scalar_t>()));
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool);
}

View File

@ -47,10 +47,20 @@ Tensor sgd_out_of_place(
STD_TORCH_CHECK(param.get_device() == -1, "CPU device index = -1");
STD_TORCH_CHECK(param.get_device_index() == -1, "CPU device index = -1");
// testing Tensor strides + stride
STD_TORCH_CHECK(param.strides()[0] == param.stride(0));
int64_t *param_sizes;
int64_t *param_strides;
aoti_torch_get_sizes(param.get(), &param_sizes);
aoti_torch_get_strides(param.get(), &param_strides);
auto out = new_empty(param, param.sizes());
int32_t param_dtype;
aoti_torch_get_dtype(param.get(), &param_dtype);
int32_t param_device_type;
aoti_torch_get_device_type(param.get(), &param_device_type);
AtenTensorHandle out_ath;
aoti_torch_empty_strided(param.dim(), param_sizes, param_strides, param_dtype, param_device_type, param.get_device(), &out_ath);
auto out = Tensor(out_ath);
sgd_math(
reinterpret_cast<float*>(param.data_ptr()),
@ -334,8 +344,6 @@ Tensor my_new_empty_dtype_variant(Tensor t) {
// Still using a std::vector below even though people can just pass in an
// initializer list (which will be implicitly converted to an HeaderOnlyArrayRef)
// directly.
// This is to test that passing in a std::vector works for BC. (It gets
// implicitly converted to HeaderOnlyArrayRef too!)
std::vector<int64_t> sizes = {2, 5};
auto dtype = std::make_optional(torch::headeronly::ScalarType::BFloat16);
return new_empty(t, sizes, dtype);

View File

@ -5,16 +5,8 @@ import contextlib
import torch
import torch.distributed as dist
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.distributed.tensor import (
DeviceMesh,
distribute_tensor,
DTensor,
Partial,
Replicate,
Shard,
)
from torch.distributed.tensor import DeviceMesh, DTensor, Partial, Replicate, Shard
from torch.distributed.tensor._dtensor_spec import ShardOrderEntry
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
@ -434,31 +426,6 @@ class TestDTensorDebugMode(TestCase):
][-1]
self.assertTrue("self.l2(self.l1(x))" in sum_op.fwd_stack_trace)
def test_pretty_print_dtensor_make_fx(self):
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
A = torch.randn(8, 32)
B = torch.randn(32, 32)
dA = distribute_tensor(A, mesh, [Shard(0)]).requires_grad_()
dB = distribute_tensor(B, mesh, [Replicate()]).requires_grad_()
def f(dA, dB):
dy = dA @ dB
loss = dy.sum()
loss.backward()
return dA.grad, dB.grad
# We actually need the tracing_mode='fake' here, or to trace under a FakeTensorMode.
# make_fx has some logic to ensure we don't accidentally stash real tensors in the graph
# so we won't stash our DTensors properly if they don't hold Fake inner tensors
gm = make_fx(f, tracing_mode="fake")(dA, dB)
# DCE isn't necessary here, there were just a lot of dead detach() nodes that spammed the graph
gm.graph.eliminate_dead_code()
gm.recompile()
# Colored is nice for actual viewing, not using in this test though
gm_str = gm.print_readable(colored=False, print_output=False)
self.assertTrue('"DTensor(f32[8, 32], S(0))" = torch.ops.aten.mm' in gm_str)
instantiate_parametrized_tests(TestDTensorDebugMode)

View File

@ -5789,229 +5789,6 @@ class NCCLTraceTest(NCCLTraceTestBase):
else:
self.assertTrue("duration_ms" not in t["entries"][0])
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@parametrize("timing_enabled", [True, False])
def test_fr_record_reset_circular_buffer_full(self, timing_enabled):
"""
Test that when the circular buffer in entries_ is full and we call reset,
then fill the buffer with new entries, dump_entries returns only the new
entries and not the old ones.
"""
if self.rank == self.MAIN_PROCESS_RANK:
return
# Override buffer size to 10 for faster testing
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10"
pg = self._create_process_group_nccl()
if timing_enabled:
pg._enable_collectives_timing()
device = self.local_device
self.set_thread_name("fr_test_thread")
a = torch.full((3, 4), float(self.rank), device=device)
# Fill the buffer completely with 10 entries
for _ in range(10):
f = pg.allreduce(a)
f.wait()
torch.cuda.synchronize(device=device)
time.sleep(1)
# Verify buffer is full with 10 entries
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
self.assertEqual(len(t["entries"]), 10)
# Now reset the flight recorder
torch._C._distributed_c10d._reset_fr_recording_nccl()
# Add new entries after reset - fill the buffer completely again
for _ in range(10):
f = pg.allreduce(a)
f.wait()
torch.cuda.synchronize(device=device)
time.sleep(1)
# Verify we get exactly 10 new entries, not 20
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
self.assertEqual(len(t["entries"]), 10)
# Verify all entries have the expected properties (from after reset)
# After reset, record IDs should start from 0 again
for i, entry in enumerate(t["entries"]):
self.assertIn("profiling_name", entry)
self.assertEqual(entry["profiling_name"], "nccl:all_reduce")
self.assertIn("record_id", entry)
# Record IDs should be sequential starting from 0 after reset
self.assertEqual(entry["record_id"], i)
dist.destroy_process_group()
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@parametrize("timing_enabled", [True, False])
def test_fr_record_reset_partial_overwrite(self, timing_enabled):
"""
Test that when the circular buffer is full, we reset, and then add fewer
entries than the buffer size, we only get the new entries.
This tests that old entries at the end of the circular buffer are properly
filtered out based on reset_epoch.
"""
if self.rank == self.MAIN_PROCESS_RANK:
return
# Override buffer size to 10 for faster testing
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10"
pg = self._create_process_group_nccl()
if timing_enabled:
pg._enable_collectives_timing()
device = self.local_device
self.set_thread_name("fr_test_thread")
a = torch.full((3, 4), float(self.rank), device=device)
# Fill the buffer completely
for _ in range(10):
f = pg.allreduce(a)
f.wait()
torch.cuda.synchronize(device=device)
time.sleep(1)
# Reset the flight recorder
torch._C._distributed_c10d._reset_fr_recording_nccl()
# Add only 3 new entries (much less than buffer size)
for _ in range(3):
f = pg.allreduce(a)
f.wait()
torch.cuda.synchronize(device=device)
time.sleep(1)
# Verify we only get the 3 new entries, not 10
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
self.assertEqual(len(t["entries"]), 3)
# Verify record IDs start from 0 after reset
for i, entry in enumerate(t["entries"]):
self.assertIn("record_id", entry)
self.assertEqual(entry["record_id"], i)
dist.destroy_process_group()
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@parametrize("timing_enabled", [True, False])
def test_fr_record_reset_wraparound(self, timing_enabled):
"""
Test that when we reset in the middle of the circular buffer and then
wrap around, dump_entries correctly returns only entries from the current
epoch in the correct order.
"""
if self.rank == self.MAIN_PROCESS_RANK:
return
# Override buffer size to 10 for faster testing
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10"
pg = self._create_process_group_nccl()
if timing_enabled:
pg._enable_collectives_timing()
device = self.local_device
self.set_thread_name("fr_test_thread")
a = torch.full((3, 4), float(self.rank), device=device)
# Fill half the buffer
for _ in range(5):
f = pg.allreduce(a)
f.wait()
torch.cuda.synchronize(device=device)
time.sleep(1)
# Reset at this point (reset happens at index 5)
torch._C._distributed_c10d._reset_fr_recording_nccl()
# Now add 8 entries, which will wrap around
# (5->9 fills rest of buffer, then 0->2 wraps around)
for _ in range(8):
f = pg.allreduce(a)
f.wait()
torch.cuda.synchronize(device=device)
time.sleep(1)
# Should get exactly 8 entries, properly ordered
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
self.assertEqual(len(t["entries"]), 8)
# Entries should be in chronological order
# The dump_entries() method returns entries from next_ to end, then 0 to next_
# After filtering old entries, we should have 8 entries in order
# Verify record IDs start from 0 after reset (id_ is reset in reset_all())
for i, entry in enumerate(t["entries"]):
self.assertIn("profiling_name", entry)
self.assertIn("record_id", entry)
self.assertEqual(entry["record_id"], i)
dist.destroy_process_group()
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@parametrize("timing_enabled", [True, False])
def test_fr_record_multiple_resets(self, timing_enabled):
"""
Test multiple consecutive resets to ensure each reset properly increments
the epoch and filters out entries from previous epochs.
"""
if self.rank == self.MAIN_PROCESS_RANK:
return
# Override buffer size to 10 for faster testing
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10"
pg = self._create_process_group_nccl()
if timing_enabled:
pg._enable_collectives_timing()
device = self.local_device
self.set_thread_name("fr_test_thread")
a = torch.full((3, 4), float(self.rank), device=device)
# First batch: 2 entries
for _ in range(2):
f = pg.allreduce(a)
f.wait()
torch.cuda.synchronize(device=device)
time.sleep(1)
# First reset
torch._C._distributed_c10d._reset_fr_recording_nccl()
# Second batch: 3 entries
for _ in range(3):
f = pg.allreduce(a)
f.wait()
torch.cuda.synchronize(device=device)
time.sleep(1)
# Second reset
torch._C._distributed_c10d._reset_fr_recording_nccl()
# Third batch: 4 entries
for _ in range(4):
f = pg.allreduce(a)
f.wait()
torch.cuda.synchronize(device=device)
time.sleep(1)
# Should only see the last 4 entries
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
self.assertEqual(len(t["entries"]), 4)
# Verify record IDs start from 0 after the last reset
for i, entry in enumerate(t["entries"]):
self.assertIn("record_id", entry)
self.assertEqual(entry["record_id"], i)
dist.destroy_process_group()
def check_if_test_is_skipped(fn):
def wrapper(self, *args, **kwargs):

View File

@ -8,11 +8,21 @@ from torch._dynamo.graph_deduplication import apply_graph_deduplication
from torch._dynamo.graph_utils import _detect_cycles
from torch._dynamo.output_graph import FakeRootModule
from torch._dynamo.test_case import TestCase
from torch._dynamo.testing import extract_graph, extract_graph_and_tracker, normalize_gm
from torch._dynamo.testing import (
AotEagerAndRecordGraphs,
extract_graph_and_tracker,
normalize_gm,
)
from torch.compiler import allow_in_graph
from torch.utils._ordered_set import OrderedSet
def extract_graph(fn, *args, **kwargs):
backend = AotEagerAndRecordGraphs()
result = torch.compile(backend=backend)(fn)(*args, **kwargs)
return result, backend.graphs, backend.fw_graphs
def graph_str(gm):
return normalize_gm(gm.print_readable(print_output=False))
@ -30,7 +40,7 @@ class GraphDededuplicationTests(TestCase):
super().tearDown()
def run_and_return_graphs(self, fn, *args, **kwargs):
return extract_graph(fn, *args, **kwargs)[0:3]
return extract_graph(fn, *args, **kwargs)
def run_and_get_simple_graph(self):
def fn(x, y):

View File

@ -1,7 +1,7 @@
# Owner(s): ["module: dynamo"]
import unittest
from collections.abc import Callable, Sequence
from typing import Any, Union
from collections.abc import Sequence
from typing import Any, Callable, Union
import torch
import torch._dynamo

View File

@ -1,5 +1,5 @@
# Owner(s): ["module: dynamo"]
from typing import NamedTuple, Optional, TYPE_CHECKING
from typing import Callable, NamedTuple, Optional
import torch
import torch._dynamo
@ -7,10 +7,6 @@ from torch._dynamo.test_case import run_tests, TestCase
from torch._dynamo.testing import CompileCounter, same
if TYPE_CHECKING:
from collections.abc import Callable
"""
This is an example of a pure-python version of autograd implemented by
@zdevito. It represents a rather challenging test case for TorchDynamo

View File

@ -1,13 +1,11 @@
# Owner(s): ["module: dynamo"]
import functools
import re
import unittest
import weakref
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
from torch._dynamo.testing import extract_graph, remove_trailing_space
from torch.testing._internal.common_cuda import TEST_MULTIGPU
from torch.testing._internal.common_utils import requires_cuda
@ -17,14 +15,6 @@ requires_multigpu = functools.partial(
)
def remove_file_comment(gm_str: str) -> str:
return remove_trailing_space(re.sub(r"File.*\n", "\n", gm_str))
def print_graph(graph: torch.fx.GraphModule) -> str:
return remove_file_comment(graph.print_readable())
class TestStreams(torch._dynamo.test_case.TestCase):
@classmethod
def setUpClass(cls):
@ -46,7 +36,9 @@ class TestStreams(torch._dynamo.test_case.TestCase):
@requires_cuda
def test_stream_enter_exit(self):
def fn(x, y, s1, s2):
def fn(x, y):
s2 = torch.Stream()
s1 = torch.Stream()
with s1:
z1 = torch.add(x, y)
with s2:
@ -55,36 +47,13 @@ class TestStreams(torch._dynamo.test_case.TestCase):
return y
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2), torch.Stream(), torch.Stream())
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
expected = fn(*inp)
(
actual,
_,
fw_graphs,
_,
) = extract_graph(fn, *inp)
self.assertEqual(len(fw_graphs), 1)
fn_opt = torch.compile(fn, fullgraph=True)
actual = fn_opt(*inp)
self.assertEqual(expected, actual)
self.assertExpectedInline(
print_graph(fw_graphs[0]),
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
# Annotation: {'stream': None}
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
# Annotation: {'stream': None}
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
# Annotation: {'stream': None}
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 2); add_1 = None
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_2, add); add_2 = add = None
return (add_3,)
""",
)
@requires_cuda
@unittest.skip("Needs graph break support with annotation context")
def test_stream_context_graph_break(self):
def fn(x, y):
s2 = torch.Stream()
@ -101,16 +70,9 @@ class <lambda>(torch.nn.Module):
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
expected = fn(*inp)
(
actual,
_,
fw_graphs,
_,
) = extract_graph(fn, *inp)
fn_opt = torch.compile(fn)
actual = fn_opt(*inp)
self.assertEqual(expected, actual)
self.assertEqual(len(fw_graphs), 2)
self.assertExpectedInline(print_graph(fw_graphs[0]), """""")
self.assertExpectedInline(print_graph(fw_graphs[1]), """""")
@requires_cuda
def test_stream_input(self):
@ -193,248 +155,22 @@ class <lambda>(torch.nn.Module):
self.assertEqual(s_act, s_exp)
def test_nested_stream_enter_exit(self):
def fn(x, y, s0, s1, s2):
with s1:
with s2:
z1 = torch.add(x, y)
with s0:
z0 = torch.add(x, y)
with s2:
y = 2 + z1
pass
return z0, y
inp = (
torch.ones(2, 2) + 1,
torch.ones(2, 2),
torch.Stream(),
torch.Stream(),
torch.Stream(),
)
expected = fn(*inp)
(
actual,
_,
fw_graphs,
_,
) = extract_graph(fn, *inp)
self.assertEqual(len(fw_graphs), 1)
self.assertEqual(expected, actual)
self.assertExpectedInline(
print_graph(fw_graphs[0]),
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
# Annotation: {'stream': None}
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
# Annotation: {'stream': None}
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
# Annotation: {'stream': None}
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add, 2); add = None
return (add_1, add_2)
""",
)
@unittest.skip("Needs graph break support with annotation context")
def test_stream_enter_exit_graph_break(self):
pass
@unittest.skip("Needs graph break support with annotation context")
def test_nested_stream_enter_exit_graph_break(self):
pass
def test_local_stream_enter_exit(self):
def fn(x, y):
s2 = torch.Stream()
s1 = torch.Stream()
with s1:
z1 = torch.add(x, y)
with s2:
z = torch.add(x, y)
y = z + 2 + z1
return y
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
expected = fn(*inp)
(
actual,
_,
fw_graphs,
_,
) = extract_graph(fn, *inp)
self.assertEqual(len(fw_graphs), 1)
self.assertEqual(expected, actual)
self.assertExpectedInline(
print_graph(fw_graphs[0]),
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
# Annotation: {'stream': 1}
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
# Annotation: {'stream': 0}
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
# Annotation: {'stream': 0}
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 2); add_1 = None
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_2, add); add_2 = add = None
return (add_3,)
""",
)
pass
def test_local_stream_nested_enter_exit(self):
def fn(x, y):
s2 = torch.Stream()
s1 = torch.Stream()
s0 = torch.Stream()
with s1:
with s2:
z1 = torch.add(x, y)
with s0:
z0 = torch.add(x, y)
with s2:
y = 2 + z1
return z0, y
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
expected = fn(*inp)
(
actual,
_,
fw_graphs,
_,
) = extract_graph(fn, *inp)
self.assertEqual(len(fw_graphs), 1)
self.assertEqual(expected, actual)
self.assertExpectedInline(
print_graph(fw_graphs[0]),
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
# Annotation: {'stream': 0}
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
# Annotation: {'stream': 2}
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
# Annotation: {'stream': 0}
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add, 2); add = None
return (add_1, add_2)
""",
)
pass
def test_stream_with_mutation(self):
def fn(x, y):
s2 = torch.Stream()
s1 = torch.Stream()
s0 = torch.Stream()
with s1:
with s2:
x.add_(y)
with s0:
z1 = torch.add(y, y)
z0 = torch.add(z1, y)
with s2:
y = 2 + z1
return z0, y
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
expected = fn(*inp)
(
actual,
_,
fw_graphs,
_,
) = extract_graph(fn, *inp)
self.assertEqual(len(fw_graphs), 1)
self.assertEqual(expected, actual)
self.assertExpectedInline(
print_graph(fw_graphs[0]),
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
# Annotation: {'stream': 0}
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
# Annotation: {'stream': 2}
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg1_1, arg1_1)
# Annotation: {'stream': 2}
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, arg1_1); arg1_1 = None
# Annotation: {'stream': 0}
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 2); add_1 = None
#
copy_: "f32[2, 2]" = torch.ops.aten.copy_.default(arg0_1, add); arg0_1 = add = copy_ = None
return (add_2, add_3)
""",
)
def test_stream_backward(self) -> None:
def fn(x, y):
s2 = torch.Stream()
s0 = torch.Stream()
with s0:
y0 = 2 * x + y
with s2:
z = 2 * x + y
return y0, z
inp = (
torch.ones(2, 2, requires_grad=True) + 1,
torch.ones(2, 2, requires_grad=True),
)
expected = fn(*inp)
(
actual,
_,
fw_graphs,
bw_graphs,
) = extract_graph(fn, *inp)
self.assertEqual(len(fw_graphs), 1)
self.assertEqual(expected, actual)
self.assertExpectedInline(
print_graph(fw_graphs[0]),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[2, 2]", primals_2: "f32[2, 2]"):
# Annotation: {'stream': 1}
mul: "f32[2, 2]" = torch.ops.aten.mul.Tensor(primals_1, 2); primals_1 = None
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul, primals_2)
# Annotation: {'stream': 0}
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul, primals_2); mul = primals_2 = None
return (add, add_1)
""",
)
actual[1].sum().backward()
self.assertExpectedInline(
print_graph(bw_graphs[0]),
"""\
class GraphModule(torch.nn.Module):
def forward(self, tangents_1: "f32[2, 2]", tangents_2: "f32[2, 2]"):
# Annotation: {'stream': 0}
mul_2: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_2, 2)
#
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(tangents_2, tangents_1); tangents_2 = None
# Annotation: {'stream': 1}
mul_3: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_1, 2); tangents_1 = None
#
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul_2, mul_3); mul_2 = mul_3 = None
return (add_3, add_2)
""",
)
pass
@requires_cuda
def test_run_opcheck(self):

View File

@ -14424,6 +14424,20 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
self.common(fn, (torch.randn(6, 4, device=GPU_TYPE).t().contiguous().t(),))
@skip_if_halide
@requires_cuda_and_triton
def test_unbacked_float_item(self):
def fn(x, max_val):
return torch.clamp(x, 0, max_val.item())
self.common(
fn,
(
torch.randn(10, 20, 30, device=self.device),
torch.tensor(5.0, device=self.device),
),
)
# end of class CommonTemplate - add new tests here

View File

@ -1864,8 +1864,6 @@ class TestFP8Matmul(TestCase):
], name_fn=lambda mkn: f"{mkn[0]}_{mkn[1]}_{mkn[2]}")
@parametrize("recipe", ["mxfp8", "mxfp4", "nvfp4"])
def test_blockwise_mxfp8_nvfp4_mxfp4_numerics(self, test_case_name, fast_accum, mkn, recipe) -> None:
if torch.version.hip and recipe == "nvfp4":
raise unittest.SkipTest("nvfp4 not supported on ROCm, skipping")
if (recipe == "nvfp4" or recipe == "mxfp4") and fast_accum:
raise unittest.SkipTest("fast_accum not supported in nvfp4/mxfp4 cublas gemm, skipping")

View File

@ -1914,7 +1914,6 @@ class TestSDPAFailureModes(NNTestCase):
q, k, v, None, 0.0, is_causal=True))
@onlyCUDA
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support Efficient Attention")
def test_mem_eff_attention_fail_with_batch_size_geq_65536(self):
batch_size = 2**16
query = torch.rand([batch_size, 2, 2, 8], device='cuda', dtype=torch.float16, requires_grad=True)
@ -1936,7 +1935,6 @@ class TestSDPAFailureModes(NNTestCase):
self.assertEqual(value.grad, v_cpu.grad, atol=2e-3, rtol=1e-4)
@onlyCUDA
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support Efficient Attention")
def test_mem_eff_attention_fail_with_batch_size_geq_65536_error(self):
query = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16)
key = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16)
@ -1950,7 +1948,6 @@ class TestSDPAFailureModes(NNTestCase):
@largeTensorTest("15GB", "cuda")
@onlyCUDA
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support Efficient Attention")
def test_mem_eff_attention_large_seq_len_uniform_attention(self):
device = torch.device("cuda")
dtype = torch.bfloat16

View File

@ -1,5 +1,5 @@
from typing import TypeAlias, Union
from typing_extensions import assert_type
from typing import Union
from typing_extensions import assert_type, TypeAlias
from torch import randn, Tensor

View File

@ -1,9 +1,8 @@
# mypy: allow-untyped-defs
# mypy: disable-error-code="type-arg"
from collections.abc import Callable
from datetime import timedelta
from enum import Enum
from typing import Any, Optional, overload, Union
from typing import Any, Callable, Optional, overload, Union
import torch
from torch import Tensor

View File

@ -3320,7 +3320,7 @@ class InstructionTranslatorBase(
obj = self.stack[-inst.arg]
assert isinstance(obj, SetVariable)
assert obj.is_mutable()
obj.call_method(self, "add", [v], {}) # type: ignore[arg-type]
obj.call_method(self, "add", [v], {})
def SET_UPDATE(self, inst: Instruction) -> None:
v = self.pop()
@ -3329,7 +3329,7 @@ class InstructionTranslatorBase(
obj = self.stack[-inst.arg]
assert isinstance(obj, SetVariable)
assert obj.is_mutable()
obj.call_method(self, "update", [v], {}) # type: ignore[arg-type]
obj.call_method(self, "update", [v], {})
def LIST_APPEND(self, inst: Instruction) -> None:
v = self.pop()
@ -3637,7 +3637,7 @@ class InstructionTranslatorBase(
obj = self.stack[-inst.arg].realize()
assert isinstance(obj, ConstDictVariable)
assert obj.is_mutable()
obj.call_method(self, "update", [v], {}) # type: ignore[arg-type]
obj.call_method(self, "update", [v], {})
DICT_UPDATE = DICT_MERGE

View File

@ -87,12 +87,6 @@ def extract_graph_and_tracker(fn, *args, **kwargs): # type: ignore[no-untyped-d
return gm.graph, region_tracker # type: ignore[union-attr]
def extract_graph(fn, *args, **kwargs): # type: ignore[no-untyped-def]
backend = AotEagerAndRecordGraphs()
result = torch.compile(backend=backend)(fn)(*args, **kwargs)
return result, backend.graphs, backend.fw_graphs, backend.bw_graphs
def collect_results(
model: torch.nn.Module, prediction: Any, loss: Any, example_inputs: Any
) -> list[Any]:

View File

@ -21,9 +21,9 @@ restoring state changes.
import inspect
import sys
import warnings
from collections.abc import Callable, Sequence, Sized
from collections.abc import Callable, Sequence
from contextlib import ExitStack
from typing import Any, ContextManager, Optional, TYPE_CHECKING, Union
from typing import Any, ContextManager, Optional, Sized, TYPE_CHECKING, Union
import torch._C
from torch._guards import Guard

View File

@ -1,3 +1,5 @@
# mypy: ignore-errors
"""
Dictionary-related variable tracking classes for PyTorch Dynamo.
@ -24,7 +26,7 @@ import inspect
import operator
import types
from collections.abc import Hashable as py_Hashable
from typing import Any, Optional, TYPE_CHECKING, Union
from typing import Optional, TYPE_CHECKING
from torch._subclasses.fake_tensor import is_fake
@ -57,13 +59,11 @@ if TYPE_CHECKING:
# - (perhaps) Define how it is compared in _HashableTracker._eq_impl
def was_instancecheck_override(obj: Any) -> bool:
def was_instancecheck_override(obj):
return type(obj).__dict__.get("__instancecheck__", False)
def raise_unhashable(
arg: VariableTracker, tx: Optional["InstructionTranslator"] = None
) -> None:
def raise_unhashable(arg, tx=None):
if tx is None:
from torch._dynamo.symbolic_convert import InstructionTranslator
@ -75,7 +75,7 @@ def raise_unhashable(
)
def is_hashable(x: VariableTracker) -> bool:
def is_hashable(x):
# NB - performing isinstance check on a LazVT realizes the VT, accidentally
# inserting the guard. To avoid this, lazyVT `is_hashable` methods looks at
# the underlying value without realizing the VT. Consider updating the
@ -143,7 +143,7 @@ class ConstDictVariable(VariableTracker):
Note that it's also fine to put VTs into dictionaries and sets, but doing so does not take into account aliasing
"""
def __init__(self, vt: VariableTracker) -> None:
def __init__(self, vt) -> None:
# We specialize SymNodes
vt = specialize_symnode(vt)
# TODO Temporarily remove to figure out what keys are we breaking on
@ -153,7 +153,7 @@ class ConstDictVariable(VariableTracker):
self.vt = vt
@property
def underlying_value(self) -> Any:
def underlying_value(self):
if (
isinstance(self.vt, variables.LazyVariableTracker)
and not self.vt.is_realized()
@ -178,8 +178,7 @@ class ConstDictVariable(VariableTracker):
elif isinstance(self.vt, variables.FrozenDataClassVariable):
Hashable = ConstDictVariable._HashableTracker
fields_values = {
k: Hashable(v).underlying_value
for k, v in self.vt.fields.items() # type: ignore[attr-defined]
k: Hashable(v).underlying_value for k, v in self.vt.fields.items()
}
return variables.FrozenDataClassVariable.HashWrapper(
self.vt.python_type(), fields_values
@ -188,16 +187,16 @@ class ConstDictVariable(VariableTracker):
# The re module in Python 3.13+ has a dictionary (_cache2) with
# an object as key (`class _ZeroSentinel(int): ...`):
# python test/dynamo/test_unittest.py CPythonTestLongMessage.test_baseAssertEqual
return self.vt.value # type: ignore[attr-defined,union-attr]
return self.vt.value
else:
x = self.vt.as_python_constant()
return x
def __hash__(self) -> int:
def __hash__(self):
return hash(self.underlying_value)
@staticmethod
def _eq_impl(a: Any, b: Any) -> bool:
def _eq_impl(a, b):
# TODO: Put this in utils and share it between variables/builtin.py and here
type_a, type_b = type(a), type(b)
if not (issubclass(type_a, type_b) or issubclass(type_b, type_a)):
@ -213,7 +212,7 @@ class ConstDictVariable(VariableTracker):
else:
return a == b
def __eq__(self, other: object) -> bool:
def __eq__(self, other: "ConstDictVariable._HashableTracker") -> bool:
Hashable = ConstDictVariable._HashableTracker
assert isinstance(other, Hashable) or ConstantVariable.is_literal(other), (
type(other)
@ -227,8 +226,8 @@ class ConstDictVariable(VariableTracker):
def __init__(
self,
items: dict[VariableTracker, VariableTracker],
user_cls: type = dict,
**kwargs: Any,
user_cls=dict,
**kwargs,
) -> None:
# .clone() pass these arguments in kwargs but they're recreated a few
# lines below
@ -248,22 +247,18 @@ class ConstDictVariable(VariableTracker):
for x, v in items.items()
)
def make_hashable(
key: Union[VariableTracker, "ConstDictVariable._HashableTracker"],
) -> "ConstDictVariable._HashableTracker":
def make_hashable(key):
return key if isinstance(key, Hashable) else Hashable(key)
dict_cls = self._get_dict_cls_from_user_cls(user_cls)
self.items = dict_cls({make_hashable(x): v for x, v in items.items()})
# need to reconstruct everything if the dictionary is an intermediate value
# or if a pop/delitem was executed
self.should_reconstruct_all = (
not is_from_local_source(self.source) if self.source else True
)
self.should_reconstruct_all = not is_from_local_source(self.source)
self.original_items = items.copy()
self.user_cls = user_cls
def _get_dict_cls_from_user_cls(self, user_cls: type) -> type:
def _get_dict_cls_from_user_cls(self, user_cls):
accepted_dict_types = (dict, collections.OrderedDict, collections.defaultdict)
# avoid executing user code if user_cls is a dict subclass
@ -282,10 +277,10 @@ class ConstDictVariable(VariableTracker):
dict_cls = dict
return dict_cls
def as_proxy(self) -> dict[Any, Any]:
def as_proxy(self):
return {k.vt.as_proxy(): v.as_proxy() for k, v in self.items.items()}
def debug_repr(self) -> str:
def debug_repr(self):
return (
"{"
+ ", ".join(
@ -294,20 +289,20 @@ class ConstDictVariable(VariableTracker):
+ "}"
)
def as_python_constant(self) -> dict[Any, Any]:
def as_python_constant(self):
return {
k.vt.as_python_constant(): v.as_python_constant()
for k, v in self.items.items()
}
def keys_as_python_constant(self) -> dict[Any, VariableTracker]:
def keys_as_python_constant(self):
self.install_dict_keys_match_guard()
return {k.vt.as_python_constant(): v for k, v in self.items.items()}
def python_type(self) -> type:
def python_type(self):
return self.user_cls
def __contains__(self, vt: VariableTracker) -> bool:
def __contains__(self, vt) -> bool:
assert isinstance(vt, VariableTracker)
Hashable = ConstDictVariable._HashableTracker
return (
@ -327,15 +322,13 @@ class ConstDictVariable(VariableTracker):
for key, value in self.items.items()
)
def is_new_item(
self, value: Optional[VariableTracker], other: VariableTracker
) -> bool:
def is_new_item(self, value, other):
# compare the id of the realized values if both values are not lazy VTs
if value and value.is_realized() and other.is_realized():
return id(value.realize()) != id(other.realize())
return id(value) != id(other)
def reconstruct_kvs_into_new_dict(self, codegen: "PyCodegen") -> None:
def reconstruct_kvs_into_new_dict(self, codegen):
# Build a dictionary that contains the keys and values.
num_args = 0
for key, value in self.items.items():
@ -347,7 +340,7 @@ class ConstDictVariable(VariableTracker):
num_args += 1
codegen.append_output(create_instruction("BUILD_MAP", arg=num_args))
def reconstruct(self, codegen: "PyCodegen") -> None:
def reconstruct(self, codegen: "PyCodegen"):
if self.user_cls is collections.OrderedDict:
# emit `OrderedDict(constructed_dict)`
codegen.add_push_null(
@ -365,21 +358,19 @@ class ConstDictVariable(VariableTracker):
def getitem_const_raise_exception_if_absent(
self, tx: "InstructionTranslator", arg: VariableTracker
) -> VariableTracker:
):
key = ConstDictVariable._HashableTracker(arg)
if key not in self.items:
raise_observed_exception(KeyError, tx)
return self.items[key]
def getitem_const(
self, tx: "InstructionTranslator", arg: VariableTracker
) -> VariableTracker:
def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
key = ConstDictVariable._HashableTracker(arg)
if key not in self.items:
msg = f"Dictionary key {arg.value} not found during tracing" # type: ignore[attr-defined]
msg = f"Dictionary key {arg.value} not found during tracing"
unimplemented_v2(
gb_type="key not found in dict",
context=f"Key {arg.value}", # type: ignore[attr-defined]
context=f"Key {arg.value}",
explanation=msg,
hints=[
"Check if the key exists in the dictionary before accessing it.",
@ -388,13 +379,13 @@ class ConstDictVariable(VariableTracker):
)
return self.items[key]
def maybe_getitem_const(self, arg: VariableTracker) -> Optional[VariableTracker]:
def maybe_getitem_const(self, arg: VariableTracker):
key = ConstDictVariable._HashableTracker(arg)
if key not in self.items:
return None
return self.items[key]
def realize_key_vt(self, arg: VariableTracker) -> None:
def realize_key_vt(self, arg: VariableTracker):
# Realize the LazyVT on a particular index
assert arg in self
key = ConstDictVariable._HashableTracker(arg)
@ -403,13 +394,11 @@ class ConstDictVariable(VariableTracker):
if isinstance(original_key_vt, variables.LazyVariableTracker):
original_key_vt.realize()
def install_dict_keys_match_guard(self) -> None:
def install_dict_keys_match_guard(self):
if self.source:
install_guard(self.make_guard(GuardBuilder.DICT_KEYS_MATCH))
def install_dict_contains_guard(
self, tx: "InstructionTranslator", args: list[VariableTracker]
) -> None:
def install_dict_contains_guard(self, tx, args):
# Key guarding - These are the cases to consider
# 1) The dict has been mutated. In this case, we would have already
# inserted a DICT_KEYS_MATCH guard, so we can skip.
@ -450,11 +439,11 @@ class ConstDictVariable(VariableTracker):
def call_method(
self,
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
tx,
name,
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
# NB - Both key and value are LazyVariableTrackers in the beginning. So,
# we have to insert guards when a dict method is accessed. For this to
# be simple, we are conservative and overguard. We skip guard only for
@ -473,7 +462,7 @@ class ConstDictVariable(VariableTracker):
tx, *args, **kwargs
)
tx.output.side_effects.mutation(self)
self.items.update(temp_dict_vt.items) # type: ignore[attr-defined]
self.items.update(temp_dict_vt.items)
return ConstantVariable.create(None)
elif name == "__getitem__":
# Key guarding - Nothing to do. LazyVT for value will take care.
@ -537,7 +526,7 @@ class ConstDictVariable(VariableTracker):
return ConstantVariable.create(len(self.items))
elif name == "__setitem__" and self.is_mutable():
if not arg_hashable:
raise_unhashable(args[0], tx)
raise_unhashable(args[0])
self.install_dict_keys_match_guard()
if kwargs or len(args) != 2:
@ -561,7 +550,7 @@ class ConstDictVariable(VariableTracker):
raise_args_mismatch(tx, name, "1 or 2 args", f"{len(args)} args")
if not arg_hashable:
raise_unhashable(args[0], tx)
raise_unhashable(args[0])
if args[0] not in self:
self.install_dict_contains_guard(tx, args)
@ -576,7 +565,7 @@ class ConstDictVariable(VariableTracker):
raise_args_mismatch(tx, name, "1 or 2 args", f"{len(args)} args")
if not arg_hashable:
raise_unhashable(args[0], tx)
raise_unhashable(args[0])
if args[0] not in self:
# missing item, return the default value. Install no DICT_CONTAINS guard.
@ -610,7 +599,7 @@ class ConstDictVariable(VariableTracker):
last = v.value
else:
raise_args_mismatch(tx, name)
k, v = self.items.popitem(last=last) # type: ignore[possibly-undefined]
k, v = self.items.popitem(last=last)
else:
k, v = self.items.popitem()
@ -643,17 +632,17 @@ class ConstDictVariable(VariableTracker):
# NB - Guard on all the keys of the other dict to ensure
# correctness.
args[0].install_dict_keys_match_guard()
dict_vt: ConstDictVariable = args[0]
dict_vt = args[0]
else:
dict_vt = BuiltinVariable.call_custom_dict(tx, dict, args[0]) # type: ignore[assignment]
self.items.update(dict_vt.items) # type: ignore[attr-defined]
dict_vt = BuiltinVariable.call_custom_dict(tx, dict, args[0])
self.items.update(dict_vt.items)
if has_kwargs:
# Handle kwargs
kwargs_hashable = {
kwargs = {
Hashable(ConstantVariable.create(k)): v
for k, v in kwargs.items()
}
self.items.update(kwargs_hashable)
self.items.update(kwargs)
return ConstantVariable.create(None)
else:
return super().call_method(tx, name, args, kwargs)
@ -667,7 +656,7 @@ class ConstDictVariable(VariableTracker):
)
if not arg_hashable:
raise_unhashable(args[0], tx)
raise_unhashable(args[0])
self.install_dict_contains_guard(tx, args)
contains = args[0] in self
@ -682,7 +671,7 @@ class ConstDictVariable(VariableTracker):
)
if not arg_hashable:
raise_unhashable(args[0], tx)
raise_unhashable(args[0])
self.install_dict_keys_match_guard()
if kwargs or len(args) > 2:
@ -718,7 +707,7 @@ class ConstDictVariable(VariableTracker):
and "last" in kwargs
and isinstance(kwargs["last"], ConstantVariable)
):
last = kwargs.get("last").value # type: ignore[union-attr]
last = kwargs.get("last").value
key = Hashable(args[0])
self.items.move_to_end(key, last=last)
@ -734,7 +723,7 @@ class ConstDictVariable(VariableTracker):
)
elif name == "__ne__":
return ConstantVariable.create(
not self.call_method(tx, "__eq__", args, kwargs).value # type: ignore[attr-defined]
not self.call_method(tx, "__eq__", args, kwargs).value
)
elif name == "__or__":
if len(args) != 1:
@ -761,14 +750,14 @@ class ConstDictVariable(VariableTracker):
if not istype(
other, (ConstDictVariable, variables.UserDefinedDictVariable)
):
err_msg = (
msg = (
f"unsupported operand type(s) for |: '{self.python_type().__name__}'"
f"and '{other.python_type().__name__}'"
)
raise_observed_exception(TypeError, tx, args=[err_msg])
raise_observed_exception(TypeError, tx, args=[msg])
# OrderedDict overloads __ror__
ts = {self.user_cls, other.user_cls} # type: ignore[attr-defined]
ts = {self.user_cls, other.user_cls}
user_cls = (
collections.OrderedDict
if any(issubclass(t, collections.OrderedDict) for t in ts)
@ -785,8 +774,8 @@ class ConstDictVariable(VariableTracker):
# NB - Guard on all the keys of the other dict to ensure
# correctness.
args[0].install_dict_keys_match_guard() # type: ignore[attr-defined]
new_dict_vt.items.update(args[0].items) # type: ignore[attr-defined]
args[0].install_dict_keys_match_guard()
new_dict_vt.items.update(args[0].items)
return new_dict_vt
elif name == "__ior__":
self.call_method(tx, "update", args, kwargs)
@ -800,13 +789,11 @@ class ConstDictVariable(VariableTracker):
else:
return super().call_method(tx, name, args, kwargs)
def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]:
def unpack_var_sequence(self, tx):
self.install_dict_keys_match_guard()
return [x.vt for x in self.items.keys()]
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> VariableTracker:
def call_obj_hasattr(self, tx, name):
# dict not allow setting arbitrary attributes. OrderedDict and
# defaultdict allow arbitrary setattr, but not deletion of default attrs
if any(
@ -829,25 +816,25 @@ class ConstDictVariable(VariableTracker):
],
)
def clone(self, **kwargs: Any) -> VariableTracker:
def clone(self, **kwargs):
self.install_dict_keys_match_guard()
return super().clone(**kwargs)
class MappingProxyVariable(VariableTracker):
# proxies to the original dict_vt
def __init__(self, dv_dict: ConstDictVariable, **kwargs: Any) -> None:
def __init__(self, dv_dict: ConstDictVariable, **kwargs) -> None:
super().__init__(**kwargs)
assert isinstance(dv_dict, ConstDictVariable)
self.dv_dict = dv_dict
def python_type(self) -> type:
def python_type(self):
return types.MappingProxyType
def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]:
def unpack_var_sequence(self, tx):
return self.dv_dict.unpack_var_sequence(tx)
def reconstruct(self, codegen: "PyCodegen") -> None:
def reconstruct(self, codegen: "PyCodegen"):
# load types.MappingProxyType
if self.source:
msg = (
@ -876,11 +863,11 @@ class MappingProxyVariable(VariableTracker):
def call_method(
self,
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
tx,
name,
args: list["VariableTracker"],
kwargs: dict[str, "VariableTracker"],
) -> "VariableTracker":
if self.source and tx.output.side_effects.has_existing_dict_mutation():
msg = (
"A dict has been modified while we have an existing mappingproxy object. "
@ -905,7 +892,7 @@ class MappingProxyVariable(VariableTracker):
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> VariableTracker:
) -> "VariableTracker":
if self.python_type() is types.MappingProxyType:
return ConstantVariable.create(name in types.MappingProxyType.__dict__)
return super().call_obj_hasattr(tx, name)
@ -913,44 +900,35 @@ class MappingProxyVariable(VariableTracker):
class NNModuleHooksDictVariable(ConstDictVariable):
# Special class to avoid adding any guards on the nn module hook ids.
def install_dict_keys_match_guard(self) -> None:
def install_dict_keys_match_guard(self):
pass
def install_dict_contains_guard(
self, tx: "InstructionTranslator", args: list[VariableTracker]
) -> None:
def install_dict_contains_guard(self, tx, args):
pass
class DefaultDictVariable(ConstDictVariable):
def __init__(
self,
items: dict[VariableTracker, VariableTracker],
user_cls: type,
default_factory: Optional[VariableTracker] = None,
**kwargs: Any,
) -> None:
def __init__(self, items, user_cls, default_factory=None, **kwargs) -> None:
super().__init__(items, user_cls, **kwargs)
assert user_cls is collections.defaultdict
if default_factory is None:
default_factory = ConstantVariable.create(None)
self.default_factory = default_factory
def is_python_constant(self) -> bool:
def is_python_constant(self):
# Return false for unsupported defaults. This ensures that a bad handler
# path is not taken in BuiltinVariable for getitem.
if self.default_factory not in [list, tuple, dict] and not self.items:
return False
return super().is_python_constant()
def debug_repr(self) -> str:
assert self.default_factory is not None
def debug_repr(self):
return (
f"defaultdict({self.default_factory.debug_repr()}, {super().debug_repr()})"
)
@staticmethod
def is_supported_arg(arg: VariableTracker) -> bool:
def is_supported_arg(arg):
if isinstance(arg, variables.BuiltinVariable):
return arg.fn in (list, tuple, dict, set)
else:
@ -964,11 +942,11 @@ class DefaultDictVariable(ConstDictVariable):
def call_method(
self,
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
tx,
name,
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
if name == "__getitem__":
if len(args) != 1:
raise_args_mismatch(tx, name, "1 args", f"{len(args)} args")
@ -984,13 +962,13 @@ class DefaultDictVariable(ConstDictVariable):
else:
default_var = self.default_factory.call_function(tx, [], {})
super().call_method(
tx, "__setitem__", [args[0], default_var], kwargs
tx, "__setitem__", (args[0], default_var), kwargs
)
return default_var
else:
return super().call_method(tx, name, args, kwargs)
def reconstruct(self, codegen: "PyCodegen") -> None:
def reconstruct(self, codegen):
# emit `defaultdict(default_factory, new_dict)`
codegen.add_push_null(
lambda: codegen.extend_output(
@ -1016,48 +994,40 @@ class SetVariable(ConstDictVariable):
def __init__(
self,
items: list[VariableTracker],
**kwargs: Any,
**kwargs,
) -> None:
# pyrefly: ignore[bad-assignment]
items = dict.fromkeys(items, SetVariable._default_value())
# pyrefly: ignore[bad-argument-type]
super().__init__(items, **kwargs)
def debug_repr(self) -> str:
def debug_repr(self):
if not self.items:
return "set()"
else:
return "{" + ",".join(k.vt.debug_repr() for k in self.items.keys()) + "}"
@property
def set_items(self) -> set["ConstDictVariable._HashableTracker"]:
def set_items(self):
return set(self.items.keys())
@staticmethod
def _default_value() -> VariableTracker:
def _default_value():
# Variable to fill in he keys of the dictionary
return ConstantVariable.create(None)
def as_proxy(self) -> Any:
def as_proxy(self):
return {k.vt.as_proxy() for k in self.set_items}
def python_type(self) -> type:
def python_type(self):
return set
def as_python_constant(self) -> Any:
def as_python_constant(self):
return {k.vt.as_python_constant() for k in self.set_items}
def reconstruct(self, codegen: "PyCodegen") -> None:
def reconstruct(self, codegen: "PyCodegen"):
codegen.foreach([x.vt for x in self.set_items])
codegen.append_output(create_instruction("BUILD_SET", arg=len(self.set_items)))
def _fast_set_method(
self,
tx: "InstructionTranslator",
fn: Any,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
def _fast_set_method(self, tx, fn, args, kwargs):
try:
res = fn(
*[x.as_python_constant() for x in [self, *args]],
@ -1067,16 +1037,15 @@ class SetVariable(ConstDictVariable):
raise_observed_exception(
type(exc), tx, args=list(map(ConstantVariable.create, exc.args))
)
# pyrefly: ignore[unbound-name]
return VariableTracker.build(tx, res)
def call_method(
self,
tx: "InstructionTranslator",
name: str,
tx,
name,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
) -> "VariableTracker":
# We forward the calls to the dictionary model
from ..utils import check_constant_args
@ -1096,10 +1065,10 @@ class SetVariable(ConstDictVariable):
return self._fast_set_method(tx, getattr(py_type, name), args, kwargs)
if name == "__init__":
temp_set_vt = variables.BuiltinVariable(set).call_set(tx, *args, **kwargs)
temp_set_vt = variables.BuiltinVariable(set).call_set(tx, *args, *kwargs)
tx.output.side_effects.mutation(self)
self.items.clear()
self.items.update(temp_set_vt.items) # type: ignore[attr-defined]
self.items.update(temp_set_vt.items)
return ConstantVariable.create(None)
elif name == "add":
if kwargs or len(args) != 1:
@ -1110,7 +1079,7 @@ class SetVariable(ConstDictVariable):
f"{len(args)} args and {len(kwargs)} kwargs",
)
name = "__setitem__"
args = [args[0], SetVariable._default_value()]
args = (args[0], SetVariable._default_value())
elif name == "pop":
if kwargs or args:
raise_args_mismatch(
@ -1121,14 +1090,12 @@ class SetVariable(ConstDictVariable):
)
# Choose an item at random and pop it via the Dict.pop method
try:
result: VariableTracker = self.set_items.pop().vt # type: ignore[assignment]
result = self.set_items.pop().vt
except KeyError as e:
raise_observed_exception(
KeyError, tx, args=list(map(ConstantVariable.create, e.args))
)
# pyrefly: ignore[unbound-name]
super().call_method(tx, name, [result], kwargs)
# pyrefly: ignore[unbound-name]
super().call_method(tx, name, (result,), kwargs)
return result
elif name == "isdisjoint":
if kwargs or len(args) != 1:
@ -1250,7 +1217,6 @@ class SetVariable(ConstDictVariable):
f"unsupported operand type(s) for {name}: '{self.python_type_name()}' and '{args[0].python_type_name()}'"
)
raise_observed_exception(TypeError, tx, args=[msg])
assert m is not None
return self.call_method(tx, m, args, kwargs)
elif name in ("__iand__", "__ior__", "__ixor__", "__isub__"):
if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)):
@ -1264,34 +1230,29 @@ class SetVariable(ConstDictVariable):
"__ixor__": "symmetric_difference_update",
"__isub__": "difference_update",
}.get(name)
assert m is not None
self.call_method(tx, m, args, kwargs)
return self
elif name == "__eq__":
if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)):
return ConstantVariable.create(False)
r = self.call_method(tx, "symmetric_difference", args, kwargs)
return ConstantVariable.create(len(r.set_items) == 0) # type: ignore[attr-defined]
return ConstantVariable.create(len(r.set_items) == 0)
elif name in cmp_name_to_op_mapping:
if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)):
return ConstantVariable.create(NotImplemented)
return ConstantVariable.create(
cmp_name_to_op_mapping[name](self.set_items, args[0].set_items) # type: ignore[attr-defined]
cmp_name_to_op_mapping[name](self.set_items, args[0].set_items)
)
return super().call_method(tx, name, args, kwargs)
def getitem_const(
self, tx: "InstructionTranslator", arg: VariableTracker
) -> VariableTracker:
def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
raise RuntimeError("Illegal to getitem on a set")
def install_dict_keys_match_guard(self) -> None:
def install_dict_keys_match_guard(self):
# Already EQUALS_MATCH guarded
pass
def install_dict_contains_guard(
self, tx: "InstructionTranslator", args: list[VariableTracker]
) -> None:
def install_dict_contains_guard(self, tx, args):
super().install_dict_contains_guard(tx, args)
@ -1299,27 +1260,27 @@ class FrozensetVariable(SetVariable):
def __init__(
self,
items: list[VariableTracker],
**kwargs: Any,
**kwargs,
) -> None:
super().__init__(items, **kwargs)
def debug_repr(self) -> str:
def debug_repr(self):
if not self.items:
return "frozenset()"
else:
return "{" + ",".join(k.vt.debug_repr() for k in self.items.keys()) + "}"
@property
def set_items(self) -> set["ConstDictVariable._HashableTracker"]:
def set_items(self):
return self.items.keys()
def python_type(self) -> type:
def python_type(self):
return frozenset
def as_python_constant(self) -> Any:
def as_python_constant(self):
return frozenset({k.vt.as_python_constant() for k in self.set_items})
def reconstruct(self, codegen: "PyCodegen") -> None:
def reconstruct(self, codegen: "PyCodegen"):
codegen.foreach([x.vt for x in self.set_items])
codegen.add_push_null(
lambda: codegen.extend_output(
@ -1332,11 +1293,11 @@ class FrozensetVariable(SetVariable):
def call_method(
self,
tx: "InstructionTranslator",
name: str,
tx,
name,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
) -> "VariableTracker":
if name in ["add", "pop", "update", "remove", "discard", "clear"]:
raise RuntimeError(f"Illegal call_method {name} on a frozenset")
elif name == "__init__":
@ -1355,7 +1316,7 @@ class FrozensetVariable(SetVariable):
"symmetric_difference",
):
r = super().call_method(tx, name, args, kwargs)
return FrozensetVariable(r.items) # type: ignore[attr-defined]
return FrozensetVariable(r.items)
return super().call_method(tx, name, args, kwargs)
@ -1363,11 +1324,11 @@ class DictKeySetVariable(SetVariable):
def __init__(
self,
items: list[VariableTracker],
**kwargs: Any,
**kwargs,
) -> None:
super().__init__(items, **kwargs)
def debug_repr(self) -> str:
def debug_repr(self):
if not self.items:
return "dict_keys([])"
else:
@ -1377,35 +1338,33 @@ class DictKeySetVariable(SetVariable):
+ "])"
)
def install_dict_keys_match_guard(self) -> None:
def install_dict_keys_match_guard(self):
# Already EQUALS_MATCH guarded
pass
def install_dict_contains_guard(
self, tx: "InstructionTranslator", args: list[VariableTracker]
) -> None:
def install_dict_contains_guard(self, tx, args):
# Already EQUALS_MATCH guarded
pass
@property
def set_items(self) -> Any:
def set_items(self):
return self.items
def python_type(self) -> type:
def python_type(self):
return dict_keys
def as_python_constant(self) -> Any:
def as_python_constant(self):
return dict.fromkeys(
{k.vt.as_python_constant() for k in self.set_items}, None
).keys()
def call_method(
self,
tx: "InstructionTranslator",
name: str,
tx,
name,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
) -> "VariableTracker":
if name in ["add", "pop", "update", "remove", "discard", "clear"]:
raise RuntimeError(f"Illegal call_method {name} on a dict_keys")
return super().call_method(tx, name, args, kwargs)
@ -1420,47 +1379,42 @@ class DictViewVariable(VariableTracker):
kv: Optional[str] = None
def __init__(self, dv_dict: ConstDictVariable, **kwargs: Any) -> None:
def __init__(self, dv_dict: ConstDictVariable, **kwargs) -> None:
super().__init__(**kwargs)
assert self.kv in ("keys", "values", "items")
assert isinstance(dv_dict, ConstDictVariable)
self.dv_dict = dv_dict
@property
def view_items(self) -> Any:
assert self.kv is not None
def view_items(self):
return getattr(self.dv_dict.items, self.kv)()
@property
def view_items_vt(self) -> list[VariableTracker]:
def view_items_vt(self):
# Returns an iterable of the unpacked items
# Implement in the subclasses
raise NotImplementedError
def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]:
def unpack_var_sequence(self, tx):
return self.view_items_vt
def reconstruct(self, codegen: "PyCodegen") -> None:
assert self.kv is not None
def reconstruct(self, codegen: "PyCodegen"):
codegen(self.dv_dict)
codegen.load_method(self.kv)
codegen.call_method(0)
def call_obj_hasattr(
self, tx: "InstructionTranslator", name: str
) -> VariableTracker:
assert self.kv is not None
def call_obj_hasattr(self, tx, name):
if name in self.python_type().__dict__:
return ConstantVariable.create(True)
return ConstantVariable.create(False)
def call_method(
self,
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
tx,
name,
args: list["VariableTracker"],
kwargs: dict[str, "VariableTracker"],
) -> "VariableTracker":
if name == "__len__":
return self.dv_dict.call_method(tx, name, args, kwargs)
elif name == "__iter__":
@ -1474,24 +1428,24 @@ class DictKeysVariable(DictViewVariable):
kv = "keys"
@property
def set_items(self) -> set[VariableTracker]:
def set_items(self):
return set(self.view_items)
@property
def view_items_vt(self) -> list[VariableTracker]:
def view_items_vt(self):
# Returns an iterable of the unpacked items
return [x.vt for x in self.view_items]
def python_type(self) -> type:
def python_type(self):
return dict_keys
def call_method(
self,
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
tx,
name,
args: list["VariableTracker"],
kwargs: dict[str, "VariableTracker"],
) -> "VariableTracker":
if name == "__contains__":
return self.dv_dict.call_method(tx, name, args, kwargs)
elif name in (
@ -1506,13 +1460,13 @@ class DictKeysVariable(DictViewVariable):
):
# These methods always returns a set
m = getattr(self.set_items, name)
r = m(args[0].set_items) # type: ignore[attr-defined]
r = m(args[0].set_items)
return SetVariable(r)
if name in cmp_name_to_op_mapping:
if not isinstance(args[0], (SetVariable, DictKeysVariable)):
return ConstantVariable.create(NotImplemented)
return ConstantVariable.create(
cmp_name_to_op_mapping[name](self.set_items, args[0].set_items) # type: ignore[attr-defined]
cmp_name_to_op_mapping[name](self.set_items, args[0].set_items)
)
return super().call_method(tx, name, args, kwargs)
@ -1522,10 +1476,10 @@ class DictValuesVariable(DictViewVariable):
kv = "values"
@property
def view_items_vt(self) -> list[VariableTracker]:
def view_items_vt(self):
return list(self.view_items)
def python_type(self) -> type:
def python_type(self):
return dict_values
@ -1533,20 +1487,14 @@ class DictItemsVariable(DictViewVariable):
kv = "items"
@property
def view_items_vt(self) -> list[VariableTracker]:
def view_items_vt(self):
# Returns an iterable of the unpacked items
return [variables.TupleVariable([k.vt, v]) for k, v in self.view_items]
def python_type(self) -> type:
def python_type(self):
return dict_items
def call_method(
self,
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
def call_method(self, tx, name, args, kwargs):
# TODO(guilhermeleobas): This should actually check if args[0]
# implements the mapping protocol.
if name == "__eq__":

View File

@ -2970,6 +2970,12 @@ class CppPythonBindingsCodeCache(CppCodeCache):
throw std::runtime_error("expected int arg");
return reinterpret_cast<uintptr_t>(result);
}}
template <> inline float parse_arg<float>(PyObject* args, size_t n) {{
auto result = PyFloat_AsDouble(PyTuple_GET_ITEM(args, n));
if(unlikely(result == -1.0 && PyErr_Occurred()))
throw std::runtime_error("expected float arg");
return static_cast<float>(result);
}}
{extra_parse_arg}

View File

@ -1732,9 +1732,15 @@ class KernelArgs:
call_args.append(self.wrap_ptr_arg(outer, dtype))
arg_types.append(f"{cpp_dtype}*")
for outer, inner in self.sizevars.items():
arg_defs.append(f"const {INDEX_TYPE} {inner}")
if isinstance(outer, sympy.Symbol) and symbol_is_type(
outer, (SymT.UNBACKED_FLOAT)
):
arg_defs.append(f"const float {inner}")
arg_types.append("const float")
else:
arg_defs.append(f"const {INDEX_TYPE} {inner}")
arg_types.append(f"const {INDEX_TYPE}")
call_args.append(self.wrap_size_arg(outer))
arg_types.append(f"const {INDEX_TYPE}")
if V.graph.wrapper_code:
V.graph.wrapper_code.ensure_size_computed(outer)
assert not self.workspace_args, "Workspace not supported on CPU "
@ -2353,6 +2359,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
SymT.UNBACKED_INT,
SymT.SIZE,
SymT.PRECOMPUTED_SIZE,
SymT.UNBACKED_FLOAT,
),
)
}

View File

@ -2,7 +2,7 @@
from __future__ import annotations
import hashlib
from typing import Any, Optional, TYPE_CHECKING
from typing import Any, Callable, Optional, Sequence, TYPE_CHECKING
import sympy # noqa: TC002
@ -17,8 +17,6 @@ from .simd import SIMDKernel, SIMDScheduling
if TYPE_CHECKING:
from collections.abc import Callable, Sequence
from ..ir import IRNode
from ..scheduler import BaseSchedulerNode

View File

@ -627,7 +627,7 @@ class ComboKernel(Kernel):
if heuristics == "foreach":
heuristics_line = f"""
@triton_heuristics.foreach(
filename=__file__,
num_warps={self.num_warps},
triton_meta={triton_meta!r},
inductor_meta={inductor_meta!r},
)

View File

@ -4,6 +4,7 @@ from typing import Any, Optional
import sympy
import torch
from torch.utils._sympy.symbol import symbol_is_type, SymT
from .. import config
from ..runtime.hints import AttrsDescriptorWrapper
@ -71,6 +72,10 @@ def signature_of(arg: KernelArgType, *, size_dtype: Optional[str]) -> str:
return "constexpr"
elif isinstance(arg.expr, (float, sympy.Float)):
return "fp32"
elif isinstance(arg.expr, sympy.Symbol) and symbol_is_type(
arg.expr, (SymT.UNBACKED_FLOAT)
):
return "fp32"
elif isinstance(arg.expr, bool):
return "i1"

View File

@ -360,7 +360,7 @@ def estimate_nccl_collective_runtime_from_fx_node(
fx_node: torch.fx.Node,
override_size: Optional[int] = None,
# TODO(ivankobzarev): NCCL estimator sometimes fail unexpectedly, enable back after fix.
use_nccl_estimator: bool = True,
use_nccl_estimator: bool = False,
) -> float:
"""
Returns estimated NCCL collective runtime in nanoseconds (ns).

View File

@ -1,6 +1,6 @@
import os
from collections.abc import Callable
from functools import cache, partial
from typing import Callable
import torch
from torch._environment import is_fbcode

View File

@ -3586,24 +3586,13 @@ def user_autotune(
)
def foreach(triton_meta, filename=None, inductor_meta=None):
def foreach(triton_meta, num_warps, filename=None, inductor_meta=None):
"""
Compile a triton foreach kernel
"""
configs = []
# Naive autotuning path for num_warps
if not (
inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise")
):
configs.append(triton.Config({}, num_stages=1, num_warps=8))
else:
for warps in [1, 2, 4, 8]:
configs.append(triton.Config({}, num_stages=1, num_warps=warps))
return cached_autotune(
None,
configs,
[triton.Config({}, num_stages=1, num_warps=num_warps)],
triton_meta=triton_meta,
inductor_meta=inductor_meta,
heuristic_type=HeuristicType.TEMPLATE,

View File

@ -52,7 +52,26 @@ __all__ = [
"MemRecordsAcc",
]
from contextlib import ContextDecorator
try:
# Available in Python >= 3.2
from contextlib import ContextDecorator as _ContextDecorator
except ImportError:
import functools
class _ContextDecorator: # type: ignore[no-redef]
def __enter__(self):
raise NotImplementedError
def __exit__(self, exc_type, exc_val, exc_tb):
raise NotImplementedError
def __call__(self, func):
@functools.wraps(func)
def wrapped(*args, **kwargs):
with self:
return func(*args, **kwargs)
return wrapped
# global python state - whether profiler is currently enabled
@ -725,7 +744,8 @@ class profile:
return all_function_events
class record_function(ContextDecorator):
# pyrefly: ignore [invalid-inheritance]
class record_function(_ContextDecorator):
"""Context manager/function decorator that adds a label to a code block/function when running autograd profiler.
Label will only appear if CPU activity tracing is enabled.

View File

@ -108,14 +108,12 @@ struct FlightRecorder {
capture_cpp_stack_ = getCvarBool(
{"TORCH_FR_CPP_STACK", "TORCH_NCCL_TRACE_CPP_STACK"}, false);
enabled_ = max_entries_ > 0;
reset_epoch_start_idx_[0] = 0;
}
struct Entry {
size_t id_; // incremented id in the trace buffer
// used to figure out where in the circular entries
// buffer this entry will be located to
// update state information
size_t reset_epoch_; // epoch when this entry was created
size_t pg_id_;
std::tuple<std::string, std::string> pg_name_; // <group_name, group_desc>
@ -185,34 +183,11 @@ struct FlightRecorder {
size_t max_entries_ = 0;
size_t next_ = 0;
size_t id_ = 0;
size_t reset_epoch_ = 0;
std::unordered_map<size_t, size_t>
reset_epoch_start_idx_; // maps reset_epoch to the idx where it starts
std::map<size_t, std::shared_ptr<ProcessGroupStatus>> all_pg_status_;
std::map<std::tuple<std::string, std::string>, std::vector<uint64_t>>
pg_name_to_ranks_;
std::string comm_lib_version_;
struct TraceIdentifier {
std::optional<size_t> id;
std::optional<size_t> reset_epoch;
};
TraceIdentifier recordWithResetEnabled(
size_t pg_id,
const std::tuple<std::string, std::string>& pg_name,
size_t collective_seq_id,
size_t p2p_seq_id,
size_t op_id,
std::string profiling_name,
const std::vector<at::Tensor>& inputs,
const std::vector<at::Tensor>& outputs,
EventType* start,
EventType* end,
std::chrono::milliseconds timeout_ms,
std::shared_ptr<ProcessGroupStatus> pg_status,
bool isP2P);
std::optional<size_t> record(
size_t pg_id,
const std::tuple<std::string, std::string>& pg_name,
@ -238,16 +213,8 @@ struct FlightRecorder {
std::vector<Entry> dump_entries();
// Returns the index in entries_ for the given id and reset_epoch.
// Caller must hold mutex_lock before calling this method.
size_t getIdxFromId(size_t id, size_t reset_epoch) const;
// Returns the entry with the given id and reset_epoch, if it exists.
// Otherwise, returns std::nullopt.
TORCH_API std::optional<Entry> getEntry(
std::optional<size_t> id,
std::optional<size_t> reset_epoch);
// Returns the entry with the given id, if it exists. Otherwise, returns
// std::nullopt.
TORCH_API std::optional<Entry> getEntry(std::optional<size_t> id);
/*
@ -260,11 +227,6 @@ struct FlightRecorder {
never hang. (timing must also be enabled for compute_duration - see
TORCH_NCCL_ENABLE_TIMING).
*/
TORCH_API void retire_id(
std::optional<size_t> id,
std::optional<size_t> reset_epoch,
bool compute_duration = true);
TORCH_API void retire_id(
std::optional<size_t> id,
bool compute_duration = true);

View File

@ -53,41 +53,8 @@ std::optional<size_t> FlightRecorder<EventType>::record(
std::chrono::milliseconds timeout_ms,
std::shared_ptr<ProcessGroupStatus> pg_status,
bool isP2P) {
auto result = recordWithResetEnabled(
pg_id,
pg_name,
collective_seq_id,
p2p_seq_id,
op_id,
std::move(profiling_name),
inputs,
outputs,
start,
end,
timeout_ms,
std::move(pg_status),
isP2P);
return result.id;
}
template <typename EventType>
typename FlightRecorder<EventType>::TraceIdentifier FlightRecorder<EventType>::
recordWithResetEnabled(
size_t pg_id,
const std::tuple<std::string, std::string>& pg_name,
size_t collective_seq_id,
size_t p2p_seq_id,
size_t op_id,
std::string profiling_name,
const std::vector<at::Tensor>& inputs,
const std::vector<at::Tensor>& outputs,
EventType* start,
EventType* end,
std::chrono::milliseconds timeout_ms,
std::shared_ptr<ProcessGroupStatus> pg_status,
bool isP2P) {
if (!enabled_) {
return TraceIdentifier{std::nullopt, std::nullopt};
return std::nullopt;
}
if (all_pg_status_.find(pg_id) == all_pg_status_.end()) {
// Current pg_status is not in FR.
@ -97,13 +64,8 @@ typename FlightRecorder<EventType>::TraceIdentifier FlightRecorder<EventType>::
torch::CapturedTraceback::gather(true, true, capture_cpp_stack_);
std::lock_guard<std::mutex> guard(mutex_);
TORCH_CHECK(
reset_epoch_start_idx_.find(reset_epoch_) !=
reset_epoch_start_idx_.end());
auto te = Entry{
id_,
reset_epoch_,
pg_id,
pg_name,
collective_seq_id,
@ -142,20 +104,15 @@ typename FlightRecorder<EventType>::TraceIdentifier FlightRecorder<EventType>::
te.sizes_.insert(te.sizes_.end(), sizes.begin(), sizes.end());
}
const auto next = next_++;
if (entries_.size() < max_entries_) {
entries_.emplace_back(std::move(te));
} else {
entries_[next] = std::move(te);
entries_[next_++] = std::move(te);
if (next_ == max_entries_) {
next_ = 0;
}
}
if (next_ == max_entries_) {
next_ = 0;
}
const auto id = id_++;
return TraceIdentifier{id, reset_epoch_};
return id_++;
}
template <typename EventType>
@ -206,20 +163,15 @@ std::vector<typename FlightRecorder<EventType>::Entry> FlightRecorder<
std::vector<Entry> result;
{
std::lock_guard<std::mutex> guard(mutex_);
// Filter entries during insertion - only keep entries from current epoch
auto filter = [this](const Entry& e) {
return e.reset_epoch_ == reset_epoch_;
};
std::copy_if(
result.reserve(entries_.size());
result.insert(
result.end(),
entries_.begin() + static_cast<std::ptrdiff_t>(next_),
entries_.end(),
std::back_inserter(result),
filter);
std::copy_if(
entries_.end());
result.insert(
result.end(),
entries_.begin(),
entries_.begin() + static_cast<std::ptrdiff_t>(next_),
std::back_inserter(result),
filter);
entries_.begin() + static_cast<std::ptrdiff_t>(next_));
}
// query any remaining events
for (auto& r : result) {
@ -230,47 +182,28 @@ std::vector<typename FlightRecorder<EventType>::Entry> FlightRecorder<
}
template <typename EventType>
// Returns the index in entries_ for the given id and reset_epoch.
// Caller must hold mutex_lock before calling this method.
size_t FlightRecorder<EventType>::getIdxFromId(size_t id, size_t reset_epoch)
const {
// Look up the starting idx for the given reset epoch
auto it = reset_epoch_start_idx_.find(reset_epoch);
TORCH_CHECK(it != reset_epoch_start_idx_.end());
// Calculate idx based on where the epoch started
return (it->second + id) % max_entries_;
}
template <typename EventType>
// Returns the entry with the given id and reset_epoch, if it exists. Otherwise,
// returns std::nullopt.
// Returns the entry with the given id, if it exists. Otherwise, returns
// std::nullopt.
std::optional<typename FlightRecorder<EventType>::Entry> FlightRecorder<
EventType>::
getEntry(std::optional<size_t> id, std::optional<size_t> reset_epoch) {
if (!enabled_ || !id || !reset_epoch) {
EventType>::getEntry(std::optional<size_t> id) {
if (!enabled_ || !id) {
return std::nullopt;
}
std::unique_lock<std::mutex> guard(mutex_);
Entry entry = entries_.at(getIdxFromId(*id, *reset_epoch));
if (entry.id_ == *id && entry.reset_epoch_ == *reset_epoch) {
Entry entry = entries_.at(*id % max_entries_);
if (entry.id_ == *id) {
return entry;
} else {
return std::nullopt;
}
return std::nullopt;
}
template <typename EventType>
std::optional<typename FlightRecorder<EventType>::Entry> FlightRecorder<
EventType>::getEntry(std::optional<size_t> id) {
return getEntry(id, 0);
}
template <typename EventType>
void FlightRecorder<EventType>::retire_id(
std::optional<size_t> id,
std::optional<size_t> reset_epoch,
bool compute_duration) {
if (!enabled_ || !id || !reset_epoch) {
if (!enabled_ || !id) {
return;
}
@ -281,8 +214,8 @@ void FlightRecorder<EventType>::retire_id(
std::unique_lock<std::mutex> guard(mutex_);
Entry* entry = &entries_.at(getIdxFromId(*id, *reset_epoch));
if (entry->id_ == *id && entry->reset_epoch_ == *reset_epoch) {
Entry* entry = &entries_.at(*id % max_entries_);
if (entry->id_ == *id) {
update_state(*entry);
if (compute_duration) {
@ -304,8 +237,8 @@ void FlightRecorder<EventType>::retire_id(
guard.lock();
// Refresh the entry pointer, see if the entry has been overwritten
entry = &entries_.at(getIdxFromId(*id, *reset_epoch));
if (!(entry->id_ == *id && entry->reset_epoch_ == *reset_epoch)) {
entry = &entries_.at(*id % max_entries_);
if (entry->id_ != *id) {
LOG(INFO) << "retire_id abandoned for id " << *id
<< ", event was overwritten while waiting to compute duration.";
return;
@ -316,23 +249,12 @@ void FlightRecorder<EventType>::retire_id(
}
}
template <typename EventType>
void FlightRecorder<EventType>::retire_id(
std::optional<size_t> id,
bool compute_duration) {
retire_id(id, 0, compute_duration);
}
template <typename EventType>
void FlightRecorder<EventType>::reset_all() {
std::lock_guard<std::mutex> guard(mutex_);
if (!entries_.empty()) {
// Soft delete: increment epoch to mark all existing entries as old
// Store where the new epoch starts in the circular buffer
reset_epoch_++;
reset_epoch_start_idx_[reset_epoch_] = next_;
id_ = 0;
}
next_ = 0;
id_ = 0;
entries_.clear();
}
template <typename EventType>

View File

@ -708,8 +708,7 @@ void ProcessGroupGloo::runLoop(int workerIndex) {
// TODO: We need to have numel of tensors for gloo as well.
pgStatus_->lastCompletedNumelIn = 0;
pgStatus_->lastCompletedNumelOut = 0;
FlightRecorder<c10::Event>::get()->retire_id(
work->trace_id_, work->trace_reset_epoch_, false);
FlightRecorder<c10::Event>::get()->retire_id(work->trace_id_, false);
lock.lock();
workInProgress_[workerIndex].reset();
}
@ -781,7 +780,7 @@ void ProcessGroupGloo::enqueue(c10::intrusive_ptr<AsyncWork> work) {
pgStatus_->lastEnqueuedNumelOut = 0;
// using c10d::FlightRecorder;
// TODO: We need to have a way to use c10::Event inside gloo as well.
auto traceId = FlightRecorder<c10::Event>::get()->recordWithResetEnabled(
work->trace_id_ = FlightRecorder<c10::Event>::get()->record(
local_id_,
std::make_tuple(pg_uid_, pg_desc_),
collectiveCounter_,
@ -796,8 +795,6 @@ void ProcessGroupGloo::enqueue(c10::intrusive_ptr<AsyncWork> work) {
work->getTimeout(),
pgStatus_,
false);
work->trace_id_ = traceId.id;
work->trace_reset_epoch_ = traceId.reset_epoch;
workQueue_.push_back(std::move(work));
lock.unlock();

View File

@ -99,7 +99,6 @@ class TORCH_API ProcessGroupGloo : public Backend {
// unique id used to tell the trace buffer that this
// work has completed
std::optional<uint64_t> trace_id_;
std::optional<uint64_t> trace_reset_epoch_;
std::shared_ptr<gloo::Context> context_;
const std::chrono::milliseconds timeout_;

View File

@ -575,7 +575,6 @@ ProcessGroupNCCL::WorkNCCL::WorkNCCL(const WorkNCCL& w)
futureWorkResult_(w.futureWorkResult_),
timingEnabled_(w.timingEnabled_),
trace_id_(w.trace_id_),
trace_reset_epoch_(w.trace_reset_epoch_),
distDebugLevel_(w.distDebugLevel_) {
exception_ = w.exception_;
}
@ -705,9 +704,9 @@ bool ProcessGroupNCCL::WorkNCCL::checkTimeout(
// Print the traceback of the collective at call time
std::string ProcessGroupNCCL::WorkNCCL::getTraceback() const {
// First step we get the corresponding record entry from FR, based on work's
// trace_id_ and trace_reset_epoch_
// trace_id_
std::optional<FlightRecorderCUDA::Entry> entry =
FlightRecorderCUDA::get()->getEntry(trace_id_, trace_reset_epoch_);
FlightRecorderCUDA::get()->getEntry(trace_id_);
if (entry.has_value()) {
auto entryVal = entry.value();
// Get stack trace from FR entry, in string format
@ -2395,8 +2394,7 @@ void ProcessGroupNCCL::Watchdog::runLoop() {
pg_->pgStatus_->lastCompletedWorkName = opTypeToString(work.opType_);
pg_->pgStatus_->lastCompletedNumelIn = work.numelIn_;
pg_->pgStatus_->lastCompletedNumelOut = work.numelOut_;
FlightRecorderCUDA::get()->retire_id(
work.trace_id_, work.trace_reset_epoch_, true);
FlightRecorderCUDA::get()->retire_id(work.trace_id_, true);
if (pg_->onCompletionHook_) {
// Move Work object to completedWorkList_ to be consumed by the hook
// thread
@ -3362,7 +3360,7 @@ c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> ProcessGroupNCCL::initWork(
// these objects to the Work because it has implications for keeping those
// tensors alive longer and adds overhead when copying Work objects
// between threads
auto traceId = FlightRecorderCUDA::get()->recordWithResetEnabled(
r->trace_id_ = FlightRecorderCUDA::get()->record(
local_id_,
std::make_tuple(pg_uid_, pg_desc_),
seqCollective_,
@ -3376,8 +3374,6 @@ c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> ProcessGroupNCCL::initWork(
options_->timeout,
pgStatus_,
isP2P);
r->trace_id_ = traceId.id;
r->trace_reset_epoch_ = traceId.reset_epoch;
}
return r;
}
@ -3597,7 +3593,6 @@ float ProcessGroupNCCL::endTimeEstimate() {
#ifdef NCCL_SIM_INFO_INITIALIZER
ncclSimInfo_t simInfo = NCCL_SIM_INFO_INITIALIZER;
C10D_NCCL_CHECK(ncclGroupSimulateEnd(&simInfo), std::nullopt);
--ncclActiveGroupCounter_;
return simInfo.estimatedTime;
#else
TORCH_CHECK(
@ -3681,7 +3676,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
// later in endCoalescing we record a 'coalesced' Work which has
// timing/state updates via watchdog thread, but lacks op metadata such as
// input/output sizes and profilingTitle per-op in the group.
FlightRecorderCUDA::get()->recordWithResetEnabled(
FlightRecorderCUDA::get()->record(
local_id_,
std::make_tuple(pg_uid_, pg_desc_),
seqCollective_,
@ -4173,7 +4168,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::pointToPoint(
// TODO(whc) because we don't pass output {tensor} to initWork, we tell
// initWork to not record, and then we manually call record passing all the
// information it wants.
auto traceId = FlightRecorderCUDA::get()->recordWithResetEnabled(
work->trace_id_ = FlightRecorderCUDA::get()->record(
local_id_,
std::make_tuple(pg_uid_, pg_desc_),
seqCollective_,
@ -4187,8 +4182,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::pointToPoint(
options_->timeout,
pgStatus_,
/*isP2P=*/true);
work->trace_id_ = traceId.id;
work->trace_reset_epoch_ = traceId.reset_epoch;
}
// Only check for NaN for send ops, for recv ops `tensor` can be a random

View File

@ -505,7 +505,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// unique id used to tell the trace buffer that this
// work has completed
std::optional<uint64_t> trace_id_;
std::optional<uint64_t> trace_reset_epoch_;
DebugLevel distDebugLevel_;
friend class ProcessGroupNCCL;
};

View File

@ -4,7 +4,6 @@
#include <torch/headeronly/core/ScalarType.h>
#include <torch/headeronly/macros/Macros.h>
#include <torch/headeronly/util/Exception.h>
#include <torch/headeronly/util/HeaderOnlyArrayRef.h>
#include <torch/headeronly/util/shim_utils.h>
#include <climits>
#include <memory>
@ -14,7 +13,6 @@
HIDDEN_NAMESPACE_BEGIN(torch, stable)
using accelerator::DeviceIndex;
using torch::headeronly::IntHeaderOnlyArrayRef;
using torch::headeronly::ScalarType;
// The torch::stable::Tensor class is a highlevel C++ wrapper around
@ -95,32 +93,6 @@ class Tensor {
return numel;
}
// note: this API is, for all intents and purposes, the same as the one in
// TensorBase.h: it returns a borrowed reference of the dimension sizes of
// a Tensor.
//
// The only difference is that it returns a header-only IntHeaderOnlyArrayRef,
// which has slightly less functionality than a regular IntArrayRef. See
// [HeaderOnlyArrayRef vs ArrayRef note] for more details.
IntHeaderOnlyArrayRef sizes() const {
int64_t* sizes;
TORCH_ERROR_CODE_CHECK(aoti_torch_get_sizes(ath_.get(), &sizes));
return IntHeaderOnlyArrayRef(sizes, dim());
}
// note: this API is, for all intents and purposes, the same as the one in
// TensorBase.h: it returns a borrowed reference of the strides of a
// Tensor.
//
// The only difference is that it returns a header-only IntHeaderOnlyArrayRef,
// which has slightly less functionality than a regular IntArrayRef. See
// [HeaderOnlyArrayRef vs ArrayRef note] for more details.
IntHeaderOnlyArrayRef strides() const {
int64_t* strides;
TORCH_ERROR_CODE_CHECK(aoti_torch_get_strides(ath_.get(), &strides));
return IntHeaderOnlyArrayRef(strides, dim());
}
// note: this is a subset of the original TensorBase API. It takes no
// arguments whereas the original API takes in a kwarg of memory format.
// Here, we assume the default contiguous memory format.

View File

@ -1,8 +1,9 @@
import functools
import math
import operator
from collections.abc import Callable, Sequence
from collections.abc import Sequence
from datetime import timedelta
from typing import Callable
import torch
from torch._C import ScriptObject

View File

@ -10,7 +10,6 @@ from ._context_parallel._attention import (
_enable_context_parallel_dispatcher,
_is_causal_behavior,
_RotateMethod,
_templated_ring_attention,
context_parallel,
context_parallel_unshard,
set_rotate_method,
@ -23,7 +22,6 @@ from ._context_parallel._load_balancer import (
)
# TODO(fegin): add deprecation message once the final interfaces are concluded.
__all__ = [
"_CausalBehavior",
"_context_parallel_shard",
@ -33,7 +31,6 @@ __all__ = [
"_enable_context_parallel_dispatcher",
"_is_causal_behavior",
"_RotateMethod",
"_templated_ring_attention",
"context_parallel",
"context_parallel_unshard",
"set_rotate_method",

View File

@ -647,15 +647,6 @@ class CodeGen:
if verbose:
# override annotation with more detailed information
try:
from torch.distributed.tensor._api import DTensor, DTensorSpec
dtensorspec_format_shard_order_str = (
DTensorSpec.format_shard_order_str
)
except ModuleNotFoundError:
DTensor = None # type: ignore[assignment,misc]
dtensorspec_format_shard_order_str = None
from torch.fx.experimental.proxy_tensor import py_sym_types
from torch.fx.passes.shape_prop import TensorMetadata
@ -686,16 +677,6 @@ class CodeGen:
core = _tensor_annotation(meta_val)
if is_plain:
maybe_type_annotation = f': "{core}"'
elif type(meta_val) is DTensor:
assert dtensorspec_format_shard_order_str is not None
dtensor_meta = dtensorspec_format_shard_order_str(
meta_val._spec.placements, # type: ignore[attr-defined]
meta_val._spec.shard_order, # type: ignore[attr-defined]
)
cls = meta_val.__class__.__name__
maybe_type_annotation = (
f': "{cls}({core}, {dim_green(dtensor_meta)})"'
)
else:
cls = meta_val.__class__.__name__
maybe_type_annotation = f': "{cls}({core})"'