[PP] move FSDP reduce scatters to end of step (#165106)

Move FSDP reduce scatters to the end of the PP step. The reduce scatter compute stream sync blocks the other stages from executing their backwards leading to bubbles. There should be a way to execute these RS earlier, but doing this for now as a quick fix.

<img width="1056" height="463" alt="image" src="https://github.com/user-attachments/assets/b945dd55-8ab1-4acc-b862-c6e2e476b834" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165106
Approved by: https://github.com/weifengpy
ghstack dependencies: #164976
This commit is contained in:
Howard Huang
2025-10-11 07:32:20 -07:00
committed by PyTorch MergeBot
parent 3a110c9bb2
commit 2beead7523
3 changed files with 40 additions and 38 deletions

View File

@ -146,6 +146,7 @@ class ComposabilityTest(MultiProcContinuousTest):
total_layers,
apply_dp,
loss_fn,
scale_grads=True,
):
if issubclass(ScheduleClass, PipelineScheduleSingle):
pipeline_stage, offset = self._build_pp_stage(
@ -163,6 +164,7 @@ class ComposabilityTest(MultiProcContinuousTest):
pipeline_stage,
n_microbatches=num_microbatches,
loss_fn=loss_fn,
scale_grads=scale_grads,
)
else:
n_virtual = 2
@ -185,6 +187,7 @@ class ComposabilityTest(MultiProcContinuousTest):
stages,
n_microbatches=num_microbatches,
loss_fn=loss_fn,
scale_grads=scale_grads,
)
return pipeline_schedule, partial_models, offsets
@ -523,8 +526,8 @@ class ComposabilityTest(MultiProcContinuousTest):
runtime.pipeline_order_with_comms = unshard_schedule
runtime.step(dummy_input)
# Verify parameters are now unsharded
check_fsdp_unsharded_state(stage.submod, expected_unsharded=True)
# Verify parameters are still sharded
check_fsdp_unsharded_state(stage.submod, expected_unsharded=False)
instantiate_parametrized_tests(ComposabilityTest)

View File

@ -625,6 +625,10 @@ or equal to the number of stages ({self._num_stages})."
# Run microbatches
self._step_microbatches(args_split, kwargs_split, targets_split, losses)
# Stage post processing
grad_scale_factor = self._n_microbatches if self.scale_grads else 1
self._stage._post_backward(grad_scale_factor)
# Return merged results per original format
if self._stage.is_last:
return self._merge_outputs(self._stage.output_chunks)
@ -773,10 +777,6 @@ class ScheduleGPipe(PipelineScheduleSingle):
logger.debug("[%s] Backwarded microbatch %s", self._stage.stage_index, i)
self._stage.scale_grads(
grad_scale_factor=self._n_microbatches if self.scale_grads else 1
)
# Wait for all backward sends to finish
for work in bwd_sends_to_wait:
_wait_batch_p2p(work)
@ -951,10 +951,6 @@ class Schedule1F1B(PipelineScheduleSingle):
send_work = _batch_p2p(bwd_sends, desc="bwd_send")
bwd_mb_index += 1
self._stage.scale_grads(
grad_scale_factor=self._n_microbatches if self.scale_grads else 1
)
# Wait for the last backward send to finish
_wait_batch_p2p(send_work)
@ -1555,6 +1551,12 @@ class PipelineScheduleMulti(_PipelineSchedule):
# Run microbatches
self._step_microbatches(args_split, kwargs_split, targets_split, losses)
# Stage post processing
# TODO: remove this section and include as part of the schedule IR?
for stage in self._stages:
grad_scale_factor = self._n_microbatches if self.scale_grads else 1
stage._post_backward(grad_scale_factor)
# Return merged results per original format
for stage in self._stages:
if stage.is_last:
@ -2086,15 +2088,12 @@ BACKWARD_INPUT, BACKWARD_WEIGHT, and OVERLAP_F_B are supported."
loss = self._maybe_get_loss(stage, mb_index)
backward_counter[stage_idx] += 1
last_backward = backward_counter[stage_idx] == self._n_microbatches
grad_scale_factor = self._n_microbatches if self.scale_grads else 1
stage.backward_one_chunk(
mb_index,
loss=loss,
full_backward=True,
last_backward=last_backward,
)
if last_backward:
stage.scale_grads(grad_scale_factor)
# SEND/RECV op are avoided for special case with 2 adjacent stages on same rank
# see [Note: V-schedule special case]
if is_prev_stage_on_this_rank:
@ -2131,13 +2130,10 @@ BACKWARD_INPUT, BACKWARD_WEIGHT, and OVERLAP_F_B are supported."
_assert_unsharded(stage_idx)
backward_counter[stage_idx] += 1
last_backward = backward_counter[stage_idx] == self._n_microbatches
grad_scale_factor = self._n_microbatches if self.scale_grads else 1
stage.backward_weight_one_chunk(
mb_index,
last_backward=last_backward,
)
if last_backward:
stage.scale_grads(grad_scale_factor)
else:
raise ValueError(f"{action=} is unknown or unsupported")

View File

@ -651,28 +651,6 @@ class _PipelineStageBase(ABC):
self.submod.set_reshard_after_backward(False)
self.submod.set_requires_gradient_sync(False)
result = perform_backward(backward_type)()
if last_backward:
# Manually call post backward for FSDP
def run_post_backward(fsdp_module: FSDPModule) -> None:
fsdp_module.set_is_last_backward(True)
fsdp_module.set_reshard_after_backward(True)
fsdp_module.set_requires_gradient_sync(True)
if isinstance(fsdp_module, ReplicateModule):
distributed_state = replicate.state(fsdp_module) # type: ignore[arg-type]
else:
distributed_state = fully_shard.state(fsdp_module) # type: ignore[attr-defined]
for state in distributed_state._state_ctx.all_states:
if state._fsdp_param_group:
state._fsdp_param_group.post_backward()
# it would be much better if pipelining backward invoked .backward so autograd hooks
# worked and modules like DDP/FSDP behaved as expected. Working around this for the time being,
# we need to call this too to ensure FSDP syncs its grad reduction ops back to the default stream.
distributed_state._root_post_backward_final_callback()
run_post_backward(self.submod)
else:
# Non-DP submodule, regular backward
@ -998,6 +976,31 @@ class _PipelineStageBase(ABC):
return ops
def _post_backward(self, grad_scale_factor: int):
# Manually call post backward for FSDP
if isinstance(self.submod, FSDPModule):
fsdp_module = self.submod
fsdp_module.set_is_last_backward(True)
fsdp_module.set_reshard_after_backward(True)
fsdp_module.set_requires_gradient_sync(True)
if isinstance(fsdp_module, ReplicateModule):
distributed_state = replicate.state(fsdp_module) # type: ignore[arg-type]
else:
distributed_state = fully_shard.state(fsdp_module) # type: ignore[attr-defined]
for state in distributed_state._state_ctx.all_states:
if state._fsdp_param_group:
state._fsdp_param_group.post_backward()
# it would be much better if pipelining backward invoked .backward so autograd hooks
# worked and modules like DDP/FSDP behaved as expected. Working around this for the time being,
# we need to call this too to ensure FSDP syncs its grad reduction ops back to the default stream.
distributed_state._root_post_backward_final_callback()
# Call gradient scaling at the end of the backward pass
# NOTE: this must happen after FSDP post_backward is FSDP is enabled
self.scale_grads(grad_scale_factor)
class _PipelineStage(_PipelineStageBase):
def __init__(