diff --git a/test/distributed/test_composability.py b/test/distributed/test_composability.py index aa6d89501fbb..3508a43cb548 100644 --- a/test/distributed/test_composability.py +++ b/test/distributed/test_composability.py @@ -146,6 +146,7 @@ class ComposabilityTest(MultiProcContinuousTest): total_layers, apply_dp, loss_fn, + scale_grads=True, ): if issubclass(ScheduleClass, PipelineScheduleSingle): pipeline_stage, offset = self._build_pp_stage( @@ -163,6 +164,7 @@ class ComposabilityTest(MultiProcContinuousTest): pipeline_stage, n_microbatches=num_microbatches, loss_fn=loss_fn, + scale_grads=scale_grads, ) else: n_virtual = 2 @@ -185,6 +187,7 @@ class ComposabilityTest(MultiProcContinuousTest): stages, n_microbatches=num_microbatches, loss_fn=loss_fn, + scale_grads=scale_grads, ) return pipeline_schedule, partial_models, offsets @@ -523,8 +526,8 @@ class ComposabilityTest(MultiProcContinuousTest): runtime.pipeline_order_with_comms = unshard_schedule runtime.step(dummy_input) - # Verify parameters are now unsharded - check_fsdp_unsharded_state(stage.submod, expected_unsharded=True) + # Verify parameters are still sharded + check_fsdp_unsharded_state(stage.submod, expected_unsharded=False) instantiate_parametrized_tests(ComposabilityTest) diff --git a/torch/distributed/pipelining/schedules.py b/torch/distributed/pipelining/schedules.py index c9520f660681..b99afdf73187 100644 --- a/torch/distributed/pipelining/schedules.py +++ b/torch/distributed/pipelining/schedules.py @@ -625,6 +625,10 @@ or equal to the number of stages ({self._num_stages})." # Run microbatches self._step_microbatches(args_split, kwargs_split, targets_split, losses) + # Stage post processing + grad_scale_factor = self._n_microbatches if self.scale_grads else 1 + self._stage._post_backward(grad_scale_factor) + # Return merged results per original format if self._stage.is_last: return self._merge_outputs(self._stage.output_chunks) @@ -773,10 +777,6 @@ class ScheduleGPipe(PipelineScheduleSingle): logger.debug("[%s] Backwarded microbatch %s", self._stage.stage_index, i) - self._stage.scale_grads( - grad_scale_factor=self._n_microbatches if self.scale_grads else 1 - ) - # Wait for all backward sends to finish for work in bwd_sends_to_wait: _wait_batch_p2p(work) @@ -951,10 +951,6 @@ class Schedule1F1B(PipelineScheduleSingle): send_work = _batch_p2p(bwd_sends, desc="bwd_send") bwd_mb_index += 1 - self._stage.scale_grads( - grad_scale_factor=self._n_microbatches if self.scale_grads else 1 - ) - # Wait for the last backward send to finish _wait_batch_p2p(send_work) @@ -1555,6 +1551,12 @@ class PipelineScheduleMulti(_PipelineSchedule): # Run microbatches self._step_microbatches(args_split, kwargs_split, targets_split, losses) + # Stage post processing + # TODO: remove this section and include as part of the schedule IR? + for stage in self._stages: + grad_scale_factor = self._n_microbatches if self.scale_grads else 1 + stage._post_backward(grad_scale_factor) + # Return merged results per original format for stage in self._stages: if stage.is_last: @@ -2086,15 +2088,12 @@ BACKWARD_INPUT, BACKWARD_WEIGHT, and OVERLAP_F_B are supported." loss = self._maybe_get_loss(stage, mb_index) backward_counter[stage_idx] += 1 last_backward = backward_counter[stage_idx] == self._n_microbatches - grad_scale_factor = self._n_microbatches if self.scale_grads else 1 stage.backward_one_chunk( mb_index, loss=loss, full_backward=True, last_backward=last_backward, ) - if last_backward: - stage.scale_grads(grad_scale_factor) # SEND/RECV op are avoided for special case with 2 adjacent stages on same rank # see [Note: V-schedule special case] if is_prev_stage_on_this_rank: @@ -2131,13 +2130,10 @@ BACKWARD_INPUT, BACKWARD_WEIGHT, and OVERLAP_F_B are supported." _assert_unsharded(stage_idx) backward_counter[stage_idx] += 1 last_backward = backward_counter[stage_idx] == self._n_microbatches - grad_scale_factor = self._n_microbatches if self.scale_grads else 1 stage.backward_weight_one_chunk( mb_index, last_backward=last_backward, ) - if last_backward: - stage.scale_grads(grad_scale_factor) else: raise ValueError(f"{action=} is unknown or unsupported") diff --git a/torch/distributed/pipelining/stage.py b/torch/distributed/pipelining/stage.py index 121c6ec90c75..fe6fbf159b41 100644 --- a/torch/distributed/pipelining/stage.py +++ b/torch/distributed/pipelining/stage.py @@ -651,28 +651,6 @@ class _PipelineStageBase(ABC): self.submod.set_reshard_after_backward(False) self.submod.set_requires_gradient_sync(False) result = perform_backward(backward_type)() - if last_backward: - # Manually call post backward for FSDP - def run_post_backward(fsdp_module: FSDPModule) -> None: - fsdp_module.set_is_last_backward(True) - fsdp_module.set_reshard_after_backward(True) - fsdp_module.set_requires_gradient_sync(True) - - if isinstance(fsdp_module, ReplicateModule): - distributed_state = replicate.state(fsdp_module) # type: ignore[arg-type] - else: - distributed_state = fully_shard.state(fsdp_module) # type: ignore[attr-defined] - - for state in distributed_state._state_ctx.all_states: - if state._fsdp_param_group: - state._fsdp_param_group.post_backward() - - # it would be much better if pipelining backward invoked .backward so autograd hooks - # worked and modules like DDP/FSDP behaved as expected. Working around this for the time being, - # we need to call this too to ensure FSDP syncs its grad reduction ops back to the default stream. - distributed_state._root_post_backward_final_callback() - - run_post_backward(self.submod) else: # Non-DP submodule, regular backward @@ -998,6 +976,31 @@ class _PipelineStageBase(ABC): return ops + def _post_backward(self, grad_scale_factor: int): + # Manually call post backward for FSDP + if isinstance(self.submod, FSDPModule): + fsdp_module = self.submod + fsdp_module.set_is_last_backward(True) + fsdp_module.set_reshard_after_backward(True) + fsdp_module.set_requires_gradient_sync(True) + + if isinstance(fsdp_module, ReplicateModule): + distributed_state = replicate.state(fsdp_module) # type: ignore[arg-type] + else: + distributed_state = fully_shard.state(fsdp_module) # type: ignore[attr-defined] + + for state in distributed_state._state_ctx.all_states: + if state._fsdp_param_group: + state._fsdp_param_group.post_backward() + + # it would be much better if pipelining backward invoked .backward so autograd hooks + # worked and modules like DDP/FSDP behaved as expected. Working around this for the time being, + # we need to call this too to ensure FSDP syncs its grad reduction ops back to the default stream. + distributed_state._root_post_backward_final_callback() + # Call gradient scaling at the end of the backward pass + # NOTE: this must happen after FSDP post_backward is FSDP is enabled + self.scale_grads(grad_scale_factor) + class _PipelineStage(_PipelineStageBase): def __init__(