diff --git a/docs/source/distributed.pipelining.rst b/docs/source/distributed.pipelining.rst index 841ab9a4b4de..fc1706a661dd 100644 --- a/docs/source/distributed.pipelining.rst +++ b/docs/source/distributed.pipelining.rst @@ -489,6 +489,8 @@ Pipeline Schedules .. autoclass:: ScheduleLoopedBFS +.. autoclass:: ScheduleInterleavedZeroBubble + .. autoclass:: PipelineScheduleSingle :members: diff --git a/test/distributed/pipelining/test_schedule_multiproc.py b/test/distributed/pipelining/test_schedule_multiproc.py index 22d43069a906..9bf0d49f036b 100644 --- a/test/distributed/pipelining/test_schedule_multiproc.py +++ b/test/distributed/pipelining/test_schedule_multiproc.py @@ -19,6 +19,7 @@ from torch.distributed.pipelining import ( ScheduleFlexibleInterleaved1F1B, ScheduleGPipe, ScheduleInterleaved1F1B, + ScheduleInterleavedZeroBubble, ScheduleLoopedBFS, ) from torch.distributed.pipelining.schedules import _PipelineScheduleRuntime @@ -348,7 +349,10 @@ class ScheduleTest(MultiProcContinousTest): @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") - @parametrize("ScheduleClass", [ScheduleInterleaved1F1B, ScheduleLoopedBFS]) + @parametrize( + "ScheduleClass", + [ScheduleInterleaved1F1B, ScheduleLoopedBFS, ScheduleInterleavedZeroBubble], + ) @parametrize("use_new_runtime", [False, True]) def test_grad_with_manual_interleaved(self, ScheduleClass, use_new_runtime): stages_per_rank = 2 @@ -408,6 +412,7 @@ class ScheduleTest(MultiProcContinousTest): num_microbatches, loss_fn=loss_fn, stage_index_to_group_rank=old_schedule.stage_index_to_group_rank, + use_full_backward=old_schedule.use_full_backward, ) tmp_schedule._load_actions(old_schedule.pipeline_order) # test that csv round-trip works for compute_comms schedule @@ -416,6 +421,7 @@ class ScheduleTest(MultiProcContinousTest): num_microbatches, loss_fn=loss_fn, stage_index_to_group_rank=old_schedule.stage_index_to_group_rank, + use_full_backward=old_schedule.use_full_backward, ) with tempfile.NamedTemporaryFile() as f: tmp_schedule._dump_csv(f.name) @@ -426,6 +432,7 @@ class ScheduleTest(MultiProcContinousTest): num_microbatches, loss_fn=loss_fn, stage_index_to_group_rank=old_schedule.stage_index_to_group_rank, + use_full_backward=old_schedule.use_full_backward, ) one_more_schedule._load_actions( schedule.pipeline_order_with_comms, format="compute_comms" diff --git a/torch/distributed/pipelining/__init__.py b/torch/distributed/pipelining/__init__.py index 3beb5e2d5909..476bf6a18a08 100644 --- a/torch/distributed/pipelining/__init__.py +++ b/torch/distributed/pipelining/__init__.py @@ -6,6 +6,7 @@ from .schedules import ( ScheduleFlexibleInterleaved1F1B, ScheduleGPipe, ScheduleInterleaved1F1B, + ScheduleInterleavedZeroBubble, ScheduleLoopedBFS, ) from .stage import build_stage, PipelineStage @@ -23,4 +24,5 @@ __all__ = [ "ScheduleGPipe", "ScheduleInterleaved1F1B", "ScheduleLoopedBFS", + "ScheduleInterleavedZeroBubble", ] diff --git a/torch/distributed/pipelining/schedules.py b/torch/distributed/pipelining/schedules.py index d5da66edf8a9..cd02e0e9042c 100644 --- a/torch/distributed/pipelining/schedules.py +++ b/torch/distributed/pipelining/schedules.py @@ -42,6 +42,7 @@ __all__ = [ "ScheduleGPipe", "ScheduleInterleaved1F1B", "ScheduleLoopedBFS", + "ScheduleInterleavedZeroBubble", ] logger = logging.getLogger(__name__) @@ -2110,6 +2111,35 @@ class ScheduleFlexibleInterleaved1F1B(PipelineScheduleMulti): return result +class ScheduleInterleavedZeroBubble(ScheduleFlexibleInterleaved1F1B): + """ + The Interleaved Zero Bubble schedule. + See https://arxiv.org/pdf/2401.10241 for details. + Will perform one forward and one backward on inputs for the microbatches in steady + state and supports multiple stages per rank. Uses the backward for weights to fill in + the pipeline bubble. + """ + + def __init__( + self, + stages: List[_PipelineStageBase], + n_microbatches: int, + loss_fn: Optional[Callable] = None, + args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None, + kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None, + output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, + ): + super().__init__( + stages=stages, + n_microbatches=n_microbatches, + loss_fn=loss_fn, + args_chunk_spec=args_chunk_spec, + kwargs_chunk_spec=kwargs_chunk_spec, + output_merge_spec=output_merge_spec, + enable_zero_bubble=True, + ) + + def get_schedule_class(schedule_name: str): """ Maps a schedule name to its corresponding class object. @@ -2123,6 +2153,7 @@ def get_schedule_class(schedule_name: str): "GPipe": ScheduleGPipe, "FlexibleInterleaved1F1B": ScheduleFlexibleInterleaved1F1B, "LoopedBFS": ScheduleLoopedBFS, + "InterleavedZeroBubble": ScheduleInterleavedZeroBubble, "PipelineScheduleSingle": PipelineScheduleSingle, "PipelineScheduleMulti": PipelineScheduleMulti, }