Commit Graph

27 Commits

Author SHA1 Message Date
d7e275d4b4 [CI][CUDA] Add periodic b200 distributed job (#159323)
1. Run distributed job with B200 runner, periodically.
2. discovered generic distributed test issue that certain unit test hard-coded ranks, calling for require_exact_world_size(world_size) API instead of require_world_size(world_size).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159323
Approved by: https://github.com/eqy

Co-authored-by: Aidyn-A <aidyn.b.aitzhan@gmail.com>
2025-10-16 21:54:04 +00:00
ebd0707578 [SymmMem] Add get_nbi the nonblocking version (#163540)
```Py
@triton.jit
def foo(dest, src):
    nvshmem.get_nbi(dest, src, 100, 0)
    # Some independent computation which overlaps with the get operation
    ...
    # Wait for completion of the get operation
    nvshmem.quiet()
```

Allows us to overlap comm and compute in the same kernel, instead of two kernels + signals.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163540
Approved by: https://github.com/ngimel, https://github.com/fegin
2025-10-01 17:50:24 +00:00
6e6c899347 [Reland][163423] Promote @requires_nvshmem instead of enable_triton (#163549)
#163423 was approved but reverted due to a revert of base.
Relanding without base.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163549
Approved by: https://github.com/wdvr

Co-authored-by: Wouter Devriendt <wouterdevriendt@meta.com>
2025-09-25 23:02:00 +00:00
96275dbf88 [CI] Fix test_triton_wait_until hang (#163886)
I don't know why `nvshmem_barrier_all_kernel`  leads the test to hang. Will investigate.
But since it is an unnecessary call here, I am removing it to unblock other PRs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163886
Approved by: https://github.com/fegin
2025-09-25 22:22:16 +00:00
3a7db34cf9 Revert "[SymmMem] Promote @requires_nvshmem instead of enable_triton (#163423)"
This reverts commit 5d8a226e23339e7243a2a84afd174f685f145b68.

Reverted https://github.com/pytorch/pytorch/pull/163423 on behalf of https://github.com/wdvr due to temporary reverting to back out #162594 ([comment](https://github.com/pytorch/pytorch/pull/163423#issuecomment-3317011500))
2025-09-22 05:35:41 +00:00
5d8a226e23 [SymmMem] Promote @requires_nvshmem instead of enable_triton (#163423)
### Issue
The previous `enable_triton` UI requires the user-defined Triton kernel have a "nvshmem" in its name.
If users did not do so, the kernel would miss the NVSHMEM init, and silently hit CUDA IMA.

The `@require_nvshmem` decorator eliminates the above name requirement (and the `enable_triton` call).

### Usage:
```
@requires_nvshmem
@triton.jit
def foo(...):
    ...

foo[(1, 1)](...)
```
It also remove the need of passing `extern_lib` to `foo` (handled by the decorator now).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163423
Approved by: https://github.com/ngimel
ghstack dependencies: #163025, #163152, #163194
2025-09-21 10:03:20 +00:00
80f8be9840 [SymmMem] Fix put_signal + wait_until hang (#163194)
The test used a wrong ptr to refer to remote address:
```
            dst_ptr = out_hdl.buffer_ptrs[peer]
            src_ptr = inp_hdl.buffer_ptrs[rank]
            sig_ptr = out_hdl.signal_pad_ptrs[peer]
```
All three indices should be `rank` instead of `peer` because NVSHMEM APIs accept local address as input and perform translation internally. Without correct signal address, the peer would be waiting, thus hang.

Also adjusted the signature of `nvshmem.putmem_signal_block` to accept tensor instead of pointer.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163194
Approved by: https://github.com/ngimel
ghstack dependencies: #163025, #163152
2025-09-18 18:18:58 +00:00
57a54a04b6 [SymmMem] Fix NVSHMEM plugin + Triton 3.5 (#163152)
1. The dispatch signatures defined in `core.extern_elementwise` call must match the C signature of the NVSHMEM functions, in particular the dtypes. Otherwise, there would be weird errors, such as IMA or hang. When matched, most of time the NVSHMEM device function will be inlined into the generated PTX. When not matched, it is represented as a function call in the PTX (not sure if it is the function call that goes wrong).

2. When calling the `core.extern` wrappers from the `triton.jit` kernels, the input must be cast to match the signatures defined in 1, e.g. via `nbytes.to(tl.int64)`. Otherwise, Triton will report a key error when searching for such kernel.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163152
Approved by: https://github.com/ngimel
ghstack dependencies: #163025
2025-09-18 00:50:22 +00:00
5a2da090ed [SymmMem] Make sure CUDA runtime is initialized before NVSHMEM init (#161232)
Previously, without calling `torch.empty` before NVSHMEM init, we see error below:
```
src/host/init/init.cu:nvshmemi_check_state_and_init:1117: nvshmem initialization failed, exiting
src/host/util/cs.cpp:21: non-zero status: 16: Device or resource busy, exiting... mutex destroy failed
```
Fixing it by calling a `cudaFree(nullptr)` to make sure CUDA runtime is initialized before NVSHMEM init.
Removing all `torch.empty(1)` calls from tests.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161232
Approved by: https://github.com/ngimel
ghstack dependencies: #161214
2025-09-02 22:53:28 +00:00
779fc29c04 [C10D] Fix spelling of MultiProcContinuousTest (#160892)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160892
Approved by: https://github.com/fduwjj
2025-08-19 20:17:19 +00:00
3a56237440 [SymmMem] Send tensors with unerased type information to NVSHMEM Triton kernels (#159788)
This PR introduces a small `@triton.jit` wrapper function over our core NVSHMEM extern functions for users to send tensors as inputs to their NVSHMEM Triton kernels (rather than pointers).

The goal is to abstract away tedious details from the developer, like manual byte-size calculations and handling of raw `int64` pointers. This lets developers work directly with typed Triton tensors and element counts, which will also be useful if you want to do for instance some local math on the data.

-----

**TODO:**
This is almost complete. One pending item is tensor-aware implementation of `nvshmem.putmem_signal_block `and `nvshmem.signal_wait_until`

From my investigation, I found the root cause to be that this specific tensor API uses local addresses instead of remote addresses for the peer

```
Pointer-Based Version:

  Rank 0 → Rank 1:
    Local buffer:   0x430300a00  (src)
    Remote buffer:  0x2430300c00 (dst) ← Rank 1's memory
    Remote signal:  0x2430301600 (sig) ← Rank 1's signal

  Rank 1 (waiting):
    Local signal:   0x430301600 (waits here)

Tensor-Based Version:

  Rank 0 → Rank 1:
    Local buffer:   0x430300a00  (src)
    Local buffer:   0x430300c00  (dst) ← this is wrong
    Local signal:   0x430300e00  (sig) ← this is wrong

  Rank 1 (waiting):
    Local signal:   0x430300e00 (waits here)

```

Next Steps: Need mechanism to resolve local tensor → remote PE address, equivalent to handle.buffer_ptrs[peer] lookup.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159788
Approved by: https://github.com/mandroid6, https://github.com/ngimel
ghstack dependencies: #158515, #158718, #159136, #159215, #159701, #159734, #159755, #159756
2025-08-08 05:20:42 +00:00
bfff2e3592 [SymmMem] Refactor NVSHMEM Reduction API to be more ergonomic with automatic dtype‐based dispatch (#159755)
This change introduces a single, generic Triton‐extern wrapper for NVSHMEM team‐based reductions. We now expose one function, `nvshmem.reduce(team, dest, source, nreduce, operation, dtype_id)`, that covers all supported ops (sum, max, min, prod) and dtypes (int8…int64, uint8…uint64, float16, bfloat16, float32, float64).

It accepts real dtype objects (torch.dtype or tl.dtype) directly in the Triton kernel launch. Internally, we normalize dtype_id (handling tl.dtype, torch.dtype, str, or constexpr) into the canonical NVSHMEM typename and assemble the proper function name, e.g. nvshmem_float_sum_reduce or nvshmem_bfloat16_prod_reduce

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159755
Approved by: https://github.com/ngimel
ghstack dependencies: #158515, #158718, #159136, #159215, #159701, #159734
2025-08-08 05:20:36 +00:00
1c881440f4 [SymmMem] Initialize NVSHMEM module only for kernels that have nvshmem in their name (#159734)
Previously, a global post-compile hook initialized the NVSHMEM module for all Triton kernels, which was inefficient. This change conditionally initializes  `_nvshmemx_cumodule_init(kernel.module)` only for Triton kernels containing "nvshmem" in their name. Also updated the names for all of our nvshmem kernels to align with this.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159734
Approved by: https://github.com/ngimel
ghstack dependencies: #158515, #158718, #159136, #159215, #159701
2025-08-08 05:20:29 +00:00
7c4f7b9340 [SymmMem] Add Triton 3.4 support to NVSHMEM Triton and fix CI tests (make device library discoverable + fix peer calculation bug) (#159701)
This PR introduces support for Triton 3.4 and resolves several CI and test-related issues.

**Triton 3.4 Compatibility**
- The JIT post-compile hook has been updated from the legacy JITFunction.compiled_hook to the new API path at triton.knobs.runtime.jit_post_compile_hook.
- The internal parameter for kernel semantics in extern function definitions has been updated from _semantic to _builder to align with API changes.

**Fix CI Errors**
- The new logic inspects the RPATH of libtorch_nvshmem.so to find the NVSHMEM device library, preventing CI tests from being skipped.
- Added a decorator to run NVSHMEM tests only on H100s (compatible hardware)

**Peer Rank Calculation Fix**
- The peer calculation in test_nvshmem_triton.py was changed from peer = (world_size - 1) - rank to peer = 1 - rank.
Reasoning: The previous logic was only valid for a 2-rank setup. In the 8-rank CI environment, it incorrectly mapped peers (e.g., rank 0 to 7), breaking tests that assume a 0↔1 communication pattern. This was reproduced and validated on an 8-rank dev setup.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159701
Approved by: https://github.com/ngimel
ghstack dependencies: #158515, #158718, #159136, #159215
2025-08-08 05:20:22 +00:00
1783d6e966 [SymmMem] Fix flaky wait_until test (#159215)
When playing around with it, I noticed some flakiness in this test across sessions.

After debugging, turns out the heavy sync primitives that I was calling (like `nvshmem_quiet()` or `nvshmem_fence()`) from inside Triton kernels was causing deadlocks. The original test tried to guarantee ordering: `put(data) -> fence/quiet -> put(flag)`. But the GPU thread got stuck in `quiet()` waiting for network confirmation while holding the SM, creating a deadlock.

The fix was realizing `wait_until` already provides all the sync you need. Just do:
- PE A: `nvshmem_wait_until(&ivar, ...)`
- PE B: `nvshmem_put(&ivar_on_PE_A, ...)`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159215
Approved by: https://github.com/mandroid6, https://github.com/ngimel
ghstack dependencies: #158515, #158718, #159136
2025-08-08 05:20:16 +00:00
ea7fe0ecf6 [SymmMem] Standardize NVSHMEM Triton wrappers on byte-based APIs + improve code clarity (#159136)
Quick refactor for consistency and clarity.

1. We now standardize all NVSHMEM data-moving collectives (put, get, alltoall, broadcast) to use their byte-based *_mem_block variants. This makes the API behavior more predictable and avoids mixing paradigms.

2. Previously, some functions operated on element counts (nelems), while others expected byte sizes but still used `nelems` as the param name. That inconsistency was easy to miss and could lead to bugs, especially for devs not familiar with the NVSHMEM internals.

To clean this up:
	•	All byte-based APIs now use nbytes or nbytes_per_pe to make the units explicit.
	•	Typed APIs consistently use nelems for element counts.
	•	Docstrings were added or updated to clarify expected units.

Also did some code cleanup — removed unused functions, fixed typos in comments, and did some general housekeeping.

This should make the API more intuitive and reduce friction for developers.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159136
Approved by: https://github.com/mandroid6, https://github.com/ngimel
ghstack dependencies: #158515, #158718
2025-08-08 05:20:09 +00:00
b0b229b197 [SymmMem] Use _get_default_group() instead of group.WORLD for group_name access (#158718)
Both approaches functionally return the default process group created by `init_process_group()` but `_get_default_group()` is a dedicated function with [better error handling and type safety](4869f71170/torch/distributed/distributed_c10d.py (L1300-L1310)).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158718
Approved by: https://github.com/Skylion007, https://github.com/fduwjj
ghstack dependencies: #158515
2025-08-08 05:20:02 +00:00
b5c937259b [SymmMem] Add NVSHMEM Reduction support (sum, min, max) into Triton (#158515)
Implements sum_reduce, min_reduce, and max_reduce collective operations for NVSHMEM Triton kernels. Enables parallel reduction computations across PE teams for int64 data types.

Tests: `python test/distributed/test_nvshmem_triton.py`

<details>
<summary> Quick debug print for sanity check </summary>

```markdown
============================================================
[Rank 1] Starting min/max reduction test with world_size=2
============================================================
============================================================
[Rank 0] Starting min/max reduction test with world_size=2
============================================================
[Rank 0] Source data for min/max: [10, 20]
[Rank 1] Source data for min/max: [15, 5]
[Rank 1] All values across PEs:
[Rank 0] All values across PEs:
  - Position 0: [10, 15]
  - Position 0: [10, 15]
  - Position 1: [20, 5]
  - Position 1: [20, 5]
[Rank 1] Expected min: [10, 5]
[Rank 0] Expected min: [10, 5]
[Rank 1] Expected max: [15, 20]
[Rank 0] Expected max: [15, 20]
[Rank 0] Executing MIN reduction...
[Rank 1] Executing MIN reduction...
[Rank 0] Executing MAX reduction...
[Rank 1] Executing MAX reduction...
/data/users/suryasub/pytorch/torch/distributed/distributed_c10d.py:4809: UserWarning: No device id is provided via `init_process_group` or `barrier `. Using the current device set by the user.
  warnings.warn(  # warn only once
/data/users/suryasub/pytorch/torch/distributed/distributed_c10d.py:4809: UserWarning: No device id is provided via `init_process_group` or `barrier `. Using the current device set by the user.
  warnings.warn(  # warn only once
[Rank 1] Results:
[Rank 0] Results:
[Rank 1] MIN reduction result: [10, 5]
[Rank 1] MAX reduction result: [15, 20]
[Rank 0] MIN reduction result: [10, 5]
[Rank 0] MAX reduction result: [15, 20]
[Rank 1] ============================================================
[Rank 1] Min/Max reduction test PASSED ✓
[Rank 1] ============================================================
[Rank 0] ============================================================
[Rank 0] Min/Max reduction test PASSED ✓
[Rank 0] ============================================================
......
============================================================
============================================================
[Rank 0] Starting sum reduction test with world_size=2
[Rank 1] Starting sum reduction test with world_size=2
============================================================
============================================================
[Rank 0] Configuration:
[Rank 1] Configuration:
  - nreduce: 3 (number of separate reductions)
  - nreduce: 3 (number of separate reductions)
  - dtype: torch.int64
  - dtype: torch.int64
[Rank 1] Source data: [2, 4, 6]
[Rank 1] Contribution explanation:
[Rank 0] Source data: [1, 2, 3]
[Rank 0] Contribution explanation:
  - Element 0: 2 = (rank=1+1) * (index=0+1)
  - Element 0: 1 = (rank=0+1) * (index=0+1)
  - Element 1: 4 = (rank=1+1) * (index=1+1)
  - Element 1: 2 = (rank=0+1) * (index=1+1)
  - Element 2: 6 = (rank=1+1) * (index=2+1)
  - Element 2: 3 = (rank=0+1) * (index=2+1)
[Rank 1] Initial destination: [-1, -1, -1]
[Rank 0] Initial destination: [-1, -1, -1]
[Rank 0] Expected results after reduction: [3, 6, 9]
[Rank 1] Expected results after reduction: [3, 6, 9]
[Rank 0] Executing sum reduction...
[Rank 1] Executing sum reduction...
[Rank 1] Sum reduction completed
/data/users/suryasub/pytorch/torch/distributed/distributed_c10d.py:4809: UserWarning: No device id is provided via `init_process_group` or `barrier `. Using the current device set by the user.
  warnings.warn(  # warn only once
[Rank 0] Sum reduction completed
/data/users/suryasub/pytorch/torch/distributed/distributed_c10d.py:4809: UserWarning: No device id is provided via `init_process_group` or `barrier `. Using the current device set by the user.
  warnings.warn(  # warn only once
[Rank 0] Results after reduction:
[Rank 0] Destination buffer: [3, 6, 9]
[Rank 1] Results after reduction:
[Rank 0] Verification:
  - Reduction 0: PE0: 1 + PE1: 2 = 3
    Result: 3, Match: ✓
  - Reduction 1: PE0: 2 + PE1: 4 = 6
    Result: 6, Match: ✓
[Rank 1] Destination buffer: [3, 6, 9]
  - Reduction 2: PE0: 3 + PE1: 6 = 9
[Rank 1] Verification:
  - Reduction 0: PE0: 1 + PE1: 2 = 3
    Result: 9, Match: ✓
    Result: 3, Match: ✓
  - Reduction 1: PE0: 2 + PE1: 4 = 6
    Result: 6, Match: ✓
  - Reduction 2: PE0: 3 + PE1: 6 = 9
    Result: 9, Match: ✓
[Rank 0] ============================================================
[Rank 0] Sum reduction test PASSED ✓
[Rank 0] All 3 reductions computed correctly across 2 PEs
[Rank 0] ============================================================
[Rank 1] ============================================================
[Rank 1] Sum reduction test PASSED ✓
[Rank 1] All 3 reductions computed correctly across 2 PEs
[Rank 1] ============================================================
```

</details>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158515
Approved by: https://github.com/mandroid6, https://github.com/ngimel
2025-08-08 05:19:55 +00:00
dd0adc9386 [SymmMem] Add NVSHMEM broadcast support into Triton (#158514)
Adds broadcast collective operation for distributing data from root PE to all other PEs in NVSHMEM Triton kernels.

Tests: `python test/distributed/test_nvshmem_triton.py -k test_triton_broadcast`
<details>
<summary> Quick debug print for sanity check </summary>

```markdown
============================================================
[Rank 0] Starting broadcast test with world_size=2
============================================================
[Rank 0] Configuration:
  - nelems: 4
  - dtype: torch.int64, element_size: 8 bytes
  - nelems_bytes: 32
============================================================
[Rank 1] Starting broadcast test with world_size=2
============================================================
[Rank 1] Configuration:
  - nelems: 4
  - dtype: torch.int64, element_size: 8 bytes
  - nelems_bytes: 32
[Rank 1] Non-root source data: [-1, -1, -1, -1]
[Rank 0] Root source data: [100, 101, 102, 103]
[Rank 1] Initial destination: [-999, -999, -999, -999]
[Rank 0] Initial destination: [-999, -999, -999, -999]
[Rank 0] Executing broadcast operation...
[Rank 1] Executing broadcast operation...
[Rank 0] Broadcast operation completed
/data/users/suryasub/pytorch/torch/distributed/distributed_c10d.py:4809: UserWarning: No device id is provided via `init_process_group` or `barrier `. Using the current device set by the user.
  warnings.warn(  # warn only once
[Rank 1] Broadcast operation completed
/data/users/suryasub/pytorch/torch/distributed/distributed_c10d.py:4809: UserWarning: No device id is provided via `init_process_group` or `barrier `. Using the current device set by the user.
  warnings.warn(  # warn only once
[Rank 1] Results after broadcast:
[Rank 0] Results after broadcast:
[Rank 1] Destination buffer: [100, 101, 102, 103]
[Rank 1] Expected: [100, 101, 102, 103]
[Rank 0] Destination buffer: [100, 101, 102, 103]
[Rank 0] Expected: [100, 101, 102, 103]
[Rank 1] Match: ✓
[Rank 0] Match: ✓
[Rank 1] ============================================================
[Rank 1] Broadcast test PASSED ✓
[Rank 1] Summary: Root PE 0 broadcasted [100, 101, 102, 103] to all PEs
[Rank 1] ============================================================
[Rank 0] ============================================================
[Rank 0] Broadcast test PASSED ✓
[Rank 0] Summary: Root PE 0 broadcasted [100, 101, 102, 103] to all PEs
[Rank 0] ============================================================
```

</details>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158514
Approved by: https://github.com/fduwjj, https://github.com/mandroid6
ghstack dependencies: #158511, #158512, #158513
2025-07-21 22:23:26 +00:00
ad2dec1997 [SymmMem] Add NVSHMEM alltoall support into Triton (#158513)
Implements collective alltoall operation for NVSHMEM Triton kernels. Enables data exchange where each PE sends unique data to every other PE in the team.

Tests: `python test/distributed/test_nvshmem_triton.py -k test_triton_alltoall`

<details>
<summary>Quick debug print for sanity check</summary>

```markdown
============================================================
[Rank 0] Starting alltoall test with world_size=2
============================================================
[Rank 0] Configuration:
  - nelems_per_pe: 2
  - dtype: torch.int64, element_size: 8 bytes
  - nelems_bytes: 16
/dvs/p4/build/sw/rel/gpgpu/toolkit/r12.8/main_nvshmem/src/modules/transport/ibrc/ibrc.cpp:1653: NULL value get_device_list failed
/dvs/p4/build/sw/rel/gpgpu/toolkit/r12.8/main_nvshmem/src/modules/transport/ibrc/ibrc.cpp:1653: NULL value get_device_list failed
[Rank 0] Preparing source data:
[Rank 1] Preparing source data:
  - Data for PE 0: [0, 0] (indices 0-1)
  - Data for PE 1: [1, 1] (indices 2-3)
[Rank 0] Complete source buffer: [0, 0, 1, 1]
  - Data for PE 0: [100, 100] (indices 0-1)
  - Data for PE 1: [101, 101] (indices 2-3)
[Rank 1] Complete source buffer: [100, 100, 101, 101]
[Rank 1] Initial destination buffer: [-1, -1, -1, -1]
[Rank 0] Initial destination buffer: [-1, -1, -1, -1]
/data/users/suryasub/pytorch/torch/distributed/distributed_c10d.py:4809: UserWarning: No device id is provided via `init_process_group` or `barrier `. Using the current device set by the user.
  warnings.warn(  # warn only once
/data/users/suryasub/pytorch/torch/distributed/distributed_c10d.py:4809: UserWarning: No device id is provided via `init_process_group` or `barrier `. Using the current device set by the user.
  warnings.warn(  # warn only once
[rank0]:[W716 15:30:06.215666766 ProcessGroupNCCL.cpp:5064] [PG ID 0 PG GUID 0 Rank 0]  using GPU 0 as device used by this process is currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. You can specify device_id in init_process_group() to force use of a particular device.
[rank1]:[W716 15:30:06.215752786 ProcessGroupNCCL.cpp:5064] [PG ID 0 PG GUID 0 Rank 1]  using GPU 1 as device used by this process is currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. You can specify device_id in init_process_group() to force use of a particular device.
NCCL version 2.27.5+cuda12.4
[Rank 1] Executing alltoall operation...
[Rank 0] Executing alltoall operation...
[Rank 1] alltoall operation completed
/data/users/suryasub/pytorch/torch/distributed/distributed_c10d.py:4809: UserWarning: No device id is provided via `init_process_group` or `barrier `. Using the current device set by the user.
  warnings.warn(  # warn only once
[Rank 0] alltoall operation completed
/data/users/suryasub/pytorch/torch/distributed/distributed_c10d.py:4809: UserWarning: No device id is provided via `init_process_group` or `barrier `. Using the current device set by the user.
  warnings.warn(  # warn only once
[Rank 0] Results after alltoall:
[Rank 1] Results after alltoall:[Rank 0] Destination buffer: [0, 0, 100, 100]
[Rank 0] Verifying results:
  - From PE 0 (indices 0-1):
    Expected: [0, 0]
    Actual:   [0, 0]
[Rank 1] Destination buffer: [1, 1, 101, 101]
[Rank 1] Verifying results:
  - From PE 0 (indices 0-1):
    Expected: [1, 1]
    Actual:   [1, 1]
    Match:    ✓
    Match:    ✓
  - From PE 1 (indices 2-3):
    Expected: [100, 100]
  - From PE 1 (indices 2-3):
    Expected: [101, 101]
    Actual:   [100, 100]
    Actual:   [101, 101]
    Match:    ✓
    Match:    ✓
[Rank 0] ============================================================
[Rank 0] Summary: ALL TESTS PASSED ✓
[Rank 0] Data flow explanation:
  - Each rank sends 2 elements to every other rank
[Rank 1] ============================================================
[Rank 1] Summary: ALL TESTS PASSED ✓
  - Rank 0 sent: [0, 0, 1, 1]
[Rank 1] Data flow explanation:
  - Each rank sends 2 elements to every other rank
  - Rank 0 received: [0, 0, 100, 100]
  - My data for PE 0 (0) went to PE 0's buffer
  - I received PE 0's data for me (0)
  - My data for PE 1 (1) went to PE 1's buffer
  - Rank 1 sent: [100, 100, 101, 101]
  - I received PE 1's data for me (100)
[Rank 0] ============================================================
  - Rank 1 received: [1, 1, 101, 101]
  - My data for PE 0 (100) went to PE 0's buffer
  - I received PE 0's data for me (1)
  - My data for PE 1 (101) went to PE 1's buffer
  - I received PE 1's data for me (101)
[Rank 1] ============================================================
```

</details>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158513
Approved by: https://github.com/fduwjj, https://github.com/mandroid6
ghstack dependencies: #158511, #158512
2025-07-21 19:14:47 +00:00
bbc32d680f [SymmMem] Add NVSHMEM sync_all support into Triton (#158512)
Adds `sync_all()` function for local store visibility synchronization in NVSHMEM Triton kernels. Provides memory ordering for local operations without remote completion guarantees.

Tests: `python test/distributed/test_nvshmem_triton.py -k test_triton_sync`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158512
Approved by: https://github.com/fduwjj
ghstack dependencies: #158511
2025-07-21 10:27:59 +00:00
70b4a8880b [SymmMem] Add NVSHMEM barrier_all, my_pe, n_pes support into Triton (#158511)
Adds device-side barrier synchronization and PE identification functions for NVSHMEM Triton integration. Includes `barrier_all()` for collective synchronization and `my_pe()`/`n_pes()` for PE identification within kernels.

We are launching with cooperative grid launch (for all the PRs in this stack) because the `nvshmemx_collective_launch` function must be used to launch kernels on the GPU when the kernels use NVSHMEM synchronization or collective APIs, and `nvshmemx_collective_launch` essentially boils down to a CUDA cooperative group launch.

Tests: `python test/distributed/test_nvshmem_triton.py -k test_triton_barrier`

Also tested that if you remove the barrier, you get an assertion error/race conditions.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158511
Approved by: https://github.com/fduwjj
2025-07-21 02:37:33 +00:00
3341c131b7 [SymmMem] Fix NCCL Hang in NVSHMEM Triton Wait Until Test (#158167)
The `test_triton_wait_until` test was hanging due to an NCCL synchronization issue stemming from mismatched NVSHMEM operations. Specifically, the flag variable was updated using `nvshmemx_signal_op` (a signaling operation), but waited on with `nvshmem_wait_until` (intended for put/get updates). Per NVSHMEM documentation (see documentation reference section below), signal-updated variables require `nvshmem_signal_wait_until` for proper completion guarantees, so the mismatch caused a deadlock and NCCL hang.

**Fix:**
- A simple fix was to replace the flag update with a regular `nvshmem_putmem_block` (via `put_kernel`) to match `nvshmem_wait_until`. I also added a fence (`nvshmem_fence`) between data and flag puts on the sender (Rank 1) for ordered delivery.

- In a follow-up PR I will add a kernel/test to demonstrate usage of `nvshmemx_signal_op`

**Testing:**
- I ran `python test/distributed/test_nvshmem_triton.py` and  `python test/distributed/test_nvshmem_triton.py  -k test_triton_wait_until`

- I also verified with debug prints (Sender completes puts/fence before receiver's wait returns, and assertions confirm correct state). Multiple runs show no hangs or failures.

**Documentation Referenced:**
- [NVSHMEM Point-To-Point Synchronization](https://docs.nvidia.com/nvshmem/api/gen/api/sync.html) explicitly states: *"the sig_addr object at the calling PE is expected only to be updated as a signal, through the signaling operations available in Section NVSHMEM_PUT_SIGNAL and Section NVSHMEM_PUT_SIGNAL_NBI"*
- [NVIDIA's Official Ring Broadcast Example](https://docs.nvidia.com/nvshmem/api/examples.html) demonstrates the correct pairing: `nvshmemx_signal_op` with `nvshmem_signal_wait_until` (not `nvshmem_wait_until`)
- [NVSHMEM Signaling Operations](https://docs.nvidia.com/nvshmem/api/gen/api/signal.html) documents that signal operations work on special "signal data objects" with specific atomicity guarantees distinct from regular RMA operations

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158167
Approved by: https://github.com/Skylion007, https://github.com/fduwjj
2025-07-15 05:57:27 +00:00
19a01382bc Revert "[SymmMem] find_path does not search /usr/local/lib (#157695)"
This reverts commit 3effe0c293219b00a0eae7e139fe2d9aed84bc03.

Reverted https://github.com/pytorch/pytorch/pull/157695 on behalf of https://github.com/kwen2501 due to Changing it to be landable on 2.8 branch ([comment](https://github.com/pytorch/pytorch/pull/157695#issuecomment-3047020152))
2025-07-08 01:12:01 +00:00
3effe0c293 [SymmMem] find_path does not search /usr/local/lib (#157695)
This PR uses `find_library` to replace `find_path`.
It also searches for NVSHMEM host lib and device lib separately.

Tested against system install location: /usr/local/lib and /usr/local/include.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157695
Approved by: https://github.com/Skylion007
ghstack dependencies: #157513
2025-07-07 23:16:45 +00:00
8f9a191db6 [SymmMem] Fix CI name mismatch; remove TORCH_SYMMMEM requirement (#157597)
Thanks @huydhn for spotting two name mismatches in the CI configs.
We were matching against "test_h100_symm_mem" instead of "h100-symm-mem".

Also, replaced `TORCH_SYMMMEM` env setting with programmatic method:
`symm_mem.set_backend(...)`

Further, skips a hanged test in `test_nvshmem_trion.py`. (#TODO @codingwithsurya )

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157597
Approved by: https://github.com/fduwjj, https://github.com/huydhn
2025-07-04 01:43:08 +00:00
195ef1bce8 [SymmMem] Refactor NVSHMEM tests: separate Triton tests into dedicated file (#156685)
## Summary

Moved the Triton-specific NVSHMEM tests in `test_nvshmem.py` into a dedicated `test_nvshmem_triton.py` file. Also put the shared Triton JIT kernels at the top-level of new file for reusability.

## Testing

```bash
TORCH_SYMMMEM=NVSHMEM python test/distributed/test_nvshmem.py
TORCH_SYMMMEM=NVSHMEM python test/distributed/test_nvshmem_triton.py
```

All 16 original tests pass with no functionality changes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156685
Approved by: https://github.com/mandroid6, https://github.com/kwen2501
ghstack dependencies: #156684
2025-06-27 04:38:37 +00:00