mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[PP] Forward only schedule (#132177)
`python test/distributed/pipelining/test_schedule_multiproc.py -k test_forward_only` Pull Request resolved: https://github.com/pytorch/pytorch/pull/132177 Approved by: https://github.com/lessw2020
This commit is contained in:
committed by
PyTorch MergeBot
parent
ee09d066d3
commit
c59f3fff52
@ -12,6 +12,7 @@ from schedule_registry import ScheduleUnbalanced, ScheduleVShaped
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed.pipelining import (
|
||||
_ScheduleForwardOnly,
|
||||
pipeline,
|
||||
PipelineStage,
|
||||
Schedule1F1B,
|
||||
@ -56,6 +57,56 @@ class ScheduleTest(MultiProcContinousTest):
|
||||
dev_id = cls.rank % torch.cuda.device_count()
|
||||
cls.device = torch.device(f"cuda:{dev_id}")
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
@parametrize("ScheduleClass", [_ScheduleForwardOnly])
|
||||
def test_forward_only(self, ScheduleClass):
|
||||
mod = MultiMLP(d_hid, n_layers=self.world_size)
|
||||
mod.to(self.device)
|
||||
|
||||
mod_ref = copy.deepcopy(mod)
|
||||
|
||||
x = torch.randn(batch_size, d_hid, device=self.device)
|
||||
x_clone = x.clone()
|
||||
|
||||
num_microbatches = 4
|
||||
x_mb = x.chunk(num_microbatches)[0]
|
||||
|
||||
# Create a pipeline
|
||||
split_spec = mod.split_spec if hasattr(mod, "split_spec") else None
|
||||
pipe = pipeline(
|
||||
mod,
|
||||
mb_args=(x_mb,),
|
||||
split_spec=split_spec,
|
||||
)
|
||||
|
||||
stage = pipe.build_stage(
|
||||
self.rank,
|
||||
self.device,
|
||||
)
|
||||
|
||||
# Attach to a schedule
|
||||
schedule = ScheduleClass(stage, num_microbatches)
|
||||
|
||||
# Run
|
||||
num_iters = 20
|
||||
for _ in range(num_iters):
|
||||
if self.rank == 0:
|
||||
schedule.step(x)
|
||||
dist.recv(x, src=self.world_size - 1)
|
||||
elif self.rank == self.world_size - 1:
|
||||
out = schedule.step()
|
||||
dist.send(out, dst=0)
|
||||
else:
|
||||
schedule.step()
|
||||
|
||||
# Validate pipelined output is the same as reference model
|
||||
if self.rank == self.world_size - 1:
|
||||
for _ in range(num_iters):
|
||||
x_clone = mod_ref(x_clone)
|
||||
|
||||
torch.testing.assert_close(x_clone, out)
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
@parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
|
||||
|
@ -1,6 +1,7 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
from ._IR import Pipe, pipe_split, pipeline, SplitPoint
|
||||
from .schedules import (
|
||||
_ScheduleForwardOnly,
|
||||
Schedule1F1B,
|
||||
ScheduleFlexibleInterleaved1F1B,
|
||||
ScheduleGPipe,
|
||||
|
@ -564,6 +564,57 @@ class PipelineScheduleSingle(_PipelineSchedule):
|
||||
return None
|
||||
|
||||
|
||||
class _ScheduleForwardOnly(PipelineScheduleSingle):
|
||||
"""
|
||||
The forward-only schedule.
|
||||
Will go through all the microbatches and perform only the forward pass
|
||||
"""
|
||||
|
||||
def _step_microbatches(
|
||||
self,
|
||||
arg_mbs: Optional[List] = None,
|
||||
kwarg_mbs: Optional[List] = None,
|
||||
target_mbs: Optional[List] = None,
|
||||
losses: Optional[List] = None,
|
||||
):
|
||||
"""
|
||||
Run one iteration of the pipeline schedule
|
||||
"""
|
||||
if target_mbs is not None or losses is not None:
|
||||
raise RuntimeError(
|
||||
"Forward-only schedule does not support loss computation"
|
||||
)
|
||||
|
||||
arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)
|
||||
|
||||
# Delay send waits
|
||||
fwd_sends_to_wait: List[dist.Work] = []
|
||||
|
||||
# Run microbatches
|
||||
for i in range(self._n_microbatches):
|
||||
with record_function(f"Forward {i}"):
|
||||
ops = self._stage.get_fwd_recv_ops(i)
|
||||
works = _sorted_batch_p2p(ops, desc="fwd_recv")
|
||||
for work in works.values():
|
||||
work.wait()
|
||||
|
||||
self._stage.forward_one_chunk(i, arg_mbs[i], kwarg_mbs[i]) # type: ignore[index]
|
||||
|
||||
ops = self._stage.get_fwd_send_ops(i)
|
||||
works = _sorted_batch_p2p(ops, desc="fwd_send")
|
||||
fwd_sends_to_wait.extend(works.values())
|
||||
|
||||
logger.debug(
|
||||
f"[{self._stage.stage_index}] Forwarded microbatch {i}" # noqa: G004
|
||||
)
|
||||
|
||||
# Wait for all forward sends to finish
|
||||
# This should not have performance impact because by the time the first
|
||||
# backward arrives all the forward sends should have been finished.
|
||||
for work in fwd_sends_to_wait:
|
||||
work.wait()
|
||||
|
||||
|
||||
class ScheduleGPipe(PipelineScheduleSingle):
|
||||
"""
|
||||
The GPipe schedule.
|
||||
|
Reference in New Issue
Block a user