Deepcompile: Make size of activation to free configurable (#7582)

In deepcompile free-activation mode, only activations larger than a
threshold are eagerly freed. The threshold is hardcoded today and thus
may not be suitable in all cases.

This PR first generalizes the dc.init() interface to take the whole
compile_config object, and then converts the threshold into a config
item.

This corresponds to issue 3 of #7577.

---------

Signed-off-by: Junjie Mao <junjie.mao@linux.alibaba.com>
This commit is contained in:
Junjie Mao
2025-09-24 09:37:46 +08:00
committed by GitHub
parent bc9ed477e9
commit 17d80ce440
6 changed files with 57 additions and 32 deletions

View File

@ -21,6 +21,8 @@ bool clone_custom_op_output;
bool profile = false;
bool pre_div_reduce = true;
int64_t free_activation_threshold;
bool sync_before_reduce; // for debugging
bool sync_after_reduce; // for debugging
bool sync_before_allgather; // for debugging
@ -108,11 +110,9 @@ at::Tensor reduce_grad_meta(at::Tensor grad_tensor, long graph_id, long ds_id)
void free_tensors(std::vector<at::Tensor> tensors)
{
int64_t THRESHOLD = 10 * 1024 * 1024;
if (!profile) {
for (auto& tensor : tensors) {
if (tensor.is_cuda() && tensor.numel() > THRESHOLD) {
if (tensor.is_cuda() && tensor.numel() > free_activation_threshold) {
tensor.record_stream(at::cuda::getCurrentCUDAStream());
tensor.set_data(torch::empty({0}, tensor.options()));
}
@ -122,15 +122,16 @@ void free_tensors(std::vector<at::Tensor> tensors)
void free_tensors_meta(std::vector<at::Tensor> tensors) {}
template <typename T>
static T get_config(pybind11::object& config, const char* name)
{
return pybind11::getattr(config, name).cast<T>();
}
void init(c10::intrusive_ptr<c10d::ProcessGroup> pg,
pybind11::object& config,
int64_t initial_reduce_bucket_size,
bool enable_double_buffer,
bool _use_symm_mem,
bool _clone_custom_op_output,
bool _sync_before_reduce,
bool _sync_after_reduce,
bool _sync_before_allgather,
bool _sync_after_allgather)
bool _clone_custom_op_output)
{
process_group = pg;
@ -153,15 +154,16 @@ void init(c10::intrusive_ptr<c10d::ProcessGroup> pg,
ncclCommInitRank(&nccl_comm, process_group->getSize(), ncclID, process_group->getRank());
param_registry = std::make_shared<DSParamRegistry>();
reduce_buckets = std::make_shared<DoubleBufferedReduceBucket>(initial_reduce_bucket_size,
enable_double_buffer);
use_symm_mem = _use_symm_mem;
reduce_buckets = std::make_shared<DoubleBufferedReduceBucket>(
initial_reduce_bucket_size, get_config<bool>(config, "double_buffer"));
use_symm_mem = get_config<bool>(config, "symmetric_memory");
clone_custom_op_output = _clone_custom_op_output;
free_activation_threshold = get_config<int64_t>(config, "free_activation_threshold");
sync_before_reduce = _sync_before_reduce;
sync_after_reduce = _sync_after_reduce;
sync_before_allgather = _sync_before_allgather;
sync_after_allgather = _sync_after_allgather;
sync_before_reduce = get_config<bool>(config, "sync_before_reduce");
sync_after_reduce = get_config<bool>(config, "sync_after_reduce");
sync_before_allgather = get_config<bool>(config, "sync_before_allgather");
sync_after_allgather = get_config<bool>(config, "sync_after_allgather");
}
void start_forward()

View File

@ -590,14 +590,9 @@ void free_tensors(std::vector<at::Tensor> tensors);
void free_tensors_meta(std::vector<at::Tensor> tensors);
void init(c10::intrusive_ptr<c10d::ProcessGroup> pg,
pybind11::object& config,
int64_t initial_reduce_bucket_size,
bool enable_double_buffer,
bool _use_symm_mem,
bool _clone_custom_op_output,
bool _sync_before_reduce,
bool _sync_after_reduce,
bool _sync_before_allgather,
bool _sync_after_allgather);
bool _clone_custom_op_output);
void reset();
void cleanup();

View File

@ -15,6 +15,9 @@ class CompileConfig(DeepSpeedConfigModel):
free_activation: bool = False
""" Turn on/off the free activation mode """
free_activation_threshold: int = 10 * 1024 * 1024
""" In free activation mode, activations no less than this threshold (in byte) are eagerly freed """
offload_activation: bool = False
""" Turn on/off the activation offloading """

View File

@ -24,10 +24,7 @@ def init_z1(engine, backend, compile_config, compile_kwargs, schedule=None, use_
optimizer._grad_acc_hooks.clear()
dc = get_deepcompile_handle()
dc.init(engine.data_parallel_group,
engine.zero_reduce_bucket_size(), compile_config.double_buffer, compile_config.symmetric_memory,
is_backend_inductor(backend), compile_config.sync_before_reduce, compile_config.sync_after_reduce, False,
False)
dc.init(engine.data_parallel_group, compile_config, engine.zero_reduce_bucket_size(), is_backend_inductor(backend))
grad_buffer = {}

View File

@ -28,10 +28,7 @@ def init_z3(engine, backend, compile_config, compile_kwargs, schedule=None):
get_accelerator().empty_cache()
dc = get_deepcompile_handle()
dc.init(engine.data_parallel_group,
engine.zero_reduce_bucket_size(), compile_config.double_buffer, compile_config.symmetric_memory,
is_backend_inductor(backend), compile_config.sync_before_reduce, compile_config.sync_after_reduce,
compile_config.sync_before_allgather, compile_config.sync_after_allgather)
dc.init(engine.data_parallel_group, compile_config, engine.zero_reduce_bucket_size(), is_backend_inductor(backend))
# Unset hooks
for m in engine.module.modules():

View File

@ -149,3 +149,34 @@ class TestDeepCompile(DistributedTest):
# This should work correctly with our padding-aware implementation
# The test verifies that padded parameters are handled properly
compare_loss(self, config_dict, dtype, iteration=1, hidden_dim_override=13)
@pytest.mark.parametrize('dtype', [torch.float32])
@pytest.mark.parametrize('zero_stage', [1, 3])
def test_free_activation_mode(self, zero_stage, dtype):
"""Test that eagerly free activations work correctly and the threshold is configurable"""
if not required_torch_version(min_version=2.6):
pytest.skip("DeepCompile requires PyTorch >= v2.6")
if get_accelerator().device_name() == "cpu":
pytest.skip("CPU does not support this test yet")
config_dict = {
"train_micro_batch_size_per_gpu": 1,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
}
},
"zero_optimization": {
"stage": zero_stage,
},
"compile": {
"deepcompile": True,
"free_activation": True,
"free_activation_threshold": 0,
}
}
compare_loss(self, config_dict, dtype)