mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[PP] Add ZeroBubble schedule (#133467)
Zero bubble can be expressed through `ScheduleFlexibleInterleaved1F1B` by setting `enable_zero_bubble=True`. But instead of having to include this flag in schedule initialization we should create a separate ZeroBubbleSchedule and also transition `Interleaved1F1B` to derive from `ScheduleFlexibleInterleaved1F1B`. Then we dont need to expose `ScheduleFlexibleInterleaved1F1B` since the naming is not obvious Pull Request resolved: https://github.com/pytorch/pytorch/pull/133467 Approved by: https://github.com/wconstab ghstack dependencies: #132691
This commit is contained in:
committed by
PyTorch MergeBot
parent
cedfac20c7
commit
108a75b454
@ -489,6 +489,8 @@ Pipeline Schedules
|
|||||||
|
|
||||||
.. autoclass:: ScheduleLoopedBFS
|
.. autoclass:: ScheduleLoopedBFS
|
||||||
|
|
||||||
|
.. autoclass:: ScheduleInterleavedZeroBubble
|
||||||
|
|
||||||
.. autoclass:: PipelineScheduleSingle
|
.. autoclass:: PipelineScheduleSingle
|
||||||
:members:
|
:members:
|
||||||
|
|
||||||
|
@ -19,6 +19,7 @@ from torch.distributed.pipelining import (
|
|||||||
ScheduleFlexibleInterleaved1F1B,
|
ScheduleFlexibleInterleaved1F1B,
|
||||||
ScheduleGPipe,
|
ScheduleGPipe,
|
||||||
ScheduleInterleaved1F1B,
|
ScheduleInterleaved1F1B,
|
||||||
|
ScheduleInterleavedZeroBubble,
|
||||||
ScheduleLoopedBFS,
|
ScheduleLoopedBFS,
|
||||||
)
|
)
|
||||||
from torch.distributed.pipelining.schedules import _PipelineScheduleRuntime
|
from torch.distributed.pipelining.schedules import _PipelineScheduleRuntime
|
||||||
@ -348,7 +349,10 @@ class ScheduleTest(MultiProcContinousTest):
|
|||||||
|
|
||||||
@requires_nccl()
|
@requires_nccl()
|
||||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||||
@parametrize("ScheduleClass", [ScheduleInterleaved1F1B, ScheduleLoopedBFS])
|
@parametrize(
|
||||||
|
"ScheduleClass",
|
||||||
|
[ScheduleInterleaved1F1B, ScheduleLoopedBFS, ScheduleInterleavedZeroBubble],
|
||||||
|
)
|
||||||
@parametrize("use_new_runtime", [False, True])
|
@parametrize("use_new_runtime", [False, True])
|
||||||
def test_grad_with_manual_interleaved(self, ScheduleClass, use_new_runtime):
|
def test_grad_with_manual_interleaved(self, ScheduleClass, use_new_runtime):
|
||||||
stages_per_rank = 2
|
stages_per_rank = 2
|
||||||
@ -408,6 +412,7 @@ class ScheduleTest(MultiProcContinousTest):
|
|||||||
num_microbatches,
|
num_microbatches,
|
||||||
loss_fn=loss_fn,
|
loss_fn=loss_fn,
|
||||||
stage_index_to_group_rank=old_schedule.stage_index_to_group_rank,
|
stage_index_to_group_rank=old_schedule.stage_index_to_group_rank,
|
||||||
|
use_full_backward=old_schedule.use_full_backward,
|
||||||
)
|
)
|
||||||
tmp_schedule._load_actions(old_schedule.pipeline_order)
|
tmp_schedule._load_actions(old_schedule.pipeline_order)
|
||||||
# test that csv round-trip works for compute_comms schedule
|
# test that csv round-trip works for compute_comms schedule
|
||||||
@ -416,6 +421,7 @@ class ScheduleTest(MultiProcContinousTest):
|
|||||||
num_microbatches,
|
num_microbatches,
|
||||||
loss_fn=loss_fn,
|
loss_fn=loss_fn,
|
||||||
stage_index_to_group_rank=old_schedule.stage_index_to_group_rank,
|
stage_index_to_group_rank=old_schedule.stage_index_to_group_rank,
|
||||||
|
use_full_backward=old_schedule.use_full_backward,
|
||||||
)
|
)
|
||||||
with tempfile.NamedTemporaryFile() as f:
|
with tempfile.NamedTemporaryFile() as f:
|
||||||
tmp_schedule._dump_csv(f.name)
|
tmp_schedule._dump_csv(f.name)
|
||||||
@ -426,6 +432,7 @@ class ScheduleTest(MultiProcContinousTest):
|
|||||||
num_microbatches,
|
num_microbatches,
|
||||||
loss_fn=loss_fn,
|
loss_fn=loss_fn,
|
||||||
stage_index_to_group_rank=old_schedule.stage_index_to_group_rank,
|
stage_index_to_group_rank=old_schedule.stage_index_to_group_rank,
|
||||||
|
use_full_backward=old_schedule.use_full_backward,
|
||||||
)
|
)
|
||||||
one_more_schedule._load_actions(
|
one_more_schedule._load_actions(
|
||||||
schedule.pipeline_order_with_comms, format="compute_comms"
|
schedule.pipeline_order_with_comms, format="compute_comms"
|
||||||
|
@ -6,6 +6,7 @@ from .schedules import (
|
|||||||
ScheduleFlexibleInterleaved1F1B,
|
ScheduleFlexibleInterleaved1F1B,
|
||||||
ScheduleGPipe,
|
ScheduleGPipe,
|
||||||
ScheduleInterleaved1F1B,
|
ScheduleInterleaved1F1B,
|
||||||
|
ScheduleInterleavedZeroBubble,
|
||||||
ScheduleLoopedBFS,
|
ScheduleLoopedBFS,
|
||||||
)
|
)
|
||||||
from .stage import build_stage, PipelineStage
|
from .stage import build_stage, PipelineStage
|
||||||
@ -23,4 +24,5 @@ __all__ = [
|
|||||||
"ScheduleGPipe",
|
"ScheduleGPipe",
|
||||||
"ScheduleInterleaved1F1B",
|
"ScheduleInterleaved1F1B",
|
||||||
"ScheduleLoopedBFS",
|
"ScheduleLoopedBFS",
|
||||||
|
"ScheduleInterleavedZeroBubble",
|
||||||
]
|
]
|
||||||
|
@ -42,6 +42,7 @@ __all__ = [
|
|||||||
"ScheduleGPipe",
|
"ScheduleGPipe",
|
||||||
"ScheduleInterleaved1F1B",
|
"ScheduleInterleaved1F1B",
|
||||||
"ScheduleLoopedBFS",
|
"ScheduleLoopedBFS",
|
||||||
|
"ScheduleInterleavedZeroBubble",
|
||||||
]
|
]
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -2110,6 +2111,35 @@ class ScheduleFlexibleInterleaved1F1B(PipelineScheduleMulti):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class ScheduleInterleavedZeroBubble(ScheduleFlexibleInterleaved1F1B):
|
||||||
|
"""
|
||||||
|
The Interleaved Zero Bubble schedule.
|
||||||
|
See https://arxiv.org/pdf/2401.10241 for details.
|
||||||
|
Will perform one forward and one backward on inputs for the microbatches in steady
|
||||||
|
state and supports multiple stages per rank. Uses the backward for weights to fill in
|
||||||
|
the pipeline bubble.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
stages: List[_PipelineStageBase],
|
||||||
|
n_microbatches: int,
|
||||||
|
loss_fn: Optional[Callable] = None,
|
||||||
|
args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None,
|
||||||
|
kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
|
||||||
|
output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
stages=stages,
|
||||||
|
n_microbatches=n_microbatches,
|
||||||
|
loss_fn=loss_fn,
|
||||||
|
args_chunk_spec=args_chunk_spec,
|
||||||
|
kwargs_chunk_spec=kwargs_chunk_spec,
|
||||||
|
output_merge_spec=output_merge_spec,
|
||||||
|
enable_zero_bubble=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_schedule_class(schedule_name: str):
|
def get_schedule_class(schedule_name: str):
|
||||||
"""
|
"""
|
||||||
Maps a schedule name to its corresponding class object.
|
Maps a schedule name to its corresponding class object.
|
||||||
@ -2123,6 +2153,7 @@ def get_schedule_class(schedule_name: str):
|
|||||||
"GPipe": ScheduleGPipe,
|
"GPipe": ScheduleGPipe,
|
||||||
"FlexibleInterleaved1F1B": ScheduleFlexibleInterleaved1F1B,
|
"FlexibleInterleaved1F1B": ScheduleFlexibleInterleaved1F1B,
|
||||||
"LoopedBFS": ScheduleLoopedBFS,
|
"LoopedBFS": ScheduleLoopedBFS,
|
||||||
|
"InterleavedZeroBubble": ScheduleInterleavedZeroBubble,
|
||||||
"PipelineScheduleSingle": PipelineScheduleSingle,
|
"PipelineScheduleSingle": PipelineScheduleSingle,
|
||||||
"PipelineScheduleMulti": PipelineScheduleMulti,
|
"PipelineScheduleMulti": PipelineScheduleMulti,
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user