Compare commits

...

20 Commits

Author SHA1 Message Date
94f210d947 update triton commit hash 2025-11-06 00:27:41 +00:00
a344069f2a Add missing skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION) to test/test_transformers.py (#166969)
This PR adds missing skips for efficient attention tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166969
Approved by: https://github.com/jeffdaily
2025-11-05 23:16:51 +00:00
af829c0dad [ROCm] Skip nvfp4 tests on ROCm (#167066)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167066
Approved by: https://github.com/jeffdaily, https://github.com/slayton58
2025-11-05 23:15:17 +00:00
3869aa115b fix fr reset api (#166970)
Summary:
- there are various places that access fr's `entries_` field
- if we empty the entries_ on reset, the accesses can result in an error
- so we only perform a soft delete instead of clearing out the entries copletely
  - only reset id_ on the reset
  - keep track of a reset_epoch which increments everytime reset is called
  - dump_entries only returns entries from the latest epoch
  - api's that access entries also check if the reset epoch matches
- make the `next_` always track the index in the circular buffer - this change was needed to make the soft delete's implementation easier

---
[//]: # (BEGIN SAPLING FOOTER)
Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/pytorch/pytorch/pull/166970).
* #166972
* #166971
* __->__ #166970

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166970
Approved by: https://github.com/fduwjj
2025-11-05 23:06:00 +00:00
47eb34b7ac [ATEN][CUDA] Reduce register pressure in radix_sort_pairs to improve torch.sort performance (#167094)
# Summary
This PR improves `torch.sort` and `torch.unique` performance by **15% to 50%** on NVIDIA GPUs by optimizing CUDA register allocation in radix sort operations.

The key change: specialize `OpaqueType<N>` to use native integer types (uint8_t, uint16_t, uint32_t, uint64_t) for common sizes (1, 2, 4, 8 bytes) instead of `char data[N]`. This enables more efficient register allocation while preserving the template deduplication strategy.

The following table shows the speedup on various input shapes and GPUs. Sorting is performed on the last dimension, and baseline torch version is 2.9.0.

| GPU  | input shape | input dtype | **Before** **(ms)** | After (ms) | Speedup |
| ---- | ----------- | ----------- | ------------------- | ---------- | ------- |
| H100 | (16, 1e6)   | int32       | 1.61                | 1.37       | 1.18×   |
| H100 | (1, 1e8)    | int32       | 6.6                 | 5.0        | 1.3×    |
| H20  | (16, 1e6)   | int64       | 3.57                | 3.03       | 1.18×   |
| H20  | (1, 1e8)    | int64       | 19.3                | 13.0       | 1.48×   |

# Analysis

`torch.sort` and `torch.unique` use `radix_sort_pairs`, which internally calls `cub::DeviceRadixSort::SortPairs`. Since values are only copied (never compared), we cast them to `OpaqueType<sizeof(value_t)>` to minimize template instantiations. For example, both `int32` and `float32` values map to the same `OpaqueType<4>.`

## The Problem

The previous `char data[N]` implementation causes inefficient register allocation. Here is one reason I find from SASS code. For 8-byte types:

- `char data[8]:` Compiler may allocate 8 registers (one per byte)

- `uint64_t data`: Compiler allocates 2 registers (standard 64-bit handling)

This happens because the compiler doesn't recognize char[8] as a cohesive 64-bit value, treating each byte independently, which increases register pressure and reduces GPU occupancy.

From Nsight Compute, when using `char data[8]`, the registers per thread is 166, and corresponding theoretical occupancy is 18.75%. When using native `uint64_t`, the registers per thread is 80, and corresponding theoretical occupancy is 37.5%.

## The Solution

Specialize `OpaqueType<N>` for common sizes using native integer types:

```
// Before
template <int N> struct alignas(N) OpaqueType { char data[N]; };

// After
template <int N> struct alignas(N) OpaqueType { char data[N]; }; // fallback
template <> struct alignas(1) OpaqueType<1> { uint8_t data; };
template <> struct alignas(2) OpaqueType<2> { uint16_t data; };
template <> struct alignas(4) OpaqueType<4> { uint32_t data; };
template <> struct alignas(8) OpaqueType<8> { uint64_t data; };
```

This preserves the template deduplication strategy (all 8-byte types still use the same `OpaqueType<8>` instantiation) while enabling better register allocation.

# Testing & Compatibility
## Testing:
 Correctness tests pass for various input types (bfloat16, int32, float32, int64), shapes, and dimensions (1, 2, 3)
 Register usage reduction verified with NSight Compute
 Linter passes
## Compatibility:
 No API/ABI changes
 Template instantiation count unchanged

# Reference
For detailed analysis, please refere to my previous blog: [Performance Optimization of torch.sort on GPU](https://yywangcs.notion.site/Performance-Optimization-of-torch-sort-on-GPU-192fc9f5d8058018a1bec1efa35da3f9)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167094
Approved by: https://github.com/ngimel, https://github.com/Skylion007
2025-11-05 22:34:19 +00:00
08200280ce [CP][BE][3/N] Add _templated_ring_attention to the backward compatility stub (#166991)
While `_templated_ring_attention` is a private API, it is unfortunatelly used by some packages.
Add it to __all__ so that people can still use it.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166991
Approved by: https://github.com/XilunWu
ghstack dependencies: #166456, #166501
2025-11-05 22:22:55 +00:00
ad7a57262c [12/N] Apply ruff UP035 rule (#166929)
This PR continues to apply ruff UP035 rule to test code and some remaining torch files.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166929
Approved by: https://github.com/Lucaskabela
2025-11-05 22:06:19 +00:00
711a775878 fix nccl estimations (#167093)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167093
Approved by: https://github.com/kwen2501, https://github.com/eellison
2025-11-05 22:01:49 +00:00
e9a688f02e [DebugMode] output, tensor id annotations for DebugMode (#165076)
Adds optional "node" id for tensors, output info annotations to DebugMode, with `DebugMode(record_output=True, record_ids=True)`

Example output for `test_debug_mode_mm`, with both enabled:
```
  torch.mm(dt$0: f32[8, 8]| S(0), dt$1: f32[8, 32]| S(0))  ->  dt$12: f32[8, 32]| S(0)
    aten::mm(dt$2: f32[8, 8]| S(0), dt$3: f32[8, 32]| S(0))
      redistribute_input(1, S(0) -> R)
        redistribute_input(t$4: f32[1, 32], trace: S(0)->R)
          _c10d_functional::all_gather_into_tensor(t$5: f32[1, 32], 8, 0)  ->  t$6: f32[8, 32]
          _c10d_functional::wait_tensor(t$7: f32[8, 32])  ->  t$8: f32[8, 32]
      aten::mm(t$9: f32[1, 8], t$10: f32[8, 32])  ->  t$11: f32[1, 32]
  <method 'sum' of 'torch._C.TensorBase' objects>(dt$13: f32[8, 32]| S(0))  ->  dt$17: f32[]| P
    aten::sum(dt$14: f32[8, 32]| S(0))
      aten::sum(t$15: f32[1, 32])  ->  t$16: f32[]"""
```

Sadly the only way to get DTensor op outputs is to set `record_torchfunction=True`, as dispatch calls just defer to DTensor's dispatch logic.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165076
Approved by: https://github.com/zpcore
2025-11-05 22:00:11 +00:00
e69aaaf45a [user-streams] Add backward test (#167021)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167021
Approved by: https://github.com/Lucaskabela
ghstack dependencies: #167019
2025-11-05 21:24:44 +00:00
fd8f368d31 [user-streams] Add graph annotation checks (#167019)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167019
Approved by: https://github.com/Lucaskabela
2025-11-05 21:24:44 +00:00
13d2cc7bd2 Remove python workaround for ContextDecorator (#167049)
This PR removes the import workaround for ContextDecorator because the import always succeeds in Py 3.10+.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167049
Approved by: https://github.com/Skylion007
2025-11-05 20:56:04 +00:00
c6c913d18e Add torch::stable::Tensor sizes and strides (#165153)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165153
Approved by: https://github.com/mikaylagawarecki
ghstack dependencies: #164991, #165152
2025-11-05 20:55:34 +00:00
ef3f953966 Revert "[DebugMode] output, tensor id annotations for DebugMode (#165076)"
This reverts commit a64c7d740428010d700b4bcd395af8a7b2d5c21f.

Reverted https://github.com/pytorch/pytorch/pull/165076 on behalf of https://github.com/wdvr due to Sorry but this is breaking internally. See diff [D86245252](https://l.workplace.com/l.php?u=https%3A%2F%2Fwww.internalfb.com%2Fdiff%2FD86245252&h=AT1oPbS1XTv6HjYeYdxmDMW1-jlT0pS8yBO2iSfbPfUB9ydsEjFXBNT56QhV1v5TKc4_QaQNxykNowSKmb4fgenjOyCv20NuL7oV_Id5fhh32hhv1IpjgsDJYK-PBFfSfv_miLIWfNgj902KcgXojbBgDcDzQeS9lNt0GQ) for details. To validate your fixes internally, you can follow the instructions here: https://fburl.com/fixing-ghfirst-reverts ([comment](https://github.com/pytorch/pytorch/pull/165076#issuecomment-3493358159))
2025-11-05 20:52:43 +00:00
ea44f12bce [13/N] Apply ruff UP035 rule (#167048)
This PR continues to apply ruff UP035 rule to test code and some remaining torch files.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167048
Approved by: https://github.com/Skylion007
2025-11-05 20:51:53 +00:00
a74fe75c45 Don't hardcode double argument for reduction base (#166951)
Fixes https://github.com/pytorch/pytorch/issues/43254

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166951
Approved by: https://github.com/ngimel, https://github.com/Skylion007
ghstack dependencies: #166813
2025-11-05 20:34:15 +00:00
6d30666bc1 Revert "[12/N] Apply ruff UP035 rule (#166929)"
This reverts commit 5863ba1b2e4de9ea0ae16a663465ec5d3d6f9f52.

Reverted https://github.com/pytorch/pytorch/pull/166929 on behalf of https://github.com/donigian due to Temporarily need to revert this to continue a revert for #165076. @cyyever Please re-merge after revert of #165076. ([comment](https://github.com/pytorch/pytorch/pull/166929#issuecomment-3493090596))
2025-11-05 20:02:47 +00:00
8e8cbb85ee Revert "[Inductor] Fix unbacked float symbol handling in kernel codegen (#166890)"
This reverts commit 0c7a4a6b48d49306eae8d0a9ee8d32b1899e5e23.

Reverted https://github.com/pytorch/pytorch/pull/166890 on behalf of https://github.com/malfet due to Looks like it broke torchfuzz tests, see fbd70fb84e/1 and same test on slow ([comment](https://github.com/pytorch/pytorch/pull/166890#issuecomment-3493011038))
2025-11-05 19:42:39 +00:00
fbd70fb84e Update typing docs to reference pyrefly (#166883)
Replacing mypy codumentation in the CONTRIBUTING.MD file with pyrefly references. I have made initial changes to https://github.com/pytorch/pytorch/wiki/Guide-for-adding-type-annotations-to-PyTorch documentation, and will replace the script at the bottom with one tailored to the pyrefly tool as a follow-up.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166883
Approved by: https://github.com/malfet
2025-11-05 19:35:38 +00:00
6c5db82584 [Inductor] Naive foreach autotune support (#162053)
Initial autotuning support for foreach kernels, 4x improvement for some kernels in internal workload. More improvements can surely be made here in the future. Removing num_warps for definition to enable autotune support in generated wrapper code.

Before:
triton_for_fused_18.kd 🔍 | 4.986 ms | 4.986 ms | 2.493 ms | 2 |
triton_for_fused_6.kd 🔍 | 0.098 ms | 0.098 ms | 0.049 ms | 2 |
triton_for_fused_7.kd 🔍 | 0.036 ms | 0.036 ms | 0.018 ms | 2 |

After:
triton_for_fused_18.kd 🔍 | 1.273 ms | 1.273 ms | 0.636 ms | 2 |
triton_for_fused_6.kd 🔍 | 0.044 ms | 0.044 ms | 0.022 ms | 2 |
triton_for_fused_7.kd 🔍 | 0.024 ms | 0.024 ms | 0.012 ms | 2 |

num_warps=8 default due to https://github.com/pytorch/pytorch/blob/main/torch/_inductor/codegen/triton_combo_kernel.py#L374

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162053
Approved by: https://github.com/mlazos, https://github.com/naromero77amd, https://github.com/jeffdaily

Co-authored-by: Nichols A. Romero <nick.romero@amd.com>
2025-11-05 19:27:23 +00:00
36 changed files with 786 additions and 179 deletions

View File

@ -1 +1 @@
bfeb066872bc1e8b2d2bc0a3b295b99dd77206e7
40eb62cb371b4c2b350c0d735dd65d4f905ee0fe

View File

@ -18,7 +18,7 @@ aspects of contributing to PyTorch.
- [Python Unit Testing](#python-unit-testing)
- [Better local unit tests with `pytest`](#better-local-unit-tests-with-pytest)
- [Local linting](#local-linting)
- [Running `mypy`](#running-mypy)
- [Running `pyrefly`](#running-pyrefly)
- [C++ Unit Testing](#c-unit-testing)
- [Run Specific CI Jobs](#run-specific-ci-jobs)
- [Merging your Change](#merging-your-change)
@ -281,7 +281,7 @@ dependencies as well as the nightly binaries into the repo directory.
**Prerequisites**:
The following packages should be installed with `pip`:
- `expecttest` and `hypothesis` - required to run tests
- `mypy` - recommended for linting
- `pyrefly` - recommended for type checking. [Pyrefly](https://pyrefly.org/)
- `pytest` - recommended to run tests more selectively
Running
```
@ -350,15 +350,32 @@ make lint
Learn more about the linter on the [lintrunner wiki page](https://github.com/pytorch/pytorch/wiki/lintrunner)
#### Running `mypy`
#### Running `pyrefly`
`mypy` is an optional static type checker for Python. We have multiple `mypy`
configs for the PyTorch codebase that are automatically validated against whenever the linter is run.
[Pyrefly](https://pyrefly.org/) is a high-performance static type checker for Python. It provides fast type checking along with IDE features like autocomplete and instant error feedback.
PyTorch uses Pyrefly for type checking across the codebase. The configuration is managed in `pyrefly.toml` at the root of the repository.
**Getting Started with Pyrefly:**
To run type checking on the PyTorch codebase:
```bash
pyrefly check
```
For more detailed error information with summaries:
```bash
pyrefly check --summarize-errors
```
**Learn More:**
- [Pyrefly Configuration](https://pyrefly.org/en/docs/configuration/) - Detailed configuration options
- [Pyrefly IDE Features](https://pyrefly.org/en/docs/IDE-features/) - Set up Pyrefly in your editor for real-time type checking
- [Python Typing Tutorial](https://pyrefly.org/en/docs/typing-for-python-developers/) - Learn about Python type annotations
See [Guide for adding type annotations to
PyTorch](https://github.com/pytorch/pytorch/wiki/Guide-for-adding-type-annotations-to-PyTorch)
for more information on how to set up `mypy` and tackle type annotation
tasks.
for PyTorch-specific guidance on how to set up `pyrefly` and tackle type annotation tasks in this codebase.
### C++ Unit Testing

View File

@ -24,7 +24,13 @@ namespace detail {
// radix_sort_pairs doesn't interact with value_t other than to copy
// the data, so we can save template instantiations by reinterpreting
// it as an opaque type.
// We use native integer types for 1/2/4/8-byte values to reduce
// register usage in CUDA kernels. For sizes > 8 fall back to char array.
template <int N> struct alignas(N) OpaqueType { char data[N]; };
template <> struct alignas(1) OpaqueType<1> { uint8_t data; };
template <> struct alignas(2) OpaqueType<2> { uint16_t data; };
template <> struct alignas(4) OpaqueType<4> { uint32_t data; };
template <> struct alignas(8) OpaqueType<8> { uint64_t data; };
template<typename key_t, int value_size>
void radix_sort_pairs_impl(

View File

@ -247,8 +247,8 @@ void binary_kernel_reduce(TensorIteratorBase& iter, ops_t ops, init_t init) {
});
}
template <typename func_t, typename vec_func_t>
void binary_kernel_reduce_vec(TensorIteratorBase& iter, func_t op, vec_func_t vop, double ident = 0) {
template <typename func_t, typename vec_func_t, typename ident_t = double>
void binary_kernel_reduce_vec(TensorIteratorBase& iter, func_t op, vec_func_t vop, ident_t ident = static_cast<ident_t>(0)) {
using traits = binary_function_traits<func_t>;
static_assert(
all_same<

View File

@ -339,33 +339,13 @@ void or_kernel_impl(TensorIterator& iter) {
}
}
template<typename scalar_t>
struct MinValuesOps: public at::native::MinOps<scalar_t> {
using arg_t = typename MinOps<scalar_t>::arg_t;
static scalar_t project(arg_t arg) {
return arg.first;
}
};
void min_values_kernel_impl(TensorIterator& iter) {
// This case is special because of Vectorized<int64_t> does not
// handle upper_bound<int64_t>().
// See: https://github.com/pytorch/pytorch/issues/43254
if (iter.dtype() == kLong || iter.dtype() == kUInt64) {
AT_DISPATCH_V2(iter.dtype(), "min_values_cpu", AT_WRAP([&iter] {
binary_kernel_reduce(
iter,
MinValuesOps<scalar_t>{},
std::pair<scalar_t, int64_t>(upper_bound<scalar_t>(), -1));
}), kLong, kUInt64);
return;
}
AT_DISPATCH_V2(iter.dtype(), "min_values_cpu", AT_WRAP([&iter] {
binary_kernel_reduce_vec(
iter,
[](scalar_t a, scalar_t b) -> scalar_t { return min_impl(a, b); },
[](Vectorized<scalar_t> a, Vectorized<scalar_t> b) { return minimum(a, b); },
static_cast<double>(upper_bound<scalar_t>()));
upper_bound<scalar_t>());
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool);
}

View File

@ -47,20 +47,10 @@ Tensor sgd_out_of_place(
STD_TORCH_CHECK(param.get_device() == -1, "CPU device index = -1");
STD_TORCH_CHECK(param.get_device_index() == -1, "CPU device index = -1");
int64_t *param_sizes;
int64_t *param_strides;
aoti_torch_get_sizes(param.get(), &param_sizes);
aoti_torch_get_strides(param.get(), &param_strides);
// testing Tensor strides + stride
STD_TORCH_CHECK(param.strides()[0] == param.stride(0));
int32_t param_dtype;
aoti_torch_get_dtype(param.get(), &param_dtype);
int32_t param_device_type;
aoti_torch_get_device_type(param.get(), &param_device_type);
AtenTensorHandle out_ath;
aoti_torch_empty_strided(param.dim(), param_sizes, param_strides, param_dtype, param_device_type, param.get_device(), &out_ath);
auto out = Tensor(out_ath);
auto out = new_empty(param, param.sizes());
sgd_math(
reinterpret_cast<float*>(param.data_ptr()),
@ -344,6 +334,8 @@ Tensor my_new_empty_dtype_variant(Tensor t) {
// Still using a std::vector below even though people can just pass in an
// initializer list (which will be implicitly converted to an HeaderOnlyArrayRef)
// directly.
// This is to test that passing in a std::vector works for BC. (It gets
// implicitly converted to HeaderOnlyArrayRef too!)
std::vector<int64_t> sizes = {2, 5};
auto dtype = std::make_optional(torch::headeronly::ScalarType::BFloat16);
return new_empty(t, sizes, dtype);

View File

@ -5789,6 +5789,229 @@ class NCCLTraceTest(NCCLTraceTestBase):
else:
self.assertTrue("duration_ms" not in t["entries"][0])
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@parametrize("timing_enabled", [True, False])
def test_fr_record_reset_circular_buffer_full(self, timing_enabled):
"""
Test that when the circular buffer in entries_ is full and we call reset,
then fill the buffer with new entries, dump_entries returns only the new
entries and not the old ones.
"""
if self.rank == self.MAIN_PROCESS_RANK:
return
# Override buffer size to 10 for faster testing
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10"
pg = self._create_process_group_nccl()
if timing_enabled:
pg._enable_collectives_timing()
device = self.local_device
self.set_thread_name("fr_test_thread")
a = torch.full((3, 4), float(self.rank), device=device)
# Fill the buffer completely with 10 entries
for _ in range(10):
f = pg.allreduce(a)
f.wait()
torch.cuda.synchronize(device=device)
time.sleep(1)
# Verify buffer is full with 10 entries
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
self.assertEqual(len(t["entries"]), 10)
# Now reset the flight recorder
torch._C._distributed_c10d._reset_fr_recording_nccl()
# Add new entries after reset - fill the buffer completely again
for _ in range(10):
f = pg.allreduce(a)
f.wait()
torch.cuda.synchronize(device=device)
time.sleep(1)
# Verify we get exactly 10 new entries, not 20
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
self.assertEqual(len(t["entries"]), 10)
# Verify all entries have the expected properties (from after reset)
# After reset, record IDs should start from 0 again
for i, entry in enumerate(t["entries"]):
self.assertIn("profiling_name", entry)
self.assertEqual(entry["profiling_name"], "nccl:all_reduce")
self.assertIn("record_id", entry)
# Record IDs should be sequential starting from 0 after reset
self.assertEqual(entry["record_id"], i)
dist.destroy_process_group()
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@parametrize("timing_enabled", [True, False])
def test_fr_record_reset_partial_overwrite(self, timing_enabled):
"""
Test that when the circular buffer is full, we reset, and then add fewer
entries than the buffer size, we only get the new entries.
This tests that old entries at the end of the circular buffer are properly
filtered out based on reset_epoch.
"""
if self.rank == self.MAIN_PROCESS_RANK:
return
# Override buffer size to 10 for faster testing
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10"
pg = self._create_process_group_nccl()
if timing_enabled:
pg._enable_collectives_timing()
device = self.local_device
self.set_thread_name("fr_test_thread")
a = torch.full((3, 4), float(self.rank), device=device)
# Fill the buffer completely
for _ in range(10):
f = pg.allreduce(a)
f.wait()
torch.cuda.synchronize(device=device)
time.sleep(1)
# Reset the flight recorder
torch._C._distributed_c10d._reset_fr_recording_nccl()
# Add only 3 new entries (much less than buffer size)
for _ in range(3):
f = pg.allreduce(a)
f.wait()
torch.cuda.synchronize(device=device)
time.sleep(1)
# Verify we only get the 3 new entries, not 10
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
self.assertEqual(len(t["entries"]), 3)
# Verify record IDs start from 0 after reset
for i, entry in enumerate(t["entries"]):
self.assertIn("record_id", entry)
self.assertEqual(entry["record_id"], i)
dist.destroy_process_group()
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@parametrize("timing_enabled", [True, False])
def test_fr_record_reset_wraparound(self, timing_enabled):
"""
Test that when we reset in the middle of the circular buffer and then
wrap around, dump_entries correctly returns only entries from the current
epoch in the correct order.
"""
if self.rank == self.MAIN_PROCESS_RANK:
return
# Override buffer size to 10 for faster testing
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10"
pg = self._create_process_group_nccl()
if timing_enabled:
pg._enable_collectives_timing()
device = self.local_device
self.set_thread_name("fr_test_thread")
a = torch.full((3, 4), float(self.rank), device=device)
# Fill half the buffer
for _ in range(5):
f = pg.allreduce(a)
f.wait()
torch.cuda.synchronize(device=device)
time.sleep(1)
# Reset at this point (reset happens at index 5)
torch._C._distributed_c10d._reset_fr_recording_nccl()
# Now add 8 entries, which will wrap around
# (5->9 fills rest of buffer, then 0->2 wraps around)
for _ in range(8):
f = pg.allreduce(a)
f.wait()
torch.cuda.synchronize(device=device)
time.sleep(1)
# Should get exactly 8 entries, properly ordered
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
self.assertEqual(len(t["entries"]), 8)
# Entries should be in chronological order
# The dump_entries() method returns entries from next_ to end, then 0 to next_
# After filtering old entries, we should have 8 entries in order
# Verify record IDs start from 0 after reset (id_ is reset in reset_all())
for i, entry in enumerate(t["entries"]):
self.assertIn("profiling_name", entry)
self.assertIn("record_id", entry)
self.assertEqual(entry["record_id"], i)
dist.destroy_process_group()
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@parametrize("timing_enabled", [True, False])
def test_fr_record_multiple_resets(self, timing_enabled):
"""
Test multiple consecutive resets to ensure each reset properly increments
the epoch and filters out entries from previous epochs.
"""
if self.rank == self.MAIN_PROCESS_RANK:
return
# Override buffer size to 10 for faster testing
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10"
pg = self._create_process_group_nccl()
if timing_enabled:
pg._enable_collectives_timing()
device = self.local_device
self.set_thread_name("fr_test_thread")
a = torch.full((3, 4), float(self.rank), device=device)
# First batch: 2 entries
for _ in range(2):
f = pg.allreduce(a)
f.wait()
torch.cuda.synchronize(device=device)
time.sleep(1)
# First reset
torch._C._distributed_c10d._reset_fr_recording_nccl()
# Second batch: 3 entries
for _ in range(3):
f = pg.allreduce(a)
f.wait()
torch.cuda.synchronize(device=device)
time.sleep(1)
# Second reset
torch._C._distributed_c10d._reset_fr_recording_nccl()
# Third batch: 4 entries
for _ in range(4):
f = pg.allreduce(a)
f.wait()
torch.cuda.synchronize(device=device)
time.sleep(1)
# Should only see the last 4 entries
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
self.assertEqual(len(t["entries"]), 4)
# Verify record IDs start from 0 after the last reset
for i, entry in enumerate(t["entries"]):
self.assertIn("record_id", entry)
self.assertEqual(entry["record_id"], i)
dist.destroy_process_group()
def check_if_test_is_skipped(fn):
def wrapper(self, *args, **kwargs):

View File

@ -8,21 +8,11 @@ from torch._dynamo.graph_deduplication import apply_graph_deduplication
from torch._dynamo.graph_utils import _detect_cycles
from torch._dynamo.output_graph import FakeRootModule
from torch._dynamo.test_case import TestCase
from torch._dynamo.testing import (
AotEagerAndRecordGraphs,
extract_graph_and_tracker,
normalize_gm,
)
from torch._dynamo.testing import extract_graph, extract_graph_and_tracker, normalize_gm
from torch.compiler import allow_in_graph
from torch.utils._ordered_set import OrderedSet
def extract_graph(fn, *args, **kwargs):
backend = AotEagerAndRecordGraphs()
result = torch.compile(backend=backend)(fn)(*args, **kwargs)
return result, backend.graphs, backend.fw_graphs
def graph_str(gm):
return normalize_gm(gm.print_readable(print_output=False))
@ -40,7 +30,7 @@ class GraphDededuplicationTests(TestCase):
super().tearDown()
def run_and_return_graphs(self, fn, *args, **kwargs):
return extract_graph(fn, *args, **kwargs)
return extract_graph(fn, *args, **kwargs)[0:3]
def run_and_get_simple_graph(self):
def fn(x, y):

View File

@ -1,7 +1,7 @@
# Owner(s): ["module: dynamo"]
import unittest
from collections.abc import Sequence
from typing import Any, Callable, Union
from collections.abc import Callable, Sequence
from typing import Any, Union
import torch
import torch._dynamo

View File

@ -1,5 +1,5 @@
# Owner(s): ["module: dynamo"]
from typing import Callable, NamedTuple, Optional
from typing import NamedTuple, Optional, TYPE_CHECKING
import torch
import torch._dynamo
@ -7,6 +7,10 @@ from torch._dynamo.test_case import run_tests, TestCase
from torch._dynamo.testing import CompileCounter, same
if TYPE_CHECKING:
from collections.abc import Callable
"""
This is an example of a pure-python version of autograd implemented by
@zdevito. It represents a rather challenging test case for TorchDynamo

View File

@ -1,11 +1,13 @@
# Owner(s): ["module: dynamo"]
import functools
import re
import unittest
import weakref
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
from torch._dynamo.testing import extract_graph, remove_trailing_space
from torch.testing._internal.common_cuda import TEST_MULTIGPU
from torch.testing._internal.common_utils import requires_cuda
@ -15,6 +17,14 @@ requires_multigpu = functools.partial(
)
def remove_file_comment(gm_str: str) -> str:
return remove_trailing_space(re.sub(r"File.*\n", "\n", gm_str))
def print_graph(graph: torch.fx.GraphModule) -> str:
return remove_file_comment(graph.print_readable())
class TestStreams(torch._dynamo.test_case.TestCase):
@classmethod
def setUpClass(cls):
@ -36,9 +46,7 @@ class TestStreams(torch._dynamo.test_case.TestCase):
@requires_cuda
def test_stream_enter_exit(self):
def fn(x, y):
s2 = torch.Stream()
s1 = torch.Stream()
def fn(x, y, s1, s2):
with s1:
z1 = torch.add(x, y)
with s2:
@ -47,13 +55,36 @@ class TestStreams(torch._dynamo.test_case.TestCase):
return y
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2), torch.Stream(), torch.Stream())
expected = fn(*inp)
fn_opt = torch.compile(fn, fullgraph=True)
actual = fn_opt(*inp)
(
actual,
_,
fw_graphs,
_,
) = extract_graph(fn, *inp)
self.assertEqual(len(fw_graphs), 1)
self.assertEqual(expected, actual)
self.assertExpectedInline(
print_graph(fw_graphs[0]),
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
# Annotation: {'stream': None}
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
# Annotation: {'stream': None}
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
# Annotation: {'stream': None}
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 2); add_1 = None
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_2, add); add_2 = add = None
return (add_3,)
""",
)
@requires_cuda
@unittest.skip("Needs graph break support with annotation context")
def test_stream_context_graph_break(self):
def fn(x, y):
s2 = torch.Stream()
@ -70,9 +101,16 @@ class TestStreams(torch._dynamo.test_case.TestCase):
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
expected = fn(*inp)
fn_opt = torch.compile(fn)
actual = fn_opt(*inp)
(
actual,
_,
fw_graphs,
_,
) = extract_graph(fn, *inp)
self.assertEqual(expected, actual)
self.assertEqual(len(fw_graphs), 2)
self.assertExpectedInline(print_graph(fw_graphs[0]), """""")
self.assertExpectedInline(print_graph(fw_graphs[1]), """""")
@requires_cuda
def test_stream_input(self):
@ -155,22 +193,248 @@ class TestStreams(torch._dynamo.test_case.TestCase):
self.assertEqual(s_act, s_exp)
def test_nested_stream_enter_exit(self):
pass
def fn(x, y, s0, s1, s2):
with s1:
with s2:
z1 = torch.add(x, y)
with s0:
z0 = torch.add(x, y)
with s2:
y = 2 + z1
return z0, y
inp = (
torch.ones(2, 2) + 1,
torch.ones(2, 2),
torch.Stream(),
torch.Stream(),
torch.Stream(),
)
expected = fn(*inp)
(
actual,
_,
fw_graphs,
_,
) = extract_graph(fn, *inp)
self.assertEqual(len(fw_graphs), 1)
self.assertEqual(expected, actual)
self.assertExpectedInline(
print_graph(fw_graphs[0]),
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
# Annotation: {'stream': None}
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
# Annotation: {'stream': None}
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
# Annotation: {'stream': None}
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add, 2); add = None
return (add_1, add_2)
""",
)
@unittest.skip("Needs graph break support with annotation context")
def test_stream_enter_exit_graph_break(self):
pass
@unittest.skip("Needs graph break support with annotation context")
def test_nested_stream_enter_exit_graph_break(self):
pass
def test_local_stream_enter_exit(self):
pass
def fn(x, y):
s2 = torch.Stream()
s1 = torch.Stream()
with s1:
z1 = torch.add(x, y)
with s2:
z = torch.add(x, y)
y = z + 2 + z1
return y
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
expected = fn(*inp)
(
actual,
_,
fw_graphs,
_,
) = extract_graph(fn, *inp)
self.assertEqual(len(fw_graphs), 1)
self.assertEqual(expected, actual)
self.assertExpectedInline(
print_graph(fw_graphs[0]),
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
# Annotation: {'stream': 1}
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
# Annotation: {'stream': 0}
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
# Annotation: {'stream': 0}
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 2); add_1 = None
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_2, add); add_2 = add = None
return (add_3,)
""",
)
def test_local_stream_nested_enter_exit(self):
pass
def fn(x, y):
s2 = torch.Stream()
s1 = torch.Stream()
s0 = torch.Stream()
with s1:
with s2:
z1 = torch.add(x, y)
with s0:
z0 = torch.add(x, y)
with s2:
y = 2 + z1
return z0, y
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
expected = fn(*inp)
(
actual,
_,
fw_graphs,
_,
) = extract_graph(fn, *inp)
self.assertEqual(len(fw_graphs), 1)
self.assertEqual(expected, actual)
self.assertExpectedInline(
print_graph(fw_graphs[0]),
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
# Annotation: {'stream': 0}
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
# Annotation: {'stream': 2}
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
# Annotation: {'stream': 0}
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add, 2); add = None
return (add_1, add_2)
""",
)
def test_stream_with_mutation(self):
pass
def fn(x, y):
s2 = torch.Stream()
s1 = torch.Stream()
s0 = torch.Stream()
with s1:
with s2:
x.add_(y)
with s0:
z1 = torch.add(y, y)
z0 = torch.add(z1, y)
with s2:
y = 2 + z1
return z0, y
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
expected = fn(*inp)
(
actual,
_,
fw_graphs,
_,
) = extract_graph(fn, *inp)
self.assertEqual(len(fw_graphs), 1)
self.assertEqual(expected, actual)
self.assertExpectedInline(
print_graph(fw_graphs[0]),
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
# Annotation: {'stream': 0}
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
# Annotation: {'stream': 2}
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg1_1, arg1_1)
# Annotation: {'stream': 2}
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, arg1_1); arg1_1 = None
# Annotation: {'stream': 0}
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 2); add_1 = None
#
copy_: "f32[2, 2]" = torch.ops.aten.copy_.default(arg0_1, add); arg0_1 = add = copy_ = None
return (add_2, add_3)
""",
)
def test_stream_backward(self) -> None:
def fn(x, y):
s2 = torch.Stream()
s0 = torch.Stream()
with s0:
y0 = 2 * x + y
with s2:
z = 2 * x + y
return y0, z
inp = (
torch.ones(2, 2, requires_grad=True) + 1,
torch.ones(2, 2, requires_grad=True),
)
expected = fn(*inp)
(
actual,
_,
fw_graphs,
bw_graphs,
) = extract_graph(fn, *inp)
self.assertEqual(len(fw_graphs), 1)
self.assertEqual(expected, actual)
self.assertExpectedInline(
print_graph(fw_graphs[0]),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[2, 2]", primals_2: "f32[2, 2]"):
# Annotation: {'stream': 1}
mul: "f32[2, 2]" = torch.ops.aten.mul.Tensor(primals_1, 2); primals_1 = None
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul, primals_2)
# Annotation: {'stream': 0}
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul, primals_2); mul = primals_2 = None
return (add, add_1)
""",
)
actual[1].sum().backward()
self.assertExpectedInline(
print_graph(bw_graphs[0]),
"""\
class GraphModule(torch.nn.Module):
def forward(self, tangents_1: "f32[2, 2]", tangents_2: "f32[2, 2]"):
# Annotation: {'stream': 0}
mul_2: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_2, 2)
#
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(tangents_2, tangents_1); tangents_2 = None
# Annotation: {'stream': 1}
mul_3: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_1, 2); tangents_1 = None
#
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul_2, mul_3); mul_2 = mul_3 = None
return (add_3, add_2)
""",
)
@requires_cuda
def test_run_opcheck(self):

View File

@ -14424,20 +14424,6 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
self.common(fn, (torch.randn(6, 4, device=GPU_TYPE).t().contiguous().t(),))
@skip_if_halide
@requires_cuda_and_triton
def test_unbacked_float_item(self):
def fn(x, max_val):
return torch.clamp(x, 0, max_val.item())
self.common(
fn,
(
torch.randn(10, 20, 30, device=self.device),
torch.tensor(5.0, device=self.device),
),
)
# end of class CommonTemplate - add new tests here

View File

@ -1864,6 +1864,8 @@ class TestFP8Matmul(TestCase):
], name_fn=lambda mkn: f"{mkn[0]}_{mkn[1]}_{mkn[2]}")
@parametrize("recipe", ["mxfp8", "mxfp4", "nvfp4"])
def test_blockwise_mxfp8_nvfp4_mxfp4_numerics(self, test_case_name, fast_accum, mkn, recipe) -> None:
if torch.version.hip and recipe == "nvfp4":
raise unittest.SkipTest("nvfp4 not supported on ROCm, skipping")
if (recipe == "nvfp4" or recipe == "mxfp4") and fast_accum:
raise unittest.SkipTest("fast_accum not supported in nvfp4/mxfp4 cublas gemm, skipping")

View File

@ -1914,6 +1914,7 @@ class TestSDPAFailureModes(NNTestCase):
q, k, v, None, 0.0, is_causal=True))
@onlyCUDA
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support Efficient Attention")
def test_mem_eff_attention_fail_with_batch_size_geq_65536(self):
batch_size = 2**16
query = torch.rand([batch_size, 2, 2, 8], device='cuda', dtype=torch.float16, requires_grad=True)
@ -1935,6 +1936,7 @@ class TestSDPAFailureModes(NNTestCase):
self.assertEqual(value.grad, v_cpu.grad, atol=2e-3, rtol=1e-4)
@onlyCUDA
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support Efficient Attention")
def test_mem_eff_attention_fail_with_batch_size_geq_65536_error(self):
query = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16)
key = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16)
@ -1948,6 +1950,7 @@ class TestSDPAFailureModes(NNTestCase):
@largeTensorTest("15GB", "cuda")
@onlyCUDA
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support Efficient Attention")
def test_mem_eff_attention_large_seq_len_uniform_attention(self):
device = torch.device("cuda")
dtype = torch.bfloat16

View File

@ -1,5 +1,5 @@
from typing import Union
from typing_extensions import assert_type, TypeAlias
from typing import TypeAlias, Union
from typing_extensions import assert_type
from torch import randn, Tensor

View File

@ -1,8 +1,9 @@
# mypy: allow-untyped-defs
# mypy: disable-error-code="type-arg"
from collections.abc import Callable
from datetime import timedelta
from enum import Enum
from typing import Any, Callable, Optional, overload, Union
from typing import Any, Optional, overload, Union
import torch
from torch import Tensor

View File

@ -87,6 +87,12 @@ def extract_graph_and_tracker(fn, *args, **kwargs): # type: ignore[no-untyped-d
return gm.graph, region_tracker # type: ignore[union-attr]
def extract_graph(fn, *args, **kwargs): # type: ignore[no-untyped-def]
backend = AotEagerAndRecordGraphs()
result = torch.compile(backend=backend)(fn)(*args, **kwargs)
return result, backend.graphs, backend.fw_graphs, backend.bw_graphs
def collect_results(
model: torch.nn.Module, prediction: Any, loss: Any, example_inputs: Any
) -> list[Any]:

View File

@ -21,9 +21,9 @@ restoring state changes.
import inspect
import sys
import warnings
from collections.abc import Callable, Sequence
from collections.abc import Callable, Sequence, Sized
from contextlib import ExitStack
from typing import Any, ContextManager, Optional, Sized, TYPE_CHECKING, Union
from typing import Any, ContextManager, Optional, TYPE_CHECKING, Union
import torch._C
from torch._guards import Guard

View File

@ -2970,12 +2970,6 @@ class CppPythonBindingsCodeCache(CppCodeCache):
throw std::runtime_error("expected int arg");
return reinterpret_cast<uintptr_t>(result);
}}
template <> inline float parse_arg<float>(PyObject* args, size_t n) {{
auto result = PyFloat_AsDouble(PyTuple_GET_ITEM(args, n));
if(unlikely(result == -1.0 && PyErr_Occurred()))
throw std::runtime_error("expected float arg");
return static_cast<float>(result);
}}
{extra_parse_arg}

View File

@ -1732,15 +1732,9 @@ class KernelArgs:
call_args.append(self.wrap_ptr_arg(outer, dtype))
arg_types.append(f"{cpp_dtype}*")
for outer, inner in self.sizevars.items():
if isinstance(outer, sympy.Symbol) and symbol_is_type(
outer, (SymT.UNBACKED_FLOAT)
):
arg_defs.append(f"const float {inner}")
arg_types.append("const float")
else:
arg_defs.append(f"const {INDEX_TYPE} {inner}")
arg_types.append(f"const {INDEX_TYPE}")
arg_defs.append(f"const {INDEX_TYPE} {inner}")
call_args.append(self.wrap_size_arg(outer))
arg_types.append(f"const {INDEX_TYPE}")
if V.graph.wrapper_code:
V.graph.wrapper_code.ensure_size_computed(outer)
assert not self.workspace_args, "Workspace not supported on CPU "
@ -2359,7 +2353,6 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
SymT.UNBACKED_INT,
SymT.SIZE,
SymT.PRECOMPUTED_SIZE,
SymT.UNBACKED_FLOAT,
),
)
}

View File

@ -2,7 +2,7 @@
from __future__ import annotations
import hashlib
from typing import Any, Callable, Optional, Sequence, TYPE_CHECKING
from typing import Any, Optional, TYPE_CHECKING
import sympy # noqa: TC002
@ -17,6 +17,8 @@ from .simd import SIMDKernel, SIMDScheduling
if TYPE_CHECKING:
from collections.abc import Callable, Sequence
from ..ir import IRNode
from ..scheduler import BaseSchedulerNode

View File

@ -627,7 +627,7 @@ class ComboKernel(Kernel):
if heuristics == "foreach":
heuristics_line = f"""
@triton_heuristics.foreach(
num_warps={self.num_warps},
filename=__file__,
triton_meta={triton_meta!r},
inductor_meta={inductor_meta!r},
)

View File

@ -4,7 +4,6 @@ from typing import Any, Optional
import sympy
import torch
from torch.utils._sympy.symbol import symbol_is_type, SymT
from .. import config
from ..runtime.hints import AttrsDescriptorWrapper
@ -72,10 +71,6 @@ def signature_of(arg: KernelArgType, *, size_dtype: Optional[str]) -> str:
return "constexpr"
elif isinstance(arg.expr, (float, sympy.Float)):
return "fp32"
elif isinstance(arg.expr, sympy.Symbol) and symbol_is_type(
arg.expr, (SymT.UNBACKED_FLOAT)
):
return "fp32"
elif isinstance(arg.expr, bool):
return "i1"

View File

@ -360,7 +360,7 @@ def estimate_nccl_collective_runtime_from_fx_node(
fx_node: torch.fx.Node,
override_size: Optional[int] = None,
# TODO(ivankobzarev): NCCL estimator sometimes fail unexpectedly, enable back after fix.
use_nccl_estimator: bool = False,
use_nccl_estimator: bool = True,
) -> float:
"""
Returns estimated NCCL collective runtime in nanoseconds (ns).

View File

@ -1,6 +1,6 @@
import os
from collections.abc import Callable
from functools import cache, partial
from typing import Callable
import torch
from torch._environment import is_fbcode

View File

@ -3586,13 +3586,24 @@ def user_autotune(
)
def foreach(triton_meta, num_warps, filename=None, inductor_meta=None):
def foreach(triton_meta, filename=None, inductor_meta=None):
"""
Compile a triton foreach kernel
"""
configs = []
# Naive autotuning path for num_warps
if not (
inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise")
):
configs.append(triton.Config({}, num_stages=1, num_warps=8))
else:
for warps in [1, 2, 4, 8]:
configs.append(triton.Config({}, num_stages=1, num_warps=warps))
return cached_autotune(
None,
[triton.Config({}, num_stages=1, num_warps=num_warps)],
configs,
triton_meta=triton_meta,
inductor_meta=inductor_meta,
heuristic_type=HeuristicType.TEMPLATE,

View File

@ -52,26 +52,7 @@ __all__ = [
"MemRecordsAcc",
]
try:
# Available in Python >= 3.2
from contextlib import ContextDecorator as _ContextDecorator
except ImportError:
import functools
class _ContextDecorator: # type: ignore[no-redef]
def __enter__(self):
raise NotImplementedError
def __exit__(self, exc_type, exc_val, exc_tb):
raise NotImplementedError
def __call__(self, func):
@functools.wraps(func)
def wrapped(*args, **kwargs):
with self:
return func(*args, **kwargs)
return wrapped
from contextlib import ContextDecorator
# global python state - whether profiler is currently enabled
@ -744,8 +725,7 @@ class profile:
return all_function_events
# pyrefly: ignore [invalid-inheritance]
class record_function(_ContextDecorator):
class record_function(ContextDecorator):
"""Context manager/function decorator that adds a label to a code block/function when running autograd profiler.
Label will only appear if CPU activity tracing is enabled.

View File

@ -108,12 +108,14 @@ struct FlightRecorder {
capture_cpp_stack_ = getCvarBool(
{"TORCH_FR_CPP_STACK", "TORCH_NCCL_TRACE_CPP_STACK"}, false);
enabled_ = max_entries_ > 0;
reset_epoch_start_idx_[0] = 0;
}
struct Entry {
size_t id_; // incremented id in the trace buffer
// used to figure out where in the circular entries
// buffer this entry will be located to
// update state information
size_t reset_epoch_; // epoch when this entry was created
size_t pg_id_;
std::tuple<std::string, std::string> pg_name_; // <group_name, group_desc>
@ -183,11 +185,34 @@ struct FlightRecorder {
size_t max_entries_ = 0;
size_t next_ = 0;
size_t id_ = 0;
size_t reset_epoch_ = 0;
std::unordered_map<size_t, size_t>
reset_epoch_start_idx_; // maps reset_epoch to the idx where it starts
std::map<size_t, std::shared_ptr<ProcessGroupStatus>> all_pg_status_;
std::map<std::tuple<std::string, std::string>, std::vector<uint64_t>>
pg_name_to_ranks_;
std::string comm_lib_version_;
struct TraceIdentifier {
std::optional<size_t> id;
std::optional<size_t> reset_epoch;
};
TraceIdentifier recordWithResetEnabled(
size_t pg_id,
const std::tuple<std::string, std::string>& pg_name,
size_t collective_seq_id,
size_t p2p_seq_id,
size_t op_id,
std::string profiling_name,
const std::vector<at::Tensor>& inputs,
const std::vector<at::Tensor>& outputs,
EventType* start,
EventType* end,
std::chrono::milliseconds timeout_ms,
std::shared_ptr<ProcessGroupStatus> pg_status,
bool isP2P);
std::optional<size_t> record(
size_t pg_id,
const std::tuple<std::string, std::string>& pg_name,
@ -213,8 +238,16 @@ struct FlightRecorder {
std::vector<Entry> dump_entries();
// Returns the entry with the given id, if it exists. Otherwise, returns
// std::nullopt.
// Returns the index in entries_ for the given id and reset_epoch.
// Caller must hold mutex_lock before calling this method.
size_t getIdxFromId(size_t id, size_t reset_epoch) const;
// Returns the entry with the given id and reset_epoch, if it exists.
// Otherwise, returns std::nullopt.
TORCH_API std::optional<Entry> getEntry(
std::optional<size_t> id,
std::optional<size_t> reset_epoch);
TORCH_API std::optional<Entry> getEntry(std::optional<size_t> id);
/*
@ -227,6 +260,11 @@ struct FlightRecorder {
never hang. (timing must also be enabled for compute_duration - see
TORCH_NCCL_ENABLE_TIMING).
*/
TORCH_API void retire_id(
std::optional<size_t> id,
std::optional<size_t> reset_epoch,
bool compute_duration = true);
TORCH_API void retire_id(
std::optional<size_t> id,
bool compute_duration = true);

View File

@ -53,8 +53,41 @@ std::optional<size_t> FlightRecorder<EventType>::record(
std::chrono::milliseconds timeout_ms,
std::shared_ptr<ProcessGroupStatus> pg_status,
bool isP2P) {
auto result = recordWithResetEnabled(
pg_id,
pg_name,
collective_seq_id,
p2p_seq_id,
op_id,
std::move(profiling_name),
inputs,
outputs,
start,
end,
timeout_ms,
std::move(pg_status),
isP2P);
return result.id;
}
template <typename EventType>
typename FlightRecorder<EventType>::TraceIdentifier FlightRecorder<EventType>::
recordWithResetEnabled(
size_t pg_id,
const std::tuple<std::string, std::string>& pg_name,
size_t collective_seq_id,
size_t p2p_seq_id,
size_t op_id,
std::string profiling_name,
const std::vector<at::Tensor>& inputs,
const std::vector<at::Tensor>& outputs,
EventType* start,
EventType* end,
std::chrono::milliseconds timeout_ms,
std::shared_ptr<ProcessGroupStatus> pg_status,
bool isP2P) {
if (!enabled_) {
return std::nullopt;
return TraceIdentifier{std::nullopt, std::nullopt};
}
if (all_pg_status_.find(pg_id) == all_pg_status_.end()) {
// Current pg_status is not in FR.
@ -64,8 +97,13 @@ std::optional<size_t> FlightRecorder<EventType>::record(
torch::CapturedTraceback::gather(true, true, capture_cpp_stack_);
std::lock_guard<std::mutex> guard(mutex_);
TORCH_CHECK(
reset_epoch_start_idx_.find(reset_epoch_) !=
reset_epoch_start_idx_.end());
auto te = Entry{
id_,
reset_epoch_,
pg_id,
pg_name,
collective_seq_id,
@ -104,15 +142,20 @@ std::optional<size_t> FlightRecorder<EventType>::record(
te.sizes_.insert(te.sizes_.end(), sizes.begin(), sizes.end());
}
const auto next = next_++;
if (entries_.size() < max_entries_) {
entries_.emplace_back(std::move(te));
} else {
entries_[next_++] = std::move(te);
if (next_ == max_entries_) {
next_ = 0;
}
entries_[next] = std::move(te);
}
return id_++;
if (next_ == max_entries_) {
next_ = 0;
}
const auto id = id_++;
return TraceIdentifier{id, reset_epoch_};
}
template <typename EventType>
@ -163,15 +206,20 @@ std::vector<typename FlightRecorder<EventType>::Entry> FlightRecorder<
std::vector<Entry> result;
{
std::lock_guard<std::mutex> guard(mutex_);
result.reserve(entries_.size());
result.insert(
result.end(),
// Filter entries during insertion - only keep entries from current epoch
auto filter = [this](const Entry& e) {
return e.reset_epoch_ == reset_epoch_;
};
std::copy_if(
entries_.begin() + static_cast<std::ptrdiff_t>(next_),
entries_.end());
result.insert(
result.end(),
entries_.end(),
std::back_inserter(result),
filter);
std::copy_if(
entries_.begin(),
entries_.begin() + static_cast<std::ptrdiff_t>(next_));
entries_.begin() + static_cast<std::ptrdiff_t>(next_),
std::back_inserter(result),
filter);
}
// query any remaining events
for (auto& r : result) {
@ -182,28 +230,47 @@ std::vector<typename FlightRecorder<EventType>::Entry> FlightRecorder<
}
template <typename EventType>
// Returns the entry with the given id, if it exists. Otherwise, returns
// std::nullopt.
// Returns the index in entries_ for the given id and reset_epoch.
// Caller must hold mutex_lock before calling this method.
size_t FlightRecorder<EventType>::getIdxFromId(size_t id, size_t reset_epoch)
const {
// Look up the starting idx for the given reset epoch
auto it = reset_epoch_start_idx_.find(reset_epoch);
TORCH_CHECK(it != reset_epoch_start_idx_.end());
// Calculate idx based on where the epoch started
return (it->second + id) % max_entries_;
}
template <typename EventType>
// Returns the entry with the given id and reset_epoch, if it exists. Otherwise,
// returns std::nullopt.
std::optional<typename FlightRecorder<EventType>::Entry> FlightRecorder<
EventType>::getEntry(std::optional<size_t> id) {
if (!enabled_ || !id) {
EventType>::
getEntry(std::optional<size_t> id, std::optional<size_t> reset_epoch) {
if (!enabled_ || !id || !reset_epoch) {
return std::nullopt;
}
std::unique_lock<std::mutex> guard(mutex_);
Entry entry = entries_.at(*id % max_entries_);
if (entry.id_ == *id) {
Entry entry = entries_.at(getIdxFromId(*id, *reset_epoch));
if (entry.id_ == *id && entry.reset_epoch_ == *reset_epoch) {
return entry;
} else {
return std::nullopt;
}
return std::nullopt;
}
template <typename EventType>
std::optional<typename FlightRecorder<EventType>::Entry> FlightRecorder<
EventType>::getEntry(std::optional<size_t> id) {
return getEntry(id, 0);
}
template <typename EventType>
void FlightRecorder<EventType>::retire_id(
std::optional<size_t> id,
std::optional<size_t> reset_epoch,
bool compute_duration) {
if (!enabled_ || !id) {
if (!enabled_ || !id || !reset_epoch) {
return;
}
@ -214,8 +281,8 @@ void FlightRecorder<EventType>::retire_id(
std::unique_lock<std::mutex> guard(mutex_);
Entry* entry = &entries_.at(*id % max_entries_);
if (entry->id_ == *id) {
Entry* entry = &entries_.at(getIdxFromId(*id, *reset_epoch));
if (entry->id_ == *id && entry->reset_epoch_ == *reset_epoch) {
update_state(*entry);
if (compute_duration) {
@ -237,8 +304,8 @@ void FlightRecorder<EventType>::retire_id(
guard.lock();
// Refresh the entry pointer, see if the entry has been overwritten
entry = &entries_.at(*id % max_entries_);
if (entry->id_ != *id) {
entry = &entries_.at(getIdxFromId(*id, *reset_epoch));
if (!(entry->id_ == *id && entry->reset_epoch_ == *reset_epoch)) {
LOG(INFO) << "retire_id abandoned for id " << *id
<< ", event was overwritten while waiting to compute duration.";
return;
@ -249,12 +316,23 @@ void FlightRecorder<EventType>::retire_id(
}
}
template <typename EventType>
void FlightRecorder<EventType>::retire_id(
std::optional<size_t> id,
bool compute_duration) {
retire_id(id, 0, compute_duration);
}
template <typename EventType>
void FlightRecorder<EventType>::reset_all() {
std::lock_guard<std::mutex> guard(mutex_);
next_ = 0;
id_ = 0;
entries_.clear();
if (!entries_.empty()) {
// Soft delete: increment epoch to mark all existing entries as old
// Store where the new epoch starts in the circular buffer
reset_epoch_++;
reset_epoch_start_idx_[reset_epoch_] = next_;
id_ = 0;
}
}
template <typename EventType>

View File

@ -708,7 +708,8 @@ void ProcessGroupGloo::runLoop(int workerIndex) {
// TODO: We need to have numel of tensors for gloo as well.
pgStatus_->lastCompletedNumelIn = 0;
pgStatus_->lastCompletedNumelOut = 0;
FlightRecorder<c10::Event>::get()->retire_id(work->trace_id_, false);
FlightRecorder<c10::Event>::get()->retire_id(
work->trace_id_, work->trace_reset_epoch_, false);
lock.lock();
workInProgress_[workerIndex].reset();
}
@ -780,7 +781,7 @@ void ProcessGroupGloo::enqueue(c10::intrusive_ptr<AsyncWork> work) {
pgStatus_->lastEnqueuedNumelOut = 0;
// using c10d::FlightRecorder;
// TODO: We need to have a way to use c10::Event inside gloo as well.
work->trace_id_ = FlightRecorder<c10::Event>::get()->record(
auto traceId = FlightRecorder<c10::Event>::get()->recordWithResetEnabled(
local_id_,
std::make_tuple(pg_uid_, pg_desc_),
collectiveCounter_,
@ -795,6 +796,8 @@ void ProcessGroupGloo::enqueue(c10::intrusive_ptr<AsyncWork> work) {
work->getTimeout(),
pgStatus_,
false);
work->trace_id_ = traceId.id;
work->trace_reset_epoch_ = traceId.reset_epoch;
workQueue_.push_back(std::move(work));
lock.unlock();

View File

@ -99,6 +99,7 @@ class TORCH_API ProcessGroupGloo : public Backend {
// unique id used to tell the trace buffer that this
// work has completed
std::optional<uint64_t> trace_id_;
std::optional<uint64_t> trace_reset_epoch_;
std::shared_ptr<gloo::Context> context_;
const std::chrono::milliseconds timeout_;

View File

@ -575,6 +575,7 @@ ProcessGroupNCCL::WorkNCCL::WorkNCCL(const WorkNCCL& w)
futureWorkResult_(w.futureWorkResult_),
timingEnabled_(w.timingEnabled_),
trace_id_(w.trace_id_),
trace_reset_epoch_(w.trace_reset_epoch_),
distDebugLevel_(w.distDebugLevel_) {
exception_ = w.exception_;
}
@ -704,9 +705,9 @@ bool ProcessGroupNCCL::WorkNCCL::checkTimeout(
// Print the traceback of the collective at call time
std::string ProcessGroupNCCL::WorkNCCL::getTraceback() const {
// First step we get the corresponding record entry from FR, based on work's
// trace_id_
// trace_id_ and trace_reset_epoch_
std::optional<FlightRecorderCUDA::Entry> entry =
FlightRecorderCUDA::get()->getEntry(trace_id_);
FlightRecorderCUDA::get()->getEntry(trace_id_, trace_reset_epoch_);
if (entry.has_value()) {
auto entryVal = entry.value();
// Get stack trace from FR entry, in string format
@ -2394,7 +2395,8 @@ void ProcessGroupNCCL::Watchdog::runLoop() {
pg_->pgStatus_->lastCompletedWorkName = opTypeToString(work.opType_);
pg_->pgStatus_->lastCompletedNumelIn = work.numelIn_;
pg_->pgStatus_->lastCompletedNumelOut = work.numelOut_;
FlightRecorderCUDA::get()->retire_id(work.trace_id_, true);
FlightRecorderCUDA::get()->retire_id(
work.trace_id_, work.trace_reset_epoch_, true);
if (pg_->onCompletionHook_) {
// Move Work object to completedWorkList_ to be consumed by the hook
// thread
@ -3360,7 +3362,7 @@ c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> ProcessGroupNCCL::initWork(
// these objects to the Work because it has implications for keeping those
// tensors alive longer and adds overhead when copying Work objects
// between threads
r->trace_id_ = FlightRecorderCUDA::get()->record(
auto traceId = FlightRecorderCUDA::get()->recordWithResetEnabled(
local_id_,
std::make_tuple(pg_uid_, pg_desc_),
seqCollective_,
@ -3374,6 +3376,8 @@ c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> ProcessGroupNCCL::initWork(
options_->timeout,
pgStatus_,
isP2P);
r->trace_id_ = traceId.id;
r->trace_reset_epoch_ = traceId.reset_epoch;
}
return r;
}
@ -3593,6 +3597,7 @@ float ProcessGroupNCCL::endTimeEstimate() {
#ifdef NCCL_SIM_INFO_INITIALIZER
ncclSimInfo_t simInfo = NCCL_SIM_INFO_INITIALIZER;
C10D_NCCL_CHECK(ncclGroupSimulateEnd(&simInfo), std::nullopt);
--ncclActiveGroupCounter_;
return simInfo.estimatedTime;
#else
TORCH_CHECK(
@ -3676,7 +3681,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
// later in endCoalescing we record a 'coalesced' Work which has
// timing/state updates via watchdog thread, but lacks op metadata such as
// input/output sizes and profilingTitle per-op in the group.
FlightRecorderCUDA::get()->record(
FlightRecorderCUDA::get()->recordWithResetEnabled(
local_id_,
std::make_tuple(pg_uid_, pg_desc_),
seqCollective_,
@ -4168,7 +4173,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::pointToPoint(
// TODO(whc) because we don't pass output {tensor} to initWork, we tell
// initWork to not record, and then we manually call record passing all the
// information it wants.
work->trace_id_ = FlightRecorderCUDA::get()->record(
auto traceId = FlightRecorderCUDA::get()->recordWithResetEnabled(
local_id_,
std::make_tuple(pg_uid_, pg_desc_),
seqCollective_,
@ -4182,6 +4187,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::pointToPoint(
options_->timeout,
pgStatus_,
/*isP2P=*/true);
work->trace_id_ = traceId.id;
work->trace_reset_epoch_ = traceId.reset_epoch;
}
// Only check for NaN for send ops, for recv ops `tensor` can be a random

View File

@ -505,6 +505,7 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// unique id used to tell the trace buffer that this
// work has completed
std::optional<uint64_t> trace_id_;
std::optional<uint64_t> trace_reset_epoch_;
DebugLevel distDebugLevel_;
friend class ProcessGroupNCCL;
};

View File

@ -4,6 +4,7 @@
#include <torch/headeronly/core/ScalarType.h>
#include <torch/headeronly/macros/Macros.h>
#include <torch/headeronly/util/Exception.h>
#include <torch/headeronly/util/HeaderOnlyArrayRef.h>
#include <torch/headeronly/util/shim_utils.h>
#include <climits>
#include <memory>
@ -13,6 +14,7 @@
HIDDEN_NAMESPACE_BEGIN(torch, stable)
using accelerator::DeviceIndex;
using torch::headeronly::IntHeaderOnlyArrayRef;
using torch::headeronly::ScalarType;
// The torch::stable::Tensor class is a highlevel C++ wrapper around
@ -93,6 +95,32 @@ class Tensor {
return numel;
}
// note: this API is, for all intents and purposes, the same as the one in
// TensorBase.h: it returns a borrowed reference of the dimension sizes of
// a Tensor.
//
// The only difference is that it returns a header-only IntHeaderOnlyArrayRef,
// which has slightly less functionality than a regular IntArrayRef. See
// [HeaderOnlyArrayRef vs ArrayRef note] for more details.
IntHeaderOnlyArrayRef sizes() const {
int64_t* sizes;
TORCH_ERROR_CODE_CHECK(aoti_torch_get_sizes(ath_.get(), &sizes));
return IntHeaderOnlyArrayRef(sizes, dim());
}
// note: this API is, for all intents and purposes, the same as the one in
// TensorBase.h: it returns a borrowed reference of the strides of a
// Tensor.
//
// The only difference is that it returns a header-only IntHeaderOnlyArrayRef,
// which has slightly less functionality than a regular IntArrayRef. See
// [HeaderOnlyArrayRef vs ArrayRef note] for more details.
IntHeaderOnlyArrayRef strides() const {
int64_t* strides;
TORCH_ERROR_CODE_CHECK(aoti_torch_get_strides(ath_.get(), &strides));
return IntHeaderOnlyArrayRef(strides, dim());
}
// note: this is a subset of the original TensorBase API. It takes no
// arguments whereas the original API takes in a kwarg of memory format.
// Here, we assume the default contiguous memory format.

View File

@ -1,9 +1,8 @@
import functools
import math
import operator
from collections.abc import Sequence
from collections.abc import Callable, Sequence
from datetime import timedelta
from typing import Callable
import torch
from torch._C import ScriptObject

View File

@ -10,6 +10,7 @@ from ._context_parallel._attention import (
_enable_context_parallel_dispatcher,
_is_causal_behavior,
_RotateMethod,
_templated_ring_attention,
context_parallel,
context_parallel_unshard,
set_rotate_method,
@ -22,6 +23,7 @@ from ._context_parallel._load_balancer import (
)
# TODO(fegin): add deprecation message once the final interfaces are concluded.
__all__ = [
"_CausalBehavior",
"_context_parallel_shard",
@ -31,6 +33,7 @@ __all__ = [
"_enable_context_parallel_dispatcher",
"_is_causal_behavior",
"_RotateMethod",
"_templated_ring_attention",
"context_parallel",
"context_parallel_unshard",
"set_rotate_method",