mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ROCm] Add support for SymmetricMemory (#150580)
This is an attempt to re-land the initial PR https://github.com/pytorch/pytorch/pull/134817 with recent design changes from upstream. **NOTE:** ROCm currently does NOT have multicast/multimem hardware support at the moment, so those features are disabled in symmetric memory for ROCm. This also means that we currently do not have a way of lowering add + all_reduce + wait_tensor into one_shot_all_reduce op in inductor as it depends on a multicast buffer support. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150580 Approved by: https://github.com/jeffdaily, https://github.com/kwen2501, https://github.com/yoyoyocmu Co-authored-by: Xiaodong Wang <xdwang@fb.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
376529c78b
commit
1ea2731e26
@ -57,9 +57,11 @@ from torch.testing._internal.common_distributed import (
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
MI300_ARCH,
|
||||
parametrize,
|
||||
retry_on_connect_failures,
|
||||
run_tests,
|
||||
runOnRocmArch,
|
||||
skip_but_pass_in_sandcastle,
|
||||
skip_but_pass_in_sandcastle_if,
|
||||
TEST_CUDA,
|
||||
@ -3322,7 +3324,7 @@ class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):
|
||||
|
||||
@requires_nccl()
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@skip_if_rocm_multiprocess
|
||||
@runOnRocmArch(MI300_ARCH)
|
||||
def test_intra_node_comm_all_reduce(self):
|
||||
from torch._C._distributed_c10d import _get_intra_node_comm_usage_counter
|
||||
from torch.testing._internal.common_cuda import SM80OrLater
|
||||
|
@ -21,6 +21,7 @@ from torch.distributed._symmetric_memory import (
|
||||
restride_A_shard_for_fused_all_gather_matmul,
|
||||
)
|
||||
from torch.testing._internal.common_cuda import _get_torch_cuda_version, SM90OrLater
|
||||
from torch.testing._internal.common_device_type import e4m3_type
|
||||
from torch.testing._internal.common_distributed import (
|
||||
MultiProcessTestCase,
|
||||
requires_multicast_support,
|
||||
@ -28,11 +29,14 @@ from torch.testing._internal.common_distributed import (
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
MI300_ARCH,
|
||||
parametrize,
|
||||
requires_cuda,
|
||||
run_tests,
|
||||
runOnRocmArch,
|
||||
skip_but_pass_in_sandcastle_if,
|
||||
skipIfRocm,
|
||||
TEST_WITH_ROCM,
|
||||
TestCase,
|
||||
)
|
||||
|
||||
@ -102,7 +106,7 @@ class SymmetricMemoryTest(MultiProcessTestCase):
|
||||
for row in connectivity.matrix:
|
||||
self.assertEqual(len(row), torch.cuda.device_count())
|
||||
|
||||
@skipIfRocm
|
||||
@runOnRocmArch(MI300_ARCH)
|
||||
def test_large_alloc(self) -> None:
|
||||
t = symm_mem.empty(2 * 1024**3, dtype=torch.uint8, device="cuda")
|
||||
self.assertEqual(t.numel() * t.element_size(), 2 * 1024**3)
|
||||
@ -142,7 +146,7 @@ class SymmetricMemoryTest(MultiProcessTestCase):
|
||||
|
||||
symm_mem_hdl.barrier()
|
||||
|
||||
@skipIfRocm
|
||||
@runOnRocmArch(MI300_ARCH)
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@parametrize("set_device", [True, False])
|
||||
def test_empty_strided_p2p(self, set_device: bool) -> None:
|
||||
@ -161,7 +165,7 @@ class SymmetricMemoryTest(MultiProcessTestCase):
|
||||
self._verify_symmetric_memory(symm_mem_hdl)
|
||||
dist.destroy_process_group()
|
||||
|
||||
@skipIfRocm
|
||||
@runOnRocmArch(MI300_ARCH)
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@parametrize("set_device", [True, False])
|
||||
def test_empty_strided_p2p_persistent(self, set_device: bool) -> None:
|
||||
@ -189,7 +193,7 @@ class SymmetricMemoryTest(MultiProcessTestCase):
|
||||
self._verify_symmetric_memory(symm_mem_hdl)
|
||||
dist.destroy_process_group()
|
||||
|
||||
@skipIfRocm
|
||||
@runOnRocmArch(MI300_ARCH)
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_get_signal_pad(self) -> None:
|
||||
self._init_process()
|
||||
@ -232,6 +236,11 @@ class SymmetricMemoryTest(MultiProcessTestCase):
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
# These timeout tests are skipped on ROCm because timeout calls trap(), which
|
||||
# is handled differently inside hip runtime. It collects gpu coredump and causes
|
||||
# the linux kernel to create a core dump of the host application. The funcitonality
|
||||
# is there, meaning timeout is happening correctly. However, there isn't a nice way
|
||||
# to test it as the current executing thread will coredump and exit.
|
||||
@skipIfRocm
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_barrier_timeout(self) -> None:
|
||||
@ -253,6 +262,11 @@ class SymmetricMemoryTest(MultiProcessTestCase):
|
||||
# impossible to terminate the process in this state.
|
||||
os._exit(0)
|
||||
|
||||
# These timeout tests are skipped on ROCm because timeout calls trap(), which
|
||||
# is handled differently inside hip runtime. It collects gpu coredump and causes
|
||||
# the linux kernel to create a core dump of the host application. The funcitonality
|
||||
# is there, meaning timeout is happening correctly. However, there isn't a nice way
|
||||
# to test it as the current executing thread will coredump and exit.
|
||||
@skipIfRocm
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_put_signal_timeout(self) -> None:
|
||||
@ -277,6 +291,11 @@ class SymmetricMemoryTest(MultiProcessTestCase):
|
||||
# impossible to terminate the process in this state.
|
||||
os._exit(0)
|
||||
|
||||
# These timeout tests are skipped on ROCm because timeout calls trap(), which
|
||||
# is handled differently inside hip runtime. It collects gpu coredump and causes
|
||||
# the linux kernel to create a core dump of the host application. The funcitonality
|
||||
# is there, meaning timeout is happening correctly. However, there isn't a nice way
|
||||
# to test it as the current executing thread will coredump and exit.
|
||||
@skipIfRocm
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_wait_signal_timeout(self) -> None:
|
||||
@ -298,7 +317,7 @@ class SymmetricMemoryTest(MultiProcessTestCase):
|
||||
# impossible to terminate the process in this state.
|
||||
os._exit(0)
|
||||
|
||||
@skipIfRocm
|
||||
@runOnRocmArch(MI300_ARCH)
|
||||
@requires_cuda
|
||||
def test_allow_overlapping_devices(self) -> None:
|
||||
os.environ["TORCH_SYMM_MEM_ALLOW_OVERLAPPING_DEVICES"] = "1"
|
||||
@ -324,7 +343,7 @@ class SymmetricMemoryTest(MultiProcessTestCase):
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
@skipIfRocm
|
||||
@runOnRocmArch(MI300_ARCH)
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@parametrize("gather_dim", [0, 1])
|
||||
def test_fused_all_gather_matmul(self, gather_dim: int) -> None:
|
||||
@ -356,7 +375,7 @@ class SymmetricMemoryTest(MultiProcessTestCase):
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
@skipIfRocm
|
||||
@skipIfRocm # this requires async_input_mm support
|
||||
@skipIf(
|
||||
not SM90OrLater,
|
||||
"_fused_all_gather_matmul_native currently only supports sm>=90",
|
||||
@ -416,7 +435,6 @@ class SymmetricMemoryTest(MultiProcessTestCase):
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
@skipIfRocm
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@requires_multicast_support()
|
||||
def test_multimem_all_gather_matmul(self) -> None:
|
||||
@ -457,7 +475,7 @@ class SymmetricMemoryTest(MultiProcessTestCase):
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
@skipIfRocm
|
||||
@runOnRocmArch(MI300_ARCH)
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@parametrize("gather_dim", [0, 1])
|
||||
@parametrize(
|
||||
@ -483,10 +501,9 @@ class SymmetricMemoryTest(MultiProcessTestCase):
|
||||
raise AssertionError("Invalid scale_mode: {scale_mode}")
|
||||
|
||||
torch.manual_seed(42 + rank)
|
||||
A_shard = torch.rand(*leading_dims, K, device="cuda").to(torch.float8_e4m3fn)
|
||||
Bs = [
|
||||
torch.rand(N, K, device="cuda").to(torch.float8_e4m3fn).T for _ in range(3)
|
||||
]
|
||||
|
||||
A_shard = torch.rand(*leading_dims, K, device="cuda").to(e4m3_type)
|
||||
Bs = [torch.rand(N, K, device="cuda").to(e4m3_type).T for _ in range(3)]
|
||||
|
||||
if scale_mode == "tensor-wise":
|
||||
A_scale = torch.tensor(0.1, device="cuda")
|
||||
@ -546,7 +563,7 @@ class SymmetricMemoryTest(MultiProcessTestCase):
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
@skipIfRocm
|
||||
@runOnRocmArch(MI300_ARCH)
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@parametrize("scatter_dim", [0, 1])
|
||||
def test_fused_matmul_reduce_scatter(self, scatter_dim: int) -> None:
|
||||
@ -575,7 +592,7 @@ class SymmetricMemoryTest(MultiProcessTestCase):
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
@skipIfRocm
|
||||
@skipIfRocm # AsyncTP support changed _fused_scaled_matmul_reduce_scatter_fallback API, need more changes
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@parametrize("scatter_dim", [0, 1])
|
||||
@parametrize("rowwise", [True, False])
|
||||
@ -592,8 +609,8 @@ class SymmetricMemoryTest(MultiProcessTestCase):
|
||||
rank = self.rank
|
||||
|
||||
torch.manual_seed(42 + rank)
|
||||
A = torch.rand(BATCH, M, K, device="cuda").to(torch.float8_e4m3fn)
|
||||
B = torch.rand(N, K, device="cuda").to(torch.float8_e4m3fn).T
|
||||
A = torch.rand(BATCH, M, K, device="cuda").to(e4m3_type)
|
||||
B = torch.rand(N, K, device="cuda").to(e4m3_type).T
|
||||
|
||||
if rowwise:
|
||||
A_scale = torch.full((BATCH, M, 1), 0.1, device="cuda")
|
||||
@ -628,7 +645,7 @@ class SymmetricMemoryTest(MultiProcessTestCase):
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
@skipIfRocm
|
||||
@runOnRocmArch(MI300_ARCH)
|
||||
@parametrize("dim", [0, 1, 2])
|
||||
def test_optimal_layout(self, dim: int) -> None:
|
||||
t = torch.rand(8, 64, 32, 16)
|
||||
@ -641,7 +658,7 @@ class SymmetricMemoryTest(MultiProcessTestCase):
|
||||
self.assertTrue(x.movedim(dim, 0).is_contiguous())
|
||||
self.assertTrue(torch.allclose(x, t))
|
||||
|
||||
@skipIfRocm
|
||||
@runOnRocmArch(MI300_ARCH)
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@parametrize("symm_mem_input", [True, False])
|
||||
def test_low_contention_all_gather(self, symm_mem_input: bool) -> None:
|
||||
@ -668,7 +685,7 @@ class SymmetricMemoryTest(MultiProcessTestCase):
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
@skipIfRocm
|
||||
@runOnRocmArch(MI300_ARCH)
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@parametrize("reduce_op", ["sum", "avg"])
|
||||
@parametrize("symm_mem_input", [True, False])
|
||||
@ -733,7 +750,7 @@ class SubgroupTest(MultiProcessTestCase):
|
||||
)
|
||||
torch.manual_seed(42 + self.rank)
|
||||
|
||||
@skipIfRocm
|
||||
@runOnRocmArch(MI300_ARCH)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_subgroup(self) -> None:
|
||||
self._init_process()
|
||||
@ -771,7 +788,6 @@ class SubgroupTest(MultiProcessTestCase):
|
||||
self.assertTrue(buf.eq(peer_rank + world.size() // 2).all())
|
||||
|
||||
|
||||
# @skipIfRocm
|
||||
@instantiate_parametrized_tests
|
||||
@requires_cuda_p2p_access()
|
||||
class SymmMemCollectiveTest(MultiProcessTestCase):
|
||||
@ -859,7 +875,7 @@ class SymmMemCollectiveTest(MultiProcessTestCase):
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
@skipIfRocm
|
||||
@runOnRocmArch(MI300_ARCH)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_one_shot_all_reduce(self) -> None:
|
||||
self._init_process()
|
||||
@ -890,7 +906,7 @@ class SymmMemCollectiveTest(MultiProcessTestCase):
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
@skipIfRocm
|
||||
@runOnRocmArch(MI300_ARCH)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_two_shot_all_reduce(self) -> None:
|
||||
self._init_process()
|
||||
@ -940,7 +956,7 @@ class SymmMemCollectiveTest(MultiProcessTestCase):
|
||||
gathered_inps.sum(dim=0), res, rtol=1e-01, atol=1e-01
|
||||
)
|
||||
|
||||
@skipIfRocm
|
||||
@runOnRocmArch(MI300_ARCH)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_reduce_scatter(self) -> None:
|
||||
self._init_process()
|
||||
@ -977,7 +993,7 @@ class SymmMemCollectiveTest(MultiProcessTestCase):
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
@skipIfRocm
|
||||
@runOnRocmArch(MI300_ARCH)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_reduce_scatter_corner_cases(self) -> None:
|
||||
dtype = torch.bfloat16
|
||||
@ -1125,12 +1141,12 @@ class LoweringTest(MultiProcessTestCase):
|
||||
|
||||
|
||||
class SymmMemSingleProcTest(TestCase):
|
||||
@skipIfRocm
|
||||
@requires_cuda
|
||||
@skipIf(
|
||||
_get_torch_cuda_version() < (12, 0),
|
||||
not TEST_WITH_ROCM and _get_torch_cuda_version() < (12, 0),
|
||||
"stream_write_value32 currently only supports cuda version>=12.0",
|
||||
)
|
||||
@runOnRocmArch(MI300_ARCH)
|
||||
def test_stream_write_value32(self):
|
||||
tensor = torch.zeros(4, dtype=torch.uint32, device="cuda")
|
||||
expect = torch.tril(torch.ones(4, 4, device="cuda")).to(torch.uint32)
|
||||
@ -1145,8 +1161,8 @@ class SymmMemSingleProcTest(TestCase):
|
||||
with self.assertRaises(RuntimeError):
|
||||
_SymmetricMemory.stream_write_value32(tensor, offset=0, val=4294967296)
|
||||
|
||||
@skipIfRocm
|
||||
@requires_cuda
|
||||
@runOnRocmArch(MI300_ARCH)
|
||||
def test_memset32(self):
|
||||
t = _SymmetricMemory.empty_strided_p2p(
|
||||
(64,),
|
||||
|
@ -7,6 +7,9 @@
|
||||
#endif
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#if defined(USE_ROCM)
|
||||
#include <hip/hip_bf16.h>
|
||||
#endif
|
||||
#if !defined(USE_ROCM)
|
||||
#include <cuda_bf16.h>
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600)
|
||||
@ -33,6 +36,10 @@ cas(uint32_t* addr, uint32_t compare, uint32_t val) {
|
||||
::cuda::atomic_ref<uint32_t, ::cuda::thread_scope_system> ref(*addr);
|
||||
ref.compare_exchange_strong(compare, val, ::cuda::std::memory_order(Sem));
|
||||
return compare;
|
||||
#elif defined(USE_ROCM)
|
||||
__atomic_compare_exchange_n(
|
||||
addr, &compare, val, false, static_cast<int>(Sem), __ATOMIC_RELAXED);
|
||||
return compare;
|
||||
#else
|
||||
CUDA_KERNEL_ASSERT(false);
|
||||
return 0;
|
||||
@ -41,7 +48,10 @@ cas(uint32_t* addr, uint32_t compare, uint32_t val) {
|
||||
|
||||
__device__ __forceinline__ void trap() {
|
||||
#if defined(USE_ROCM)
|
||||
assert(0);
|
||||
// abort() calls trap() under the covers. However, on ROCm, the trap is
|
||||
// handled differently inside hip runtime. It collects a gpu core dump and
|
||||
// causes linux kernerl to create a core dump of the host application.
|
||||
abort();
|
||||
#else
|
||||
__trap();
|
||||
#endif
|
||||
@ -49,8 +59,8 @@ __device__ __forceinline__ void trap() {
|
||||
|
||||
__device__ __forceinline__ size_t global_timer_ns() {
|
||||
#if defined(USE_ROCM)
|
||||
CUDA_KERNEL_ASSERT(false);
|
||||
return 0;
|
||||
static constexpr double MI300_FREQ_GHZ = 2.1;
|
||||
return __builtin_amdgcn_s_memtime() / MI300_FREQ_GHZ;
|
||||
#else
|
||||
size_t val;
|
||||
asm volatile("mov.u64 %0, %globaltimer;" : "=l"(val) : : "memory");
|
||||
@ -244,14 +254,10 @@ __device__ __inline__ void multimem_st(T* mc_ptr, Vec<Alignment>& vec) {
|
||||
#endif
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
using __nv_bfloat162 = uint32_t;
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
__device__ __inline__ T add_bf16x2(T a, T b) {
|
||||
static_assert(sizeof(T) == 4);
|
||||
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
|
||||
CUDA_KERNEL_ASSERT(false);
|
||||
return T{};
|
||||
#else
|
||||
|
@ -14,6 +14,8 @@
|
||||
|
||||
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
||||
#include <c10/cuda/driver_api.h>
|
||||
#elif defined(USE_ROCM)
|
||||
#include <hip/hip_runtime_api.h>
|
||||
#endif
|
||||
|
||||
#if defined(CUDART_VERSION) && CUDART_VERSION >= 12030
|
||||
@ -39,17 +41,20 @@ AllocationRef::AllocationRef(
|
||||
device_idx(device_idx) {}
|
||||
|
||||
AllocationRef::~AllocationRef() {
|
||||
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
||||
// Leak the cuda allocations during static deinitialization
|
||||
if (is_finalizing()) {
|
||||
return;
|
||||
}
|
||||
auto driver_api = c10::cuda::DriverAPI::get();
|
||||
c10::cuda::CUDAGuard guard(device_idx);
|
||||
C10_CUDA_CHECK(cudaDeviceSynchronize());
|
||||
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
||||
// Leak the cuda allocations during static deinitialization
|
||||
auto driver_api = c10::cuda::DriverAPI::get();
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
driver_api->cuMemUnmap_(reinterpret_cast<CUdeviceptr>(ptr), block_size));
|
||||
C10_CUDA_DRIVER_CHECK(driver_api->cuMemRelease_(handle));
|
||||
#elif defined(USE_ROCM)
|
||||
C10_HIP_CHECK(hipMemUnmap(reinterpret_cast<hipDeviceptr_t>(ptr), block_size));
|
||||
C10_HIP_CHECK(hipMemRelease(handle));
|
||||
#else
|
||||
TORCH_CHECK(
|
||||
false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED");
|
||||
@ -368,10 +373,12 @@ void* CUDASymmetricMemoryAllocator::alloc(
|
||||
size_t size,
|
||||
int device_idx,
|
||||
const std::optional<std::string>& group_name) {
|
||||
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
||||
|
||||
size_t signal_pad_offset = at::round_up(size, 16UL);
|
||||
size_t block_size = signal_pad_offset + signal_pad_size;
|
||||
c10::cuda::CUDAGuard guard(device_idx);
|
||||
device_idx = static_cast<int>(guard.current_device().index());
|
||||
|
||||
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
||||
CUmemAllocationProp prop = {};
|
||||
prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
|
||||
prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
|
||||
@ -379,8 +386,6 @@ void* CUDASymmetricMemoryAllocator::alloc(
|
||||
prop.location.id = device_idx;
|
||||
prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
|
||||
|
||||
size_t signal_pad_offset = at::round_up(size, 16UL);
|
||||
size_t block_size = signal_pad_offset + signal_pad_size;
|
||||
|
||||
size_t granularity;
|
||||
auto driver_api = c10::cuda::DriverAPI::get();
|
||||
@ -392,6 +397,27 @@ void* CUDASymmetricMemoryAllocator::alloc(
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
driver_api->cuMemCreate_(&handle, block_size, &prop, 0));
|
||||
|
||||
#elif defined(USE_ROCM)
|
||||
hipMemAllocationProp prop = {};
|
||||
prop.type = hipMemAllocationTypePinned;
|
||||
prop.location.type = hipMemLocationTypeDevice;
|
||||
// NOLINTNEXTLINE(bugprone-signed-char-misuse)
|
||||
prop.location.id = device_idx;
|
||||
prop.requestedHandleType = hipMemHandleTypePosixFileDescriptor;
|
||||
|
||||
|
||||
size_t granularity;
|
||||
C10_HIP_CHECK(hipMemGetAllocationGranularity(
|
||||
&granularity, &prop, hipMemAllocationGranularityRecommended));
|
||||
block_size = at::round_up(block_size, granularity);
|
||||
|
||||
HandleType handle;
|
||||
C10_HIP_CHECK(hipMemCreate(reinterpret_cast<hipMemGenericAllocationHandle_t*>(&handle), block_size, &prop, 0));
|
||||
|
||||
#else
|
||||
TORCH_CHECK(
|
||||
false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED");
|
||||
#endif
|
||||
void* ptr = nullptr;
|
||||
map_block(&ptr, handle, block_size, device_idx);
|
||||
|
||||
@ -411,10 +437,6 @@ void* CUDASymmetricMemoryAllocator::alloc(
|
||||
ptr_to_block_.emplace(ptr, std::move(block));
|
||||
}
|
||||
return ptr;
|
||||
#else
|
||||
TORCH_CHECK(
|
||||
false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED");
|
||||
#endif
|
||||
}
|
||||
|
||||
void CUDASymmetricMemoryAllocator::free(void* ptr) {
|
||||
@ -559,7 +581,7 @@ static void init_multicast_for_block(
|
||||
c10::intrusive_ptr<SymmetricMemory> CUDASymmetricMemoryAllocator::rendezvous(
|
||||
void* ptr,
|
||||
const std::optional<std::string>& group_name) {
|
||||
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
||||
|
||||
auto block = find_block(ptr);
|
||||
if (block == nullptr) {
|
||||
return nullptr;
|
||||
@ -595,9 +617,10 @@ c10::intrusive_ptr<SymmetricMemory> CUDASymmetricMemoryAllocator::rendezvous(
|
||||
auto store = group_info.store;
|
||||
int rank = group_info.rank;
|
||||
int world_size = group_info.world_size;
|
||||
|
||||
auto driver_api = c10::cuda::DriverAPI::get();
|
||||
int block_fd;
|
||||
|
||||
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
||||
auto driver_api = c10::cuda::DriverAPI::get();
|
||||
// using the CUDA Driver API to export a GPU memory block as a
|
||||
// POSIX file descriptor (FD), so it can be shared across processes via IPC.
|
||||
C10_CUDA_DRIVER_CHECK(driver_api->cuMemExportToShareableHandle_(
|
||||
@ -605,6 +628,13 @@ c10::intrusive_ptr<SymmetricMemory> CUDASymmetricMemoryAllocator::rendezvous(
|
||||
block->alloc_ref->handle,
|
||||
CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR,
|
||||
0));
|
||||
#elif defined (USE_ROCM)
|
||||
C10_HIP_CHECK(hipMemExportToShareableHandle(
|
||||
&block_fd, block->alloc_ref->handle, hipMemHandleTypePosixFileDescriptor, 0));
|
||||
#else
|
||||
TORCH_CHECK(
|
||||
false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED");
|
||||
#endif
|
||||
|
||||
auto local_req = RendezvousRequest{
|
||||
.device_idx = block->device_idx,
|
||||
@ -635,10 +665,20 @@ c10::intrusive_ptr<SymmetricMemory> CUDASymmetricMemoryAllocator::rendezvous(
|
||||
}
|
||||
// This api imports a GPU memory allocation that was previously exported as a file
|
||||
// descriptor and it returns a memory handle.
|
||||
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
||||
C10_CUDA_DRIVER_CHECK(driver_api->cuMemImportFromShareableHandle_(
|
||||
&handles[r],
|
||||
(void*)(uintptr_t)imported_fds[r],
|
||||
CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR));
|
||||
#elif defined (USE_ROCM)
|
||||
C10_HIP_CHECK(hipMemImportFromShareableHandle(
|
||||
&handles[r],
|
||||
(void*)(uintptr_t)&(imported_fds[r]),
|
||||
hipMemHandleTypePosixFileDescriptor));
|
||||
#else
|
||||
TORCH_CHECK(
|
||||
false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED");
|
||||
#endif
|
||||
map_block(&buffers[r], handles[r], block->block_size, block->device_idx);
|
||||
signal_pads[r] = (void*)((uintptr_t)buffers[r] + block->signal_pad_offset);
|
||||
close(imported_fds[r]);
|
||||
@ -676,10 +716,6 @@ c10::intrusive_ptr<SymmetricMemory> CUDASymmetricMemoryAllocator::rendezvous(
|
||||
group_info.world_size);
|
||||
block->symm_mems[group_name_] = symm_mem;
|
||||
return symm_mem;
|
||||
#else
|
||||
TORCH_CHECK(
|
||||
false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED");
|
||||
#endif
|
||||
}
|
||||
|
||||
bool CUDASymmetricMemoryAllocator::has_multicast_support(int device_idx) {
|
||||
|
@ -20,7 +20,7 @@
|
||||
#include <torch/csrc/distributed/c10d/CUDASymmetricMemory.hpp>
|
||||
#include <torch/csrc/distributed/c10d/cuda/AsyncMM.cuh>
|
||||
|
||||
#if defined(CUDART_VERSION) && CUDART_VERSION >= 12030
|
||||
#if defined(USE_ROCM) || (defined(CUDART_VERSION) && CUDART_VERSION >= 12030)
|
||||
|
||||
#define INT_SWITCH_CASE(name, val, ...) \
|
||||
case val: { \
|
||||
@ -124,6 +124,7 @@ void init_elementwise_launch_config(
|
||||
}
|
||||
}
|
||||
|
||||
#if !defined(USE_ROCM) //No multi-cast support on ROCm yet
|
||||
template <typename T, int alignment>
|
||||
static __global__ void multimem_all_reduce_kernel(
|
||||
T* input_mc_ptr,
|
||||
@ -395,6 +396,8 @@ at::Tensor multimem_all_gather_out(
|
||||
return out;
|
||||
}
|
||||
|
||||
#endif //no multi-cast support on ROCm
|
||||
|
||||
// One-shot all-reduce is register-intensive because it stages values loaded
|
||||
// from peers in registers before performing reduction. Setting the thread
|
||||
// count to 512 to prevent/alleviate register spill.
|
||||
@ -1054,7 +1057,6 @@ at::Tensor memset32_(
|
||||
int64_t offset,
|
||||
int64_t val,
|
||||
int64_t count) {
|
||||
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
||||
TORCH_CHECK(
|
||||
input.dim() == 1 && input.is_contiguous() &&
|
||||
input.scalar_type() == c10::ScalarType::UInt32,
|
||||
@ -1078,7 +1080,6 @@ at::Tensor memset32_(
|
||||
"symm_mem::memset32_: val must be in the range of "
|
||||
"[0, 4294967295] (uint32_t).")
|
||||
|
||||
auto element_size = c10::elementSize(input.scalar_type());
|
||||
TORCH_CHECK(
|
||||
offset + count <= input.numel(),
|
||||
"symm_mem::memset32_: offset + count (",
|
||||
@ -1088,14 +1089,20 @@ at::Tensor memset32_(
|
||||
")");
|
||||
|
||||
auto addr = reinterpret_cast<uint32_t*>(input.data_ptr()) + offset;
|
||||
|
||||
c10::cuda::CUDAGuard guard(input.device());
|
||||
|
||||
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
||||
auto driver_api = c10::cuda::DriverAPI::get();
|
||||
C10_CUDA_DRIVER_CHECK(driver_api->cuMemsetD32Async_(
|
||||
reinterpret_cast<CUdeviceptr>(addr),
|
||||
val,
|
||||
count,
|
||||
at::cuda::getCurrentCUDAStream()));
|
||||
#elif defined(USE_ROCM)
|
||||
C10_HIP_CHECK(hipMemsetD32Async(reinterpret_cast<hipDeviceptr_t>(addr),
|
||||
val,
|
||||
count,
|
||||
at::cuda::getCurrentCUDAStream()));
|
||||
#else
|
||||
TORCH_CHECK(
|
||||
false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED");
|
||||
@ -1107,7 +1114,6 @@ at::Tensor stream_write_value32_(
|
||||
at::Tensor& input,
|
||||
int64_t offset,
|
||||
int64_t val) {
|
||||
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
||||
TORCH_CHECK(
|
||||
input.dim() == 1 && input.is_contiguous() &&
|
||||
input.scalar_type() == c10::ScalarType::UInt32,
|
||||
@ -1127,7 +1133,6 @@ at::Tensor stream_write_value32_(
|
||||
"symm_mem::stream_write_value32_: "
|
||||
"val must be in the range of [0, 4294967295] (uint32_t).")
|
||||
|
||||
auto element_size = c10::elementSize(input.scalar_type());
|
||||
TORCH_CHECK(
|
||||
offset < input.numel(),
|
||||
"symm_mem::stream_write_value32_: offset (",
|
||||
@ -1137,8 +1142,9 @@ at::Tensor stream_write_value32_(
|
||||
")");
|
||||
|
||||
auto addr = reinterpret_cast<uint32_t*>(input.data_ptr()) + offset;
|
||||
|
||||
c10::cuda::CUDAGuard guard(input.device());
|
||||
|
||||
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
||||
auto driver_api = c10::cuda::DriverAPI::get();
|
||||
// According to the documentation of CUstreamWriteValue_flags,
|
||||
// cuStreamWriteValue32 will provide a memory fence before the write, which
|
||||
@ -1149,6 +1155,12 @@ at::Tensor stream_write_value32_(
|
||||
reinterpret_cast<CUdeviceptr>(addr),
|
||||
val,
|
||||
0));
|
||||
#elif defined(USE_ROCM)
|
||||
C10_HIP_CHECK(hipStreamWriteValue32(
|
||||
at::cuda::getCurrentCUDAStream(),
|
||||
reinterpret_cast<void*>(addr),
|
||||
val,
|
||||
0));
|
||||
#else
|
||||
TORCH_CHECK(
|
||||
false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED");
|
||||
@ -1159,6 +1171,17 @@ at::Tensor stream_write_value32_(
|
||||
} // namespace
|
||||
|
||||
TORCH_LIBRARY_IMPL(symm_mem, CUDA, m) {
|
||||
#if defined(USE_ROCM) || defined(CUDART_VERSION)
|
||||
m.impl("one_shot_all_reduce", ::one_shot_all_reduce);
|
||||
m.impl("one_shot_all_reduce_out", ::one_shot_all_reduce_out);
|
||||
m.impl("one_shot_all_reduce_copy", ::one_shot_all_reduce_copy);
|
||||
m.impl("one_shot_all_reduce_copy_out", ::one_shot_all_reduce_copy_out);
|
||||
m.impl("two_shot_all_reduce_", ::two_shot_all_reduce_);
|
||||
m.impl("two_shot_all_reduce_out", ::two_shot_all_reduce_out);
|
||||
m.impl("reduce_scatter_out", ::reduce_scatter_out);
|
||||
|
||||
m.impl("_async_input_mm", c10d::cuda::detail::async_input_mm);
|
||||
#endif
|
||||
#if defined(CUDART_VERSION)
|
||||
m.impl("multimem_all_reduce_", ::multimem_all_reduce_);
|
||||
|
||||
@ -1173,15 +1196,6 @@ TORCH_LIBRARY_IMPL(symm_mem, CUDA, m) {
|
||||
m.impl(
|
||||
"multimem_one_shot_all_reduce_out", ::multimem_one_shot_all_reduce_out);
|
||||
m.impl("multimem_all_gather_out", ::multimem_all_gather_out);
|
||||
m.impl("one_shot_all_reduce", ::one_shot_all_reduce);
|
||||
m.impl("one_shot_all_reduce_out", ::one_shot_all_reduce_out);
|
||||
m.impl("one_shot_all_reduce_copy", ::one_shot_all_reduce_copy);
|
||||
m.impl("one_shot_all_reduce_copy_out", ::one_shot_all_reduce_copy_out);
|
||||
m.impl("two_shot_all_reduce_", ::two_shot_all_reduce_);
|
||||
m.impl("two_shot_all_reduce_out", ::two_shot_all_reduce_out);
|
||||
m.impl("reduce_scatter_out", ::reduce_scatter_out);
|
||||
|
||||
m.impl("_async_input_mm", c10d::cuda::detail::async_input_mm);
|
||||
#endif
|
||||
m.impl("stream_write_value32_", ::stream_write_value32_);
|
||||
m.impl("memset32_", ::memset32_);
|
||||
|
@ -6,6 +6,8 @@ constexpr size_t signal_pad_size = 2048;
|
||||
|
||||
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
||||
using HandleType = CUmemGenericAllocationHandle;
|
||||
#elif defined(USE_ROCM)
|
||||
using HandleType = hipMemGenericAllocationHandle_t;
|
||||
#else
|
||||
using HandleType = void*;
|
||||
#endif
|
||||
|
@ -7,6 +7,9 @@
|
||||
|
||||
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
||||
#include <c10/cuda/driver_api.h>
|
||||
#elif defined(USE_ROCM)
|
||||
#include <c10/hip/HIPException.h>
|
||||
#include <hip/hip_runtime_api.h>
|
||||
#endif
|
||||
|
||||
#include <torch/csrc/distributed/c10d/CUDASymmetricMemoryUtils.hpp>
|
||||
@ -66,17 +69,26 @@ void IpcChannel::send_fd(int dst_pid, int fd) {
|
||||
// Because file descriptors are process-local kernel objects, and we can’t
|
||||
// pass them via normal socket payloads (like write() or send()). Unix domain
|
||||
// sockets provide a mechanism to pass actual FDs via sendmsg()/recvmsg().
|
||||
// Define destination socket address
|
||||
struct sockaddr_un addr = {.sun_family = AF_UNIX};
|
||||
auto socket_name = get_socket_name(dst_pid);
|
||||
std::copy(socket_name.begin(), socket_name.end(), addr.sun_path);
|
||||
|
||||
// Prepare data to send
|
||||
// Data being sent is "fd", the value of fd will be sent as auxiliary data
|
||||
// (control message)
|
||||
struct iovec io = {.iov_base = (void*)("fd"), .iov_len = 2};
|
||||
|
||||
// Prepare control message data buffer and zero it out
|
||||
// NOLINTNEXTLINE(*array*)
|
||||
char cbuf[CMSG_SPACE(sizeof(int))];
|
||||
memset(cbuf, 0, sizeof(cbuf));
|
||||
|
||||
// Create message header
|
||||
struct msghdr msg {
|
||||
// destination socket address and size of it
|
||||
// message content in msg_iov and number of such structs (1 in our case)
|
||||
// auxiliary data with the value of fd and size of it
|
||||
.msg_name = (void*)&addr, .msg_namelen = sizeof(struct sockaddr_un),
|
||||
.msg_iov = &io, .msg_iovlen = 1, .msg_control = cbuf,
|
||||
.msg_controllen = sizeof(cbuf)
|
||||
@ -87,7 +99,9 @@ void IpcChannel::send_fd(int dst_pid, int fd) {
|
||||
// descriptors.
|
||||
auto cmsg = CMSG_FIRSTHDR(&msg);
|
||||
cmsg->cmsg_len = CMSG_LEN(sizeof(int));
|
||||
// Specify socket level message
|
||||
cmsg->cmsg_level = SOL_SOCKET;
|
||||
// SCM_RIGHTS is the type used to pass file descriptors
|
||||
cmsg->cmsg_type = SCM_RIGHTS;
|
||||
|
||||
if (fd != -1) {
|
||||
@ -99,6 +113,7 @@ void IpcChannel::send_fd(int dst_pid, int fd) {
|
||||
msg.msg_controllen = 0;
|
||||
}
|
||||
|
||||
// Finally send the the message
|
||||
TORCH_CHECK(
|
||||
sendmsg(socket_, &msg, 0) > 0,
|
||||
"Failed to send fd: ",
|
||||
@ -106,20 +121,32 @@ void IpcChannel::send_fd(int dst_pid, int fd) {
|
||||
}
|
||||
|
||||
int IpcChannel::recv_fd() {
|
||||
// Prepare buffer for regular message "fd"
|
||||
// NOLINTNEXTLINE(*array*)
|
||||
char buf[2];
|
||||
memset(&buf, 0, sizeof(buf));
|
||||
struct iovec io = {.iov_base = (void*)buf, .iov_len = sizeof(buf)};
|
||||
|
||||
// Prepare buffer for control message and zero it out
|
||||
// NOLINTNEXTLINE(*array*)
|
||||
char cbuf[CMSG_SPACE(sizeof(int))];
|
||||
memset(cbuf, 0, sizeof(cbuf));
|
||||
|
||||
// Define socket address to receive on: family AF_UNIX means unix domain
|
||||
// socket
|
||||
struct sockaddr_un addr = {.sun_family = AF_UNIX};
|
||||
std::copy(socket_name_.begin(), socket_name_.end(), addr.sun_path);
|
||||
|
||||
// Prepare message header
|
||||
struct msghdr msg = {
|
||||
.msg_name = (void*)&addr,
|
||||
.msg_namelen = sizeof(struct sockaddr_un),
|
||||
.msg_iov = &io,
|
||||
.msg_iovlen = 1,
|
||||
.msg_control = cbuf,
|
||||
.msg_controllen = sizeof(cbuf)};
|
||||
|
||||
// Recieve message on socket_
|
||||
TORCH_CHECK(
|
||||
recvmsg(socket_, &msg, 0) > 0,
|
||||
"Failed to receive fd: ",
|
||||
@ -129,6 +156,7 @@ int IpcChannel::recv_fd() {
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Extract control message and validate its content
|
||||
auto cmsg = CMSG_FIRSTHDR(&msg);
|
||||
TORCH_CHECK(cmsg != nullptr);
|
||||
TORCH_CHECK(cmsg->cmsg_len == CMSG_LEN(sizeof(int)));
|
||||
@ -207,6 +235,27 @@ void map_block(
|
||||
desc.location.id = static_cast<int>(device_idx);
|
||||
desc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
|
||||
C10_CUDA_DRIVER_CHECK(driver_api->cuMemSetAccess_(*dev_ptr, size, &desc, 1));
|
||||
#elif defined(USE_ROCM)
|
||||
C10_HIP_CHECK(hipMemAddressReserve(ptr, size, 0ULL, 0, 0ULL));
|
||||
C10_HIP_CHECK(hipMemMap(
|
||||
*ptr,
|
||||
size,
|
||||
0,
|
||||
reinterpret_cast<hipMemGenericAllocationHandle_t>(handle),
|
||||
0ULL));
|
||||
C10_HIP_CHECK(hipMemMap(
|
||||
*ptr,
|
||||
size,
|
||||
0,
|
||||
reinterpret_cast<hipMemGenericAllocationHandle_t>(handle),
|
||||
0ULL));
|
||||
|
||||
hipMemAccessDesc desc;
|
||||
desc.location.type = hipMemLocationTypeDevice;
|
||||
// NOLINTNEXTLINE(bugprone-signed-char-misuse)
|
||||
desc.location.id = static_cast<int>(device_idx);
|
||||
desc.flags = hipMemAccessFlagsProtReadWrite;
|
||||
C10_HIP_CHECK(hipMemSetAccess(*ptr, size, &desc, 1));
|
||||
#else
|
||||
TORCH_CHECK(
|
||||
false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED");
|
||||
|
@ -3,7 +3,9 @@
|
||||
#include <torch/csrc/distributed/c10d/DMAConnectivity.hpp>
|
||||
#include <torch/csrc/distributed/c10d/Utils.hpp>
|
||||
|
||||
// #include <cuda_runtime.h>
|
||||
#if defined(USE_ROCM)
|
||||
#include <rocm_smi/rocm_smi.h>
|
||||
#endif
|
||||
|
||||
namespace c10d::intra_node_comm {
|
||||
|
||||
@ -13,15 +15,13 @@ static std::vector<std::string> ENABLE_INTRA_NODE_COMM = {
|
||||
// IntraNodeComm can be used even without NVLink connection. This is only used
|
||||
// for testing purposes.
|
||||
static std::vector<std::string> TEST_INTRA_NODE_COMM = {"TEST_INTRA_NODE_COMM"};
|
||||
|
||||
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
||||
static int intraNodeCommIdx = 0;
|
||||
#endif
|
||||
|
||||
/**
|
||||
* Query the nvlink connection among devices.
|
||||
*/
|
||||
static NvlMesh getNvlMesh(const std::vector<int>& rankToDeviceIdx) {
|
||||
#if !defined(USE_RCOM)
|
||||
auto connectivity = detect_dma_connectivity(c10::DeviceType::CUDA, "nvlink");
|
||||
NvlMesh nvlMesh = {};
|
||||
for (size_t srcRank = 0; srcRank < kMaxDevices; ++srcRank) {
|
||||
@ -35,6 +35,31 @@ static NvlMesh getNvlMesh(const std::vector<int>& rankToDeviceIdx) {
|
||||
}
|
||||
}
|
||||
return nvlMesh;
|
||||
#else
|
||||
NvlMesh nvlMesh = {};
|
||||
const auto worldSize = rankToDeviceIdx.size();
|
||||
// For each device, loop over devices connected to it
|
||||
for (size_t idx = 0; idx < worldSize; ++idx) {
|
||||
for (size_t link = 0; link < kMaxDevices; ++link) {
|
||||
if (idx == link)
|
||||
continue;
|
||||
|
||||
bool conn = false;
|
||||
auto ret = rsmi_is_P2P_accessible(idx, link, &conn);
|
||||
if (ret != RSMI_STATUS_SUCCESS) {
|
||||
LOG(ERROR)
|
||||
<< "IntraNodeComm: getNvlMesh: rsmi_is_P2P_accessible returned error ret="
|
||||
<< ret;
|
||||
return {};
|
||||
}
|
||||
|
||||
if (conn) {
|
||||
nvlMesh[idx][link] += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
return nvlMesh;
|
||||
#endif
|
||||
}
|
||||
|
||||
/**
|
||||
@ -128,7 +153,6 @@ bool IntraNodeComm::rendezvous() {
|
||||
if (isInitialized_) {
|
||||
return true;
|
||||
}
|
||||
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
||||
if (!isIntraNodeCommSupported() || worldSize_ < 2 ||
|
||||
worldSize_ > kMaxDevices) {
|
||||
return false;
|
||||
@ -148,6 +172,14 @@ bool IntraNodeComm::rendezvous() {
|
||||
gethostname(devInfo.hostname, sizeof(devInfo.hostname));
|
||||
devInfo.deviceIdx = deviceIdx_;
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
auto ret = rsmi_init(0);
|
||||
if (ret != RSMI_STATUS_SUCCESS) {
|
||||
LOG(ERROR) << "IntraNodeComm:: rendezvous failed in rsmi_init, ret=" << ret;
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
|
||||
auto peerDevInfos =
|
||||
storeAllGather(store_, "handshake-0", rank_, worldSize_, devInfo);
|
||||
|
||||
@ -191,8 +223,6 @@ bool IntraNodeComm::rendezvous() {
|
||||
symmetricMemory_ = allocator->rendezvous(symmetricMemoryPtr_, std::nullopt);
|
||||
isInitialized_ = true;
|
||||
return true;
|
||||
#endif
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace c10d::intra_node_comm
|
||||
|
Reference in New Issue
Block a user