mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-13 09:55:10 +08:00
Compare commits
1 Commits
ngimel/hos
...
update_sub
| Author | SHA1 | Date | |
|---|---|---|---|
| a09b0c26a7 |
@ -18,7 +18,7 @@ aspects of contributing to PyTorch.
|
||||
- [Python Unit Testing](#python-unit-testing)
|
||||
- [Better local unit tests with `pytest`](#better-local-unit-tests-with-pytest)
|
||||
- [Local linting](#local-linting)
|
||||
- [Running `pyrefly`](#running-pyrefly)
|
||||
- [Running `mypy`](#running-mypy)
|
||||
- [C++ Unit Testing](#c-unit-testing)
|
||||
- [Run Specific CI Jobs](#run-specific-ci-jobs)
|
||||
- [Merging your Change](#merging-your-change)
|
||||
@ -281,7 +281,7 @@ dependencies as well as the nightly binaries into the repo directory.
|
||||
**Prerequisites**:
|
||||
The following packages should be installed with `pip`:
|
||||
- `expecttest` and `hypothesis` - required to run tests
|
||||
- `pyrefly` - recommended for type checking. [Pyrefly](https://pyrefly.org/)
|
||||
- `mypy` - recommended for linting
|
||||
- `pytest` - recommended to run tests more selectively
|
||||
Running
|
||||
```
|
||||
@ -350,32 +350,15 @@ make lint
|
||||
|
||||
Learn more about the linter on the [lintrunner wiki page](https://github.com/pytorch/pytorch/wiki/lintrunner)
|
||||
|
||||
#### Running `pyrefly`
|
||||
#### Running `mypy`
|
||||
|
||||
[Pyrefly](https://pyrefly.org/) is a high-performance static type checker for Python. It provides fast type checking along with IDE features like autocomplete and instant error feedback.
|
||||
|
||||
PyTorch uses Pyrefly for type checking across the codebase. The configuration is managed in `pyrefly.toml` at the root of the repository.
|
||||
|
||||
**Getting Started with Pyrefly:**
|
||||
|
||||
To run type checking on the PyTorch codebase:
|
||||
```bash
|
||||
pyrefly check
|
||||
```
|
||||
|
||||
For more detailed error information with summaries:
|
||||
```bash
|
||||
pyrefly check --summarize-errors
|
||||
```
|
||||
|
||||
**Learn More:**
|
||||
- [Pyrefly Configuration](https://pyrefly.org/en/docs/configuration/) - Detailed configuration options
|
||||
- [Pyrefly IDE Features](https://pyrefly.org/en/docs/IDE-features/) - Set up Pyrefly in your editor for real-time type checking
|
||||
- [Python Typing Tutorial](https://pyrefly.org/en/docs/typing-for-python-developers/) - Learn about Python type annotations
|
||||
`mypy` is an optional static type checker for Python. We have multiple `mypy`
|
||||
configs for the PyTorch codebase that are automatically validated against whenever the linter is run.
|
||||
|
||||
See [Guide for adding type annotations to
|
||||
PyTorch](https://github.com/pytorch/pytorch/wiki/Guide-for-adding-type-annotations-to-PyTorch)
|
||||
for PyTorch-specific guidance on how to set up `pyrefly` and tackle type annotation tasks in this codebase.
|
||||
for more information on how to set up `mypy` and tackle type annotation
|
||||
tasks.
|
||||
|
||||
### C++ Unit Testing
|
||||
|
||||
|
||||
@ -226,8 +226,8 @@ template <
|
||||
typename B = HostBlock<S>>
|
||||
struct CachingHostAllocatorImpl {
|
||||
virtual ~CachingHostAllocatorImpl() {
|
||||
if (active_) {
|
||||
active_ = false;
|
||||
active_ = false;
|
||||
if (pinned_use_background_threads()) {
|
||||
getBackgroundThreadPool()->waitWorkComplete();
|
||||
}
|
||||
}
|
||||
@ -260,7 +260,6 @@ struct CachingHostAllocatorImpl {
|
||||
if (pinned_use_background_threads()) {
|
||||
// Launch the background thread and process events in a loop.
|
||||
static bool background_thread_flag [[maybe_unused]] = [this] {
|
||||
active_ = true;
|
||||
getBackgroundThreadPool()->run([&]() {
|
||||
while (active_) {
|
||||
process_events();
|
||||
@ -684,9 +683,9 @@ struct CachingHostAllocatorImpl {
|
||||
alignas(hardware_destructive_interference_size) std::mutex events_mutex_;
|
||||
std::deque<std::pair<E, B*>> events_; // event queue paired with block
|
||||
|
||||
// Indicates whether the event-processing thread pool is active.
|
||||
// Indicates whether the object is active.
|
||||
// Set to false in the destructor to signal background threads to stop.
|
||||
std::atomic<bool> active_{false};
|
||||
std::atomic<bool> active_{true};
|
||||
protected:
|
||||
alignas(hardware_destructive_interference_size) HostStatsStaged stats_;
|
||||
};
|
||||
|
||||
@ -24,13 +24,7 @@ namespace detail {
|
||||
// radix_sort_pairs doesn't interact with value_t other than to copy
|
||||
// the data, so we can save template instantiations by reinterpreting
|
||||
// it as an opaque type.
|
||||
// We use native integer types for 1/2/4/8-byte values to reduce
|
||||
// register usage in CUDA kernels. For sizes > 8 fall back to char array.
|
||||
template <int N> struct alignas(N) OpaqueType { char data[N]; };
|
||||
template <> struct alignas(1) OpaqueType<1> { uint8_t data; };
|
||||
template <> struct alignas(2) OpaqueType<2> { uint16_t data; };
|
||||
template <> struct alignas(4) OpaqueType<4> { uint32_t data; };
|
||||
template <> struct alignas(8) OpaqueType<8> { uint64_t data; };
|
||||
|
||||
template<typename key_t, int value_size>
|
||||
void radix_sort_pairs_impl(
|
||||
|
||||
@ -247,8 +247,8 @@ void binary_kernel_reduce(TensorIteratorBase& iter, ops_t ops, init_t init) {
|
||||
});
|
||||
}
|
||||
|
||||
template <typename func_t, typename vec_func_t, typename ident_t = double>
|
||||
void binary_kernel_reduce_vec(TensorIteratorBase& iter, func_t op, vec_func_t vop, ident_t ident = static_cast<ident_t>(0)) {
|
||||
template <typename func_t, typename vec_func_t>
|
||||
void binary_kernel_reduce_vec(TensorIteratorBase& iter, func_t op, vec_func_t vop, double ident = 0) {
|
||||
using traits = binary_function_traits<func_t>;
|
||||
static_assert(
|
||||
all_same<
|
||||
|
||||
@ -339,13 +339,33 @@ void or_kernel_impl(TensorIterator& iter) {
|
||||
}
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
struct MinValuesOps: public at::native::MinOps<scalar_t> {
|
||||
using arg_t = typename MinOps<scalar_t>::arg_t;
|
||||
static scalar_t project(arg_t arg) {
|
||||
return arg.first;
|
||||
}
|
||||
};
|
||||
|
||||
void min_values_kernel_impl(TensorIterator& iter) {
|
||||
// This case is special because of Vectorized<int64_t> does not
|
||||
// handle upper_bound<int64_t>().
|
||||
// See: https://github.com/pytorch/pytorch/issues/43254
|
||||
if (iter.dtype() == kLong || iter.dtype() == kUInt64) {
|
||||
AT_DISPATCH_V2(iter.dtype(), "min_values_cpu", AT_WRAP([&iter] {
|
||||
binary_kernel_reduce(
|
||||
iter,
|
||||
MinValuesOps<scalar_t>{},
|
||||
std::pair<scalar_t, int64_t>(upper_bound<scalar_t>(), -1));
|
||||
}), kLong, kUInt64);
|
||||
return;
|
||||
}
|
||||
AT_DISPATCH_V2(iter.dtype(), "min_values_cpu", AT_WRAP([&iter] {
|
||||
binary_kernel_reduce_vec(
|
||||
iter,
|
||||
[](scalar_t a, scalar_t b) -> scalar_t { return min_impl(a, b); },
|
||||
[](Vectorized<scalar_t> a, Vectorized<scalar_t> b) { return minimum(a, b); },
|
||||
upper_bound<scalar_t>());
|
||||
static_cast<double>(upper_bound<scalar_t>()));
|
||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool);
|
||||
}
|
||||
|
||||
|
||||
@ -47,10 +47,20 @@ Tensor sgd_out_of_place(
|
||||
STD_TORCH_CHECK(param.get_device() == -1, "CPU device index = -1");
|
||||
STD_TORCH_CHECK(param.get_device_index() == -1, "CPU device index = -1");
|
||||
|
||||
// testing Tensor strides + stride
|
||||
STD_TORCH_CHECK(param.strides()[0] == param.stride(0));
|
||||
int64_t *param_sizes;
|
||||
int64_t *param_strides;
|
||||
aoti_torch_get_sizes(param.get(), ¶m_sizes);
|
||||
aoti_torch_get_strides(param.get(), ¶m_strides);
|
||||
|
||||
auto out = new_empty(param, param.sizes());
|
||||
int32_t param_dtype;
|
||||
aoti_torch_get_dtype(param.get(), ¶m_dtype);
|
||||
|
||||
int32_t param_device_type;
|
||||
aoti_torch_get_device_type(param.get(), ¶m_device_type);
|
||||
|
||||
AtenTensorHandle out_ath;
|
||||
aoti_torch_empty_strided(param.dim(), param_sizes, param_strides, param_dtype, param_device_type, param.get_device(), &out_ath);
|
||||
auto out = Tensor(out_ath);
|
||||
|
||||
sgd_math(
|
||||
reinterpret_cast<float*>(param.data_ptr()),
|
||||
@ -334,8 +344,6 @@ Tensor my_new_empty_dtype_variant(Tensor t) {
|
||||
// Still using a std::vector below even though people can just pass in an
|
||||
// initializer list (which will be implicitly converted to an HeaderOnlyArrayRef)
|
||||
// directly.
|
||||
// This is to test that passing in a std::vector works for BC. (It gets
|
||||
// implicitly converted to HeaderOnlyArrayRef too!)
|
||||
std::vector<int64_t> sizes = {2, 5};
|
||||
auto dtype = std::make_optional(torch::headeronly::ScalarType::BFloat16);
|
||||
return new_empty(t, sizes, dtype);
|
||||
|
||||
@ -5,16 +5,8 @@ import contextlib
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.distributed.tensor import (
|
||||
DeviceMesh,
|
||||
distribute_tensor,
|
||||
DTensor,
|
||||
Partial,
|
||||
Replicate,
|
||||
Shard,
|
||||
)
|
||||
from torch.distributed.tensor import DeviceMesh, DTensor, Partial, Replicate, Shard
|
||||
from torch.distributed.tensor._dtensor_spec import ShardOrderEntry
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
@ -434,31 +426,6 @@ class TestDTensorDebugMode(TestCase):
|
||||
][-1]
|
||||
self.assertTrue("self.l2(self.l1(x))" in sum_op.fwd_stack_trace)
|
||||
|
||||
def test_pretty_print_dtensor_make_fx(self):
|
||||
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
||||
|
||||
A = torch.randn(8, 32)
|
||||
B = torch.randn(32, 32)
|
||||
dA = distribute_tensor(A, mesh, [Shard(0)]).requires_grad_()
|
||||
dB = distribute_tensor(B, mesh, [Replicate()]).requires_grad_()
|
||||
|
||||
def f(dA, dB):
|
||||
dy = dA @ dB
|
||||
loss = dy.sum()
|
||||
loss.backward()
|
||||
return dA.grad, dB.grad
|
||||
|
||||
# We actually need the tracing_mode='fake' here, or to trace under a FakeTensorMode.
|
||||
# make_fx has some logic to ensure we don't accidentally stash real tensors in the graph
|
||||
# so we won't stash our DTensors properly if they don't hold Fake inner tensors
|
||||
gm = make_fx(f, tracing_mode="fake")(dA, dB)
|
||||
# DCE isn't necessary here, there were just a lot of dead detach() nodes that spammed the graph
|
||||
gm.graph.eliminate_dead_code()
|
||||
gm.recompile()
|
||||
# Colored is nice for actual viewing, not using in this test though
|
||||
gm_str = gm.print_readable(colored=False, print_output=False)
|
||||
self.assertTrue('"DTensor(f32[8, 32], S(0))" = torch.ops.aten.mm' in gm_str)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestDTensorDebugMode)
|
||||
|
||||
|
||||
@ -5789,229 +5789,6 @@ class NCCLTraceTest(NCCLTraceTestBase):
|
||||
else:
|
||||
self.assertTrue("duration_ms" not in t["entries"][0])
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
@parametrize("timing_enabled", [True, False])
|
||||
def test_fr_record_reset_circular_buffer_full(self, timing_enabled):
|
||||
"""
|
||||
Test that when the circular buffer in entries_ is full and we call reset,
|
||||
then fill the buffer with new entries, dump_entries returns only the new
|
||||
entries and not the old ones.
|
||||
"""
|
||||
if self.rank == self.MAIN_PROCESS_RANK:
|
||||
return
|
||||
|
||||
# Override buffer size to 10 for faster testing
|
||||
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10"
|
||||
|
||||
pg = self._create_process_group_nccl()
|
||||
if timing_enabled:
|
||||
pg._enable_collectives_timing()
|
||||
device = self.local_device
|
||||
self.set_thread_name("fr_test_thread")
|
||||
a = torch.full((3, 4), float(self.rank), device=device)
|
||||
|
||||
# Fill the buffer completely with 10 entries
|
||||
for _ in range(10):
|
||||
f = pg.allreduce(a)
|
||||
f.wait()
|
||||
torch.cuda.synchronize(device=device)
|
||||
time.sleep(1)
|
||||
|
||||
# Verify buffer is full with 10 entries
|
||||
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
|
||||
self.assertEqual(len(t["entries"]), 10)
|
||||
|
||||
# Now reset the flight recorder
|
||||
torch._C._distributed_c10d._reset_fr_recording_nccl()
|
||||
|
||||
# Add new entries after reset - fill the buffer completely again
|
||||
for _ in range(10):
|
||||
f = pg.allreduce(a)
|
||||
f.wait()
|
||||
torch.cuda.synchronize(device=device)
|
||||
time.sleep(1)
|
||||
|
||||
# Verify we get exactly 10 new entries, not 20
|
||||
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
|
||||
self.assertEqual(len(t["entries"]), 10)
|
||||
|
||||
# Verify all entries have the expected properties (from after reset)
|
||||
# After reset, record IDs should start from 0 again
|
||||
for i, entry in enumerate(t["entries"]):
|
||||
self.assertIn("profiling_name", entry)
|
||||
self.assertEqual(entry["profiling_name"], "nccl:all_reduce")
|
||||
self.assertIn("record_id", entry)
|
||||
# Record IDs should be sequential starting from 0 after reset
|
||||
self.assertEqual(entry["record_id"], i)
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
@parametrize("timing_enabled", [True, False])
|
||||
def test_fr_record_reset_partial_overwrite(self, timing_enabled):
|
||||
"""
|
||||
Test that when the circular buffer is full, we reset, and then add fewer
|
||||
entries than the buffer size, we only get the new entries.
|
||||
This tests that old entries at the end of the circular buffer are properly
|
||||
filtered out based on reset_epoch.
|
||||
"""
|
||||
if self.rank == self.MAIN_PROCESS_RANK:
|
||||
return
|
||||
|
||||
# Override buffer size to 10 for faster testing
|
||||
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10"
|
||||
|
||||
pg = self._create_process_group_nccl()
|
||||
if timing_enabled:
|
||||
pg._enable_collectives_timing()
|
||||
device = self.local_device
|
||||
self.set_thread_name("fr_test_thread")
|
||||
a = torch.full((3, 4), float(self.rank), device=device)
|
||||
|
||||
# Fill the buffer completely
|
||||
for _ in range(10):
|
||||
f = pg.allreduce(a)
|
||||
f.wait()
|
||||
torch.cuda.synchronize(device=device)
|
||||
time.sleep(1)
|
||||
|
||||
# Reset the flight recorder
|
||||
torch._C._distributed_c10d._reset_fr_recording_nccl()
|
||||
|
||||
# Add only 3 new entries (much less than buffer size)
|
||||
for _ in range(3):
|
||||
f = pg.allreduce(a)
|
||||
f.wait()
|
||||
torch.cuda.synchronize(device=device)
|
||||
time.sleep(1)
|
||||
|
||||
# Verify we only get the 3 new entries, not 10
|
||||
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
|
||||
self.assertEqual(len(t["entries"]), 3)
|
||||
|
||||
# Verify record IDs start from 0 after reset
|
||||
for i, entry in enumerate(t["entries"]):
|
||||
self.assertIn("record_id", entry)
|
||||
self.assertEqual(entry["record_id"], i)
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
@parametrize("timing_enabled", [True, False])
|
||||
def test_fr_record_reset_wraparound(self, timing_enabled):
|
||||
"""
|
||||
Test that when we reset in the middle of the circular buffer and then
|
||||
wrap around, dump_entries correctly returns only entries from the current
|
||||
epoch in the correct order.
|
||||
"""
|
||||
if self.rank == self.MAIN_PROCESS_RANK:
|
||||
return
|
||||
|
||||
# Override buffer size to 10 for faster testing
|
||||
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10"
|
||||
|
||||
pg = self._create_process_group_nccl()
|
||||
if timing_enabled:
|
||||
pg._enable_collectives_timing()
|
||||
device = self.local_device
|
||||
self.set_thread_name("fr_test_thread")
|
||||
a = torch.full((3, 4), float(self.rank), device=device)
|
||||
|
||||
# Fill half the buffer
|
||||
for _ in range(5):
|
||||
f = pg.allreduce(a)
|
||||
f.wait()
|
||||
torch.cuda.synchronize(device=device)
|
||||
time.sleep(1)
|
||||
|
||||
# Reset at this point (reset happens at index 5)
|
||||
torch._C._distributed_c10d._reset_fr_recording_nccl()
|
||||
|
||||
# Now add 8 entries, which will wrap around
|
||||
# (5->9 fills rest of buffer, then 0->2 wraps around)
|
||||
for _ in range(8):
|
||||
f = pg.allreduce(a)
|
||||
f.wait()
|
||||
torch.cuda.synchronize(device=device)
|
||||
time.sleep(1)
|
||||
|
||||
# Should get exactly 8 entries, properly ordered
|
||||
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
|
||||
self.assertEqual(len(t["entries"]), 8)
|
||||
|
||||
# Entries should be in chronological order
|
||||
# The dump_entries() method returns entries from next_ to end, then 0 to next_
|
||||
# After filtering old entries, we should have 8 entries in order
|
||||
# Verify record IDs start from 0 after reset (id_ is reset in reset_all())
|
||||
for i, entry in enumerate(t["entries"]):
|
||||
self.assertIn("profiling_name", entry)
|
||||
self.assertIn("record_id", entry)
|
||||
self.assertEqual(entry["record_id"], i)
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
@parametrize("timing_enabled", [True, False])
|
||||
def test_fr_record_multiple_resets(self, timing_enabled):
|
||||
"""
|
||||
Test multiple consecutive resets to ensure each reset properly increments
|
||||
the epoch and filters out entries from previous epochs.
|
||||
"""
|
||||
if self.rank == self.MAIN_PROCESS_RANK:
|
||||
return
|
||||
|
||||
# Override buffer size to 10 for faster testing
|
||||
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10"
|
||||
|
||||
pg = self._create_process_group_nccl()
|
||||
if timing_enabled:
|
||||
pg._enable_collectives_timing()
|
||||
device = self.local_device
|
||||
self.set_thread_name("fr_test_thread")
|
||||
a = torch.full((3, 4), float(self.rank), device=device)
|
||||
|
||||
# First batch: 2 entries
|
||||
for _ in range(2):
|
||||
f = pg.allreduce(a)
|
||||
f.wait()
|
||||
torch.cuda.synchronize(device=device)
|
||||
time.sleep(1)
|
||||
|
||||
# First reset
|
||||
torch._C._distributed_c10d._reset_fr_recording_nccl()
|
||||
|
||||
# Second batch: 3 entries
|
||||
for _ in range(3):
|
||||
f = pg.allreduce(a)
|
||||
f.wait()
|
||||
torch.cuda.synchronize(device=device)
|
||||
time.sleep(1)
|
||||
|
||||
# Second reset
|
||||
torch._C._distributed_c10d._reset_fr_recording_nccl()
|
||||
|
||||
# Third batch: 4 entries
|
||||
for _ in range(4):
|
||||
f = pg.allreduce(a)
|
||||
f.wait()
|
||||
torch.cuda.synchronize(device=device)
|
||||
time.sleep(1)
|
||||
|
||||
# Should only see the last 4 entries
|
||||
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
|
||||
self.assertEqual(len(t["entries"]), 4)
|
||||
|
||||
# Verify record IDs start from 0 after the last reset
|
||||
for i, entry in enumerate(t["entries"]):
|
||||
self.assertIn("record_id", entry)
|
||||
self.assertEqual(entry["record_id"], i)
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def check_if_test_is_skipped(fn):
|
||||
def wrapper(self, *args, **kwargs):
|
||||
|
||||
@ -8,11 +8,21 @@ from torch._dynamo.graph_deduplication import apply_graph_deduplication
|
||||
from torch._dynamo.graph_utils import _detect_cycles
|
||||
from torch._dynamo.output_graph import FakeRootModule
|
||||
from torch._dynamo.test_case import TestCase
|
||||
from torch._dynamo.testing import extract_graph, extract_graph_and_tracker, normalize_gm
|
||||
from torch._dynamo.testing import (
|
||||
AotEagerAndRecordGraphs,
|
||||
extract_graph_and_tracker,
|
||||
normalize_gm,
|
||||
)
|
||||
from torch.compiler import allow_in_graph
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
|
||||
def extract_graph(fn, *args, **kwargs):
|
||||
backend = AotEagerAndRecordGraphs()
|
||||
result = torch.compile(backend=backend)(fn)(*args, **kwargs)
|
||||
return result, backend.graphs, backend.fw_graphs
|
||||
|
||||
|
||||
def graph_str(gm):
|
||||
return normalize_gm(gm.print_readable(print_output=False))
|
||||
|
||||
@ -30,7 +40,7 @@ class GraphDededuplicationTests(TestCase):
|
||||
super().tearDown()
|
||||
|
||||
def run_and_return_graphs(self, fn, *args, **kwargs):
|
||||
return extract_graph(fn, *args, **kwargs)[0:3]
|
||||
return extract_graph(fn, *args, **kwargs)
|
||||
|
||||
def run_and_get_simple_graph(self):
|
||||
def fn(x, y):
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
import unittest
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Callable, Union
|
||||
|
||||
import torch
|
||||
import torch._dynamo
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
from typing import NamedTuple, Optional, TYPE_CHECKING
|
||||
from typing import Callable, NamedTuple, Optional
|
||||
|
||||
import torch
|
||||
import torch._dynamo
|
||||
@ -7,10 +7,6 @@ from torch._dynamo.test_case import run_tests, TestCase
|
||||
from torch._dynamo.testing import CompileCounter, same
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
|
||||
"""
|
||||
This is an example of a pure-python version of autograd implemented by
|
||||
@zdevito. It represents a rather challenging test case for TorchDynamo
|
||||
|
||||
@ -1,13 +1,11 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
import functools
|
||||
import re
|
||||
import unittest
|
||||
import weakref
|
||||
|
||||
import torch
|
||||
import torch._dynamo.test_case
|
||||
import torch._dynamo.testing
|
||||
from torch._dynamo.testing import extract_graph, remove_trailing_space
|
||||
from torch.testing._internal.common_cuda import TEST_MULTIGPU
|
||||
from torch.testing._internal.common_utils import requires_cuda
|
||||
|
||||
@ -17,14 +15,6 @@ requires_multigpu = functools.partial(
|
||||
)
|
||||
|
||||
|
||||
def remove_file_comment(gm_str: str) -> str:
|
||||
return remove_trailing_space(re.sub(r"File.*\n", "\n", gm_str))
|
||||
|
||||
|
||||
def print_graph(graph: torch.fx.GraphModule) -> str:
|
||||
return remove_file_comment(graph.print_readable())
|
||||
|
||||
|
||||
class TestStreams(torch._dynamo.test_case.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
@ -46,7 +36,9 @@ class TestStreams(torch._dynamo.test_case.TestCase):
|
||||
|
||||
@requires_cuda
|
||||
def test_stream_enter_exit(self):
|
||||
def fn(x, y, s1, s2):
|
||||
def fn(x, y):
|
||||
s2 = torch.Stream()
|
||||
s1 = torch.Stream()
|
||||
with s1:
|
||||
z1 = torch.add(x, y)
|
||||
with s2:
|
||||
@ -55,36 +47,13 @@ class TestStreams(torch._dynamo.test_case.TestCase):
|
||||
|
||||
return y
|
||||
|
||||
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2), torch.Stream(), torch.Stream())
|
||||
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
|
||||
expected = fn(*inp)
|
||||
(
|
||||
actual,
|
||||
_,
|
||||
fw_graphs,
|
||||
_,
|
||||
) = extract_graph(fn, *inp)
|
||||
self.assertEqual(len(fw_graphs), 1)
|
||||
fn_opt = torch.compile(fn, fullgraph=True)
|
||||
actual = fn_opt(*inp)
|
||||
self.assertEqual(expected, actual)
|
||||
self.assertExpectedInline(
|
||||
print_graph(fw_graphs[0]),
|
||||
"""\
|
||||
class <lambda>(torch.nn.Module):
|
||||
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
|
||||
# Annotation: {'stream': None}
|
||||
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
|
||||
|
||||
# Annotation: {'stream': None}
|
||||
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
|
||||
|
||||
# Annotation: {'stream': None}
|
||||
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 2); add_1 = None
|
||||
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_2, add); add_2 = add = None
|
||||
return (add_3,)
|
||||
""",
|
||||
)
|
||||
|
||||
@requires_cuda
|
||||
@unittest.skip("Needs graph break support with annotation context")
|
||||
def test_stream_context_graph_break(self):
|
||||
def fn(x, y):
|
||||
s2 = torch.Stream()
|
||||
@ -101,16 +70,9 @@ class <lambda>(torch.nn.Module):
|
||||
|
||||
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
|
||||
expected = fn(*inp)
|
||||
(
|
||||
actual,
|
||||
_,
|
||||
fw_graphs,
|
||||
_,
|
||||
) = extract_graph(fn, *inp)
|
||||
fn_opt = torch.compile(fn)
|
||||
actual = fn_opt(*inp)
|
||||
self.assertEqual(expected, actual)
|
||||
self.assertEqual(len(fw_graphs), 2)
|
||||
self.assertExpectedInline(print_graph(fw_graphs[0]), """""")
|
||||
self.assertExpectedInline(print_graph(fw_graphs[1]), """""")
|
||||
|
||||
@requires_cuda
|
||||
def test_stream_input(self):
|
||||
@ -193,248 +155,22 @@ class <lambda>(torch.nn.Module):
|
||||
self.assertEqual(s_act, s_exp)
|
||||
|
||||
def test_nested_stream_enter_exit(self):
|
||||
def fn(x, y, s0, s1, s2):
|
||||
with s1:
|
||||
with s2:
|
||||
z1 = torch.add(x, y)
|
||||
with s0:
|
||||
z0 = torch.add(x, y)
|
||||
with s2:
|
||||
y = 2 + z1
|
||||
pass
|
||||
|
||||
return z0, y
|
||||
|
||||
inp = (
|
||||
torch.ones(2, 2) + 1,
|
||||
torch.ones(2, 2),
|
||||
torch.Stream(),
|
||||
torch.Stream(),
|
||||
torch.Stream(),
|
||||
)
|
||||
expected = fn(*inp)
|
||||
(
|
||||
actual,
|
||||
_,
|
||||
fw_graphs,
|
||||
_,
|
||||
) = extract_graph(fn, *inp)
|
||||
self.assertEqual(len(fw_graphs), 1)
|
||||
self.assertEqual(expected, actual)
|
||||
self.assertExpectedInline(
|
||||
print_graph(fw_graphs[0]),
|
||||
"""\
|
||||
class <lambda>(torch.nn.Module):
|
||||
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
|
||||
# Annotation: {'stream': None}
|
||||
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
|
||||
|
||||
# Annotation: {'stream': None}
|
||||
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
|
||||
|
||||
# Annotation: {'stream': None}
|
||||
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add, 2); add = None
|
||||
return (add_1, add_2)
|
||||
""",
|
||||
)
|
||||
|
||||
@unittest.skip("Needs graph break support with annotation context")
|
||||
def test_stream_enter_exit_graph_break(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Needs graph break support with annotation context")
|
||||
def test_nested_stream_enter_exit_graph_break(self):
|
||||
pass
|
||||
|
||||
def test_local_stream_enter_exit(self):
|
||||
def fn(x, y):
|
||||
s2 = torch.Stream()
|
||||
s1 = torch.Stream()
|
||||
with s1:
|
||||
z1 = torch.add(x, y)
|
||||
with s2:
|
||||
z = torch.add(x, y)
|
||||
y = z + 2 + z1
|
||||
|
||||
return y
|
||||
|
||||
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
|
||||
expected = fn(*inp)
|
||||
(
|
||||
actual,
|
||||
_,
|
||||
fw_graphs,
|
||||
_,
|
||||
) = extract_graph(fn, *inp)
|
||||
self.assertEqual(len(fw_graphs), 1)
|
||||
self.assertEqual(expected, actual)
|
||||
self.assertExpectedInline(
|
||||
print_graph(fw_graphs[0]),
|
||||
"""\
|
||||
class <lambda>(torch.nn.Module):
|
||||
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
|
||||
# Annotation: {'stream': 1}
|
||||
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
|
||||
|
||||
# Annotation: {'stream': 0}
|
||||
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
|
||||
|
||||
# Annotation: {'stream': 0}
|
||||
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 2); add_1 = None
|
||||
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_2, add); add_2 = add = None
|
||||
return (add_3,)
|
||||
""",
|
||||
)
|
||||
pass
|
||||
|
||||
def test_local_stream_nested_enter_exit(self):
|
||||
def fn(x, y):
|
||||
s2 = torch.Stream()
|
||||
s1 = torch.Stream()
|
||||
s0 = torch.Stream()
|
||||
with s1:
|
||||
with s2:
|
||||
z1 = torch.add(x, y)
|
||||
with s0:
|
||||
z0 = torch.add(x, y)
|
||||
with s2:
|
||||
y = 2 + z1
|
||||
|
||||
return z0, y
|
||||
|
||||
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
|
||||
expected = fn(*inp)
|
||||
(
|
||||
actual,
|
||||
_,
|
||||
fw_graphs,
|
||||
_,
|
||||
) = extract_graph(fn, *inp)
|
||||
self.assertEqual(len(fw_graphs), 1)
|
||||
self.assertEqual(expected, actual)
|
||||
self.assertExpectedInline(
|
||||
print_graph(fw_graphs[0]),
|
||||
"""\
|
||||
class <lambda>(torch.nn.Module):
|
||||
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
|
||||
# Annotation: {'stream': 0}
|
||||
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
|
||||
|
||||
# Annotation: {'stream': 2}
|
||||
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
|
||||
|
||||
# Annotation: {'stream': 0}
|
||||
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add, 2); add = None
|
||||
return (add_1, add_2)
|
||||
""",
|
||||
)
|
||||
pass
|
||||
|
||||
def test_stream_with_mutation(self):
|
||||
def fn(x, y):
|
||||
s2 = torch.Stream()
|
||||
s1 = torch.Stream()
|
||||
s0 = torch.Stream()
|
||||
with s1:
|
||||
with s2:
|
||||
x.add_(y)
|
||||
with s0:
|
||||
z1 = torch.add(y, y)
|
||||
z0 = torch.add(z1, y)
|
||||
with s2:
|
||||
y = 2 + z1
|
||||
|
||||
return z0, y
|
||||
|
||||
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
|
||||
expected = fn(*inp)
|
||||
(
|
||||
actual,
|
||||
_,
|
||||
fw_graphs,
|
||||
_,
|
||||
) = extract_graph(fn, *inp)
|
||||
self.assertEqual(len(fw_graphs), 1)
|
||||
self.assertEqual(expected, actual)
|
||||
self.assertExpectedInline(
|
||||
print_graph(fw_graphs[0]),
|
||||
"""\
|
||||
class <lambda>(torch.nn.Module):
|
||||
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
|
||||
# Annotation: {'stream': 0}
|
||||
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
|
||||
|
||||
# Annotation: {'stream': 2}
|
||||
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg1_1, arg1_1)
|
||||
|
||||
# Annotation: {'stream': 2}
|
||||
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, arg1_1); arg1_1 = None
|
||||
|
||||
# Annotation: {'stream': 0}
|
||||
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 2); add_1 = None
|
||||
|
||||
#
|
||||
copy_: "f32[2, 2]" = torch.ops.aten.copy_.default(arg0_1, add); arg0_1 = add = copy_ = None
|
||||
return (add_2, add_3)
|
||||
""",
|
||||
)
|
||||
|
||||
def test_stream_backward(self) -> None:
|
||||
def fn(x, y):
|
||||
s2 = torch.Stream()
|
||||
s0 = torch.Stream()
|
||||
with s0:
|
||||
y0 = 2 * x + y
|
||||
with s2:
|
||||
z = 2 * x + y
|
||||
|
||||
return y0, z
|
||||
|
||||
inp = (
|
||||
torch.ones(2, 2, requires_grad=True) + 1,
|
||||
torch.ones(2, 2, requires_grad=True),
|
||||
)
|
||||
expected = fn(*inp)
|
||||
(
|
||||
actual,
|
||||
_,
|
||||
fw_graphs,
|
||||
bw_graphs,
|
||||
) = extract_graph(fn, *inp)
|
||||
self.assertEqual(len(fw_graphs), 1)
|
||||
self.assertEqual(expected, actual)
|
||||
self.assertExpectedInline(
|
||||
print_graph(fw_graphs[0]),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "f32[2, 2]", primals_2: "f32[2, 2]"):
|
||||
# Annotation: {'stream': 1}
|
||||
mul: "f32[2, 2]" = torch.ops.aten.mul.Tensor(primals_1, 2); primals_1 = None
|
||||
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul, primals_2)
|
||||
|
||||
# Annotation: {'stream': 0}
|
||||
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul, primals_2); mul = primals_2 = None
|
||||
return (add, add_1)
|
||||
""",
|
||||
)
|
||||
|
||||
actual[1].sum().backward()
|
||||
self.assertExpectedInline(
|
||||
print_graph(bw_graphs[0]),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, tangents_1: "f32[2, 2]", tangents_2: "f32[2, 2]"):
|
||||
# Annotation: {'stream': 0}
|
||||
mul_2: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_2, 2)
|
||||
|
||||
#
|
||||
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(tangents_2, tangents_1); tangents_2 = None
|
||||
|
||||
# Annotation: {'stream': 1}
|
||||
mul_3: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_1, 2); tangents_1 = None
|
||||
|
||||
#
|
||||
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul_2, mul_3); mul_2 = mul_3 = None
|
||||
return (add_3, add_2)
|
||||
""",
|
||||
)
|
||||
pass
|
||||
|
||||
@requires_cuda
|
||||
def test_run_opcheck(self):
|
||||
|
||||
@ -14424,6 +14424,20 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
||||
|
||||
self.common(fn, (torch.randn(6, 4, device=GPU_TYPE).t().contiguous().t(),))
|
||||
|
||||
@skip_if_halide
|
||||
@requires_cuda_and_triton
|
||||
def test_unbacked_float_item(self):
|
||||
def fn(x, max_val):
|
||||
return torch.clamp(x, 0, max_val.item())
|
||||
|
||||
self.common(
|
||||
fn,
|
||||
(
|
||||
torch.randn(10, 20, 30, device=self.device),
|
||||
torch.tensor(5.0, device=self.device),
|
||||
),
|
||||
)
|
||||
|
||||
# end of class CommonTemplate - add new tests here
|
||||
|
||||
|
||||
|
||||
@ -1864,8 +1864,6 @@ class TestFP8Matmul(TestCase):
|
||||
], name_fn=lambda mkn: f"{mkn[0]}_{mkn[1]}_{mkn[2]}")
|
||||
@parametrize("recipe", ["mxfp8", "mxfp4", "nvfp4"])
|
||||
def test_blockwise_mxfp8_nvfp4_mxfp4_numerics(self, test_case_name, fast_accum, mkn, recipe) -> None:
|
||||
if torch.version.hip and recipe == "nvfp4":
|
||||
raise unittest.SkipTest("nvfp4 not supported on ROCm, skipping")
|
||||
if (recipe == "nvfp4" or recipe == "mxfp4") and fast_accum:
|
||||
raise unittest.SkipTest("fast_accum not supported in nvfp4/mxfp4 cublas gemm, skipping")
|
||||
|
||||
|
||||
@ -1914,7 +1914,6 @@ class TestSDPAFailureModes(NNTestCase):
|
||||
q, k, v, None, 0.0, is_causal=True))
|
||||
|
||||
@onlyCUDA
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support Efficient Attention")
|
||||
def test_mem_eff_attention_fail_with_batch_size_geq_65536(self):
|
||||
batch_size = 2**16
|
||||
query = torch.rand([batch_size, 2, 2, 8], device='cuda', dtype=torch.float16, requires_grad=True)
|
||||
@ -1936,7 +1935,6 @@ class TestSDPAFailureModes(NNTestCase):
|
||||
self.assertEqual(value.grad, v_cpu.grad, atol=2e-3, rtol=1e-4)
|
||||
|
||||
@onlyCUDA
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support Efficient Attention")
|
||||
def test_mem_eff_attention_fail_with_batch_size_geq_65536_error(self):
|
||||
query = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16)
|
||||
key = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16)
|
||||
@ -1950,7 +1948,6 @@ class TestSDPAFailureModes(NNTestCase):
|
||||
|
||||
@largeTensorTest("15GB", "cuda")
|
||||
@onlyCUDA
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support Efficient Attention")
|
||||
def test_mem_eff_attention_large_seq_len_uniform_attention(self):
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.bfloat16
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from typing import TypeAlias, Union
|
||||
from typing_extensions import assert_type
|
||||
from typing import Union
|
||||
from typing_extensions import assert_type, TypeAlias
|
||||
|
||||
from torch import randn, Tensor
|
||||
|
||||
|
||||
2
third_party/tensorpipe
vendored
2
third_party/tensorpipe
vendored
Submodule third_party/tensorpipe updated: af0118d13e...2b4cd91092
@ -1,9 +1,8 @@
|
||||
# mypy: allow-untyped-defs
|
||||
# mypy: disable-error-code="type-arg"
|
||||
from collections.abc import Callable
|
||||
from datetime import timedelta
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, overload, Union
|
||||
from typing import Any, Callable, Optional, overload, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
@ -3320,7 +3320,7 @@ class InstructionTranslatorBase(
|
||||
obj = self.stack[-inst.arg]
|
||||
assert isinstance(obj, SetVariable)
|
||||
assert obj.is_mutable()
|
||||
obj.call_method(self, "add", [v], {}) # type: ignore[arg-type]
|
||||
obj.call_method(self, "add", [v], {})
|
||||
|
||||
def SET_UPDATE(self, inst: Instruction) -> None:
|
||||
v = self.pop()
|
||||
@ -3329,7 +3329,7 @@ class InstructionTranslatorBase(
|
||||
obj = self.stack[-inst.arg]
|
||||
assert isinstance(obj, SetVariable)
|
||||
assert obj.is_mutable()
|
||||
obj.call_method(self, "update", [v], {}) # type: ignore[arg-type]
|
||||
obj.call_method(self, "update", [v], {})
|
||||
|
||||
def LIST_APPEND(self, inst: Instruction) -> None:
|
||||
v = self.pop()
|
||||
@ -3637,7 +3637,7 @@ class InstructionTranslatorBase(
|
||||
obj = self.stack[-inst.arg].realize()
|
||||
assert isinstance(obj, ConstDictVariable)
|
||||
assert obj.is_mutable()
|
||||
obj.call_method(self, "update", [v], {}) # type: ignore[arg-type]
|
||||
obj.call_method(self, "update", [v], {})
|
||||
|
||||
DICT_UPDATE = DICT_MERGE
|
||||
|
||||
|
||||
@ -87,12 +87,6 @@ def extract_graph_and_tracker(fn, *args, **kwargs): # type: ignore[no-untyped-d
|
||||
return gm.graph, region_tracker # type: ignore[union-attr]
|
||||
|
||||
|
||||
def extract_graph(fn, *args, **kwargs): # type: ignore[no-untyped-def]
|
||||
backend = AotEagerAndRecordGraphs()
|
||||
result = torch.compile(backend=backend)(fn)(*args, **kwargs)
|
||||
return result, backend.graphs, backend.fw_graphs, backend.bw_graphs
|
||||
|
||||
|
||||
def collect_results(
|
||||
model: torch.nn.Module, prediction: Any, loss: Any, example_inputs: Any
|
||||
) -> list[Any]:
|
||||
|
||||
@ -21,9 +21,9 @@ restoring state changes.
|
||||
import inspect
|
||||
import sys
|
||||
import warnings
|
||||
from collections.abc import Callable, Sequence, Sized
|
||||
from collections.abc import Callable, Sequence
|
||||
from contextlib import ExitStack
|
||||
from typing import Any, ContextManager, Optional, TYPE_CHECKING, Union
|
||||
from typing import Any, ContextManager, Optional, Sized, TYPE_CHECKING, Union
|
||||
|
||||
import torch._C
|
||||
from torch._guards import Guard
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
# mypy: ignore-errors
|
||||
|
||||
"""
|
||||
Dictionary-related variable tracking classes for PyTorch Dynamo.
|
||||
|
||||
@ -24,7 +26,7 @@ import inspect
|
||||
import operator
|
||||
import types
|
||||
from collections.abc import Hashable as py_Hashable
|
||||
from typing import Any, Optional, TYPE_CHECKING, Union
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
|
||||
from torch._subclasses.fake_tensor import is_fake
|
||||
|
||||
@ -57,13 +59,11 @@ if TYPE_CHECKING:
|
||||
# - (perhaps) Define how it is compared in _HashableTracker._eq_impl
|
||||
|
||||
|
||||
def was_instancecheck_override(obj: Any) -> bool:
|
||||
def was_instancecheck_override(obj):
|
||||
return type(obj).__dict__.get("__instancecheck__", False)
|
||||
|
||||
|
||||
def raise_unhashable(
|
||||
arg: VariableTracker, tx: Optional["InstructionTranslator"] = None
|
||||
) -> None:
|
||||
def raise_unhashable(arg, tx=None):
|
||||
if tx is None:
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
|
||||
@ -75,7 +75,7 @@ def raise_unhashable(
|
||||
)
|
||||
|
||||
|
||||
def is_hashable(x: VariableTracker) -> bool:
|
||||
def is_hashable(x):
|
||||
# NB - performing isinstance check on a LazVT realizes the VT, accidentally
|
||||
# inserting the guard. To avoid this, lazyVT `is_hashable` methods looks at
|
||||
# the underlying value without realizing the VT. Consider updating the
|
||||
@ -143,7 +143,7 @@ class ConstDictVariable(VariableTracker):
|
||||
Note that it's also fine to put VTs into dictionaries and sets, but doing so does not take into account aliasing
|
||||
"""
|
||||
|
||||
def __init__(self, vt: VariableTracker) -> None:
|
||||
def __init__(self, vt) -> None:
|
||||
# We specialize SymNodes
|
||||
vt = specialize_symnode(vt)
|
||||
# TODO Temporarily remove to figure out what keys are we breaking on
|
||||
@ -153,7 +153,7 @@ class ConstDictVariable(VariableTracker):
|
||||
self.vt = vt
|
||||
|
||||
@property
|
||||
def underlying_value(self) -> Any:
|
||||
def underlying_value(self):
|
||||
if (
|
||||
isinstance(self.vt, variables.LazyVariableTracker)
|
||||
and not self.vt.is_realized()
|
||||
@ -178,8 +178,7 @@ class ConstDictVariable(VariableTracker):
|
||||
elif isinstance(self.vt, variables.FrozenDataClassVariable):
|
||||
Hashable = ConstDictVariable._HashableTracker
|
||||
fields_values = {
|
||||
k: Hashable(v).underlying_value
|
||||
for k, v in self.vt.fields.items() # type: ignore[attr-defined]
|
||||
k: Hashable(v).underlying_value for k, v in self.vt.fields.items()
|
||||
}
|
||||
return variables.FrozenDataClassVariable.HashWrapper(
|
||||
self.vt.python_type(), fields_values
|
||||
@ -188,16 +187,16 @@ class ConstDictVariable(VariableTracker):
|
||||
# The re module in Python 3.13+ has a dictionary (_cache2) with
|
||||
# an object as key (`class _ZeroSentinel(int): ...`):
|
||||
# python test/dynamo/test_unittest.py CPythonTestLongMessage.test_baseAssertEqual
|
||||
return self.vt.value # type: ignore[attr-defined,union-attr]
|
||||
return self.vt.value
|
||||
else:
|
||||
x = self.vt.as_python_constant()
|
||||
return x
|
||||
|
||||
def __hash__(self) -> int:
|
||||
def __hash__(self):
|
||||
return hash(self.underlying_value)
|
||||
|
||||
@staticmethod
|
||||
def _eq_impl(a: Any, b: Any) -> bool:
|
||||
def _eq_impl(a, b):
|
||||
# TODO: Put this in utils and share it between variables/builtin.py and here
|
||||
type_a, type_b = type(a), type(b)
|
||||
if not (issubclass(type_a, type_b) or issubclass(type_b, type_a)):
|
||||
@ -213,7 +212,7 @@ class ConstDictVariable(VariableTracker):
|
||||
else:
|
||||
return a == b
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
def __eq__(self, other: "ConstDictVariable._HashableTracker") -> bool:
|
||||
Hashable = ConstDictVariable._HashableTracker
|
||||
assert isinstance(other, Hashable) or ConstantVariable.is_literal(other), (
|
||||
type(other)
|
||||
@ -227,8 +226,8 @@ class ConstDictVariable(VariableTracker):
|
||||
def __init__(
|
||||
self,
|
||||
items: dict[VariableTracker, VariableTracker],
|
||||
user_cls: type = dict,
|
||||
**kwargs: Any,
|
||||
user_cls=dict,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
# .clone() pass these arguments in kwargs but they're recreated a few
|
||||
# lines below
|
||||
@ -248,22 +247,18 @@ class ConstDictVariable(VariableTracker):
|
||||
for x, v in items.items()
|
||||
)
|
||||
|
||||
def make_hashable(
|
||||
key: Union[VariableTracker, "ConstDictVariable._HashableTracker"],
|
||||
) -> "ConstDictVariable._HashableTracker":
|
||||
def make_hashable(key):
|
||||
return key if isinstance(key, Hashable) else Hashable(key)
|
||||
|
||||
dict_cls = self._get_dict_cls_from_user_cls(user_cls)
|
||||
self.items = dict_cls({make_hashable(x): v for x, v in items.items()})
|
||||
# need to reconstruct everything if the dictionary is an intermediate value
|
||||
# or if a pop/delitem was executed
|
||||
self.should_reconstruct_all = (
|
||||
not is_from_local_source(self.source) if self.source else True
|
||||
)
|
||||
self.should_reconstruct_all = not is_from_local_source(self.source)
|
||||
self.original_items = items.copy()
|
||||
self.user_cls = user_cls
|
||||
|
||||
def _get_dict_cls_from_user_cls(self, user_cls: type) -> type:
|
||||
def _get_dict_cls_from_user_cls(self, user_cls):
|
||||
accepted_dict_types = (dict, collections.OrderedDict, collections.defaultdict)
|
||||
|
||||
# avoid executing user code if user_cls is a dict subclass
|
||||
@ -282,10 +277,10 @@ class ConstDictVariable(VariableTracker):
|
||||
dict_cls = dict
|
||||
return dict_cls
|
||||
|
||||
def as_proxy(self) -> dict[Any, Any]:
|
||||
def as_proxy(self):
|
||||
return {k.vt.as_proxy(): v.as_proxy() for k, v in self.items.items()}
|
||||
|
||||
def debug_repr(self) -> str:
|
||||
def debug_repr(self):
|
||||
return (
|
||||
"{"
|
||||
+ ", ".join(
|
||||
@ -294,20 +289,20 @@ class ConstDictVariable(VariableTracker):
|
||||
+ "}"
|
||||
)
|
||||
|
||||
def as_python_constant(self) -> dict[Any, Any]:
|
||||
def as_python_constant(self):
|
||||
return {
|
||||
k.vt.as_python_constant(): v.as_python_constant()
|
||||
for k, v in self.items.items()
|
||||
}
|
||||
|
||||
def keys_as_python_constant(self) -> dict[Any, VariableTracker]:
|
||||
def keys_as_python_constant(self):
|
||||
self.install_dict_keys_match_guard()
|
||||
return {k.vt.as_python_constant(): v for k, v in self.items.items()}
|
||||
|
||||
def python_type(self) -> type:
|
||||
def python_type(self):
|
||||
return self.user_cls
|
||||
|
||||
def __contains__(self, vt: VariableTracker) -> bool:
|
||||
def __contains__(self, vt) -> bool:
|
||||
assert isinstance(vt, VariableTracker)
|
||||
Hashable = ConstDictVariable._HashableTracker
|
||||
return (
|
||||
@ -327,15 +322,13 @@ class ConstDictVariable(VariableTracker):
|
||||
for key, value in self.items.items()
|
||||
)
|
||||
|
||||
def is_new_item(
|
||||
self, value: Optional[VariableTracker], other: VariableTracker
|
||||
) -> bool:
|
||||
def is_new_item(self, value, other):
|
||||
# compare the id of the realized values if both values are not lazy VTs
|
||||
if value and value.is_realized() and other.is_realized():
|
||||
return id(value.realize()) != id(other.realize())
|
||||
return id(value) != id(other)
|
||||
|
||||
def reconstruct_kvs_into_new_dict(self, codegen: "PyCodegen") -> None:
|
||||
def reconstruct_kvs_into_new_dict(self, codegen):
|
||||
# Build a dictionary that contains the keys and values.
|
||||
num_args = 0
|
||||
for key, value in self.items.items():
|
||||
@ -347,7 +340,7 @@ class ConstDictVariable(VariableTracker):
|
||||
num_args += 1
|
||||
codegen.append_output(create_instruction("BUILD_MAP", arg=num_args))
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
if self.user_cls is collections.OrderedDict:
|
||||
# emit `OrderedDict(constructed_dict)`
|
||||
codegen.add_push_null(
|
||||
@ -365,21 +358,19 @@ class ConstDictVariable(VariableTracker):
|
||||
|
||||
def getitem_const_raise_exception_if_absent(
|
||||
self, tx: "InstructionTranslator", arg: VariableTracker
|
||||
) -> VariableTracker:
|
||||
):
|
||||
key = ConstDictVariable._HashableTracker(arg)
|
||||
if key not in self.items:
|
||||
raise_observed_exception(KeyError, tx)
|
||||
return self.items[key]
|
||||
|
||||
def getitem_const(
|
||||
self, tx: "InstructionTranslator", arg: VariableTracker
|
||||
) -> VariableTracker:
|
||||
def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
|
||||
key = ConstDictVariable._HashableTracker(arg)
|
||||
if key not in self.items:
|
||||
msg = f"Dictionary key {arg.value} not found during tracing" # type: ignore[attr-defined]
|
||||
msg = f"Dictionary key {arg.value} not found during tracing"
|
||||
unimplemented_v2(
|
||||
gb_type="key not found in dict",
|
||||
context=f"Key {arg.value}", # type: ignore[attr-defined]
|
||||
context=f"Key {arg.value}",
|
||||
explanation=msg,
|
||||
hints=[
|
||||
"Check if the key exists in the dictionary before accessing it.",
|
||||
@ -388,13 +379,13 @@ class ConstDictVariable(VariableTracker):
|
||||
)
|
||||
return self.items[key]
|
||||
|
||||
def maybe_getitem_const(self, arg: VariableTracker) -> Optional[VariableTracker]:
|
||||
def maybe_getitem_const(self, arg: VariableTracker):
|
||||
key = ConstDictVariable._HashableTracker(arg)
|
||||
if key not in self.items:
|
||||
return None
|
||||
return self.items[key]
|
||||
|
||||
def realize_key_vt(self, arg: VariableTracker) -> None:
|
||||
def realize_key_vt(self, arg: VariableTracker):
|
||||
# Realize the LazyVT on a particular index
|
||||
assert arg in self
|
||||
key = ConstDictVariable._HashableTracker(arg)
|
||||
@ -403,13 +394,11 @@ class ConstDictVariable(VariableTracker):
|
||||
if isinstance(original_key_vt, variables.LazyVariableTracker):
|
||||
original_key_vt.realize()
|
||||
|
||||
def install_dict_keys_match_guard(self) -> None:
|
||||
def install_dict_keys_match_guard(self):
|
||||
if self.source:
|
||||
install_guard(self.make_guard(GuardBuilder.DICT_KEYS_MATCH))
|
||||
|
||||
def install_dict_contains_guard(
|
||||
self, tx: "InstructionTranslator", args: list[VariableTracker]
|
||||
) -> None:
|
||||
def install_dict_contains_guard(self, tx, args):
|
||||
# Key guarding - These are the cases to consider
|
||||
# 1) The dict has been mutated. In this case, we would have already
|
||||
# inserted a DICT_KEYS_MATCH guard, so we can skip.
|
||||
@ -450,11 +439,11 @@ class ConstDictVariable(VariableTracker):
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
tx,
|
||||
name,
|
||||
args: "list[VariableTracker]",
|
||||
kwargs: "dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
# NB - Both key and value are LazyVariableTrackers in the beginning. So,
|
||||
# we have to insert guards when a dict method is accessed. For this to
|
||||
# be simple, we are conservative and overguard. We skip guard only for
|
||||
@ -473,7 +462,7 @@ class ConstDictVariable(VariableTracker):
|
||||
tx, *args, **kwargs
|
||||
)
|
||||
tx.output.side_effects.mutation(self)
|
||||
self.items.update(temp_dict_vt.items) # type: ignore[attr-defined]
|
||||
self.items.update(temp_dict_vt.items)
|
||||
return ConstantVariable.create(None)
|
||||
elif name == "__getitem__":
|
||||
# Key guarding - Nothing to do. LazyVT for value will take care.
|
||||
@ -537,7 +526,7 @@ class ConstDictVariable(VariableTracker):
|
||||
return ConstantVariable.create(len(self.items))
|
||||
elif name == "__setitem__" and self.is_mutable():
|
||||
if not arg_hashable:
|
||||
raise_unhashable(args[0], tx)
|
||||
raise_unhashable(args[0])
|
||||
|
||||
self.install_dict_keys_match_guard()
|
||||
if kwargs or len(args) != 2:
|
||||
@ -561,7 +550,7 @@ class ConstDictVariable(VariableTracker):
|
||||
raise_args_mismatch(tx, name, "1 or 2 args", f"{len(args)} args")
|
||||
|
||||
if not arg_hashable:
|
||||
raise_unhashable(args[0], tx)
|
||||
raise_unhashable(args[0])
|
||||
|
||||
if args[0] not in self:
|
||||
self.install_dict_contains_guard(tx, args)
|
||||
@ -576,7 +565,7 @@ class ConstDictVariable(VariableTracker):
|
||||
raise_args_mismatch(tx, name, "1 or 2 args", f"{len(args)} args")
|
||||
|
||||
if not arg_hashable:
|
||||
raise_unhashable(args[0], tx)
|
||||
raise_unhashable(args[0])
|
||||
|
||||
if args[0] not in self:
|
||||
# missing item, return the default value. Install no DICT_CONTAINS guard.
|
||||
@ -610,7 +599,7 @@ class ConstDictVariable(VariableTracker):
|
||||
last = v.value
|
||||
else:
|
||||
raise_args_mismatch(tx, name)
|
||||
k, v = self.items.popitem(last=last) # type: ignore[possibly-undefined]
|
||||
k, v = self.items.popitem(last=last)
|
||||
else:
|
||||
k, v = self.items.popitem()
|
||||
|
||||
@ -643,17 +632,17 @@ class ConstDictVariable(VariableTracker):
|
||||
# NB - Guard on all the keys of the other dict to ensure
|
||||
# correctness.
|
||||
args[0].install_dict_keys_match_guard()
|
||||
dict_vt: ConstDictVariable = args[0]
|
||||
dict_vt = args[0]
|
||||
else:
|
||||
dict_vt = BuiltinVariable.call_custom_dict(tx, dict, args[0]) # type: ignore[assignment]
|
||||
self.items.update(dict_vt.items) # type: ignore[attr-defined]
|
||||
dict_vt = BuiltinVariable.call_custom_dict(tx, dict, args[0])
|
||||
self.items.update(dict_vt.items)
|
||||
if has_kwargs:
|
||||
# Handle kwargs
|
||||
kwargs_hashable = {
|
||||
kwargs = {
|
||||
Hashable(ConstantVariable.create(k)): v
|
||||
for k, v in kwargs.items()
|
||||
}
|
||||
self.items.update(kwargs_hashable)
|
||||
self.items.update(kwargs)
|
||||
return ConstantVariable.create(None)
|
||||
else:
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
@ -667,7 +656,7 @@ class ConstDictVariable(VariableTracker):
|
||||
)
|
||||
|
||||
if not arg_hashable:
|
||||
raise_unhashable(args[0], tx)
|
||||
raise_unhashable(args[0])
|
||||
|
||||
self.install_dict_contains_guard(tx, args)
|
||||
contains = args[0] in self
|
||||
@ -682,7 +671,7 @@ class ConstDictVariable(VariableTracker):
|
||||
)
|
||||
|
||||
if not arg_hashable:
|
||||
raise_unhashable(args[0], tx)
|
||||
raise_unhashable(args[0])
|
||||
|
||||
self.install_dict_keys_match_guard()
|
||||
if kwargs or len(args) > 2:
|
||||
@ -718,7 +707,7 @@ class ConstDictVariable(VariableTracker):
|
||||
and "last" in kwargs
|
||||
and isinstance(kwargs["last"], ConstantVariable)
|
||||
):
|
||||
last = kwargs.get("last").value # type: ignore[union-attr]
|
||||
last = kwargs.get("last").value
|
||||
|
||||
key = Hashable(args[0])
|
||||
self.items.move_to_end(key, last=last)
|
||||
@ -734,7 +723,7 @@ class ConstDictVariable(VariableTracker):
|
||||
)
|
||||
elif name == "__ne__":
|
||||
return ConstantVariable.create(
|
||||
not self.call_method(tx, "__eq__", args, kwargs).value # type: ignore[attr-defined]
|
||||
not self.call_method(tx, "__eq__", args, kwargs).value
|
||||
)
|
||||
elif name == "__or__":
|
||||
if len(args) != 1:
|
||||
@ -761,14 +750,14 @@ class ConstDictVariable(VariableTracker):
|
||||
if not istype(
|
||||
other, (ConstDictVariable, variables.UserDefinedDictVariable)
|
||||
):
|
||||
err_msg = (
|
||||
msg = (
|
||||
f"unsupported operand type(s) for |: '{self.python_type().__name__}'"
|
||||
f"and '{other.python_type().__name__}'"
|
||||
)
|
||||
raise_observed_exception(TypeError, tx, args=[err_msg])
|
||||
raise_observed_exception(TypeError, tx, args=[msg])
|
||||
|
||||
# OrderedDict overloads __ror__
|
||||
ts = {self.user_cls, other.user_cls} # type: ignore[attr-defined]
|
||||
ts = {self.user_cls, other.user_cls}
|
||||
user_cls = (
|
||||
collections.OrderedDict
|
||||
if any(issubclass(t, collections.OrderedDict) for t in ts)
|
||||
@ -785,8 +774,8 @@ class ConstDictVariable(VariableTracker):
|
||||
|
||||
# NB - Guard on all the keys of the other dict to ensure
|
||||
# correctness.
|
||||
args[0].install_dict_keys_match_guard() # type: ignore[attr-defined]
|
||||
new_dict_vt.items.update(args[0].items) # type: ignore[attr-defined]
|
||||
args[0].install_dict_keys_match_guard()
|
||||
new_dict_vt.items.update(args[0].items)
|
||||
return new_dict_vt
|
||||
elif name == "__ior__":
|
||||
self.call_method(tx, "update", args, kwargs)
|
||||
@ -800,13 +789,11 @@ class ConstDictVariable(VariableTracker):
|
||||
else:
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]:
|
||||
def unpack_var_sequence(self, tx):
|
||||
self.install_dict_keys_match_guard()
|
||||
return [x.vt for x in self.items.keys()]
|
||||
|
||||
def call_obj_hasattr(
|
||||
self, tx: "InstructionTranslator", name: str
|
||||
) -> VariableTracker:
|
||||
def call_obj_hasattr(self, tx, name):
|
||||
# dict not allow setting arbitrary attributes. OrderedDict and
|
||||
# defaultdict allow arbitrary setattr, but not deletion of default attrs
|
||||
if any(
|
||||
@ -829,25 +816,25 @@ class ConstDictVariable(VariableTracker):
|
||||
],
|
||||
)
|
||||
|
||||
def clone(self, **kwargs: Any) -> VariableTracker:
|
||||
def clone(self, **kwargs):
|
||||
self.install_dict_keys_match_guard()
|
||||
return super().clone(**kwargs)
|
||||
|
||||
|
||||
class MappingProxyVariable(VariableTracker):
|
||||
# proxies to the original dict_vt
|
||||
def __init__(self, dv_dict: ConstDictVariable, **kwargs: Any) -> None:
|
||||
def __init__(self, dv_dict: ConstDictVariable, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
assert isinstance(dv_dict, ConstDictVariable)
|
||||
self.dv_dict = dv_dict
|
||||
|
||||
def python_type(self) -> type:
|
||||
def python_type(self):
|
||||
return types.MappingProxyType
|
||||
|
||||
def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]:
|
||||
def unpack_var_sequence(self, tx):
|
||||
return self.dv_dict.unpack_var_sequence(tx)
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
# load types.MappingProxyType
|
||||
if self.source:
|
||||
msg = (
|
||||
@ -876,11 +863,11 @@ class MappingProxyVariable(VariableTracker):
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
tx,
|
||||
name,
|
||||
args: list["VariableTracker"],
|
||||
kwargs: dict[str, "VariableTracker"],
|
||||
) -> "VariableTracker":
|
||||
if self.source and tx.output.side_effects.has_existing_dict_mutation():
|
||||
msg = (
|
||||
"A dict has been modified while we have an existing mappingproxy object. "
|
||||
@ -905,7 +892,7 @@ class MappingProxyVariable(VariableTracker):
|
||||
|
||||
def call_obj_hasattr(
|
||||
self, tx: "InstructionTranslator", name: str
|
||||
) -> VariableTracker:
|
||||
) -> "VariableTracker":
|
||||
if self.python_type() is types.MappingProxyType:
|
||||
return ConstantVariable.create(name in types.MappingProxyType.__dict__)
|
||||
return super().call_obj_hasattr(tx, name)
|
||||
@ -913,44 +900,35 @@ class MappingProxyVariable(VariableTracker):
|
||||
|
||||
class NNModuleHooksDictVariable(ConstDictVariable):
|
||||
# Special class to avoid adding any guards on the nn module hook ids.
|
||||
def install_dict_keys_match_guard(self) -> None:
|
||||
def install_dict_keys_match_guard(self):
|
||||
pass
|
||||
|
||||
def install_dict_contains_guard(
|
||||
self, tx: "InstructionTranslator", args: list[VariableTracker]
|
||||
) -> None:
|
||||
def install_dict_contains_guard(self, tx, args):
|
||||
pass
|
||||
|
||||
|
||||
class DefaultDictVariable(ConstDictVariable):
|
||||
def __init__(
|
||||
self,
|
||||
items: dict[VariableTracker, VariableTracker],
|
||||
user_cls: type,
|
||||
default_factory: Optional[VariableTracker] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
def __init__(self, items, user_cls, default_factory=None, **kwargs) -> None:
|
||||
super().__init__(items, user_cls, **kwargs)
|
||||
assert user_cls is collections.defaultdict
|
||||
if default_factory is None:
|
||||
default_factory = ConstantVariable.create(None)
|
||||
self.default_factory = default_factory
|
||||
|
||||
def is_python_constant(self) -> bool:
|
||||
def is_python_constant(self):
|
||||
# Return false for unsupported defaults. This ensures that a bad handler
|
||||
# path is not taken in BuiltinVariable for getitem.
|
||||
if self.default_factory not in [list, tuple, dict] and not self.items:
|
||||
return False
|
||||
return super().is_python_constant()
|
||||
|
||||
def debug_repr(self) -> str:
|
||||
assert self.default_factory is not None
|
||||
def debug_repr(self):
|
||||
return (
|
||||
f"defaultdict({self.default_factory.debug_repr()}, {super().debug_repr()})"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def is_supported_arg(arg: VariableTracker) -> bool:
|
||||
def is_supported_arg(arg):
|
||||
if isinstance(arg, variables.BuiltinVariable):
|
||||
return arg.fn in (list, tuple, dict, set)
|
||||
else:
|
||||
@ -964,11 +942,11 @@ class DefaultDictVariable(ConstDictVariable):
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
tx,
|
||||
name,
|
||||
args: "list[VariableTracker]",
|
||||
kwargs: "dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
if name == "__getitem__":
|
||||
if len(args) != 1:
|
||||
raise_args_mismatch(tx, name, "1 args", f"{len(args)} args")
|
||||
@ -984,13 +962,13 @@ class DefaultDictVariable(ConstDictVariable):
|
||||
else:
|
||||
default_var = self.default_factory.call_function(tx, [], {})
|
||||
super().call_method(
|
||||
tx, "__setitem__", [args[0], default_var], kwargs
|
||||
tx, "__setitem__", (args[0], default_var), kwargs
|
||||
)
|
||||
return default_var
|
||||
else:
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
def reconstruct(self, codegen):
|
||||
# emit `defaultdict(default_factory, new_dict)`
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.extend_output(
|
||||
@ -1016,48 +994,40 @@ class SetVariable(ConstDictVariable):
|
||||
def __init__(
|
||||
self,
|
||||
items: list[VariableTracker],
|
||||
**kwargs: Any,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
# pyrefly: ignore[bad-assignment]
|
||||
items = dict.fromkeys(items, SetVariable._default_value())
|
||||
# pyrefly: ignore[bad-argument-type]
|
||||
super().__init__(items, **kwargs)
|
||||
|
||||
def debug_repr(self) -> str:
|
||||
def debug_repr(self):
|
||||
if not self.items:
|
||||
return "set()"
|
||||
else:
|
||||
return "{" + ",".join(k.vt.debug_repr() for k in self.items.keys()) + "}"
|
||||
|
||||
@property
|
||||
def set_items(self) -> set["ConstDictVariable._HashableTracker"]:
|
||||
def set_items(self):
|
||||
return set(self.items.keys())
|
||||
|
||||
@staticmethod
|
||||
def _default_value() -> VariableTracker:
|
||||
def _default_value():
|
||||
# Variable to fill in he keys of the dictionary
|
||||
return ConstantVariable.create(None)
|
||||
|
||||
def as_proxy(self) -> Any:
|
||||
def as_proxy(self):
|
||||
return {k.vt.as_proxy() for k in self.set_items}
|
||||
|
||||
def python_type(self) -> type:
|
||||
def python_type(self):
|
||||
return set
|
||||
|
||||
def as_python_constant(self) -> Any:
|
||||
def as_python_constant(self):
|
||||
return {k.vt.as_python_constant() for k in self.set_items}
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
codegen.foreach([x.vt for x in self.set_items])
|
||||
codegen.append_output(create_instruction("BUILD_SET", arg=len(self.set_items)))
|
||||
|
||||
def _fast_set_method(
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
fn: Any,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
def _fast_set_method(self, tx, fn, args, kwargs):
|
||||
try:
|
||||
res = fn(
|
||||
*[x.as_python_constant() for x in [self, *args]],
|
||||
@ -1067,16 +1037,15 @@ class SetVariable(ConstDictVariable):
|
||||
raise_observed_exception(
|
||||
type(exc), tx, args=list(map(ConstantVariable.create, exc.args))
|
||||
)
|
||||
# pyrefly: ignore[unbound-name]
|
||||
return VariableTracker.build(tx, res)
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
tx,
|
||||
name,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
) -> "VariableTracker":
|
||||
# We forward the calls to the dictionary model
|
||||
from ..utils import check_constant_args
|
||||
|
||||
@ -1096,10 +1065,10 @@ class SetVariable(ConstDictVariable):
|
||||
return self._fast_set_method(tx, getattr(py_type, name), args, kwargs)
|
||||
|
||||
if name == "__init__":
|
||||
temp_set_vt = variables.BuiltinVariable(set).call_set(tx, *args, **kwargs)
|
||||
temp_set_vt = variables.BuiltinVariable(set).call_set(tx, *args, *kwargs)
|
||||
tx.output.side_effects.mutation(self)
|
||||
self.items.clear()
|
||||
self.items.update(temp_set_vt.items) # type: ignore[attr-defined]
|
||||
self.items.update(temp_set_vt.items)
|
||||
return ConstantVariable.create(None)
|
||||
elif name == "add":
|
||||
if kwargs or len(args) != 1:
|
||||
@ -1110,7 +1079,7 @@ class SetVariable(ConstDictVariable):
|
||||
f"{len(args)} args and {len(kwargs)} kwargs",
|
||||
)
|
||||
name = "__setitem__"
|
||||
args = [args[0], SetVariable._default_value()]
|
||||
args = (args[0], SetVariable._default_value())
|
||||
elif name == "pop":
|
||||
if kwargs or args:
|
||||
raise_args_mismatch(
|
||||
@ -1121,14 +1090,12 @@ class SetVariable(ConstDictVariable):
|
||||
)
|
||||
# Choose an item at random and pop it via the Dict.pop method
|
||||
try:
|
||||
result: VariableTracker = self.set_items.pop().vt # type: ignore[assignment]
|
||||
result = self.set_items.pop().vt
|
||||
except KeyError as e:
|
||||
raise_observed_exception(
|
||||
KeyError, tx, args=list(map(ConstantVariable.create, e.args))
|
||||
)
|
||||
# pyrefly: ignore[unbound-name]
|
||||
super().call_method(tx, name, [result], kwargs)
|
||||
# pyrefly: ignore[unbound-name]
|
||||
super().call_method(tx, name, (result,), kwargs)
|
||||
return result
|
||||
elif name == "isdisjoint":
|
||||
if kwargs or len(args) != 1:
|
||||
@ -1250,7 +1217,6 @@ class SetVariable(ConstDictVariable):
|
||||
f"unsupported operand type(s) for {name}: '{self.python_type_name()}' and '{args[0].python_type_name()}'"
|
||||
)
|
||||
raise_observed_exception(TypeError, tx, args=[msg])
|
||||
assert m is not None
|
||||
return self.call_method(tx, m, args, kwargs)
|
||||
elif name in ("__iand__", "__ior__", "__ixor__", "__isub__"):
|
||||
if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)):
|
||||
@ -1264,34 +1230,29 @@ class SetVariable(ConstDictVariable):
|
||||
"__ixor__": "symmetric_difference_update",
|
||||
"__isub__": "difference_update",
|
||||
}.get(name)
|
||||
assert m is not None
|
||||
self.call_method(tx, m, args, kwargs)
|
||||
return self
|
||||
elif name == "__eq__":
|
||||
if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)):
|
||||
return ConstantVariable.create(False)
|
||||
r = self.call_method(tx, "symmetric_difference", args, kwargs)
|
||||
return ConstantVariable.create(len(r.set_items) == 0) # type: ignore[attr-defined]
|
||||
return ConstantVariable.create(len(r.set_items) == 0)
|
||||
elif name in cmp_name_to_op_mapping:
|
||||
if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)):
|
||||
return ConstantVariable.create(NotImplemented)
|
||||
return ConstantVariable.create(
|
||||
cmp_name_to_op_mapping[name](self.set_items, args[0].set_items) # type: ignore[attr-defined]
|
||||
cmp_name_to_op_mapping[name](self.set_items, args[0].set_items)
|
||||
)
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
def getitem_const(
|
||||
self, tx: "InstructionTranslator", arg: VariableTracker
|
||||
) -> VariableTracker:
|
||||
def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
|
||||
raise RuntimeError("Illegal to getitem on a set")
|
||||
|
||||
def install_dict_keys_match_guard(self) -> None:
|
||||
def install_dict_keys_match_guard(self):
|
||||
# Already EQUALS_MATCH guarded
|
||||
pass
|
||||
|
||||
def install_dict_contains_guard(
|
||||
self, tx: "InstructionTranslator", args: list[VariableTracker]
|
||||
) -> None:
|
||||
def install_dict_contains_guard(self, tx, args):
|
||||
super().install_dict_contains_guard(tx, args)
|
||||
|
||||
|
||||
@ -1299,27 +1260,27 @@ class FrozensetVariable(SetVariable):
|
||||
def __init__(
|
||||
self,
|
||||
items: list[VariableTracker],
|
||||
**kwargs: Any,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(items, **kwargs)
|
||||
|
||||
def debug_repr(self) -> str:
|
||||
def debug_repr(self):
|
||||
if not self.items:
|
||||
return "frozenset()"
|
||||
else:
|
||||
return "{" + ",".join(k.vt.debug_repr() for k in self.items.keys()) + "}"
|
||||
|
||||
@property
|
||||
def set_items(self) -> set["ConstDictVariable._HashableTracker"]:
|
||||
def set_items(self):
|
||||
return self.items.keys()
|
||||
|
||||
def python_type(self) -> type:
|
||||
def python_type(self):
|
||||
return frozenset
|
||||
|
||||
def as_python_constant(self) -> Any:
|
||||
def as_python_constant(self):
|
||||
return frozenset({k.vt.as_python_constant() for k in self.set_items})
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
codegen.foreach([x.vt for x in self.set_items])
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.extend_output(
|
||||
@ -1332,11 +1293,11 @@ class FrozensetVariable(SetVariable):
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
tx,
|
||||
name,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
) -> "VariableTracker":
|
||||
if name in ["add", "pop", "update", "remove", "discard", "clear"]:
|
||||
raise RuntimeError(f"Illegal call_method {name} on a frozenset")
|
||||
elif name == "__init__":
|
||||
@ -1355,7 +1316,7 @@ class FrozensetVariable(SetVariable):
|
||||
"symmetric_difference",
|
||||
):
|
||||
r = super().call_method(tx, name, args, kwargs)
|
||||
return FrozensetVariable(r.items) # type: ignore[attr-defined]
|
||||
return FrozensetVariable(r.items)
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
|
||||
@ -1363,11 +1324,11 @@ class DictKeySetVariable(SetVariable):
|
||||
def __init__(
|
||||
self,
|
||||
items: list[VariableTracker],
|
||||
**kwargs: Any,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(items, **kwargs)
|
||||
|
||||
def debug_repr(self) -> str:
|
||||
def debug_repr(self):
|
||||
if not self.items:
|
||||
return "dict_keys([])"
|
||||
else:
|
||||
@ -1377,35 +1338,33 @@ class DictKeySetVariable(SetVariable):
|
||||
+ "])"
|
||||
)
|
||||
|
||||
def install_dict_keys_match_guard(self) -> None:
|
||||
def install_dict_keys_match_guard(self):
|
||||
# Already EQUALS_MATCH guarded
|
||||
pass
|
||||
|
||||
def install_dict_contains_guard(
|
||||
self, tx: "InstructionTranslator", args: list[VariableTracker]
|
||||
) -> None:
|
||||
def install_dict_contains_guard(self, tx, args):
|
||||
# Already EQUALS_MATCH guarded
|
||||
pass
|
||||
|
||||
@property
|
||||
def set_items(self) -> Any:
|
||||
def set_items(self):
|
||||
return self.items
|
||||
|
||||
def python_type(self) -> type:
|
||||
def python_type(self):
|
||||
return dict_keys
|
||||
|
||||
def as_python_constant(self) -> Any:
|
||||
def as_python_constant(self):
|
||||
return dict.fromkeys(
|
||||
{k.vt.as_python_constant() for k in self.set_items}, None
|
||||
).keys()
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
tx,
|
||||
name,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
) -> "VariableTracker":
|
||||
if name in ["add", "pop", "update", "remove", "discard", "clear"]:
|
||||
raise RuntimeError(f"Illegal call_method {name} on a dict_keys")
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
@ -1420,47 +1379,42 @@ class DictViewVariable(VariableTracker):
|
||||
|
||||
kv: Optional[str] = None
|
||||
|
||||
def __init__(self, dv_dict: ConstDictVariable, **kwargs: Any) -> None:
|
||||
def __init__(self, dv_dict: ConstDictVariable, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
assert self.kv in ("keys", "values", "items")
|
||||
assert isinstance(dv_dict, ConstDictVariable)
|
||||
self.dv_dict = dv_dict
|
||||
|
||||
@property
|
||||
def view_items(self) -> Any:
|
||||
assert self.kv is not None
|
||||
def view_items(self):
|
||||
return getattr(self.dv_dict.items, self.kv)()
|
||||
|
||||
@property
|
||||
def view_items_vt(self) -> list[VariableTracker]:
|
||||
def view_items_vt(self):
|
||||
# Returns an iterable of the unpacked items
|
||||
# Implement in the subclasses
|
||||
raise NotImplementedError
|
||||
|
||||
def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]:
|
||||
def unpack_var_sequence(self, tx):
|
||||
return self.view_items_vt
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
assert self.kv is not None
|
||||
def reconstruct(self, codegen: "PyCodegen"):
|
||||
codegen(self.dv_dict)
|
||||
codegen.load_method(self.kv)
|
||||
codegen.call_method(0)
|
||||
|
||||
def call_obj_hasattr(
|
||||
self, tx: "InstructionTranslator", name: str
|
||||
) -> VariableTracker:
|
||||
assert self.kv is not None
|
||||
def call_obj_hasattr(self, tx, name):
|
||||
if name in self.python_type().__dict__:
|
||||
return ConstantVariable.create(True)
|
||||
return ConstantVariable.create(False)
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
tx,
|
||||
name,
|
||||
args: list["VariableTracker"],
|
||||
kwargs: dict[str, "VariableTracker"],
|
||||
) -> "VariableTracker":
|
||||
if name == "__len__":
|
||||
return self.dv_dict.call_method(tx, name, args, kwargs)
|
||||
elif name == "__iter__":
|
||||
@ -1474,24 +1428,24 @@ class DictKeysVariable(DictViewVariable):
|
||||
kv = "keys"
|
||||
|
||||
@property
|
||||
def set_items(self) -> set[VariableTracker]:
|
||||
def set_items(self):
|
||||
return set(self.view_items)
|
||||
|
||||
@property
|
||||
def view_items_vt(self) -> list[VariableTracker]:
|
||||
def view_items_vt(self):
|
||||
# Returns an iterable of the unpacked items
|
||||
return [x.vt for x in self.view_items]
|
||||
|
||||
def python_type(self) -> type:
|
||||
def python_type(self):
|
||||
return dict_keys
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
tx,
|
||||
name,
|
||||
args: list["VariableTracker"],
|
||||
kwargs: dict[str, "VariableTracker"],
|
||||
) -> "VariableTracker":
|
||||
if name == "__contains__":
|
||||
return self.dv_dict.call_method(tx, name, args, kwargs)
|
||||
elif name in (
|
||||
@ -1506,13 +1460,13 @@ class DictKeysVariable(DictViewVariable):
|
||||
):
|
||||
# These methods always returns a set
|
||||
m = getattr(self.set_items, name)
|
||||
r = m(args[0].set_items) # type: ignore[attr-defined]
|
||||
r = m(args[0].set_items)
|
||||
return SetVariable(r)
|
||||
if name in cmp_name_to_op_mapping:
|
||||
if not isinstance(args[0], (SetVariable, DictKeysVariable)):
|
||||
return ConstantVariable.create(NotImplemented)
|
||||
return ConstantVariable.create(
|
||||
cmp_name_to_op_mapping[name](self.set_items, args[0].set_items) # type: ignore[attr-defined]
|
||||
cmp_name_to_op_mapping[name](self.set_items, args[0].set_items)
|
||||
)
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
@ -1522,10 +1476,10 @@ class DictValuesVariable(DictViewVariable):
|
||||
kv = "values"
|
||||
|
||||
@property
|
||||
def view_items_vt(self) -> list[VariableTracker]:
|
||||
def view_items_vt(self):
|
||||
return list(self.view_items)
|
||||
|
||||
def python_type(self) -> type:
|
||||
def python_type(self):
|
||||
return dict_values
|
||||
|
||||
|
||||
@ -1533,20 +1487,14 @@ class DictItemsVariable(DictViewVariable):
|
||||
kv = "items"
|
||||
|
||||
@property
|
||||
def view_items_vt(self) -> list[VariableTracker]:
|
||||
def view_items_vt(self):
|
||||
# Returns an iterable of the unpacked items
|
||||
return [variables.TupleVariable([k.vt, v]) for k, v in self.view_items]
|
||||
|
||||
def python_type(self) -> type:
|
||||
def python_type(self):
|
||||
return dict_items
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx: "InstructionTranslator",
|
||||
name: str,
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
def call_method(self, tx, name, args, kwargs):
|
||||
# TODO(guilhermeleobas): This should actually check if args[0]
|
||||
# implements the mapping protocol.
|
||||
if name == "__eq__":
|
||||
|
||||
@ -2970,6 +2970,12 @@ class CppPythonBindingsCodeCache(CppCodeCache):
|
||||
throw std::runtime_error("expected int arg");
|
||||
return reinterpret_cast<uintptr_t>(result);
|
||||
}}
|
||||
template <> inline float parse_arg<float>(PyObject* args, size_t n) {{
|
||||
auto result = PyFloat_AsDouble(PyTuple_GET_ITEM(args, n));
|
||||
if(unlikely(result == -1.0 && PyErr_Occurred()))
|
||||
throw std::runtime_error("expected float arg");
|
||||
return static_cast<float>(result);
|
||||
}}
|
||||
|
||||
{extra_parse_arg}
|
||||
|
||||
|
||||
@ -1732,9 +1732,15 @@ class KernelArgs:
|
||||
call_args.append(self.wrap_ptr_arg(outer, dtype))
|
||||
arg_types.append(f"{cpp_dtype}*")
|
||||
for outer, inner in self.sizevars.items():
|
||||
arg_defs.append(f"const {INDEX_TYPE} {inner}")
|
||||
if isinstance(outer, sympy.Symbol) and symbol_is_type(
|
||||
outer, (SymT.UNBACKED_FLOAT)
|
||||
):
|
||||
arg_defs.append(f"const float {inner}")
|
||||
arg_types.append("const float")
|
||||
else:
|
||||
arg_defs.append(f"const {INDEX_TYPE} {inner}")
|
||||
arg_types.append(f"const {INDEX_TYPE}")
|
||||
call_args.append(self.wrap_size_arg(outer))
|
||||
arg_types.append(f"const {INDEX_TYPE}")
|
||||
if V.graph.wrapper_code:
|
||||
V.graph.wrapper_code.ensure_size_computed(outer)
|
||||
assert not self.workspace_args, "Workspace not supported on CPU "
|
||||
@ -2353,6 +2359,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
|
||||
SymT.UNBACKED_INT,
|
||||
SymT.SIZE,
|
||||
SymT.PRECOMPUTED_SIZE,
|
||||
SymT.UNBACKED_FLOAT,
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from typing import Any, Optional, TYPE_CHECKING
|
||||
from typing import Any, Callable, Optional, Sequence, TYPE_CHECKING
|
||||
|
||||
import sympy # noqa: TC002
|
||||
|
||||
@ -17,8 +17,6 @@ from .simd import SIMDKernel, SIMDScheduling
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable, Sequence
|
||||
|
||||
from ..ir import IRNode
|
||||
from ..scheduler import BaseSchedulerNode
|
||||
|
||||
|
||||
@ -627,7 +627,7 @@ class ComboKernel(Kernel):
|
||||
if heuristics == "foreach":
|
||||
heuristics_line = f"""
|
||||
@triton_heuristics.foreach(
|
||||
filename=__file__,
|
||||
num_warps={self.num_warps},
|
||||
triton_meta={triton_meta!r},
|
||||
inductor_meta={inductor_meta!r},
|
||||
)
|
||||
|
||||
@ -4,6 +4,7 @@ from typing import Any, Optional
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
from torch.utils._sympy.symbol import symbol_is_type, SymT
|
||||
|
||||
from .. import config
|
||||
from ..runtime.hints import AttrsDescriptorWrapper
|
||||
@ -71,6 +72,10 @@ def signature_of(arg: KernelArgType, *, size_dtype: Optional[str]) -> str:
|
||||
return "constexpr"
|
||||
elif isinstance(arg.expr, (float, sympy.Float)):
|
||||
return "fp32"
|
||||
elif isinstance(arg.expr, sympy.Symbol) and symbol_is_type(
|
||||
arg.expr, (SymT.UNBACKED_FLOAT)
|
||||
):
|
||||
return "fp32"
|
||||
elif isinstance(arg.expr, bool):
|
||||
return "i1"
|
||||
|
||||
|
||||
@ -360,7 +360,7 @@ def estimate_nccl_collective_runtime_from_fx_node(
|
||||
fx_node: torch.fx.Node,
|
||||
override_size: Optional[int] = None,
|
||||
# TODO(ivankobzarev): NCCL estimator sometimes fail unexpectedly, enable back after fix.
|
||||
use_nccl_estimator: bool = True,
|
||||
use_nccl_estimator: bool = False,
|
||||
) -> float:
|
||||
"""
|
||||
Returns estimated NCCL collective runtime in nanoseconds (ns).
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from functools import cache, partial
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from torch._environment import is_fbcode
|
||||
|
||||
@ -3586,24 +3586,13 @@ def user_autotune(
|
||||
)
|
||||
|
||||
|
||||
def foreach(triton_meta, filename=None, inductor_meta=None):
|
||||
def foreach(triton_meta, num_warps, filename=None, inductor_meta=None):
|
||||
"""
|
||||
Compile a triton foreach kernel
|
||||
"""
|
||||
configs = []
|
||||
|
||||
# Naive autotuning path for num_warps
|
||||
if not (
|
||||
inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise")
|
||||
):
|
||||
configs.append(triton.Config({}, num_stages=1, num_warps=8))
|
||||
else:
|
||||
for warps in [1, 2, 4, 8]:
|
||||
configs.append(triton.Config({}, num_stages=1, num_warps=warps))
|
||||
|
||||
return cached_autotune(
|
||||
None,
|
||||
configs,
|
||||
[triton.Config({}, num_stages=1, num_warps=num_warps)],
|
||||
triton_meta=triton_meta,
|
||||
inductor_meta=inductor_meta,
|
||||
heuristic_type=HeuristicType.TEMPLATE,
|
||||
|
||||
@ -52,7 +52,26 @@ __all__ = [
|
||||
"MemRecordsAcc",
|
||||
]
|
||||
|
||||
from contextlib import ContextDecorator
|
||||
try:
|
||||
# Available in Python >= 3.2
|
||||
from contextlib import ContextDecorator as _ContextDecorator
|
||||
except ImportError:
|
||||
import functools
|
||||
|
||||
class _ContextDecorator: # type: ignore[no-redef]
|
||||
def __enter__(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(self, func):
|
||||
@functools.wraps(func)
|
||||
def wrapped(*args, **kwargs):
|
||||
with self:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
# global python state - whether profiler is currently enabled
|
||||
@ -725,7 +744,8 @@ class profile:
|
||||
return all_function_events
|
||||
|
||||
|
||||
class record_function(ContextDecorator):
|
||||
# pyrefly: ignore [invalid-inheritance]
|
||||
class record_function(_ContextDecorator):
|
||||
"""Context manager/function decorator that adds a label to a code block/function when running autograd profiler.
|
||||
Label will only appear if CPU activity tracing is enabled.
|
||||
|
||||
|
||||
@ -108,14 +108,12 @@ struct FlightRecorder {
|
||||
capture_cpp_stack_ = getCvarBool(
|
||||
{"TORCH_FR_CPP_STACK", "TORCH_NCCL_TRACE_CPP_STACK"}, false);
|
||||
enabled_ = max_entries_ > 0;
|
||||
reset_epoch_start_idx_[0] = 0;
|
||||
}
|
||||
struct Entry {
|
||||
size_t id_; // incremented id in the trace buffer
|
||||
// used to figure out where in the circular entries
|
||||
// buffer this entry will be located to
|
||||
// update state information
|
||||
size_t reset_epoch_; // epoch when this entry was created
|
||||
size_t pg_id_;
|
||||
std::tuple<std::string, std::string> pg_name_; // <group_name, group_desc>
|
||||
|
||||
@ -185,34 +183,11 @@ struct FlightRecorder {
|
||||
size_t max_entries_ = 0;
|
||||
size_t next_ = 0;
|
||||
size_t id_ = 0;
|
||||
size_t reset_epoch_ = 0;
|
||||
std::unordered_map<size_t, size_t>
|
||||
reset_epoch_start_idx_; // maps reset_epoch to the idx where it starts
|
||||
std::map<size_t, std::shared_ptr<ProcessGroupStatus>> all_pg_status_;
|
||||
std::map<std::tuple<std::string, std::string>, std::vector<uint64_t>>
|
||||
pg_name_to_ranks_;
|
||||
std::string comm_lib_version_;
|
||||
|
||||
struct TraceIdentifier {
|
||||
std::optional<size_t> id;
|
||||
std::optional<size_t> reset_epoch;
|
||||
};
|
||||
|
||||
TraceIdentifier recordWithResetEnabled(
|
||||
size_t pg_id,
|
||||
const std::tuple<std::string, std::string>& pg_name,
|
||||
size_t collective_seq_id,
|
||||
size_t p2p_seq_id,
|
||||
size_t op_id,
|
||||
std::string profiling_name,
|
||||
const std::vector<at::Tensor>& inputs,
|
||||
const std::vector<at::Tensor>& outputs,
|
||||
EventType* start,
|
||||
EventType* end,
|
||||
std::chrono::milliseconds timeout_ms,
|
||||
std::shared_ptr<ProcessGroupStatus> pg_status,
|
||||
bool isP2P);
|
||||
|
||||
std::optional<size_t> record(
|
||||
size_t pg_id,
|
||||
const std::tuple<std::string, std::string>& pg_name,
|
||||
@ -238,16 +213,8 @@ struct FlightRecorder {
|
||||
|
||||
std::vector<Entry> dump_entries();
|
||||
|
||||
// Returns the index in entries_ for the given id and reset_epoch.
|
||||
// Caller must hold mutex_lock before calling this method.
|
||||
size_t getIdxFromId(size_t id, size_t reset_epoch) const;
|
||||
|
||||
// Returns the entry with the given id and reset_epoch, if it exists.
|
||||
// Otherwise, returns std::nullopt.
|
||||
TORCH_API std::optional<Entry> getEntry(
|
||||
std::optional<size_t> id,
|
||||
std::optional<size_t> reset_epoch);
|
||||
|
||||
// Returns the entry with the given id, if it exists. Otherwise, returns
|
||||
// std::nullopt.
|
||||
TORCH_API std::optional<Entry> getEntry(std::optional<size_t> id);
|
||||
|
||||
/*
|
||||
@ -260,11 +227,6 @@ struct FlightRecorder {
|
||||
never hang. (timing must also be enabled for compute_duration - see
|
||||
TORCH_NCCL_ENABLE_TIMING).
|
||||
*/
|
||||
TORCH_API void retire_id(
|
||||
std::optional<size_t> id,
|
||||
std::optional<size_t> reset_epoch,
|
||||
bool compute_duration = true);
|
||||
|
||||
TORCH_API void retire_id(
|
||||
std::optional<size_t> id,
|
||||
bool compute_duration = true);
|
||||
|
||||
@ -53,41 +53,8 @@ std::optional<size_t> FlightRecorder<EventType>::record(
|
||||
std::chrono::milliseconds timeout_ms,
|
||||
std::shared_ptr<ProcessGroupStatus> pg_status,
|
||||
bool isP2P) {
|
||||
auto result = recordWithResetEnabled(
|
||||
pg_id,
|
||||
pg_name,
|
||||
collective_seq_id,
|
||||
p2p_seq_id,
|
||||
op_id,
|
||||
std::move(profiling_name),
|
||||
inputs,
|
||||
outputs,
|
||||
start,
|
||||
end,
|
||||
timeout_ms,
|
||||
std::move(pg_status),
|
||||
isP2P);
|
||||
return result.id;
|
||||
}
|
||||
|
||||
template <typename EventType>
|
||||
typename FlightRecorder<EventType>::TraceIdentifier FlightRecorder<EventType>::
|
||||
recordWithResetEnabled(
|
||||
size_t pg_id,
|
||||
const std::tuple<std::string, std::string>& pg_name,
|
||||
size_t collective_seq_id,
|
||||
size_t p2p_seq_id,
|
||||
size_t op_id,
|
||||
std::string profiling_name,
|
||||
const std::vector<at::Tensor>& inputs,
|
||||
const std::vector<at::Tensor>& outputs,
|
||||
EventType* start,
|
||||
EventType* end,
|
||||
std::chrono::milliseconds timeout_ms,
|
||||
std::shared_ptr<ProcessGroupStatus> pg_status,
|
||||
bool isP2P) {
|
||||
if (!enabled_) {
|
||||
return TraceIdentifier{std::nullopt, std::nullopt};
|
||||
return std::nullopt;
|
||||
}
|
||||
if (all_pg_status_.find(pg_id) == all_pg_status_.end()) {
|
||||
// Current pg_status is not in FR.
|
||||
@ -97,13 +64,8 @@ typename FlightRecorder<EventType>::TraceIdentifier FlightRecorder<EventType>::
|
||||
torch::CapturedTraceback::gather(true, true, capture_cpp_stack_);
|
||||
std::lock_guard<std::mutex> guard(mutex_);
|
||||
|
||||
TORCH_CHECK(
|
||||
reset_epoch_start_idx_.find(reset_epoch_) !=
|
||||
reset_epoch_start_idx_.end());
|
||||
|
||||
auto te = Entry{
|
||||
id_,
|
||||
reset_epoch_,
|
||||
pg_id,
|
||||
pg_name,
|
||||
collective_seq_id,
|
||||
@ -142,20 +104,15 @@ typename FlightRecorder<EventType>::TraceIdentifier FlightRecorder<EventType>::
|
||||
te.sizes_.insert(te.sizes_.end(), sizes.begin(), sizes.end());
|
||||
}
|
||||
|
||||
const auto next = next_++;
|
||||
|
||||
if (entries_.size() < max_entries_) {
|
||||
entries_.emplace_back(std::move(te));
|
||||
} else {
|
||||
entries_[next] = std::move(te);
|
||||
entries_[next_++] = std::move(te);
|
||||
if (next_ == max_entries_) {
|
||||
next_ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
if (next_ == max_entries_) {
|
||||
next_ = 0;
|
||||
}
|
||||
|
||||
const auto id = id_++;
|
||||
return TraceIdentifier{id, reset_epoch_};
|
||||
return id_++;
|
||||
}
|
||||
|
||||
template <typename EventType>
|
||||
@ -206,20 +163,15 @@ std::vector<typename FlightRecorder<EventType>::Entry> FlightRecorder<
|
||||
std::vector<Entry> result;
|
||||
{
|
||||
std::lock_guard<std::mutex> guard(mutex_);
|
||||
// Filter entries during insertion - only keep entries from current epoch
|
||||
auto filter = [this](const Entry& e) {
|
||||
return e.reset_epoch_ == reset_epoch_;
|
||||
};
|
||||
std::copy_if(
|
||||
result.reserve(entries_.size());
|
||||
result.insert(
|
||||
result.end(),
|
||||
entries_.begin() + static_cast<std::ptrdiff_t>(next_),
|
||||
entries_.end(),
|
||||
std::back_inserter(result),
|
||||
filter);
|
||||
std::copy_if(
|
||||
entries_.end());
|
||||
result.insert(
|
||||
result.end(),
|
||||
entries_.begin(),
|
||||
entries_.begin() + static_cast<std::ptrdiff_t>(next_),
|
||||
std::back_inserter(result),
|
||||
filter);
|
||||
entries_.begin() + static_cast<std::ptrdiff_t>(next_));
|
||||
}
|
||||
// query any remaining events
|
||||
for (auto& r : result) {
|
||||
@ -230,47 +182,28 @@ std::vector<typename FlightRecorder<EventType>::Entry> FlightRecorder<
|
||||
}
|
||||
|
||||
template <typename EventType>
|
||||
// Returns the index in entries_ for the given id and reset_epoch.
|
||||
// Caller must hold mutex_lock before calling this method.
|
||||
size_t FlightRecorder<EventType>::getIdxFromId(size_t id, size_t reset_epoch)
|
||||
const {
|
||||
// Look up the starting idx for the given reset epoch
|
||||
auto it = reset_epoch_start_idx_.find(reset_epoch);
|
||||
TORCH_CHECK(it != reset_epoch_start_idx_.end());
|
||||
// Calculate idx based on where the epoch started
|
||||
return (it->second + id) % max_entries_;
|
||||
}
|
||||
|
||||
template <typename EventType>
|
||||
// Returns the entry with the given id and reset_epoch, if it exists. Otherwise,
|
||||
// returns std::nullopt.
|
||||
// Returns the entry with the given id, if it exists. Otherwise, returns
|
||||
// std::nullopt.
|
||||
std::optional<typename FlightRecorder<EventType>::Entry> FlightRecorder<
|
||||
EventType>::
|
||||
getEntry(std::optional<size_t> id, std::optional<size_t> reset_epoch) {
|
||||
if (!enabled_ || !id || !reset_epoch) {
|
||||
EventType>::getEntry(std::optional<size_t> id) {
|
||||
if (!enabled_ || !id) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
std::unique_lock<std::mutex> guard(mutex_);
|
||||
Entry entry = entries_.at(getIdxFromId(*id, *reset_epoch));
|
||||
if (entry.id_ == *id && entry.reset_epoch_ == *reset_epoch) {
|
||||
Entry entry = entries_.at(*id % max_entries_);
|
||||
if (entry.id_ == *id) {
|
||||
return entry;
|
||||
} else {
|
||||
return std::nullopt;
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
template <typename EventType>
|
||||
std::optional<typename FlightRecorder<EventType>::Entry> FlightRecorder<
|
||||
EventType>::getEntry(std::optional<size_t> id) {
|
||||
return getEntry(id, 0);
|
||||
}
|
||||
|
||||
template <typename EventType>
|
||||
void FlightRecorder<EventType>::retire_id(
|
||||
std::optional<size_t> id,
|
||||
std::optional<size_t> reset_epoch,
|
||||
bool compute_duration) {
|
||||
if (!enabled_ || !id || !reset_epoch) {
|
||||
if (!enabled_ || !id) {
|
||||
return;
|
||||
}
|
||||
|
||||
@ -281,8 +214,8 @@ void FlightRecorder<EventType>::retire_id(
|
||||
|
||||
std::unique_lock<std::mutex> guard(mutex_);
|
||||
|
||||
Entry* entry = &entries_.at(getIdxFromId(*id, *reset_epoch));
|
||||
if (entry->id_ == *id && entry->reset_epoch_ == *reset_epoch) {
|
||||
Entry* entry = &entries_.at(*id % max_entries_);
|
||||
if (entry->id_ == *id) {
|
||||
update_state(*entry);
|
||||
|
||||
if (compute_duration) {
|
||||
@ -304,8 +237,8 @@ void FlightRecorder<EventType>::retire_id(
|
||||
guard.lock();
|
||||
|
||||
// Refresh the entry pointer, see if the entry has been overwritten
|
||||
entry = &entries_.at(getIdxFromId(*id, *reset_epoch));
|
||||
if (!(entry->id_ == *id && entry->reset_epoch_ == *reset_epoch)) {
|
||||
entry = &entries_.at(*id % max_entries_);
|
||||
if (entry->id_ != *id) {
|
||||
LOG(INFO) << "retire_id abandoned for id " << *id
|
||||
<< ", event was overwritten while waiting to compute duration.";
|
||||
return;
|
||||
@ -316,23 +249,12 @@ void FlightRecorder<EventType>::retire_id(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename EventType>
|
||||
void FlightRecorder<EventType>::retire_id(
|
||||
std::optional<size_t> id,
|
||||
bool compute_duration) {
|
||||
retire_id(id, 0, compute_duration);
|
||||
}
|
||||
|
||||
template <typename EventType>
|
||||
void FlightRecorder<EventType>::reset_all() {
|
||||
std::lock_guard<std::mutex> guard(mutex_);
|
||||
if (!entries_.empty()) {
|
||||
// Soft delete: increment epoch to mark all existing entries as old
|
||||
// Store where the new epoch starts in the circular buffer
|
||||
reset_epoch_++;
|
||||
reset_epoch_start_idx_[reset_epoch_] = next_;
|
||||
id_ = 0;
|
||||
}
|
||||
next_ = 0;
|
||||
id_ = 0;
|
||||
entries_.clear();
|
||||
}
|
||||
|
||||
template <typename EventType>
|
||||
|
||||
@ -708,8 +708,7 @@ void ProcessGroupGloo::runLoop(int workerIndex) {
|
||||
// TODO: We need to have numel of tensors for gloo as well.
|
||||
pgStatus_->lastCompletedNumelIn = 0;
|
||||
pgStatus_->lastCompletedNumelOut = 0;
|
||||
FlightRecorder<c10::Event>::get()->retire_id(
|
||||
work->trace_id_, work->trace_reset_epoch_, false);
|
||||
FlightRecorder<c10::Event>::get()->retire_id(work->trace_id_, false);
|
||||
lock.lock();
|
||||
workInProgress_[workerIndex].reset();
|
||||
}
|
||||
@ -781,7 +780,7 @@ void ProcessGroupGloo::enqueue(c10::intrusive_ptr<AsyncWork> work) {
|
||||
pgStatus_->lastEnqueuedNumelOut = 0;
|
||||
// using c10d::FlightRecorder;
|
||||
// TODO: We need to have a way to use c10::Event inside gloo as well.
|
||||
auto traceId = FlightRecorder<c10::Event>::get()->recordWithResetEnabled(
|
||||
work->trace_id_ = FlightRecorder<c10::Event>::get()->record(
|
||||
local_id_,
|
||||
std::make_tuple(pg_uid_, pg_desc_),
|
||||
collectiveCounter_,
|
||||
@ -796,8 +795,6 @@ void ProcessGroupGloo::enqueue(c10::intrusive_ptr<AsyncWork> work) {
|
||||
work->getTimeout(),
|
||||
pgStatus_,
|
||||
false);
|
||||
work->trace_id_ = traceId.id;
|
||||
work->trace_reset_epoch_ = traceId.reset_epoch;
|
||||
workQueue_.push_back(std::move(work));
|
||||
lock.unlock();
|
||||
|
||||
|
||||
@ -99,7 +99,6 @@ class TORCH_API ProcessGroupGloo : public Backend {
|
||||
// unique id used to tell the trace buffer that this
|
||||
// work has completed
|
||||
std::optional<uint64_t> trace_id_;
|
||||
std::optional<uint64_t> trace_reset_epoch_;
|
||||
std::shared_ptr<gloo::Context> context_;
|
||||
const std::chrono::milliseconds timeout_;
|
||||
|
||||
|
||||
@ -575,7 +575,6 @@ ProcessGroupNCCL::WorkNCCL::WorkNCCL(const WorkNCCL& w)
|
||||
futureWorkResult_(w.futureWorkResult_),
|
||||
timingEnabled_(w.timingEnabled_),
|
||||
trace_id_(w.trace_id_),
|
||||
trace_reset_epoch_(w.trace_reset_epoch_),
|
||||
distDebugLevel_(w.distDebugLevel_) {
|
||||
exception_ = w.exception_;
|
||||
}
|
||||
@ -705,9 +704,9 @@ bool ProcessGroupNCCL::WorkNCCL::checkTimeout(
|
||||
// Print the traceback of the collective at call time
|
||||
std::string ProcessGroupNCCL::WorkNCCL::getTraceback() const {
|
||||
// First step we get the corresponding record entry from FR, based on work's
|
||||
// trace_id_ and trace_reset_epoch_
|
||||
// trace_id_
|
||||
std::optional<FlightRecorderCUDA::Entry> entry =
|
||||
FlightRecorderCUDA::get()->getEntry(trace_id_, trace_reset_epoch_);
|
||||
FlightRecorderCUDA::get()->getEntry(trace_id_);
|
||||
if (entry.has_value()) {
|
||||
auto entryVal = entry.value();
|
||||
// Get stack trace from FR entry, in string format
|
||||
@ -2395,8 +2394,7 @@ void ProcessGroupNCCL::Watchdog::runLoop() {
|
||||
pg_->pgStatus_->lastCompletedWorkName = opTypeToString(work.opType_);
|
||||
pg_->pgStatus_->lastCompletedNumelIn = work.numelIn_;
|
||||
pg_->pgStatus_->lastCompletedNumelOut = work.numelOut_;
|
||||
FlightRecorderCUDA::get()->retire_id(
|
||||
work.trace_id_, work.trace_reset_epoch_, true);
|
||||
FlightRecorderCUDA::get()->retire_id(work.trace_id_, true);
|
||||
if (pg_->onCompletionHook_) {
|
||||
// Move Work object to completedWorkList_ to be consumed by the hook
|
||||
// thread
|
||||
@ -3362,7 +3360,7 @@ c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> ProcessGroupNCCL::initWork(
|
||||
// these objects to the Work because it has implications for keeping those
|
||||
// tensors alive longer and adds overhead when copying Work objects
|
||||
// between threads
|
||||
auto traceId = FlightRecorderCUDA::get()->recordWithResetEnabled(
|
||||
r->trace_id_ = FlightRecorderCUDA::get()->record(
|
||||
local_id_,
|
||||
std::make_tuple(pg_uid_, pg_desc_),
|
||||
seqCollective_,
|
||||
@ -3376,8 +3374,6 @@ c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> ProcessGroupNCCL::initWork(
|
||||
options_->timeout,
|
||||
pgStatus_,
|
||||
isP2P);
|
||||
r->trace_id_ = traceId.id;
|
||||
r->trace_reset_epoch_ = traceId.reset_epoch;
|
||||
}
|
||||
return r;
|
||||
}
|
||||
@ -3597,7 +3593,6 @@ float ProcessGroupNCCL::endTimeEstimate() {
|
||||
#ifdef NCCL_SIM_INFO_INITIALIZER
|
||||
ncclSimInfo_t simInfo = NCCL_SIM_INFO_INITIALIZER;
|
||||
C10D_NCCL_CHECK(ncclGroupSimulateEnd(&simInfo), std::nullopt);
|
||||
--ncclActiveGroupCounter_;
|
||||
return simInfo.estimatedTime;
|
||||
#else
|
||||
TORCH_CHECK(
|
||||
@ -3681,7 +3676,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
|
||||
// later in endCoalescing we record a 'coalesced' Work which has
|
||||
// timing/state updates via watchdog thread, but lacks op metadata such as
|
||||
// input/output sizes and profilingTitle per-op in the group.
|
||||
FlightRecorderCUDA::get()->recordWithResetEnabled(
|
||||
FlightRecorderCUDA::get()->record(
|
||||
local_id_,
|
||||
std::make_tuple(pg_uid_, pg_desc_),
|
||||
seqCollective_,
|
||||
@ -4173,7 +4168,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::pointToPoint(
|
||||
// TODO(whc) because we don't pass output {tensor} to initWork, we tell
|
||||
// initWork to not record, and then we manually call record passing all the
|
||||
// information it wants.
|
||||
auto traceId = FlightRecorderCUDA::get()->recordWithResetEnabled(
|
||||
work->trace_id_ = FlightRecorderCUDA::get()->record(
|
||||
local_id_,
|
||||
std::make_tuple(pg_uid_, pg_desc_),
|
||||
seqCollective_,
|
||||
@ -4187,8 +4182,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::pointToPoint(
|
||||
options_->timeout,
|
||||
pgStatus_,
|
||||
/*isP2P=*/true);
|
||||
work->trace_id_ = traceId.id;
|
||||
work->trace_reset_epoch_ = traceId.reset_epoch;
|
||||
}
|
||||
|
||||
// Only check for NaN for send ops, for recv ops `tensor` can be a random
|
||||
|
||||
@ -505,7 +505,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
// unique id used to tell the trace buffer that this
|
||||
// work has completed
|
||||
std::optional<uint64_t> trace_id_;
|
||||
std::optional<uint64_t> trace_reset_epoch_;
|
||||
DebugLevel distDebugLevel_;
|
||||
friend class ProcessGroupNCCL;
|
||||
};
|
||||
|
||||
@ -4,7 +4,6 @@
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
#include <torch/headeronly/macros/Macros.h>
|
||||
#include <torch/headeronly/util/Exception.h>
|
||||
#include <torch/headeronly/util/HeaderOnlyArrayRef.h>
|
||||
#include <torch/headeronly/util/shim_utils.h>
|
||||
#include <climits>
|
||||
#include <memory>
|
||||
@ -14,7 +13,6 @@
|
||||
HIDDEN_NAMESPACE_BEGIN(torch, stable)
|
||||
|
||||
using accelerator::DeviceIndex;
|
||||
using torch::headeronly::IntHeaderOnlyArrayRef;
|
||||
using torch::headeronly::ScalarType;
|
||||
|
||||
// The torch::stable::Tensor class is a highlevel C++ wrapper around
|
||||
@ -95,32 +93,6 @@ class Tensor {
|
||||
return numel;
|
||||
}
|
||||
|
||||
// note: this API is, for all intents and purposes, the same as the one in
|
||||
// TensorBase.h: it returns a borrowed reference of the dimension sizes of
|
||||
// a Tensor.
|
||||
//
|
||||
// The only difference is that it returns a header-only IntHeaderOnlyArrayRef,
|
||||
// which has slightly less functionality than a regular IntArrayRef. See
|
||||
// [HeaderOnlyArrayRef vs ArrayRef note] for more details.
|
||||
IntHeaderOnlyArrayRef sizes() const {
|
||||
int64_t* sizes;
|
||||
TORCH_ERROR_CODE_CHECK(aoti_torch_get_sizes(ath_.get(), &sizes));
|
||||
return IntHeaderOnlyArrayRef(sizes, dim());
|
||||
}
|
||||
|
||||
// note: this API is, for all intents and purposes, the same as the one in
|
||||
// TensorBase.h: it returns a borrowed reference of the strides of a
|
||||
// Tensor.
|
||||
//
|
||||
// The only difference is that it returns a header-only IntHeaderOnlyArrayRef,
|
||||
// which has slightly less functionality than a regular IntArrayRef. See
|
||||
// [HeaderOnlyArrayRef vs ArrayRef note] for more details.
|
||||
IntHeaderOnlyArrayRef strides() const {
|
||||
int64_t* strides;
|
||||
TORCH_ERROR_CODE_CHECK(aoti_torch_get_strides(ath_.get(), &strides));
|
||||
return IntHeaderOnlyArrayRef(strides, dim());
|
||||
}
|
||||
|
||||
// note: this is a subset of the original TensorBase API. It takes no
|
||||
// arguments whereas the original API takes in a kwarg of memory format.
|
||||
// Here, we assume the default contiguous memory format.
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
import functools
|
||||
import math
|
||||
import operator
|
||||
from collections.abc import Callable, Sequence
|
||||
from collections.abc import Sequence
|
||||
from datetime import timedelta
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from torch._C import ScriptObject
|
||||
|
||||
@ -10,7 +10,6 @@ from ._context_parallel._attention import (
|
||||
_enable_context_parallel_dispatcher,
|
||||
_is_causal_behavior,
|
||||
_RotateMethod,
|
||||
_templated_ring_attention,
|
||||
context_parallel,
|
||||
context_parallel_unshard,
|
||||
set_rotate_method,
|
||||
@ -23,7 +22,6 @@ from ._context_parallel._load_balancer import (
|
||||
)
|
||||
|
||||
|
||||
# TODO(fegin): add deprecation message once the final interfaces are concluded.
|
||||
__all__ = [
|
||||
"_CausalBehavior",
|
||||
"_context_parallel_shard",
|
||||
@ -33,7 +31,6 @@ __all__ = [
|
||||
"_enable_context_parallel_dispatcher",
|
||||
"_is_causal_behavior",
|
||||
"_RotateMethod",
|
||||
"_templated_ring_attention",
|
||||
"context_parallel",
|
||||
"context_parallel_unshard",
|
||||
"set_rotate_method",
|
||||
|
||||
@ -647,15 +647,6 @@ class CodeGen:
|
||||
|
||||
if verbose:
|
||||
# override annotation with more detailed information
|
||||
try:
|
||||
from torch.distributed.tensor._api import DTensor, DTensorSpec
|
||||
|
||||
dtensorspec_format_shard_order_str = (
|
||||
DTensorSpec.format_shard_order_str
|
||||
)
|
||||
except ModuleNotFoundError:
|
||||
DTensor = None # type: ignore[assignment,misc]
|
||||
dtensorspec_format_shard_order_str = None
|
||||
from torch.fx.experimental.proxy_tensor import py_sym_types
|
||||
from torch.fx.passes.shape_prop import TensorMetadata
|
||||
|
||||
@ -686,16 +677,6 @@ class CodeGen:
|
||||
core = _tensor_annotation(meta_val)
|
||||
if is_plain:
|
||||
maybe_type_annotation = f': "{core}"'
|
||||
elif type(meta_val) is DTensor:
|
||||
assert dtensorspec_format_shard_order_str is not None
|
||||
dtensor_meta = dtensorspec_format_shard_order_str(
|
||||
meta_val._spec.placements, # type: ignore[attr-defined]
|
||||
meta_val._spec.shard_order, # type: ignore[attr-defined]
|
||||
)
|
||||
cls = meta_val.__class__.__name__
|
||||
maybe_type_annotation = (
|
||||
f': "{cls}({core}, {dim_green(dtensor_meta)})"'
|
||||
)
|
||||
else:
|
||||
cls = meta_val.__class__.__name__
|
||||
maybe_type_annotation = f': "{cls}({core})"'
|
||||
|
||||
Reference in New Issue
Block a user