Files
pytorch/torch/csrc/distributed/c10d/SymmetricMemory.cpp
Yifu Wang 8771e3429c Introduce a prototype for SymmetricMemory (#128582)
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom):

This PR introduces a prototype for `SymmetricMemory` (including a CUDA implementation) - a remote-memory access-based communication primitive. It allows for user-defined communication patterns/kernels and is designed to be torch.compile-friendly. It addresses the major limitations of `IntraNodeComm` and `ProcessGroupCudaP2p` and serves as a replacement for them.

### SymmetricMemory

`SymmetricMemory` represents symmetric allocations across a group of devices. The allocations represented by a `SymmetricMemory` object are accessible by all devices in the group. The class can be used for **op-level custom communication patterns** (via the get_buffer APIs and the synchronization primitives), as well as **custom communication kernels** (via the buffer and signal_pad device pointers).

### Python API Example

```python
from torch._C.distributed_c10d import _SymmetricMemory

# Set a store for rendezvousing symmetric allocations on a group of devices
# identified by group_name. The concept of groups is logical; users can
# utilize predefined groups (e.g., a group of device identified by a
# ProcessGroup) or create custom ones. Note that a SymmetricMemoryAllocator
# backends might employ a more efficient communication channel for the actual
# rendezvous process and only use the store for bootstrapping purposes.
_SymmetricMemory.set_group_info(group_name, rank, world_size, store)

# Identical to empty_strided, but allows symmetric memory access to be
# established for the allocated tensor via _SymmetricMemory.rendezvous().
# This function itself is not a collective operation.
t = _SymmetricMemory.empty_strided_p2p((64, 64), (64, 1), torch.float32, group_name)

# Users can write Python custom ops that leverages the symmetric memory access.
# Below are examples of things users can do (assuming the group's world_size is 2).

# Establishes symmetric memory access on tensors allocated via
# _SymmetricMemory.empty_strided_p2p(). rendezvous() is a one-time process,
# and the mapping between a local memory region and the associated SymmetricMemory
# object is unique. Subsequent calls to rendezvous() with the same tensor will receive
# the cached SymmetricMemory object.
#
# The function has a collective semantic and must be invoked simultaneously
# from all rendezvous participants.
symm_mem = _SymmetricMemory.rendezvous(t)

# This represents the allocation on rank 0 and is accessible from all devices.
buf = symm_mem.get_buffer(0, (64, 64), torch.float32)

if symm_mem.rank == 0:
    symm_mem.wait_signal(src_rank=1)
    assert buf.eq(42).all()
else:
    # The remote buffer can be used as a regular tensor
    buf.fill_(42)
    symm_mem.put_signal(dst_rank=0)

symm_mem.barrier()

if symm_mem.rank == 0:
    symm_mem.barrier()
    assert buf.eq(43).all()
else:
    new_val = torch.empty_like(buf)
    new_val.fill_(43)
    # Contiguous copies to/from a remote buffer utilize copy engines
    # which bypasses SMs (i.e. no need to load the data into registers)
    buf.copy_(new_val)
    symm_mem.barrier()
```

### Custom CUDA Comm Kernels

Given a tensor, users can access the associated `SymmetricMemory` which provides pointer to remote buffers/signal_pads needed for custom communication kernels.

```cpp
TORCH_API c10::intrusive_ptr<SymmetricMemory> get_symmetric_memory(
    const at::Tensor& tensor);

class TORCH_API SymmetricMemory : public c10::intrusive_ptr_target {
 public:
  ...
  virtual std::vector<void*> get_buffer_ptrs() = 0;
  virtual std::vector<void*> get_signal_pad_ptrs() = 0;
  virtual void** get_buffer_ptrs_dev() = 0;
  virtual void** get_signal_pad_ptrs_dev() = 0;
  virtual size_t get_buffer_size() = 0;
  virtual size_t get_signal_pad_size() = 0;
  virtual int get_rank() = 0;
  virtual int get_world_size() = 0;
  ...
};
```

### Limitations of IntraNodeComm and ProcessGroupCudaP2p
Both `IntraNodeComm` (used by `ProcessGroupCudaP2p`) manages a single fixed-size workspace. This approach:
- Leads to awkward UX in which the required workspace needs to be specified upfront.
- Can not avoid extra copies for some algorithms in eager mode (e.g., custom/multimem all-reduce, reduce-scatter, all-gather).
- Prevents torch.compile from eliminating all copies.

In addition, they only offer out-of-the-box communication kernels and don't expose required pointers for user-defined, custom CUDA comm kernels.

* __->__ #128582

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128582
Approved by: https://github.com/wanchaol
2024-06-19 03:38:58 +00:00

190 lines
5.7 KiB
C++

#include <torch/csrc/distributed/c10d/SymmetricMemory.hpp>
namespace {
using namespace c10d::symmetric_memory;
class AllocatorMap {
public:
static AllocatorMap& get() {
static AllocatorMap instance;
return instance;
}
void register_allocator(
c10::DeviceType device_type,
c10::intrusive_ptr<SymmetricMemoryAllocator> allocator) {
map_[device_type] = std::move(allocator);
}
c10::intrusive_ptr<SymmetricMemoryAllocator> get_allocator(
c10::DeviceType device_type) {
auto it = map_.find(device_type);
TORCH_CHECK(
it != map_.end(),
"SymmetricMemory does not support device type ",
device_type);
return it->second;
}
~AllocatorMap() {
for (auto& it : map_) {
it.second.release();
}
}
private:
AllocatorMap() = default;
AllocatorMap(const AllocatorMap&) = delete;
AllocatorMap& operator=(const AllocatorMap&) = delete;
std::unordered_map<
c10::DeviceType,
c10::intrusive_ptr<SymmetricMemoryAllocator>>
map_;
};
static std::unordered_map<std::string, GroupInfo> group_info_map{};
// Data structures for tracking persistent allocations
static std::unordered_map<uint64_t, void*> alloc_id_to_dev_ptr{};
static std::unordered_map<uint64_t, c10::weak_intrusive_ptr<c10::StorageImpl>>
alloc_id_to_storage{};
static at::Tensor empty_strided_p2p_persistent(
c10::IntArrayRef size,
c10::IntArrayRef stride,
c10::ScalarType dtype,
c10::Device device,
const std::string& group_name,
uint64_t alloc_id) {
// Make the allocation fails if a previous allocation with the same alloc_id
// is still active.
auto storage = alloc_id_to_storage.find(alloc_id);
if (storage != alloc_id_to_storage.end() && storage->second.use_count() > 0) {
TORCH_CHECK(
false,
"SymmetricMemory::empty_strided_p2p_persistent: ",
"can not allocate with alloc_id == ",
alloc_id,
" because a previous allocation with the same alloc_id "
"is still active.");
}
const size_t numel =
std::accumulate(size.begin(), size.end(), 1, std::multiplies<int>());
const size_t element_size = c10::elementSize(dtype);
const size_t alloc_size = numel * element_size;
auto allocator = get_allocator(device.type());
void* dev_ptr = nullptr;
if (alloc_id_to_dev_ptr.find(alloc_id) != alloc_id_to_dev_ptr.end()) {
dev_ptr = alloc_id_to_dev_ptr[alloc_id];
TORCH_CHECK(
alloc_size == allocator->get_alloc_size(dev_ptr),
"SymmetricMemory::empty_strided_p2p_persistent: ",
"requested allocation size (",
alloc_size,
") is different from the size of a previous allocation ",
"with the same alloc_id ",
allocator->get_alloc_size(dev_ptr));
} else {
dev_ptr = allocator->alloc(alloc_size, device.index(), group_name);
alloc_id_to_dev_ptr[alloc_id] = dev_ptr;
}
auto options = at::TensorOptions().dtype(dtype).device(device);
auto allocated = at::from_blob(dev_ptr, size, stride, options);
// Track the allocation's activeness
alloc_id_to_storage.erase(alloc_id);
alloc_id_to_storage.emplace(
alloc_id, allocated.storage().getWeakStorageImpl());
return allocated;
}
} // namespace
namespace c10d {
namespace symmetric_memory {
void register_allocator(
c10::DeviceType device_type,
c10::intrusive_ptr<SymmetricMemoryAllocator> allocator) {
return AllocatorMap::get().register_allocator(
device_type, std::move(allocator));
}
c10::intrusive_ptr<SymmetricMemoryAllocator> get_allocator(
c10::DeviceType device_type) {
return AllocatorMap::get().get_allocator(device_type);
}
void set_group_info(
const std::string& group_name,
int rank,
int world_size,
c10::intrusive_ptr<Store> store) {
TORCH_CHECK(group_info_map.find(group_name) == group_info_map.end());
GroupInfo group_info;
group_info.rank = rank;
group_info.world_size = world_size;
group_info.store = std::move(store);
group_info_map.emplace(group_name, std::move(group_info));
}
const GroupInfo& get_group_info(const std::string& group_name) {
TORCH_CHECK(
group_info_map.find(group_name) != group_info_map.end(),
"get_group_info: no group info associated with the group name ",
group_name);
return group_info_map[group_name];
}
at::Tensor empty_strided_p2p(
c10::IntArrayRef size,
c10::IntArrayRef stride,
c10::ScalarType dtype,
c10::Device device,
const std::string& group_name,
std::optional<uint64_t> alloc_id) {
if (alloc_id.has_value()) {
return empty_strided_p2p_persistent(
size, stride, dtype, device, group_name, *alloc_id);
}
const size_t numel =
std::accumulate(size.begin(), size.end(), 1, std::multiplies<int>());
const size_t element_size = c10::elementSize(dtype);
const size_t alloc_size = numel * element_size;
auto allocator = get_allocator(device.type());
void* dev_ptr = allocator->alloc(alloc_size, device.index(), group_name);
auto options = at::TensorOptions().dtype(dtype).device(device);
return at::from_blob(
dev_ptr,
size,
stride,
[allocator = std::move(allocator)](void* ptr) { allocator->free(ptr); },
options);
}
TORCH_API c10::intrusive_ptr<SymmetricMemory> rendezvous(
const at::Tensor& tensor) {
auto allocator = get_allocator(tensor.device().type());
return allocator->rendezvous(tensor.data_ptr());
}
c10::intrusive_ptr<SymmetricMemory> get_symmetric_memory(
const at::Tensor& tensor) {
auto allocator = get_allocator(tensor.device().type());
TORCH_CHECK(
allocator->is_rendezvous_completed(tensor.data_ptr()),
"SymmetricMemory: must invoke rendezvous on a tensor ",
"before calling get_symmetric_memory on it");
return allocator->rendezvous(tensor.data_ptr());
}
} // namespace symmetric_memory
} // namespace c10d