mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[PP] Add ZeroBubble schedule (#133467)
Zero bubble can be expressed through `ScheduleFlexibleInterleaved1F1B` by setting `enable_zero_bubble=True`. But instead of having to include this flag in schedule initialization we should create a separate ZeroBubbleSchedule and also transition `Interleaved1F1B` to derive from `ScheduleFlexibleInterleaved1F1B`. Then we dont need to expose `ScheduleFlexibleInterleaved1F1B` since the naming is not obvious Pull Request resolved: https://github.com/pytorch/pytorch/pull/133467 Approved by: https://github.com/wconstab ghstack dependencies: #132691
This commit is contained in:
committed by
PyTorch MergeBot
parent
cedfac20c7
commit
108a75b454
@ -489,6 +489,8 @@ Pipeline Schedules
|
||||
|
||||
.. autoclass:: ScheduleLoopedBFS
|
||||
|
||||
.. autoclass:: ScheduleInterleavedZeroBubble
|
||||
|
||||
.. autoclass:: PipelineScheduleSingle
|
||||
:members:
|
||||
|
||||
|
@ -19,6 +19,7 @@ from torch.distributed.pipelining import (
|
||||
ScheduleFlexibleInterleaved1F1B,
|
||||
ScheduleGPipe,
|
||||
ScheduleInterleaved1F1B,
|
||||
ScheduleInterleavedZeroBubble,
|
||||
ScheduleLoopedBFS,
|
||||
)
|
||||
from torch.distributed.pipelining.schedules import _PipelineScheduleRuntime
|
||||
@ -348,7 +349,10 @@ class ScheduleTest(MultiProcContinousTest):
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
@parametrize("ScheduleClass", [ScheduleInterleaved1F1B, ScheduleLoopedBFS])
|
||||
@parametrize(
|
||||
"ScheduleClass",
|
||||
[ScheduleInterleaved1F1B, ScheduleLoopedBFS, ScheduleInterleavedZeroBubble],
|
||||
)
|
||||
@parametrize("use_new_runtime", [False, True])
|
||||
def test_grad_with_manual_interleaved(self, ScheduleClass, use_new_runtime):
|
||||
stages_per_rank = 2
|
||||
@ -408,6 +412,7 @@ class ScheduleTest(MultiProcContinousTest):
|
||||
num_microbatches,
|
||||
loss_fn=loss_fn,
|
||||
stage_index_to_group_rank=old_schedule.stage_index_to_group_rank,
|
||||
use_full_backward=old_schedule.use_full_backward,
|
||||
)
|
||||
tmp_schedule._load_actions(old_schedule.pipeline_order)
|
||||
# test that csv round-trip works for compute_comms schedule
|
||||
@ -416,6 +421,7 @@ class ScheduleTest(MultiProcContinousTest):
|
||||
num_microbatches,
|
||||
loss_fn=loss_fn,
|
||||
stage_index_to_group_rank=old_schedule.stage_index_to_group_rank,
|
||||
use_full_backward=old_schedule.use_full_backward,
|
||||
)
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
tmp_schedule._dump_csv(f.name)
|
||||
@ -426,6 +432,7 @@ class ScheduleTest(MultiProcContinousTest):
|
||||
num_microbatches,
|
||||
loss_fn=loss_fn,
|
||||
stage_index_to_group_rank=old_schedule.stage_index_to_group_rank,
|
||||
use_full_backward=old_schedule.use_full_backward,
|
||||
)
|
||||
one_more_schedule._load_actions(
|
||||
schedule.pipeline_order_with_comms, format="compute_comms"
|
||||
|
@ -6,6 +6,7 @@ from .schedules import (
|
||||
ScheduleFlexibleInterleaved1F1B,
|
||||
ScheduleGPipe,
|
||||
ScheduleInterleaved1F1B,
|
||||
ScheduleInterleavedZeroBubble,
|
||||
ScheduleLoopedBFS,
|
||||
)
|
||||
from .stage import build_stage, PipelineStage
|
||||
@ -23,4 +24,5 @@ __all__ = [
|
||||
"ScheduleGPipe",
|
||||
"ScheduleInterleaved1F1B",
|
||||
"ScheduleLoopedBFS",
|
||||
"ScheduleInterleavedZeroBubble",
|
||||
]
|
||||
|
@ -42,6 +42,7 @@ __all__ = [
|
||||
"ScheduleGPipe",
|
||||
"ScheduleInterleaved1F1B",
|
||||
"ScheduleLoopedBFS",
|
||||
"ScheduleInterleavedZeroBubble",
|
||||
]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -2110,6 +2111,35 @@ class ScheduleFlexibleInterleaved1F1B(PipelineScheduleMulti):
|
||||
return result
|
||||
|
||||
|
||||
class ScheduleInterleavedZeroBubble(ScheduleFlexibleInterleaved1F1B):
|
||||
"""
|
||||
The Interleaved Zero Bubble schedule.
|
||||
See https://arxiv.org/pdf/2401.10241 for details.
|
||||
Will perform one forward and one backward on inputs for the microbatches in steady
|
||||
state and supports multiple stages per rank. Uses the backward for weights to fill in
|
||||
the pipeline bubble.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stages: List[_PipelineStageBase],
|
||||
n_microbatches: int,
|
||||
loss_fn: Optional[Callable] = None,
|
||||
args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None,
|
||||
kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
|
||||
output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
||||
):
|
||||
super().__init__(
|
||||
stages=stages,
|
||||
n_microbatches=n_microbatches,
|
||||
loss_fn=loss_fn,
|
||||
args_chunk_spec=args_chunk_spec,
|
||||
kwargs_chunk_spec=kwargs_chunk_spec,
|
||||
output_merge_spec=output_merge_spec,
|
||||
enable_zero_bubble=True,
|
||||
)
|
||||
|
||||
|
||||
def get_schedule_class(schedule_name: str):
|
||||
"""
|
||||
Maps a schedule name to its corresponding class object.
|
||||
@ -2123,6 +2153,7 @@ def get_schedule_class(schedule_name: str):
|
||||
"GPipe": ScheduleGPipe,
|
||||
"FlexibleInterleaved1F1B": ScheduleFlexibleInterleaved1F1B,
|
||||
"LoopedBFS": ScheduleLoopedBFS,
|
||||
"InterleavedZeroBubble": ScheduleInterleavedZeroBubble,
|
||||
"PipelineScheduleSingle": PipelineScheduleSingle,
|
||||
"PipelineScheduleMulti": PipelineScheduleMulti,
|
||||
}
|
||||
|
Reference in New Issue
Block a user