mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[PP] Fix edge case with FSDP when stages_per_rank > 3 (#165467)
There is an edge case with FSDP + PP when we add UNSHARD + RESHARD, we at max have 3 stages unsharded, 3f83e8915e/torch/distributed/pipelining/schedules.py (L1029-L1031)
This change is need to be able to unshard and reshard a stage multiple times.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165467
Approved by: https://github.com/wwwjn
This commit is contained in:
committed by
PyTorch MergeBot
parent
132ae8e6dd
commit
ca65023b90
@ -2038,6 +2038,7 @@ BACKWARD_INPUT, BACKWARD_WEIGHT, and OVERLAP_F_B are supported."
|
||||
if not isinstance(submodule, FSDPModule):
|
||||
continue
|
||||
submodule.reshard()
|
||||
unsharded_stages.remove(stage_idx)
|
||||
elif comp_type == FORWARD:
|
||||
if stage_uses_fsdp:
|
||||
_assert_unsharded(stage_idx)
|
||||
|
Reference in New Issue
Block a user