Files
pytorch/test/distributed/pipelining/schedule_registry.py

231 lines
6.1 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]
# This file is a Schedule zoo for testing torch.distributed.pipelining.
# It includes schedules designed purely for testing purposes
from typing import Callable, Optional
from torch.distributed.pipelining.schedules import (
_Action,
_ComputationType,
_PipelineScheduleRuntime,
PipelineScheduleMulti,
RECV_B,
RECV_F,
SEND_B,
SEND_F,
)
from torch.distributed.pipelining.stage import _PipelineStageBase
F = _ComputationType.FORWARD
B = _ComputationType.FULL_BACKWARD
W = _ComputationType.BACKWARD_WEIGHT
I = _ComputationType.BACKWARD_INPUT
class ScheduleVShaped(PipelineScheduleMulti):
n_stages = 4
rank_stages = {
0: [0, 3],
1: [1, 2],
}
def __init__(
self,
stages: list[_PipelineStageBase],
n_microbatches: int,
loss_fn: Optional[Callable] = None,
scale_grads: bool = True,
):
super().__init__(
stages=stages,
n_microbatches=n_microbatches,
loss_fn=loss_fn,
scale_grads=scale_grads,
)
# Go through one microbatch
# Note(whc) - it might be easier to work with this schedules by writing them as a list of
# ["0F0", ...] and then parsing them in the test infra to turn them into actions.
self.pipeline_order = {
0: [
_Action(0, F, 0),
None,
None,
_Action(3, F, 0),
_Action(3, B, 0),
None,
None,
_Action(0, B, 0),
],
1: [
None,
_Action(1, F, 0),
_Action(2, F, 0),
None,
None,
_Action(2, B, 0),
_Action(1, B, 0),
None,
],
}
self._validate_and_set_stage_mapping(self.pipeline_order)
class ScheduleUnbalanced(PipelineScheduleMulti):
n_stages = 5
rank_stages = {
0: [0, 1, 4],
1: [2, 3],
}
def __init__(
self,
stages: list[_PipelineStageBase],
n_microbatches: int,
loss_fn: Optional[Callable] = None,
scale_grads: bool = True,
):
super().__init__(
stages=stages,
n_microbatches=n_microbatches,
loss_fn=loss_fn,
scale_grads=scale_grads,
)
self.pipeline_order = {
0: [
_Action(0, F, 0),
_Action(1, F, 0),
None,
None,
_Action(4, F, 0),
_Action(4, B, 0),
None,
None,
_Action(1, B, 0),
_Action(0, B, 0),
],
1: [
None,
None,
_Action(2, F, 0),
_Action(3, F, 0),
None,
None,
_Action(3, B, 0),
_Action(2, B, 0),
None,
None,
],
}
self._validate_and_set_stage_mapping(self.pipeline_order)
class ScheduleWithW(PipelineScheduleMulti):
n_stages = 4
num_microbatches = 2
rank_stages = {
0: [0, 2],
1: [1, 3],
}
def __init__(
self,
stages: list[_PipelineStageBase],
n_microbatches: int,
loss_fn: Optional[Callable] = None,
enable_zero_bubble: bool = True,
scale_grads: bool = True,
):
super().__init__(
stages=stages,
n_microbatches=n_microbatches,
loss_fn=loss_fn,
scale_grads=scale_grads,
)
# Needs to be updated as part of all schedules using "W"
self.use_full_backward = False
# Go through two microbatches
self.pipeline_order = {
0: [
_Action(0, F, 0),
_Action(0, F, 1),
_Action(2, F, 0),
_Action(2, F, 1),
None,
_Action(2, I, 0),
_Action(2, W, 0),
_Action(0, I, 0),
_Action(2, I, 1),
_Action(0, W, 0),
_Action(0, I, 1),
_Action(2, W, 1),
_Action(0, W, 1),
],
1: [
None,
_Action(1, F, 0),
_Action(1, F, 1),
_Action(3, F, 0),
_Action(3, I, 0),
_Action(3, F, 1),
_Action(1, I, 0),
_Action(3, I, 1),
_Action(3, W, 0),
_Action(1, I, 1),
_Action(1, W, 0),
_Action(3, W, 1),
_Action(1, W, 1),
],
}
self._validate_and_set_stage_mapping(self.pipeline_order)
class ScheduleWithReorderedB(_PipelineScheduleRuntime):
n_stages = 2
num_microbatches = 2
rank_stages = {
0: [0],
1: [1],
}
def __init__(
self,
stages: list[_PipelineStageBase],
n_microbatches: int,
loss_fn: Optional[Callable] = None,
scale_grads: bool = True,
):
super().__init__(
stages=stages,
n_microbatches=n_microbatches,
loss_fn=loss_fn,
scale_grads=scale_grads,
)
# Go through two microbatches
self.pipeline_order_with_comms = {
0: [
_Action(0, F, 0),
_Action(0, F, 1),
_Action(0, SEND_F, 0),
_Action(0, SEND_F, 1),
_Action(0, RECV_B, 0),
_Action(0, RECV_B, 1),
_Action(0, B, 0),
_Action(0, B, 1),
],
1: [
_Action(1, RECV_F, 0),
_Action(1, RECV_F, 1),
_Action(1, F, 0),
_Action(1, F, 1),
_Action(1, B, 0),
_Action(1, B, 1),
_Action(1, SEND_B, 0),
_Action(1, SEND_B, 1),
],
}