mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Add get_pipeline_order() for Gpipe and 1F1B (#155935)
The [schedule visualizer](https://github.com/pytorch/pytorch/blob/main/torch/distributed/pipelining/_schedule_visualizer.py) relies on `self.pipeline_order` to be populated. The `_PipelineScheduleRuntime` also depends on this to run the IR. The single stage schedules do not implement this so this PR adds that. Also fixes a bug in the schedule visualizer Pull Request resolved: https://github.com/pytorch/pytorch/pull/155935 Approved by: https://github.com/wconstab
This commit is contained in:
committed by
PyTorch MergeBot
parent
5435e75399
commit
38e1e5d54c
@ -56,14 +56,14 @@ def get_schedule_ops(
|
||||
num_stages_per_rank = 1
|
||||
assert num_stages_per_rank == 1
|
||||
stages = mock_pipeline_stage
|
||||
stages.num_stages = num_stages_per_rank
|
||||
stages.num_stages = num_stages_per_rank * pp_degree
|
||||
elif issubclass(schedule_class, PipelineScheduleMulti):
|
||||
if num_stages_per_rank is None:
|
||||
num_stages_per_rank = 2
|
||||
assert num_stages_per_rank >= 2
|
||||
stages = [mock_pipeline_stage for _ in range(num_stages_per_rank)]
|
||||
for stage in stages:
|
||||
stage.num_stages = num_stages_per_rank
|
||||
stage.num_stages = num_stages_per_rank * pp_degree
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid schedule: {schedule_class}")
|
||||
|
@ -485,6 +485,10 @@ class PipelineScheduleSingle(_PipelineSchedule):
|
||||
or equal to the number of stages ({self._num_stages})."
|
||||
)
|
||||
|
||||
self.pipeline_order: Optional[dict[int, list[Optional[_Action]]]] = (
|
||||
self._get_pipeline_order()
|
||||
)
|
||||
|
||||
def _initialize_stage(self, args, kwargs):
|
||||
self._stage._prepare_forward_infra(self._n_microbatches, args, kwargs)
|
||||
if self._has_backward:
|
||||
@ -524,6 +528,24 @@ or equal to the number of stages ({self._num_stages})."
|
||||
else:
|
||||
return None
|
||||
|
||||
def _get_pipeline_order(self) -> Optional[dict[int, list[Optional[_Action]]]]:
|
||||
"""
|
||||
Returns the pipeline execution order as a schedule IR.
|
||||
|
||||
The returned IR is a dictionary mapping rank IDs to lists of actions.
|
||||
Each action is either an _Action object representing computation to perform,
|
||||
or None representing a deliberate idle step.
|
||||
|
||||
The None values are used to represent pipeline bubbles where a rank
|
||||
must wait for dependencies from other ranks before proceeding. However
|
||||
during execution, with the _PipelineScheduleRuntime, these Nones are
|
||||
skipped since the relevant communication (send/recv) will be scheduled and waited on.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping rank -> list of actions
|
||||
"""
|
||||
return None
|
||||
|
||||
|
||||
class _ScheduleForwardOnly(PipelineScheduleSingle):
|
||||
"""
|
||||
@ -666,6 +688,38 @@ class ScheduleGPipe(PipelineScheduleSingle):
|
||||
for work in bwd_sends_to_wait:
|
||||
_wait_batch_p2p(work)
|
||||
|
||||
def _get_pipeline_order(self) -> Optional[dict[int, list[Optional[_Action]]]]:
|
||||
"""
|
||||
Returns the pipeline order for GPipe schedule.
|
||||
|
||||
See base method in PipelineScheduleSingle for details on the schedule IR format.
|
||||
"""
|
||||
pipeline_order = {}
|
||||
pp_group_size = self._num_stages
|
||||
|
||||
for rank in range(pp_group_size):
|
||||
actions: list[Optional[_Action]] = []
|
||||
|
||||
# 1. Initial delay based on rank position
|
||||
warmup_delay = rank
|
||||
actions.extend([None] * warmup_delay)
|
||||
|
||||
# 2. Forward passes for all microbatches
|
||||
for mb_idx in range(self._n_microbatches):
|
||||
actions.append(_Action(rank, _ComputationType.FORWARD, mb_idx))
|
||||
|
||||
# 3. Wait period before backward passes can begin
|
||||
backward_delay = 3 * (pp_group_size - 1 - rank)
|
||||
actions.extend([None] * backward_delay)
|
||||
|
||||
# 4. Backward passes for all microbatches
|
||||
for mb_idx in range(self._n_microbatches):
|
||||
actions.append(_Action(rank, _ComputationType.FULL_BACKWARD, mb_idx))
|
||||
|
||||
pipeline_order[rank] = actions
|
||||
|
||||
return pipeline_order
|
||||
|
||||
|
||||
class Schedule1F1B(PipelineScheduleSingle):
|
||||
"""
|
||||
@ -813,6 +867,74 @@ class Schedule1F1B(PipelineScheduleSingle):
|
||||
# Return losses if there is a container passed in
|
||||
self._update_losses(self._stage, losses)
|
||||
|
||||
def _get_pipeline_order(self) -> Optional[dict[int, list[Optional[_Action]]]]:
|
||||
"""
|
||||
Returns the pipeline order for 1F1B schedule.
|
||||
|
||||
See base method in PipelineScheduleSingle for details on the schedule IR format.
|
||||
"""
|
||||
pipeline_order = {}
|
||||
pp_group_size = self._num_stages
|
||||
|
||||
for rank in range(pp_group_size):
|
||||
actions: list[Optional[_Action]] = []
|
||||
|
||||
# 1. Warmup phase: initial delay based on rank
|
||||
actions.extend([None] * rank)
|
||||
|
||||
# 2. Initial forward passes before 1F1B phase
|
||||
num_forward = (pp_group_size - 1) - rank
|
||||
forward_mb = 0
|
||||
for i in range(num_forward):
|
||||
actions.append(_Action(rank, _ComputationType.FORWARD, i))
|
||||
forward_mb = i
|
||||
|
||||
# 3. Wait for backward to be ready
|
||||
wait_for_1f1b = max(0, 2 * (pp_group_size - 1 - rank))
|
||||
actions.extend([None] * wait_for_1f1b)
|
||||
|
||||
# 4. 1F1B steady state phase
|
||||
backward_mb = 0
|
||||
remaining_forward = self._n_microbatches - num_forward
|
||||
|
||||
while remaining_forward > 0:
|
||||
# One forward
|
||||
forward_mb += 1
|
||||
actions.append(_Action(rank, _ComputationType.FORWARD, forward_mb))
|
||||
remaining_forward -= 1
|
||||
|
||||
# One backward
|
||||
actions.append(
|
||||
_Action(rank, _ComputationType.FULL_BACKWARD, backward_mb)
|
||||
)
|
||||
backward_mb += 1
|
||||
|
||||
# 5. Cooldown phase: remaining backward passes
|
||||
remaining_backward = self._n_microbatches - backward_mb
|
||||
|
||||
while remaining_backward > 0:
|
||||
# Add None and backward actions in alternating pattern
|
||||
# based on distance from the last stage
|
||||
if (pp_group_size - rank) > 0:
|
||||
actions.append(None)
|
||||
# Decrement the wait counter only if we still have backward passes to do
|
||||
if remaining_backward > 0:
|
||||
actions.append(
|
||||
_Action(rank, _ComputationType.FULL_BACKWARD, backward_mb)
|
||||
)
|
||||
backward_mb += 1
|
||||
remaining_backward -= 1
|
||||
else:
|
||||
# If we're at the last stage, just add backward actions without None
|
||||
actions.append(
|
||||
_Action(rank, _ComputationType.FULL_BACKWARD, backward_mb)
|
||||
)
|
||||
backward_mb += 1
|
||||
remaining_backward -= 1
|
||||
|
||||
pipeline_order[rank] = actions
|
||||
return pipeline_order
|
||||
|
||||
|
||||
def _add_unshard_reshard(
|
||||
compute_actions: list[Optional[_Action]],
|
||||
|
Reference in New Issue
Block a user