mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
Simplify leaf module hook (#7592)
This PR simplifies hooks for leaf module using PyTorch's API. --------- Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com> Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
This commit is contained in:
@ -17,7 +17,7 @@ from deepspeed.utils import groups, z3_leaf_parameter
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
from deepspeed.runtime.base_optimizer import ZeROOptimizer
|
||||
from deepspeed.utils import logger
|
||||
from deepspeed.utils.torch import register_grad_hook
|
||||
from deepspeed.utils.torch import register_grad_hook, required_torch_version
|
||||
from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler
|
||||
from deepspeed.runtime.torch_autocast import get_autocast_dtype, get_all_comm_dtypes, is_autocast_initialized, sort_dtypes
|
||||
from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced, all_to_all_quant_reduce, all_to_all_loco_quant_reduce
|
||||
@ -26,7 +26,7 @@ from deepspeed.runtime.zero.partition_parameters import *
|
||||
from deepspeed.runtime.zero.config import ZeroStageEnum
|
||||
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum
|
||||
from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload
|
||||
from deepspeed.runtime.zero.utils import apply_to_tensors_only, get_mapping_to_flat_buffer
|
||||
from deepspeed.runtime.zero.utils import get_mapping_to_flat_buffer
|
||||
from deepspeed.runtime.zero.offload_states import offload_adam_states, reload_adam_states
|
||||
from deepspeed.ops.adam import DeepSpeedCPUAdam
|
||||
from deepspeed.runtime.swap_tensor.partitioned_param_swapper import PartitionedParamStatus
|
||||
@ -1214,59 +1214,21 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
|
||||
# Partition the parameter after creating the hook
|
||||
param.partition()
|
||||
|
||||
# We delay reduce-scatter for all gradients in the leaf modules until the backward pass of the leaf module is done
|
||||
# We delay reduce for all gradients in the leaf modules until the backward pass of the leaf module is done
|
||||
for leaf_module, leaf_parameters in self.leaf_parameters.items():
|
||||
|
||||
def wrapper_pre_hook(params):
|
||||
def make_hook(params):
|
||||
|
||||
def forward_pre_hook(module, input):
|
||||
"""Pre-forward hook to set backward hook on input tensors to the leaf module"""
|
||||
module._leaf_module_inputs_remaining = 0
|
||||
def reduce_leaf_module_grads(module, grad_input, grad_output):
|
||||
for param in params:
|
||||
if param.grad is None:
|
||||
param.grad = torch.zeros_like(param)
|
||||
self.reduce_ready_partitions_and_remove_grads(param)
|
||||
|
||||
@instrument_w_nvtx
|
||||
def reduce_leaf_module_grads(grad):
|
||||
module._leaf_module_inputs_remaining -= 1
|
||||
# Make sure everything is done in the leaf module
|
||||
if module._leaf_module_inputs_remaining == 0:
|
||||
for param in params:
|
||||
if param.grad is None:
|
||||
param.grad = torch.zeros_like(param)
|
||||
self.reduce_ready_partitions_and_remove_grads(param)
|
||||
return reduce_leaf_module_grads
|
||||
|
||||
def set_module_bwd_hook(tensor):
|
||||
if tensor.requires_grad:
|
||||
module._leaf_module_inputs_remaining += 1
|
||||
tensor.register_hook(reduce_leaf_module_grads)
|
||||
return tensor
|
||||
|
||||
output = apply_to_tensors_only(set_module_bwd_hook, input)
|
||||
|
||||
return output
|
||||
|
||||
return forward_pre_hook
|
||||
|
||||
def wrapper_post_hook():
|
||||
|
||||
def forward_post_hook(module, input, output):
|
||||
"""Pre-forward hook to set backward hook on input tensors to the leaf module"""
|
||||
module._leaf_output_required_grad_num = 0
|
||||
|
||||
def increment_rg_count_bwd_hook(tensor):
|
||||
if tensor.requires_grad:
|
||||
module._leaf_output_required_grad_num += 1
|
||||
return tensor
|
||||
|
||||
apply_to_tensors_only(increment_rg_count_bwd_hook, output)
|
||||
|
||||
if module._leaf_module_inputs_remaining == 0 and module._leaf_output_required_grad_num > 0:
|
||||
raise RuntimeError(
|
||||
"A module cannot be set as a leaf module when it does not have any input tensors that require gradients and has output tensors that require gradients. This is because the gradient reduction hook will not be called in this case."
|
||||
)
|
||||
|
||||
return forward_post_hook
|
||||
|
||||
self._leaf_module_hooks.append(leaf_module.register_forward_pre_hook(wrapper_pre_hook(leaf_parameters)))
|
||||
self._leaf_module_hooks.append(leaf_module.register_forward_hook(wrapper_post_hook()))
|
||||
assert required_torch_version(min_version=1.8), "Leaf module requires PyTorch >= 1.8"
|
||||
self._leaf_module_hooks.append(leaf_module.register_full_backward_hook(make_hook(leaf_parameters)))
|
||||
|
||||
print_rank_0('[End] Create gradient reduction hooks')
|
||||
|
||||
|
Reference in New Issue
Block a user