mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[PP] move FSDP reduce scatters to end of step (#165106)
Move FSDP reduce scatters to the end of the PP step. The reduce scatter compute stream sync blocks the other stages from executing their backwards leading to bubbles. There should be a way to execute these RS earlier, but doing this for now as a quick fix. <img width="1056" height="463" alt="image" src="https://github.com/user-attachments/assets/b945dd55-8ab1-4acc-b862-c6e2e476b834" /> Pull Request resolved: https://github.com/pytorch/pytorch/pull/165106 Approved by: https://github.com/weifengpy ghstack dependencies: #164976
This commit is contained in:
committed by
PyTorch MergeBot
parent
3a110c9bb2
commit
2beead7523
@ -146,6 +146,7 @@ class ComposabilityTest(MultiProcContinuousTest):
|
||||
total_layers,
|
||||
apply_dp,
|
||||
loss_fn,
|
||||
scale_grads=True,
|
||||
):
|
||||
if issubclass(ScheduleClass, PipelineScheduleSingle):
|
||||
pipeline_stage, offset = self._build_pp_stage(
|
||||
@ -163,6 +164,7 @@ class ComposabilityTest(MultiProcContinuousTest):
|
||||
pipeline_stage,
|
||||
n_microbatches=num_microbatches,
|
||||
loss_fn=loss_fn,
|
||||
scale_grads=scale_grads,
|
||||
)
|
||||
else:
|
||||
n_virtual = 2
|
||||
@ -185,6 +187,7 @@ class ComposabilityTest(MultiProcContinuousTest):
|
||||
stages,
|
||||
n_microbatches=num_microbatches,
|
||||
loss_fn=loss_fn,
|
||||
scale_grads=scale_grads,
|
||||
)
|
||||
return pipeline_schedule, partial_models, offsets
|
||||
|
||||
@ -523,8 +526,8 @@ class ComposabilityTest(MultiProcContinuousTest):
|
||||
runtime.pipeline_order_with_comms = unshard_schedule
|
||||
runtime.step(dummy_input)
|
||||
|
||||
# Verify parameters are now unsharded
|
||||
check_fsdp_unsharded_state(stage.submod, expected_unsharded=True)
|
||||
# Verify parameters are still sharded
|
||||
check_fsdp_unsharded_state(stage.submod, expected_unsharded=False)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(ComposabilityTest)
|
||||
|
@ -625,6 +625,10 @@ or equal to the number of stages ({self._num_stages})."
|
||||
# Run microbatches
|
||||
self._step_microbatches(args_split, kwargs_split, targets_split, losses)
|
||||
|
||||
# Stage post processing
|
||||
grad_scale_factor = self._n_microbatches if self.scale_grads else 1
|
||||
self._stage._post_backward(grad_scale_factor)
|
||||
|
||||
# Return merged results per original format
|
||||
if self._stage.is_last:
|
||||
return self._merge_outputs(self._stage.output_chunks)
|
||||
@ -773,10 +777,6 @@ class ScheduleGPipe(PipelineScheduleSingle):
|
||||
|
||||
logger.debug("[%s] Backwarded microbatch %s", self._stage.stage_index, i)
|
||||
|
||||
self._stage.scale_grads(
|
||||
grad_scale_factor=self._n_microbatches if self.scale_grads else 1
|
||||
)
|
||||
|
||||
# Wait for all backward sends to finish
|
||||
for work in bwd_sends_to_wait:
|
||||
_wait_batch_p2p(work)
|
||||
@ -951,10 +951,6 @@ class Schedule1F1B(PipelineScheduleSingle):
|
||||
send_work = _batch_p2p(bwd_sends, desc="bwd_send")
|
||||
bwd_mb_index += 1
|
||||
|
||||
self._stage.scale_grads(
|
||||
grad_scale_factor=self._n_microbatches if self.scale_grads else 1
|
||||
)
|
||||
|
||||
# Wait for the last backward send to finish
|
||||
_wait_batch_p2p(send_work)
|
||||
|
||||
@ -1555,6 +1551,12 @@ class PipelineScheduleMulti(_PipelineSchedule):
|
||||
# Run microbatches
|
||||
self._step_microbatches(args_split, kwargs_split, targets_split, losses)
|
||||
|
||||
# Stage post processing
|
||||
# TODO: remove this section and include as part of the schedule IR?
|
||||
for stage in self._stages:
|
||||
grad_scale_factor = self._n_microbatches if self.scale_grads else 1
|
||||
stage._post_backward(grad_scale_factor)
|
||||
|
||||
# Return merged results per original format
|
||||
for stage in self._stages:
|
||||
if stage.is_last:
|
||||
@ -2086,15 +2088,12 @@ BACKWARD_INPUT, BACKWARD_WEIGHT, and OVERLAP_F_B are supported."
|
||||
loss = self._maybe_get_loss(stage, mb_index)
|
||||
backward_counter[stage_idx] += 1
|
||||
last_backward = backward_counter[stage_idx] == self._n_microbatches
|
||||
grad_scale_factor = self._n_microbatches if self.scale_grads else 1
|
||||
stage.backward_one_chunk(
|
||||
mb_index,
|
||||
loss=loss,
|
||||
full_backward=True,
|
||||
last_backward=last_backward,
|
||||
)
|
||||
if last_backward:
|
||||
stage.scale_grads(grad_scale_factor)
|
||||
# SEND/RECV op are avoided for special case with 2 adjacent stages on same rank
|
||||
# see [Note: V-schedule special case]
|
||||
if is_prev_stage_on_this_rank:
|
||||
@ -2131,13 +2130,10 @@ BACKWARD_INPUT, BACKWARD_WEIGHT, and OVERLAP_F_B are supported."
|
||||
_assert_unsharded(stage_idx)
|
||||
backward_counter[stage_idx] += 1
|
||||
last_backward = backward_counter[stage_idx] == self._n_microbatches
|
||||
grad_scale_factor = self._n_microbatches if self.scale_grads else 1
|
||||
stage.backward_weight_one_chunk(
|
||||
mb_index,
|
||||
last_backward=last_backward,
|
||||
)
|
||||
if last_backward:
|
||||
stage.scale_grads(grad_scale_factor)
|
||||
else:
|
||||
raise ValueError(f"{action=} is unknown or unsupported")
|
||||
|
||||
|
@ -651,28 +651,6 @@ class _PipelineStageBase(ABC):
|
||||
self.submod.set_reshard_after_backward(False)
|
||||
self.submod.set_requires_gradient_sync(False)
|
||||
result = perform_backward(backward_type)()
|
||||
if last_backward:
|
||||
# Manually call post backward for FSDP
|
||||
def run_post_backward(fsdp_module: FSDPModule) -> None:
|
||||
fsdp_module.set_is_last_backward(True)
|
||||
fsdp_module.set_reshard_after_backward(True)
|
||||
fsdp_module.set_requires_gradient_sync(True)
|
||||
|
||||
if isinstance(fsdp_module, ReplicateModule):
|
||||
distributed_state = replicate.state(fsdp_module) # type: ignore[arg-type]
|
||||
else:
|
||||
distributed_state = fully_shard.state(fsdp_module) # type: ignore[attr-defined]
|
||||
|
||||
for state in distributed_state._state_ctx.all_states:
|
||||
if state._fsdp_param_group:
|
||||
state._fsdp_param_group.post_backward()
|
||||
|
||||
# it would be much better if pipelining backward invoked .backward so autograd hooks
|
||||
# worked and modules like DDP/FSDP behaved as expected. Working around this for the time being,
|
||||
# we need to call this too to ensure FSDP syncs its grad reduction ops back to the default stream.
|
||||
distributed_state._root_post_backward_final_callback()
|
||||
|
||||
run_post_backward(self.submod)
|
||||
|
||||
else:
|
||||
# Non-DP submodule, regular backward
|
||||
@ -998,6 +976,31 @@ class _PipelineStageBase(ABC):
|
||||
|
||||
return ops
|
||||
|
||||
def _post_backward(self, grad_scale_factor: int):
|
||||
# Manually call post backward for FSDP
|
||||
if isinstance(self.submod, FSDPModule):
|
||||
fsdp_module = self.submod
|
||||
fsdp_module.set_is_last_backward(True)
|
||||
fsdp_module.set_reshard_after_backward(True)
|
||||
fsdp_module.set_requires_gradient_sync(True)
|
||||
|
||||
if isinstance(fsdp_module, ReplicateModule):
|
||||
distributed_state = replicate.state(fsdp_module) # type: ignore[arg-type]
|
||||
else:
|
||||
distributed_state = fully_shard.state(fsdp_module) # type: ignore[attr-defined]
|
||||
|
||||
for state in distributed_state._state_ctx.all_states:
|
||||
if state._fsdp_param_group:
|
||||
state._fsdp_param_group.post_backward()
|
||||
|
||||
# it would be much better if pipelining backward invoked .backward so autograd hooks
|
||||
# worked and modules like DDP/FSDP behaved as expected. Working around this for the time being,
|
||||
# we need to call this too to ensure FSDP syncs its grad reduction ops back to the default stream.
|
||||
distributed_state._root_post_backward_final_callback()
|
||||
# Call gradient scaling at the end of the backward pass
|
||||
# NOTE: this must happen after FSDP post_backward is FSDP is enabled
|
||||
self.scale_grads(grad_scale_factor)
|
||||
|
||||
|
||||
class _PipelineStage(_PipelineStageBase):
|
||||
def __init__(
|
||||
|
Reference in New Issue
Block a user