DeepCompile ZeRO-3: robust allgather for uneven shards; fix profiling… (#7489)

… meta key (max_mem)

---------

Signed-off-by: Abhishek <dalakotiashu150@gmail.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Co-authored-by: Abhishek <dalakotiashu150@gmail.com>
Co-authored-by: Masahiro Tanaka <mtanaka@anyscale.com>
This commit is contained in:
Jupiter-Guy
2025-09-22 16:45:00 -07:00
committed by GitHub
parent 80033a8293
commit 325c6c5e9c
4 changed files with 117 additions and 13 deletions

View File

@ -69,9 +69,15 @@ public:
const at::Tensor& ds_tensor = param.getDSTensor();
if (symm_mem == nullptr) {
// Fast path: assume uniform shard sizes (ZeRO-3 partitions are padded to uniform size)
const int world_size = process_group_->getSize();
const int64_t shard_elems = ds_tensor.numel();
// Perform all-gather directly into the pre-allocated padded output buffer
// NCCL requires contiguous storage; use .contiguous() explicitly
ncclResult_t result = ncclAllGather(ds_tensor.contiguous().data_ptr(),
output_buf.data_ptr(),
ds_tensor.numel(),
shard_elems,
get_nccl_data_type(ds_tensor.scalar_type()),
nccl_comm_,
ag_stream_);
@ -104,13 +110,30 @@ public:
at::Tensor allgatherParam(long ds_id,
c10::intrusive_ptr<c10d::symmetric_memory::SymmetricMemory> symm_mem)
{
if (param_registry_->isValid(ds_id)) { return param_registry_->getGatheredParam(ds_id); }
const DSParam& param = param_registry_->getParam(ds_id);
const at::Tensor& ds_tensor = param.getDSTensor();
at::Tensor output_buf = param_registry_->hasGatheredParam(ds_id)
? param_registry_->getGatheredParam(ds_id)
: torch::empty(param.getShape(), ds_tensor.options());
const int world_size = process_group_->getSize();
const int64_t true_numel = static_cast<int64_t>(productDim(param.getShape()));
const int64_t padded_per_rank = (true_numel + world_size - 1) / world_size;
const int64_t padded_numel = static_cast<int64_t>(world_size) * padded_per_rank;
if (param_registry_->isValid(ds_id)) {
// Return a view sliced to the true size with the original shape
auto base = param_registry_->getGatheredParam(ds_id);
return base.flatten()
.index({torch::indexing::Slice(0, true_numel)})
.view(param.getShape());
}
at::Tensor output_buf;
if (param_registry_->hasGatheredParam(ds_id)) {
auto existing = param_registry_->getGatheredParam(ds_id);
if (existing.defined() && existing.numel() == padded_numel) { output_buf = existing; }
}
if (!output_buf.defined()) {
at::cuda::CUDAStreamGuard guard(ag_stream_);
output_buf = torch::empty({padded_numel}, ds_tensor.options());
}
assert(hasKey(ag_comp_done_events_, ds_id));
ag_comp_done_events_[ds_id]->record();
@ -119,7 +142,10 @@ public:
launchAllGather(output_buf, ds_id, symm_mem);
ag_comm_done_events_[ds_id]->record(ag_stream_);
return output_buf;
// Return a view of the gathered padded buffer matching the true param shape
return output_buf.flatten()
.index({torch::indexing::Slice(0, true_numel)})
.view(param.getShape());
}
void prefetchParamsFused(std::vector<int64_t> ds_ids,
@ -133,11 +159,19 @@ public:
std::unordered_map<long, at::Tensor> output_bufs;
for (long ds_id : invalid_ds_ids) {
const DSParam& param = param_registry_->getParam(ds_id);
const at::Tensor& ds_tensor = param.getDSTensor();
const int world_size = process_group_->getSize();
const int64_t shard_elems = ds_tensor.numel();
const int64_t padded_numel = static_cast<int64_t>(world_size) * shard_elems;
if (param_registry_->hasGatheredParam(ds_id)) {
output_bufs[ds_id] = param_registry_->getGatheredParam(ds_id);
} else {
output_bufs[ds_id] = torch::empty(param.getShape(), param.getDSTensor().options());
auto existing = param_registry_->getGatheredParam(ds_id);
if (existing.defined() && existing.numel() == padded_numel) {
output_bufs[ds_id] = existing;
continue;
}
}
output_bufs[ds_id] = torch::empty({padded_numel}, ds_tensor.options());
}
for (long ds_id : invalid_ds_ids) {
@ -383,6 +417,43 @@ void register_z3_param(long ds_id,
{
param_registry->registerParam(ds_id, ds_shape, ds_tensor, grad_buffer, true, 0, persistent);
if (persistent) { param_registry->registerGatheredParam(ds_id, ds_tensor); }
// Validate that padded shard sizes are uniform across ranks at registration time
// DeepSpeed pads parameters to ensure even division, so we check the padded size
// which should be uniform across all ranks for correct allgather behavior
const int64_t local_count = ds_tensor.numel();
const int world_size = process_group->getSize();
// Calculate padded size (aligned to world_size)
// Use ds_shape to compute the full (unpartitioned) parameter size
int64_t total_numel = 1;
for (const auto dim : ds_shape) { total_numel *= dim; }
const int64_t padded_per_rank = (total_numel + world_size - 1) / world_size;
// For verification: all ranks should have the same padded size
auto count_options = at::TensorOptions().dtype(at::kLong).device(at::kCUDA);
at::Tensor local_padded_tensor = torch::tensor({padded_per_rank}, count_options);
std::vector<at::Tensor> all_padded_counts(world_size);
for (int i = 0; i < world_size; ++i) {
all_padded_counts[i] = torch::empty_like(local_padded_tensor);
}
// Build lvalue buffers for output and input as required by ProcessGroup::allgather
// The first argument must be a single-element vector containing a vector of WORLD_SIZE tensors
std::vector<std::vector<at::Tensor>> output_tensors(1);
output_tensors[0] = all_padded_counts;
std::vector<at::Tensor> input_tensors = {local_padded_tensor};
process_group->allgather(output_tensors, input_tensors)->wait();
// Verify all ranks agree on the padded size
for (int i = 0; i < world_size; ++i) {
int64_t padded_count = all_padded_counts[i].to(torch::kCPU).item<int64_t>();
if (padded_count != padded_per_rank) {
throw std::runtime_error(
"ZeRO-3 registration error: inconsistent padded shard sizes across ranks. "
"This is an internal error - please report this issue.");
}
}
}
at::Tensor allgather_param(at::Tensor param_tensor, long graph_id, long ds_id)

View File

@ -122,7 +122,7 @@ class ProfilingInterpreter(Interpreter):
n.meta["device_time"] = 0.0
n.meta["wall_time"] = 0.0
n.meta["alloc_mem"] = 0
n.meta["max_memory"] = 0
n.meta["max_mem"] = 0
n.meta["tensor_size"] = _node_size(n)
return super().run_node(n)

View File

@ -116,3 +116,36 @@ class TestDeepCompile(DistributedTest):
# Need warmup steps
compare_loss(self, config_dict, dtype, iteration=10)
@pytest.mark.parametrize('dtype', [torch.float32])
@pytest.mark.parametrize('zero_stage', [3])
def test_padded_shard_handling(self, zero_stage, dtype):
"""Test that parameters with padding (uneven division) work correctly with DeepCompile"""
if not required_torch_version(min_version=2.6):
pytest.skip("DeepCompile requires PyTorch >= v2.6")
if get_accelerator().device_name() == "cpu":
pytest.skip("CPU does not support this test yet")
# Use a hidden dimension that requires padding when divided across ranks
# With world_size=2, a hidden_dim of 13 creates parameters that need padding
config_dict = {
"train_micro_batch_size_per_gpu": 1,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
}
},
"zero_optimization": {
"stage": zero_stage,
},
"compile": {
"deepcompile": True
}
}
# This should work correctly with our padding-aware implementation
# The test verifies that padded parameters are handled properly
compare_loss(self, config_dict, dtype, iteration=1, hidden_dim_override=13)

View File

@ -16,8 +16,8 @@ from unit.common import enable_determinism
@enable_determinism(123)
def compare_loss(self, config, dtype, iteration=5):
hidden_dim = 10
def compare_loss(self, config, dtype, iteration=5, hidden_dim_override=None):
hidden_dim = hidden_dim_override if hidden_dim_override is not None else 10
RTOL = 5e-1
ATOL = 1e-2