[ROCm] Add support for SymmetricMemory (#150580)

This is an attempt to re-land the initial PR https://github.com/pytorch/pytorch/pull/134817 with recent design changes from upstream.

**NOTE:**
ROCm currently does NOT have multicast/multimem hardware support at the moment, so those features are disabled in symmetric memory for ROCm. This also means that we currently do not have a way of lowering add + all_reduce + wait_tensor into one_shot_all_reduce op in inductor as it depends on a multicast buffer support.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150580
Approved by: https://github.com/jeffdaily, https://github.com/kwen2501, https://github.com/yoyoyocmu

Co-authored-by: Xiaodong Wang <xdwang@fb.com>
This commit is contained in:
Prachi Gupta
2025-05-02 18:35:14 +00:00
committed by PyTorch MergeBot
parent 376529c78b
commit 1ea2731e26
8 changed files with 234 additions and 79 deletions

View File

@ -57,9 +57,11 @@ from torch.testing._internal.common_distributed import (
)
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
MI300_ARCH,
parametrize,
retry_on_connect_failures,
run_tests,
runOnRocmArch,
skip_but_pass_in_sandcastle,
skip_but_pass_in_sandcastle_if,
TEST_CUDA,
@ -3322,7 +3324,7 @@ class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):
@requires_nccl()
@skip_if_lt_x_gpu(2)
@skip_if_rocm_multiprocess
@runOnRocmArch(MI300_ARCH)
def test_intra_node_comm_all_reduce(self):
from torch._C._distributed_c10d import _get_intra_node_comm_usage_counter
from torch.testing._internal.common_cuda import SM80OrLater

View File

@ -21,6 +21,7 @@ from torch.distributed._symmetric_memory import (
restride_A_shard_for_fused_all_gather_matmul,
)
from torch.testing._internal.common_cuda import _get_torch_cuda_version, SM90OrLater
from torch.testing._internal.common_device_type import e4m3_type
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
requires_multicast_support,
@ -28,11 +29,14 @@ from torch.testing._internal.common_distributed import (
)
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
MI300_ARCH,
parametrize,
requires_cuda,
run_tests,
runOnRocmArch,
skip_but_pass_in_sandcastle_if,
skipIfRocm,
TEST_WITH_ROCM,
TestCase,
)
@ -102,7 +106,7 @@ class SymmetricMemoryTest(MultiProcessTestCase):
for row in connectivity.matrix:
self.assertEqual(len(row), torch.cuda.device_count())
@skipIfRocm
@runOnRocmArch(MI300_ARCH)
def test_large_alloc(self) -> None:
t = symm_mem.empty(2 * 1024**3, dtype=torch.uint8, device="cuda")
self.assertEqual(t.numel() * t.element_size(), 2 * 1024**3)
@ -142,7 +146,7 @@ class SymmetricMemoryTest(MultiProcessTestCase):
symm_mem_hdl.barrier()
@skipIfRocm
@runOnRocmArch(MI300_ARCH)
@skip_if_lt_x_gpu(2)
@parametrize("set_device", [True, False])
def test_empty_strided_p2p(self, set_device: bool) -> None:
@ -161,7 +165,7 @@ class SymmetricMemoryTest(MultiProcessTestCase):
self._verify_symmetric_memory(symm_mem_hdl)
dist.destroy_process_group()
@skipIfRocm
@runOnRocmArch(MI300_ARCH)
@skip_if_lt_x_gpu(2)
@parametrize("set_device", [True, False])
def test_empty_strided_p2p_persistent(self, set_device: bool) -> None:
@ -189,7 +193,7 @@ class SymmetricMemoryTest(MultiProcessTestCase):
self._verify_symmetric_memory(symm_mem_hdl)
dist.destroy_process_group()
@skipIfRocm
@runOnRocmArch(MI300_ARCH)
@skip_if_lt_x_gpu(2)
def test_get_signal_pad(self) -> None:
self._init_process()
@ -232,6 +236,11 @@ class SymmetricMemoryTest(MultiProcessTestCase):
dist.destroy_process_group()
# These timeout tests are skipped on ROCm because timeout calls trap(), which
# is handled differently inside hip runtime. It collects gpu coredump and causes
# the linux kernel to create a core dump of the host application. The funcitonality
# is there, meaning timeout is happening correctly. However, there isn't a nice way
# to test it as the current executing thread will coredump and exit.
@skipIfRocm
@skip_if_lt_x_gpu(2)
def test_barrier_timeout(self) -> None:
@ -253,6 +262,11 @@ class SymmetricMemoryTest(MultiProcessTestCase):
# impossible to terminate the process in this state.
os._exit(0)
# These timeout tests are skipped on ROCm because timeout calls trap(), which
# is handled differently inside hip runtime. It collects gpu coredump and causes
# the linux kernel to create a core dump of the host application. The funcitonality
# is there, meaning timeout is happening correctly. However, there isn't a nice way
# to test it as the current executing thread will coredump and exit.
@skipIfRocm
@skip_if_lt_x_gpu(2)
def test_put_signal_timeout(self) -> None:
@ -277,6 +291,11 @@ class SymmetricMemoryTest(MultiProcessTestCase):
# impossible to terminate the process in this state.
os._exit(0)
# These timeout tests are skipped on ROCm because timeout calls trap(), which
# is handled differently inside hip runtime. It collects gpu coredump and causes
# the linux kernel to create a core dump of the host application. The funcitonality
# is there, meaning timeout is happening correctly. However, there isn't a nice way
# to test it as the current executing thread will coredump and exit.
@skipIfRocm
@skip_if_lt_x_gpu(2)
def test_wait_signal_timeout(self) -> None:
@ -298,7 +317,7 @@ class SymmetricMemoryTest(MultiProcessTestCase):
# impossible to terminate the process in this state.
os._exit(0)
@skipIfRocm
@runOnRocmArch(MI300_ARCH)
@requires_cuda
def test_allow_overlapping_devices(self) -> None:
os.environ["TORCH_SYMM_MEM_ALLOW_OVERLAPPING_DEVICES"] = "1"
@ -324,7 +343,7 @@ class SymmetricMemoryTest(MultiProcessTestCase):
dist.destroy_process_group()
@skipIfRocm
@runOnRocmArch(MI300_ARCH)
@skip_if_lt_x_gpu(2)
@parametrize("gather_dim", [0, 1])
def test_fused_all_gather_matmul(self, gather_dim: int) -> None:
@ -356,7 +375,7 @@ class SymmetricMemoryTest(MultiProcessTestCase):
dist.destroy_process_group()
@skipIfRocm
@skipIfRocm # this requires async_input_mm support
@skipIf(
not SM90OrLater,
"_fused_all_gather_matmul_native currently only supports sm>=90",
@ -416,7 +435,6 @@ class SymmetricMemoryTest(MultiProcessTestCase):
dist.destroy_process_group()
@skipIfRocm
@skip_if_lt_x_gpu(2)
@requires_multicast_support()
def test_multimem_all_gather_matmul(self) -> None:
@ -457,7 +475,7 @@ class SymmetricMemoryTest(MultiProcessTestCase):
dist.destroy_process_group()
@skipIfRocm
@runOnRocmArch(MI300_ARCH)
@skip_if_lt_x_gpu(2)
@parametrize("gather_dim", [0, 1])
@parametrize(
@ -483,10 +501,9 @@ class SymmetricMemoryTest(MultiProcessTestCase):
raise AssertionError("Invalid scale_mode: {scale_mode}")
torch.manual_seed(42 + rank)
A_shard = torch.rand(*leading_dims, K, device="cuda").to(torch.float8_e4m3fn)
Bs = [
torch.rand(N, K, device="cuda").to(torch.float8_e4m3fn).T for _ in range(3)
]
A_shard = torch.rand(*leading_dims, K, device="cuda").to(e4m3_type)
Bs = [torch.rand(N, K, device="cuda").to(e4m3_type).T for _ in range(3)]
if scale_mode == "tensor-wise":
A_scale = torch.tensor(0.1, device="cuda")
@ -546,7 +563,7 @@ class SymmetricMemoryTest(MultiProcessTestCase):
dist.destroy_process_group()
@skipIfRocm
@runOnRocmArch(MI300_ARCH)
@skip_if_lt_x_gpu(2)
@parametrize("scatter_dim", [0, 1])
def test_fused_matmul_reduce_scatter(self, scatter_dim: int) -> None:
@ -575,7 +592,7 @@ class SymmetricMemoryTest(MultiProcessTestCase):
dist.destroy_process_group()
@skipIfRocm
@skipIfRocm # AsyncTP support changed _fused_scaled_matmul_reduce_scatter_fallback API, need more changes
@skip_if_lt_x_gpu(2)
@parametrize("scatter_dim", [0, 1])
@parametrize("rowwise", [True, False])
@ -592,8 +609,8 @@ class SymmetricMemoryTest(MultiProcessTestCase):
rank = self.rank
torch.manual_seed(42 + rank)
A = torch.rand(BATCH, M, K, device="cuda").to(torch.float8_e4m3fn)
B = torch.rand(N, K, device="cuda").to(torch.float8_e4m3fn).T
A = torch.rand(BATCH, M, K, device="cuda").to(e4m3_type)
B = torch.rand(N, K, device="cuda").to(e4m3_type).T
if rowwise:
A_scale = torch.full((BATCH, M, 1), 0.1, device="cuda")
@ -628,7 +645,7 @@ class SymmetricMemoryTest(MultiProcessTestCase):
dist.destroy_process_group()
@skipIfRocm
@runOnRocmArch(MI300_ARCH)
@parametrize("dim", [0, 1, 2])
def test_optimal_layout(self, dim: int) -> None:
t = torch.rand(8, 64, 32, 16)
@ -641,7 +658,7 @@ class SymmetricMemoryTest(MultiProcessTestCase):
self.assertTrue(x.movedim(dim, 0).is_contiguous())
self.assertTrue(torch.allclose(x, t))
@skipIfRocm
@runOnRocmArch(MI300_ARCH)
@skip_if_lt_x_gpu(2)
@parametrize("symm_mem_input", [True, False])
def test_low_contention_all_gather(self, symm_mem_input: bool) -> None:
@ -668,7 +685,7 @@ class SymmetricMemoryTest(MultiProcessTestCase):
dist.destroy_process_group()
@skipIfRocm
@runOnRocmArch(MI300_ARCH)
@skip_if_lt_x_gpu(2)
@parametrize("reduce_op", ["sum", "avg"])
@parametrize("symm_mem_input", [True, False])
@ -733,7 +750,7 @@ class SubgroupTest(MultiProcessTestCase):
)
torch.manual_seed(42 + self.rank)
@skipIfRocm
@runOnRocmArch(MI300_ARCH)
@skip_if_lt_x_gpu(4)
def test_subgroup(self) -> None:
self._init_process()
@ -771,7 +788,6 @@ class SubgroupTest(MultiProcessTestCase):
self.assertTrue(buf.eq(peer_rank + world.size() // 2).all())
# @skipIfRocm
@instantiate_parametrized_tests
@requires_cuda_p2p_access()
class SymmMemCollectiveTest(MultiProcessTestCase):
@ -859,7 +875,7 @@ class SymmMemCollectiveTest(MultiProcessTestCase):
dist.destroy_process_group()
@skipIfRocm
@runOnRocmArch(MI300_ARCH)
@skip_if_lt_x_gpu(4)
def test_one_shot_all_reduce(self) -> None:
self._init_process()
@ -890,7 +906,7 @@ class SymmMemCollectiveTest(MultiProcessTestCase):
dist.destroy_process_group()
@skipIfRocm
@runOnRocmArch(MI300_ARCH)
@skip_if_lt_x_gpu(4)
def test_two_shot_all_reduce(self) -> None:
self._init_process()
@ -940,7 +956,7 @@ class SymmMemCollectiveTest(MultiProcessTestCase):
gathered_inps.sum(dim=0), res, rtol=1e-01, atol=1e-01
)
@skipIfRocm
@runOnRocmArch(MI300_ARCH)
@skip_if_lt_x_gpu(4)
def test_reduce_scatter(self) -> None:
self._init_process()
@ -977,7 +993,7 @@ class SymmMemCollectiveTest(MultiProcessTestCase):
dist.destroy_process_group()
@skipIfRocm
@runOnRocmArch(MI300_ARCH)
@skip_if_lt_x_gpu(4)
def test_reduce_scatter_corner_cases(self) -> None:
dtype = torch.bfloat16
@ -1125,12 +1141,12 @@ class LoweringTest(MultiProcessTestCase):
class SymmMemSingleProcTest(TestCase):
@skipIfRocm
@requires_cuda
@skipIf(
_get_torch_cuda_version() < (12, 0),
not TEST_WITH_ROCM and _get_torch_cuda_version() < (12, 0),
"stream_write_value32 currently only supports cuda version>=12.0",
)
@runOnRocmArch(MI300_ARCH)
def test_stream_write_value32(self):
tensor = torch.zeros(4, dtype=torch.uint32, device="cuda")
expect = torch.tril(torch.ones(4, 4, device="cuda")).to(torch.uint32)
@ -1145,8 +1161,8 @@ class SymmMemSingleProcTest(TestCase):
with self.assertRaises(RuntimeError):
_SymmetricMemory.stream_write_value32(tensor, offset=0, val=4294967296)
@skipIfRocm
@requires_cuda
@runOnRocmArch(MI300_ARCH)
def test_memset32(self):
t = _SymmetricMemory.empty_strided_p2p(
(64,),

View File

@ -7,6 +7,9 @@
#endif
#include <ATen/ATen.h>
#if defined(USE_ROCM)
#include <hip/hip_bf16.h>
#endif
#if !defined(USE_ROCM)
#include <cuda_bf16.h>
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600)
@ -33,6 +36,10 @@ cas(uint32_t* addr, uint32_t compare, uint32_t val) {
::cuda::atomic_ref<uint32_t, ::cuda::thread_scope_system> ref(*addr);
ref.compare_exchange_strong(compare, val, ::cuda::std::memory_order(Sem));
return compare;
#elif defined(USE_ROCM)
__atomic_compare_exchange_n(
addr, &compare, val, false, static_cast<int>(Sem), __ATOMIC_RELAXED);
return compare;
#else
CUDA_KERNEL_ASSERT(false);
return 0;
@ -41,7 +48,10 @@ cas(uint32_t* addr, uint32_t compare, uint32_t val) {
__device__ __forceinline__ void trap() {
#if defined(USE_ROCM)
assert(0);
// abort() calls trap() under the covers. However, on ROCm, the trap is
// handled differently inside hip runtime. It collects a gpu core dump and
// causes linux kernerl to create a core dump of the host application.
abort();
#else
__trap();
#endif
@ -49,8 +59,8 @@ __device__ __forceinline__ void trap() {
__device__ __forceinline__ size_t global_timer_ns() {
#if defined(USE_ROCM)
CUDA_KERNEL_ASSERT(false);
return 0;
static constexpr double MI300_FREQ_GHZ = 2.1;
return __builtin_amdgcn_s_memtime() / MI300_FREQ_GHZ;
#else
size_t val;
asm volatile("mov.u64 %0, %globaltimer;" : "=l"(val) : : "memory");
@ -244,14 +254,10 @@ __device__ __inline__ void multimem_st(T* mc_ptr, Vec<Alignment>& vec) {
#endif
}
#if defined(USE_ROCM)
using __nv_bfloat162 = uint32_t;
#endif
template <typename T>
__device__ __inline__ T add_bf16x2(T a, T b) {
static_assert(sizeof(T) == 4);
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
CUDA_KERNEL_ASSERT(false);
return T{};
#else

View File

@ -14,6 +14,8 @@
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
#include <c10/cuda/driver_api.h>
#elif defined(USE_ROCM)
#include <hip/hip_runtime_api.h>
#endif
#if defined(CUDART_VERSION) && CUDART_VERSION >= 12030
@ -39,17 +41,20 @@ AllocationRef::AllocationRef(
device_idx(device_idx) {}
AllocationRef::~AllocationRef() {
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
// Leak the cuda allocations during static deinitialization
if (is_finalizing()) {
return;
}
auto driver_api = c10::cuda::DriverAPI::get();
c10::cuda::CUDAGuard guard(device_idx);
C10_CUDA_CHECK(cudaDeviceSynchronize());
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
// Leak the cuda allocations during static deinitialization
auto driver_api = c10::cuda::DriverAPI::get();
C10_CUDA_DRIVER_CHECK(
driver_api->cuMemUnmap_(reinterpret_cast<CUdeviceptr>(ptr), block_size));
C10_CUDA_DRIVER_CHECK(driver_api->cuMemRelease_(handle));
#elif defined(USE_ROCM)
C10_HIP_CHECK(hipMemUnmap(reinterpret_cast<hipDeviceptr_t>(ptr), block_size));
C10_HIP_CHECK(hipMemRelease(handle));
#else
TORCH_CHECK(
false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED");
@ -368,10 +373,12 @@ void* CUDASymmetricMemoryAllocator::alloc(
size_t size,
int device_idx,
const std::optional<std::string>& group_name) {
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
size_t signal_pad_offset = at::round_up(size, 16UL);
size_t block_size = signal_pad_offset + signal_pad_size;
c10::cuda::CUDAGuard guard(device_idx);
device_idx = static_cast<int>(guard.current_device().index());
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
CUmemAllocationProp prop = {};
prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
@ -379,8 +386,6 @@ void* CUDASymmetricMemoryAllocator::alloc(
prop.location.id = device_idx;
prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
size_t signal_pad_offset = at::round_up(size, 16UL);
size_t block_size = signal_pad_offset + signal_pad_size;
size_t granularity;
auto driver_api = c10::cuda::DriverAPI::get();
@ -392,6 +397,27 @@ void* CUDASymmetricMemoryAllocator::alloc(
C10_CUDA_DRIVER_CHECK(
driver_api->cuMemCreate_(&handle, block_size, &prop, 0));
#elif defined(USE_ROCM)
hipMemAllocationProp prop = {};
prop.type = hipMemAllocationTypePinned;
prop.location.type = hipMemLocationTypeDevice;
// NOLINTNEXTLINE(bugprone-signed-char-misuse)
prop.location.id = device_idx;
prop.requestedHandleType = hipMemHandleTypePosixFileDescriptor;
size_t granularity;
C10_HIP_CHECK(hipMemGetAllocationGranularity(
&granularity, &prop, hipMemAllocationGranularityRecommended));
block_size = at::round_up(block_size, granularity);
HandleType handle;
C10_HIP_CHECK(hipMemCreate(reinterpret_cast<hipMemGenericAllocationHandle_t*>(&handle), block_size, &prop, 0));
#else
TORCH_CHECK(
false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED");
#endif
void* ptr = nullptr;
map_block(&ptr, handle, block_size, device_idx);
@ -411,10 +437,6 @@ void* CUDASymmetricMemoryAllocator::alloc(
ptr_to_block_.emplace(ptr, std::move(block));
}
return ptr;
#else
TORCH_CHECK(
false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED");
#endif
}
void CUDASymmetricMemoryAllocator::free(void* ptr) {
@ -559,7 +581,7 @@ static void init_multicast_for_block(
c10::intrusive_ptr<SymmetricMemory> CUDASymmetricMemoryAllocator::rendezvous(
void* ptr,
const std::optional<std::string>& group_name) {
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
auto block = find_block(ptr);
if (block == nullptr) {
return nullptr;
@ -595,9 +617,10 @@ c10::intrusive_ptr<SymmetricMemory> CUDASymmetricMemoryAllocator::rendezvous(
auto store = group_info.store;
int rank = group_info.rank;
int world_size = group_info.world_size;
auto driver_api = c10::cuda::DriverAPI::get();
int block_fd;
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
auto driver_api = c10::cuda::DriverAPI::get();
// using the CUDA Driver API to export a GPU memory block as a
// POSIX file descriptor (FD), so it can be shared across processes via IPC.
C10_CUDA_DRIVER_CHECK(driver_api->cuMemExportToShareableHandle_(
@ -605,6 +628,13 @@ c10::intrusive_ptr<SymmetricMemory> CUDASymmetricMemoryAllocator::rendezvous(
block->alloc_ref->handle,
CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR,
0));
#elif defined (USE_ROCM)
C10_HIP_CHECK(hipMemExportToShareableHandle(
&block_fd, block->alloc_ref->handle, hipMemHandleTypePosixFileDescriptor, 0));
#else
TORCH_CHECK(
false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED");
#endif
auto local_req = RendezvousRequest{
.device_idx = block->device_idx,
@ -635,10 +665,20 @@ c10::intrusive_ptr<SymmetricMemory> CUDASymmetricMemoryAllocator::rendezvous(
}
// This api imports a GPU memory allocation that was previously exported as a file
// descriptor and it returns a memory handle.
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
C10_CUDA_DRIVER_CHECK(driver_api->cuMemImportFromShareableHandle_(
&handles[r],
(void*)(uintptr_t)imported_fds[r],
CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR));
#elif defined (USE_ROCM)
C10_HIP_CHECK(hipMemImportFromShareableHandle(
&handles[r],
(void*)(uintptr_t)&(imported_fds[r]),
hipMemHandleTypePosixFileDescriptor));
#else
TORCH_CHECK(
false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED");
#endif
map_block(&buffers[r], handles[r], block->block_size, block->device_idx);
signal_pads[r] = (void*)((uintptr_t)buffers[r] + block->signal_pad_offset);
close(imported_fds[r]);
@ -676,10 +716,6 @@ c10::intrusive_ptr<SymmetricMemory> CUDASymmetricMemoryAllocator::rendezvous(
group_info.world_size);
block->symm_mems[group_name_] = symm_mem;
return symm_mem;
#else
TORCH_CHECK(
false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED");
#endif
}
bool CUDASymmetricMemoryAllocator::has_multicast_support(int device_idx) {

View File

@ -20,7 +20,7 @@
#include <torch/csrc/distributed/c10d/CUDASymmetricMemory.hpp>
#include <torch/csrc/distributed/c10d/cuda/AsyncMM.cuh>
#if defined(CUDART_VERSION) && CUDART_VERSION >= 12030
#if defined(USE_ROCM) || (defined(CUDART_VERSION) && CUDART_VERSION >= 12030)
#define INT_SWITCH_CASE(name, val, ...) \
case val: { \
@ -124,6 +124,7 @@ void init_elementwise_launch_config(
}
}
#if !defined(USE_ROCM) //No multi-cast support on ROCm yet
template <typename T, int alignment>
static __global__ void multimem_all_reduce_kernel(
T* input_mc_ptr,
@ -395,6 +396,8 @@ at::Tensor multimem_all_gather_out(
return out;
}
#endif //no multi-cast support on ROCm
// One-shot all-reduce is register-intensive because it stages values loaded
// from peers in registers before performing reduction. Setting the thread
// count to 512 to prevent/alleviate register spill.
@ -1054,7 +1057,6 @@ at::Tensor memset32_(
int64_t offset,
int64_t val,
int64_t count) {
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
TORCH_CHECK(
input.dim() == 1 && input.is_contiguous() &&
input.scalar_type() == c10::ScalarType::UInt32,
@ -1078,7 +1080,6 @@ at::Tensor memset32_(
"symm_mem::memset32_: val must be in the range of "
"[0, 4294967295] (uint32_t).")
auto element_size = c10::elementSize(input.scalar_type());
TORCH_CHECK(
offset + count <= input.numel(),
"symm_mem::memset32_: offset + count (",
@ -1088,14 +1089,20 @@ at::Tensor memset32_(
")");
auto addr = reinterpret_cast<uint32_t*>(input.data_ptr()) + offset;
c10::cuda::CUDAGuard guard(input.device());
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
auto driver_api = c10::cuda::DriverAPI::get();
C10_CUDA_DRIVER_CHECK(driver_api->cuMemsetD32Async_(
reinterpret_cast<CUdeviceptr>(addr),
val,
count,
at::cuda::getCurrentCUDAStream()));
#elif defined(USE_ROCM)
C10_HIP_CHECK(hipMemsetD32Async(reinterpret_cast<hipDeviceptr_t>(addr),
val,
count,
at::cuda::getCurrentCUDAStream()));
#else
TORCH_CHECK(
false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED");
@ -1107,7 +1114,6 @@ at::Tensor stream_write_value32_(
at::Tensor& input,
int64_t offset,
int64_t val) {
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
TORCH_CHECK(
input.dim() == 1 && input.is_contiguous() &&
input.scalar_type() == c10::ScalarType::UInt32,
@ -1127,7 +1133,6 @@ at::Tensor stream_write_value32_(
"symm_mem::stream_write_value32_: "
"val must be in the range of [0, 4294967295] (uint32_t).")
auto element_size = c10::elementSize(input.scalar_type());
TORCH_CHECK(
offset < input.numel(),
"symm_mem::stream_write_value32_: offset (",
@ -1137,8 +1142,9 @@ at::Tensor stream_write_value32_(
")");
auto addr = reinterpret_cast<uint32_t*>(input.data_ptr()) + offset;
c10::cuda::CUDAGuard guard(input.device());
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
auto driver_api = c10::cuda::DriverAPI::get();
// According to the documentation of CUstreamWriteValue_flags,
// cuStreamWriteValue32 will provide a memory fence before the write, which
@ -1149,6 +1155,12 @@ at::Tensor stream_write_value32_(
reinterpret_cast<CUdeviceptr>(addr),
val,
0));
#elif defined(USE_ROCM)
C10_HIP_CHECK(hipStreamWriteValue32(
at::cuda::getCurrentCUDAStream(),
reinterpret_cast<void*>(addr),
val,
0));
#else
TORCH_CHECK(
false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED");
@ -1159,6 +1171,17 @@ at::Tensor stream_write_value32_(
} // namespace
TORCH_LIBRARY_IMPL(symm_mem, CUDA, m) {
#if defined(USE_ROCM) || defined(CUDART_VERSION)
m.impl("one_shot_all_reduce", ::one_shot_all_reduce);
m.impl("one_shot_all_reduce_out", ::one_shot_all_reduce_out);
m.impl("one_shot_all_reduce_copy", ::one_shot_all_reduce_copy);
m.impl("one_shot_all_reduce_copy_out", ::one_shot_all_reduce_copy_out);
m.impl("two_shot_all_reduce_", ::two_shot_all_reduce_);
m.impl("two_shot_all_reduce_out", ::two_shot_all_reduce_out);
m.impl("reduce_scatter_out", ::reduce_scatter_out);
m.impl("_async_input_mm", c10d::cuda::detail::async_input_mm);
#endif
#if defined(CUDART_VERSION)
m.impl("multimem_all_reduce_", ::multimem_all_reduce_);
@ -1173,15 +1196,6 @@ TORCH_LIBRARY_IMPL(symm_mem, CUDA, m) {
m.impl(
"multimem_one_shot_all_reduce_out", ::multimem_one_shot_all_reduce_out);
m.impl("multimem_all_gather_out", ::multimem_all_gather_out);
m.impl("one_shot_all_reduce", ::one_shot_all_reduce);
m.impl("one_shot_all_reduce_out", ::one_shot_all_reduce_out);
m.impl("one_shot_all_reduce_copy", ::one_shot_all_reduce_copy);
m.impl("one_shot_all_reduce_copy_out", ::one_shot_all_reduce_copy_out);
m.impl("two_shot_all_reduce_", ::two_shot_all_reduce_);
m.impl("two_shot_all_reduce_out", ::two_shot_all_reduce_out);
m.impl("reduce_scatter_out", ::reduce_scatter_out);
m.impl("_async_input_mm", c10d::cuda::detail::async_input_mm);
#endif
m.impl("stream_write_value32_", ::stream_write_value32_);
m.impl("memset32_", ::memset32_);

View File

@ -6,6 +6,8 @@ constexpr size_t signal_pad_size = 2048;
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
using HandleType = CUmemGenericAllocationHandle;
#elif defined(USE_ROCM)
using HandleType = hipMemGenericAllocationHandle_t;
#else
using HandleType = void*;
#endif

View File

@ -7,6 +7,9 @@
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
#include <c10/cuda/driver_api.h>
#elif defined(USE_ROCM)
#include <c10/hip/HIPException.h>
#include <hip/hip_runtime_api.h>
#endif
#include <torch/csrc/distributed/c10d/CUDASymmetricMemoryUtils.hpp>
@ -66,17 +69,26 @@ void IpcChannel::send_fd(int dst_pid, int fd) {
// Because file descriptors are process-local kernel objects, and we cant
// pass them via normal socket payloads (like write() or send()). Unix domain
// sockets provide a mechanism to pass actual FDs via sendmsg()/recvmsg().
// Define destination socket address
struct sockaddr_un addr = {.sun_family = AF_UNIX};
auto socket_name = get_socket_name(dst_pid);
std::copy(socket_name.begin(), socket_name.end(), addr.sun_path);
// Prepare data to send
// Data being sent is "fd", the value of fd will be sent as auxiliary data
// (control message)
struct iovec io = {.iov_base = (void*)("fd"), .iov_len = 2};
// Prepare control message data buffer and zero it out
// NOLINTNEXTLINE(*array*)
char cbuf[CMSG_SPACE(sizeof(int))];
memset(cbuf, 0, sizeof(cbuf));
// Create message header
struct msghdr msg {
// destination socket address and size of it
// message content in msg_iov and number of such structs (1 in our case)
// auxiliary data with the value of fd and size of it
.msg_name = (void*)&addr, .msg_namelen = sizeof(struct sockaddr_un),
.msg_iov = &io, .msg_iovlen = 1, .msg_control = cbuf,
.msg_controllen = sizeof(cbuf)
@ -87,7 +99,9 @@ void IpcChannel::send_fd(int dst_pid, int fd) {
// descriptors.
auto cmsg = CMSG_FIRSTHDR(&msg);
cmsg->cmsg_len = CMSG_LEN(sizeof(int));
// Specify socket level message
cmsg->cmsg_level = SOL_SOCKET;
// SCM_RIGHTS is the type used to pass file descriptors
cmsg->cmsg_type = SCM_RIGHTS;
if (fd != -1) {
@ -99,6 +113,7 @@ void IpcChannel::send_fd(int dst_pid, int fd) {
msg.msg_controllen = 0;
}
// Finally send the the message
TORCH_CHECK(
sendmsg(socket_, &msg, 0) > 0,
"Failed to send fd: ",
@ -106,20 +121,32 @@ void IpcChannel::send_fd(int dst_pid, int fd) {
}
int IpcChannel::recv_fd() {
// Prepare buffer for regular message "fd"
// NOLINTNEXTLINE(*array*)
char buf[2];
memset(&buf, 0, sizeof(buf));
struct iovec io = {.iov_base = (void*)buf, .iov_len = sizeof(buf)};
// Prepare buffer for control message and zero it out
// NOLINTNEXTLINE(*array*)
char cbuf[CMSG_SPACE(sizeof(int))];
memset(cbuf, 0, sizeof(cbuf));
// Define socket address to receive on: family AF_UNIX means unix domain
// socket
struct sockaddr_un addr = {.sun_family = AF_UNIX};
std::copy(socket_name_.begin(), socket_name_.end(), addr.sun_path);
// Prepare message header
struct msghdr msg = {
.msg_name = (void*)&addr,
.msg_namelen = sizeof(struct sockaddr_un),
.msg_iov = &io,
.msg_iovlen = 1,
.msg_control = cbuf,
.msg_controllen = sizeof(cbuf)};
// Recieve message on socket_
TORCH_CHECK(
recvmsg(socket_, &msg, 0) > 0,
"Failed to receive fd: ",
@ -129,6 +156,7 @@ int IpcChannel::recv_fd() {
return -1;
}
// Extract control message and validate its content
auto cmsg = CMSG_FIRSTHDR(&msg);
TORCH_CHECK(cmsg != nullptr);
TORCH_CHECK(cmsg->cmsg_len == CMSG_LEN(sizeof(int)));
@ -207,6 +235,27 @@ void map_block(
desc.location.id = static_cast<int>(device_idx);
desc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
C10_CUDA_DRIVER_CHECK(driver_api->cuMemSetAccess_(*dev_ptr, size, &desc, 1));
#elif defined(USE_ROCM)
C10_HIP_CHECK(hipMemAddressReserve(ptr, size, 0ULL, 0, 0ULL));
C10_HIP_CHECK(hipMemMap(
*ptr,
size,
0,
reinterpret_cast<hipMemGenericAllocationHandle_t>(handle),
0ULL));
C10_HIP_CHECK(hipMemMap(
*ptr,
size,
0,
reinterpret_cast<hipMemGenericAllocationHandle_t>(handle),
0ULL));
hipMemAccessDesc desc;
desc.location.type = hipMemLocationTypeDevice;
// NOLINTNEXTLINE(bugprone-signed-char-misuse)
desc.location.id = static_cast<int>(device_idx);
desc.flags = hipMemAccessFlagsProtReadWrite;
C10_HIP_CHECK(hipMemSetAccess(*ptr, size, &desc, 1));
#else
TORCH_CHECK(
false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED");

View File

@ -3,7 +3,9 @@
#include <torch/csrc/distributed/c10d/DMAConnectivity.hpp>
#include <torch/csrc/distributed/c10d/Utils.hpp>
// #include <cuda_runtime.h>
#if defined(USE_ROCM)
#include <rocm_smi/rocm_smi.h>
#endif
namespace c10d::intra_node_comm {
@ -13,15 +15,13 @@ static std::vector<std::string> ENABLE_INTRA_NODE_COMM = {
// IntraNodeComm can be used even without NVLink connection. This is only used
// for testing purposes.
static std::vector<std::string> TEST_INTRA_NODE_COMM = {"TEST_INTRA_NODE_COMM"};
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
static int intraNodeCommIdx = 0;
#endif
/**
* Query the nvlink connection among devices.
*/
static NvlMesh getNvlMesh(const std::vector<int>& rankToDeviceIdx) {
#if !defined(USE_RCOM)
auto connectivity = detect_dma_connectivity(c10::DeviceType::CUDA, "nvlink");
NvlMesh nvlMesh = {};
for (size_t srcRank = 0; srcRank < kMaxDevices; ++srcRank) {
@ -35,6 +35,31 @@ static NvlMesh getNvlMesh(const std::vector<int>& rankToDeviceIdx) {
}
}
return nvlMesh;
#else
NvlMesh nvlMesh = {};
const auto worldSize = rankToDeviceIdx.size();
// For each device, loop over devices connected to it
for (size_t idx = 0; idx < worldSize; ++idx) {
for (size_t link = 0; link < kMaxDevices; ++link) {
if (idx == link)
continue;
bool conn = false;
auto ret = rsmi_is_P2P_accessible(idx, link, &conn);
if (ret != RSMI_STATUS_SUCCESS) {
LOG(ERROR)
<< "IntraNodeComm: getNvlMesh: rsmi_is_P2P_accessible returned error ret="
<< ret;
return {};
}
if (conn) {
nvlMesh[idx][link] += 1;
}
}
}
return nvlMesh;
#endif
}
/**
@ -128,7 +153,6 @@ bool IntraNodeComm::rendezvous() {
if (isInitialized_) {
return true;
}
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
if (!isIntraNodeCommSupported() || worldSize_ < 2 ||
worldSize_ > kMaxDevices) {
return false;
@ -148,6 +172,14 @@ bool IntraNodeComm::rendezvous() {
gethostname(devInfo.hostname, sizeof(devInfo.hostname));
devInfo.deviceIdx = deviceIdx_;
#if defined(USE_ROCM)
auto ret = rsmi_init(0);
if (ret != RSMI_STATUS_SUCCESS) {
LOG(ERROR) << "IntraNodeComm:: rendezvous failed in rsmi_init, ret=" << ret;
return false;
}
#endif
auto peerDevInfos =
storeAllGather(store_, "handshake-0", rank_, worldSize_, devInfo);
@ -191,8 +223,6 @@ bool IntraNodeComm::rendezvous() {
symmetricMemory_ = allocator->rendezvous(symmetricMemoryPtr_, std::nullopt);
isInitialized_ = true;
return true;
#endif
return false;
}
} // namespace c10d::intra_node_comm