[PP] Fix edge case with FSDP when stages_per_rank > 3 (#165467)

There is an edge case with FSDP + PP when we add UNSHARD + RESHARD, we at max have 3 stages unsharded, 3f83e8915e/torch/distributed/pipelining/schedules.py (L1029-L1031)

This change is need to be able to unshard and reshard a stage multiple times.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165467
Approved by: https://github.com/wwwjn
This commit is contained in:
Howard Huang
2025-10-15 01:53:00 +00:00
committed by PyTorch MergeBot
parent 132ae8e6dd
commit ca65023b90

View File

@ -2038,6 +2038,7 @@ BACKWARD_INPUT, BACKWARD_WEIGHT, and OVERLAP_F_B are supported."
if not isinstance(submodule, FSDPModule):
continue
submodule.reshard()
unsharded_stages.remove(stage_idx)
elif comp_type == FORWARD:
if stage_uses_fsdp:
_assert_unsharded(stage_idx)