Add get_pipeline_order() for Gpipe and 1F1B (#155935)

The [schedule visualizer](https://github.com/pytorch/pytorch/blob/main/torch/distributed/pipelining/_schedule_visualizer.py) relies on `self.pipeline_order` to be populated. The `_PipelineScheduleRuntime` also depends on this to run the IR.

The single stage schedules do not implement this so this PR adds that. Also fixes a bug in the schedule visualizer

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155935
Approved by: https://github.com/wconstab
This commit is contained in:
Howard Huang
2025-06-16 12:47:52 -07:00
committed by PyTorch MergeBot
parent 5435e75399
commit 38e1e5d54c
2 changed files with 124 additions and 2 deletions

View File

@ -56,14 +56,14 @@ def get_schedule_ops(
num_stages_per_rank = 1
assert num_stages_per_rank == 1
stages = mock_pipeline_stage
stages.num_stages = num_stages_per_rank
stages.num_stages = num_stages_per_rank * pp_degree
elif issubclass(schedule_class, PipelineScheduleMulti):
if num_stages_per_rank is None:
num_stages_per_rank = 2
assert num_stages_per_rank >= 2
stages = [mock_pipeline_stage for _ in range(num_stages_per_rank)]
for stage in stages:
stage.num_stages = num_stages_per_rank
stage.num_stages = num_stages_per_rank * pp_degree
else:
raise ValueError(f"Invalid schedule: {schedule_class}")

View File

@ -485,6 +485,10 @@ class PipelineScheduleSingle(_PipelineSchedule):
or equal to the number of stages ({self._num_stages})."
)
self.pipeline_order: Optional[dict[int, list[Optional[_Action]]]] = (
self._get_pipeline_order()
)
def _initialize_stage(self, args, kwargs):
self._stage._prepare_forward_infra(self._n_microbatches, args, kwargs)
if self._has_backward:
@ -524,6 +528,24 @@ or equal to the number of stages ({self._num_stages})."
else:
return None
def _get_pipeline_order(self) -> Optional[dict[int, list[Optional[_Action]]]]:
"""
Returns the pipeline execution order as a schedule IR.
The returned IR is a dictionary mapping rank IDs to lists of actions.
Each action is either an _Action object representing computation to perform,
or None representing a deliberate idle step.
The None values are used to represent pipeline bubbles where a rank
must wait for dependencies from other ranks before proceeding. However
during execution, with the _PipelineScheduleRuntime, these Nones are
skipped since the relevant communication (send/recv) will be scheduled and waited on.
Returns:
A dictionary mapping rank -> list of actions
"""
return None
class _ScheduleForwardOnly(PipelineScheduleSingle):
"""
@ -666,6 +688,38 @@ class ScheduleGPipe(PipelineScheduleSingle):
for work in bwd_sends_to_wait:
_wait_batch_p2p(work)
def _get_pipeline_order(self) -> Optional[dict[int, list[Optional[_Action]]]]:
"""
Returns the pipeline order for GPipe schedule.
See base method in PipelineScheduleSingle for details on the schedule IR format.
"""
pipeline_order = {}
pp_group_size = self._num_stages
for rank in range(pp_group_size):
actions: list[Optional[_Action]] = []
# 1. Initial delay based on rank position
warmup_delay = rank
actions.extend([None] * warmup_delay)
# 2. Forward passes for all microbatches
for mb_idx in range(self._n_microbatches):
actions.append(_Action(rank, _ComputationType.FORWARD, mb_idx))
# 3. Wait period before backward passes can begin
backward_delay = 3 * (pp_group_size - 1 - rank)
actions.extend([None] * backward_delay)
# 4. Backward passes for all microbatches
for mb_idx in range(self._n_microbatches):
actions.append(_Action(rank, _ComputationType.FULL_BACKWARD, mb_idx))
pipeline_order[rank] = actions
return pipeline_order
class Schedule1F1B(PipelineScheduleSingle):
"""
@ -813,6 +867,74 @@ class Schedule1F1B(PipelineScheduleSingle):
# Return losses if there is a container passed in
self._update_losses(self._stage, losses)
def _get_pipeline_order(self) -> Optional[dict[int, list[Optional[_Action]]]]:
"""
Returns the pipeline order for 1F1B schedule.
See base method in PipelineScheduleSingle for details on the schedule IR format.
"""
pipeline_order = {}
pp_group_size = self._num_stages
for rank in range(pp_group_size):
actions: list[Optional[_Action]] = []
# 1. Warmup phase: initial delay based on rank
actions.extend([None] * rank)
# 2. Initial forward passes before 1F1B phase
num_forward = (pp_group_size - 1) - rank
forward_mb = 0
for i in range(num_forward):
actions.append(_Action(rank, _ComputationType.FORWARD, i))
forward_mb = i
# 3. Wait for backward to be ready
wait_for_1f1b = max(0, 2 * (pp_group_size - 1 - rank))
actions.extend([None] * wait_for_1f1b)
# 4. 1F1B steady state phase
backward_mb = 0
remaining_forward = self._n_microbatches - num_forward
while remaining_forward > 0:
# One forward
forward_mb += 1
actions.append(_Action(rank, _ComputationType.FORWARD, forward_mb))
remaining_forward -= 1
# One backward
actions.append(
_Action(rank, _ComputationType.FULL_BACKWARD, backward_mb)
)
backward_mb += 1
# 5. Cooldown phase: remaining backward passes
remaining_backward = self._n_microbatches - backward_mb
while remaining_backward > 0:
# Add None and backward actions in alternating pattern
# based on distance from the last stage
if (pp_group_size - rank) > 0:
actions.append(None)
# Decrement the wait counter only if we still have backward passes to do
if remaining_backward > 0:
actions.append(
_Action(rank, _ComputationType.FULL_BACKWARD, backward_mb)
)
backward_mb += 1
remaining_backward -= 1
else:
# If we're at the last stage, just add backward actions without None
actions.append(
_Action(rank, _ComputationType.FULL_BACKWARD, backward_mb)
)
backward_mb += 1
remaining_backward -= 1
pipeline_order[rank] = actions
return pipeline_order
def _add_unshard_reshard(
compute_actions: list[Optional[_Action]],