diff --git a/torch/distributed/pipelining/_schedule_visualizer.py b/torch/distributed/pipelining/_schedule_visualizer.py index ccda7177e889..b39a806fa776 100644 --- a/torch/distributed/pipelining/_schedule_visualizer.py +++ b/torch/distributed/pipelining/_schedule_visualizer.py @@ -56,14 +56,14 @@ def get_schedule_ops( num_stages_per_rank = 1 assert num_stages_per_rank == 1 stages = mock_pipeline_stage - stages.num_stages = num_stages_per_rank + stages.num_stages = num_stages_per_rank * pp_degree elif issubclass(schedule_class, PipelineScheduleMulti): if num_stages_per_rank is None: num_stages_per_rank = 2 assert num_stages_per_rank >= 2 stages = [mock_pipeline_stage for _ in range(num_stages_per_rank)] for stage in stages: - stage.num_stages = num_stages_per_rank + stage.num_stages = num_stages_per_rank * pp_degree else: raise ValueError(f"Invalid schedule: {schedule_class}") diff --git a/torch/distributed/pipelining/schedules.py b/torch/distributed/pipelining/schedules.py index c3b316577744..a77d58b2a4f8 100644 --- a/torch/distributed/pipelining/schedules.py +++ b/torch/distributed/pipelining/schedules.py @@ -485,6 +485,10 @@ class PipelineScheduleSingle(_PipelineSchedule): or equal to the number of stages ({self._num_stages})." ) + self.pipeline_order: Optional[dict[int, list[Optional[_Action]]]] = ( + self._get_pipeline_order() + ) + def _initialize_stage(self, args, kwargs): self._stage._prepare_forward_infra(self._n_microbatches, args, kwargs) if self._has_backward: @@ -524,6 +528,24 @@ or equal to the number of stages ({self._num_stages})." else: return None + def _get_pipeline_order(self) -> Optional[dict[int, list[Optional[_Action]]]]: + """ + Returns the pipeline execution order as a schedule IR. + + The returned IR is a dictionary mapping rank IDs to lists of actions. + Each action is either an _Action object representing computation to perform, + or None representing a deliberate idle step. + + The None values are used to represent pipeline bubbles where a rank + must wait for dependencies from other ranks before proceeding. However + during execution, with the _PipelineScheduleRuntime, these Nones are + skipped since the relevant communication (send/recv) will be scheduled and waited on. + + Returns: + A dictionary mapping rank -> list of actions + """ + return None + class _ScheduleForwardOnly(PipelineScheduleSingle): """ @@ -666,6 +688,38 @@ class ScheduleGPipe(PipelineScheduleSingle): for work in bwd_sends_to_wait: _wait_batch_p2p(work) + def _get_pipeline_order(self) -> Optional[dict[int, list[Optional[_Action]]]]: + """ + Returns the pipeline order for GPipe schedule. + + See base method in PipelineScheduleSingle for details on the schedule IR format. + """ + pipeline_order = {} + pp_group_size = self._num_stages + + for rank in range(pp_group_size): + actions: list[Optional[_Action]] = [] + + # 1. Initial delay based on rank position + warmup_delay = rank + actions.extend([None] * warmup_delay) + + # 2. Forward passes for all microbatches + for mb_idx in range(self._n_microbatches): + actions.append(_Action(rank, _ComputationType.FORWARD, mb_idx)) + + # 3. Wait period before backward passes can begin + backward_delay = 3 * (pp_group_size - 1 - rank) + actions.extend([None] * backward_delay) + + # 4. Backward passes for all microbatches + for mb_idx in range(self._n_microbatches): + actions.append(_Action(rank, _ComputationType.FULL_BACKWARD, mb_idx)) + + pipeline_order[rank] = actions + + return pipeline_order + class Schedule1F1B(PipelineScheduleSingle): """ @@ -813,6 +867,74 @@ class Schedule1F1B(PipelineScheduleSingle): # Return losses if there is a container passed in self._update_losses(self._stage, losses) + def _get_pipeline_order(self) -> Optional[dict[int, list[Optional[_Action]]]]: + """ + Returns the pipeline order for 1F1B schedule. + + See base method in PipelineScheduleSingle for details on the schedule IR format. + """ + pipeline_order = {} + pp_group_size = self._num_stages + + for rank in range(pp_group_size): + actions: list[Optional[_Action]] = [] + + # 1. Warmup phase: initial delay based on rank + actions.extend([None] * rank) + + # 2. Initial forward passes before 1F1B phase + num_forward = (pp_group_size - 1) - rank + forward_mb = 0 + for i in range(num_forward): + actions.append(_Action(rank, _ComputationType.FORWARD, i)) + forward_mb = i + + # 3. Wait for backward to be ready + wait_for_1f1b = max(0, 2 * (pp_group_size - 1 - rank)) + actions.extend([None] * wait_for_1f1b) + + # 4. 1F1B steady state phase + backward_mb = 0 + remaining_forward = self._n_microbatches - num_forward + + while remaining_forward > 0: + # One forward + forward_mb += 1 + actions.append(_Action(rank, _ComputationType.FORWARD, forward_mb)) + remaining_forward -= 1 + + # One backward + actions.append( + _Action(rank, _ComputationType.FULL_BACKWARD, backward_mb) + ) + backward_mb += 1 + + # 5. Cooldown phase: remaining backward passes + remaining_backward = self._n_microbatches - backward_mb + + while remaining_backward > 0: + # Add None and backward actions in alternating pattern + # based on distance from the last stage + if (pp_group_size - rank) > 0: + actions.append(None) + # Decrement the wait counter only if we still have backward passes to do + if remaining_backward > 0: + actions.append( + _Action(rank, _ComputationType.FULL_BACKWARD, backward_mb) + ) + backward_mb += 1 + remaining_backward -= 1 + else: + # If we're at the last stage, just add backward actions without None + actions.append( + _Action(rank, _ComputationType.FULL_BACKWARD, backward_mb) + ) + backward_mb += 1 + remaining_backward -= 1 + + pipeline_order[rank] = actions + return pipeline_order + def _add_unshard_reshard( compute_actions: list[Optional[_Action]],