mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[symm_mem] Add nccl as a backend for symmetric memory (#155740)
Running unit test: TORCH_SYMMMEM=NCCL TORCH_DISTRIBUTED_DEBUG=INFO TORCH_CPP_LOG_LEVEL=INFO pytest test/distributed/test_nccl.py -k test_nccl_symmem_alloc Pull Request resolved: https://github.com/pytorch/pytorch/pull/155740 Approved by: https://github.com/kwen2501
This commit is contained in:
@ -724,6 +724,7 @@ libtorch_cuda_distributed_extra_sources = [
|
|||||||
"torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryOps.cu",
|
"torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryOps.cu",
|
||||||
"torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryUtils.cpp",
|
"torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryUtils.cpp",
|
||||||
"torch/csrc/distributed/c10d/symm_mem/CudaDMAConnectivity.cpp",
|
"torch/csrc/distributed/c10d/symm_mem/CudaDMAConnectivity.cpp",
|
||||||
|
"torch/csrc/distributed/c10d/symm_mem/NCCLSymmetricMemory.cu",
|
||||||
"torch/csrc/distributed/c10d/symm_mem/intra_node_comm.cpp",
|
"torch/csrc/distributed/c10d/symm_mem/intra_node_comm.cpp",
|
||||||
"torch/csrc/distributed/c10d/symm_mem/intra_node_comm.cu",
|
"torch/csrc/distributed/c10d/symm_mem/intra_node_comm.cu",
|
||||||
"torch/csrc/distributed/rpc/tensorpipe_cuda.cpp",
|
"torch/csrc/distributed/rpc/tensorpipe_cuda.cpp",
|
||||||
|
|||||||
@ -579,6 +579,7 @@ if(USE_CUDA)
|
|||||||
${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu
|
${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu
|
||||||
${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryOps.cu
|
${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryOps.cu
|
||||||
${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryUtils.cpp
|
${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryUtils.cpp
|
||||||
|
${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/NCCLSymmetricMemory.cu
|
||||||
PROPERTIES COMPILE_FLAGS "-DPYTORCH_C10_DRIVER_API_SUPPORTED=1"
|
PROPERTIES COMPILE_FLAGS "-DPYTORCH_C10_DRIVER_API_SUPPORTED=1"
|
||||||
)
|
)
|
||||||
endif()
|
endif()
|
||||||
|
|||||||
@ -7,11 +7,13 @@ import torch
|
|||||||
import torch.cuda
|
import torch.cuda
|
||||||
import torch.cuda.nccl as nccl
|
import torch.cuda.nccl as nccl
|
||||||
import torch.distributed as c10d
|
import torch.distributed as c10d
|
||||||
|
import torch.distributed._symmetric_memory as symm_mem
|
||||||
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU
|
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU
|
||||||
from torch.testing._internal.common_device_type import (
|
from torch.testing._internal.common_device_type import (
|
||||||
dtypes,
|
dtypes,
|
||||||
instantiate_device_type_tests,
|
instantiate_device_type_tests,
|
||||||
)
|
)
|
||||||
|
from torch.testing._internal.common_distributed import MultiProcContinousTest
|
||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
IS_WINDOWS,
|
IS_WINDOWS,
|
||||||
load_tests,
|
load_tests,
|
||||||
@ -239,6 +241,41 @@ class TestNCCL(TestCase):
|
|||||||
self.assertEqual(outputs[i], expected[i])
|
self.assertEqual(outputs[i], expected[i])
|
||||||
|
|
||||||
|
|
||||||
|
device_type = "cuda"
|
||||||
|
device_module = torch.get_device_module(device_type)
|
||||||
|
|
||||||
|
|
||||||
|
class NCCLSymmetricMemoryTest(MultiProcContinousTest):
|
||||||
|
def _init_device(self) -> None:
|
||||||
|
# TODO: relieve this (seems to hang if without)
|
||||||
|
device_module.set_device(self.device)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> torch.device:
|
||||||
|
return torch.device(device_type, self.rank)
|
||||||
|
|
||||||
|
# To run this test, one needs to TORCH_SYMMMEM=NCCL when running the test.
|
||||||
|
@skip_but_pass_in_sandcastle_if(TEST_WITH_ROCM, "Skip NCCL tests for ROCm")
|
||||||
|
@skip_but_pass_in_sandcastle_if(IS_WINDOWS, "NCCL doesn't support Windows")
|
||||||
|
def test_nccl_symmem_alloc(self):
|
||||||
|
self._init_device()
|
||||||
|
c10d.all_reduce(torch.ones(1, device=self.device))
|
||||||
|
group_name = c10d.group.WORLD.group_name
|
||||||
|
symm_mem.enable_symm_mem_for_group(group_name)
|
||||||
|
|
||||||
|
dtype = torch.float
|
||||||
|
numel = 1024
|
||||||
|
|
||||||
|
def foo():
|
||||||
|
inp = symm_mem.empty(numel, dtype=dtype, device=self.device)
|
||||||
|
symm_mem.rendezvous(inp, group=group_name)
|
||||||
|
|
||||||
|
foo()
|
||||||
|
|
||||||
|
out = symm_mem.empty(numel, dtype=dtype, device=self.device)
|
||||||
|
symm_mem.rendezvous(out, group=group_name)
|
||||||
|
|
||||||
|
|
||||||
instantiate_device_type_tests(TestNCCL, globals(), only_for="cuda")
|
instantiate_device_type_tests(TestNCCL, globals(), only_for="cuda")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
355
torch/csrc/distributed/c10d/symm_mem/NCCLSymmetricMemory.cu
Normal file
355
torch/csrc/distributed/c10d/symm_mem/NCCLSymmetricMemory.cu
Normal file
@ -0,0 +1,355 @@
|
|||||||
|
#ifdef USE_C10D_NCCL
|
||||||
|
#include <nccl.h>
|
||||||
|
#include <torch/csrc/cuda/nccl.h>
|
||||||
|
|
||||||
|
#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 27, 1)
|
||||||
|
#define NCCL_HAS_SYMMEM_SUPPORT
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef NCCL_HAS_SYMMEM_SUPPORT
|
||||||
|
#include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
|
||||||
|
#include <torch/csrc/distributed/c10d/GroupRegistry.hpp>
|
||||||
|
#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
|
||||||
|
#include <torch/csrc/distributed/c10d/cuda/utils.hpp>
|
||||||
|
#include <torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory-inl.h>
|
||||||
|
#include <torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryUtils.hpp>
|
||||||
|
#include <torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.hpp>
|
||||||
|
|
||||||
|
#include <ATen/ceil_div.h>
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
#include <c10/cuda/CUDACachingAllocator.h>
|
||||||
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
#include <c10/util/error.h>
|
||||||
|
|
||||||
|
namespace c10d {
|
||||||
|
namespace symmetric_memory {
|
||||||
|
|
||||||
|
/* Start of NCCLAllocation implementation */
|
||||||
|
|
||||||
|
static StoreExchange storeExchange = StoreExchange("NCCLAllocation");
|
||||||
|
|
||||||
|
struct NCCLAllocation {
|
||||||
|
void* ptr;
|
||||||
|
size_t buffer_size;
|
||||||
|
int device_idx;
|
||||||
|
|
||||||
|
NCCLAllocation(void* ptr, size_t buffer_size, int device_idx)
|
||||||
|
: ptr(ptr), buffer_size(buffer_size), device_idx(device_idx) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
class NCCLSymmetricMemory : public SymmetricMemory {
|
||||||
|
public:
|
||||||
|
NCCLSymmetricMemory(
|
||||||
|
std::shared_ptr<NCCLAllocation> allocation,
|
||||||
|
const std::string& group_name,
|
||||||
|
ncclWindow_t handle,
|
||||||
|
ncclWindow_t signal_handle)
|
||||||
|
: allocation_(allocation),
|
||||||
|
buffer_size_(allocation->buffer_size),
|
||||||
|
device_idx_(allocation->device_idx),
|
||||||
|
group_name_(group_name),
|
||||||
|
handle_(handle),
|
||||||
|
signal_handle_(signal_handle) {
|
||||||
|
c10::cuda::CUDAGuard guard(device_idx_);
|
||||||
|
|
||||||
|
// We need some API like nvshmem_extension::nvshmem_ptr()
|
||||||
|
// put API to get the reference of remote memory.
|
||||||
|
// WIP
|
||||||
|
}
|
||||||
|
|
||||||
|
~NCCLSymmetricMemory() override = default;
|
||||||
|
|
||||||
|
std::vector<void*> get_buffer_ptrs() override {
|
||||||
|
return buffers_;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<void*> get_signal_pad_ptrs() override {
|
||||||
|
return signal_pads_;
|
||||||
|
}
|
||||||
|
|
||||||
|
void** get_buffer_ptrs_dev() override {
|
||||||
|
return buffers_dev_;
|
||||||
|
}
|
||||||
|
|
||||||
|
void** get_signal_pad_ptrs_dev() override {
|
||||||
|
return signal_pads_dev_;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t get_buffer_size() override {
|
||||||
|
return buffer_size_;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t get_signal_pad_size() override {
|
||||||
|
return signal_pad_size;
|
||||||
|
};
|
||||||
|
|
||||||
|
bool has_multicast_support() override {
|
||||||
|
// TODO
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
void* get_multicast_ptr() override {
|
||||||
|
// TODO
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: This is up for change.
|
||||||
|
at::Tensor get_buffer(
|
||||||
|
int rank,
|
||||||
|
c10::IntArrayRef sizes,
|
||||||
|
c10::ScalarType dtype,
|
||||||
|
int64_t storage_offset) {
|
||||||
|
// TODO: deduplicate
|
||||||
|
const size_t numel = std::accumulate(
|
||||||
|
sizes.begin(),
|
||||||
|
sizes.end(),
|
||||||
|
static_cast<size_t>(1),
|
||||||
|
std::multiplies<size_t>());
|
||||||
|
const auto element_size = c10::elementSize(dtype);
|
||||||
|
const auto req_size = (numel + storage_offset) * element_size;
|
||||||
|
TORCH_CHECK(
|
||||||
|
req_size <= buffer_size_,
|
||||||
|
"NCCLSymmetricMemory::get_buffer: the requested size (",
|
||||||
|
req_size,
|
||||||
|
" bytes) exceeds the allocated size (",
|
||||||
|
buffer_size_,
|
||||||
|
" bytes)");
|
||||||
|
auto data_ptr = reinterpret_cast<uint8_t*>(buffers_[rank]) +
|
||||||
|
storage_offset * element_size;
|
||||||
|
auto device = c10::Device(c10::DeviceType::CUDA, device_idx_);
|
||||||
|
auto options = at::TensorOptions().dtype(dtype).device(device);
|
||||||
|
return at::for_blob(data_ptr, sizes)
|
||||||
|
.options(options)
|
||||||
|
.target_device(device)
|
||||||
|
.make_tensor();
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: This is up for change.
|
||||||
|
at::Tensor get_signal_pad(
|
||||||
|
int rank,
|
||||||
|
c10::IntArrayRef sizes,
|
||||||
|
std::optional<c10::ScalarType> dtype,
|
||||||
|
int64_t storage_offset) override {
|
||||||
|
// TODO: deduplicate
|
||||||
|
// If the dtype is unspecified, default it to UInt32, as it
|
||||||
|
// is the most common type for signaling purposes.
|
||||||
|
if (!dtype.has_value()) {
|
||||||
|
dtype = c10::ScalarType::UInt32;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the shape is unspecified, treat the signal pad as a 1d tensor.
|
||||||
|
const auto element_size = c10::elementSize(*dtype);
|
||||||
|
std::vector<int64_t> shape;
|
||||||
|
if (!sizes.empty()) {
|
||||||
|
shape = sizes.vec();
|
||||||
|
} else {
|
||||||
|
shape.push_back(signal_pad_size / element_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t numel = std::accumulate(
|
||||||
|
shape.begin(),
|
||||||
|
shape.end(),
|
||||||
|
static_cast<size_t>(1),
|
||||||
|
std::multiplies<size_t>());
|
||||||
|
const auto req_size = (numel + storage_offset) * element_size;
|
||||||
|
TORCH_CHECK(
|
||||||
|
req_size <= signal_pad_size,
|
||||||
|
"NCCLSymmetricMemory::get_signal_pad: the requested size (",
|
||||||
|
req_size,
|
||||||
|
" bytes) exceeds the allocated size (",
|
||||||
|
signal_pad_size,
|
||||||
|
" bytes)");
|
||||||
|
auto data_ptr = reinterpret_cast<uint8_t*>(signal_pads_[rank]) +
|
||||||
|
storage_offset * element_size;
|
||||||
|
auto device = c10::Device(c10::DeviceType::CUDA, device_idx_);
|
||||||
|
auto options = at::TensorOptions().dtype(*dtype).device(device);
|
||||||
|
return at::for_blob(data_ptr, shape)
|
||||||
|
.options(options)
|
||||||
|
.target_device(device)
|
||||||
|
.make_tensor();
|
||||||
|
}
|
||||||
|
|
||||||
|
void barrier(int channel, size_t timeout_ms) override {
|
||||||
|
// TODO
|
||||||
|
}
|
||||||
|
|
||||||
|
void put_signal(int dst_rank, int channel, size_t timeout_ms) override {
|
||||||
|
// TODO
|
||||||
|
}
|
||||||
|
|
||||||
|
void wait_signal(int src_rank, int channel, size_t timeout_ms) override {
|
||||||
|
// TODO
|
||||||
|
}
|
||||||
|
|
||||||
|
int get_rank() override {
|
||||||
|
return rank_;
|
||||||
|
}
|
||||||
|
|
||||||
|
int get_world_size() override {
|
||||||
|
return world_size_;
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual std::vector<int>& get_rank_to_global_rank() override {
|
||||||
|
return rank_to_global_rank_;
|
||||||
|
};
|
||||||
|
|
||||||
|
int* get_rank_to_global_rank_dev() override {
|
||||||
|
return rank_to_global_rank_dev_;
|
||||||
|
};
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::shared_ptr<NCCLAllocation> allocation_;
|
||||||
|
size_t buffer_size_;
|
||||||
|
// TODO: We need to finalize what booking variables we need for nccl backend.
|
||||||
|
std::vector<void*> buffers_;
|
||||||
|
std::vector<void*> signal_pads_;
|
||||||
|
int device_idx_;
|
||||||
|
int rank_;
|
||||||
|
int world_size_;
|
||||||
|
void** buffers_dev_;
|
||||||
|
void** signal_pads_dev_;
|
||||||
|
std::string group_name_;
|
||||||
|
ncclWindow_t handle_;
|
||||||
|
ncclWindow_t signal_handle_;
|
||||||
|
|
||||||
|
std::vector<int> rank_to_global_rank_;
|
||||||
|
int* rank_to_global_rank_dev_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class NCCLSymmetricMemoryAllocator : public SymmetricMemoryAllocator {
|
||||||
|
public:
|
||||||
|
void* alloc(
|
||||||
|
size_t size,
|
||||||
|
int device_idx,
|
||||||
|
const std::optional<std::string>& group_name) override {
|
||||||
|
TORCH_CHECK(
|
||||||
|
group_name == std::nullopt,
|
||||||
|
"NCCLSymmetricMemoryAllocator::alloc "
|
||||||
|
"must not be called with a group_name");
|
||||||
|
|
||||||
|
auto group_info = get_group_info("0");
|
||||||
|
auto store = group_info.store;
|
||||||
|
int rank = group_info.rank;
|
||||||
|
int world_size = group_info.world_size;
|
||||||
|
c10::cuda::CUDAGuard guard(device_idx);
|
||||||
|
// TODO: we might need to use a roundup or mempool for mem allocation.
|
||||||
|
void* ptr;
|
||||||
|
C10D_NCCL_CHECK(ncclMemAlloc(&ptr, size), "ncclMemAlloc");
|
||||||
|
auto allocation =
|
||||||
|
std::make_shared<NCCLAllocation>(ptr, size, device_idx);
|
||||||
|
// TODO: thread safety
|
||||||
|
allocations_.emplace(ptr, allocation);
|
||||||
|
return ptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
void free(void* ptr) override {
|
||||||
|
// TODO: thread safety
|
||||||
|
ptr_to_symm_mem_.erase(ptr);
|
||||||
|
allocations_.erase(ptr);
|
||||||
|
};
|
||||||
|
|
||||||
|
size_t get_alloc_size(void* ptr) override {
|
||||||
|
auto it = ptr_to_symm_mem_.find(ptr);
|
||||||
|
if (it == ptr_to_symm_mem_.end()) {
|
||||||
|
TORCH_CHECK(
|
||||||
|
false, ptr, " is not allocated with NCCLSymmetricMemoryAllocator");
|
||||||
|
}
|
||||||
|
return it->second->get_buffer_size();
|
||||||
|
};
|
||||||
|
|
||||||
|
c10::intrusive_ptr<SymmetricMemory> rendezvous(
|
||||||
|
void* ptr,
|
||||||
|
const std::optional<std::string>& group_name) override {
|
||||||
|
TORCH_CHECK(group_name.has_value(), "group_name must be provided");
|
||||||
|
{
|
||||||
|
auto it = symm_mems_.find(std::make_tuple(ptr, *group_name));
|
||||||
|
if (it != symm_mems_.end()) {
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto it = allocations_.find(ptr);
|
||||||
|
TORCH_CHECK(it != allocations_.end(), "memory needs to be first allocated before calling rendezvous.");
|
||||||
|
|
||||||
|
|
||||||
|
auto group = resolve_process_group(group_name.value());
|
||||||
|
auto alloc = it->second;
|
||||||
|
c10::cuda::CUDAGuard guard(alloc->device_idx);
|
||||||
|
ncclWindow_t handle;
|
||||||
|
ncclWindow_t signal_handle;
|
||||||
|
|
||||||
|
auto group_info = get_group_info(group_name.value());
|
||||||
|
auto global_rank = get_group_info("0").rank;
|
||||||
|
auto buffer_size_map =
|
||||||
|
storeExchange.all_gather(group_info.store, group_info.rank, group_info.world_size, it->second->buffer_size);
|
||||||
|
|
||||||
|
LOG(INFO) << "[rank " << group_info.rank << "]"
|
||||||
|
<< "buffer_size_map: " << buffer_size_map;
|
||||||
|
// NCCL window registration api requires all ranks to have the same buffer size
|
||||||
|
// we have this check to make sure all ranks have the same buffer size.
|
||||||
|
for (auto r = 0; r < group_info.world_size; ++r) {
|
||||||
|
TORCH_CHECK(alloc->buffer_size == buffer_size_map[r], "buffer size mismatch");
|
||||||
|
}
|
||||||
|
auto* ncclPg = dynamic_cast<c10d::ProcessGroupNCCL*>(
|
||||||
|
group->getBackend(c10::DeviceType::CUDA).get());
|
||||||
|
TORCH_CHECK(ncclPg != nullptr, "backend must be a NCCL process group");
|
||||||
|
ncclComm_t comm = reinterpret_cast<ncclComm_t>(ncclPg->getCommPtr());
|
||||||
|
C10D_NCCL_CHECK(
|
||||||
|
ncclCommWindowRegister(comm, ptr, alloc->buffer_size, (ncclWindow_t*)&handle, NCCL_WIN_COLL_SYMMETRIC),
|
||||||
|
c10::str(
|
||||||
|
"Failed to window register segment with ptr ",
|
||||||
|
ptr,
|
||||||
|
", size ",
|
||||||
|
alloc->buffer_size,
|
||||||
|
" on ncclComm_ ",
|
||||||
|
comm));
|
||||||
|
|
||||||
|
void* signal_pad_ptr;
|
||||||
|
TORCH_CHECK(ncclMemAlloc(&signal_pad_ptr, signal_pad_size) == ncclSuccess, "ncclMemAlloc failed");
|
||||||
|
C10D_NCCL_CHECK(
|
||||||
|
ncclCommWindowRegister(comm, signal_pad_ptr, signal_pad_size, (ncclWindow_t*)&signal_handle, NCCL_WIN_COLL_SYMMETRIC),
|
||||||
|
c10::str(
|
||||||
|
"Failed to window register segment with ptr ",
|
||||||
|
signal_pad_ptr,
|
||||||
|
", size ",
|
||||||
|
signal_pad_size,
|
||||||
|
" on ncclComm_ ",
|
||||||
|
comm));
|
||||||
|
|
||||||
|
auto symm_mem =
|
||||||
|
c10::make_intrusive<NCCLSymmetricMemory>(alloc, *group_name, std::move(handle), std::move(signal_handle));
|
||||||
|
|
||||||
|
symm_mems_[std::make_tuple(ptr, *group_name)] = symm_mem;
|
||||||
|
return symm_mem;
|
||||||
|
};
|
||||||
|
|
||||||
|
bool has_multicast_support(int device_idx) override {
|
||||||
|
// TODO
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::unordered_map<void*, c10::intrusive_ptr<SymmetricMemory>>
|
||||||
|
ptr_to_symm_mem_;
|
||||||
|
|
||||||
|
std::unordered_map<void*, std::shared_ptr<NCCLAllocation>> allocations_;
|
||||||
|
std::map<std::tuple<void*, std::string>, c10::intrusive_ptr<SymmetricMemory>>
|
||||||
|
symm_mems_;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct RegisterNCCLSymmetricMemoryAllocator {
|
||||||
|
RegisterNCCLSymmetricMemoryAllocator() {
|
||||||
|
// Query backend used for CUDA tensor
|
||||||
|
if (getSymmMemBackendCUDA() == "NCCL") {
|
||||||
|
register_allocator(
|
||||||
|
c10::DeviceType::CUDA,
|
||||||
|
c10::make_intrusive<NCCLSymmetricMemoryAllocator>());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
static RegisterNCCLSymmetricMemoryAllocator register_allocator_;
|
||||||
|
|
||||||
|
} // namespace symmetric_memory
|
||||||
|
} // namespace c10d
|
||||||
|
#endif // NCCL_HAS_SYMMEM_SUPPORT
|
||||||
|
#endif // USE_C10D_NCCL
|
||||||
Reference in New Issue
Block a user