Revert "support zb1p and zb2p algorithms (#130752)"

This reverts commit 8fe5b93667b60e37c12d288659a25cbd5ae53c79.

Reverted https://github.com/pytorch/pytorch/pull/130752 on behalf of https://github.com/atalman due to Broke Periodic CI: distributed/pipelining/test_composability.py::ComposabilityTest::test_manual_with_data_parallel_dp_type_DDP_ScheduleClass4 [GH job link](https://github.com/pytorch/pytorch/actions/runs/10131472868/job/28014900187) [HUD commit link](8fe5b93667) ([comment](https://github.com/pytorch/pytorch/pull/130752#issuecomment-2255819078))
This commit is contained in:
PyTorch MergeBot
2024-07-29 12:39:59 +00:00
parent 9d497887b8
commit eb9409511e
5 changed files with 67 additions and 225 deletions

View File

@ -489,8 +489,6 @@ Pipeline Schedules
.. autoclass:: ScheduleLoopedBFS
.. autoclass:: ZeroBubbleAlgorithm
.. autoclass:: PipelineScheduleSingle
:members:

View File

@ -7,7 +7,6 @@ from torch.distributed.pipelining import (
ScheduleFlexibleInterleaved1F1B,
ScheduleInterleaved1F1B,
ScheduleLoopedBFS,
ZeroBubbleAlgorithm,
)
from torch.distributed.pipelining.schedules import (
_format_pipeline_order,
@ -128,21 +127,14 @@ class TestSchedulePlan(TestCase):
warmup_ops = warmups_ops_last_stage + 2 * (group_size - 1)
warmup_ops = min(warmup_ops, num_microbatches * num_local_stages)
zero_bubble_algorithms = [
None,
ZeroBubbleAlgorithm.ZB1P,
ZeroBubbleAlgorithm.ZB2P,
]
for i in range(len(zero_bubble_algorithms)):
for i in range(2):
num_stages = num_local_stages * group_size
stages = [
MockPipelineStage(group_size=group_size, num_stages=num_stages)
for i in range(num_local_stages)
]
schedule = ScheduleClass(
stages,
num_microbatches,
zero_bubble_algorithm=zero_bubble_algorithms[i],
stages, num_microbatches, enable_zero_bubble=(i == 0)
)
formatted_pipeline_order = _format_pipeline_order(
schedule.pipeline_order
@ -152,7 +144,7 @@ class TestSchedulePlan(TestCase):
schedule.pipeline_order,
num_microbatches,
num_stages,
enable_zero_bubble=(zero_bubble_algorithms[i] is not None),
enable_zero_bubble=(i == 0),
)

View File

@ -19,7 +19,6 @@ from torch.distributed.pipelining import (
ScheduleGPipe,
ScheduleInterleaved1F1B,
ScheduleLoopedBFS,
ZeroBubbleAlgorithm,
)
from torch.testing._internal.common_cuda import TEST_MULTIGPU
from torch.testing._internal.common_distributed import (
@ -556,10 +555,7 @@ class ScheduleTest(MultiProcContinousTest):
# Attach to a schedule
schedule = ScheduleClass(
stages,
chunks,
loss_fn=full_loss_fn,
zero_bubble_algorithm=ZeroBubbleAlgorithm.ZB2P,
stages, chunks, loss_fn=full_loss_fn, enable_zero_bubble=True
)
for _ in range(2):

View File

@ -6,7 +6,6 @@ from .schedules import (
ScheduleGPipe,
ScheduleInterleaved1F1B,
ScheduleLoopedBFS,
ZeroBubbleAlgorithm,
)
from .stage import build_stage, PipelineStage
@ -23,5 +22,4 @@ __all__ = [
"ScheduleGPipe",
"ScheduleInterleaved1F1B",
"ScheduleLoopedBFS",
"ZeroBubbleAlgorithm",
]

View File

@ -4,7 +4,6 @@
import csv
import itertools
import logging
import math
import re
from abc import ABC, abstractmethod
from collections import defaultdict
@ -65,20 +64,6 @@ W = _ComputationType.WEIGHT
_action_regex = re.compile(r"(\d+)([F,B,W])(\d*)")
class ZeroBubbleAlgorithm(Enum):
ZB1P = 1
ZB2P = 2
ZBV = 3
def __str__(self):
str_map = {
ZeroBubbleAlgorithm.ZB1P: "ZB1P",
ZeroBubbleAlgorithm.ZB2P: "ZB2P",
ZeroBubbleAlgorithm.ZBV: "ZBV",
}
return str_map[self]
class _Action(NamedTuple):
stage_index: int
computation_type: _ComputationType
@ -1199,10 +1184,14 @@ def _get_1f1b_rank_ops(
rank,
forward_stage_index,
backward_stage_index,
num_1f1b_microbatches=0,
enable_zero_bubble=False,
):
# All stages start with handling microbatch 0
fwd_stage_mb_index: Dict[int, int] = defaultdict(int)
bwd_stage_mb_index: Dict[int, int] = defaultdict(int)
weight_stage_mb_index: Dict[int, int] = defaultdict(int)
# Store the list of operations used for that rank
rank_ops: List[Optional[_Action]] = []
# Pre-padding, rank starts with no-ops based on the warmup.
@ -1219,8 +1208,14 @@ def _get_1f1b_rank_ops(
n_local_stages * pp_group_size + 2 * (pp_group_size - 1 - rank)
) - (warmup_ops + rank)
if enable_zero_bubble:
post_warmup_ops = pp_group_size - rank - 1
total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops
backward_op_ids = []
weight_op_count = 0
for op in range(total_ops):
# Warmup phase
if op < warmup_ops:
@ -1251,11 +1246,28 @@ def _get_1f1b_rank_ops(
rank_ops.append(
_Action(bwd_stage_index, _ComputationType.BACKWARD, bwd_mb_index)
)
backward_op_ids.append(op)
if enable_zero_bubble and op - warmup_ops >= num_1f1b_microbatches:
weight_stage_index = backward_stage_index(
backward_op_ids[weight_op_count]
)
weight_stage_mb_index[weight_stage_index] = (
weight_mb_index := weight_stage_mb_index[weight_stage_index]
) + 1
rank_ops.append(
_Action(
weight_stage_index, _ComputationType.WEIGHT, weight_mb_index
)
)
weight_op_count += 1
# Cooldown phase
else:
# During cooldown phase, we need steps to align with 1f1b happening in other ranks
# TODO: we don't need to always append, after all 1f1b are finished we can stop appending None
rank_ops.append(None)
if not enable_zero_bubble:
rank_ops.append(None)
bwd_stage_index = backward_stage_index(op)
bwd_stage_mb_index[bwd_stage_index] = (
bwd_mb_index := bwd_stage_mb_index[bwd_stage_index]
@ -1263,6 +1275,32 @@ def _get_1f1b_rank_ops(
rank_ops.append(
_Action(bwd_stage_index, _ComputationType.BACKWARD, bwd_mb_index)
)
backward_op_ids.append(op)
if enable_zero_bubble and op - warmup_ops >= num_1f1b_microbatches:
weight_stage_index = backward_stage_index(
backward_op_ids[weight_op_count]
)
weight_stage_mb_index[weight_stage_index] = (
weight_mb_index := weight_stage_mb_index[weight_stage_index]
) + 1
rank_ops.append(
_Action(
weight_stage_index, _ComputationType.WEIGHT, weight_mb_index
)
)
weight_op_count += 1
while enable_zero_bubble and weight_op_count < len(backward_op_ids):
weight_stage_index = backward_stage_index(backward_op_ids[weight_op_count])
weight_stage_mb_index[weight_stage_index] = (
weight_mb_index := weight_stage_mb_index[weight_stage_index]
) + 1
rank_ops.append(
_Action(weight_stage_index, _ComputationType.WEIGHT, weight_mb_index)
)
weight_op_count += 1
return rank_ops
@ -1381,8 +1419,7 @@ class ScheduleFlexibleInterleaved1F1B(PipelineScheduleMulti):
1. pp_group_size = 4, n_microbatches = 10. We will have num_rounds = 2 and n_microbatches % 2 is 0.
2. pp_group_size = 4, n_microbatches = 3. We will have num_rounds = 1 and n_microbatches % 1 is 0.
When zero_bubble_algorithm is passed in, we will use the corresponding schedule in
https://openreview.net/pdf?id=tuzTN0eIO5
When enable_zero_bubble is True, we will use the ZB1P schedule in https://openreview.net/pdf?id=tuzTN0eIO5
"""
def __init__(
@ -1393,7 +1430,7 @@ class ScheduleFlexibleInterleaved1F1B(PipelineScheduleMulti):
args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None,
kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
zero_bubble_algorithm: Optional[ZeroBubbleAlgorithm] = None,
enable_zero_bubble: bool = False,
):
self.pp_group_size = stages[0].group_size
super().__init__(
@ -1403,16 +1440,13 @@ class ScheduleFlexibleInterleaved1F1B(PipelineScheduleMulti):
args_chunk_spec=args_chunk_spec,
kwargs_chunk_spec=kwargs_chunk_spec,
output_merge_spec=output_merge_spec,
use_full_backward=not zero_bubble_algorithm,
use_full_backward=not enable_zero_bubble,
)
self.n_local_stages = len(stages)
self.rank = stages[0].group_rank
self.number_of_rounds = max(1, n_microbatches // self.pp_group_size)
self.microbatches_per_round = n_microbatches // self.number_of_rounds
self.zero_bubble_algorithm = zero_bubble_algorithm
if self.zero_bubble_algorithm is ZeroBubbleAlgorithm.ZBV:
raise ValueError("ZBV is not yet supported")
self.enable_zero_bubble = enable_zero_bubble
if n_microbatches % self.number_of_rounds != 0:
raise ValueError(
"Flexible Interleaved 1F1B requires the number of microbatches to be a "
@ -1441,9 +1475,7 @@ class ScheduleFlexibleInterleaved1F1B(PipelineScheduleMulti):
self.n_local_stages - 1
) * self.microbatches_per_round
# Increment warmup operations by 2 for each hop away from the last stage
multiply_factor = 1
if self.zero_bubble_algorithm is ZeroBubbleAlgorithm.ZB2P:
multiply_factor = 2
multiply_factor = 1 if self.enable_zero_bubble else 2
warmup_ops = warmups_ops_last_stage + multiply_factor * (
(self.pp_group_size - 1) - rank
)
@ -1485,12 +1517,10 @@ class ScheduleFlexibleInterleaved1F1B(PipelineScheduleMulti):
)
return (local_index * self.pp_group_size) + rank
if self.zero_bubble_algorithm:
if self.enable_zero_bubble:
num_1f1b_microbatches = rank
if self.zero_bubble_algorithm is ZeroBubbleAlgorithm.ZB2P:
num_1f1b_microbatches = 2 * rank
return self._get_1f1b_rank_ops_zero_bubble(
return _get_1f1b_rank_ops(
self.n_local_stages,
self.pp_group_size,
warmup_ops,
@ -1500,8 +1530,7 @@ class ScheduleFlexibleInterleaved1F1B(PipelineScheduleMulti):
forward_stage_index,
backward_stage_index,
num_1f1b_microbatches,
zero_bubble_algorithm=self.zero_bubble_algorithm,
forward_local_stage_one_index=self.pp_group_size + rank,
enable_zero_bubble=True,
)
return _get_1f1b_rank_ops(
@ -1517,7 +1546,7 @@ class ScheduleFlexibleInterleaved1F1B(PipelineScheduleMulti):
def _add_bubbles_to_actions(self, num_stages_global):
actions = self.pipeline_order
if not self.zero_bubble_algorithm:
if not self.enable_zero_bubble:
return actions
def need_bubble(stage, op, microbatch, num_stages_global, seen_ops):
@ -1580,174 +1609,3 @@ class ScheduleFlexibleInterleaved1F1B(PipelineScheduleMulti):
f"Non zero bubbles added: {total_bubbles_added=} {bubbles_added=}" # noqa: G004
)
return result
def _get_1f1b_rank_ops_zero_bubble(
self,
n_local_stages,
pp_group_size,
warmup_ops,
fwd_bwd_ops,
cooldown_ops,
rank,
forward_stage_index,
backward_stage_index,
num_1f1b_microbatches,
zero_bubble_algorithm,
forward_local_stage_one_index,
):
# All stages start with handling microbatch 0
fwd_stage_mb_index: Dict[int, int] = defaultdict(int)
bwd_stage_mb_index: Dict[int, int] = defaultdict(int)
weight_stage_mb_index: Dict[int, int] = defaultdict(int)
# Store the list of operations used for that rank
rank_ops: List[Optional[_Action]] = []
# Pre-padding, rank starts with no-ops based on the warmup.
for _ in range(rank):
rank_ops.append(None)
# These are used to calculate the number of slots to fill with no-ops, to account for the delay in warmup
# when we want to wait for the backward to trickle back up and start 1f1b to align all ranks.
# Formula:
# pre-padding + warmup_ops + post_warmup_ops = earliest time step of first backward
# post_warmup_ops = [earliest time step of first backward] - (warmup_ops + pre-padding)
# earliest time step of first backward = [local_stages * group_size + 2 * (group_size - 1 - rank)]
# warmup_ops = calculated above
post_warmup_ops = (
n_local_stages * pp_group_size + 2 * (pp_group_size - 1 - rank)
) - (warmup_ops + rank)
if zero_bubble_algorithm is ZeroBubbleAlgorithm.ZB1P:
post_warmup_ops = pp_group_size - rank - 1
elif zero_bubble_algorithm is ZeroBubbleAlgorithm.ZB2P:
post_warmup_ops = 0
total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops
prefill_steps_1b1w = 0
if zero_bubble_algorithm is ZeroBubbleAlgorithm.ZB2P:
prefill_steps_1b1w = max(0, math.ceil((pp_group_size - 4) / 2) - rank)
backward_op_ids = []
weight_op_count = 0
forward_op_id = 0
backward_op_id = warmup_ops
has_backfilled = False
for op in range(total_ops - prefill_steps_1b1w):
# Warmup phase
if op < warmup_ops:
fwd_stage_index = forward_stage_index(forward_op_id)
# This will assign the current microbatch index and update it as well
fwd_stage_mb_index[fwd_stage_index] = (
mb_index := fwd_stage_mb_index[fwd_stage_index]
) + 1
rank_ops.append(
_Action(fwd_stage_index, _ComputationType.FORWARD, mb_index)
)
if forward_op_id == warmup_ops - 1:
# This is the last step in the warmup phase, so we need to wait for the backward to trickle back up
rank_ops.extend([None] * post_warmup_ops)
forward_op_id += 1
# 1F1B Phase (forward and backward)
elif warmup_ops <= op < warmup_ops + fwd_bwd_ops:
fwd_stage_index = forward_stage_index(forward_op_id)
if (
fwd_stage_index == forward_local_stage_one_index
and not has_backfilled
):
has_backfilled = True
for _ in range(prefill_steps_1b1w):
bwd_stage_index = backward_stage_index(backward_op_id)
bwd_stage_mb_index[bwd_stage_index] = (
bwd_mb_index := bwd_stage_mb_index[bwd_stage_index]
) + 1
rank_ops.append(
_Action(
bwd_stage_index, _ComputationType.BACKWARD, bwd_mb_index
)
)
backward_op_ids.append(backward_op_id)
backward_op_id += 1
weight_stage_index = backward_stage_index(
backward_op_ids[weight_op_count]
)
weight_stage_mb_index[weight_stage_index] = (
weight_mb_index := weight_stage_mb_index[weight_stage_index]
) + 1
rank_ops.append(
_Action(
weight_stage_index,
_ComputationType.WEIGHT,
weight_mb_index,
)
)
weight_op_count += 1
fwd_stage_mb_index[fwd_stage_index] = (
fwd_mb_index := fwd_stage_mb_index[fwd_stage_index]
) + 1
rank_ops.append(
_Action(fwd_stage_index, _ComputationType.FORWARD, fwd_mb_index)
)
bwd_stage_index = backward_stage_index(backward_op_id)
bwd_stage_mb_index[bwd_stage_index] = (
bwd_mb_index := bwd_stage_mb_index[bwd_stage_index]
) + 1
rank_ops.append(
_Action(bwd_stage_index, _ComputationType.BACKWARD, bwd_mb_index)
)
backward_op_ids.append(backward_op_id)
forward_op_id += 1
backward_op_id += 1
if op - warmup_ops >= num_1f1b_microbatches:
weight_stage_index = backward_stage_index(
backward_op_ids[weight_op_count]
)
weight_stage_mb_index[weight_stage_index] = (
weight_mb_index := weight_stage_mb_index[weight_stage_index]
) + 1
rank_ops.append(
_Action(
weight_stage_index, _ComputationType.WEIGHT, weight_mb_index
)
)
weight_op_count += 1
# Cooldown phase
else:
bwd_stage_index = backward_stage_index(backward_op_id)
bwd_stage_mb_index[bwd_stage_index] = (
bwd_mb_index := bwd_stage_mb_index[bwd_stage_index]
) + 1
rank_ops.append(
_Action(bwd_stage_index, _ComputationType.BACKWARD, bwd_mb_index)
)
backward_op_ids.append(backward_op_id)
backward_op_id += 1
if zero_bubble_algorithm and op - warmup_ops >= num_1f1b_microbatches:
weight_stage_index = backward_stage_index(
backward_op_ids[weight_op_count]
)
weight_stage_mb_index[weight_stage_index] = (
weight_mb_index := weight_stage_mb_index[weight_stage_index]
) + 1
rank_ops.append(
_Action(
weight_stage_index, _ComputationType.WEIGHT, weight_mb_index
)
)
weight_op_count += 1
while weight_op_count < len(backward_op_ids):
weight_stage_index = backward_stage_index(backward_op_ids[weight_op_count])
weight_stage_mb_index[weight_stage_index] = (
weight_mb_index := weight_stage_mb_index[weight_stage_index]
) + 1
rank_ops.append(
_Action(weight_stage_index, _ComputationType.WEIGHT, weight_mb_index)
)
weight_op_count += 1
return rank_ops