[PP] Forward only schedule (#132177)

`python test/distributed/pipelining/test_schedule_multiproc.py -k test_forward_only`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132177
Approved by: https://github.com/lessw2020
This commit is contained in:
Howard Huang
2024-08-01 06:46:56 -07:00
committed by PyTorch MergeBot
parent ee09d066d3
commit c59f3fff52
3 changed files with 103 additions and 0 deletions

View File

@ -12,6 +12,7 @@ from schedule_registry import ScheduleUnbalanced, ScheduleVShaped
import torch
import torch.distributed as dist
from torch.distributed.pipelining import (
_ScheduleForwardOnly,
pipeline,
PipelineStage,
Schedule1F1B,
@ -56,6 +57,56 @@ class ScheduleTest(MultiProcContinousTest):
dev_id = cls.rank % torch.cuda.device_count()
cls.device = torch.device(f"cuda:{dev_id}")
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@parametrize("ScheduleClass", [_ScheduleForwardOnly])
def test_forward_only(self, ScheduleClass):
mod = MultiMLP(d_hid, n_layers=self.world_size)
mod.to(self.device)
mod_ref = copy.deepcopy(mod)
x = torch.randn(batch_size, d_hid, device=self.device)
x_clone = x.clone()
num_microbatches = 4
x_mb = x.chunk(num_microbatches)[0]
# Create a pipeline
split_spec = mod.split_spec if hasattr(mod, "split_spec") else None
pipe = pipeline(
mod,
mb_args=(x_mb,),
split_spec=split_spec,
)
stage = pipe.build_stage(
self.rank,
self.device,
)
# Attach to a schedule
schedule = ScheduleClass(stage, num_microbatches)
# Run
num_iters = 20
for _ in range(num_iters):
if self.rank == 0:
schedule.step(x)
dist.recv(x, src=self.world_size - 1)
elif self.rank == self.world_size - 1:
out = schedule.step()
dist.send(out, dst=0)
else:
schedule.step()
# Validate pipelined output is the same as reference model
if self.rank == self.world_size - 1:
for _ in range(num_iters):
x_clone = mod_ref(x_clone)
torch.testing.assert_close(x_clone, out)
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])

View File

@ -1,6 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
from ._IR import Pipe, pipe_split, pipeline, SplitPoint
from .schedules import (
_ScheduleForwardOnly,
Schedule1F1B,
ScheduleFlexibleInterleaved1F1B,
ScheduleGPipe,

View File

@ -564,6 +564,57 @@ class PipelineScheduleSingle(_PipelineSchedule):
return None
class _ScheduleForwardOnly(PipelineScheduleSingle):
"""
The forward-only schedule.
Will go through all the microbatches and perform only the forward pass
"""
def _step_microbatches(
self,
arg_mbs: Optional[List] = None,
kwarg_mbs: Optional[List] = None,
target_mbs: Optional[List] = None,
losses: Optional[List] = None,
):
"""
Run one iteration of the pipeline schedule
"""
if target_mbs is not None or losses is not None:
raise RuntimeError(
"Forward-only schedule does not support loss computation"
)
arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)
# Delay send waits
fwd_sends_to_wait: List[dist.Work] = []
# Run microbatches
for i in range(self._n_microbatches):
with record_function(f"Forward {i}"):
ops = self._stage.get_fwd_recv_ops(i)
works = _sorted_batch_p2p(ops, desc="fwd_recv")
for work in works.values():
work.wait()
self._stage.forward_one_chunk(i, arg_mbs[i], kwarg_mbs[i]) # type: ignore[index]
ops = self._stage.get_fwd_send_ops(i)
works = _sorted_batch_p2p(ops, desc="fwd_send")
fwd_sends_to_wait.extend(works.values())
logger.debug(
f"[{self._stage.stage_index}] Forwarded microbatch {i}" # noqa: G004
)
# Wait for all forward sends to finish
# This should not have performance impact because by the time the first
# backward arrives all the forward sends should have been finished.
for work in fwd_sends_to_wait:
work.wait()
class ScheduleGPipe(PipelineScheduleSingle):
"""
The GPipe schedule.