Files
DeepSpeed/deepspeed/compile/init_z3.py
Masahiro Tanaka 227a60c0c4 DeepCompile for enhanced compiler integration (#7154)
This PR introduces *DeepCompile*, a new feature that efficiently
integrates compiler optimizations with other DeepSpeed features.
DeepCompile utilizes torch's dynamo to capture the computation graph and
modifies it to incorporate DeepSpeed’s optimizations seamlessly.

Currently, DeepCompile supports ZeRO-1 and ZeRO-3, with enhancements
such as proactive prefetching and selective unsharding to improve
performance.
(More details will be added later.)

---------

Signed-off-by: Masahiro Tanaka <mtanaka@microsoft.com>
Signed-off-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: zafarsadiq <zafarsadiq120@gmail.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
2025-04-16 04:33:53 +00:00

95 lines
3.7 KiB
Python

# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
from deepspeed import comm as dist
from deepspeed.accelerator import get_accelerator
from deepspeed.runtime.zero.partition_parameters import InsertPostInitMethodToModuleSubClasses
from .passes import zero3_compile, prefetch, selective_gather, offload_parameters
from .backend import make_backend, launch_compile_passes, init_schedule
from .patch_fake_tensor import patch_fake_tensor
from .util import get_deepcompile_handle, add_pre_backward_hook, is_backend_inductor
WARMUP = 5
def init_z3(engine, backend, compile_config, compile_kwargs, schedule=None):
optimizer = engine.optimizer
if optimizer is not None and hasattr(optimizer, '_DeepSpeedZeroOptimizer_Stage3__ipg_bucket_flat_buffer'):
optimizer._DeepSpeedZeroOptimizer_Stage3__ipg_bucket_flat_buffer = None
get_accelerator().empty_cache()
dc = get_deepcompile_handle()
dc.init(engine.data_parallel_group,
engine.zero_reduce_bucket_size(), compile_config.double_buffer, compile_config.symmetric_memory,
is_backend_inductor(backend), compile_config.sync_before_reduce, compile_config.sync_after_reduce,
compile_config.sync_before_allgather, compile_config.sync_after_allgather)
# Unset hooks
for m in engine.module.modules():
m._parameters = m._original_parameters
optimizer.parameter_offload._remove_module_hooks()
for hook in optimizer._grad_acc_hooks:
hook.remove()
optimizer._grad_acc_hooks.clear()
# Unpatch linear
if hasattr(InsertPostInitMethodToModuleSubClasses, "linear_bk"):
torch.nn.functional.linear = InsertPostInitMethodToModuleSubClasses.linear_bk
if compile_config.symmetric_memory:
group_name = engine.data_parallel_group.group_name
dist.enable_symm_mem_for_group(group_name)
for p in engine.module.parameters():
grad_buffer = optimizer._DeepSpeedZeroOptimizer_Stage3__param_id_to_grad_partition[p.ds_id]
# Disable persistent param
p.ds_persist = False
dc.register_z3_param(p.ds_id, p.ds_shape, p.ds_tensor, grad_buffer, p.ds_persist)
def set_grad_buffer():
for i, sub_group in enumerate(optimizer.fp16_groups):
optimizer.averaged_gradients[i] = [
optimizer._DeepSpeedZeroOptimizer_Stage3__param_id_to_grad_partition[param.ds_id]
if param.requires_grad else torch.zeros_like(param.ds_tensor) for param in sub_group
]
add_pre_backward_hook(set_grad_buffer)
if schedule is None:
schedule = []
if (compile_config.offload_parameters):
schedule.append((0, [zero3_compile.add_z3_gather_release, offload_parameters.offload_parameter_fwd]))
else:
schedule.append((0, [zero3_compile.add_z3_gather_release]))
schedule.append(
(WARMUP,
[zero3_compile.add_z3_gather_release, prefetch.schedule_prefetch, selective_gather.selective_gather]))
init_schedule(schedule)
# offloading opt states need additional setup
from .passes.offload_adam_states import move_opt_states, move_opt_states_sync, init_offload_opt_states
for _, passes in schedule:
if move_opt_states in passes or move_opt_states_sync in passes:
init_offload_opt_states(optimizer, dc)
engine.launch_compile_passes = launch_compile_passes
patch_fake_tensor()
free_activation = compile_config.free_activation and not is_backend_inductor(backend)
torch._inductor.config.size_asserts = False
return make_backend(backend,
compile_kwargs=compile_kwargs,
free_activation=free_activation,
debug_log=compile_config.debug_log)