mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[pipelining] pipeline() taking microbatch as example input (#128163)
Changed the API of `pipeline()` to take microbatch instead of full batch as example args. Main purpose is to: - make this API more atomic; - decouple tracing frontend from runtime info like `num_chunks`. Side effects: - Creates opportunity for varying `num_chunks` of schedules with the same `pipe` object. - User has to create example microbatch input. - Chunk spec stuff are now all moved to runtime side. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128163 Approved by: https://github.com/H-Huang
This commit is contained in:
@ -317,14 +317,8 @@ The following set of APIs transform your model into a pipeline representation.
|
||||
|
||||
.. autoclass:: Pipe
|
||||
|
||||
.. autofunction:: annotate_split_points
|
||||
|
||||
.. autofunction:: pipe_split
|
||||
|
||||
.. autoclass:: ArgsChunkSpec
|
||||
|
||||
.. autoclass:: KwargsChunkSpec
|
||||
|
||||
Microbatch Utilities
|
||||
====================
|
||||
|
||||
|
||||
@ -2604,12 +2604,9 @@
|
||||
"TensorPipeRpcBackendOptions"
|
||||
],
|
||||
"torch.distributed.pipelining": [
|
||||
"ArgsChunkSpec",
|
||||
"KwargsChunkSpec",
|
||||
"Pipe",
|
||||
"PipelineStage",
|
||||
"SplitPoint",
|
||||
"annotate_split_points",
|
||||
"pipe_split",
|
||||
"pipeline"
|
||||
],
|
||||
|
||||
@ -1,72 +0,0 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
import torch
|
||||
from torch.distributed.pipelining import (
|
||||
ArgsChunkSpec,
|
||||
KwargsChunkSpec,
|
||||
pipe_split,
|
||||
pipeline,
|
||||
)
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
d_hid = 512
|
||||
batch_size = 256
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
class ModelWithKwargs(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
|
||||
self.mm_param1 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
|
||||
self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
|
||||
self.lin1 = torch.nn.Linear(d_hid, d_hid)
|
||||
self.lin2 = torch.nn.Linear(d_hid, d_hid)
|
||||
|
||||
def forward(self, x, y, z=torch.zeros(batch_size, d_hid)):
|
||||
x = torch.mm(x, self.mm_param0)
|
||||
x = x + y
|
||||
x = torch.relu(x)
|
||||
x = x + z
|
||||
pipe_split()
|
||||
x = torch.mm(x, self.mm_param1)
|
||||
x = self.lin1(x)
|
||||
pipe_split()
|
||||
x = torch.relu(x)
|
||||
x = torch.mm(x, self.mm_param2)
|
||||
pipe_split()
|
||||
x = self.lin2(x)
|
||||
x = torch.relu(x)
|
||||
return x
|
||||
|
||||
|
||||
class ChunkSpecTests(TestCase):
|
||||
def test_chunk_spec(self):
|
||||
mod = ModelWithKwargs()
|
||||
|
||||
x = torch.randn(batch_size, d_hid)
|
||||
y = torch.randn(batch_size, d_hid)
|
||||
z = torch.randn(batch_size, d_hid)
|
||||
|
||||
chunks = 4
|
||||
|
||||
with ArgsChunkSpec((0, 0)), KwargsChunkSpec({"z": 0}):
|
||||
pipe = pipeline(
|
||||
mod,
|
||||
chunks,
|
||||
example_args=(x, y),
|
||||
example_kwargs={"z": z},
|
||||
)
|
||||
|
||||
assert pipe.num_stages == 4
|
||||
|
||||
ref = mod(x, y, z)
|
||||
out = pipe(x, y, z)[0]
|
||||
torch.testing.assert_close(out, ref)
|
||||
print(f"equivalence test passed {torch.sum(out)} ref {torch.sum(ref)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
@ -1,6 +1,9 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
from model_registry import ModelWithKwargs
|
||||
|
||||
import torch
|
||||
from torch.distributed.pipelining import pipeline
|
||||
from torch.distributed.pipelining.microbatch import (
|
||||
merge_chunks,
|
||||
split_args_kwargs_into_chunks,
|
||||
@ -10,6 +13,7 @@ from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
d_hid = 512
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
class MicrobatchTests(TestCase):
|
||||
@ -49,9 +53,39 @@ class MicrobatchTests(TestCase):
|
||||
},
|
||||
)
|
||||
torch.testing.assert_close(merged_kwargs, kwargs)
|
||||
|
||||
print("Microbatch test passed")
|
||||
|
||||
def test_chunk_spec(self):
|
||||
mod = ModelWithKwargs()
|
||||
batch_size = ModelWithKwargs.DEFAULT_BATCH_SIZE
|
||||
|
||||
x = torch.randn(batch_size, d_hid)
|
||||
y = torch.randn(batch_size, d_hid)
|
||||
|
||||
num_chunks = 4
|
||||
|
||||
args_chunk_spec = TensorChunkSpec.from_tuple((0,))
|
||||
kwargs_chunk_spec = TensorChunkSpec.from_dict({"y": 0})
|
||||
|
||||
args_split, kwargs_split = split_args_kwargs_into_chunks(
|
||||
(x,),
|
||||
{"y": y},
|
||||
num_chunks,
|
||||
args_chunk_spec,
|
||||
kwargs_chunk_spec,
|
||||
)
|
||||
|
||||
pipe = pipeline(
|
||||
mod,
|
||||
mb_args=args_split[0],
|
||||
mb_kwargs=kwargs_split[0],
|
||||
)
|
||||
|
||||
ref = mod(x, y)
|
||||
out = pipe(x, y)[0]
|
||||
torch.testing.assert_close(out, ref)
|
||||
print(f"equivalence test passed {torch.sum(out)} ref {torch.sum(ref)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -13,7 +13,7 @@ from torch.testing._internal.common_utils import (
|
||||
|
||||
|
||||
d_hid = 512
|
||||
batch_size = 256
|
||||
microbatch_size = 16
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
@ -81,13 +81,12 @@ class PipeTests(TestCase):
|
||||
@parametrize("ModelClass", [ExampleCode, MultiMLP, ModelWithParamAlias])
|
||||
def test_model_split(self, ModelClass):
|
||||
mod = ModelClass()
|
||||
x = torch.randn(batch_size, d_hid)
|
||||
y = torch.randn(batch_size, d_hid)
|
||||
x = torch.randn(microbatch_size, d_hid)
|
||||
y = torch.randn(microbatch_size, d_hid)
|
||||
|
||||
pipe = pipeline(
|
||||
mod,
|
||||
num_chunks=4,
|
||||
example_args=(x, y),
|
||||
mb_args=(x, y),
|
||||
)
|
||||
|
||||
assert (
|
||||
|
||||
@ -81,20 +81,22 @@ class ScheduleTest(MultiProcContinousTest):
|
||||
target = torch.randn(batch_size, d_hid, device=self.device)
|
||||
loss_fn = torch.nn.MSELoss(reduction="sum")
|
||||
|
||||
# Create a pipeline
|
||||
chunks = 4
|
||||
x_mb = x.chunk(chunks)[0]
|
||||
|
||||
# Create a pipeline
|
||||
split_spec = mod.split_spec if hasattr(mod, "split_spec") else None
|
||||
pipe = pipeline(
|
||||
mod,
|
||||
chunks,
|
||||
example_args=(x,),
|
||||
mb_args=(x_mb,),
|
||||
split_spec=split_spec,
|
||||
)
|
||||
|
||||
stage = TracerPipelineStage(
|
||||
pipe,
|
||||
self.rank,
|
||||
device=self.device,
|
||||
self.device,
|
||||
chunks, # to be cleaned
|
||||
)
|
||||
|
||||
# Attach to a schedule
|
||||
@ -123,17 +125,20 @@ class ScheduleTest(MultiProcContinousTest):
|
||||
loss_fn = torch.nn.MSELoss(reduction="sum")
|
||||
|
||||
chunks = 4
|
||||
x_mb = x.chunk(chunks)[0]
|
||||
y_mb = y.chunk(chunks)[0]
|
||||
|
||||
pipe = pipeline(
|
||||
mod,
|
||||
chunks,
|
||||
example_args=(x,),
|
||||
example_kwargs={"y": y},
|
||||
mb_args=(x_mb,),
|
||||
mb_kwargs={"y": y_mb},
|
||||
)
|
||||
|
||||
stage = TracerPipelineStage(
|
||||
pipe,
|
||||
self.rank,
|
||||
device=self.device,
|
||||
self.device,
|
||||
chunks, # to be cleaned
|
||||
)
|
||||
|
||||
# Attach to a schedule
|
||||
@ -184,18 +189,19 @@ class ScheduleTest(MultiProcContinousTest):
|
||||
|
||||
# Create a pipeline
|
||||
chunks = 4
|
||||
x_mb = x.chunk(chunks)[0]
|
||||
split_spec = mod.split_spec if hasattr(mod, "split_spec") else None
|
||||
pipe = pipeline(
|
||||
mod,
|
||||
chunks,
|
||||
example_args=(x,),
|
||||
mb_args=(x_mb,),
|
||||
split_spec=split_spec,
|
||||
)
|
||||
|
||||
stage = TracerPipelineStage(
|
||||
pipe,
|
||||
self.rank,
|
||||
device=self.device,
|
||||
self.device,
|
||||
chunks, # to be cleaned
|
||||
)
|
||||
|
||||
# Attach to a schedule
|
||||
|
||||
@ -82,19 +82,20 @@ class StageTest(MultiProcContinousTest):
|
||||
mod.to(self.device)
|
||||
|
||||
x = torch.randn(batch_size, d_hid, device=self.device)
|
||||
x_mb = x.chunk(chunks)[0]
|
||||
|
||||
split_spec = mod.split_spec if hasattr(mod, "split_spec") else None
|
||||
pipe = pipeline(
|
||||
mod,
|
||||
chunks,
|
||||
example_args=(x,),
|
||||
mb_args=(x_mb,),
|
||||
split_spec=split_spec,
|
||||
)
|
||||
|
||||
stage = TracerPipelineStage(
|
||||
pipe,
|
||||
self.rank,
|
||||
device=self.device,
|
||||
self.device,
|
||||
chunks, # to be cleaned
|
||||
)
|
||||
|
||||
# Attach to a schedule
|
||||
@ -150,17 +151,20 @@ class StageTest(MultiProcContinousTest):
|
||||
x = torch.randn(batch_size, d_hid, device=self.device)
|
||||
y = torch.randn(batch_size, d_hid, device=self.device)
|
||||
|
||||
x_mb = x.chunk(chunks)[0]
|
||||
y_mb = y.chunk(chunks)[0]
|
||||
|
||||
pipe = pipeline(
|
||||
mod,
|
||||
chunks,
|
||||
example_args=(x,),
|
||||
example_kwargs={"y": y},
|
||||
mb_args=(x_mb,),
|
||||
mb_kwargs={"y": y_mb},
|
||||
)
|
||||
|
||||
stage = TracerPipelineStage(
|
||||
pipe,
|
||||
self.rank,
|
||||
device=self.device,
|
||||
self.device,
|
||||
chunks,
|
||||
)
|
||||
|
||||
# Attach to a schedule
|
||||
|
||||
@ -7,7 +7,7 @@ from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
d_hid = 16
|
||||
n_layers = 8
|
||||
batch_size = 4
|
||||
microbatch_size = 4
|
||||
|
||||
|
||||
class MLPModule(torch.nn.Module):
|
||||
@ -36,8 +36,7 @@ class TransformerLike(torch.nn.Module):
|
||||
class TransformerTests(TestCase):
|
||||
def test_ir(self):
|
||||
transformer = TransformerLike()
|
||||
print("Original model:\n", transformer)
|
||||
x = torch.randn(batch_size, d_hid)
|
||||
x = torch.randn(microbatch_size, d_hid)
|
||||
|
||||
# Split into 2 stages
|
||||
num_stages = 2
|
||||
@ -45,7 +44,6 @@ class TransformerTests(TestCase):
|
||||
|
||||
pipe = pipeline(
|
||||
transformer,
|
||||
1,
|
||||
(x,),
|
||||
split_spec=split_spec,
|
||||
)
|
||||
@ -59,19 +57,18 @@ class TransformerTests(TestCase):
|
||||
layers = []
|
||||
for stage_idx in range(pipe.num_stages):
|
||||
stage_mod = pipe.get_stage_module(stage_idx)
|
||||
print(f"\nStage {stage_idx}: \n", stage_mod)
|
||||
layers += get_layers(stage_mod)
|
||||
|
||||
# Check layer completeness
|
||||
orig_layers = get_layers(transformer)
|
||||
assert sorted(layers) == sorted(orig_layers), f"{layers} != {orig_layers}"
|
||||
print("Layers matched! ", layers)
|
||||
print("Layers matched!")
|
||||
|
||||
# Check equivalence
|
||||
ref = transformer(x)
|
||||
out = pipe(x)[0]
|
||||
torch.testing.assert_close(out, ref)
|
||||
print(f"\nEquivalence test passed {torch.sum(out)} ref {torch.sum(ref)}")
|
||||
print(f"Equivalence test passed {torch.sum(out)} ref {torch.sum(ref)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -48,7 +48,6 @@ class UnflattenTests(TestCase):
|
||||
|
||||
pipe = pipeline(
|
||||
mod,
|
||||
1,
|
||||
(x,),
|
||||
{"constant": constant},
|
||||
)
|
||||
|
||||
@ -20,7 +20,7 @@ import torch
|
||||
import torch.distributed as dist
|
||||
from torch.profiler import record_function
|
||||
|
||||
from .microbatch import merge_chunks, split_args_kwargs_into_chunks
|
||||
from .microbatch import merge_chunks, split_args_kwargs_into_chunks, TensorChunkSpec
|
||||
from .PipelineStage import _PipelineStageBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -64,12 +64,24 @@ class _PipelineSchedule(ABC):
|
||||
self,
|
||||
n_microbatches: int,
|
||||
loss_fn: Optional[Callable[..., torch.Tensor]] = None,
|
||||
args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None,
|
||||
kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
|
||||
output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
||||
):
|
||||
# From arguments
|
||||
self._n_microbatches = n_microbatches
|
||||
self._loss_fn = loss_fn
|
||||
# Chunking specification for positional inputs. (default: `None`)
|
||||
self._args_chunk_spec = args_chunk_spec
|
||||
# Chunking specification for keyword inputs. (default: `None`)
|
||||
self._kwargs_chunk_spec = kwargs_chunk_spec
|
||||
self._output_merge_spec = output_merge_spec
|
||||
"""
|
||||
# args_chunk_spec and kwargs_chunk_spec specify how to chunk inputs.
|
||||
# They are used to convert batch to microbatches in `step(x)`. See
|
||||
# `TensorChunkSpec` for helper methods for creating them.
|
||||
"""
|
||||
|
||||
# Derived
|
||||
self._has_backward = self._loss_fn is not None
|
||||
# To be filled by subclasses
|
||||
@ -201,22 +213,13 @@ class _PipelineSchedule(ABC):
|
||||
Splits a full-batch input into chunks (i.e. microbatches) and returns
|
||||
the chunks
|
||||
"""
|
||||
if self._pipe_info is not None:
|
||||
# Use spec from `pipe_info`
|
||||
args_chunk_spec = self._pipe_info.args_chunk_spec
|
||||
kwargs_chunk_spec = self._pipe_info.kwargs_chunk_spec
|
||||
else:
|
||||
# Use default spec from `microbatch.py` (i.e. chunk dim 0 for each arg/kwarg)
|
||||
args_chunk_spec = None
|
||||
kwargs_chunk_spec = None
|
||||
|
||||
if args or kwargs:
|
||||
args_split, kwargs_split = split_args_kwargs_into_chunks(
|
||||
args,
|
||||
kwargs,
|
||||
self._n_microbatches,
|
||||
args_chunk_spec,
|
||||
kwargs_chunk_spec,
|
||||
self._args_chunk_spec,
|
||||
self._kwargs_chunk_spec,
|
||||
)
|
||||
return args_split, kwargs_split
|
||||
else:
|
||||
@ -285,12 +288,16 @@ class PipelineScheduleSingle(_PipelineSchedule):
|
||||
stage: _PipelineStageBase,
|
||||
n_microbatches: int,
|
||||
loss_fn: Optional[Callable] = None,
|
||||
args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None,
|
||||
kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
|
||||
output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
||||
):
|
||||
# Init parent
|
||||
super().__init__(
|
||||
n_microbatches=n_microbatches,
|
||||
loss_fn=loss_fn,
|
||||
args_chunk_spec=args_chunk_spec,
|
||||
kwargs_chunk_spec=kwargs_chunk_spec,
|
||||
output_merge_spec=output_merge_spec,
|
||||
)
|
||||
self._pipe_info = (
|
||||
@ -567,6 +574,8 @@ class PipelineScheduleMulti(_PipelineSchedule):
|
||||
stages: List[_PipelineStageBase],
|
||||
n_microbatches: int,
|
||||
loss_fn: Optional[Callable] = None,
|
||||
args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None,
|
||||
kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
|
||||
output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
||||
):
|
||||
if len(stages) <= 1:
|
||||
@ -577,6 +586,8 @@ class PipelineScheduleMulti(_PipelineSchedule):
|
||||
super().__init__(
|
||||
n_microbatches=n_microbatches,
|
||||
loss_fn=loss_fn,
|
||||
args_chunk_spec=args_chunk_spec,
|
||||
kwargs_chunk_spec=kwargs_chunk_spec,
|
||||
output_merge_spec=output_merge_spec,
|
||||
)
|
||||
self._pipe_info = (
|
||||
@ -712,6 +723,8 @@ class ScheduleInterleaved1F1B(PipelineScheduleMulti):
|
||||
stages: List[_PipelineStageBase],
|
||||
n_microbatches: int,
|
||||
loss_fn: Optional[Callable] = None,
|
||||
args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None,
|
||||
kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
|
||||
output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
||||
):
|
||||
self.pp_group_size = stages[0].group_size
|
||||
@ -726,6 +739,8 @@ class ScheduleInterleaved1F1B(PipelineScheduleMulti):
|
||||
stages=stages,
|
||||
n_microbatches=n_microbatches,
|
||||
loss_fn=loss_fn,
|
||||
args_chunk_spec=args_chunk_spec,
|
||||
kwargs_chunk_spec=kwargs_chunk_spec,
|
||||
output_merge_spec=output_merge_spec,
|
||||
)
|
||||
|
||||
|
||||
@ -630,6 +630,7 @@ class _PipelineStage(_PipelineStageBase):
|
||||
stage_index: int,
|
||||
pipe_info: Pipe.PipeInfo,
|
||||
device: torch.device,
|
||||
num_chunks: int,
|
||||
group: Optional[dist.ProcessGroup] = None,
|
||||
):
|
||||
"""
|
||||
@ -642,7 +643,7 @@ class _PipelineStage(_PipelineStageBase):
|
||||
stage_index,
|
||||
pipe_info.num_stages,
|
||||
device,
|
||||
pipe_info.num_chunks,
|
||||
num_chunks,
|
||||
group,
|
||||
)
|
||||
self.pipe_info = pipe_info
|
||||
@ -901,6 +902,7 @@ class TracerPipelineStage(_PipelineStage):
|
||||
pipe: Pipe,
|
||||
stage_index: int,
|
||||
device: torch.device,
|
||||
num_chunks: int, # To be cleaned
|
||||
group: Optional[dist.ProcessGroup] = None,
|
||||
):
|
||||
"""
|
||||
@ -910,7 +912,9 @@ class TracerPipelineStage(_PipelineStage):
|
||||
stage_module = pipe.get_stage_module(stage_index)
|
||||
# Get my pipe info
|
||||
pipe_info = pipe.info()
|
||||
super().__init__(stage_module, stage_index, pipe_info, device, group)
|
||||
super().__init__(
|
||||
stage_module, stage_index, pipe_info, device, num_chunks, group
|
||||
)
|
||||
|
||||
|
||||
# Manual PipelineStage functions and definition
|
||||
|
||||
@ -23,7 +23,6 @@ from torch.fx.passes.split_module import split_module
|
||||
|
||||
from ._backward import _null_coalesce_accumulate, stage_backward
|
||||
from ._unflatten import _outline_submodules
|
||||
from .microbatch import split_args_kwargs_into_chunks, TensorChunkSpec
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -486,28 +485,11 @@ def _direct_serialization_reduce(self):
|
||||
|
||||
|
||||
class Pipe(torch.nn.Module):
|
||||
# Class variables
|
||||
# args_chunk_spec and kwargs_chunk_spec are used to specify how to chunk
|
||||
# inputs. They are used to create microbatched examples before tracing.
|
||||
# See context managers `ArgsChunkSpec` and `KwargsChunkSpec`.
|
||||
# TODO: Do we need to support `_Replicate`? It's unclear, dropping for now.
|
||||
|
||||
# args_chunk_spec:
|
||||
# Chunking specification for positional inputs. (default: `None`)
|
||||
args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None
|
||||
|
||||
# kwargs_chunk_spec:
|
||||
# Chunking specification for keyword inputs. (default: `None`)
|
||||
kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None
|
||||
|
||||
@dataclass
|
||||
class PipeInfo:
|
||||
graph: fx.Graph
|
||||
num_stages: int
|
||||
num_chunks: int
|
||||
has_loss_and_backward: bool
|
||||
args_chunk_spec: Optional[Tuple[Any, ...]] = None
|
||||
kwargs_chunk_spec: Optional[Dict[str, Any]] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -1000,7 +982,6 @@ class Pipe(torch.nn.Module):
|
||||
@staticmethod
|
||||
def from_tracing(
|
||||
mod: torch.nn.Module,
|
||||
num_chunks: int,
|
||||
example_args: Tuple[Any, ...],
|
||||
example_kwargs: Optional[Dict[str, Any]] = None,
|
||||
split_policy: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None,
|
||||
@ -1019,19 +1000,11 @@ class Pipe(torch.nn.Module):
|
||||
)
|
||||
"""
|
||||
|
||||
args_split, kwargs_split = split_args_kwargs_into_chunks(
|
||||
example_args,
|
||||
example_kwargs,
|
||||
num_chunks,
|
||||
Pipe.args_chunk_spec,
|
||||
Pipe.kwargs_chunk_spec,
|
||||
)
|
||||
|
||||
# Trace with export
|
||||
exported_program = Pipe._trace_with_export(
|
||||
mod,
|
||||
example_args=args_split[0],
|
||||
example_kwargs=kwargs_split[0],
|
||||
example_args,
|
||||
example_kwargs,
|
||||
)
|
||||
|
||||
pipe = Pipe._from_traced(
|
||||
@ -1075,10 +1048,7 @@ class Pipe(torch.nn.Module):
|
||||
pipe.pipe_info = Pipe.PipeInfo(
|
||||
graph=pipe.split_gm.graph,
|
||||
num_stages=pipe.num_stages,
|
||||
num_chunks=num_chunks,
|
||||
has_loss_and_backward=pipe.has_loss_and_backward,
|
||||
args_chunk_spec=Pipe.args_chunk_spec,
|
||||
kwargs_chunk_spec=Pipe.kwargs_chunk_spec,
|
||||
)
|
||||
return pipe
|
||||
|
||||
@ -1145,29 +1115,26 @@ def annotate_split_points(mod: torch.nn.Module, spec: Dict[str, SplitPoint]):
|
||||
|
||||
def pipeline(
|
||||
module: torch.nn.Module,
|
||||
num_chunks: int,
|
||||
example_args: Tuple[Any, ...],
|
||||
example_kwargs: Optional[Dict[str, Any]] = None,
|
||||
mb_args: Tuple[Any, ...],
|
||||
mb_kwargs: Optional[Dict[str, Any]] = None,
|
||||
split_spec: Optional[Dict[str, SplitPoint]] = None,
|
||||
split_policy: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None,
|
||||
) -> Pipe:
|
||||
"""
|
||||
Creates a pipeline representation for the provided module.
|
||||
Split a module based on a specification.
|
||||
|
||||
See `Pipe` for more details.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
module:
|
||||
The module to be transformed into a `Pipe`.
|
||||
num_chunks:
|
||||
The number of microbatches to be run with this pipeline.
|
||||
example_args:
|
||||
Example positional inputs to be used with this pipeline.
|
||||
example_kwargs:
|
||||
Example keyword inputs to be used with this pipeline. (default: `None`)
|
||||
The module to be splitted.
|
||||
mb_args:
|
||||
Example positional inputs, in micro-batch form.
|
||||
mb_kwargs:
|
||||
Example keyword inputs, in micro-batch form. (default: `None`)
|
||||
split_spec:
|
||||
A dictionary mapping module names to `SplitPoint`s. (default: `None`)
|
||||
A dictionary using submodule names as split marker. (default: `None`)
|
||||
split_policy:
|
||||
The policy to use for splitting the module. (default: `None`)
|
||||
|
||||
@ -1185,77 +1152,14 @@ def pipeline(
|
||||
annotate_split_points(module, split_spec)
|
||||
return Pipe.from_tracing(
|
||||
mod=module,
|
||||
num_chunks=num_chunks,
|
||||
example_args=example_args,
|
||||
example_kwargs=example_kwargs,
|
||||
example_args=mb_args,
|
||||
example_kwargs=mb_kwargs,
|
||||
)
|
||||
else:
|
||||
# Use split policy
|
||||
return Pipe.from_tracing(
|
||||
mod=module,
|
||||
num_chunks=num_chunks,
|
||||
example_args=example_args,
|
||||
example_kwargs=example_kwargs,
|
||||
example_args=mb_args,
|
||||
example_kwargs=mb_kwargs,
|
||||
split_policy=split_policy,
|
||||
)
|
||||
|
||||
|
||||
class ArgsChunkSpec:
|
||||
"""
|
||||
Context manager for setting `args_chunk_spec` during creation of Pipe
|
||||
|
||||
Example:
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> # There are three positional arguments to the model, and
|
||||
>>> # we are chunking them along dimension 0, 0 and 1, respectively
|
||||
>>> with ArgsChunkSpec((0, 0, 1)):
|
||||
>>> pipe = pipeline(model, num_chunks, example_args)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chunk_dims: Tuple[int, ...],
|
||||
):
|
||||
self.args_chunk_spec = map_aggregate(
|
||||
chunk_dims,
|
||||
lambda dim: TensorChunkSpec(dim),
|
||||
)
|
||||
|
||||
def __enter__(self):
|
||||
# Inject into the Pipe class
|
||||
Pipe.args_chunk_spec = self.args_chunk_spec
|
||||
return self.args_chunk_spec
|
||||
|
||||
def __exit__(self, exc_type, exc_val, traceback):
|
||||
# Remove from the Pipe class
|
||||
Pipe.args_chunk_spec = None
|
||||
|
||||
|
||||
class KwargsChunkSpec:
|
||||
"""
|
||||
Context manager for setting `kwargs_chunk_spec` during creation of Pipe
|
||||
|
||||
Example:
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> # Chunk dimension 0 for the "id" argument, 1 for the "mask" argument
|
||||
>>> with KwargsChunkSpec({"id": 0, "mask": 1}):
|
||||
>>> pipe = pipeline(model, num_chunks, (), example_kwargs)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chunk_dims: Dict[str, int],
|
||||
):
|
||||
self.kwargs_chunk_spec = map_aggregate(
|
||||
chunk_dims,
|
||||
lambda dim: TensorChunkSpec(dim),
|
||||
)
|
||||
|
||||
def __enter__(self):
|
||||
# Inject into the Pipe class
|
||||
Pipe.kwargs_chunk_spec = self.kwargs_chunk_spec
|
||||
return self.kwargs_chunk_spec
|
||||
|
||||
def __exit__(self, exc_type, exc_val, traceback):
|
||||
# Remove from the Pipe class
|
||||
Pipe.kwargs_chunk_spec = None
|
||||
|
||||
@ -1,13 +1,5 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
from ._IR import (
|
||||
annotate_split_points,
|
||||
ArgsChunkSpec,
|
||||
KwargsChunkSpec,
|
||||
Pipe,
|
||||
pipe_split,
|
||||
pipeline,
|
||||
SplitPoint,
|
||||
)
|
||||
from ._IR import Pipe, pipe_split, pipeline, SplitPoint
|
||||
from .PipelineSchedule import (
|
||||
Schedule1F1B,
|
||||
ScheduleGPipe,
|
||||
@ -20,10 +12,7 @@ __all__ = [
|
||||
"Pipe",
|
||||
"pipe_split",
|
||||
"SplitPoint",
|
||||
"annotate_split_points",
|
||||
"pipeline",
|
||||
"ArgsChunkSpec",
|
||||
"KwargsChunkSpec",
|
||||
"TracerPipelineStage",
|
||||
"PipelineStage",
|
||||
"Schedule1F1B",
|
||||
|
||||
@ -3,9 +3,16 @@ import logging
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch.fx.node import map_aggregate
|
||||
from torch.utils._pytree import tree_flatten, tree_unflatten
|
||||
|
||||
|
||||
__all__ = [
|
||||
"TensorChunkSpec",
|
||||
"split_args_kwargs_into_chunks",
|
||||
"merge_chunks",
|
||||
]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
"""
|
||||
@ -45,8 +52,11 @@ sum_reducer = _LossReducer(torch.tensor(0.0), lambda a, b: a + b)
|
||||
DEFAULT_CHUNK_DIM = 0
|
||||
|
||||
|
||||
# Class used to specify chunking of inputs
|
||||
class TensorChunkSpec:
|
||||
"""
|
||||
Class used to specify chunking of inputs
|
||||
"""
|
||||
|
||||
def __init__(self, split_dim):
|
||||
self.split_dim = split_dim
|
||||
|
||||
@ -60,6 +70,43 @@ class TensorChunkSpec:
|
||||
def __str__(self):
|
||||
return f"TensorChunkSpec({self.split_dim})"
|
||||
|
||||
@staticmethod
|
||||
def from_tuple(
|
||||
chunk_dims: Tuple[int, ...],
|
||||
):
|
||||
"""
|
||||
A helper for creating a tuple of `TensorChunkSpec` from a tuple of chunk
|
||||
dimensions (int's).
|
||||
Example:
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> # There are three positional arguments to the model, and
|
||||
>>> # we are chunking them along dimension 0, 0 and 1, respectively
|
||||
>>> args_chunk_spec = TensorChunkSpec.from_tuple((0, 0, 1))
|
||||
"""
|
||||
args_chunk_spec = map_aggregate(
|
||||
chunk_dims,
|
||||
lambda dim: TensorChunkSpec(dim),
|
||||
)
|
||||
return args_chunk_spec
|
||||
|
||||
@staticmethod
|
||||
def from_dict(
|
||||
chunk_dims: Dict[str, int],
|
||||
):
|
||||
"""
|
||||
A helper for creating a dictionary of `TensorChunkSpec` from a
|
||||
dictionary of chunk dimensions (int's).
|
||||
Example:
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> # Chunk dimension 0 for the "id" argument, 1 for the "mask" argument
|
||||
>>> kwargs_chunk_spec = TensorChunkSpec.from_dict({"id": 0, "mask": 1})
|
||||
"""
|
||||
kwargs_chunk_spec = map_aggregate(
|
||||
chunk_dims,
|
||||
lambda dim: TensorChunkSpec(dim),
|
||||
)
|
||||
return kwargs_chunk_spec
|
||||
|
||||
|
||||
# Class used to specify replication of inputs
|
||||
class _Replicate:
|
||||
|
||||
Reference in New Issue
Block a user