[pipelining] pipeline() taking microbatch as example input (#128163)

Changed the API of `pipeline()` to take microbatch instead of full batch as example args.

Main purpose is to:
- make this API more atomic;
- decouple tracing frontend from runtime info like `num_chunks`.

Side effects:
- Creates opportunity for varying `num_chunks` of schedules with the same `pipe` object.
- User has to create example microbatch input.
- Chunk spec stuff are now all moved to runtime side.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128163
Approved by: https://github.com/H-Huang
This commit is contained in:
Ke Wen
2024-06-07 08:42:25 -07:00
committed by PyTorch MergeBot
parent 224b4339e5
commit 3090667cf9
14 changed files with 168 additions and 251 deletions

View File

@ -317,14 +317,8 @@ The following set of APIs transform your model into a pipeline representation.
.. autoclass:: Pipe
.. autofunction:: annotate_split_points
.. autofunction:: pipe_split
.. autoclass:: ArgsChunkSpec
.. autoclass:: KwargsChunkSpec
Microbatch Utilities
====================

View File

@ -2604,12 +2604,9 @@
"TensorPipeRpcBackendOptions"
],
"torch.distributed.pipelining": [
"ArgsChunkSpec",
"KwargsChunkSpec",
"Pipe",
"PipelineStage",
"SplitPoint",
"annotate_split_points",
"pipe_split",
"pipeline"
],

View File

@ -1,72 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]
import torch
from torch.distributed.pipelining import (
ArgsChunkSpec,
KwargsChunkSpec,
pipe_split,
pipeline,
)
from torch.testing._internal.common_utils import run_tests, TestCase
d_hid = 512
batch_size = 256
torch.manual_seed(0)
class ModelWithKwargs(torch.nn.Module):
def __init__(self):
super().__init__()
self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
self.mm_param1 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
self.lin1 = torch.nn.Linear(d_hid, d_hid)
self.lin2 = torch.nn.Linear(d_hid, d_hid)
def forward(self, x, y, z=torch.zeros(batch_size, d_hid)):
x = torch.mm(x, self.mm_param0)
x = x + y
x = torch.relu(x)
x = x + z
pipe_split()
x = torch.mm(x, self.mm_param1)
x = self.lin1(x)
pipe_split()
x = torch.relu(x)
x = torch.mm(x, self.mm_param2)
pipe_split()
x = self.lin2(x)
x = torch.relu(x)
return x
class ChunkSpecTests(TestCase):
def test_chunk_spec(self):
mod = ModelWithKwargs()
x = torch.randn(batch_size, d_hid)
y = torch.randn(batch_size, d_hid)
z = torch.randn(batch_size, d_hid)
chunks = 4
with ArgsChunkSpec((0, 0)), KwargsChunkSpec({"z": 0}):
pipe = pipeline(
mod,
chunks,
example_args=(x, y),
example_kwargs={"z": z},
)
assert pipe.num_stages == 4
ref = mod(x, y, z)
out = pipe(x, y, z)[0]
torch.testing.assert_close(out, ref)
print(f"equivalence test passed {torch.sum(out)} ref {torch.sum(ref)}")
if __name__ == "__main__":
run_tests()

View File

@ -1,6 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]
from model_registry import ModelWithKwargs
import torch
from torch.distributed.pipelining import pipeline
from torch.distributed.pipelining.microbatch import (
merge_chunks,
split_args_kwargs_into_chunks,
@ -10,6 +13,7 @@ from torch.testing._internal.common_utils import run_tests, TestCase
d_hid = 512
torch.manual_seed(0)
class MicrobatchTests(TestCase):
@ -49,9 +53,39 @@ class MicrobatchTests(TestCase):
},
)
torch.testing.assert_close(merged_kwargs, kwargs)
print("Microbatch test passed")
def test_chunk_spec(self):
mod = ModelWithKwargs()
batch_size = ModelWithKwargs.DEFAULT_BATCH_SIZE
x = torch.randn(batch_size, d_hid)
y = torch.randn(batch_size, d_hid)
num_chunks = 4
args_chunk_spec = TensorChunkSpec.from_tuple((0,))
kwargs_chunk_spec = TensorChunkSpec.from_dict({"y": 0})
args_split, kwargs_split = split_args_kwargs_into_chunks(
(x,),
{"y": y},
num_chunks,
args_chunk_spec,
kwargs_chunk_spec,
)
pipe = pipeline(
mod,
mb_args=args_split[0],
mb_kwargs=kwargs_split[0],
)
ref = mod(x, y)
out = pipe(x, y)[0]
torch.testing.assert_close(out, ref)
print(f"equivalence test passed {torch.sum(out)} ref {torch.sum(ref)}")
if __name__ == "__main__":
run_tests()

View File

@ -13,7 +13,7 @@ from torch.testing._internal.common_utils import (
d_hid = 512
batch_size = 256
microbatch_size = 16
torch.manual_seed(0)
@ -81,13 +81,12 @@ class PipeTests(TestCase):
@parametrize("ModelClass", [ExampleCode, MultiMLP, ModelWithParamAlias])
def test_model_split(self, ModelClass):
mod = ModelClass()
x = torch.randn(batch_size, d_hid)
y = torch.randn(batch_size, d_hid)
x = torch.randn(microbatch_size, d_hid)
y = torch.randn(microbatch_size, d_hid)
pipe = pipeline(
mod,
num_chunks=4,
example_args=(x, y),
mb_args=(x, y),
)
assert (

View File

@ -81,20 +81,22 @@ class ScheduleTest(MultiProcContinousTest):
target = torch.randn(batch_size, d_hid, device=self.device)
loss_fn = torch.nn.MSELoss(reduction="sum")
# Create a pipeline
chunks = 4
x_mb = x.chunk(chunks)[0]
# Create a pipeline
split_spec = mod.split_spec if hasattr(mod, "split_spec") else None
pipe = pipeline(
mod,
chunks,
example_args=(x,),
mb_args=(x_mb,),
split_spec=split_spec,
)
stage = TracerPipelineStage(
pipe,
self.rank,
device=self.device,
self.device,
chunks, # to be cleaned
)
# Attach to a schedule
@ -123,17 +125,20 @@ class ScheduleTest(MultiProcContinousTest):
loss_fn = torch.nn.MSELoss(reduction="sum")
chunks = 4
x_mb = x.chunk(chunks)[0]
y_mb = y.chunk(chunks)[0]
pipe = pipeline(
mod,
chunks,
example_args=(x,),
example_kwargs={"y": y},
mb_args=(x_mb,),
mb_kwargs={"y": y_mb},
)
stage = TracerPipelineStage(
pipe,
self.rank,
device=self.device,
self.device,
chunks, # to be cleaned
)
# Attach to a schedule
@ -184,18 +189,19 @@ class ScheduleTest(MultiProcContinousTest):
# Create a pipeline
chunks = 4
x_mb = x.chunk(chunks)[0]
split_spec = mod.split_spec if hasattr(mod, "split_spec") else None
pipe = pipeline(
mod,
chunks,
example_args=(x,),
mb_args=(x_mb,),
split_spec=split_spec,
)
stage = TracerPipelineStage(
pipe,
self.rank,
device=self.device,
self.device,
chunks, # to be cleaned
)
# Attach to a schedule

View File

@ -82,19 +82,20 @@ class StageTest(MultiProcContinousTest):
mod.to(self.device)
x = torch.randn(batch_size, d_hid, device=self.device)
x_mb = x.chunk(chunks)[0]
split_spec = mod.split_spec if hasattr(mod, "split_spec") else None
pipe = pipeline(
mod,
chunks,
example_args=(x,),
mb_args=(x_mb,),
split_spec=split_spec,
)
stage = TracerPipelineStage(
pipe,
self.rank,
device=self.device,
self.device,
chunks, # to be cleaned
)
# Attach to a schedule
@ -150,17 +151,20 @@ class StageTest(MultiProcContinousTest):
x = torch.randn(batch_size, d_hid, device=self.device)
y = torch.randn(batch_size, d_hid, device=self.device)
x_mb = x.chunk(chunks)[0]
y_mb = y.chunk(chunks)[0]
pipe = pipeline(
mod,
chunks,
example_args=(x,),
example_kwargs={"y": y},
mb_args=(x_mb,),
mb_kwargs={"y": y_mb},
)
stage = TracerPipelineStage(
pipe,
self.rank,
device=self.device,
self.device,
chunks,
)
# Attach to a schedule

View File

@ -7,7 +7,7 @@ from torch.testing._internal.common_utils import run_tests, TestCase
d_hid = 16
n_layers = 8
batch_size = 4
microbatch_size = 4
class MLPModule(torch.nn.Module):
@ -36,8 +36,7 @@ class TransformerLike(torch.nn.Module):
class TransformerTests(TestCase):
def test_ir(self):
transformer = TransformerLike()
print("Original model:\n", transformer)
x = torch.randn(batch_size, d_hid)
x = torch.randn(microbatch_size, d_hid)
# Split into 2 stages
num_stages = 2
@ -45,7 +44,6 @@ class TransformerTests(TestCase):
pipe = pipeline(
transformer,
1,
(x,),
split_spec=split_spec,
)
@ -59,19 +57,18 @@ class TransformerTests(TestCase):
layers = []
for stage_idx in range(pipe.num_stages):
stage_mod = pipe.get_stage_module(stage_idx)
print(f"\nStage {stage_idx}: \n", stage_mod)
layers += get_layers(stage_mod)
# Check layer completeness
orig_layers = get_layers(transformer)
assert sorted(layers) == sorted(orig_layers), f"{layers} != {orig_layers}"
print("Layers matched! ", layers)
print("Layers matched!")
# Check equivalence
ref = transformer(x)
out = pipe(x)[0]
torch.testing.assert_close(out, ref)
print(f"\nEquivalence test passed {torch.sum(out)} ref {torch.sum(ref)}")
print(f"Equivalence test passed {torch.sum(out)} ref {torch.sum(ref)}")
if __name__ == "__main__":

View File

@ -48,7 +48,6 @@ class UnflattenTests(TestCase):
pipe = pipeline(
mod,
1,
(x,),
{"constant": constant},
)

View File

@ -20,7 +20,7 @@ import torch
import torch.distributed as dist
from torch.profiler import record_function
from .microbatch import merge_chunks, split_args_kwargs_into_chunks
from .microbatch import merge_chunks, split_args_kwargs_into_chunks, TensorChunkSpec
from .PipelineStage import _PipelineStageBase
if TYPE_CHECKING:
@ -64,12 +64,24 @@ class _PipelineSchedule(ABC):
self,
n_microbatches: int,
loss_fn: Optional[Callable[..., torch.Tensor]] = None,
args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None,
kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
):
# From arguments
self._n_microbatches = n_microbatches
self._loss_fn = loss_fn
# Chunking specification for positional inputs. (default: `None`)
self._args_chunk_spec = args_chunk_spec
# Chunking specification for keyword inputs. (default: `None`)
self._kwargs_chunk_spec = kwargs_chunk_spec
self._output_merge_spec = output_merge_spec
"""
# args_chunk_spec and kwargs_chunk_spec specify how to chunk inputs.
# They are used to convert batch to microbatches in `step(x)`. See
# `TensorChunkSpec` for helper methods for creating them.
"""
# Derived
self._has_backward = self._loss_fn is not None
# To be filled by subclasses
@ -201,22 +213,13 @@ class _PipelineSchedule(ABC):
Splits a full-batch input into chunks (i.e. microbatches) and returns
the chunks
"""
if self._pipe_info is not None:
# Use spec from `pipe_info`
args_chunk_spec = self._pipe_info.args_chunk_spec
kwargs_chunk_spec = self._pipe_info.kwargs_chunk_spec
else:
# Use default spec from `microbatch.py` (i.e. chunk dim 0 for each arg/kwarg)
args_chunk_spec = None
kwargs_chunk_spec = None
if args or kwargs:
args_split, kwargs_split = split_args_kwargs_into_chunks(
args,
kwargs,
self._n_microbatches,
args_chunk_spec,
kwargs_chunk_spec,
self._args_chunk_spec,
self._kwargs_chunk_spec,
)
return args_split, kwargs_split
else:
@ -285,12 +288,16 @@ class PipelineScheduleSingle(_PipelineSchedule):
stage: _PipelineStageBase,
n_microbatches: int,
loss_fn: Optional[Callable] = None,
args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None,
kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
):
# Init parent
super().__init__(
n_microbatches=n_microbatches,
loss_fn=loss_fn,
args_chunk_spec=args_chunk_spec,
kwargs_chunk_spec=kwargs_chunk_spec,
output_merge_spec=output_merge_spec,
)
self._pipe_info = (
@ -567,6 +574,8 @@ class PipelineScheduleMulti(_PipelineSchedule):
stages: List[_PipelineStageBase],
n_microbatches: int,
loss_fn: Optional[Callable] = None,
args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None,
kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
):
if len(stages) <= 1:
@ -577,6 +586,8 @@ class PipelineScheduleMulti(_PipelineSchedule):
super().__init__(
n_microbatches=n_microbatches,
loss_fn=loss_fn,
args_chunk_spec=args_chunk_spec,
kwargs_chunk_spec=kwargs_chunk_spec,
output_merge_spec=output_merge_spec,
)
self._pipe_info = (
@ -712,6 +723,8 @@ class ScheduleInterleaved1F1B(PipelineScheduleMulti):
stages: List[_PipelineStageBase],
n_microbatches: int,
loss_fn: Optional[Callable] = None,
args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None,
kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
):
self.pp_group_size = stages[0].group_size
@ -726,6 +739,8 @@ class ScheduleInterleaved1F1B(PipelineScheduleMulti):
stages=stages,
n_microbatches=n_microbatches,
loss_fn=loss_fn,
args_chunk_spec=args_chunk_spec,
kwargs_chunk_spec=kwargs_chunk_spec,
output_merge_spec=output_merge_spec,
)

View File

@ -630,6 +630,7 @@ class _PipelineStage(_PipelineStageBase):
stage_index: int,
pipe_info: Pipe.PipeInfo,
device: torch.device,
num_chunks: int,
group: Optional[dist.ProcessGroup] = None,
):
"""
@ -642,7 +643,7 @@ class _PipelineStage(_PipelineStageBase):
stage_index,
pipe_info.num_stages,
device,
pipe_info.num_chunks,
num_chunks,
group,
)
self.pipe_info = pipe_info
@ -901,6 +902,7 @@ class TracerPipelineStage(_PipelineStage):
pipe: Pipe,
stage_index: int,
device: torch.device,
num_chunks: int, # To be cleaned
group: Optional[dist.ProcessGroup] = None,
):
"""
@ -910,7 +912,9 @@ class TracerPipelineStage(_PipelineStage):
stage_module = pipe.get_stage_module(stage_index)
# Get my pipe info
pipe_info = pipe.info()
super().__init__(stage_module, stage_index, pipe_info, device, group)
super().__init__(
stage_module, stage_index, pipe_info, device, num_chunks, group
)
# Manual PipelineStage functions and definition

View File

@ -23,7 +23,6 @@ from torch.fx.passes.split_module import split_module
from ._backward import _null_coalesce_accumulate, stage_backward
from ._unflatten import _outline_submodules
from .microbatch import split_args_kwargs_into_chunks, TensorChunkSpec
logger = logging.getLogger(__name__)
@ -486,28 +485,11 @@ def _direct_serialization_reduce(self):
class Pipe(torch.nn.Module):
# Class variables
# args_chunk_spec and kwargs_chunk_spec are used to specify how to chunk
# inputs. They are used to create microbatched examples before tracing.
# See context managers `ArgsChunkSpec` and `KwargsChunkSpec`.
# TODO: Do we need to support `_Replicate`? It's unclear, dropping for now.
# args_chunk_spec:
# Chunking specification for positional inputs. (default: `None`)
args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None
# kwargs_chunk_spec:
# Chunking specification for keyword inputs. (default: `None`)
kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None
@dataclass
class PipeInfo:
graph: fx.Graph
num_stages: int
num_chunks: int
has_loss_and_backward: bool
args_chunk_spec: Optional[Tuple[Any, ...]] = None
kwargs_chunk_spec: Optional[Dict[str, Any]] = None
def __init__(
self,
@ -1000,7 +982,6 @@ class Pipe(torch.nn.Module):
@staticmethod
def from_tracing(
mod: torch.nn.Module,
num_chunks: int,
example_args: Tuple[Any, ...],
example_kwargs: Optional[Dict[str, Any]] = None,
split_policy: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None,
@ -1019,19 +1000,11 @@ class Pipe(torch.nn.Module):
)
"""
args_split, kwargs_split = split_args_kwargs_into_chunks(
example_args,
example_kwargs,
num_chunks,
Pipe.args_chunk_spec,
Pipe.kwargs_chunk_spec,
)
# Trace with export
exported_program = Pipe._trace_with_export(
mod,
example_args=args_split[0],
example_kwargs=kwargs_split[0],
example_args,
example_kwargs,
)
pipe = Pipe._from_traced(
@ -1075,10 +1048,7 @@ class Pipe(torch.nn.Module):
pipe.pipe_info = Pipe.PipeInfo(
graph=pipe.split_gm.graph,
num_stages=pipe.num_stages,
num_chunks=num_chunks,
has_loss_and_backward=pipe.has_loss_and_backward,
args_chunk_spec=Pipe.args_chunk_spec,
kwargs_chunk_spec=Pipe.kwargs_chunk_spec,
)
return pipe
@ -1145,29 +1115,26 @@ def annotate_split_points(mod: torch.nn.Module, spec: Dict[str, SplitPoint]):
def pipeline(
module: torch.nn.Module,
num_chunks: int,
example_args: Tuple[Any, ...],
example_kwargs: Optional[Dict[str, Any]] = None,
mb_args: Tuple[Any, ...],
mb_kwargs: Optional[Dict[str, Any]] = None,
split_spec: Optional[Dict[str, SplitPoint]] = None,
split_policy: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None,
) -> Pipe:
"""
Creates a pipeline representation for the provided module.
Split a module based on a specification.
See `Pipe` for more details.
Arguments
---------
module:
The module to be transformed into a `Pipe`.
num_chunks:
The number of microbatches to be run with this pipeline.
example_args:
Example positional inputs to be used with this pipeline.
example_kwargs:
Example keyword inputs to be used with this pipeline. (default: `None`)
The module to be splitted.
mb_args:
Example positional inputs, in micro-batch form.
mb_kwargs:
Example keyword inputs, in micro-batch form. (default: `None`)
split_spec:
A dictionary mapping module names to `SplitPoint`s. (default: `None`)
A dictionary using submodule names as split marker. (default: `None`)
split_policy:
The policy to use for splitting the module. (default: `None`)
@ -1185,77 +1152,14 @@ def pipeline(
annotate_split_points(module, split_spec)
return Pipe.from_tracing(
mod=module,
num_chunks=num_chunks,
example_args=example_args,
example_kwargs=example_kwargs,
example_args=mb_args,
example_kwargs=mb_kwargs,
)
else:
# Use split policy
return Pipe.from_tracing(
mod=module,
num_chunks=num_chunks,
example_args=example_args,
example_kwargs=example_kwargs,
example_args=mb_args,
example_kwargs=mb_kwargs,
split_policy=split_policy,
)
class ArgsChunkSpec:
"""
Context manager for setting `args_chunk_spec` during creation of Pipe
Example:
>>> # xdoctest: +SKIP
>>> # There are three positional arguments to the model, and
>>> # we are chunking them along dimension 0, 0 and 1, respectively
>>> with ArgsChunkSpec((0, 0, 1)):
>>> pipe = pipeline(model, num_chunks, example_args)
"""
def __init__(
self,
chunk_dims: Tuple[int, ...],
):
self.args_chunk_spec = map_aggregate(
chunk_dims,
lambda dim: TensorChunkSpec(dim),
)
def __enter__(self):
# Inject into the Pipe class
Pipe.args_chunk_spec = self.args_chunk_spec
return self.args_chunk_spec
def __exit__(self, exc_type, exc_val, traceback):
# Remove from the Pipe class
Pipe.args_chunk_spec = None
class KwargsChunkSpec:
"""
Context manager for setting `kwargs_chunk_spec` during creation of Pipe
Example:
>>> # xdoctest: +SKIP
>>> # Chunk dimension 0 for the "id" argument, 1 for the "mask" argument
>>> with KwargsChunkSpec({"id": 0, "mask": 1}):
>>> pipe = pipeline(model, num_chunks, (), example_kwargs)
"""
def __init__(
self,
chunk_dims: Dict[str, int],
):
self.kwargs_chunk_spec = map_aggregate(
chunk_dims,
lambda dim: TensorChunkSpec(dim),
)
def __enter__(self):
# Inject into the Pipe class
Pipe.kwargs_chunk_spec = self.kwargs_chunk_spec
return self.kwargs_chunk_spec
def __exit__(self, exc_type, exc_val, traceback):
# Remove from the Pipe class
Pipe.kwargs_chunk_spec = None

View File

@ -1,13 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
from ._IR import (
annotate_split_points,
ArgsChunkSpec,
KwargsChunkSpec,
Pipe,
pipe_split,
pipeline,
SplitPoint,
)
from ._IR import Pipe, pipe_split, pipeline, SplitPoint
from .PipelineSchedule import (
Schedule1F1B,
ScheduleGPipe,
@ -20,10 +12,7 @@ __all__ = [
"Pipe",
"pipe_split",
"SplitPoint",
"annotate_split_points",
"pipeline",
"ArgsChunkSpec",
"KwargsChunkSpec",
"TracerPipelineStage",
"PipelineStage",
"Schedule1F1B",

View File

@ -3,9 +3,16 @@ import logging
from typing import Any, Dict, List, Optional, Tuple
import torch
from torch.fx.node import map_aggregate
from torch.utils._pytree import tree_flatten, tree_unflatten
__all__ = [
"TensorChunkSpec",
"split_args_kwargs_into_chunks",
"merge_chunks",
]
logger = logging.getLogger(__name__)
"""
@ -45,8 +52,11 @@ sum_reducer = _LossReducer(torch.tensor(0.0), lambda a, b: a + b)
DEFAULT_CHUNK_DIM = 0
# Class used to specify chunking of inputs
class TensorChunkSpec:
"""
Class used to specify chunking of inputs
"""
def __init__(self, split_dim):
self.split_dim = split_dim
@ -60,6 +70,43 @@ class TensorChunkSpec:
def __str__(self):
return f"TensorChunkSpec({self.split_dim})"
@staticmethod
def from_tuple(
chunk_dims: Tuple[int, ...],
):
"""
A helper for creating a tuple of `TensorChunkSpec` from a tuple of chunk
dimensions (int's).
Example:
>>> # xdoctest: +SKIP
>>> # There are three positional arguments to the model, and
>>> # we are chunking them along dimension 0, 0 and 1, respectively
>>> args_chunk_spec = TensorChunkSpec.from_tuple((0, 0, 1))
"""
args_chunk_spec = map_aggregate(
chunk_dims,
lambda dim: TensorChunkSpec(dim),
)
return args_chunk_spec
@staticmethod
def from_dict(
chunk_dims: Dict[str, int],
):
"""
A helper for creating a dictionary of `TensorChunkSpec` from a
dictionary of chunk dimensions (int's).
Example:
>>> # xdoctest: +SKIP
>>> # Chunk dimension 0 for the "id" argument, 1 for the "mask" argument
>>> kwargs_chunk_spec = TensorChunkSpec.from_dict({"id": 0, "mask": 1})
"""
kwargs_chunk_spec = map_aggregate(
chunk_dims,
lambda dim: TensorChunkSpec(dim),
)
return kwargs_chunk_spec
# Class used to specify replication of inputs
class _Replicate: