mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Changed the API of `pipeline()` to take microbatch instead of full batch as example args. Main purpose is to: - make this API more atomic; - decouple tracing frontend from runtime info like `num_chunks`. Side effects: - Creates opportunity for varying `num_chunks` of schedules with the same `pipe` object. - User has to create example microbatch input. - Chunk spec stuff are now all moved to runtime side. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128163 Approved by: https://github.com/H-Huang
92 lines
2.8 KiB
Python
92 lines
2.8 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates
|
|
# Owner(s): ["oncall: distributed"]
|
|
from model_registry import ModelWithKwargs
|
|
|
|
import torch
|
|
from torch.distributed.pipelining import pipeline
|
|
from torch.distributed.pipelining.microbatch import (
|
|
merge_chunks,
|
|
split_args_kwargs_into_chunks,
|
|
TensorChunkSpec,
|
|
)
|
|
from torch.testing._internal.common_utils import run_tests, TestCase
|
|
|
|
|
|
d_hid = 512
|
|
torch.manual_seed(0)
|
|
|
|
|
|
class MicrobatchTests(TestCase):
|
|
def test_split_and_merge(self):
|
|
x0 = torch.randn(128, d_hid)
|
|
x1 = torch.randn(256, d_hid)
|
|
x2 = torch.randn(512, d_hid)
|
|
|
|
args = (x0, x1, x2)
|
|
kwargs = {"x0": x0, "x1": x1, "x2": x2}
|
|
|
|
# Default chunking: dim 0
|
|
arg_chunks, kwarg_chunks = split_args_kwargs_into_chunks(args, kwargs, 2)
|
|
assert len(arg_chunks) == 2
|
|
assert len(kwarg_chunks) == 2
|
|
assert arg_chunks[0][0].shape == torch.Size([64, d_hid])
|
|
assert arg_chunks[1][0].shape == torch.Size([64, d_hid])
|
|
assert arg_chunks[0][1].shape == torch.Size([128, d_hid])
|
|
assert arg_chunks[0][2].shape == torch.Size([256, d_hid])
|
|
assert kwarg_chunks[0]["x0"].shape == torch.Size([64, d_hid])
|
|
assert kwarg_chunks[0]["x1"].shape == torch.Size([128, d_hid])
|
|
assert kwarg_chunks[1]["x2"].shape == torch.Size([256, d_hid])
|
|
|
|
# Merge chunks back together
|
|
merged_args = merge_chunks(
|
|
arg_chunks,
|
|
(TensorChunkSpec(0), TensorChunkSpec(0), TensorChunkSpec(0)),
|
|
)
|
|
torch.testing.assert_close(merged_args, args)
|
|
|
|
merged_kwargs = merge_chunks(
|
|
kwarg_chunks,
|
|
{
|
|
"x0": TensorChunkSpec(0),
|
|
"x1": TensorChunkSpec(0),
|
|
"x2": TensorChunkSpec(0),
|
|
},
|
|
)
|
|
torch.testing.assert_close(merged_kwargs, kwargs)
|
|
print("Microbatch test passed")
|
|
|
|
def test_chunk_spec(self):
|
|
mod = ModelWithKwargs()
|
|
batch_size = ModelWithKwargs.DEFAULT_BATCH_SIZE
|
|
|
|
x = torch.randn(batch_size, d_hid)
|
|
y = torch.randn(batch_size, d_hid)
|
|
|
|
num_chunks = 4
|
|
|
|
args_chunk_spec = TensorChunkSpec.from_tuple((0,))
|
|
kwargs_chunk_spec = TensorChunkSpec.from_dict({"y": 0})
|
|
|
|
args_split, kwargs_split = split_args_kwargs_into_chunks(
|
|
(x,),
|
|
{"y": y},
|
|
num_chunks,
|
|
args_chunk_spec,
|
|
kwargs_chunk_spec,
|
|
)
|
|
|
|
pipe = pipeline(
|
|
mod,
|
|
mb_args=args_split[0],
|
|
mb_kwargs=kwargs_split[0],
|
|
)
|
|
|
|
ref = mod(x, y)
|
|
out = pipe(x, y)[0]
|
|
torch.testing.assert_close(out, ref)
|
|
print(f"equivalence test passed {torch.sum(out)} ref {torch.sum(ref)}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|