[PP] Add ZeroBubble schedule (#133467)

Zero bubble can be expressed through `ScheduleFlexibleInterleaved1F1B` by setting `enable_zero_bubble=True`. But instead of having to include this flag in schedule initialization we should create a separate ZeroBubbleSchedule and also transition `Interleaved1F1B` to derive from `ScheduleFlexibleInterleaved1F1B`. Then we dont need to expose `ScheduleFlexibleInterleaved1F1B` since the naming is not obvious

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133467
Approved by: https://github.com/wconstab
ghstack dependencies: #132691
This commit is contained in:
Howard Huang
2024-08-21 13:25:33 -07:00
committed by PyTorch MergeBot
parent cedfac20c7
commit 108a75b454
4 changed files with 43 additions and 1 deletions

View File

@ -489,6 +489,8 @@ Pipeline Schedules
.. autoclass:: ScheduleLoopedBFS
.. autoclass:: ScheduleInterleavedZeroBubble
.. autoclass:: PipelineScheduleSingle
:members:

View File

@ -19,6 +19,7 @@ from torch.distributed.pipelining import (
ScheduleFlexibleInterleaved1F1B,
ScheduleGPipe,
ScheduleInterleaved1F1B,
ScheduleInterleavedZeroBubble,
ScheduleLoopedBFS,
)
from torch.distributed.pipelining.schedules import _PipelineScheduleRuntime
@ -348,7 +349,10 @@ class ScheduleTest(MultiProcContinousTest):
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@parametrize("ScheduleClass", [ScheduleInterleaved1F1B, ScheduleLoopedBFS])
@parametrize(
"ScheduleClass",
[ScheduleInterleaved1F1B, ScheduleLoopedBFS, ScheduleInterleavedZeroBubble],
)
@parametrize("use_new_runtime", [False, True])
def test_grad_with_manual_interleaved(self, ScheduleClass, use_new_runtime):
stages_per_rank = 2
@ -408,6 +412,7 @@ class ScheduleTest(MultiProcContinousTest):
num_microbatches,
loss_fn=loss_fn,
stage_index_to_group_rank=old_schedule.stage_index_to_group_rank,
use_full_backward=old_schedule.use_full_backward,
)
tmp_schedule._load_actions(old_schedule.pipeline_order)
# test that csv round-trip works for compute_comms schedule
@ -416,6 +421,7 @@ class ScheduleTest(MultiProcContinousTest):
num_microbatches,
loss_fn=loss_fn,
stage_index_to_group_rank=old_schedule.stage_index_to_group_rank,
use_full_backward=old_schedule.use_full_backward,
)
with tempfile.NamedTemporaryFile() as f:
tmp_schedule._dump_csv(f.name)
@ -426,6 +432,7 @@ class ScheduleTest(MultiProcContinousTest):
num_microbatches,
loss_fn=loss_fn,
stage_index_to_group_rank=old_schedule.stage_index_to_group_rank,
use_full_backward=old_schedule.use_full_backward,
)
one_more_schedule._load_actions(
schedule.pipeline_order_with_comms, format="compute_comms"

View File

@ -6,6 +6,7 @@ from .schedules import (
ScheduleFlexibleInterleaved1F1B,
ScheduleGPipe,
ScheduleInterleaved1F1B,
ScheduleInterleavedZeroBubble,
ScheduleLoopedBFS,
)
from .stage import build_stage, PipelineStage
@ -23,4 +24,5 @@ __all__ = [
"ScheduleGPipe",
"ScheduleInterleaved1F1B",
"ScheduleLoopedBFS",
"ScheduleInterleavedZeroBubble",
]

View File

@ -42,6 +42,7 @@ __all__ = [
"ScheduleGPipe",
"ScheduleInterleaved1F1B",
"ScheduleLoopedBFS",
"ScheduleInterleavedZeroBubble",
]
logger = logging.getLogger(__name__)
@ -2110,6 +2111,35 @@ class ScheduleFlexibleInterleaved1F1B(PipelineScheduleMulti):
return result
class ScheduleInterleavedZeroBubble(ScheduleFlexibleInterleaved1F1B):
"""
The Interleaved Zero Bubble schedule.
See https://arxiv.org/pdf/2401.10241 for details.
Will perform one forward and one backward on inputs for the microbatches in steady
state and supports multiple stages per rank. Uses the backward for weights to fill in
the pipeline bubble.
"""
def __init__(
self,
stages: List[_PipelineStageBase],
n_microbatches: int,
loss_fn: Optional[Callable] = None,
args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None,
kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
):
super().__init__(
stages=stages,
n_microbatches=n_microbatches,
loss_fn=loss_fn,
args_chunk_spec=args_chunk_spec,
kwargs_chunk_spec=kwargs_chunk_spec,
output_merge_spec=output_merge_spec,
enable_zero_bubble=True,
)
def get_schedule_class(schedule_name: str):
"""
Maps a schedule name to its corresponding class object.
@ -2123,6 +2153,7 @@ def get_schedule_class(schedule_name: str):
"GPipe": ScheduleGPipe,
"FlexibleInterleaved1F1B": ScheduleFlexibleInterleaved1F1B,
"LoopedBFS": ScheduleLoopedBFS,
"InterleavedZeroBubble": ScheduleInterleavedZeroBubble,
"PipelineScheduleSingle": PipelineScheduleSingle,
"PipelineScheduleMulti": PipelineScheduleMulti,
}