Simplify leaf module hook (#7592)

This PR simplifies hooks for leaf module using PyTorch's API.

---------

Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
This commit is contained in:
Masahiro Tanaka
2025-09-27 10:12:15 -07:00
committed by GitHub
parent b75654001a
commit 6ea345ae27

View File

@ -17,7 +17,7 @@ from deepspeed.utils import groups, z3_leaf_parameter
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from deepspeed.runtime.base_optimizer import ZeROOptimizer
from deepspeed.utils import logger
from deepspeed.utils.torch import register_grad_hook
from deepspeed.utils.torch import register_grad_hook, required_torch_version
from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler
from deepspeed.runtime.torch_autocast import get_autocast_dtype, get_all_comm_dtypes, is_autocast_initialized, sort_dtypes
from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced, all_to_all_quant_reduce, all_to_all_loco_quant_reduce
@ -26,7 +26,7 @@ from deepspeed.runtime.zero.partition_parameters import *
from deepspeed.runtime.zero.config import ZeroStageEnum
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum
from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload
from deepspeed.runtime.zero.utils import apply_to_tensors_only, get_mapping_to_flat_buffer
from deepspeed.runtime.zero.utils import get_mapping_to_flat_buffer
from deepspeed.runtime.zero.offload_states import offload_adam_states, reload_adam_states
from deepspeed.ops.adam import DeepSpeedCPUAdam
from deepspeed.runtime.swap_tensor.partitioned_param_swapper import PartitionedParamStatus
@ -1214,59 +1214,21 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
# Partition the parameter after creating the hook
param.partition()
# We delay reduce-scatter for all gradients in the leaf modules until the backward pass of the leaf module is done
# We delay reduce for all gradients in the leaf modules until the backward pass of the leaf module is done
for leaf_module, leaf_parameters in self.leaf_parameters.items():
def wrapper_pre_hook(params):
def make_hook(params):
def forward_pre_hook(module, input):
"""Pre-forward hook to set backward hook on input tensors to the leaf module"""
module._leaf_module_inputs_remaining = 0
def reduce_leaf_module_grads(module, grad_input, grad_output):
for param in params:
if param.grad is None:
param.grad = torch.zeros_like(param)
self.reduce_ready_partitions_and_remove_grads(param)
@instrument_w_nvtx
def reduce_leaf_module_grads(grad):
module._leaf_module_inputs_remaining -= 1
# Make sure everything is done in the leaf module
if module._leaf_module_inputs_remaining == 0:
for param in params:
if param.grad is None:
param.grad = torch.zeros_like(param)
self.reduce_ready_partitions_and_remove_grads(param)
return reduce_leaf_module_grads
def set_module_bwd_hook(tensor):
if tensor.requires_grad:
module._leaf_module_inputs_remaining += 1
tensor.register_hook(reduce_leaf_module_grads)
return tensor
output = apply_to_tensors_only(set_module_bwd_hook, input)
return output
return forward_pre_hook
def wrapper_post_hook():
def forward_post_hook(module, input, output):
"""Pre-forward hook to set backward hook on input tensors to the leaf module"""
module._leaf_output_required_grad_num = 0
def increment_rg_count_bwd_hook(tensor):
if tensor.requires_grad:
module._leaf_output_required_grad_num += 1
return tensor
apply_to_tensors_only(increment_rg_count_bwd_hook, output)
if module._leaf_module_inputs_remaining == 0 and module._leaf_output_required_grad_num > 0:
raise RuntimeError(
"A module cannot be set as a leaf module when it does not have any input tensors that require gradients and has output tensors that require gradients. This is because the gradient reduction hook will not be called in this case."
)
return forward_post_hook
self._leaf_module_hooks.append(leaf_module.register_forward_pre_hook(wrapper_pre_hook(leaf_parameters)))
self._leaf_module_hooks.append(leaf_module.register_forward_hook(wrapper_post_hook()))
assert required_torch_version(min_version=1.8), "Leaf module requires PyTorch >= 1.8"
self._leaf_module_hooks.append(leaf_module.register_full_backward_hook(make_hook(leaf_parameters)))
print_rank_0('[End] Create gradient reduction hooks')