mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): This PR introduces a prototype for `SymmetricMemory` (including a CUDA implementation) - a remote-memory access-based communication primitive. It allows for user-defined communication patterns/kernels and is designed to be torch.compile-friendly. It addresses the major limitations of `IntraNodeComm` and `ProcessGroupCudaP2p` and serves as a replacement for them. ### SymmetricMemory `SymmetricMemory` represents symmetric allocations across a group of devices. The allocations represented by a `SymmetricMemory` object are accessible by all devices in the group. The class can be used for **op-level custom communication patterns** (via the get_buffer APIs and the synchronization primitives), as well as **custom communication kernels** (via the buffer and signal_pad device pointers). ### Python API Example ```python from torch._C.distributed_c10d import _SymmetricMemory # Set a store for rendezvousing symmetric allocations on a group of devices # identified by group_name. The concept of groups is logical; users can # utilize predefined groups (e.g., a group of device identified by a # ProcessGroup) or create custom ones. Note that a SymmetricMemoryAllocator # backends might employ a more efficient communication channel for the actual # rendezvous process and only use the store for bootstrapping purposes. _SymmetricMemory.set_group_info(group_name, rank, world_size, store) # Identical to empty_strided, but allows symmetric memory access to be # established for the allocated tensor via _SymmetricMemory.rendezvous(). # This function itself is not a collective operation. t = _SymmetricMemory.empty_strided_p2p((64, 64), (64, 1), torch.float32, group_name) # Users can write Python custom ops that leverages the symmetric memory access. # Below are examples of things users can do (assuming the group's world_size is 2). # Establishes symmetric memory access on tensors allocated via # _SymmetricMemory.empty_strided_p2p(). rendezvous() is a one-time process, # and the mapping between a local memory region and the associated SymmetricMemory # object is unique. Subsequent calls to rendezvous() with the same tensor will receive # the cached SymmetricMemory object. # # The function has a collective semantic and must be invoked simultaneously # from all rendezvous participants. symm_mem = _SymmetricMemory.rendezvous(t) # This represents the allocation on rank 0 and is accessible from all devices. buf = symm_mem.get_buffer(0, (64, 64), torch.float32) if symm_mem.rank == 0: symm_mem.wait_signal(src_rank=1) assert buf.eq(42).all() else: # The remote buffer can be used as a regular tensor buf.fill_(42) symm_mem.put_signal(dst_rank=0) symm_mem.barrier() if symm_mem.rank == 0: symm_mem.barrier() assert buf.eq(43).all() else: new_val = torch.empty_like(buf) new_val.fill_(43) # Contiguous copies to/from a remote buffer utilize copy engines # which bypasses SMs (i.e. no need to load the data into registers) buf.copy_(new_val) symm_mem.barrier() ``` ### Custom CUDA Comm Kernels Given a tensor, users can access the associated `SymmetricMemory` which provides pointer to remote buffers/signal_pads needed for custom communication kernels. ```cpp TORCH_API c10::intrusive_ptr<SymmetricMemory> get_symmetric_memory( const at::Tensor& tensor); class TORCH_API SymmetricMemory : public c10::intrusive_ptr_target { public: ... virtual std::vector<void*> get_buffer_ptrs() = 0; virtual std::vector<void*> get_signal_pad_ptrs() = 0; virtual void** get_buffer_ptrs_dev() = 0; virtual void** get_signal_pad_ptrs_dev() = 0; virtual size_t get_buffer_size() = 0; virtual size_t get_signal_pad_size() = 0; virtual int get_rank() = 0; virtual int get_world_size() = 0; ... }; ``` ### Limitations of IntraNodeComm and ProcessGroupCudaP2p Both `IntraNodeComm` (used by `ProcessGroupCudaP2p`) manages a single fixed-size workspace. This approach: - Leads to awkward UX in which the required workspace needs to be specified upfront. - Can not avoid extra copies for some algorithms in eager mode (e.g., custom/multimem all-reduce, reduce-scatter, all-gather). - Prevents torch.compile from eliminating all copies. In addition, they only offer out-of-the-box communication kernels and don't expose required pointers for user-defined, custom CUDA comm kernels. * __->__ #128582 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128582 Approved by: https://github.com/wanchaol
190 lines
5.7 KiB
C++
190 lines
5.7 KiB
C++
#include <torch/csrc/distributed/c10d/SymmetricMemory.hpp>
|
|
|
|
namespace {
|
|
|
|
using namespace c10d::symmetric_memory;
|
|
|
|
class AllocatorMap {
|
|
public:
|
|
static AllocatorMap& get() {
|
|
static AllocatorMap instance;
|
|
return instance;
|
|
}
|
|
|
|
void register_allocator(
|
|
c10::DeviceType device_type,
|
|
c10::intrusive_ptr<SymmetricMemoryAllocator> allocator) {
|
|
map_[device_type] = std::move(allocator);
|
|
}
|
|
|
|
c10::intrusive_ptr<SymmetricMemoryAllocator> get_allocator(
|
|
c10::DeviceType device_type) {
|
|
auto it = map_.find(device_type);
|
|
TORCH_CHECK(
|
|
it != map_.end(),
|
|
"SymmetricMemory does not support device type ",
|
|
device_type);
|
|
return it->second;
|
|
}
|
|
|
|
~AllocatorMap() {
|
|
for (auto& it : map_) {
|
|
it.second.release();
|
|
}
|
|
}
|
|
|
|
private:
|
|
AllocatorMap() = default;
|
|
AllocatorMap(const AllocatorMap&) = delete;
|
|
AllocatorMap& operator=(const AllocatorMap&) = delete;
|
|
|
|
std::unordered_map<
|
|
c10::DeviceType,
|
|
c10::intrusive_ptr<SymmetricMemoryAllocator>>
|
|
map_;
|
|
};
|
|
|
|
static std::unordered_map<std::string, GroupInfo> group_info_map{};
|
|
|
|
// Data structures for tracking persistent allocations
|
|
static std::unordered_map<uint64_t, void*> alloc_id_to_dev_ptr{};
|
|
static std::unordered_map<uint64_t, c10::weak_intrusive_ptr<c10::StorageImpl>>
|
|
alloc_id_to_storage{};
|
|
|
|
static at::Tensor empty_strided_p2p_persistent(
|
|
c10::IntArrayRef size,
|
|
c10::IntArrayRef stride,
|
|
c10::ScalarType dtype,
|
|
c10::Device device,
|
|
const std::string& group_name,
|
|
uint64_t alloc_id) {
|
|
// Make the allocation fails if a previous allocation with the same alloc_id
|
|
// is still active.
|
|
auto storage = alloc_id_to_storage.find(alloc_id);
|
|
if (storage != alloc_id_to_storage.end() && storage->second.use_count() > 0) {
|
|
TORCH_CHECK(
|
|
false,
|
|
"SymmetricMemory::empty_strided_p2p_persistent: ",
|
|
"can not allocate with alloc_id == ",
|
|
alloc_id,
|
|
" because a previous allocation with the same alloc_id "
|
|
"is still active.");
|
|
}
|
|
|
|
const size_t numel =
|
|
std::accumulate(size.begin(), size.end(), 1, std::multiplies<int>());
|
|
const size_t element_size = c10::elementSize(dtype);
|
|
const size_t alloc_size = numel * element_size;
|
|
|
|
auto allocator = get_allocator(device.type());
|
|
void* dev_ptr = nullptr;
|
|
if (alloc_id_to_dev_ptr.find(alloc_id) != alloc_id_to_dev_ptr.end()) {
|
|
dev_ptr = alloc_id_to_dev_ptr[alloc_id];
|
|
TORCH_CHECK(
|
|
alloc_size == allocator->get_alloc_size(dev_ptr),
|
|
"SymmetricMemory::empty_strided_p2p_persistent: ",
|
|
"requested allocation size (",
|
|
alloc_size,
|
|
") is different from the size of a previous allocation ",
|
|
"with the same alloc_id ",
|
|
allocator->get_alloc_size(dev_ptr));
|
|
} else {
|
|
dev_ptr = allocator->alloc(alloc_size, device.index(), group_name);
|
|
alloc_id_to_dev_ptr[alloc_id] = dev_ptr;
|
|
}
|
|
|
|
auto options = at::TensorOptions().dtype(dtype).device(device);
|
|
auto allocated = at::from_blob(dev_ptr, size, stride, options);
|
|
|
|
// Track the allocation's activeness
|
|
alloc_id_to_storage.erase(alloc_id);
|
|
alloc_id_to_storage.emplace(
|
|
alloc_id, allocated.storage().getWeakStorageImpl());
|
|
return allocated;
|
|
}
|
|
|
|
} // namespace
|
|
|
|
namespace c10d {
|
|
namespace symmetric_memory {
|
|
|
|
void register_allocator(
|
|
c10::DeviceType device_type,
|
|
c10::intrusive_ptr<SymmetricMemoryAllocator> allocator) {
|
|
return AllocatorMap::get().register_allocator(
|
|
device_type, std::move(allocator));
|
|
}
|
|
|
|
c10::intrusive_ptr<SymmetricMemoryAllocator> get_allocator(
|
|
c10::DeviceType device_type) {
|
|
return AllocatorMap::get().get_allocator(device_type);
|
|
}
|
|
|
|
void set_group_info(
|
|
const std::string& group_name,
|
|
int rank,
|
|
int world_size,
|
|
c10::intrusive_ptr<Store> store) {
|
|
TORCH_CHECK(group_info_map.find(group_name) == group_info_map.end());
|
|
GroupInfo group_info;
|
|
group_info.rank = rank;
|
|
group_info.world_size = world_size;
|
|
group_info.store = std::move(store);
|
|
group_info_map.emplace(group_name, std::move(group_info));
|
|
}
|
|
|
|
const GroupInfo& get_group_info(const std::string& group_name) {
|
|
TORCH_CHECK(
|
|
group_info_map.find(group_name) != group_info_map.end(),
|
|
"get_group_info: no group info associated with the group name ",
|
|
group_name);
|
|
return group_info_map[group_name];
|
|
}
|
|
|
|
at::Tensor empty_strided_p2p(
|
|
c10::IntArrayRef size,
|
|
c10::IntArrayRef stride,
|
|
c10::ScalarType dtype,
|
|
c10::Device device,
|
|
const std::string& group_name,
|
|
std::optional<uint64_t> alloc_id) {
|
|
if (alloc_id.has_value()) {
|
|
return empty_strided_p2p_persistent(
|
|
size, stride, dtype, device, group_name, *alloc_id);
|
|
}
|
|
const size_t numel =
|
|
std::accumulate(size.begin(), size.end(), 1, std::multiplies<int>());
|
|
const size_t element_size = c10::elementSize(dtype);
|
|
const size_t alloc_size = numel * element_size;
|
|
|
|
auto allocator = get_allocator(device.type());
|
|
void* dev_ptr = allocator->alloc(alloc_size, device.index(), group_name);
|
|
|
|
auto options = at::TensorOptions().dtype(dtype).device(device);
|
|
return at::from_blob(
|
|
dev_ptr,
|
|
size,
|
|
stride,
|
|
[allocator = std::move(allocator)](void* ptr) { allocator->free(ptr); },
|
|
options);
|
|
}
|
|
|
|
TORCH_API c10::intrusive_ptr<SymmetricMemory> rendezvous(
|
|
const at::Tensor& tensor) {
|
|
auto allocator = get_allocator(tensor.device().type());
|
|
return allocator->rendezvous(tensor.data_ptr());
|
|
}
|
|
|
|
c10::intrusive_ptr<SymmetricMemory> get_symmetric_memory(
|
|
const at::Tensor& tensor) {
|
|
auto allocator = get_allocator(tensor.device().type());
|
|
TORCH_CHECK(
|
|
allocator->is_rendezvous_completed(tensor.data_ptr()),
|
|
"SymmetricMemory: must invoke rendezvous on a tensor ",
|
|
"before calling get_symmetric_memory on it");
|
|
return allocator->rendezvous(tensor.data_ptr());
|
|
}
|
|
|
|
} // namespace symmetric_memory
|
|
} // namespace c10d
|