mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
This PR introduces *DeepCompile*, a new feature that efficiently integrates compiler optimizations with other DeepSpeed features. DeepCompile utilizes torch's dynamo to capture the computation graph and modifies it to incorporate DeepSpeed’s optimizations seamlessly. Currently, DeepCompile supports ZeRO-1 and ZeRO-3, with enhancements such as proactive prefetching and selective unsharding to improve performance. (More details will be added later.) --------- Signed-off-by: Masahiro Tanaka <mtanaka@microsoft.com> Signed-off-by: Olatunji Ruwase <olruwase@microsoft.com> Co-authored-by: zafarsadiq <zafarsadiq120@gmail.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
95 lines
3.7 KiB
Python
95 lines
3.7 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
# DeepSpeed Team
|
|
|
|
import torch
|
|
|
|
from deepspeed import comm as dist
|
|
from deepspeed.accelerator import get_accelerator
|
|
from deepspeed.runtime.zero.partition_parameters import InsertPostInitMethodToModuleSubClasses
|
|
|
|
from .passes import zero3_compile, prefetch, selective_gather, offload_parameters
|
|
from .backend import make_backend, launch_compile_passes, init_schedule
|
|
from .patch_fake_tensor import patch_fake_tensor
|
|
from .util import get_deepcompile_handle, add_pre_backward_hook, is_backend_inductor
|
|
|
|
WARMUP = 5
|
|
|
|
|
|
def init_z3(engine, backend, compile_config, compile_kwargs, schedule=None):
|
|
|
|
optimizer = engine.optimizer
|
|
if optimizer is not None and hasattr(optimizer, '_DeepSpeedZeroOptimizer_Stage3__ipg_bucket_flat_buffer'):
|
|
optimizer._DeepSpeedZeroOptimizer_Stage3__ipg_bucket_flat_buffer = None
|
|
get_accelerator().empty_cache()
|
|
|
|
dc = get_deepcompile_handle()
|
|
dc.init(engine.data_parallel_group,
|
|
engine.zero_reduce_bucket_size(), compile_config.double_buffer, compile_config.symmetric_memory,
|
|
is_backend_inductor(backend), compile_config.sync_before_reduce, compile_config.sync_after_reduce,
|
|
compile_config.sync_before_allgather, compile_config.sync_after_allgather)
|
|
|
|
# Unset hooks
|
|
for m in engine.module.modules():
|
|
m._parameters = m._original_parameters
|
|
optimizer.parameter_offload._remove_module_hooks()
|
|
|
|
for hook in optimizer._grad_acc_hooks:
|
|
hook.remove()
|
|
optimizer._grad_acc_hooks.clear()
|
|
|
|
# Unpatch linear
|
|
if hasattr(InsertPostInitMethodToModuleSubClasses, "linear_bk"):
|
|
torch.nn.functional.linear = InsertPostInitMethodToModuleSubClasses.linear_bk
|
|
|
|
if compile_config.symmetric_memory:
|
|
group_name = engine.data_parallel_group.group_name
|
|
dist.enable_symm_mem_for_group(group_name)
|
|
|
|
for p in engine.module.parameters():
|
|
grad_buffer = optimizer._DeepSpeedZeroOptimizer_Stage3__param_id_to_grad_partition[p.ds_id]
|
|
|
|
# Disable persistent param
|
|
p.ds_persist = False
|
|
dc.register_z3_param(p.ds_id, p.ds_shape, p.ds_tensor, grad_buffer, p.ds_persist)
|
|
|
|
def set_grad_buffer():
|
|
for i, sub_group in enumerate(optimizer.fp16_groups):
|
|
optimizer.averaged_gradients[i] = [
|
|
optimizer._DeepSpeedZeroOptimizer_Stage3__param_id_to_grad_partition[param.ds_id]
|
|
if param.requires_grad else torch.zeros_like(param.ds_tensor) for param in sub_group
|
|
]
|
|
|
|
add_pre_backward_hook(set_grad_buffer)
|
|
|
|
if schedule is None:
|
|
schedule = []
|
|
if (compile_config.offload_parameters):
|
|
schedule.append((0, [zero3_compile.add_z3_gather_release, offload_parameters.offload_parameter_fwd]))
|
|
else:
|
|
schedule.append((0, [zero3_compile.add_z3_gather_release]))
|
|
schedule.append(
|
|
(WARMUP,
|
|
[zero3_compile.add_z3_gather_release, prefetch.schedule_prefetch, selective_gather.selective_gather]))
|
|
|
|
init_schedule(schedule)
|
|
|
|
# offloading opt states need additional setup
|
|
from .passes.offload_adam_states import move_opt_states, move_opt_states_sync, init_offload_opt_states
|
|
for _, passes in schedule:
|
|
if move_opt_states in passes or move_opt_states_sync in passes:
|
|
init_offload_opt_states(optimizer, dc)
|
|
|
|
engine.launch_compile_passes = launch_compile_passes
|
|
|
|
patch_fake_tensor()
|
|
free_activation = compile_config.free_activation and not is_backend_inductor(backend)
|
|
|
|
torch._inductor.config.size_asserts = False
|
|
|
|
return make_backend(backend,
|
|
compile_kwargs=compile_kwargs,
|
|
free_activation=free_activation,
|
|
debug_log=compile_config.debug_log)
|