mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
### Summary - Added multicast support to SymmetricMemory. If the cuda runtime and cuda driver have multicast support, SymmetricMemory associate all peer buffers with a multicast object and exposes the multicast virtual address. - Implemented `multimem_all_reduce_` and `multimem_one_shot_all_reduce` based on the multicast support. The two variants shows different performance characteristic for different message size. We plan to use Inductor for collective algo selection (and required symmetric memory buffer allocation). ### Benchmark 8xH100 (non-standard version with HBM2e at 650W). NVSwitch V3 with NVLS support.   Differential Revision: [D61682507](https://our.internmc.facebook.com/intern/diff/D61682507) Pull Request resolved: https://github.com/pytorch/pytorch/pull/133424 Approved by: https://github.com/yf225, https://github.com/weifengpy
116 lines
3.0 KiB
C++
116 lines
3.0 KiB
C++
#pragma once
|
|
|
|
#include <ATen/ATen.h>
|
|
#include <torch/csrc/distributed/c10d/Store.hpp>
|
|
#include <torch/csrc/distributed/c10d/SymmetricMemory.hpp>
|
|
|
|
namespace c10d {
|
|
namespace symmetric_memory {
|
|
|
|
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
|
using HandleType = CUmemGenericAllocationHandle;
|
|
#else
|
|
using HandleType = void*;
|
|
#endif
|
|
|
|
class CUDASymmetricMemory : public SymmetricMemory {
|
|
public:
|
|
CUDASymmetricMemory(
|
|
std::vector<HandleType> handles,
|
|
size_t block_size,
|
|
std::vector<void*> buffers,
|
|
std::vector<void*> signal_pads,
|
|
HandleType mc_handle,
|
|
void* mc_addr,
|
|
size_t buffer_size,
|
|
int local_device_idx,
|
|
int rank,
|
|
int world_size);
|
|
|
|
~CUDASymmetricMemory() override;
|
|
|
|
std::vector<void*> get_buffer_ptrs() override;
|
|
std::vector<void*> get_signal_pad_ptrs() override;
|
|
void** get_buffer_ptrs_dev() override;
|
|
void** get_signal_pad_ptrs_dev() override;
|
|
size_t get_buffer_size() override;
|
|
size_t get_signal_pad_size() override;
|
|
|
|
bool has_multicast_support() override;
|
|
void* get_multicast_ptr() override;
|
|
|
|
at::Tensor get_buffer(
|
|
int rank,
|
|
c10::IntArrayRef sizes,
|
|
c10::ScalarType dtype,
|
|
int64_t storage_offset) override;
|
|
|
|
void barrier(int channel) override;
|
|
void put_signal(int dst_rank, int channel) override;
|
|
void wait_signal(int src_rank, int channel) override;
|
|
|
|
int get_rank() override;
|
|
int get_world_size() override;
|
|
|
|
private:
|
|
std::vector<HandleType> handles_;
|
|
size_t block_size_;
|
|
std::vector<void*> buffers_;
|
|
std::vector<void*> signal_pads_;
|
|
HandleType mc_handle_;
|
|
void* mc_addr_;
|
|
size_t buffer_size_;
|
|
int local_device_idx_;
|
|
int rank_;
|
|
int world_size_;
|
|
void** buffers_dev_;
|
|
void** signal_pads_dev_;
|
|
std::optional<std::function<void(void)>> finalizer_;
|
|
};
|
|
|
|
struct Block : public c10::intrusive_ptr_target {
|
|
HandleType handle;
|
|
int device_idx;
|
|
size_t block_size;
|
|
size_t buffer_size;
|
|
size_t signal_pad_offset;
|
|
std::string group_name;
|
|
c10::intrusive_ptr<CUDASymmetricMemory> symm_mem = nullptr;
|
|
|
|
Block(
|
|
HandleType handle,
|
|
int device_idx,
|
|
size_t block_size,
|
|
size_t buffer_size,
|
|
size_t signal_pad_offset,
|
|
const std::string& group_name)
|
|
: handle(handle),
|
|
device_idx(device_idx),
|
|
block_size(block_size),
|
|
buffer_size(buffer_size),
|
|
signal_pad_offset(signal_pad_offset),
|
|
group_name(group_name),
|
|
symm_mem(nullptr) {}
|
|
};
|
|
|
|
class CUDASymmetricMemoryAllocator : public SymmetricMemoryAllocator {
|
|
public:
|
|
void* alloc(size_t size, int device_idx, const std::string& group_name)
|
|
override;
|
|
|
|
void free(void* ptr) override;
|
|
size_t get_alloc_size(void* ptr) override;
|
|
c10::intrusive_ptr<SymmetricMemory> rendezvous(void* ptr) override;
|
|
bool is_rendezvous_completed(void* ptr) override;
|
|
bool has_multicast_support() override;
|
|
|
|
private:
|
|
c10::intrusive_ptr<Block> find_block(void* ptr);
|
|
|
|
std::shared_mutex mutex_;
|
|
std::unordered_map<void*, c10::intrusive_ptr<Block>> ptr_to_block_;
|
|
};
|
|
|
|
} // namespace symmetric_memory
|
|
} // namespace c10d
|