mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
PR #6993 replaces the flat IPG buffers with a dict maintaining type-indexed buckets. The member is also renamed from `_ipg_bucket_flat_buffer` to `ipg_buckets`. Update the bucket clearing logic in `init_z3` accordingly. Signed-off-by: Junjie Mao <junjie.mao@linux.alibaba.com>
96 lines
3.5 KiB
Python
96 lines
3.5 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
# DeepSpeed Team
|
|
|
|
import torch
|
|
|
|
from deepspeed import comm as dist
|
|
from deepspeed.accelerator import get_accelerator
|
|
from deepspeed.runtime.zero.partition_parameters import InsertPostInitMethodToModuleSubClasses
|
|
from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload
|
|
|
|
from .passes import zero3_compile, prefetch, selective_gather, offload_parameters
|
|
from .backend import make_backend, launch_compile_passes, init_schedule
|
|
from .patch_fake_tensor import patch_fake_tensor
|
|
from .util import get_deepcompile_handle, add_pre_backward_hook
|
|
|
|
WARMUP = 5
|
|
|
|
|
|
def init_z3(engine, backend, compile_config, compile_kwargs, schedule=None):
|
|
|
|
optimizer = engine.optimizer
|
|
use_opt = not isinstance(optimizer, DeepSpeedZeRoOffload)
|
|
|
|
if use_opt and hasattr(optimizer, "ipg_buckets"):
|
|
optimizer.ipg_buckets.clear()
|
|
get_accelerator().empty_cache()
|
|
|
|
dc = get_deepcompile_handle()
|
|
dc.init(engine.data_parallel_group, compile_config, engine.zero_reduce_bucket_size())
|
|
|
|
# Unset hooks
|
|
for m in engine.module.modules():
|
|
m._parameters = m._original_parameters
|
|
|
|
if use_opt:
|
|
optimizer.parameter_offload._remove_module_hooks()
|
|
|
|
for hook in optimizer._grad_acc_hooks:
|
|
hook.remove()
|
|
optimizer._grad_acc_hooks.clear()
|
|
|
|
# Unpatch linear
|
|
if hasattr(InsertPostInitMethodToModuleSubClasses, "linear_bk"):
|
|
torch.nn.functional.linear = InsertPostInitMethodToModuleSubClasses.linear_bk
|
|
|
|
if compile_config.symmetric_memory:
|
|
group_name = engine.data_parallel_group.group_name
|
|
dist.enable_symm_mem_for_group(group_name)
|
|
|
|
for p in engine.module.parameters():
|
|
grad_buffer = torch.Tensor()
|
|
if use_opt:
|
|
grad_buffer = optimizer._DeepSpeedZeroOptimizer_Stage3__param_id_to_grad_partition[p.ds_id]
|
|
|
|
# Disable persistent param
|
|
p.ds_persist = False
|
|
dc.register_z3_param(p.ds_id, p.ds_shape, p.ds_tensor, grad_buffer, p.ds_persist)
|
|
|
|
if schedule is None:
|
|
schedule = []
|
|
if (compile_config.offload_parameters):
|
|
schedule.append((0, [zero3_compile.add_z3_gather_release, offload_parameters.offload_parameter_fwd]))
|
|
else:
|
|
schedule.append((0, [zero3_compile.add_z3_gather_release]))
|
|
schedule.append(
|
|
(WARMUP,
|
|
[zero3_compile.add_z3_gather_release, prefetch.schedule_prefetch, selective_gather.selective_gather]))
|
|
|
|
init_schedule(schedule)
|
|
|
|
if use_opt:
|
|
|
|
def set_grad_buffer():
|
|
for i, sub_group in enumerate(optimizer.fp16_groups):
|
|
optimizer.averaged_gradients[i] = [
|
|
optimizer._DeepSpeedZeroOptimizer_Stage3__param_id_to_grad_partition[param.ds_id]
|
|
if param.requires_grad else torch.zeros_like(param.ds_tensor) for param in sub_group
|
|
]
|
|
|
|
add_pre_backward_hook(set_grad_buffer)
|
|
|
|
# offloading opt states need additional setup
|
|
from .passes.offload_adam_states import move_opt_states, move_opt_states_sync, init_offload_opt_states
|
|
for _, passes in schedule:
|
|
if move_opt_states in passes or move_opt_states_sync in passes:
|
|
init_offload_opt_states(optimizer, dc)
|
|
|
|
engine.launch_compile_passes = launch_compile_passes
|
|
|
|
patch_fake_tensor()
|
|
torch._inductor.config.size_asserts = False
|
|
|
|
return make_backend(backend, compile_config, compile_kwargs=compile_kwargs)
|