support zb1p and zb2p algorithms (#130752)

Previously, we have proved that ZB2P is not truly zero bubble when num_local_stages exceed 4 and so only ZB1P was supported.

We did a few tweaks to the ZB2P to really make it zero bubble. Algorithm and proof is attached.
[zero_bubble.pdf](https://github.com/user-attachments/files/16238738/zero_bubble.pdf)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130752
Approved by: https://github.com/H-Huang
This commit is contained in:
Haoci Zhang
2024-07-24 17:58:46 +00:00
committed by PyTorch MergeBot
parent 5e6cfb7db5
commit 8fe5b93667
5 changed files with 225 additions and 67 deletions

View File

@ -489,6 +489,8 @@ Pipeline Schedules
.. autoclass:: ScheduleLoopedBFS
.. autoclass:: ZeroBubbleAlgorithm
.. autoclass:: PipelineScheduleSingle
:members:

View File

@ -7,6 +7,7 @@ from torch.distributed.pipelining import (
ScheduleFlexibleInterleaved1F1B,
ScheduleInterleaved1F1B,
ScheduleLoopedBFS,
ZeroBubbleAlgorithm,
)
from torch.distributed.pipelining.schedules import (
_format_pipeline_order,
@ -127,14 +128,21 @@ class TestSchedulePlan(TestCase):
warmup_ops = warmups_ops_last_stage + 2 * (group_size - 1)
warmup_ops = min(warmup_ops, num_microbatches * num_local_stages)
for i in range(2):
zero_bubble_algorithms = [
None,
ZeroBubbleAlgorithm.ZB1P,
ZeroBubbleAlgorithm.ZB2P,
]
for i in range(len(zero_bubble_algorithms)):
num_stages = num_local_stages * group_size
stages = [
MockPipelineStage(group_size=group_size, num_stages=num_stages)
for i in range(num_local_stages)
]
schedule = ScheduleClass(
stages, num_microbatches, enable_zero_bubble=(i == 0)
stages,
num_microbatches,
zero_bubble_algorithm=zero_bubble_algorithms[i],
)
formatted_pipeline_order = _format_pipeline_order(
schedule.pipeline_order
@ -144,7 +152,7 @@ class TestSchedulePlan(TestCase):
schedule.pipeline_order,
num_microbatches,
num_stages,
enable_zero_bubble=(i == 0),
enable_zero_bubble=(zero_bubble_algorithms[i] is not None),
)

View File

@ -19,6 +19,7 @@ from torch.distributed.pipelining import (
ScheduleGPipe,
ScheduleInterleaved1F1B,
ScheduleLoopedBFS,
ZeroBubbleAlgorithm,
)
from torch.testing._internal.common_cuda import TEST_MULTIGPU
from torch.testing._internal.common_distributed import (
@ -555,7 +556,10 @@ class ScheduleTest(MultiProcContinousTest):
# Attach to a schedule
schedule = ScheduleClass(
stages, chunks, loss_fn=full_loss_fn, enable_zero_bubble=True
stages,
chunks,
loss_fn=full_loss_fn,
zero_bubble_algorithm=ZeroBubbleAlgorithm.ZB2P,
)
for _ in range(2):

View File

@ -6,6 +6,7 @@ from .schedules import (
ScheduleGPipe,
ScheduleInterleaved1F1B,
ScheduleLoopedBFS,
ZeroBubbleAlgorithm,
)
from .stage import build_stage, PipelineStage
@ -22,4 +23,5 @@ __all__ = [
"ScheduleGPipe",
"ScheduleInterleaved1F1B",
"ScheduleLoopedBFS",
"ZeroBubbleAlgorithm",
]

View File

@ -4,6 +4,7 @@
import csv
import itertools
import logging
import math
import re
from abc import ABC, abstractmethod
from collections import defaultdict
@ -64,6 +65,20 @@ W = _ComputationType.WEIGHT
_action_regex = re.compile(r"(\d+)([F,B,W])(\d*)")
class ZeroBubbleAlgorithm(Enum):
ZB1P = 1
ZB2P = 2
ZBV = 3
def __str__(self):
str_map = {
ZeroBubbleAlgorithm.ZB1P: "ZB1P",
ZeroBubbleAlgorithm.ZB2P: "ZB2P",
ZeroBubbleAlgorithm.ZBV: "ZBV",
}
return str_map[self]
class _Action(NamedTuple):
stage_index: int
computation_type: _ComputationType
@ -1184,14 +1199,10 @@ def _get_1f1b_rank_ops(
rank,
forward_stage_index,
backward_stage_index,
num_1f1b_microbatches=0,
enable_zero_bubble=False,
):
# All stages start with handling microbatch 0
fwd_stage_mb_index: Dict[int, int] = defaultdict(int)
bwd_stage_mb_index: Dict[int, int] = defaultdict(int)
weight_stage_mb_index: Dict[int, int] = defaultdict(int)
# Store the list of operations used for that rank
rank_ops: List[Optional[_Action]] = []
# Pre-padding, rank starts with no-ops based on the warmup.
@ -1208,14 +1219,8 @@ def _get_1f1b_rank_ops(
n_local_stages * pp_group_size + 2 * (pp_group_size - 1 - rank)
) - (warmup_ops + rank)
if enable_zero_bubble:
post_warmup_ops = pp_group_size - rank - 1
total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops
backward_op_ids = []
weight_op_count = 0
for op in range(total_ops):
# Warmup phase
if op < warmup_ops:
@ -1246,28 +1251,11 @@ def _get_1f1b_rank_ops(
rank_ops.append(
_Action(bwd_stage_index, _ComputationType.BACKWARD, bwd_mb_index)
)
backward_op_ids.append(op)
if enable_zero_bubble and op - warmup_ops >= num_1f1b_microbatches:
weight_stage_index = backward_stage_index(
backward_op_ids[weight_op_count]
)
weight_stage_mb_index[weight_stage_index] = (
weight_mb_index := weight_stage_mb_index[weight_stage_index]
) + 1
rank_ops.append(
_Action(
weight_stage_index, _ComputationType.WEIGHT, weight_mb_index
)
)
weight_op_count += 1
# Cooldown phase
else:
# During cooldown phase, we need steps to align with 1f1b happening in other ranks
# TODO: we don't need to always append, after all 1f1b are finished we can stop appending None
if not enable_zero_bubble:
rank_ops.append(None)
rank_ops.append(None)
bwd_stage_index = backward_stage_index(op)
bwd_stage_mb_index[bwd_stage_index] = (
bwd_mb_index := bwd_stage_mb_index[bwd_stage_index]
@ -1275,32 +1263,6 @@ def _get_1f1b_rank_ops(
rank_ops.append(
_Action(bwd_stage_index, _ComputationType.BACKWARD, bwd_mb_index)
)
backward_op_ids.append(op)
if enable_zero_bubble and op - warmup_ops >= num_1f1b_microbatches:
weight_stage_index = backward_stage_index(
backward_op_ids[weight_op_count]
)
weight_stage_mb_index[weight_stage_index] = (
weight_mb_index := weight_stage_mb_index[weight_stage_index]
) + 1
rank_ops.append(
_Action(
weight_stage_index, _ComputationType.WEIGHT, weight_mb_index
)
)
weight_op_count += 1
while enable_zero_bubble and weight_op_count < len(backward_op_ids):
weight_stage_index = backward_stage_index(backward_op_ids[weight_op_count])
weight_stage_mb_index[weight_stage_index] = (
weight_mb_index := weight_stage_mb_index[weight_stage_index]
) + 1
rank_ops.append(
_Action(weight_stage_index, _ComputationType.WEIGHT, weight_mb_index)
)
weight_op_count += 1
return rank_ops
@ -1419,7 +1381,8 @@ class ScheduleFlexibleInterleaved1F1B(PipelineScheduleMulti):
1. pp_group_size = 4, n_microbatches = 10. We will have num_rounds = 2 and n_microbatches % 2 is 0.
2. pp_group_size = 4, n_microbatches = 3. We will have num_rounds = 1 and n_microbatches % 1 is 0.
When enable_zero_bubble is True, we will use the ZB1P schedule in https://openreview.net/pdf?id=tuzTN0eIO5
When zero_bubble_algorithm is passed in, we will use the corresponding schedule in
https://openreview.net/pdf?id=tuzTN0eIO5
"""
def __init__(
@ -1430,7 +1393,7 @@ class ScheduleFlexibleInterleaved1F1B(PipelineScheduleMulti):
args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None,
kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
enable_zero_bubble: bool = False,
zero_bubble_algorithm: Optional[ZeroBubbleAlgorithm] = None,
):
self.pp_group_size = stages[0].group_size
super().__init__(
@ -1440,13 +1403,16 @@ class ScheduleFlexibleInterleaved1F1B(PipelineScheduleMulti):
args_chunk_spec=args_chunk_spec,
kwargs_chunk_spec=kwargs_chunk_spec,
output_merge_spec=output_merge_spec,
use_full_backward=not enable_zero_bubble,
use_full_backward=not zero_bubble_algorithm,
)
self.n_local_stages = len(stages)
self.rank = stages[0].group_rank
self.number_of_rounds = max(1, n_microbatches // self.pp_group_size)
self.microbatches_per_round = n_microbatches // self.number_of_rounds
self.enable_zero_bubble = enable_zero_bubble
self.zero_bubble_algorithm = zero_bubble_algorithm
if self.zero_bubble_algorithm is ZeroBubbleAlgorithm.ZBV:
raise ValueError("ZBV is not yet supported")
if n_microbatches % self.number_of_rounds != 0:
raise ValueError(
"Flexible Interleaved 1F1B requires the number of microbatches to be a "
@ -1475,7 +1441,9 @@ class ScheduleFlexibleInterleaved1F1B(PipelineScheduleMulti):
self.n_local_stages - 1
) * self.microbatches_per_round
# Increment warmup operations by 2 for each hop away from the last stage
multiply_factor = 1 if self.enable_zero_bubble else 2
multiply_factor = 1
if self.zero_bubble_algorithm is ZeroBubbleAlgorithm.ZB2P:
multiply_factor = 2
warmup_ops = warmups_ops_last_stage + multiply_factor * (
(self.pp_group_size - 1) - rank
)
@ -1517,10 +1485,12 @@ class ScheduleFlexibleInterleaved1F1B(PipelineScheduleMulti):
)
return (local_index * self.pp_group_size) + rank
if self.enable_zero_bubble:
if self.zero_bubble_algorithm:
num_1f1b_microbatches = rank
if self.zero_bubble_algorithm is ZeroBubbleAlgorithm.ZB2P:
num_1f1b_microbatches = 2 * rank
return _get_1f1b_rank_ops(
return self._get_1f1b_rank_ops_zero_bubble(
self.n_local_stages,
self.pp_group_size,
warmup_ops,
@ -1530,7 +1500,8 @@ class ScheduleFlexibleInterleaved1F1B(PipelineScheduleMulti):
forward_stage_index,
backward_stage_index,
num_1f1b_microbatches,
enable_zero_bubble=True,
zero_bubble_algorithm=self.zero_bubble_algorithm,
forward_local_stage_one_index=self.pp_group_size + rank,
)
return _get_1f1b_rank_ops(
@ -1546,7 +1517,7 @@ class ScheduleFlexibleInterleaved1F1B(PipelineScheduleMulti):
def _add_bubbles_to_actions(self, num_stages_global):
actions = self.pipeline_order
if not self.enable_zero_bubble:
if not self.zero_bubble_algorithm:
return actions
def need_bubble(stage, op, microbatch, num_stages_global, seen_ops):
@ -1609,3 +1580,174 @@ class ScheduleFlexibleInterleaved1F1B(PipelineScheduleMulti):
f"Non zero bubbles added: {total_bubbles_added=} {bubbles_added=}" # noqa: G004
)
return result
def _get_1f1b_rank_ops_zero_bubble(
self,
n_local_stages,
pp_group_size,
warmup_ops,
fwd_bwd_ops,
cooldown_ops,
rank,
forward_stage_index,
backward_stage_index,
num_1f1b_microbatches,
zero_bubble_algorithm,
forward_local_stage_one_index,
):
# All stages start with handling microbatch 0
fwd_stage_mb_index: Dict[int, int] = defaultdict(int)
bwd_stage_mb_index: Dict[int, int] = defaultdict(int)
weight_stage_mb_index: Dict[int, int] = defaultdict(int)
# Store the list of operations used for that rank
rank_ops: List[Optional[_Action]] = []
# Pre-padding, rank starts with no-ops based on the warmup.
for _ in range(rank):
rank_ops.append(None)
# These are used to calculate the number of slots to fill with no-ops, to account for the delay in warmup
# when we want to wait for the backward to trickle back up and start 1f1b to align all ranks.
# Formula:
# pre-padding + warmup_ops + post_warmup_ops = earliest time step of first backward
# post_warmup_ops = [earliest time step of first backward] - (warmup_ops + pre-padding)
# earliest time step of first backward = [local_stages * group_size + 2 * (group_size - 1 - rank)]
# warmup_ops = calculated above
post_warmup_ops = (
n_local_stages * pp_group_size + 2 * (pp_group_size - 1 - rank)
) - (warmup_ops + rank)
if zero_bubble_algorithm is ZeroBubbleAlgorithm.ZB1P:
post_warmup_ops = pp_group_size - rank - 1
elif zero_bubble_algorithm is ZeroBubbleAlgorithm.ZB2P:
post_warmup_ops = 0
total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops
prefill_steps_1b1w = 0
if zero_bubble_algorithm is ZeroBubbleAlgorithm.ZB2P:
prefill_steps_1b1w = max(0, math.ceil((pp_group_size - 4) / 2) - rank)
backward_op_ids = []
weight_op_count = 0
forward_op_id = 0
backward_op_id = warmup_ops
has_backfilled = False
for op in range(total_ops - prefill_steps_1b1w):
# Warmup phase
if op < warmup_ops:
fwd_stage_index = forward_stage_index(forward_op_id)
# This will assign the current microbatch index and update it as well
fwd_stage_mb_index[fwd_stage_index] = (
mb_index := fwd_stage_mb_index[fwd_stage_index]
) + 1
rank_ops.append(
_Action(fwd_stage_index, _ComputationType.FORWARD, mb_index)
)
if forward_op_id == warmup_ops - 1:
# This is the last step in the warmup phase, so we need to wait for the backward to trickle back up
rank_ops.extend([None] * post_warmup_ops)
forward_op_id += 1
# 1F1B Phase (forward and backward)
elif warmup_ops <= op < warmup_ops + fwd_bwd_ops:
fwd_stage_index = forward_stage_index(forward_op_id)
if (
fwd_stage_index == forward_local_stage_one_index
and not has_backfilled
):
has_backfilled = True
for _ in range(prefill_steps_1b1w):
bwd_stage_index = backward_stage_index(backward_op_id)
bwd_stage_mb_index[bwd_stage_index] = (
bwd_mb_index := bwd_stage_mb_index[bwd_stage_index]
) + 1
rank_ops.append(
_Action(
bwd_stage_index, _ComputationType.BACKWARD, bwd_mb_index
)
)
backward_op_ids.append(backward_op_id)
backward_op_id += 1
weight_stage_index = backward_stage_index(
backward_op_ids[weight_op_count]
)
weight_stage_mb_index[weight_stage_index] = (
weight_mb_index := weight_stage_mb_index[weight_stage_index]
) + 1
rank_ops.append(
_Action(
weight_stage_index,
_ComputationType.WEIGHT,
weight_mb_index,
)
)
weight_op_count += 1
fwd_stage_mb_index[fwd_stage_index] = (
fwd_mb_index := fwd_stage_mb_index[fwd_stage_index]
) + 1
rank_ops.append(
_Action(fwd_stage_index, _ComputationType.FORWARD, fwd_mb_index)
)
bwd_stage_index = backward_stage_index(backward_op_id)
bwd_stage_mb_index[bwd_stage_index] = (
bwd_mb_index := bwd_stage_mb_index[bwd_stage_index]
) + 1
rank_ops.append(
_Action(bwd_stage_index, _ComputationType.BACKWARD, bwd_mb_index)
)
backward_op_ids.append(backward_op_id)
forward_op_id += 1
backward_op_id += 1
if op - warmup_ops >= num_1f1b_microbatches:
weight_stage_index = backward_stage_index(
backward_op_ids[weight_op_count]
)
weight_stage_mb_index[weight_stage_index] = (
weight_mb_index := weight_stage_mb_index[weight_stage_index]
) + 1
rank_ops.append(
_Action(
weight_stage_index, _ComputationType.WEIGHT, weight_mb_index
)
)
weight_op_count += 1
# Cooldown phase
else:
bwd_stage_index = backward_stage_index(backward_op_id)
bwd_stage_mb_index[bwd_stage_index] = (
bwd_mb_index := bwd_stage_mb_index[bwd_stage_index]
) + 1
rank_ops.append(
_Action(bwd_stage_index, _ComputationType.BACKWARD, bwd_mb_index)
)
backward_op_ids.append(backward_op_id)
backward_op_id += 1
if zero_bubble_algorithm and op - warmup_ops >= num_1f1b_microbatches:
weight_stage_index = backward_stage_index(
backward_op_ids[weight_op_count]
)
weight_stage_mb_index[weight_stage_index] = (
weight_mb_index := weight_stage_mb_index[weight_stage_index]
) + 1
rank_ops.append(
_Action(
weight_stage_index, _ComputationType.WEIGHT, weight_mb_index
)
)
weight_op_count += 1
while weight_op_count < len(backward_op_ids):
weight_stage_index = backward_stage_index(backward_op_ids[weight_op_count])
weight_stage_mb_index[weight_stage_index] = (
weight_mb_index := weight_stage_mb_index[weight_stage_index]
) + 1
rank_ops.append(
_Action(weight_stage_index, _ComputationType.WEIGHT, weight_mb_index)
)
weight_op_count += 1
return rank_ops