mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
DeepCompile ZeRO-3: robust allgather for uneven shards; fix profiling… (#7489)
… meta key (max_mem) --------- Signed-off-by: Abhishek <dalakotiashu150@gmail.com> Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com> Co-authored-by: Abhishek <dalakotiashu150@gmail.com> Co-authored-by: Masahiro Tanaka <mtanaka@anyscale.com>
This commit is contained in:
@ -69,9 +69,15 @@ public:
|
||||
const at::Tensor& ds_tensor = param.getDSTensor();
|
||||
|
||||
if (symm_mem == nullptr) {
|
||||
// Fast path: assume uniform shard sizes (ZeRO-3 partitions are padded to uniform size)
|
||||
const int world_size = process_group_->getSize();
|
||||
const int64_t shard_elems = ds_tensor.numel();
|
||||
|
||||
// Perform all-gather directly into the pre-allocated padded output buffer
|
||||
// NCCL requires contiguous storage; use .contiguous() explicitly
|
||||
ncclResult_t result = ncclAllGather(ds_tensor.contiguous().data_ptr(),
|
||||
output_buf.data_ptr(),
|
||||
ds_tensor.numel(),
|
||||
shard_elems,
|
||||
get_nccl_data_type(ds_tensor.scalar_type()),
|
||||
nccl_comm_,
|
||||
ag_stream_);
|
||||
@ -104,13 +110,30 @@ public:
|
||||
at::Tensor allgatherParam(long ds_id,
|
||||
c10::intrusive_ptr<c10d::symmetric_memory::SymmetricMemory> symm_mem)
|
||||
{
|
||||
if (param_registry_->isValid(ds_id)) { return param_registry_->getGatheredParam(ds_id); }
|
||||
|
||||
const DSParam& param = param_registry_->getParam(ds_id);
|
||||
const at::Tensor& ds_tensor = param.getDSTensor();
|
||||
at::Tensor output_buf = param_registry_->hasGatheredParam(ds_id)
|
||||
? param_registry_->getGatheredParam(ds_id)
|
||||
: torch::empty(param.getShape(), ds_tensor.options());
|
||||
const int world_size = process_group_->getSize();
|
||||
const int64_t true_numel = static_cast<int64_t>(productDim(param.getShape()));
|
||||
const int64_t padded_per_rank = (true_numel + world_size - 1) / world_size;
|
||||
const int64_t padded_numel = static_cast<int64_t>(world_size) * padded_per_rank;
|
||||
|
||||
if (param_registry_->isValid(ds_id)) {
|
||||
// Return a view sliced to the true size with the original shape
|
||||
auto base = param_registry_->getGatheredParam(ds_id);
|
||||
return base.flatten()
|
||||
.index({torch::indexing::Slice(0, true_numel)})
|
||||
.view(param.getShape());
|
||||
}
|
||||
|
||||
at::Tensor output_buf;
|
||||
if (param_registry_->hasGatheredParam(ds_id)) {
|
||||
auto existing = param_registry_->getGatheredParam(ds_id);
|
||||
if (existing.defined() && existing.numel() == padded_numel) { output_buf = existing; }
|
||||
}
|
||||
if (!output_buf.defined()) {
|
||||
at::cuda::CUDAStreamGuard guard(ag_stream_);
|
||||
output_buf = torch::empty({padded_numel}, ds_tensor.options());
|
||||
}
|
||||
|
||||
assert(hasKey(ag_comp_done_events_, ds_id));
|
||||
ag_comp_done_events_[ds_id]->record();
|
||||
@ -119,7 +142,10 @@ public:
|
||||
launchAllGather(output_buf, ds_id, symm_mem);
|
||||
|
||||
ag_comm_done_events_[ds_id]->record(ag_stream_);
|
||||
return output_buf;
|
||||
// Return a view of the gathered padded buffer matching the true param shape
|
||||
return output_buf.flatten()
|
||||
.index({torch::indexing::Slice(0, true_numel)})
|
||||
.view(param.getShape());
|
||||
}
|
||||
|
||||
void prefetchParamsFused(std::vector<int64_t> ds_ids,
|
||||
@ -133,11 +159,19 @@ public:
|
||||
std::unordered_map<long, at::Tensor> output_bufs;
|
||||
for (long ds_id : invalid_ds_ids) {
|
||||
const DSParam& param = param_registry_->getParam(ds_id);
|
||||
const at::Tensor& ds_tensor = param.getDSTensor();
|
||||
const int world_size = process_group_->getSize();
|
||||
const int64_t shard_elems = ds_tensor.numel();
|
||||
const int64_t padded_numel = static_cast<int64_t>(world_size) * shard_elems;
|
||||
|
||||
if (param_registry_->hasGatheredParam(ds_id)) {
|
||||
output_bufs[ds_id] = param_registry_->getGatheredParam(ds_id);
|
||||
} else {
|
||||
output_bufs[ds_id] = torch::empty(param.getShape(), param.getDSTensor().options());
|
||||
auto existing = param_registry_->getGatheredParam(ds_id);
|
||||
if (existing.defined() && existing.numel() == padded_numel) {
|
||||
output_bufs[ds_id] = existing;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
output_bufs[ds_id] = torch::empty({padded_numel}, ds_tensor.options());
|
||||
}
|
||||
|
||||
for (long ds_id : invalid_ds_ids) {
|
||||
@ -383,6 +417,43 @@ void register_z3_param(long ds_id,
|
||||
{
|
||||
param_registry->registerParam(ds_id, ds_shape, ds_tensor, grad_buffer, true, 0, persistent);
|
||||
if (persistent) { param_registry->registerGatheredParam(ds_id, ds_tensor); }
|
||||
|
||||
// Validate that padded shard sizes are uniform across ranks at registration time
|
||||
// DeepSpeed pads parameters to ensure even division, so we check the padded size
|
||||
// which should be uniform across all ranks for correct allgather behavior
|
||||
const int64_t local_count = ds_tensor.numel();
|
||||
const int world_size = process_group->getSize();
|
||||
|
||||
// Calculate padded size (aligned to world_size)
|
||||
// Use ds_shape to compute the full (unpartitioned) parameter size
|
||||
int64_t total_numel = 1;
|
||||
for (const auto dim : ds_shape) { total_numel *= dim; }
|
||||
const int64_t padded_per_rank = (total_numel + world_size - 1) / world_size;
|
||||
|
||||
// For verification: all ranks should have the same padded size
|
||||
auto count_options = at::TensorOptions().dtype(at::kLong).device(at::kCUDA);
|
||||
at::Tensor local_padded_tensor = torch::tensor({padded_per_rank}, count_options);
|
||||
std::vector<at::Tensor> all_padded_counts(world_size);
|
||||
for (int i = 0; i < world_size; ++i) {
|
||||
all_padded_counts[i] = torch::empty_like(local_padded_tensor);
|
||||
}
|
||||
|
||||
// Build lvalue buffers for output and input as required by ProcessGroup::allgather
|
||||
// The first argument must be a single-element vector containing a vector of WORLD_SIZE tensors
|
||||
std::vector<std::vector<at::Tensor>> output_tensors(1);
|
||||
output_tensors[0] = all_padded_counts;
|
||||
std::vector<at::Tensor> input_tensors = {local_padded_tensor};
|
||||
process_group->allgather(output_tensors, input_tensors)->wait();
|
||||
|
||||
// Verify all ranks agree on the padded size
|
||||
for (int i = 0; i < world_size; ++i) {
|
||||
int64_t padded_count = all_padded_counts[i].to(torch::kCPU).item<int64_t>();
|
||||
if (padded_count != padded_per_rank) {
|
||||
throw std::runtime_error(
|
||||
"ZeRO-3 registration error: inconsistent padded shard sizes across ranks. "
|
||||
"This is an internal error - please report this issue.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
at::Tensor allgather_param(at::Tensor param_tensor, long graph_id, long ds_id)
|
||||
|
@ -122,7 +122,7 @@ class ProfilingInterpreter(Interpreter):
|
||||
n.meta["device_time"] = 0.0
|
||||
n.meta["wall_time"] = 0.0
|
||||
n.meta["alloc_mem"] = 0
|
||||
n.meta["max_memory"] = 0
|
||||
n.meta["max_mem"] = 0
|
||||
n.meta["tensor_size"] = _node_size(n)
|
||||
return super().run_node(n)
|
||||
|
||||
|
@ -116,3 +116,36 @@ class TestDeepCompile(DistributedTest):
|
||||
|
||||
# Need warmup steps
|
||||
compare_loss(self, config_dict, dtype, iteration=10)
|
||||
|
||||
@pytest.mark.parametrize('dtype', [torch.float32])
|
||||
@pytest.mark.parametrize('zero_stage', [3])
|
||||
def test_padded_shard_handling(self, zero_stage, dtype):
|
||||
"""Test that parameters with padding (uneven division) work correctly with DeepCompile"""
|
||||
if not required_torch_version(min_version=2.6):
|
||||
pytest.skip("DeepCompile requires PyTorch >= v2.6")
|
||||
|
||||
if get_accelerator().device_name() == "cpu":
|
||||
pytest.skip("CPU does not support this test yet")
|
||||
|
||||
# Use a hidden dimension that requires padding when divided across ranks
|
||||
# With world_size=2, a hidden_dim of 13 creates parameters that need padding
|
||||
config_dict = {
|
||||
"train_micro_batch_size_per_gpu": 1,
|
||||
"steps_per_print": 1,
|
||||
"optimizer": {
|
||||
"type": "Adam",
|
||||
"params": {
|
||||
"lr": 0.00015
|
||||
}
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": zero_stage,
|
||||
},
|
||||
"compile": {
|
||||
"deepcompile": True
|
||||
}
|
||||
}
|
||||
|
||||
# This should work correctly with our padding-aware implementation
|
||||
# The test verifies that padded parameters are handled properly
|
||||
compare_loss(self, config_dict, dtype, iteration=1, hidden_dim_override=13)
|
||||
|
@ -16,8 +16,8 @@ from unit.common import enable_determinism
|
||||
|
||||
|
||||
@enable_determinism(123)
|
||||
def compare_loss(self, config, dtype, iteration=5):
|
||||
hidden_dim = 10
|
||||
def compare_loss(self, config, dtype, iteration=5, hidden_dim_override=None):
|
||||
hidden_dim = hidden_dim_override if hidden_dim_override is not None else 10
|
||||
RTOL = 5e-1
|
||||
ATOL = 1e-2
|
||||
|
||||
|
Reference in New Issue
Block a user