mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "support zb1p and zb2p algorithms (#130752)"
This reverts commit 8fe5b93667b60e37c12d288659a25cbd5ae53c79.
Reverted https://github.com/pytorch/pytorch/pull/130752 on behalf of https://github.com/atalman due to Broke Periodic CI: distributed/pipelining/test_composability.py::ComposabilityTest::test_manual_with_data_parallel_dp_type_DDP_ScheduleClass4 [GH job link](https://github.com/pytorch/pytorch/actions/runs/10131472868/job/28014900187) [HUD commit link](8fe5b93667
) ([comment](https://github.com/pytorch/pytorch/pull/130752#issuecomment-2255819078))
This commit is contained in:
@ -489,8 +489,6 @@ Pipeline Schedules
|
||||
|
||||
.. autoclass:: ScheduleLoopedBFS
|
||||
|
||||
.. autoclass:: ZeroBubbleAlgorithm
|
||||
|
||||
.. autoclass:: PipelineScheduleSingle
|
||||
:members:
|
||||
|
||||
|
@ -7,7 +7,6 @@ from torch.distributed.pipelining import (
|
||||
ScheduleFlexibleInterleaved1F1B,
|
||||
ScheduleInterleaved1F1B,
|
||||
ScheduleLoopedBFS,
|
||||
ZeroBubbleAlgorithm,
|
||||
)
|
||||
from torch.distributed.pipelining.schedules import (
|
||||
_format_pipeline_order,
|
||||
@ -128,21 +127,14 @@ class TestSchedulePlan(TestCase):
|
||||
warmup_ops = warmups_ops_last_stage + 2 * (group_size - 1)
|
||||
warmup_ops = min(warmup_ops, num_microbatches * num_local_stages)
|
||||
|
||||
zero_bubble_algorithms = [
|
||||
None,
|
||||
ZeroBubbleAlgorithm.ZB1P,
|
||||
ZeroBubbleAlgorithm.ZB2P,
|
||||
]
|
||||
for i in range(len(zero_bubble_algorithms)):
|
||||
for i in range(2):
|
||||
num_stages = num_local_stages * group_size
|
||||
stages = [
|
||||
MockPipelineStage(group_size=group_size, num_stages=num_stages)
|
||||
for i in range(num_local_stages)
|
||||
]
|
||||
schedule = ScheduleClass(
|
||||
stages,
|
||||
num_microbatches,
|
||||
zero_bubble_algorithm=zero_bubble_algorithms[i],
|
||||
stages, num_microbatches, enable_zero_bubble=(i == 0)
|
||||
)
|
||||
formatted_pipeline_order = _format_pipeline_order(
|
||||
schedule.pipeline_order
|
||||
@ -152,7 +144,7 @@ class TestSchedulePlan(TestCase):
|
||||
schedule.pipeline_order,
|
||||
num_microbatches,
|
||||
num_stages,
|
||||
enable_zero_bubble=(zero_bubble_algorithms[i] is not None),
|
||||
enable_zero_bubble=(i == 0),
|
||||
)
|
||||
|
||||
|
||||
|
@ -19,7 +19,6 @@ from torch.distributed.pipelining import (
|
||||
ScheduleGPipe,
|
||||
ScheduleInterleaved1F1B,
|
||||
ScheduleLoopedBFS,
|
||||
ZeroBubbleAlgorithm,
|
||||
)
|
||||
from torch.testing._internal.common_cuda import TEST_MULTIGPU
|
||||
from torch.testing._internal.common_distributed import (
|
||||
@ -556,10 +555,7 @@ class ScheduleTest(MultiProcContinousTest):
|
||||
|
||||
# Attach to a schedule
|
||||
schedule = ScheduleClass(
|
||||
stages,
|
||||
chunks,
|
||||
loss_fn=full_loss_fn,
|
||||
zero_bubble_algorithm=ZeroBubbleAlgorithm.ZB2P,
|
||||
stages, chunks, loss_fn=full_loss_fn, enable_zero_bubble=True
|
||||
)
|
||||
|
||||
for _ in range(2):
|
||||
|
@ -6,7 +6,6 @@ from .schedules import (
|
||||
ScheduleGPipe,
|
||||
ScheduleInterleaved1F1B,
|
||||
ScheduleLoopedBFS,
|
||||
ZeroBubbleAlgorithm,
|
||||
)
|
||||
from .stage import build_stage, PipelineStage
|
||||
|
||||
@ -23,5 +22,4 @@ __all__ = [
|
||||
"ScheduleGPipe",
|
||||
"ScheduleInterleaved1F1B",
|
||||
"ScheduleLoopedBFS",
|
||||
"ZeroBubbleAlgorithm",
|
||||
]
|
||||
|
@ -4,7 +4,6 @@
|
||||
import csv
|
||||
import itertools
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
@ -65,20 +64,6 @@ W = _ComputationType.WEIGHT
|
||||
_action_regex = re.compile(r"(\d+)([F,B,W])(\d*)")
|
||||
|
||||
|
||||
class ZeroBubbleAlgorithm(Enum):
|
||||
ZB1P = 1
|
||||
ZB2P = 2
|
||||
ZBV = 3
|
||||
|
||||
def __str__(self):
|
||||
str_map = {
|
||||
ZeroBubbleAlgorithm.ZB1P: "ZB1P",
|
||||
ZeroBubbleAlgorithm.ZB2P: "ZB2P",
|
||||
ZeroBubbleAlgorithm.ZBV: "ZBV",
|
||||
}
|
||||
return str_map[self]
|
||||
|
||||
|
||||
class _Action(NamedTuple):
|
||||
stage_index: int
|
||||
computation_type: _ComputationType
|
||||
@ -1199,10 +1184,14 @@ def _get_1f1b_rank_ops(
|
||||
rank,
|
||||
forward_stage_index,
|
||||
backward_stage_index,
|
||||
num_1f1b_microbatches=0,
|
||||
enable_zero_bubble=False,
|
||||
):
|
||||
# All stages start with handling microbatch 0
|
||||
fwd_stage_mb_index: Dict[int, int] = defaultdict(int)
|
||||
bwd_stage_mb_index: Dict[int, int] = defaultdict(int)
|
||||
weight_stage_mb_index: Dict[int, int] = defaultdict(int)
|
||||
|
||||
# Store the list of operations used for that rank
|
||||
rank_ops: List[Optional[_Action]] = []
|
||||
# Pre-padding, rank starts with no-ops based on the warmup.
|
||||
@ -1219,8 +1208,14 @@ def _get_1f1b_rank_ops(
|
||||
n_local_stages * pp_group_size + 2 * (pp_group_size - 1 - rank)
|
||||
) - (warmup_ops + rank)
|
||||
|
||||
if enable_zero_bubble:
|
||||
post_warmup_ops = pp_group_size - rank - 1
|
||||
|
||||
total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops
|
||||
|
||||
backward_op_ids = []
|
||||
weight_op_count = 0
|
||||
|
||||
for op in range(total_ops):
|
||||
# Warmup phase
|
||||
if op < warmup_ops:
|
||||
@ -1251,11 +1246,28 @@ def _get_1f1b_rank_ops(
|
||||
rank_ops.append(
|
||||
_Action(bwd_stage_index, _ComputationType.BACKWARD, bwd_mb_index)
|
||||
)
|
||||
backward_op_ids.append(op)
|
||||
|
||||
if enable_zero_bubble and op - warmup_ops >= num_1f1b_microbatches:
|
||||
weight_stage_index = backward_stage_index(
|
||||
backward_op_ids[weight_op_count]
|
||||
)
|
||||
weight_stage_mb_index[weight_stage_index] = (
|
||||
weight_mb_index := weight_stage_mb_index[weight_stage_index]
|
||||
) + 1
|
||||
rank_ops.append(
|
||||
_Action(
|
||||
weight_stage_index, _ComputationType.WEIGHT, weight_mb_index
|
||||
)
|
||||
)
|
||||
weight_op_count += 1
|
||||
# Cooldown phase
|
||||
else:
|
||||
# During cooldown phase, we need steps to align with 1f1b happening in other ranks
|
||||
# TODO: we don't need to always append, after all 1f1b are finished we can stop appending None
|
||||
rank_ops.append(None)
|
||||
if not enable_zero_bubble:
|
||||
rank_ops.append(None)
|
||||
|
||||
bwd_stage_index = backward_stage_index(op)
|
||||
bwd_stage_mb_index[bwd_stage_index] = (
|
||||
bwd_mb_index := bwd_stage_mb_index[bwd_stage_index]
|
||||
@ -1263,6 +1275,32 @@ def _get_1f1b_rank_ops(
|
||||
rank_ops.append(
|
||||
_Action(bwd_stage_index, _ComputationType.BACKWARD, bwd_mb_index)
|
||||
)
|
||||
backward_op_ids.append(op)
|
||||
|
||||
if enable_zero_bubble and op - warmup_ops >= num_1f1b_microbatches:
|
||||
weight_stage_index = backward_stage_index(
|
||||
backward_op_ids[weight_op_count]
|
||||
)
|
||||
weight_stage_mb_index[weight_stage_index] = (
|
||||
weight_mb_index := weight_stage_mb_index[weight_stage_index]
|
||||
) + 1
|
||||
rank_ops.append(
|
||||
_Action(
|
||||
weight_stage_index, _ComputationType.WEIGHT, weight_mb_index
|
||||
)
|
||||
)
|
||||
weight_op_count += 1
|
||||
|
||||
while enable_zero_bubble and weight_op_count < len(backward_op_ids):
|
||||
weight_stage_index = backward_stage_index(backward_op_ids[weight_op_count])
|
||||
weight_stage_mb_index[weight_stage_index] = (
|
||||
weight_mb_index := weight_stage_mb_index[weight_stage_index]
|
||||
) + 1
|
||||
rank_ops.append(
|
||||
_Action(weight_stage_index, _ComputationType.WEIGHT, weight_mb_index)
|
||||
)
|
||||
weight_op_count += 1
|
||||
|
||||
return rank_ops
|
||||
|
||||
|
||||
@ -1381,8 +1419,7 @@ class ScheduleFlexibleInterleaved1F1B(PipelineScheduleMulti):
|
||||
1. pp_group_size = 4, n_microbatches = 10. We will have num_rounds = 2 and n_microbatches % 2 is 0.
|
||||
2. pp_group_size = 4, n_microbatches = 3. We will have num_rounds = 1 and n_microbatches % 1 is 0.
|
||||
|
||||
When zero_bubble_algorithm is passed in, we will use the corresponding schedule in
|
||||
https://openreview.net/pdf?id=tuzTN0eIO5
|
||||
When enable_zero_bubble is True, we will use the ZB1P schedule in https://openreview.net/pdf?id=tuzTN0eIO5
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -1393,7 +1430,7 @@ class ScheduleFlexibleInterleaved1F1B(PipelineScheduleMulti):
|
||||
args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None,
|
||||
kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
|
||||
output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
||||
zero_bubble_algorithm: Optional[ZeroBubbleAlgorithm] = None,
|
||||
enable_zero_bubble: bool = False,
|
||||
):
|
||||
self.pp_group_size = stages[0].group_size
|
||||
super().__init__(
|
||||
@ -1403,16 +1440,13 @@ class ScheduleFlexibleInterleaved1F1B(PipelineScheduleMulti):
|
||||
args_chunk_spec=args_chunk_spec,
|
||||
kwargs_chunk_spec=kwargs_chunk_spec,
|
||||
output_merge_spec=output_merge_spec,
|
||||
use_full_backward=not zero_bubble_algorithm,
|
||||
use_full_backward=not enable_zero_bubble,
|
||||
)
|
||||
self.n_local_stages = len(stages)
|
||||
self.rank = stages[0].group_rank
|
||||
self.number_of_rounds = max(1, n_microbatches // self.pp_group_size)
|
||||
self.microbatches_per_round = n_microbatches // self.number_of_rounds
|
||||
self.zero_bubble_algorithm = zero_bubble_algorithm
|
||||
if self.zero_bubble_algorithm is ZeroBubbleAlgorithm.ZBV:
|
||||
raise ValueError("ZBV is not yet supported")
|
||||
|
||||
self.enable_zero_bubble = enable_zero_bubble
|
||||
if n_microbatches % self.number_of_rounds != 0:
|
||||
raise ValueError(
|
||||
"Flexible Interleaved 1F1B requires the number of microbatches to be a "
|
||||
@ -1441,9 +1475,7 @@ class ScheduleFlexibleInterleaved1F1B(PipelineScheduleMulti):
|
||||
self.n_local_stages - 1
|
||||
) * self.microbatches_per_round
|
||||
# Increment warmup operations by 2 for each hop away from the last stage
|
||||
multiply_factor = 1
|
||||
if self.zero_bubble_algorithm is ZeroBubbleAlgorithm.ZB2P:
|
||||
multiply_factor = 2
|
||||
multiply_factor = 1 if self.enable_zero_bubble else 2
|
||||
warmup_ops = warmups_ops_last_stage + multiply_factor * (
|
||||
(self.pp_group_size - 1) - rank
|
||||
)
|
||||
@ -1485,12 +1517,10 @@ class ScheduleFlexibleInterleaved1F1B(PipelineScheduleMulti):
|
||||
)
|
||||
return (local_index * self.pp_group_size) + rank
|
||||
|
||||
if self.zero_bubble_algorithm:
|
||||
if self.enable_zero_bubble:
|
||||
num_1f1b_microbatches = rank
|
||||
if self.zero_bubble_algorithm is ZeroBubbleAlgorithm.ZB2P:
|
||||
num_1f1b_microbatches = 2 * rank
|
||||
|
||||
return self._get_1f1b_rank_ops_zero_bubble(
|
||||
return _get_1f1b_rank_ops(
|
||||
self.n_local_stages,
|
||||
self.pp_group_size,
|
||||
warmup_ops,
|
||||
@ -1500,8 +1530,7 @@ class ScheduleFlexibleInterleaved1F1B(PipelineScheduleMulti):
|
||||
forward_stage_index,
|
||||
backward_stage_index,
|
||||
num_1f1b_microbatches,
|
||||
zero_bubble_algorithm=self.zero_bubble_algorithm,
|
||||
forward_local_stage_one_index=self.pp_group_size + rank,
|
||||
enable_zero_bubble=True,
|
||||
)
|
||||
|
||||
return _get_1f1b_rank_ops(
|
||||
@ -1517,7 +1546,7 @@ class ScheduleFlexibleInterleaved1F1B(PipelineScheduleMulti):
|
||||
|
||||
def _add_bubbles_to_actions(self, num_stages_global):
|
||||
actions = self.pipeline_order
|
||||
if not self.zero_bubble_algorithm:
|
||||
if not self.enable_zero_bubble:
|
||||
return actions
|
||||
|
||||
def need_bubble(stage, op, microbatch, num_stages_global, seen_ops):
|
||||
@ -1580,174 +1609,3 @@ class ScheduleFlexibleInterleaved1F1B(PipelineScheduleMulti):
|
||||
f"Non zero bubbles added: {total_bubbles_added=} {bubbles_added=}" # noqa: G004
|
||||
)
|
||||
return result
|
||||
|
||||
def _get_1f1b_rank_ops_zero_bubble(
|
||||
self,
|
||||
n_local_stages,
|
||||
pp_group_size,
|
||||
warmup_ops,
|
||||
fwd_bwd_ops,
|
||||
cooldown_ops,
|
||||
rank,
|
||||
forward_stage_index,
|
||||
backward_stage_index,
|
||||
num_1f1b_microbatches,
|
||||
zero_bubble_algorithm,
|
||||
forward_local_stage_one_index,
|
||||
):
|
||||
# All stages start with handling microbatch 0
|
||||
fwd_stage_mb_index: Dict[int, int] = defaultdict(int)
|
||||
bwd_stage_mb_index: Dict[int, int] = defaultdict(int)
|
||||
weight_stage_mb_index: Dict[int, int] = defaultdict(int)
|
||||
|
||||
# Store the list of operations used for that rank
|
||||
rank_ops: List[Optional[_Action]] = []
|
||||
# Pre-padding, rank starts with no-ops based on the warmup.
|
||||
for _ in range(rank):
|
||||
rank_ops.append(None)
|
||||
# These are used to calculate the number of slots to fill with no-ops, to account for the delay in warmup
|
||||
# when we want to wait for the backward to trickle back up and start 1f1b to align all ranks.
|
||||
# Formula:
|
||||
# pre-padding + warmup_ops + post_warmup_ops = earliest time step of first backward
|
||||
# post_warmup_ops = [earliest time step of first backward] - (warmup_ops + pre-padding)
|
||||
# earliest time step of first backward = [local_stages * group_size + 2 * (group_size - 1 - rank)]
|
||||
# warmup_ops = calculated above
|
||||
post_warmup_ops = (
|
||||
n_local_stages * pp_group_size + 2 * (pp_group_size - 1 - rank)
|
||||
) - (warmup_ops + rank)
|
||||
|
||||
if zero_bubble_algorithm is ZeroBubbleAlgorithm.ZB1P:
|
||||
post_warmup_ops = pp_group_size - rank - 1
|
||||
elif zero_bubble_algorithm is ZeroBubbleAlgorithm.ZB2P:
|
||||
post_warmup_ops = 0
|
||||
|
||||
total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops
|
||||
|
||||
prefill_steps_1b1w = 0
|
||||
if zero_bubble_algorithm is ZeroBubbleAlgorithm.ZB2P:
|
||||
prefill_steps_1b1w = max(0, math.ceil((pp_group_size - 4) / 2) - rank)
|
||||
|
||||
backward_op_ids = []
|
||||
weight_op_count = 0
|
||||
forward_op_id = 0
|
||||
backward_op_id = warmup_ops
|
||||
has_backfilled = False
|
||||
|
||||
for op in range(total_ops - prefill_steps_1b1w):
|
||||
# Warmup phase
|
||||
if op < warmup_ops:
|
||||
fwd_stage_index = forward_stage_index(forward_op_id)
|
||||
# This will assign the current microbatch index and update it as well
|
||||
fwd_stage_mb_index[fwd_stage_index] = (
|
||||
mb_index := fwd_stage_mb_index[fwd_stage_index]
|
||||
) + 1
|
||||
rank_ops.append(
|
||||
_Action(fwd_stage_index, _ComputationType.FORWARD, mb_index)
|
||||
)
|
||||
if forward_op_id == warmup_ops - 1:
|
||||
# This is the last step in the warmup phase, so we need to wait for the backward to trickle back up
|
||||
rank_ops.extend([None] * post_warmup_ops)
|
||||
forward_op_id += 1
|
||||
|
||||
# 1F1B Phase (forward and backward)
|
||||
elif warmup_ops <= op < warmup_ops + fwd_bwd_ops:
|
||||
fwd_stage_index = forward_stage_index(forward_op_id)
|
||||
if (
|
||||
fwd_stage_index == forward_local_stage_one_index
|
||||
and not has_backfilled
|
||||
):
|
||||
has_backfilled = True
|
||||
for _ in range(prefill_steps_1b1w):
|
||||
bwd_stage_index = backward_stage_index(backward_op_id)
|
||||
bwd_stage_mb_index[bwd_stage_index] = (
|
||||
bwd_mb_index := bwd_stage_mb_index[bwd_stage_index]
|
||||
) + 1
|
||||
rank_ops.append(
|
||||
_Action(
|
||||
bwd_stage_index, _ComputationType.BACKWARD, bwd_mb_index
|
||||
)
|
||||
)
|
||||
backward_op_ids.append(backward_op_id)
|
||||
backward_op_id += 1
|
||||
weight_stage_index = backward_stage_index(
|
||||
backward_op_ids[weight_op_count]
|
||||
)
|
||||
weight_stage_mb_index[weight_stage_index] = (
|
||||
weight_mb_index := weight_stage_mb_index[weight_stage_index]
|
||||
) + 1
|
||||
rank_ops.append(
|
||||
_Action(
|
||||
weight_stage_index,
|
||||
_ComputationType.WEIGHT,
|
||||
weight_mb_index,
|
||||
)
|
||||
)
|
||||
weight_op_count += 1
|
||||
|
||||
fwd_stage_mb_index[fwd_stage_index] = (
|
||||
fwd_mb_index := fwd_stage_mb_index[fwd_stage_index]
|
||||
) + 1
|
||||
rank_ops.append(
|
||||
_Action(fwd_stage_index, _ComputationType.FORWARD, fwd_mb_index)
|
||||
)
|
||||
bwd_stage_index = backward_stage_index(backward_op_id)
|
||||
bwd_stage_mb_index[bwd_stage_index] = (
|
||||
bwd_mb_index := bwd_stage_mb_index[bwd_stage_index]
|
||||
) + 1
|
||||
rank_ops.append(
|
||||
_Action(bwd_stage_index, _ComputationType.BACKWARD, bwd_mb_index)
|
||||
)
|
||||
backward_op_ids.append(backward_op_id)
|
||||
forward_op_id += 1
|
||||
backward_op_id += 1
|
||||
|
||||
if op - warmup_ops >= num_1f1b_microbatches:
|
||||
weight_stage_index = backward_stage_index(
|
||||
backward_op_ids[weight_op_count]
|
||||
)
|
||||
weight_stage_mb_index[weight_stage_index] = (
|
||||
weight_mb_index := weight_stage_mb_index[weight_stage_index]
|
||||
) + 1
|
||||
rank_ops.append(
|
||||
_Action(
|
||||
weight_stage_index, _ComputationType.WEIGHT, weight_mb_index
|
||||
)
|
||||
)
|
||||
weight_op_count += 1
|
||||
# Cooldown phase
|
||||
else:
|
||||
bwd_stage_index = backward_stage_index(backward_op_id)
|
||||
bwd_stage_mb_index[bwd_stage_index] = (
|
||||
bwd_mb_index := bwd_stage_mb_index[bwd_stage_index]
|
||||
) + 1
|
||||
rank_ops.append(
|
||||
_Action(bwd_stage_index, _ComputationType.BACKWARD, bwd_mb_index)
|
||||
)
|
||||
backward_op_ids.append(backward_op_id)
|
||||
backward_op_id += 1
|
||||
|
||||
if zero_bubble_algorithm and op - warmup_ops >= num_1f1b_microbatches:
|
||||
weight_stage_index = backward_stage_index(
|
||||
backward_op_ids[weight_op_count]
|
||||
)
|
||||
weight_stage_mb_index[weight_stage_index] = (
|
||||
weight_mb_index := weight_stage_mb_index[weight_stage_index]
|
||||
) + 1
|
||||
rank_ops.append(
|
||||
_Action(
|
||||
weight_stage_index, _ComputationType.WEIGHT, weight_mb_index
|
||||
)
|
||||
)
|
||||
weight_op_count += 1
|
||||
|
||||
while weight_op_count < len(backward_op_ids):
|
||||
weight_stage_index = backward_stage_index(backward_op_ids[weight_op_count])
|
||||
weight_stage_mb_index[weight_stage_index] = (
|
||||
weight_mb_index := weight_stage_mb_index[weight_stage_index]
|
||||
) + 1
|
||||
rank_ops.append(
|
||||
_Action(weight_stage_index, _ComputationType.WEIGHT, weight_mb_index)
|
||||
)
|
||||
weight_op_count += 1
|
||||
|
||||
return rank_ops
|
||||
|
Reference in New Issue
Block a user