mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-06 17:24:59 +08:00
Compare commits
20 Commits
lucaskabel
...
ciflow/tru
| Author | SHA1 | Date | |
|---|---|---|---|
| 94f210d947 | |||
| a344069f2a | |||
| af829c0dad | |||
| 3869aa115b | |||
| 47eb34b7ac | |||
| 08200280ce | |||
| ad7a57262c | |||
| 711a775878 | |||
| e9a688f02e | |||
| e69aaaf45a | |||
| fd8f368d31 | |||
| 13d2cc7bd2 | |||
| c6c913d18e | |||
| ef3f953966 | |||
| ea44f12bce | |||
| a74fe75c45 | |||
| 6d30666bc1 | |||
| 8e8cbb85ee | |||
| fbd70fb84e | |||
| 6c5db82584 |
@ -1 +1 @@
|
||||
bfeb066872bc1e8b2d2bc0a3b295b99dd77206e7
|
||||
40eb62cb371b4c2b350c0d735dd65d4f905ee0fe
|
||||
|
||||
@ -18,7 +18,7 @@ aspects of contributing to PyTorch.
|
||||
- [Python Unit Testing](#python-unit-testing)
|
||||
- [Better local unit tests with `pytest`](#better-local-unit-tests-with-pytest)
|
||||
- [Local linting](#local-linting)
|
||||
- [Running `mypy`](#running-mypy)
|
||||
- [Running `pyrefly`](#running-pyrefly)
|
||||
- [C++ Unit Testing](#c-unit-testing)
|
||||
- [Run Specific CI Jobs](#run-specific-ci-jobs)
|
||||
- [Merging your Change](#merging-your-change)
|
||||
@ -281,7 +281,7 @@ dependencies as well as the nightly binaries into the repo directory.
|
||||
**Prerequisites**:
|
||||
The following packages should be installed with `pip`:
|
||||
- `expecttest` and `hypothesis` - required to run tests
|
||||
- `mypy` - recommended for linting
|
||||
- `pyrefly` - recommended for type checking. [Pyrefly](https://pyrefly.org/)
|
||||
- `pytest` - recommended to run tests more selectively
|
||||
Running
|
||||
```
|
||||
@ -350,15 +350,32 @@ make lint
|
||||
|
||||
Learn more about the linter on the [lintrunner wiki page](https://github.com/pytorch/pytorch/wiki/lintrunner)
|
||||
|
||||
#### Running `mypy`
|
||||
#### Running `pyrefly`
|
||||
|
||||
`mypy` is an optional static type checker for Python. We have multiple `mypy`
|
||||
configs for the PyTorch codebase that are automatically validated against whenever the linter is run.
|
||||
[Pyrefly](https://pyrefly.org/) is a high-performance static type checker for Python. It provides fast type checking along with IDE features like autocomplete and instant error feedback.
|
||||
|
||||
PyTorch uses Pyrefly for type checking across the codebase. The configuration is managed in `pyrefly.toml` at the root of the repository.
|
||||
|
||||
**Getting Started with Pyrefly:**
|
||||
|
||||
To run type checking on the PyTorch codebase:
|
||||
```bash
|
||||
pyrefly check
|
||||
```
|
||||
|
||||
For more detailed error information with summaries:
|
||||
```bash
|
||||
pyrefly check --summarize-errors
|
||||
```
|
||||
|
||||
**Learn More:**
|
||||
- [Pyrefly Configuration](https://pyrefly.org/en/docs/configuration/) - Detailed configuration options
|
||||
- [Pyrefly IDE Features](https://pyrefly.org/en/docs/IDE-features/) - Set up Pyrefly in your editor for real-time type checking
|
||||
- [Python Typing Tutorial](https://pyrefly.org/en/docs/typing-for-python-developers/) - Learn about Python type annotations
|
||||
|
||||
See [Guide for adding type annotations to
|
||||
PyTorch](https://github.com/pytorch/pytorch/wiki/Guide-for-adding-type-annotations-to-PyTorch)
|
||||
for more information on how to set up `mypy` and tackle type annotation
|
||||
tasks.
|
||||
for PyTorch-specific guidance on how to set up `pyrefly` and tackle type annotation tasks in this codebase.
|
||||
|
||||
### C++ Unit Testing
|
||||
|
||||
|
||||
@ -24,7 +24,13 @@ namespace detail {
|
||||
// radix_sort_pairs doesn't interact with value_t other than to copy
|
||||
// the data, so we can save template instantiations by reinterpreting
|
||||
// it as an opaque type.
|
||||
// We use native integer types for 1/2/4/8-byte values to reduce
|
||||
// register usage in CUDA kernels. For sizes > 8 fall back to char array.
|
||||
template <int N> struct alignas(N) OpaqueType { char data[N]; };
|
||||
template <> struct alignas(1) OpaqueType<1> { uint8_t data; };
|
||||
template <> struct alignas(2) OpaqueType<2> { uint16_t data; };
|
||||
template <> struct alignas(4) OpaqueType<4> { uint32_t data; };
|
||||
template <> struct alignas(8) OpaqueType<8> { uint64_t data; };
|
||||
|
||||
template<typename key_t, int value_size>
|
||||
void radix_sort_pairs_impl(
|
||||
|
||||
@ -247,8 +247,8 @@ void binary_kernel_reduce(TensorIteratorBase& iter, ops_t ops, init_t init) {
|
||||
});
|
||||
}
|
||||
|
||||
template <typename func_t, typename vec_func_t>
|
||||
void binary_kernel_reduce_vec(TensorIteratorBase& iter, func_t op, vec_func_t vop, double ident = 0) {
|
||||
template <typename func_t, typename vec_func_t, typename ident_t = double>
|
||||
void binary_kernel_reduce_vec(TensorIteratorBase& iter, func_t op, vec_func_t vop, ident_t ident = static_cast<ident_t>(0)) {
|
||||
using traits = binary_function_traits<func_t>;
|
||||
static_assert(
|
||||
all_same<
|
||||
|
||||
@ -339,33 +339,13 @@ void or_kernel_impl(TensorIterator& iter) {
|
||||
}
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
struct MinValuesOps: public at::native::MinOps<scalar_t> {
|
||||
using arg_t = typename MinOps<scalar_t>::arg_t;
|
||||
static scalar_t project(arg_t arg) {
|
||||
return arg.first;
|
||||
}
|
||||
};
|
||||
|
||||
void min_values_kernel_impl(TensorIterator& iter) {
|
||||
// This case is special because of Vectorized<int64_t> does not
|
||||
// handle upper_bound<int64_t>().
|
||||
// See: https://github.com/pytorch/pytorch/issues/43254
|
||||
if (iter.dtype() == kLong || iter.dtype() == kUInt64) {
|
||||
AT_DISPATCH_V2(iter.dtype(), "min_values_cpu", AT_WRAP([&iter] {
|
||||
binary_kernel_reduce(
|
||||
iter,
|
||||
MinValuesOps<scalar_t>{},
|
||||
std::pair<scalar_t, int64_t>(upper_bound<scalar_t>(), -1));
|
||||
}), kLong, kUInt64);
|
||||
return;
|
||||
}
|
||||
AT_DISPATCH_V2(iter.dtype(), "min_values_cpu", AT_WRAP([&iter] {
|
||||
binary_kernel_reduce_vec(
|
||||
iter,
|
||||
[](scalar_t a, scalar_t b) -> scalar_t { return min_impl(a, b); },
|
||||
[](Vectorized<scalar_t> a, Vectorized<scalar_t> b) { return minimum(a, b); },
|
||||
static_cast<double>(upper_bound<scalar_t>()));
|
||||
upper_bound<scalar_t>());
|
||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool);
|
||||
}
|
||||
|
||||
|
||||
@ -47,20 +47,10 @@ Tensor sgd_out_of_place(
|
||||
STD_TORCH_CHECK(param.get_device() == -1, "CPU device index = -1");
|
||||
STD_TORCH_CHECK(param.get_device_index() == -1, "CPU device index = -1");
|
||||
|
||||
int64_t *param_sizes;
|
||||
int64_t *param_strides;
|
||||
aoti_torch_get_sizes(param.get(), ¶m_sizes);
|
||||
aoti_torch_get_strides(param.get(), ¶m_strides);
|
||||
// testing Tensor strides + stride
|
||||
STD_TORCH_CHECK(param.strides()[0] == param.stride(0));
|
||||
|
||||
int32_t param_dtype;
|
||||
aoti_torch_get_dtype(param.get(), ¶m_dtype);
|
||||
|
||||
int32_t param_device_type;
|
||||
aoti_torch_get_device_type(param.get(), ¶m_device_type);
|
||||
|
||||
AtenTensorHandle out_ath;
|
||||
aoti_torch_empty_strided(param.dim(), param_sizes, param_strides, param_dtype, param_device_type, param.get_device(), &out_ath);
|
||||
auto out = Tensor(out_ath);
|
||||
auto out = new_empty(param, param.sizes());
|
||||
|
||||
sgd_math(
|
||||
reinterpret_cast<float*>(param.data_ptr()),
|
||||
@ -344,6 +334,8 @@ Tensor my_new_empty_dtype_variant(Tensor t) {
|
||||
// Still using a std::vector below even though people can just pass in an
|
||||
// initializer list (which will be implicitly converted to an HeaderOnlyArrayRef)
|
||||
// directly.
|
||||
// This is to test that passing in a std::vector works for BC. (It gets
|
||||
// implicitly converted to HeaderOnlyArrayRef too!)
|
||||
std::vector<int64_t> sizes = {2, 5};
|
||||
auto dtype = std::make_optional(torch::headeronly::ScalarType::BFloat16);
|
||||
return new_empty(t, sizes, dtype);
|
||||
|
||||
@ -5789,6 +5789,229 @@ class NCCLTraceTest(NCCLTraceTestBase):
|
||||
else:
|
||||
self.assertTrue("duration_ms" not in t["entries"][0])
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
@parametrize("timing_enabled", [True, False])
|
||||
def test_fr_record_reset_circular_buffer_full(self, timing_enabled):
|
||||
"""
|
||||
Test that when the circular buffer in entries_ is full and we call reset,
|
||||
then fill the buffer with new entries, dump_entries returns only the new
|
||||
entries and not the old ones.
|
||||
"""
|
||||
if self.rank == self.MAIN_PROCESS_RANK:
|
||||
return
|
||||
|
||||
# Override buffer size to 10 for faster testing
|
||||
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10"
|
||||
|
||||
pg = self._create_process_group_nccl()
|
||||
if timing_enabled:
|
||||
pg._enable_collectives_timing()
|
||||
device = self.local_device
|
||||
self.set_thread_name("fr_test_thread")
|
||||
a = torch.full((3, 4), float(self.rank), device=device)
|
||||
|
||||
# Fill the buffer completely with 10 entries
|
||||
for _ in range(10):
|
||||
f = pg.allreduce(a)
|
||||
f.wait()
|
||||
torch.cuda.synchronize(device=device)
|
||||
time.sleep(1)
|
||||
|
||||
# Verify buffer is full with 10 entries
|
||||
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
|
||||
self.assertEqual(len(t["entries"]), 10)
|
||||
|
||||
# Now reset the flight recorder
|
||||
torch._C._distributed_c10d._reset_fr_recording_nccl()
|
||||
|
||||
# Add new entries after reset - fill the buffer completely again
|
||||
for _ in range(10):
|
||||
f = pg.allreduce(a)
|
||||
f.wait()
|
||||
torch.cuda.synchronize(device=device)
|
||||
time.sleep(1)
|
||||
|
||||
# Verify we get exactly 10 new entries, not 20
|
||||
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
|
||||
self.assertEqual(len(t["entries"]), 10)
|
||||
|
||||
# Verify all entries have the expected properties (from after reset)
|
||||
# After reset, record IDs should start from 0 again
|
||||
for i, entry in enumerate(t["entries"]):
|
||||
self.assertIn("profiling_name", entry)
|
||||
self.assertEqual(entry["profiling_name"], "nccl:all_reduce")
|
||||
self.assertIn("record_id", entry)
|
||||
# Record IDs should be sequential starting from 0 after reset
|
||||
self.assertEqual(entry["record_id"], i)
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
@parametrize("timing_enabled", [True, False])
|
||||
def test_fr_record_reset_partial_overwrite(self, timing_enabled):
|
||||
"""
|
||||
Test that when the circular buffer is full, we reset, and then add fewer
|
||||
entries than the buffer size, we only get the new entries.
|
||||
This tests that old entries at the end of the circular buffer are properly
|
||||
filtered out based on reset_epoch.
|
||||
"""
|
||||
if self.rank == self.MAIN_PROCESS_RANK:
|
||||
return
|
||||
|
||||
# Override buffer size to 10 for faster testing
|
||||
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10"
|
||||
|
||||
pg = self._create_process_group_nccl()
|
||||
if timing_enabled:
|
||||
pg._enable_collectives_timing()
|
||||
device = self.local_device
|
||||
self.set_thread_name("fr_test_thread")
|
||||
a = torch.full((3, 4), float(self.rank), device=device)
|
||||
|
||||
# Fill the buffer completely
|
||||
for _ in range(10):
|
||||
f = pg.allreduce(a)
|
||||
f.wait()
|
||||
torch.cuda.synchronize(device=device)
|
||||
time.sleep(1)
|
||||
|
||||
# Reset the flight recorder
|
||||
torch._C._distributed_c10d._reset_fr_recording_nccl()
|
||||
|
||||
# Add only 3 new entries (much less than buffer size)
|
||||
for _ in range(3):
|
||||
f = pg.allreduce(a)
|
||||
f.wait()
|
||||
torch.cuda.synchronize(device=device)
|
||||
time.sleep(1)
|
||||
|
||||
# Verify we only get the 3 new entries, not 10
|
||||
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
|
||||
self.assertEqual(len(t["entries"]), 3)
|
||||
|
||||
# Verify record IDs start from 0 after reset
|
||||
for i, entry in enumerate(t["entries"]):
|
||||
self.assertIn("record_id", entry)
|
||||
self.assertEqual(entry["record_id"], i)
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
@parametrize("timing_enabled", [True, False])
|
||||
def test_fr_record_reset_wraparound(self, timing_enabled):
|
||||
"""
|
||||
Test that when we reset in the middle of the circular buffer and then
|
||||
wrap around, dump_entries correctly returns only entries from the current
|
||||
epoch in the correct order.
|
||||
"""
|
||||
if self.rank == self.MAIN_PROCESS_RANK:
|
||||
return
|
||||
|
||||
# Override buffer size to 10 for faster testing
|
||||
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10"
|
||||
|
||||
pg = self._create_process_group_nccl()
|
||||
if timing_enabled:
|
||||
pg._enable_collectives_timing()
|
||||
device = self.local_device
|
||||
self.set_thread_name("fr_test_thread")
|
||||
a = torch.full((3, 4), float(self.rank), device=device)
|
||||
|
||||
# Fill half the buffer
|
||||
for _ in range(5):
|
||||
f = pg.allreduce(a)
|
||||
f.wait()
|
||||
torch.cuda.synchronize(device=device)
|
||||
time.sleep(1)
|
||||
|
||||
# Reset at this point (reset happens at index 5)
|
||||
torch._C._distributed_c10d._reset_fr_recording_nccl()
|
||||
|
||||
# Now add 8 entries, which will wrap around
|
||||
# (5->9 fills rest of buffer, then 0->2 wraps around)
|
||||
for _ in range(8):
|
||||
f = pg.allreduce(a)
|
||||
f.wait()
|
||||
torch.cuda.synchronize(device=device)
|
||||
time.sleep(1)
|
||||
|
||||
# Should get exactly 8 entries, properly ordered
|
||||
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
|
||||
self.assertEqual(len(t["entries"]), 8)
|
||||
|
||||
# Entries should be in chronological order
|
||||
# The dump_entries() method returns entries from next_ to end, then 0 to next_
|
||||
# After filtering old entries, we should have 8 entries in order
|
||||
# Verify record IDs start from 0 after reset (id_ is reset in reset_all())
|
||||
for i, entry in enumerate(t["entries"]):
|
||||
self.assertIn("profiling_name", entry)
|
||||
self.assertIn("record_id", entry)
|
||||
self.assertEqual(entry["record_id"], i)
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
@parametrize("timing_enabled", [True, False])
|
||||
def test_fr_record_multiple_resets(self, timing_enabled):
|
||||
"""
|
||||
Test multiple consecutive resets to ensure each reset properly increments
|
||||
the epoch and filters out entries from previous epochs.
|
||||
"""
|
||||
if self.rank == self.MAIN_PROCESS_RANK:
|
||||
return
|
||||
|
||||
# Override buffer size to 10 for faster testing
|
||||
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10"
|
||||
|
||||
pg = self._create_process_group_nccl()
|
||||
if timing_enabled:
|
||||
pg._enable_collectives_timing()
|
||||
device = self.local_device
|
||||
self.set_thread_name("fr_test_thread")
|
||||
a = torch.full((3, 4), float(self.rank), device=device)
|
||||
|
||||
# First batch: 2 entries
|
||||
for _ in range(2):
|
||||
f = pg.allreduce(a)
|
||||
f.wait()
|
||||
torch.cuda.synchronize(device=device)
|
||||
time.sleep(1)
|
||||
|
||||
# First reset
|
||||
torch._C._distributed_c10d._reset_fr_recording_nccl()
|
||||
|
||||
# Second batch: 3 entries
|
||||
for _ in range(3):
|
||||
f = pg.allreduce(a)
|
||||
f.wait()
|
||||
torch.cuda.synchronize(device=device)
|
||||
time.sleep(1)
|
||||
|
||||
# Second reset
|
||||
torch._C._distributed_c10d._reset_fr_recording_nccl()
|
||||
|
||||
# Third batch: 4 entries
|
||||
for _ in range(4):
|
||||
f = pg.allreduce(a)
|
||||
f.wait()
|
||||
torch.cuda.synchronize(device=device)
|
||||
time.sleep(1)
|
||||
|
||||
# Should only see the last 4 entries
|
||||
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
|
||||
self.assertEqual(len(t["entries"]), 4)
|
||||
|
||||
# Verify record IDs start from 0 after the last reset
|
||||
for i, entry in enumerate(t["entries"]):
|
||||
self.assertIn("record_id", entry)
|
||||
self.assertEqual(entry["record_id"], i)
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def check_if_test_is_skipped(fn):
|
||||
def wrapper(self, *args, **kwargs):
|
||||
|
||||
@ -8,21 +8,11 @@ from torch._dynamo.graph_deduplication import apply_graph_deduplication
|
||||
from torch._dynamo.graph_utils import _detect_cycles
|
||||
from torch._dynamo.output_graph import FakeRootModule
|
||||
from torch._dynamo.test_case import TestCase
|
||||
from torch._dynamo.testing import (
|
||||
AotEagerAndRecordGraphs,
|
||||
extract_graph_and_tracker,
|
||||
normalize_gm,
|
||||
)
|
||||
from torch._dynamo.testing import extract_graph, extract_graph_and_tracker, normalize_gm
|
||||
from torch.compiler import allow_in_graph
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
|
||||
def extract_graph(fn, *args, **kwargs):
|
||||
backend = AotEagerAndRecordGraphs()
|
||||
result = torch.compile(backend=backend)(fn)(*args, **kwargs)
|
||||
return result, backend.graphs, backend.fw_graphs
|
||||
|
||||
|
||||
def graph_str(gm):
|
||||
return normalize_gm(gm.print_readable(print_output=False))
|
||||
|
||||
@ -40,7 +30,7 @@ class GraphDededuplicationTests(TestCase):
|
||||
super().tearDown()
|
||||
|
||||
def run_and_return_graphs(self, fn, *args, **kwargs):
|
||||
return extract_graph(fn, *args, **kwargs)
|
||||
return extract_graph(fn, *args, **kwargs)[0:3]
|
||||
|
||||
def run_and_get_simple_graph(self):
|
||||
def fn(x, y):
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
import unittest
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Callable, Union
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any, Union
|
||||
|
||||
import torch
|
||||
import torch._dynamo
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
from typing import Callable, NamedTuple, Optional
|
||||
from typing import NamedTuple, Optional, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch._dynamo
|
||||
@ -7,6 +7,10 @@ from torch._dynamo.test_case import run_tests, TestCase
|
||||
from torch._dynamo.testing import CompileCounter, same
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
|
||||
"""
|
||||
This is an example of a pure-python version of autograd implemented by
|
||||
@zdevito. It represents a rather challenging test case for TorchDynamo
|
||||
|
||||
@ -1,11 +1,13 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
import functools
|
||||
import re
|
||||
import unittest
|
||||
import weakref
|
||||
|
||||
import torch
|
||||
import torch._dynamo.test_case
|
||||
import torch._dynamo.testing
|
||||
from torch._dynamo.testing import extract_graph, remove_trailing_space
|
||||
from torch.testing._internal.common_cuda import TEST_MULTIGPU
|
||||
from torch.testing._internal.common_utils import requires_cuda
|
||||
|
||||
@ -15,6 +17,14 @@ requires_multigpu = functools.partial(
|
||||
)
|
||||
|
||||
|
||||
def remove_file_comment(gm_str: str) -> str:
|
||||
return remove_trailing_space(re.sub(r"File.*\n", "\n", gm_str))
|
||||
|
||||
|
||||
def print_graph(graph: torch.fx.GraphModule) -> str:
|
||||
return remove_file_comment(graph.print_readable())
|
||||
|
||||
|
||||
class TestStreams(torch._dynamo.test_case.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
@ -36,9 +46,7 @@ class TestStreams(torch._dynamo.test_case.TestCase):
|
||||
|
||||
@requires_cuda
|
||||
def test_stream_enter_exit(self):
|
||||
def fn(x, y):
|
||||
s2 = torch.Stream()
|
||||
s1 = torch.Stream()
|
||||
def fn(x, y, s1, s2):
|
||||
with s1:
|
||||
z1 = torch.add(x, y)
|
||||
with s2:
|
||||
@ -47,13 +55,36 @@ class TestStreams(torch._dynamo.test_case.TestCase):
|
||||
|
||||
return y
|
||||
|
||||
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
|
||||
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2), torch.Stream(), torch.Stream())
|
||||
expected = fn(*inp)
|
||||
fn_opt = torch.compile(fn, fullgraph=True)
|
||||
actual = fn_opt(*inp)
|
||||
(
|
||||
actual,
|
||||
_,
|
||||
fw_graphs,
|
||||
_,
|
||||
) = extract_graph(fn, *inp)
|
||||
self.assertEqual(len(fw_graphs), 1)
|
||||
self.assertEqual(expected, actual)
|
||||
self.assertExpectedInline(
|
||||
print_graph(fw_graphs[0]),
|
||||
"""\
|
||||
class <lambda>(torch.nn.Module):
|
||||
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
|
||||
# Annotation: {'stream': None}
|
||||
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
|
||||
|
||||
# Annotation: {'stream': None}
|
||||
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
|
||||
|
||||
# Annotation: {'stream': None}
|
||||
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 2); add_1 = None
|
||||
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_2, add); add_2 = add = None
|
||||
return (add_3,)
|
||||
""",
|
||||
)
|
||||
|
||||
@requires_cuda
|
||||
@unittest.skip("Needs graph break support with annotation context")
|
||||
def test_stream_context_graph_break(self):
|
||||
def fn(x, y):
|
||||
s2 = torch.Stream()
|
||||
@ -70,9 +101,16 @@ class TestStreams(torch._dynamo.test_case.TestCase):
|
||||
|
||||
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
|
||||
expected = fn(*inp)
|
||||
fn_opt = torch.compile(fn)
|
||||
actual = fn_opt(*inp)
|
||||
(
|
||||
actual,
|
||||
_,
|
||||
fw_graphs,
|
||||
_,
|
||||
) = extract_graph(fn, *inp)
|
||||
self.assertEqual(expected, actual)
|
||||
self.assertEqual(len(fw_graphs), 2)
|
||||
self.assertExpectedInline(print_graph(fw_graphs[0]), """""")
|
||||
self.assertExpectedInline(print_graph(fw_graphs[1]), """""")
|
||||
|
||||
@requires_cuda
|
||||
def test_stream_input(self):
|
||||
@ -155,22 +193,248 @@ class TestStreams(torch._dynamo.test_case.TestCase):
|
||||
self.assertEqual(s_act, s_exp)
|
||||
|
||||
def test_nested_stream_enter_exit(self):
|
||||
pass
|
||||
def fn(x, y, s0, s1, s2):
|
||||
with s1:
|
||||
with s2:
|
||||
z1 = torch.add(x, y)
|
||||
with s0:
|
||||
z0 = torch.add(x, y)
|
||||
with s2:
|
||||
y = 2 + z1
|
||||
|
||||
return z0, y
|
||||
|
||||
inp = (
|
||||
torch.ones(2, 2) + 1,
|
||||
torch.ones(2, 2),
|
||||
torch.Stream(),
|
||||
torch.Stream(),
|
||||
torch.Stream(),
|
||||
)
|
||||
expected = fn(*inp)
|
||||
(
|
||||
actual,
|
||||
_,
|
||||
fw_graphs,
|
||||
_,
|
||||
) = extract_graph(fn, *inp)
|
||||
self.assertEqual(len(fw_graphs), 1)
|
||||
self.assertEqual(expected, actual)
|
||||
self.assertExpectedInline(
|
||||
print_graph(fw_graphs[0]),
|
||||
"""\
|
||||
class <lambda>(torch.nn.Module):
|
||||
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
|
||||
# Annotation: {'stream': None}
|
||||
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
|
||||
|
||||
# Annotation: {'stream': None}
|
||||
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
|
||||
|
||||
# Annotation: {'stream': None}
|
||||
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add, 2); add = None
|
||||
return (add_1, add_2)
|
||||
""",
|
||||
)
|
||||
|
||||
@unittest.skip("Needs graph break support with annotation context")
|
||||
def test_stream_enter_exit_graph_break(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Needs graph break support with annotation context")
|
||||
def test_nested_stream_enter_exit_graph_break(self):
|
||||
pass
|
||||
|
||||
def test_local_stream_enter_exit(self):
|
||||
pass
|
||||
def fn(x, y):
|
||||
s2 = torch.Stream()
|
||||
s1 = torch.Stream()
|
||||
with s1:
|
||||
z1 = torch.add(x, y)
|
||||
with s2:
|
||||
z = torch.add(x, y)
|
||||
y = z + 2 + z1
|
||||
|
||||
return y
|
||||
|
||||
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
|
||||
expected = fn(*inp)
|
||||
(
|
||||
actual,
|
||||
_,
|
||||
fw_graphs,
|
||||
_,
|
||||
) = extract_graph(fn, *inp)
|
||||
self.assertEqual(len(fw_graphs), 1)
|
||||
self.assertEqual(expected, actual)
|
||||
self.assertExpectedInline(
|
||||
print_graph(fw_graphs[0]),
|
||||
"""\
|
||||
class <lambda>(torch.nn.Module):
|
||||
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
|
||||
# Annotation: {'stream': 1}
|
||||
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
|
||||
|
||||
# Annotation: {'stream': 0}
|
||||
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
|
||||
|
||||
# Annotation: {'stream': 0}
|
||||
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 2); add_1 = None
|
||||
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_2, add); add_2 = add = None
|
||||
return (add_3,)
|
||||
""",
|
||||
)
|
||||
|
||||
def test_local_stream_nested_enter_exit(self):
|
||||
pass
|
||||
def fn(x, y):
|
||||
s2 = torch.Stream()
|
||||
s1 = torch.Stream()
|
||||
s0 = torch.Stream()
|
||||
with s1:
|
||||
with s2:
|
||||
z1 = torch.add(x, y)
|
||||
with s0:
|
||||
z0 = torch.add(x, y)
|
||||
with s2:
|
||||
y = 2 + z1
|
||||
|
||||
return z0, y
|
||||
|
||||
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
|
||||
expected = fn(*inp)
|
||||
(
|
||||
actual,
|
||||
_,
|
||||
fw_graphs,
|
||||
_,
|
||||
) = extract_graph(fn, *inp)
|
||||
self.assertEqual(len(fw_graphs), 1)
|
||||
self.assertEqual(expected, actual)
|
||||
self.assertExpectedInline(
|
||||
print_graph(fw_graphs[0]),
|
||||
"""\
|
||||
class <lambda>(torch.nn.Module):
|
||||
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
|
||||
# Annotation: {'stream': 0}
|
||||
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
|
||||
|
||||
# Annotation: {'stream': 2}
|
||||
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
|
||||
|
||||
# Annotation: {'stream': 0}
|
||||
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add, 2); add = None
|
||||
return (add_1, add_2)
|
||||
""",
|
||||
)
|
||||
|
||||
def test_stream_with_mutation(self):
|
||||
pass
|
||||
def fn(x, y):
|
||||
s2 = torch.Stream()
|
||||
s1 = torch.Stream()
|
||||
s0 = torch.Stream()
|
||||
with s1:
|
||||
with s2:
|
||||
x.add_(y)
|
||||
with s0:
|
||||
z1 = torch.add(y, y)
|
||||
z0 = torch.add(z1, y)
|
||||
with s2:
|
||||
y = 2 + z1
|
||||
|
||||
return z0, y
|
||||
|
||||
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
|
||||
expected = fn(*inp)
|
||||
(
|
||||
actual,
|
||||
_,
|
||||
fw_graphs,
|
||||
_,
|
||||
) = extract_graph(fn, *inp)
|
||||
self.assertEqual(len(fw_graphs), 1)
|
||||
self.assertEqual(expected, actual)
|
||||
self.assertExpectedInline(
|
||||
print_graph(fw_graphs[0]),
|
||||
"""\
|
||||
class <lambda>(torch.nn.Module):
|
||||
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
|
||||
# Annotation: {'stream': 0}
|
||||
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
|
||||
|
||||
# Annotation: {'stream': 2}
|
||||
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg1_1, arg1_1)
|
||||
|
||||
# Annotation: {'stream': 2}
|
||||
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, arg1_1); arg1_1 = None
|
||||
|
||||
# Annotation: {'stream': 0}
|
||||
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 2); add_1 = None
|
||||
|
||||
#
|
||||
copy_: "f32[2, 2]" = torch.ops.aten.copy_.default(arg0_1, add); arg0_1 = add = copy_ = None
|
||||
return (add_2, add_3)
|
||||
""",
|
||||
)
|
||||
|
||||
def test_stream_backward(self) -> None:
|
||||
def fn(x, y):
|
||||
s2 = torch.Stream()
|
||||
s0 = torch.Stream()
|
||||
with s0:
|
||||
y0 = 2 * x + y
|
||||
with s2:
|
||||
z = 2 * x + y
|
||||
|
||||
return y0, z
|
||||
|
||||
inp = (
|
||||
torch.ones(2, 2, requires_grad=True) + 1,
|
||||
torch.ones(2, 2, requires_grad=True),
|
||||
)
|
||||
expected = fn(*inp)
|
||||
(
|
||||
actual,
|
||||
_,
|
||||
fw_graphs,
|
||||
bw_graphs,
|
||||
) = extract_graph(fn, *inp)
|
||||
self.assertEqual(len(fw_graphs), 1)
|
||||
self.assertEqual(expected, actual)
|
||||
self.assertExpectedInline(
|
||||
print_graph(fw_graphs[0]),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "f32[2, 2]", primals_2: "f32[2, 2]"):
|
||||
# Annotation: {'stream': 1}
|
||||
mul: "f32[2, 2]" = torch.ops.aten.mul.Tensor(primals_1, 2); primals_1 = None
|
||||
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul, primals_2)
|
||||
|
||||
# Annotation: {'stream': 0}
|
||||
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul, primals_2); mul = primals_2 = None
|
||||
return (add, add_1)
|
||||
""",
|
||||
)
|
||||
|
||||
actual[1].sum().backward()
|
||||
self.assertExpectedInline(
|
||||
print_graph(bw_graphs[0]),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, tangents_1: "f32[2, 2]", tangents_2: "f32[2, 2]"):
|
||||
# Annotation: {'stream': 0}
|
||||
mul_2: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_2, 2)
|
||||
|
||||
#
|
||||
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(tangents_2, tangents_1); tangents_2 = None
|
||||
|
||||
# Annotation: {'stream': 1}
|
||||
mul_3: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_1, 2); tangents_1 = None
|
||||
|
||||
#
|
||||
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul_2, mul_3); mul_2 = mul_3 = None
|
||||
return (add_3, add_2)
|
||||
""",
|
||||
)
|
||||
|
||||
@requires_cuda
|
||||
def test_run_opcheck(self):
|
||||
|
||||
@ -14424,20 +14424,6 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
||||
|
||||
self.common(fn, (torch.randn(6, 4, device=GPU_TYPE).t().contiguous().t(),))
|
||||
|
||||
@skip_if_halide
|
||||
@requires_cuda_and_triton
|
||||
def test_unbacked_float_item(self):
|
||||
def fn(x, max_val):
|
||||
return torch.clamp(x, 0, max_val.item())
|
||||
|
||||
self.common(
|
||||
fn,
|
||||
(
|
||||
torch.randn(10, 20, 30, device=self.device),
|
||||
torch.tensor(5.0, device=self.device),
|
||||
),
|
||||
)
|
||||
|
||||
# end of class CommonTemplate - add new tests here
|
||||
|
||||
|
||||
|
||||
@ -1864,6 +1864,8 @@ class TestFP8Matmul(TestCase):
|
||||
], name_fn=lambda mkn: f"{mkn[0]}_{mkn[1]}_{mkn[2]}")
|
||||
@parametrize("recipe", ["mxfp8", "mxfp4", "nvfp4"])
|
||||
def test_blockwise_mxfp8_nvfp4_mxfp4_numerics(self, test_case_name, fast_accum, mkn, recipe) -> None:
|
||||
if torch.version.hip and recipe == "nvfp4":
|
||||
raise unittest.SkipTest("nvfp4 not supported on ROCm, skipping")
|
||||
if (recipe == "nvfp4" or recipe == "mxfp4") and fast_accum:
|
||||
raise unittest.SkipTest("fast_accum not supported in nvfp4/mxfp4 cublas gemm, skipping")
|
||||
|
||||
|
||||
@ -1914,6 +1914,7 @@ class TestSDPAFailureModes(NNTestCase):
|
||||
q, k, v, None, 0.0, is_causal=True))
|
||||
|
||||
@onlyCUDA
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support Efficient Attention")
|
||||
def test_mem_eff_attention_fail_with_batch_size_geq_65536(self):
|
||||
batch_size = 2**16
|
||||
query = torch.rand([batch_size, 2, 2, 8], device='cuda', dtype=torch.float16, requires_grad=True)
|
||||
@ -1935,6 +1936,7 @@ class TestSDPAFailureModes(NNTestCase):
|
||||
self.assertEqual(value.grad, v_cpu.grad, atol=2e-3, rtol=1e-4)
|
||||
|
||||
@onlyCUDA
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support Efficient Attention")
|
||||
def test_mem_eff_attention_fail_with_batch_size_geq_65536_error(self):
|
||||
query = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16)
|
||||
key = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16)
|
||||
@ -1948,6 +1950,7 @@ class TestSDPAFailureModes(NNTestCase):
|
||||
|
||||
@largeTensorTest("15GB", "cuda")
|
||||
@onlyCUDA
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support Efficient Attention")
|
||||
def test_mem_eff_attention_large_seq_len_uniform_attention(self):
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.bfloat16
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from typing import Union
|
||||
from typing_extensions import assert_type, TypeAlias
|
||||
from typing import TypeAlias, Union
|
||||
from typing_extensions import assert_type
|
||||
|
||||
from torch import randn, Tensor
|
||||
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
# mypy: allow-untyped-defs
|
||||
# mypy: disable-error-code="type-arg"
|
||||
from collections.abc import Callable
|
||||
from datetime import timedelta
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Optional, overload, Union
|
||||
from typing import Any, Optional, overload, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
@ -87,6 +87,12 @@ def extract_graph_and_tracker(fn, *args, **kwargs): # type: ignore[no-untyped-d
|
||||
return gm.graph, region_tracker # type: ignore[union-attr]
|
||||
|
||||
|
||||
def extract_graph(fn, *args, **kwargs): # type: ignore[no-untyped-def]
|
||||
backend = AotEagerAndRecordGraphs()
|
||||
result = torch.compile(backend=backend)(fn)(*args, **kwargs)
|
||||
return result, backend.graphs, backend.fw_graphs, backend.bw_graphs
|
||||
|
||||
|
||||
def collect_results(
|
||||
model: torch.nn.Module, prediction: Any, loss: Any, example_inputs: Any
|
||||
) -> list[Any]:
|
||||
|
||||
@ -21,9 +21,9 @@ restoring state changes.
|
||||
import inspect
|
||||
import sys
|
||||
import warnings
|
||||
from collections.abc import Callable, Sequence
|
||||
from collections.abc import Callable, Sequence, Sized
|
||||
from contextlib import ExitStack
|
||||
from typing import Any, ContextManager, Optional, Sized, TYPE_CHECKING, Union
|
||||
from typing import Any, ContextManager, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch._C
|
||||
from torch._guards import Guard
|
||||
|
||||
@ -2970,12 +2970,6 @@ class CppPythonBindingsCodeCache(CppCodeCache):
|
||||
throw std::runtime_error("expected int arg");
|
||||
return reinterpret_cast<uintptr_t>(result);
|
||||
}}
|
||||
template <> inline float parse_arg<float>(PyObject* args, size_t n) {{
|
||||
auto result = PyFloat_AsDouble(PyTuple_GET_ITEM(args, n));
|
||||
if(unlikely(result == -1.0 && PyErr_Occurred()))
|
||||
throw std::runtime_error("expected float arg");
|
||||
return static_cast<float>(result);
|
||||
}}
|
||||
|
||||
{extra_parse_arg}
|
||||
|
||||
|
||||
@ -1732,15 +1732,9 @@ class KernelArgs:
|
||||
call_args.append(self.wrap_ptr_arg(outer, dtype))
|
||||
arg_types.append(f"{cpp_dtype}*")
|
||||
for outer, inner in self.sizevars.items():
|
||||
if isinstance(outer, sympy.Symbol) and symbol_is_type(
|
||||
outer, (SymT.UNBACKED_FLOAT)
|
||||
):
|
||||
arg_defs.append(f"const float {inner}")
|
||||
arg_types.append("const float")
|
||||
else:
|
||||
arg_defs.append(f"const {INDEX_TYPE} {inner}")
|
||||
arg_types.append(f"const {INDEX_TYPE}")
|
||||
arg_defs.append(f"const {INDEX_TYPE} {inner}")
|
||||
call_args.append(self.wrap_size_arg(outer))
|
||||
arg_types.append(f"const {INDEX_TYPE}")
|
||||
if V.graph.wrapper_code:
|
||||
V.graph.wrapper_code.ensure_size_computed(outer)
|
||||
assert not self.workspace_args, "Workspace not supported on CPU "
|
||||
@ -2359,7 +2353,6 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
|
||||
SymT.UNBACKED_INT,
|
||||
SymT.SIZE,
|
||||
SymT.PRECOMPUTED_SIZE,
|
||||
SymT.UNBACKED_FLOAT,
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from typing import Any, Callable, Optional, Sequence, TYPE_CHECKING
|
||||
from typing import Any, Optional, TYPE_CHECKING
|
||||
|
||||
import sympy # noqa: TC002
|
||||
|
||||
@ -17,6 +17,8 @@ from .simd import SIMDKernel, SIMDScheduling
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable, Sequence
|
||||
|
||||
from ..ir import IRNode
|
||||
from ..scheduler import BaseSchedulerNode
|
||||
|
||||
|
||||
@ -627,7 +627,7 @@ class ComboKernel(Kernel):
|
||||
if heuristics == "foreach":
|
||||
heuristics_line = f"""
|
||||
@triton_heuristics.foreach(
|
||||
num_warps={self.num_warps},
|
||||
filename=__file__,
|
||||
triton_meta={triton_meta!r},
|
||||
inductor_meta={inductor_meta!r},
|
||||
)
|
||||
|
||||
@ -4,7 +4,6 @@ from typing import Any, Optional
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
from torch.utils._sympy.symbol import symbol_is_type, SymT
|
||||
|
||||
from .. import config
|
||||
from ..runtime.hints import AttrsDescriptorWrapper
|
||||
@ -72,10 +71,6 @@ def signature_of(arg: KernelArgType, *, size_dtype: Optional[str]) -> str:
|
||||
return "constexpr"
|
||||
elif isinstance(arg.expr, (float, sympy.Float)):
|
||||
return "fp32"
|
||||
elif isinstance(arg.expr, sympy.Symbol) and symbol_is_type(
|
||||
arg.expr, (SymT.UNBACKED_FLOAT)
|
||||
):
|
||||
return "fp32"
|
||||
elif isinstance(arg.expr, bool):
|
||||
return "i1"
|
||||
|
||||
|
||||
@ -360,7 +360,7 @@ def estimate_nccl_collective_runtime_from_fx_node(
|
||||
fx_node: torch.fx.Node,
|
||||
override_size: Optional[int] = None,
|
||||
# TODO(ivankobzarev): NCCL estimator sometimes fail unexpectedly, enable back after fix.
|
||||
use_nccl_estimator: bool = False,
|
||||
use_nccl_estimator: bool = True,
|
||||
) -> float:
|
||||
"""
|
||||
Returns estimated NCCL collective runtime in nanoseconds (ns).
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from functools import cache, partial
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from torch._environment import is_fbcode
|
||||
|
||||
@ -3586,13 +3586,24 @@ def user_autotune(
|
||||
)
|
||||
|
||||
|
||||
def foreach(triton_meta, num_warps, filename=None, inductor_meta=None):
|
||||
def foreach(triton_meta, filename=None, inductor_meta=None):
|
||||
"""
|
||||
Compile a triton foreach kernel
|
||||
"""
|
||||
configs = []
|
||||
|
||||
# Naive autotuning path for num_warps
|
||||
if not (
|
||||
inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise")
|
||||
):
|
||||
configs.append(triton.Config({}, num_stages=1, num_warps=8))
|
||||
else:
|
||||
for warps in [1, 2, 4, 8]:
|
||||
configs.append(triton.Config({}, num_stages=1, num_warps=warps))
|
||||
|
||||
return cached_autotune(
|
||||
None,
|
||||
[triton.Config({}, num_stages=1, num_warps=num_warps)],
|
||||
configs,
|
||||
triton_meta=triton_meta,
|
||||
inductor_meta=inductor_meta,
|
||||
heuristic_type=HeuristicType.TEMPLATE,
|
||||
|
||||
@ -52,26 +52,7 @@ __all__ = [
|
||||
"MemRecordsAcc",
|
||||
]
|
||||
|
||||
try:
|
||||
# Available in Python >= 3.2
|
||||
from contextlib import ContextDecorator as _ContextDecorator
|
||||
except ImportError:
|
||||
import functools
|
||||
|
||||
class _ContextDecorator: # type: ignore[no-redef]
|
||||
def __enter__(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(self, func):
|
||||
@functools.wraps(func)
|
||||
def wrapped(*args, **kwargs):
|
||||
with self:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapped
|
||||
from contextlib import ContextDecorator
|
||||
|
||||
|
||||
# global python state - whether profiler is currently enabled
|
||||
@ -744,8 +725,7 @@ class profile:
|
||||
return all_function_events
|
||||
|
||||
|
||||
# pyrefly: ignore [invalid-inheritance]
|
||||
class record_function(_ContextDecorator):
|
||||
class record_function(ContextDecorator):
|
||||
"""Context manager/function decorator that adds a label to a code block/function when running autograd profiler.
|
||||
Label will only appear if CPU activity tracing is enabled.
|
||||
|
||||
|
||||
@ -108,12 +108,14 @@ struct FlightRecorder {
|
||||
capture_cpp_stack_ = getCvarBool(
|
||||
{"TORCH_FR_CPP_STACK", "TORCH_NCCL_TRACE_CPP_STACK"}, false);
|
||||
enabled_ = max_entries_ > 0;
|
||||
reset_epoch_start_idx_[0] = 0;
|
||||
}
|
||||
struct Entry {
|
||||
size_t id_; // incremented id in the trace buffer
|
||||
// used to figure out where in the circular entries
|
||||
// buffer this entry will be located to
|
||||
// update state information
|
||||
size_t reset_epoch_; // epoch when this entry was created
|
||||
size_t pg_id_;
|
||||
std::tuple<std::string, std::string> pg_name_; // <group_name, group_desc>
|
||||
|
||||
@ -183,11 +185,34 @@ struct FlightRecorder {
|
||||
size_t max_entries_ = 0;
|
||||
size_t next_ = 0;
|
||||
size_t id_ = 0;
|
||||
size_t reset_epoch_ = 0;
|
||||
std::unordered_map<size_t, size_t>
|
||||
reset_epoch_start_idx_; // maps reset_epoch to the idx where it starts
|
||||
std::map<size_t, std::shared_ptr<ProcessGroupStatus>> all_pg_status_;
|
||||
std::map<std::tuple<std::string, std::string>, std::vector<uint64_t>>
|
||||
pg_name_to_ranks_;
|
||||
std::string comm_lib_version_;
|
||||
|
||||
struct TraceIdentifier {
|
||||
std::optional<size_t> id;
|
||||
std::optional<size_t> reset_epoch;
|
||||
};
|
||||
|
||||
TraceIdentifier recordWithResetEnabled(
|
||||
size_t pg_id,
|
||||
const std::tuple<std::string, std::string>& pg_name,
|
||||
size_t collective_seq_id,
|
||||
size_t p2p_seq_id,
|
||||
size_t op_id,
|
||||
std::string profiling_name,
|
||||
const std::vector<at::Tensor>& inputs,
|
||||
const std::vector<at::Tensor>& outputs,
|
||||
EventType* start,
|
||||
EventType* end,
|
||||
std::chrono::milliseconds timeout_ms,
|
||||
std::shared_ptr<ProcessGroupStatus> pg_status,
|
||||
bool isP2P);
|
||||
|
||||
std::optional<size_t> record(
|
||||
size_t pg_id,
|
||||
const std::tuple<std::string, std::string>& pg_name,
|
||||
@ -213,8 +238,16 @@ struct FlightRecorder {
|
||||
|
||||
std::vector<Entry> dump_entries();
|
||||
|
||||
// Returns the entry with the given id, if it exists. Otherwise, returns
|
||||
// std::nullopt.
|
||||
// Returns the index in entries_ for the given id and reset_epoch.
|
||||
// Caller must hold mutex_lock before calling this method.
|
||||
size_t getIdxFromId(size_t id, size_t reset_epoch) const;
|
||||
|
||||
// Returns the entry with the given id and reset_epoch, if it exists.
|
||||
// Otherwise, returns std::nullopt.
|
||||
TORCH_API std::optional<Entry> getEntry(
|
||||
std::optional<size_t> id,
|
||||
std::optional<size_t> reset_epoch);
|
||||
|
||||
TORCH_API std::optional<Entry> getEntry(std::optional<size_t> id);
|
||||
|
||||
/*
|
||||
@ -227,6 +260,11 @@ struct FlightRecorder {
|
||||
never hang. (timing must also be enabled for compute_duration - see
|
||||
TORCH_NCCL_ENABLE_TIMING).
|
||||
*/
|
||||
TORCH_API void retire_id(
|
||||
std::optional<size_t> id,
|
||||
std::optional<size_t> reset_epoch,
|
||||
bool compute_duration = true);
|
||||
|
||||
TORCH_API void retire_id(
|
||||
std::optional<size_t> id,
|
||||
bool compute_duration = true);
|
||||
|
||||
@ -53,8 +53,41 @@ std::optional<size_t> FlightRecorder<EventType>::record(
|
||||
std::chrono::milliseconds timeout_ms,
|
||||
std::shared_ptr<ProcessGroupStatus> pg_status,
|
||||
bool isP2P) {
|
||||
auto result = recordWithResetEnabled(
|
||||
pg_id,
|
||||
pg_name,
|
||||
collective_seq_id,
|
||||
p2p_seq_id,
|
||||
op_id,
|
||||
std::move(profiling_name),
|
||||
inputs,
|
||||
outputs,
|
||||
start,
|
||||
end,
|
||||
timeout_ms,
|
||||
std::move(pg_status),
|
||||
isP2P);
|
||||
return result.id;
|
||||
}
|
||||
|
||||
template <typename EventType>
|
||||
typename FlightRecorder<EventType>::TraceIdentifier FlightRecorder<EventType>::
|
||||
recordWithResetEnabled(
|
||||
size_t pg_id,
|
||||
const std::tuple<std::string, std::string>& pg_name,
|
||||
size_t collective_seq_id,
|
||||
size_t p2p_seq_id,
|
||||
size_t op_id,
|
||||
std::string profiling_name,
|
||||
const std::vector<at::Tensor>& inputs,
|
||||
const std::vector<at::Tensor>& outputs,
|
||||
EventType* start,
|
||||
EventType* end,
|
||||
std::chrono::milliseconds timeout_ms,
|
||||
std::shared_ptr<ProcessGroupStatus> pg_status,
|
||||
bool isP2P) {
|
||||
if (!enabled_) {
|
||||
return std::nullopt;
|
||||
return TraceIdentifier{std::nullopt, std::nullopt};
|
||||
}
|
||||
if (all_pg_status_.find(pg_id) == all_pg_status_.end()) {
|
||||
// Current pg_status is not in FR.
|
||||
@ -64,8 +97,13 @@ std::optional<size_t> FlightRecorder<EventType>::record(
|
||||
torch::CapturedTraceback::gather(true, true, capture_cpp_stack_);
|
||||
std::lock_guard<std::mutex> guard(mutex_);
|
||||
|
||||
TORCH_CHECK(
|
||||
reset_epoch_start_idx_.find(reset_epoch_) !=
|
||||
reset_epoch_start_idx_.end());
|
||||
|
||||
auto te = Entry{
|
||||
id_,
|
||||
reset_epoch_,
|
||||
pg_id,
|
||||
pg_name,
|
||||
collective_seq_id,
|
||||
@ -104,15 +142,20 @@ std::optional<size_t> FlightRecorder<EventType>::record(
|
||||
te.sizes_.insert(te.sizes_.end(), sizes.begin(), sizes.end());
|
||||
}
|
||||
|
||||
const auto next = next_++;
|
||||
|
||||
if (entries_.size() < max_entries_) {
|
||||
entries_.emplace_back(std::move(te));
|
||||
} else {
|
||||
entries_[next_++] = std::move(te);
|
||||
if (next_ == max_entries_) {
|
||||
next_ = 0;
|
||||
}
|
||||
entries_[next] = std::move(te);
|
||||
}
|
||||
return id_++;
|
||||
|
||||
if (next_ == max_entries_) {
|
||||
next_ = 0;
|
||||
}
|
||||
|
||||
const auto id = id_++;
|
||||
return TraceIdentifier{id, reset_epoch_};
|
||||
}
|
||||
|
||||
template <typename EventType>
|
||||
@ -163,15 +206,20 @@ std::vector<typename FlightRecorder<EventType>::Entry> FlightRecorder<
|
||||
std::vector<Entry> result;
|
||||
{
|
||||
std::lock_guard<std::mutex> guard(mutex_);
|
||||
result.reserve(entries_.size());
|
||||
result.insert(
|
||||
result.end(),
|
||||
// Filter entries during insertion - only keep entries from current epoch
|
||||
auto filter = [this](const Entry& e) {
|
||||
return e.reset_epoch_ == reset_epoch_;
|
||||
};
|
||||
std::copy_if(
|
||||
entries_.begin() + static_cast<std::ptrdiff_t>(next_),
|
||||
entries_.end());
|
||||
result.insert(
|
||||
result.end(),
|
||||
entries_.end(),
|
||||
std::back_inserter(result),
|
||||
filter);
|
||||
std::copy_if(
|
||||
entries_.begin(),
|
||||
entries_.begin() + static_cast<std::ptrdiff_t>(next_));
|
||||
entries_.begin() + static_cast<std::ptrdiff_t>(next_),
|
||||
std::back_inserter(result),
|
||||
filter);
|
||||
}
|
||||
// query any remaining events
|
||||
for (auto& r : result) {
|
||||
@ -182,28 +230,47 @@ std::vector<typename FlightRecorder<EventType>::Entry> FlightRecorder<
|
||||
}
|
||||
|
||||
template <typename EventType>
|
||||
// Returns the entry with the given id, if it exists. Otherwise, returns
|
||||
// std::nullopt.
|
||||
// Returns the index in entries_ for the given id and reset_epoch.
|
||||
// Caller must hold mutex_lock before calling this method.
|
||||
size_t FlightRecorder<EventType>::getIdxFromId(size_t id, size_t reset_epoch)
|
||||
const {
|
||||
// Look up the starting idx for the given reset epoch
|
||||
auto it = reset_epoch_start_idx_.find(reset_epoch);
|
||||
TORCH_CHECK(it != reset_epoch_start_idx_.end());
|
||||
// Calculate idx based on where the epoch started
|
||||
return (it->second + id) % max_entries_;
|
||||
}
|
||||
|
||||
template <typename EventType>
|
||||
// Returns the entry with the given id and reset_epoch, if it exists. Otherwise,
|
||||
// returns std::nullopt.
|
||||
std::optional<typename FlightRecorder<EventType>::Entry> FlightRecorder<
|
||||
EventType>::getEntry(std::optional<size_t> id) {
|
||||
if (!enabled_ || !id) {
|
||||
EventType>::
|
||||
getEntry(std::optional<size_t> id, std::optional<size_t> reset_epoch) {
|
||||
if (!enabled_ || !id || !reset_epoch) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
std::unique_lock<std::mutex> guard(mutex_);
|
||||
Entry entry = entries_.at(*id % max_entries_);
|
||||
if (entry.id_ == *id) {
|
||||
Entry entry = entries_.at(getIdxFromId(*id, *reset_epoch));
|
||||
if (entry.id_ == *id && entry.reset_epoch_ == *reset_epoch) {
|
||||
return entry;
|
||||
} else {
|
||||
return std::nullopt;
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
template <typename EventType>
|
||||
std::optional<typename FlightRecorder<EventType>::Entry> FlightRecorder<
|
||||
EventType>::getEntry(std::optional<size_t> id) {
|
||||
return getEntry(id, 0);
|
||||
}
|
||||
|
||||
template <typename EventType>
|
||||
void FlightRecorder<EventType>::retire_id(
|
||||
std::optional<size_t> id,
|
||||
std::optional<size_t> reset_epoch,
|
||||
bool compute_duration) {
|
||||
if (!enabled_ || !id) {
|
||||
if (!enabled_ || !id || !reset_epoch) {
|
||||
return;
|
||||
}
|
||||
|
||||
@ -214,8 +281,8 @@ void FlightRecorder<EventType>::retire_id(
|
||||
|
||||
std::unique_lock<std::mutex> guard(mutex_);
|
||||
|
||||
Entry* entry = &entries_.at(*id % max_entries_);
|
||||
if (entry->id_ == *id) {
|
||||
Entry* entry = &entries_.at(getIdxFromId(*id, *reset_epoch));
|
||||
if (entry->id_ == *id && entry->reset_epoch_ == *reset_epoch) {
|
||||
update_state(*entry);
|
||||
|
||||
if (compute_duration) {
|
||||
@ -237,8 +304,8 @@ void FlightRecorder<EventType>::retire_id(
|
||||
guard.lock();
|
||||
|
||||
// Refresh the entry pointer, see if the entry has been overwritten
|
||||
entry = &entries_.at(*id % max_entries_);
|
||||
if (entry->id_ != *id) {
|
||||
entry = &entries_.at(getIdxFromId(*id, *reset_epoch));
|
||||
if (!(entry->id_ == *id && entry->reset_epoch_ == *reset_epoch)) {
|
||||
LOG(INFO) << "retire_id abandoned for id " << *id
|
||||
<< ", event was overwritten while waiting to compute duration.";
|
||||
return;
|
||||
@ -249,12 +316,23 @@ void FlightRecorder<EventType>::retire_id(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename EventType>
|
||||
void FlightRecorder<EventType>::retire_id(
|
||||
std::optional<size_t> id,
|
||||
bool compute_duration) {
|
||||
retire_id(id, 0, compute_duration);
|
||||
}
|
||||
|
||||
template <typename EventType>
|
||||
void FlightRecorder<EventType>::reset_all() {
|
||||
std::lock_guard<std::mutex> guard(mutex_);
|
||||
next_ = 0;
|
||||
id_ = 0;
|
||||
entries_.clear();
|
||||
if (!entries_.empty()) {
|
||||
// Soft delete: increment epoch to mark all existing entries as old
|
||||
// Store where the new epoch starts in the circular buffer
|
||||
reset_epoch_++;
|
||||
reset_epoch_start_idx_[reset_epoch_] = next_;
|
||||
id_ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename EventType>
|
||||
|
||||
@ -708,7 +708,8 @@ void ProcessGroupGloo::runLoop(int workerIndex) {
|
||||
// TODO: We need to have numel of tensors for gloo as well.
|
||||
pgStatus_->lastCompletedNumelIn = 0;
|
||||
pgStatus_->lastCompletedNumelOut = 0;
|
||||
FlightRecorder<c10::Event>::get()->retire_id(work->trace_id_, false);
|
||||
FlightRecorder<c10::Event>::get()->retire_id(
|
||||
work->trace_id_, work->trace_reset_epoch_, false);
|
||||
lock.lock();
|
||||
workInProgress_[workerIndex].reset();
|
||||
}
|
||||
@ -780,7 +781,7 @@ void ProcessGroupGloo::enqueue(c10::intrusive_ptr<AsyncWork> work) {
|
||||
pgStatus_->lastEnqueuedNumelOut = 0;
|
||||
// using c10d::FlightRecorder;
|
||||
// TODO: We need to have a way to use c10::Event inside gloo as well.
|
||||
work->trace_id_ = FlightRecorder<c10::Event>::get()->record(
|
||||
auto traceId = FlightRecorder<c10::Event>::get()->recordWithResetEnabled(
|
||||
local_id_,
|
||||
std::make_tuple(pg_uid_, pg_desc_),
|
||||
collectiveCounter_,
|
||||
@ -795,6 +796,8 @@ void ProcessGroupGloo::enqueue(c10::intrusive_ptr<AsyncWork> work) {
|
||||
work->getTimeout(),
|
||||
pgStatus_,
|
||||
false);
|
||||
work->trace_id_ = traceId.id;
|
||||
work->trace_reset_epoch_ = traceId.reset_epoch;
|
||||
workQueue_.push_back(std::move(work));
|
||||
lock.unlock();
|
||||
|
||||
|
||||
@ -99,6 +99,7 @@ class TORCH_API ProcessGroupGloo : public Backend {
|
||||
// unique id used to tell the trace buffer that this
|
||||
// work has completed
|
||||
std::optional<uint64_t> trace_id_;
|
||||
std::optional<uint64_t> trace_reset_epoch_;
|
||||
std::shared_ptr<gloo::Context> context_;
|
||||
const std::chrono::milliseconds timeout_;
|
||||
|
||||
|
||||
@ -575,6 +575,7 @@ ProcessGroupNCCL::WorkNCCL::WorkNCCL(const WorkNCCL& w)
|
||||
futureWorkResult_(w.futureWorkResult_),
|
||||
timingEnabled_(w.timingEnabled_),
|
||||
trace_id_(w.trace_id_),
|
||||
trace_reset_epoch_(w.trace_reset_epoch_),
|
||||
distDebugLevel_(w.distDebugLevel_) {
|
||||
exception_ = w.exception_;
|
||||
}
|
||||
@ -704,9 +705,9 @@ bool ProcessGroupNCCL::WorkNCCL::checkTimeout(
|
||||
// Print the traceback of the collective at call time
|
||||
std::string ProcessGroupNCCL::WorkNCCL::getTraceback() const {
|
||||
// First step we get the corresponding record entry from FR, based on work's
|
||||
// trace_id_
|
||||
// trace_id_ and trace_reset_epoch_
|
||||
std::optional<FlightRecorderCUDA::Entry> entry =
|
||||
FlightRecorderCUDA::get()->getEntry(trace_id_);
|
||||
FlightRecorderCUDA::get()->getEntry(trace_id_, trace_reset_epoch_);
|
||||
if (entry.has_value()) {
|
||||
auto entryVal = entry.value();
|
||||
// Get stack trace from FR entry, in string format
|
||||
@ -2394,7 +2395,8 @@ void ProcessGroupNCCL::Watchdog::runLoop() {
|
||||
pg_->pgStatus_->lastCompletedWorkName = opTypeToString(work.opType_);
|
||||
pg_->pgStatus_->lastCompletedNumelIn = work.numelIn_;
|
||||
pg_->pgStatus_->lastCompletedNumelOut = work.numelOut_;
|
||||
FlightRecorderCUDA::get()->retire_id(work.trace_id_, true);
|
||||
FlightRecorderCUDA::get()->retire_id(
|
||||
work.trace_id_, work.trace_reset_epoch_, true);
|
||||
if (pg_->onCompletionHook_) {
|
||||
// Move Work object to completedWorkList_ to be consumed by the hook
|
||||
// thread
|
||||
@ -3360,7 +3362,7 @@ c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> ProcessGroupNCCL::initWork(
|
||||
// these objects to the Work because it has implications for keeping those
|
||||
// tensors alive longer and adds overhead when copying Work objects
|
||||
// between threads
|
||||
r->trace_id_ = FlightRecorderCUDA::get()->record(
|
||||
auto traceId = FlightRecorderCUDA::get()->recordWithResetEnabled(
|
||||
local_id_,
|
||||
std::make_tuple(pg_uid_, pg_desc_),
|
||||
seqCollective_,
|
||||
@ -3374,6 +3376,8 @@ c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> ProcessGroupNCCL::initWork(
|
||||
options_->timeout,
|
||||
pgStatus_,
|
||||
isP2P);
|
||||
r->trace_id_ = traceId.id;
|
||||
r->trace_reset_epoch_ = traceId.reset_epoch;
|
||||
}
|
||||
return r;
|
||||
}
|
||||
@ -3593,6 +3597,7 @@ float ProcessGroupNCCL::endTimeEstimate() {
|
||||
#ifdef NCCL_SIM_INFO_INITIALIZER
|
||||
ncclSimInfo_t simInfo = NCCL_SIM_INFO_INITIALIZER;
|
||||
C10D_NCCL_CHECK(ncclGroupSimulateEnd(&simInfo), std::nullopt);
|
||||
--ncclActiveGroupCounter_;
|
||||
return simInfo.estimatedTime;
|
||||
#else
|
||||
TORCH_CHECK(
|
||||
@ -3676,7 +3681,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
|
||||
// later in endCoalescing we record a 'coalesced' Work which has
|
||||
// timing/state updates via watchdog thread, but lacks op metadata such as
|
||||
// input/output sizes and profilingTitle per-op in the group.
|
||||
FlightRecorderCUDA::get()->record(
|
||||
FlightRecorderCUDA::get()->recordWithResetEnabled(
|
||||
local_id_,
|
||||
std::make_tuple(pg_uid_, pg_desc_),
|
||||
seqCollective_,
|
||||
@ -4168,7 +4173,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::pointToPoint(
|
||||
// TODO(whc) because we don't pass output {tensor} to initWork, we tell
|
||||
// initWork to not record, and then we manually call record passing all the
|
||||
// information it wants.
|
||||
work->trace_id_ = FlightRecorderCUDA::get()->record(
|
||||
auto traceId = FlightRecorderCUDA::get()->recordWithResetEnabled(
|
||||
local_id_,
|
||||
std::make_tuple(pg_uid_, pg_desc_),
|
||||
seqCollective_,
|
||||
@ -4182,6 +4187,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::pointToPoint(
|
||||
options_->timeout,
|
||||
pgStatus_,
|
||||
/*isP2P=*/true);
|
||||
work->trace_id_ = traceId.id;
|
||||
work->trace_reset_epoch_ = traceId.reset_epoch;
|
||||
}
|
||||
|
||||
// Only check for NaN for send ops, for recv ops `tensor` can be a random
|
||||
|
||||
@ -505,6 +505,7 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
// unique id used to tell the trace buffer that this
|
||||
// work has completed
|
||||
std::optional<uint64_t> trace_id_;
|
||||
std::optional<uint64_t> trace_reset_epoch_;
|
||||
DebugLevel distDebugLevel_;
|
||||
friend class ProcessGroupNCCL;
|
||||
};
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
#include <torch/headeronly/macros/Macros.h>
|
||||
#include <torch/headeronly/util/Exception.h>
|
||||
#include <torch/headeronly/util/HeaderOnlyArrayRef.h>
|
||||
#include <torch/headeronly/util/shim_utils.h>
|
||||
#include <climits>
|
||||
#include <memory>
|
||||
@ -13,6 +14,7 @@
|
||||
HIDDEN_NAMESPACE_BEGIN(torch, stable)
|
||||
|
||||
using accelerator::DeviceIndex;
|
||||
using torch::headeronly::IntHeaderOnlyArrayRef;
|
||||
using torch::headeronly::ScalarType;
|
||||
|
||||
// The torch::stable::Tensor class is a highlevel C++ wrapper around
|
||||
@ -93,6 +95,32 @@ class Tensor {
|
||||
return numel;
|
||||
}
|
||||
|
||||
// note: this API is, for all intents and purposes, the same as the one in
|
||||
// TensorBase.h: it returns a borrowed reference of the dimension sizes of
|
||||
// a Tensor.
|
||||
//
|
||||
// The only difference is that it returns a header-only IntHeaderOnlyArrayRef,
|
||||
// which has slightly less functionality than a regular IntArrayRef. See
|
||||
// [HeaderOnlyArrayRef vs ArrayRef note] for more details.
|
||||
IntHeaderOnlyArrayRef sizes() const {
|
||||
int64_t* sizes;
|
||||
TORCH_ERROR_CODE_CHECK(aoti_torch_get_sizes(ath_.get(), &sizes));
|
||||
return IntHeaderOnlyArrayRef(sizes, dim());
|
||||
}
|
||||
|
||||
// note: this API is, for all intents and purposes, the same as the one in
|
||||
// TensorBase.h: it returns a borrowed reference of the strides of a
|
||||
// Tensor.
|
||||
//
|
||||
// The only difference is that it returns a header-only IntHeaderOnlyArrayRef,
|
||||
// which has slightly less functionality than a regular IntArrayRef. See
|
||||
// [HeaderOnlyArrayRef vs ArrayRef note] for more details.
|
||||
IntHeaderOnlyArrayRef strides() const {
|
||||
int64_t* strides;
|
||||
TORCH_ERROR_CODE_CHECK(aoti_torch_get_strides(ath_.get(), &strides));
|
||||
return IntHeaderOnlyArrayRef(strides, dim());
|
||||
}
|
||||
|
||||
// note: this is a subset of the original TensorBase API. It takes no
|
||||
// arguments whereas the original API takes in a kwarg of memory format.
|
||||
// Here, we assume the default contiguous memory format.
|
||||
|
||||
@ -1,9 +1,8 @@
|
||||
import functools
|
||||
import math
|
||||
import operator
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
from datetime import timedelta
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from torch._C import ScriptObject
|
||||
|
||||
@ -10,6 +10,7 @@ from ._context_parallel._attention import (
|
||||
_enable_context_parallel_dispatcher,
|
||||
_is_causal_behavior,
|
||||
_RotateMethod,
|
||||
_templated_ring_attention,
|
||||
context_parallel,
|
||||
context_parallel_unshard,
|
||||
set_rotate_method,
|
||||
@ -22,6 +23,7 @@ from ._context_parallel._load_balancer import (
|
||||
)
|
||||
|
||||
|
||||
# TODO(fegin): add deprecation message once the final interfaces are concluded.
|
||||
__all__ = [
|
||||
"_CausalBehavior",
|
||||
"_context_parallel_shard",
|
||||
@ -31,6 +33,7 @@ __all__ = [
|
||||
"_enable_context_parallel_dispatcher",
|
||||
"_is_causal_behavior",
|
||||
"_RotateMethod",
|
||||
"_templated_ring_attention",
|
||||
"context_parallel",
|
||||
"context_parallel_unshard",
|
||||
"set_rotate_method",
|
||||
|
||||
Reference in New Issue
Block a user