Fix gradient buffer access for DeepCompile Z1/2 (#7548)

The initialization of DeepCompile+Z1/2 now fails due to the change
introduced in #7509.

This PR resolves the issue by:
- Adding an argument to optimizer.get_flat_partition
- Skipping the entire allreduce function in the engine

---------

Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
This commit is contained in:
Masahiro Tanaka
2025-09-10 11:12:02 -07:00
committed by GitHub
parent 0012ff6ea8
commit 0e859aa0d3
2 changed files with 18 additions and 0 deletions

View File

@ -31,13 +31,23 @@ def init_z1(engine, backend, compile_config, compile_kwargs, schedule=None, use_
grad_buffer = {}
# Save original all_grad_tensors state as we temporarily modify it
original_all_grad_tensors = optimizer.all_grad_tensors.copy() if hasattr(optimizer, 'all_grad_tensors') else {}
for i, group in enumerate(optimizer.bit16_groups):
# Temporarily populate all_grad_tensors for get_flat_partition call
# This is needed because get_flat_partition accesses all_grad_tensors[param_group_idx][i]
# but it's empty during initialization
if i not in optimizer.all_grad_tensors or optimizer.all_grad_tensors[i] is None:
optimizer.all_grad_tensors[i] = optimizer.get_all_grad_tensors(optimizer.params_in_partition[i],
optimizer.gradient_accumulation_dtype)
grad_buffer[i] = optimizer.get_flat_partition(optimizer.params_in_partition[i],
optimizer.first_offset[i],
optimizer.partition_size[i],
dtype=optimizer.gradient_accumulation_dtype,
device=get_accelerator().current_device_name(),
param_group_idx=i,
return_tensor_list=True)
grad_buffer[i] = [p.clone().detach() for p in grad_buffer[i]] # Maybe not necessary
@ -59,6 +69,9 @@ def init_z1(engine, backend, compile_config, compile_kwargs, schedule=None, use_
# print(f"[r{dist.get_rank()}] Registering group {i} param {param_id} in_partition={in_partition} p={p.shape} buf=None")
dc.register_param(p.param_id, p.shape, p, torch.empty([0], dtype=p.dtype, device=p.device), 0)
# Restore original all_grad_tensors state
optimizer.all_grad_tensors = original_all_grad_tensors
def set_grad_buffer():
optimizer.averaged_gradients = copy.copy(grad_buffer)

View File

@ -2199,6 +2199,11 @@ class DeepSpeedEngine(Module):
@instrument_w_nvtx
def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE):
# Skip gradient reduction when DeepCompile is enabled
# DeepCompile handles its own gradient reduction through compiled graph operations
if self.is_deepcompile_enabled():
return
# Pass (PP) gas boundary flag to optimizer (required for zero)
self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary()
# ZeRO stage >= 2 communicates during non gradient accumulation boundaries as well