Pull in fairscale.nn.Pipe into PyTorch. (#44090)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44090

This is an initial commit pulling in the torchgpipe fork at
https://github.com/facebookresearch/fairscale.

The purpose of this commit is to just pull in the code and ensure all tests and
builds work fine. We will slowly modify this to match our intended API
mentioned in https://fb.quip.com/txurAV3zIFox#RPZACAfAKMq. Follow up PRs would
address further changes needed on top of the initial commit..

We're pulling the code into the `torch.distributed._pipeline.sync` package. The
package is private on purpose since there is a lot of work (ex: docs, API
changes etc.) that needs to go in before we can actually officially support
this.
ghstack-source-id: 114864254

Test Plan:
1) waitforbuildbot
2) Ran all tests on my devgpu

Reviewed By: mrshenli

Differential Revision: D23493316

fbshipit-source-id: fe3c8b7dadeeb86abdc00e8a8652491b0b16743a
This commit is contained in:
Pritam Damania
2020-10-22 10:53:07 -07:00
committed by Facebook GitHub Bot
parent b63ddd6f57
commit 06d50b5eb0
53 changed files with 6532 additions and 7 deletions

13
LICENSE
View File

@ -16,23 +16,26 @@ Copyright (c) 2016-present, Facebook Inc. All rights reserved.
All contributions by Facebook: All contributions by Facebook:
Copyright (c) 2016 Facebook Inc. Copyright (c) 2016 Facebook Inc.
All contributions by Google: All contributions by Google:
Copyright (c) 2015 Google Inc. Copyright (c) 2015 Google Inc.
All rights reserved. All rights reserved.
All contributions by Yangqing Jia: All contributions by Yangqing Jia:
Copyright (c) 2015 Yangqing Jia Copyright (c) 2015 Yangqing Jia
All rights reserved. All rights reserved.
All contributions by Kakao Brain:
Copyright 2019-2020 Kakao Brain
All contributions from Caffe: All contributions from Caffe:
Copyright(c) 2013, 2014, 2015, the respective contributors Copyright(c) 2013, 2014, 2015, the respective contributors
All rights reserved. All rights reserved.
All other contributions: All other contributions:
Copyright(c) 2015, 2016 the respective contributors Copyright(c) 2015, 2016 the respective contributors
All rights reserved. All rights reserved.
Caffe2 uses a copyright model similar to Caffe: each contributor holds Caffe2 uses a copyright model similar to Caffe: each contributor holds
copyright over their contributions to Caffe2. The project versioning records copyright over their contributions to Caffe2. The project versioning records
all such contribution and copyright details. If a contributor wants to further all such contribution and copyright details. If a contributor wants to further

3
NOTICE
View File

@ -22,6 +22,9 @@ All contributions by Yangqing Jia:
Copyright (c) 2015 Yangqing Jia Copyright (c) 2015 Yangqing Jia
All rights reserved. All rights reserved.
All contributions by Kakao Brain:
Copyright 2019-2020 Kakao Brain
All other contributions: All other contributions:
Copyright(c) 2015, 2016 the respective contributors Copyright(c) 2015, 2016 the respective contributors
All rights reserved. All rights reserved.

View File

@ -0,0 +1,27 @@
Copyright 2019-2020 Kakao Brain
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from this
software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
POSSIBILITY OF SUCH DAMAGE.

View File

@ -0,0 +1,8 @@
# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# tests/__init__.py makes pytest can import the application without custom sys.path or PYTHONPATH.
# See also: https://docs.pytest.org/en/latest/goodpractices.html

View File

@ -0,0 +1,37 @@
# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import pytest
import torch
@pytest.fixture(autouse=True)
def manual_seed_zero():
torch.manual_seed(0)
@pytest.fixture(scope="session")
def cuda_sleep():
# Warm-up CUDA.
torch.empty(1, device="cuda")
# From test/test_cuda.py in PyTorch.
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
torch.cuda._sleep(1000000)
end.record()
end.synchronize()
cycles_per_ms = 1000000 / start.elapsed_time(end)
def cuda_sleep(seconds):
torch.cuda._sleep(int(seconds * cycles_per_ms * 1000))
return cuda_sleep
def pytest_report_header():
return f"torch: {torch.__version__}"

View File

@ -0,0 +1,6 @@
# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

View File

@ -0,0 +1,45 @@
# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import copy
from torch import nn
from torch.distributed._pipeline.sync.skip import Namespace, skippable, stash
def test_namespace_difference():
ns1 = Namespace()
ns2 = Namespace()
assert ns1 != ns2
def test_namespace_copy():
ns = Namespace()
assert copy.copy(ns) == ns
assert copy.copy(ns) is not ns
def test_skippable_repr():
@skippable(stash=["hello"])
class Hello(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(1, 1, 1)
def forward(self, x):
yield stash("hello", x)
return self.conv(x) # noqa
m = Hello()
assert (
repr(m)
== """
@skippable(Hello(
(conv): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1))
))
""".strip()
)

View File

@ -0,0 +1,106 @@
# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import pytest
import torch
from torch import nn
from torch.distributed._pipeline.sync import Pipe
from torch.distributed._pipeline.sync.skip import pop, skippable, stash
from torch.distributed._pipeline.sync.skip.portal import PortalBlue, PortalCopy, PortalOrange
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
@pytest.mark.parametrize("balance", [[3], [1, 2], [2, 1], [1, 1, 1]], ids=["3", "1:2", "2:1", "1:1:1"])
@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"])
def test_1to3(balance, checkpoint):
if torch.cuda.device_count() < len(balance):
pytest.skip("at least %d cuda devices required" % len(balance))
@skippable(stash=["1to3"])
class Layer1(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 3, 1)
def forward(self, input):
yield stash("1to3", input)
output = self.conv(input)
return output # noqa
class Layer2(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 3, 1)
def forward(self, input):
output = self.conv(input)
return output
@skippable(pop=["1to3"])
class Layer3(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 3, 1)
def forward(self, input):
skip_1to3 = yield pop("1to3")
output = self.conv(input) + skip_1to3
return output
model = nn.Sequential(Layer1(), Layer2(), Layer3())
model = Pipe(model, balance, chunks=3, checkpoint=checkpoint)
in_device = model.devices[0]
out_device = model.devices[-1]
input = torch.rand(30, 3, 224, 224, device=in_device, requires_grad=True)
output = model(input)
loss = output.mean()
loss.backward()
assert torch.allclose(output.norm(), torch.tensor(1039.0, device=out_device), atol=6e-1)
assert torch.allclose(input.grad.norm(), torch.tensor(0.0004533053, device=in_device))
def test_none_skip():
@skippable(stash=["none"])
class Stash(nn.Module):
def forward(self, input):
yield stash("none", None)
return input # noqa
@skippable(pop=["none"])
class Pop(nn.Module):
def forward(self, input):
none = yield pop("none")
assert none is None
return input
model = nn.Sequential(Stash(), Pop())
model = Pipe(model, [1, 1], devices=["cpu", "cpu"], chunks=5)
input = torch.rand(10, requires_grad=True)
output = model(input)
def assert_grad_fn_is_not_portal(grad_fn, visited=None):
if visited is None:
visited = set()
if grad_fn in visited or grad_fn is None:
return
assert not isinstance(grad_fn, PortalBlue._backward_cls)
assert not isinstance(grad_fn, PortalCopy._backward_cls)
assert not isinstance(grad_fn, PortalOrange._backward_cls)
visited.add(grad_fn)
for next_grad_fn, _ in grad_fn.next_functions:
assert_grad_fn_is_not_portal(next_grad_fn, visited)
assert_grad_fn_is_not_portal(output.grad_fn)
output.sum().backward()
assert input.grad.mean().item() == 1

View File

@ -0,0 +1,111 @@
# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from torch import nn
from torch.distributed._pipeline.sync.skip import Namespace, pop, skippable, stash
from torch.distributed._pipeline.sync.skip.layout import inspect_skip_layout
class Pass(nn.Module):
def forward(self, input):
return input
@skippable(stash=["foo"])
class StashFoo(nn.Module):
def forward(self, input):
yield stash("foo", input)
return input # noqa
@skippable(pop=["foo"])
class PopFoo(nn.Module):
def forward(self, input):
foo = yield stash("foo")
return input + foo
@skippable(stash=["bar"])
class StashBar(nn.Module):
def forward(self, input):
yield stash("bar", input)
return input # noqa
@skippable(pop=["bar"])
class PopBar(nn.Module):
def forward(self, input):
bar = yield pop("bar")
return input + bar
def test_no_skippables():
p1 = nn.Sequential(Pass())
p2 = nn.Sequential(Pass())
layout = inspect_skip_layout([p1, p2])
policy = [list(layout.copy_policy(i)) for i in range(2)]
assert policy == [[], []]
def test_inner_partition():
p1 = nn.Sequential(StashFoo(), PopFoo())
p2 = nn.Sequential(Pass())
layout = inspect_skip_layout([p1, p2])
policy = [list(layout.copy_policy(i)) for i in range(2)]
assert policy == [[], []]
def test_adjoining_partitions():
p1 = nn.Sequential(StashFoo())
p2 = nn.Sequential(PopFoo())
layout = inspect_skip_layout([p1, p2])
policy = [list(layout.copy_policy(i)) for i in range(2)]
assert policy == [[], [(0, None, "foo")]]
def test_far_partitions():
p1 = nn.Sequential(StashFoo())
p2 = nn.Sequential(Pass())
p3 = nn.Sequential(PopFoo())
layout = inspect_skip_layout([p1, p2, p3])
policy = [list(layout.copy_policy(i)) for i in range(3)]
assert policy == [[], [], [(0, None, "foo")]]
def test_pop_2_from_different_partitions():
p1 = nn.Sequential(StashFoo())
p2 = nn.Sequential(StashBar())
p3 = nn.Sequential(PopBar(), PopFoo())
layout = inspect_skip_layout([p1, p2, p3])
policy = [list(layout.copy_policy(i)) for i in range(3)]
# p3 pops 'bar' before 'foo', but the plan is sorted by source partition index.
assert policy == [[], [], [(0, None, "foo"), (1, None, "bar")]]
def test_namespace():
ns1 = Namespace()
ns2 = Namespace()
p1 = nn.Sequential(StashFoo().isolate(ns1))
p2 = nn.Sequential(StashFoo().isolate(ns2))
p3 = nn.Sequential(PopFoo().isolate(ns2), PopFoo().isolate(ns1))
layout = inspect_skip_layout([p1, p2, p3])
policy = [list(layout.copy_policy(i)) for i in range(3)]
# p3 pops 'bar' before 'foo', but the plan is sorted by source partition index.
assert policy == [[], [], [(0, ns1, "foo"), (1, ns2, "foo")]]

View File

@ -0,0 +1,126 @@
# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import pytest
import torch
from torch import nn
from torch.distributed._pipeline.sync import Pipe, is_checkpointing, is_recomputing
from torch.distributed._pipeline.sync.skip import pop, skippable, stash
from torch.distributed._pipeline.sync.skip.tracker import current_skip_tracker
@skippable(stash=["skip"])
class Stash(nn.Module):
def forward(self, input):
yield stash("skip", input)
return input # noqa
@skippable(pop=["skip"])
class Pop(nn.Module):
def forward(self, input):
skip = yield pop("skip")
return input + skip
@pytest.mark.parametrize("train", [True, False], ids=["train", "eval"])
@pytest.mark.parametrize("checkpoint", ["always", "except_last", "never"])
def test_delete_portal_tensor(train, checkpoint):
# Without checkpointing:
# +- Stash --+ +--- Pop ----+ - - - layers
# | 2,blue,1 |--| 1,orange,0 | - - - tensor_life and portal function
# +----------+ +------------+
#
# With checkpointing:
# +- Stash --+ +--- Pop ----+ +--- Pop'----+ +- Stash'--+
# | 3,blue,2 |--| 2,orange,1 |--| 1,orange,0 |--| 1,blue,0 |
# +----------+ +------------+ +------------+ +----------+
def portal_tensor_life_is(tensor_life, skip_tracker=None):
if skip_tracker is None:
skip_tracker = current_skip_tracker()
# Get the current portal.
portal = list(skip_tracker.portals.values())[0]
if tensor_life == 0:
return portal.tensor_life == 0 and portal.tensor is None
else:
return portal.tensor_life == tensor_life and portal.tensor is not None
# Check the portal tensor after 'Stash'.
stash_ = Stash()
@stash_.register_forward_hook
def check_portal_tensor_after_stash(*_):
if is_checkpointing():
assert portal_tensor_life_is(2)
elif is_recomputing():
assert portal_tensor_life_is(0)
else:
assert portal_tensor_life_is(1)
pop_ = Pop()
@pop_.register_forward_hook
def check_portal_tensor_after_pop(*_):
if is_checkpointing():
assert portal_tensor_life_is(1)
elif is_recomputing():
assert portal_tensor_life_is(0)
else:
assert portal_tensor_life_is(0)
class NoPortalTensorAtBackward(nn.Module):
class F(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
ctx.skip_tracker = current_skip_tracker()
return input.detach()
@staticmethod
def backward(ctx, grad):
assert portal_tensor_life_is(0, skip_tracker=ctx.skip_tracker)
return grad
def forward(self, input):
return self.F.apply(input)
model = nn.Sequential(NoPortalTensorAtBackward(), stash_, pop_)
model = Pipe(model, balance=[2, 1], devices=["cpu", "cpu"], chunks=2, checkpoint=checkpoint)
input = torch.rand(10, requires_grad=True)
if train:
model.train()
output = model(input)
output.norm().backward()
else:
model.eval()
with torch.no_grad():
model(input)
@pytest.mark.parametrize("train", [True, False], ids=["train", "eval"])
def test_no_portal_without_pipe(train, monkeypatch):
def deny(*args, **kwargs):
raise AssertionError("tried to create Portal without Pipe")
monkeypatch.setattr("torch.distributed._pipeline.sync.skip.portal.Portal.__init__", deny)
model = nn.Sequential(Stash(), Pop())
input = torch.rand(10, requires_grad=True)
if train:
model.train()
output = model(input)
output.norm().backward()
else:
model.eval()
with torch.no_grad():
model(input)

View File

@ -0,0 +1,155 @@
# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import pytest
import torch
from torch.distributed._pipeline.sync.dependency import fork, join
from torch.distributed._pipeline.sync.skip.portal import Portal
from torch.distributed._pipeline.sync.stream import default_stream
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
def test_copy_returns_on_next_device():
portal = Portal(torch.rand(1), tensor_life=1)
prev_stream = default_stream(torch.device("cpu"))
next_stream = default_stream(torch.device("cuda"))
phony = torch.zeros(0, requires_grad=True)
assert phony.device.type == "cpu"
phony = portal.copy(prev_stream, next_stream, phony)
assert phony.device.type == "cuda"
def test_blue_orange():
tensor1 = torch.rand(1, requires_grad=True)
tensor2 = torch.rand(1, requires_grad=True)
# Same with: output = tensor1*2 + tensor2
#
# +----------------------+
# | |
# tensor2 -- PortalBlue -+ +- PortalOrange -+
# | | |
# tensor1 ------------ Join -- Fork --- Mul --- Add -- output
#
main = tensor1
portal = Portal(tensor2, tensor_life=2)
phony = portal.blue()
main = join(main, phony)
main, phony = fork(main)
sub = portal.orange(phony)
output = main * 2 + sub
output.backward()
assert torch.allclose(tensor1.grad, torch.tensor([2.0]))
assert torch.allclose(tensor2.grad, torch.tensor([1.0]))
def test_blue_orange_not_requires_grad():
tensor1 = torch.rand(1, requires_grad=True)
tensor2 = torch.rand(1)
# Same with: output = tensor1*2 + tensor2
#
# +----------------------+
# | |
# tensor2 -- PortalBlue -+ +- PortalOrange -+
# | | |
# tensor1 ------------ Join -- Fork --- Mul --- Add -- output
#
main = tensor1
portal = Portal(tensor2, tensor_life=2)
phony = portal.blue()
main = join(main, phony)
main, phony = fork(main)
sub = portal.orange(phony)
output = main * 2 + sub
output.backward()
assert torch.allclose(tensor1.grad, torch.tensor([2.0]))
assert tensor2.grad is None
def test_use_grad():
tensor = torch.rand(1, requires_grad=True)
portal = Portal(tensor, tensor_life=1)
portal.put_grad(tensor)
assert portal.use_grad() is tensor
# Gradient in a portal is ephemeral.
with pytest.raises(RuntimeError):
portal.use_grad()
class TestTensorLife:
@pytest.fixture
def new_portal(self):
portal = None
def new_portal(tensor_life):
nonlocal portal
tensor = torch.rand(1, requires_grad=True)
portal = Portal(tensor, tensor_life)
return portal, tensor
yield new_portal
# A test using this fixture must exhaust the tensor in the portal.
with pytest.raises(RuntimeError):
portal.check_tensor_life()
assert portal.tensor is None
def test_tensor_life_0(self, new_portal):
portal, tensor = new_portal(0)
assert portal.tensor is None
def test_tensor_life_1(self, new_portal):
portal, tensor = new_portal(1)
assert portal.tensor is tensor
portal.blue()
def test_tensor_life_2(self, new_portal):
portal, tensor = new_portal(2)
assert portal.tensor is tensor
phony = portal.blue()
assert portal.orange(phony).data_ptr() == tensor.data_ptr()
def test_tensor_life_3(self, new_portal):
portal, tensor = new_portal(3)
assert portal.tensor is tensor
phony = portal.blue()
assert portal.orange(phony).data_ptr() == tensor.data_ptr()
assert portal.orange(phony).data_ptr() == tensor.data_ptr()
def test_tensor_life_4(self, new_portal):
portal, tensor = new_portal(4)
assert portal.tensor is tensor
phony = portal.blue()
assert portal.orange(phony).data_ptr() == tensor.data_ptr()
assert portal.orange(phony).data_ptr() == tensor.data_ptr()
portal.blue()
def test_tensor_life_3_plus_1(self, new_portal):
portal, tensor = new_portal(3)
assert portal.tensor is tensor
phony = portal.blue()
assert portal.orange(phony).data_ptr() == tensor.data_ptr()
assert portal.orange(phony).data_ptr() == tensor.data_ptr()
another_tensor = torch.rand(1, requires_grad=True)
portal.put_tensor(another_tensor, tensor_life=1)
portal.blue()

View File

@ -0,0 +1,136 @@
# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import pytest
import torch
from torch import nn
from torch.distributed._pipeline.sync.skip import pop, skippable, stash
from torch.distributed._pipeline.sync.skip.tracker import SkipTracker, use_skip_tracker
@pytest.fixture(autouse=True)
def skip_tracker():
skip_tracker = SkipTracker()
with use_skip_tracker(skip_tracker):
yield skip_tracker
def test_stash(skip_tracker):
@skippable(stash=["foo"])
class Stash(nn.Module):
def forward(self, input):
yield stash("foo", input)
return input * 2 # noqa
l1 = Stash()
assert len(skip_tracker.tensors) == 0
with use_skip_tracker(skip_tracker):
l1(torch.tensor(42))
assert len(skip_tracker.tensors) == 1
def test_pop():
@skippable(stash=["foo"])
class Stash(nn.Module):
def forward(self, input):
yield stash("foo", input)
return input * 2 # noqa
@skippable(pop=["foo"])
class Pop(nn.Module):
def forward(self, input):
foo = yield pop("foo")
return foo # noqa
l1 = Stash()
l2 = Pop()
output = l2(l1(torch.tensor(42)))
assert output.item() == 42
def test_declare_but_not_use():
@skippable(stash=["foo"])
class Stash(nn.Module):
def forward(self, input):
return input * 2
@skippable(pop=["foo"])
class Pop(nn.Module):
def forward(self, input):
return input * 3
l1 = Stash()
l2 = Pop()
with pytest.raises(RuntimeError):
l1(torch.tensor(42))
with pytest.raises(RuntimeError):
l2(torch.tensor(42))
def test_stash_not_declared():
@skippable()
class Stash(nn.Module):
def forward(self, input):
yield stash("foo", input)
return input * 2 # noqa
l1 = Stash()
with pytest.raises(RuntimeError):
l1(torch.tensor(42))
def test_pop_not_declared():
@skippable(stash=["foo"])
class Stash(nn.Module):
def forward(self, input):
yield stash("foo", input)
return input * 2 # noqa
@skippable()
class Pop(nn.Module):
def forward(self, input):
foo = yield pop("foo")
return foo # noqa
l1 = Stash()
l2 = Pop()
latent = l1(torch.tensor(42))
with pytest.raises(RuntimeError):
l2(latent)
def test_pop_not_stashed():
@skippable(pop=["foo"])
class Pop(nn.Module):
def forward(self, input):
yield pop("foo")
l1 = Pop()
with pytest.raises(RuntimeError):
l1(torch.tensor(42))
def test_stash_none():
@skippable(stash=["foo"])
class Stash(nn.Module):
def forward(self, input):
yield stash("foo", None)
return input * 2 # noqa
l1 = Stash()
l1(torch.tensor(42))

View File

@ -0,0 +1,127 @@
# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from queue import Queue
import threading
import pytest
import torch
from torch import nn
from torch.distributed._pipeline.sync.checkpoint import enable_checkpointing, enable_recomputing
from torch.distributed._pipeline.sync.microbatch import Batch
from torch.distributed._pipeline.sync.skip import pop, skippable, stash
from torch.distributed._pipeline.sync.skip.layout import SkipLayout
from torch.distributed._pipeline.sync.skip.tracker import SkipTracker, SkipTrackerThroughPotals, current_skip_tracker
def test_default_skip_tracker():
q = Queue()
def f():
q.put(current_skip_tracker())
t = threading.Thread(target=f)
t.start()
t.join()
skip_tracker = q.get()
assert type(skip_tracker) is SkipTracker
assert type(skip_tracker) is not SkipTrackerThroughPotals
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
def test_default_skip_tracker_by_data_parallel():
@skippable(stash=["foo"])
class Stash(nn.Module):
def forward(self, input):
yield stash("foo", input)
return input * 2 # noqa
@skippable(pop=["foo"])
class Pop(nn.Module):
def forward(self, input):
foo = yield pop("foo")
return foo
model = nn.Sequential(Stash(), Pop())
model = nn.DataParallel(model, device_ids=[0, 0], output_device=0)
input = torch.rand(10, device=0)
output = model(input)
assert torch.allclose(output, input)
def test_reuse_portal():
skip_layout = SkipLayout(num_partitions=2, skip_routes={(None, "test"): (0, 1)})
skip_tracker = SkipTrackerThroughPotals(skip_layout)
batch = Batch(torch.tensor([1.0]))
a = torch.tensor([2.0])
b = torch.tensor([2.0])
skip_tracker.save(batch, None, "test", a)
portal = skip_tracker.portals[(None, "test")]
skip_tracker.save(batch, None, "test", b)
assert portal is skip_tracker.portals[(None, "test")]
def test_no_copy_no_portal():
skip_layout = SkipLayout(num_partitions=2, skip_routes={(None, "copy"): (0, 1), (None, "not_copy"): (0, 0)})
skip_tracker = SkipTrackerThroughPotals(skip_layout)
batch = Batch(torch.tensor([1.0]))
a = torch.tensor([2.0])
b = torch.tensor([2.0])
skip_tracker.save(batch, None, "copy", a)
skip_tracker.save(batch, None, "not_copy", b)
assert (None, "copy") in skip_tracker.portals
assert (None, "copy") not in skip_tracker.tensors
assert (None, "not_copy") in skip_tracker.tensors
assert (None, "not_copy") not in skip_tracker.portals
def test_tensor_life_without_checkpointing():
skip_layout = SkipLayout(num_partitions=2, skip_routes={(None, "test"): (0, 1)})
skip_tracker = SkipTrackerThroughPotals(skip_layout)
batch = Batch(torch.tensor([1.0]))
tensor = torch.tensor([2.0])
skip_tracker.save(batch, None, "test", tensor)
assert skip_tracker.portals[(None, "test")].tensor_life == 1
skip_tracker.load(batch, None, "test")
assert skip_tracker.portals[(None, "test")].tensor_life == 0
def test_tensor_life_with_checkpointing():
skip_layout = SkipLayout(num_partitions=2, skip_routes={(None, "test"): (0, 1)})
skip_tracker = SkipTrackerThroughPotals(skip_layout)
batch = Batch(torch.tensor([1.0]))
tensor = torch.tensor([2.0])
with enable_checkpointing():
skip_tracker.save(batch, None, "test", tensor)
assert skip_tracker.portals[(None, "test")].tensor_life == 2
with enable_checkpointing():
skip_tracker.load(batch, None, "test")
assert skip_tracker.portals[(None, "test")].tensor_life == 1
with enable_recomputing():
skip_tracker.load(batch, None, "test")
assert skip_tracker.portals[(None, "test")].tensor_life == 0
with enable_recomputing():
skip_tracker.save(batch, None, "test", tensor)
assert skip_tracker.portals[(None, "test")].tensor_life == 0

View File

@ -0,0 +1,152 @@
# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import pytest
from torch import nn
from torch.distributed._pipeline.sync.skip import Namespace, skippable, verify_skippables
def test_matching():
@skippable(stash=["foo"])
class Layer1(nn.Module):
pass
@skippable(pop=["foo"])
class Layer2(nn.Module):
pass
verify_skippables(nn.Sequential(Layer1(), Layer2()))
def test_stash_not_pop():
@skippable(stash=["foo"])
class Layer1(nn.Module):
pass
with pytest.raises(TypeError) as e:
verify_skippables(nn.Sequential(Layer1()))
assert "no module declared 'foo' as poppable but stashed" in str(e.value)
def test_pop_unknown():
@skippable(pop=["foo"])
class Layer1(nn.Module):
pass
with pytest.raises(TypeError) as e:
verify_skippables(nn.Sequential(Layer1()))
assert "'0' declared 'foo' as poppable but it was not stashed" in str(e.value)
def test_stash_again():
@skippable(stash=["foo"])
class Layer1(nn.Module):
pass
@skippable(stash=["foo"])
class Layer2(nn.Module):
pass
@skippable(pop=["foo"])
class Layer3(nn.Module):
pass
with pytest.raises(TypeError) as e:
verify_skippables(nn.Sequential(Layer1(), Layer2(), Layer3()))
assert "'1' redeclared 'foo' as stashable" in str(e.value)
def test_pop_again():
@skippable(stash=["foo"])
class Layer1(nn.Module):
pass
@skippable(pop=["foo"])
class Layer2(nn.Module):
pass
@skippable(pop=["foo"])
class Layer3(nn.Module):
pass
with pytest.raises(TypeError) as e:
verify_skippables(nn.Sequential(Layer1(), Layer2(), Layer3()))
assert "'2' redeclared 'foo' as poppable" in str(e.value)
def test_stash_pop_together_different_names():
@skippable(stash=["foo"])
class Layer1(nn.Module):
pass
@skippable(pop=["foo"], stash=["bar"])
class Layer2(nn.Module):
pass
@skippable(pop=["bar"])
class Layer3(nn.Module):
pass
verify_skippables(nn.Sequential(Layer1(), Layer2(), Layer3()))
def test_stash_pop_together_same_name():
@skippable(stash=["foo"], pop=["foo"])
class Layer1(nn.Module):
pass
with pytest.raises(TypeError) as e:
verify_skippables(nn.Sequential(Layer1()))
assert "'0' declared 'foo' both as stashable and as poppable" in str(e.value)
def test_double_stash_pop():
@skippable(stash=["foo"])
class Layer1(nn.Module):
pass
@skippable(pop=["foo"])
class Layer2(nn.Module):
pass
@skippable(stash=["foo"])
class Layer3(nn.Module):
pass
@skippable(pop=["foo"])
class Layer4(nn.Module):
pass
with pytest.raises(TypeError) as e:
verify_skippables(nn.Sequential(Layer1(), Layer2(), Layer3(), Layer4()))
assert "'2' redeclared 'foo' as stashable" in str(e.value)
assert "'3' redeclared 'foo' as poppable" in str(e.value)
def test_double_stash_pop_but_isolated():
@skippable(stash=["foo"])
class Layer1(nn.Module):
pass
@skippable(pop=["foo"])
class Layer2(nn.Module):
pass
@skippable(stash=["foo"])
class Layer3(nn.Module):
pass
@skippable(pop=["foo"])
class Layer4(nn.Module):
pass
ns1 = Namespace()
ns2 = Namespace()
verify_skippables(
nn.Sequential(Layer1().isolate(ns1), Layer2().isolate(ns1), Layer3().isolate(ns2), Layer4().isolate(ns2),)
)

View File

@ -0,0 +1,225 @@
# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import time
import pytest
import torch
from torch import nn
from torch.distributed._pipeline.sync.balance import balance_by_size, balance_by_time, blockpartition
from torch.distributed._pipeline.sync.balance.profile import layerwise_sandbox
skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
devices = ["cpu"]
if torch.cuda.is_available():
devices.append("cuda")
def test_blockpartition():
assert blockpartition.solve([1, 2, 3, 4, 5, 6], partitions=2) == [[1, 2, 3, 4], [5, 6]]
def test_blockpartition_zeros():
assert blockpartition.solve([0, 0], partitions=2) == [[0], [0]]
def test_blockpartition_non_positive_partitions():
with pytest.raises(ValueError):
blockpartition.solve([42], partitions=0)
with pytest.raises(ValueError):
blockpartition.solve([42], partitions=-1)
def test_blockpartition_short_sequence():
with pytest.raises(ValueError):
blockpartition.solve([], partitions=1)
with pytest.raises(ValueError):
blockpartition.solve([42], partitions=2)
@pytest.mark.parametrize("device", devices)
@pytest.mark.skip(reason="Flaky due to time.sleep()")
def test_balance_by_time(device):
class Delay(nn.Module):
def __init__(self, seconds):
super().__init__()
self.seconds = seconds
def forward(self, x):
time.sleep(self.seconds)
return x
model = nn.Sequential(*[Delay(i / 10) for i in [1, 2, 3, 4, 5, 6]])
sample = torch.rand(1)
balance = balance_by_time(2, model, sample, device=device)
assert balance == [4, 2]
def test_balance_by_time_loop_resets_input():
# nn.Flatten was introduced at PyTorch 1.2.0.
class Flatten(nn.Module):
def forward(self, x):
return x.flatten(1)
model = nn.Sequential(nn.Conv2d(3, 2, 1), Flatten(), nn.Linear(128, 10))
sample = torch.rand(10, 3, 8, 8)
balance = balance_by_time(2, model, sample, device="cpu")
assert balance == [1, 2]
@skip_if_no_cuda
def test_balance_by_size_latent():
class Expand(nn.Module):
def __init__(self, times):
super().__init__()
self.times = times
def forward(self, x):
for i in range(self.times):
x = x + torch.rand_like(x, requires_grad=True)
return x
sample = torch.rand(10, 100, 100)
model = nn.Sequential(*[Expand(i) for i in [1, 2, 3, 4, 5, 6]])
balance = balance_by_size(2, model, sample)
assert balance == [4, 2]
model = nn.Sequential(*[Expand(i) for i in [6, 5, 4, 3, 2, 1]])
balance = balance_by_size(2, model, sample)
assert balance == [2, 4]
@skip_if_no_cuda
def test_balance_by_size_param():
model = nn.Sequential(*[nn.Linear(i + 1, i + 2) for i in range(6)])
sample = torch.rand(7, 1)
balance = balance_by_size(2, model, sample, param_scale=100)
assert balance == [4, 2]
model = nn.Sequential(*[nn.Linear(i + 2, i + 1) for i in reversed(range(6))])
sample = torch.rand(1, 7)
balance = balance_by_size(2, model, sample, param_scale=100)
assert balance == [2, 4]
@skip_if_no_cuda
def test_balance_by_size_param_scale():
class Tradeoff(nn.Module):
def __init__(self, param_size, latent_size):
super().__init__()
self.fc = nn.Linear(param_size, param_size)
self.latent_size = latent_size
def forward(self, x):
for i in range(self.latent_size):
x = x + torch.rand_like(x, requires_grad=True)
return x
model = nn.Sequential(
Tradeoff(param_size=1, latent_size=6),
Tradeoff(param_size=2, latent_size=5),
Tradeoff(param_size=3, latent_size=4),
Tradeoff(param_size=4, latent_size=3),
Tradeoff(param_size=5, latent_size=2),
Tradeoff(param_size=6, latent_size=1),
)
sample = torch.rand(1, requires_grad=True)
balance = balance_by_size(2, model, sample, param_scale=0)
assert balance == [2, 4]
balance = balance_by_size(2, model, sample, param_scale=100)
assert balance == [4, 2]
@pytest.mark.parametrize("device", devices)
def test_layerwise_sandbox(device):
model = nn.Sequential(nn.Conv2d(3, 3, 1), nn.BatchNorm2d(3))
model.eval()
for layer in layerwise_sandbox(model, torch.device(device)):
assert layer.training
assert all(p.device.type == device for p in layer.parameters())
assert all(not l.training for l in model)
assert all(p.device.type == "cpu" for p in model.parameters())
@pytest.mark.parametrize("device", devices)
def test_sandbox_during_profiling(device):
model = nn.Sequential(nn.BatchNorm2d(3))
before = {k: v.clone() for k, v in model.state_dict().items()}
sample = torch.rand(1, 3, 10, 10)
balance_by_time(1, model, sample, device=device)
after = model.state_dict()
assert before.keys() == after.keys()
for key, value in before.items():
assert torch.allclose(after[key], value), key
def test_not_training():
class AssertTraining(nn.Module):
def forward(self, x):
assert self.training
return x
model = nn.Sequential(AssertTraining())
model.eval()
assert not model.training
sample = torch.rand(1)
balance_by_time(1, model, sample, device="cpu")
assert not model.training
def test_balance_by_time_tuple():
class Twin(nn.Module):
def forward(self, x):
return x, x.detach()
class Add(nn.Module):
def forward(self, a_b):
a, b = a_b
return a + b
model = nn.Sequential(Twin(), Add())
sample = torch.rand(1, requires_grad=True)
balance_by_time(1, model, sample, device="cpu")
@skip_if_no_cuda
def test_balance_by_size_tuple():
class Twin(nn.Module):
def forward(self, x):
return x, x.detach()
class Add(nn.Module):
def forward(self, a_b):
a, b = a_b
return a + b
model = nn.Sequential(Twin(), Add())
sample = torch.rand(1, requires_grad=True)
balance_by_size(1, model, sample)
def test_already_has_grad():
model = nn.Sequential(nn.Conv2d(3, 3, 1))
sample = torch.rand(1, 3, 32, 32)
model(sample).norm().backward()
with pytest.raises(ValueError, match="some parameter already has gradient"):
balance_by_time(1, model, sample, device="cpu")

View File

@ -0,0 +1,128 @@
# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import pytest
import torch
from torch import nn
import torch.nn.functional as F
from torch.distributed._pipeline.sync import Pipe
def test_python_autograd_function():
# A Python autograd function might fail with this error:
#
# RuntimeError: Returning Variables sharing storage with other Variables
# that require grad is not supported in Python functions. Please submit a
# feature request if you hit this error.
#
# It doesn't look like an essential restriction. But it happens on the
# current PyTorch version. To avoid it, we should detach the tensor before
# returning by identity autograd functions, such as Wait, Fork, and Join.
#
class Identity(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
return input
@staticmethod
def backward(ctx, grad):
return grad
class M(nn.Module):
def forward(self, input):
return Identity.apply(input)
model = nn.Sequential(M(), M())
model = Pipe(model, [1, 1], devices=["cpu", "cpu"], checkpoint="always")
x = torch.rand(42)
y = model(x)
assert torch.allclose(x, y)
def test_exception_no_hang():
# In v0.0.2, once a failed partition receives a normal message
# (non-closing) for the next micro-batch, a hang occured. The reason was
# that a failed partition didn't call in_queue.task_done() on a normal
# message. So the former partition was blocked at out_queue.join() for the
# next of next micro-batch.
class ExpectedException(Exception):
pass
class Pass(nn.Module):
def forward(self, x):
return x
class Raise(nn.Module):
def forward(self, x):
raise ExpectedException()
model = nn.Sequential(Pass(), Pass(), Raise())
model = Pipe(model, [1, 1, 1], devices=["cpu", "cpu", "cpu"], chunks=3)
with pytest.raises(ExpectedException):
model(torch.rand(3))
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="2 cuda devices required")
def test_tuple_wait(cuda_sleep):
# In v0.0.3, Wait is applied to only the first tensor on a micro-batch.
# Under this behavior, if checkpointing was disabled, there's a possibility
# that gradient accumulations on other tensors are not synchronized
# properly to the copy stream.
class Sleep(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return x.detach()
@staticmethod
def backward(ctx, grad):
with torch.cuda.device(grad.device):
cuda_sleep(0.05)
return grad
class Layer1(nn.Module):
def forward(self, pair):
a, b = pair
return a * 1, b * 2, b * 3
class Layer2(nn.Module):
def forward(self, triple):
a, b, c = triple
b = Sleep.apply(b)
return a + b + c
model = nn.Sequential(Layer1(), Layer2())
model = Pipe(model, [1, 1], devices=[0, 1], chunks=32, checkpoint="never")
a = torch.rand(1024, 3, 32, 32, device=0, requires_grad=True)
b = torch.rand(1024, 3, 32, 32, device=0, requires_grad=True)
y = model((a, b))
y.norm().backward()
torch.cuda.synchronize(0)
torch.cuda.synchronize(1)
assert torch.isclose(b.grad.norm().cpu(), torch.tensor(5.000))
def test_parallel_randoms():
class Dropouts(nn.Module):
def forward(self, x):
for _ in range(100):
x = F.dropout(x, p=0.001)
return x
model = nn.Sequential(Dropouts(), Dropouts())
x = torch.rand(10, 10, requires_grad=True)
model = Pipe(model, [1, 1], devices=["cpu", "cpu"], chunks=10, checkpoint="always")
y = model(x)
y.norm().backward()
assert y.to(torch.bool).tolist() == x.grad.to(torch.bool).tolist()

View File

@ -0,0 +1,158 @@
# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from functools import partial
import pytest
import torch
from torch import nn
import torch.cuda
from torch.distributed._pipeline.sync.checkpoint import Checkpointing, checkpoint, is_checkpointing, is_recomputing
from torch.distributed._pipeline.sync.dependency import fork, join
from torch.distributed._pipeline.sync.microbatch import Batch
devices = ["cpu"]
if torch.cuda.is_available():
devices.append("cuda")
@pytest.mark.parametrize("device", devices)
def test_serial_checkpoints(device):
# Copied from https://github.com/pytorch/pytorch/pull/18568.
timeline = []
class Log(torch.autograd.Function):
@staticmethod
def forward(ctx, name, x):
ctx.name = name
timeline.append(f"{name}:forward")
return x.detach()
@staticmethod
def backward(ctx, grad_output):
name = ctx.name
timeline.append(f"{name}:backward")
return None, grad_output
a = torch.rand(1, device=device, requires_grad=True)
b = torch.rand(1, device=device, requires_grad=True)
# Increase the next function sequence number.
_ = a + 1 + 2 + 3 + 4 + 5
a = checkpoint(partial(Log.apply, "a"), a)
a, phony = fork(a)
b = join(b, phony)
b = checkpoint(partial(Log.apply, "b"), b)
c = torch.cat((a, b))
out = c.sum()
# +--> {a} --Checkpoint(Log)--> {a}
# {out} --Sum--> {c} --Cat ^-----------------------------+
# +--> {b} --Checkpoint(Log)--> {b} --First--> {b}
out.backward()
assert timeline == ["a:forward", "b:forward", "b:forward", "b:backward", "a:forward", "a:backward"]
# |----------------------| |-----------------------| |-----------------------|
# forward pass Checkpoint(Log[b]) Checkpoint(Log[a])
def test_not_requires_grad():
x = Batch(torch.rand(1, requires_grad=False))
assert not x[0].requires_grad
def f(x):
return x * 2
chk = Checkpointing(f, x)
x = chk.checkpoint()
assert x[0].requires_grad
chk.recompute(x)
assert x[0].requires_grad
x.tensor.backward()
def test_not_requires_grad_with_parameter():
x = torch.rand(1, requires_grad=False)
a = torch.rand(1, requires_grad=True)
def f(x):
return x * a
y = checkpoint(f, x)
y.backward()
assert a.grad is not None
@pytest.mark.parametrize("device", devices)
def test_random_in_checkpoint(device):
dropout = nn.Dropout(p=0.5)
torch.manual_seed(0)
x = torch.randn(3, 3, device=device, requires_grad=True)
y = dropout(x)
y.norm().backward()
torch.manual_seed(0)
chk_x = torch.randn(3, 3, device=device, requires_grad=True)
chk_y = checkpoint(dropout, chk_x)
chk_y.norm().backward()
assert torch.allclose(x.grad, chk_x.grad)
def test_detect_checkpointing_recomputing():
logs = []
class Detect(nn.Module):
def forward(self, input):
logs.append((is_checkpointing(), is_recomputing()))
return input
model = Detect()
input = torch.rand(1, requires_grad=True)
output = checkpoint(model, input)
output.backward()
assert logs == [(True, False), (False, True)]
def test_detect_checkpointing_recomputing_without_checkpoint():
logs = []
class Detect(nn.Module):
def forward(self, input):
logs.append((is_checkpointing(), is_recomputing()))
return input
model = Detect()
input = torch.rand(1, requires_grad=True)
output = model(input)
output.backward()
assert logs == [(False, False)]
def test_non_grad_output():
class ForkNonGrad(nn.Module):
def forward(self, input):
return (input * 2, torch.rand(1))
model = ForkNonGrad()
input = torch.rand(1, requires_grad=True)
output = checkpoint(model, input)
output[0].backward()

View File

@ -0,0 +1,68 @@
# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import pytest
import torch
from torch.distributed._pipeline.sync.copy import Copy, Wait
from torch.distributed._pipeline.sync.stream import CPUStream, current_stream, get_device, is_cuda, new_stream, use_stream
skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
def _test_copy_wait(prev_stream, next_stream, cuda_sleep=None):
device = get_device(prev_stream)
with use_stream(prev_stream):
if is_cuda(prev_stream):
cuda_sleep(0.5)
x = torch.ones(100, device=device, requires_grad=True)
(y,) = Copy.apply(prev_stream, next_stream, x)
(y,) = Wait.apply(prev_stream, next_stream, x)
with use_stream(next_stream):
assert torch.allclose(y.sum(), torch.tensor(100.0, device=device))
y.norm().backward()
with use_stream(prev_stream):
assert torch.allclose(x.grad.sum(), torch.tensor(10.0, device=device))
def test_copy_wait_cpu_cpu():
prev_stream = CPUStream
next_stream = CPUStream
_test_copy_wait(prev_stream, next_stream)
@skip_if_no_cuda
def test_copy_wait_cpu_cuda(cuda_sleep):
prev_stream = CPUStream
next_stream = current_stream(torch.device("cuda"))
_test_copy_wait(prev_stream, next_stream, cuda_sleep)
@skip_if_no_cuda
def test_copy_wait_cuda_cpu(cuda_sleep):
prev_stream = current_stream(torch.device("cuda"))
next_stream = CPUStream
_test_copy_wait(prev_stream, next_stream, cuda_sleep)
@skip_if_no_cuda
def test_copy_wait_cuda_cuda(cuda_sleep):
prev_stream = current_stream(torch.device("cuda"))
next_stream = new_stream(torch.device("cuda"))
_test_copy_wait(prev_stream, next_stream, cuda_sleep)
def test_wait_multiple_tensors():
a = torch.rand(1, requires_grad=True)
b = torch.rand(1, requires_grad=True)
a, b = Wait.apply(CPUStream, CPUStream, a, b)
assert a.grad_fn is b.grad_fn
assert a.grad_fn.__class__ is Wait._backward_cls

View File

@ -0,0 +1,192 @@
# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from copy import deepcopy
from itertools import chain
import pytest
import torch
from torch import nn, optim
from torch.distributed._pipeline.sync.batchnorm import DeferredBatchNorm
CHUNKS = 4
def tilt_dist(input):
# Tilt variance by channel.
rgb = input.transpose(0, 1)
rgb[0] *= 1
rgb[1] *= 10
rgb[2] *= 100
# Tilt mean by single batch.
for i, single in enumerate(input):
single += 2 ** i
return input
def chunked_forward(model, input, chunks=CHUNKS):
output_chunks = []
for chunk in input.chunk(chunks):
output_chunks.append(model(chunk))
return torch.cat(output_chunks)
@pytest.mark.parametrize("chunks", [1, 4])
@pytest.mark.parametrize("input_requires_grad", [True, False])
def test_transparency(chunks, input_requires_grad):
bn = nn.BatchNorm2d(3)
dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=chunks)
input1 = torch.rand(16, 3, 224, 224)
input1 = tilt_dist(input1)
input2 = input1.clone()
input1.requires_grad = input_requires_grad
input2.requires_grad = input_requires_grad
output1 = chunked_forward(bn, input1, chunks=chunks)
output2 = chunked_forward(dbn, input2, chunks=chunks)
assert torch.allclose(output1, output2, atol=1e-4)
output1.mean().backward()
output2.mean().backward()
assert torch.allclose(bn.weight.grad, dbn.weight.grad, atol=1e-4)
if input_requires_grad:
assert input1.grad is not None
assert input2.grad is not None
assert torch.allclose(input1.grad, input2.grad, atol=1e-4)
@pytest.mark.parametrize("momentum", [0.1, None])
def test_running_stats(momentum):
bn = nn.BatchNorm2d(3, momentum=momentum)
dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=CHUNKS)
input = torch.rand(16, 3, 224, 224)
input = tilt_dist(input)
bn(input)
chunked_forward(dbn, input)
assert torch.allclose(bn.running_mean, dbn.running_mean, atol=1e-4)
assert torch.allclose(bn.running_var, dbn.running_var, atol=1e-4)
def test_convert_deferred_batch_norm():
bn = nn.BatchNorm2d(3, track_running_stats=False)
bn = DeferredBatchNorm.convert_deferred_batch_norm(bn, chunks=CHUNKS)
assert type(bn) is nn.BatchNorm2d # because of track_running_stats=False
dbn = DeferredBatchNorm(3, chunks=CHUNKS)
dbn_again = DeferredBatchNorm.convert_deferred_batch_norm(dbn, chunks=CHUNKS)
assert dbn is dbn_again
dbn_again = DeferredBatchNorm.convert_deferred_batch_norm(dbn, chunks=CHUNKS + 1)
assert dbn is not dbn_again # because of different chunks
def test_eval():
bn = nn.BatchNorm2d(3)
dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=CHUNKS)
input = torch.rand(16, 3, 224, 224)
input = tilt_dist(input)
bn(input)
chunked_forward(dbn, input)
bn.eval()
dbn.eval()
assert torch.allclose(bn(input), dbn(input), atol=1e-4)
def test_optimize():
bn = nn.BatchNorm2d(3)
dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=CHUNKS)
opt = optim.SGD(chain(bn.parameters(), dbn.parameters()), lr=1.0)
for i in range(5):
input = torch.rand(16, 3, 224, 224)
input = tilt_dist(input)
# train
y = bn(input)
a = y.sum()
a.backward()
y = chunked_forward(dbn, input)
b = y.sum()
b.backward()
opt.step()
# eval
bn.eval()
dbn.eval()
with torch.no_grad():
assert torch.allclose(bn(input), dbn(input), atol=1e-1 * (10 ** i))
def test_conv_bn():
bn = nn.Sequential(nn.Conv2d(3, 3, 1), nn.BatchNorm2d(3))
dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=CHUNKS)
input = torch.rand(16, 3, 224, 224)
input = tilt_dist(input)
opt = optim.SGD(chain(bn.parameters(), dbn.parameters()), lr=0.1)
# 1st step
a = bn(input)
b = chunked_forward(dbn, input)
# Outputs are different. (per-mini-batch vs. per-micro-batch)
assert not torch.allclose(a, b)
a.sum().backward()
b.sum().backward()
opt.step()
opt.zero_grad()
# Conv layers are also trained differently because of their different outputs.
assert not torch.allclose(bn[0].weight, dbn[0].weight)
# But BNs track identical running stats.
assert torch.allclose(bn[1].running_mean, dbn[1].running_mean, atol=1e-4)
assert torch.allclose(bn[1].running_var, dbn[1].running_var, atol=1e3)
# 2nd step
a = bn(input)
b = chunked_forward(dbn, input)
a.sum().backward()
b.sum().backward()
# BNs can't track identical running stats due to the different conv layers.
assert not torch.allclose(bn[1].running_mean, dbn[1].running_mean, atol=1e-4)
assert not torch.allclose(bn[1].running_var, dbn[1].running_var, atol=1e3)
def test_input_requiring_grad():
dbn = DeferredBatchNorm(3, chunks=CHUNKS)
input = torch.rand(16, 3, 224, 224)
input = tilt_dist(input)
input.requires_grad = True
chunked_forward(dbn, input)
assert not dbn.sum.requires_grad
assert dbn.sum.grad_fn is None

View File

@ -0,0 +1,144 @@
# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import weakref
import pytest
import torch
from torch.distributed._pipeline.sync.dependency import Fork, Join, fork, join
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
def test_fork_join():
logs = []
class Log(torch.autograd.Function):
@staticmethod
def forward(ctx, number, tensor):
ctx.number = number
return tensor.detach()
@staticmethod
def backward(ctx, grad):
logs.append(ctx.number)
return None, grad
a = torch.rand(1, device="cpu", requires_grad=True)
b = torch.rand(1, device="cuda", requires_grad=True)
a = Log.apply(1, a)
a, phony = fork(a)
b = join(a, phony)
b = Log.apply(2, b)
b = b.to("cpu")
(a + b).backward()
assert logs == [2, 1]
def test_fork_join_enable_grad():
x = torch.rand(1, requires_grad=True)
with torch.enable_grad():
x2, p = fork(x)
assert p.requires_grad
assert x2 is not x
x = x2
assert x.requires_grad
assert p.requires_grad
assert x.grad_fn.__class__ is Fork._backward_cls
assert p.grad_fn.__class__ is Fork._backward_cls
with torch.enable_grad():
x2 = join(x, p)
assert x2 is not x
x = x2
assert x.requires_grad
assert x.grad_fn.__class__ is Join._backward_cls
def test_fork_join_no_grad(monkeypatch):
def do_not_apply(*args):
raise AssertionError("Function.apply called")
monkeypatch.setattr("torch.autograd.Function.apply", do_not_apply)
x = torch.rand(1, requires_grad=True)
with torch.no_grad():
x2, p = fork(x)
assert not p.requires_grad
assert x2 is x
x = x2
with torch.no_grad():
x2 = join(x, p)
assert x2 is x
x = x2
def test_fork_leak():
leak = None
class F(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
return input
@staticmethod
def backward(ctx, grad):
nonlocal leak
leak = weakref.ref(ctx)
return grad
x = torch.rand(1, requires_grad=True)
x = F.apply(x)
x, phony = fork(x)
x = join(x, phony)
x.backward()
del x, phony
assert leak() is None
def test_join_when_fork_not_requires_grad():
x = torch.rand(2, 1)
a, b = x.chunk(2)
assert not a.requires_grad
a, p = fork(a)
assert not a.requires_grad
assert not p.requires_grad
assert not b.requires_grad
b = join(b, p)
assert not b.requires_grad
def test_join_when_fork_requires_grad():
x = torch.rand(2, 1)
a, b = x.chunk(2)
a.requires_grad_()
assert a.requires_grad
a, p = fork(a)
assert a.requires_grad
assert p.requires_grad
assert not b.requires_grad
b = join(b, p)
assert b.requires_grad

View File

@ -0,0 +1,71 @@
# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import pytest
import torch
from torch import nn
from torch.distributed._pipeline.sync import Pipe
def test_inplace_on_requires_grad():
model = nn.Sequential(nn.Linear(1, 1), nn.ReLU(inplace=True))
model = Pipe(model, [1, 1], devices=["cpu", "cpu"], checkpoint="always")
x = torch.rand(1)
y = model(x)
message = r"a leaf Variable that requires grad .* used in an in-place operation."
with pytest.raises(RuntimeError, match=message):
y.backward()
@pytest.mark.xfail(strict=True)
def test_inplace_on_not_requires_grad():
# In-place operation on a tensor not requiring grad doesn't cause a
# RuntimeError. Currently, we cannot detect this case.
model = nn.Sequential(nn.ReLU(inplace=True))
model = Pipe(model, [1], devices=["cpu"], checkpoint="always")
x = torch.rand(1)
y = model(x)
del model
message = r"a leaf Variable that requires grad .* used in an in-place operation."
with pytest.raises(RuntimeError, match=message):
y.backward()
@pytest.mark.xfail(strict=True)
def test_inplace_incorrect_grad():
class M(nn.Module):
def forward(self, foo_bar):
# 'foo' requires grad but 'bar' does not. In-place operation on
# 'bar' won't cause a RuntimeError.
foo, bar = foo_bar
# add_(1) is not idempotent, in contrast to relu_(). If it is
# executed multiple times, it will accumulates each difference onto
# 'bar'.
bar.add_(1)
# 'bar' is still captured by checkpointing. 'foo' will get
# incorrect grad.
return foo * bar
model = nn.Sequential(M())
model = Pipe(model, [1], devices=["cpu"], checkpoint="always")
foo = torch.tensor([1.0], requires_grad=True)
bar = torch.tensor([1.0])
output = model((foo, bar))
del model
output.backward()
# The gradient of 'foo' should be 2, but it is 3 actually because
# bar.add_(1) was executed twice due to checkpointing.
assert foo.grad.item() == 2.0

View File

@ -0,0 +1,138 @@
# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import pytest
import torch
import torch.cuda
from torch.distributed._pipeline.sync.microbatch import Batch, check, gather, scatter
def test_batch_atomic():
x = torch.tensor(42)
b = Batch(x)
assert b.atomic
assert b.tensor is x
with pytest.raises(AttributeError):
b.tensors
assert list(b) == [x]
assert len(b) == 1
assert b[0] is x
def test_batch_non_atomic():
x, y = torch.tensor(42), torch.tensor(21)
b = Batch((x, y))
assert not b.atomic
with pytest.raises(AttributeError):
b.tensor
assert b.tensors == (x, y)
assert list(b) == [x, y]
assert len(b) == 2
assert b[0] is x
assert b[1] is y
def test_batch_call():
a = Batch(torch.tensor(42))
b = Batch((torch.tensor(42), torch.tensor(21)))
def f(x):
return x
assert a.call(f).atomic
assert not b.call(f).atomic
def test_batch_setitem_by_index():
a = Batch(torch.tensor(42))
b = Batch((torch.tensor(42), torch.tensor(21)))
a[0] = torch.tensor(0)
b[0] = torch.tensor(0)
assert a.atomic
assert a[0].item() == 0
assert not b.atomic
assert len(b) == 2
assert b[0].item() == 0
assert b[1].item() == 21
def test_batch_setitem_by_slice():
a = Batch(torch.tensor(42))
b = Batch((torch.tensor(42), torch.tensor(21)))
a[:] = (torch.tensor(0),)
b[:] = (torch.tensor(0),)
assert a.atomic
assert a[0].item() == 0
assert not b.atomic
assert len(b) == 1
assert b[0].item() == 0
def test_check():
check(torch.tensor(42))
check((torch.tensor(4), torch.tensor(2)))
with pytest.raises(TypeError):
check(42)
with pytest.raises(TypeError):
check("str")
with pytest.raises(TypeError):
check((torch.tensor(4), 2))
def test_gather_tensors():
a = torch.zeros(1, 1)
b = torch.zeros(1, 1)
ab = gather([Batch(a), Batch(b)])
assert ab.size() == (2, 1)
def test_gather_tuples():
a = (torch.zeros(1, 1), torch.zeros(2, 2))
b = (torch.zeros(1, 1), torch.zeros(2, 2))
ab = gather([Batch(a), Batch(b)])
assert isinstance(ab, tuple)
assert ab[0].size() == (2, 1)
assert ab[1].size() == (4, 2)
def test_scatter_tensor():
ab = torch.zeros(2, 1)
a, b = scatter(ab, chunks=2)
assert a.tensor.size() == (1, 1)
assert b.tensor.size() == (1, 1)
def test_scatter_tuple():
ab = (torch.zeros(2, 1), torch.zeros(4, 2))
a, b = scatter(ab, chunks=2)
assert a.tensors[0].size() == (1, 1)
assert b.tensors[0].size() == (1, 1)
assert a.tensors[1].size() == (2, 2)
assert b.tensors[1].size() == (2, 2)

View File

@ -0,0 +1,50 @@
# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import torch
from torch.distributed._pipeline.sync.phony import get_phony
def test_phony_size():
p = get_phony(torch.device("cpu"), requires_grad=False)
assert p.size() == (0,)
def test_phony_requires_grad():
p1 = get_phony(torch.device("cpu"), requires_grad=True)
p2 = get_phony(torch.device("cpu"), requires_grad=False)
assert p1.requires_grad
assert not p2.requires_grad
def test_cached_phony():
p1 = get_phony(torch.device("cpu"), requires_grad=True)
p2 = get_phony(torch.device("cpu"), requires_grad=True)
assert p1 is p2
p3 = get_phony(torch.device("cpu"), requires_grad=False)
p4 = get_phony(torch.device("cpu"), requires_grad=False)
assert p3 is p4
assert p1 is not p3
def test_phony_in_autograd_function():
class Phonify(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
phony = get_phony(input.device, requires_grad=False)
return phony.detach()
x = torch.rand(1, requires_grad=True)
p1 = Phonify.apply(x)
p2 = get_phony(torch.device("cpu"), requires_grad=True)
assert p1 is not p2
assert p1.grad_fn is not None
assert p2.grad_fn is None

View File

@ -0,0 +1,608 @@
# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from collections import OrderedDict
from copy import deepcopy
import time
import pytest
import torch
from torch import nn
from torch.distributed._pipeline.sync import Pipe
skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
def test_parameters():
model = nn.Sequential(nn.Linear(1, 1))
pipe = Pipe(model, balance=[1], devices=["cpu"], chunks=1)
assert list(pipe.parameters()) != []
def test_public_attrs():
class MyString:
def __init__(self, value):
self.value = value
def __str__(self):
return self.value
model = nn.Sequential(nn.Linear(1, 1))
pipe = Pipe(model, balance=(1,), devices=("cpu",), chunks=42.000, checkpoint=MyString("always"))
assert pipe.balance == [1]
assert pipe.devices == [torch.device("cpu")]
assert pipe.chunks == 42
assert isinstance(pipe.chunks, int)
assert pipe.checkpoint == "always"
assert isinstance(pipe.checkpoint, str)
@pytest.mark.parametrize("balance", [[2], [1, 1]])
def test_sequential_like(balance):
a = nn.Linear(1, 1)
b = nn.Linear(1, 1)
model = nn.Sequential(a, b)
model = Pipe(model, balance, devices=["cpu", "cpu"])
assert len(model) == 2
assert list(model) == [a, b]
assert model[0] is a
assert model[1] is b
with pytest.raises(IndexError):
_ = model[2]
assert model[-1] is b
assert model[-2] is a
def test_balance_wrong_length():
a = nn.Linear(1, 1)
b = nn.Linear(1, 1)
model = nn.Sequential(a, b)
with pytest.raises(ValueError):
Pipe(model, balance=[1])
with pytest.raises(ValueError):
Pipe(model, balance=[3])
def test_balance_less_than_1():
a = nn.Linear(1, 1)
b = nn.Linear(1, 1)
model = nn.Sequential(a, b)
with pytest.raises(ValueError):
Pipe(model, balance=[0, 2])
with pytest.raises(ValueError):
Pipe(model, balance=[-1, 3])
def test_chunks_less_than_1():
model = nn.Sequential(nn.Linear(1, 1))
with pytest.raises(ValueError):
Pipe(model, balance=[1], devices=["cpu"], chunks=0)
with pytest.raises(ValueError):
Pipe(model, balance=[1], devices=["cpu"], chunks=-1)
def test_too_few_devices():
model = nn.Sequential(nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1))
with pytest.raises(IndexError):
# len(balance) > len(devices)
model = Pipe(model, balance=[1, 1, 1, 1], devices=["cpu"])
def test_batch_size_indivisible():
model = nn.Sequential(nn.Linear(1, 1))
model = Pipe(model, balance=[1], devices=["cpu"], chunks=4)
with pytest.warns(None) as record:
model(torch.rand(7, 1))
# Indivisible batch size is legal.
assert not record
def test_batch_size_small():
model = nn.Sequential(nn.Linear(1, 1))
model = Pipe(model, balance=[1], devices=["cpu"], chunks=4)
with pytest.warns(None) as record:
model(torch.rand(2, 1))
# Batch size smaller than chunks is legal.
assert not record
def test_checkpoint_mode():
def count_grad_fn(grad_fn, name, visited=None):
if visited is None:
visited = set()
if grad_fn in visited:
return 0
visited.add(grad_fn)
if grad_fn is None:
return 0
if grad_fn.__class__.__name__ == name:
return 1
counter = 0
for next_grad_fn, _ in grad_fn.next_functions:
counter += count_grad_fn(next_grad_fn, name, visited=visited)
return counter
model = nn.Sequential(nn.Linear(1, 1))
input = torch.rand(2, 1)
always = Pipe(model, balance=[1], devices=["cpu"], chunks=2, checkpoint="always")
except_last = Pipe(model, balance=[1], devices=["cpu"], chunks=2, checkpoint="except_last")
never = Pipe(model, balance=[1], devices=["cpu"], chunks=2, checkpoint="never")
always_output = always(input)
except_last_output = except_last(input)
never_output = never(input)
assert count_grad_fn(always_output.grad_fn, "CheckpointBackward") == 2
assert count_grad_fn(except_last_output.grad_fn, "CheckpointBackward") == 1
assert count_grad_fn(never_output.grad_fn, "CheckpointBackward") == 0
def test_checkpoint_mode_invalid():
model = nn.Sequential(nn.Linear(1, 1))
with pytest.raises(ValueError, match="checkpoint is not one of 'always', 'except_last', or 'never'"):
Pipe(model, balance=[1], devices=["cpu"], chunks=2, checkpoint="INVALID_CHECKPOINT")
def test_checkpoint_mode_when_chunks_1():
model = nn.Sequential(nn.Linear(1, 1))
# All checkpoint modes are fine.
Pipe(model, balance=[1], devices=["cpu"], chunks=1, checkpoint="except_last")
Pipe(model, balance=[1], devices=["cpu"], chunks=1, checkpoint="always")
Pipe(model, balance=[1], devices=["cpu"], chunks=1, checkpoint="never")
def test_checkpoint_eval():
model = nn.Sequential(nn.Linear(1, 1))
model = Pipe(model, balance=[1], devices=["cpu"], chunks=2)
input = torch.rand(2, 1)
def find_grad_fn(grad_fn, name):
if grad_fn is None:
return False
if grad_fn.__class__.__name__ == name:
return True
for next_grad_fn, _ in grad_fn.next_functions:
if find_grad_fn(next_grad_fn, name):
return True
return False
model.train()
train_output = model(input)
assert find_grad_fn(train_output.grad_fn, "CheckpointBackward")
assert find_grad_fn(train_output.grad_fn, "RecomputeBackward")
model.eval()
eval_output = model(input)
assert not find_grad_fn(eval_output.grad_fn, "CheckpointBackward")
assert not find_grad_fn(eval_output.grad_fn, "RecomputeBackward")
def test_checkpoint_non_float_input():
class ForkNonFloat(nn.Module):
def forward(self, input):
return (input * 2, torch.tensor([False]))
class JoinNonFloat(nn.Module):
def forward(self, input):
return input[0] * 2
model = nn.Sequential(ForkNonFloat(), JoinNonFloat())
model = Pipe(model, balance=[1, 1], devices=["cpu", "cpu"], chunks=1, checkpoint="always")
input = torch.rand(1, requires_grad=True)
output = model(input)
output.backward()
def test_no_grad():
model = nn.Sequential(nn.Linear(1, 1))
model = Pipe(model, balance=[1], devices=["cpu"], chunks=2)
input = torch.rand(2, 1)
latent = None
def hook(module, input, output):
_ = module
_ = input
nonlocal latent
latent = output
partition = model.partitions[0]
partition.register_forward_hook(hook)
with torch.no_grad():
model(input)
assert latent.grad_fn is None
def test_exception():
class ExpectedException(Exception):
pass
class Raise(nn.Module):
def forward(self, *_):
raise ExpectedException()
model = nn.Sequential(Raise())
model = Pipe(model, balance=[1], devices=["cpu"], chunks=1)
with pytest.raises(ExpectedException):
model(torch.rand(1))
def test_exception_early_stop_asap():
"""Even the first partitions have finished to process, the partition before
the failed partition should be killed as soon as possible.
"""
class ExpectedException(Exception):
pass
class Pass(nn.Module):
def forward(self, x):
return x
counter = 0
class Counter(nn.Module):
def forward(self, x):
time.sleep(0.1)
nonlocal counter
counter += 1
return x
class Raise(nn.Module):
def forward(self, x):
raise ExpectedException()
model = nn.Sequential(Pass(), Pass(), Counter(), Raise())
model = Pipe(model, [1, 1, 1, 1], devices=["cpu", "cpu", "cpu", "cpu"], chunks=3)
with pytest.raises(ExpectedException):
model(torch.rand(3))
# If the early stop doesn't work, it would be 3 instead.
assert counter == 2
def test_input_pair():
class Two(nn.Module):
def __init__(self):
super().__init__()
self.fc_a = nn.Linear(1, 1)
self.fc_b = nn.Linear(1, 1)
def forward(self, a_and_b):
a, b = a_and_b
return (self.fc_a(a), self.fc_b(b))
model = nn.Sequential(Two())
model = Pipe(model, balance=[1], devices=["cpu"], chunks=2)
a = torch.rand(10, 1, requires_grad=True)
b = torch.rand(10, 1, requires_grad=True)
a_out, b_out = model((a, b))
loss = (a_out + b_out).mean()
loss.backward()
assert a.grad is not None
assert b.grad is not None
def test_input_singleton():
class One(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(1, 1)
def forward(self, only_a):
(a,) = only_a
return (self.fc(a),)
model = nn.Sequential(One())
model = Pipe(model, balance=[1], devices=["cpu"], chunks=2)
a = torch.rand(10, 1, requires_grad=True)
(a_out,) = model((a,))
loss = a_out.mean()
loss.backward()
assert all(p.grad is not None for p in model.parameters())
assert a.grad is not None
def test_input_varargs():
model = nn.Sequential(nn.Linear(1, 1))
model = Pipe(model, balance=[1], devices=["cpu"])
a = torch.rand(1)
b = torch.rand(1)
# TypeError: forward() takes 2 positional arguments but 3 were given
with pytest.raises(TypeError):
model(a, b)
def test_non_tensor():
class NonTensor(nn.Module):
def forward(self, _):
return "hello"
model = nn.Sequential(NonTensor())
model = Pipe(model, balance=[1], devices=["cpu"])
x = torch.rand(1)
# TypeError: expected Tensor as element 0 in argument 0, but got str
with pytest.raises(TypeError):
model(x)
# TypeError: expected Tensor to scatter, but got str
with pytest.raises(TypeError):
model("hello")
def test_non_tensor_tuple():
class NonTensorTuple(nn.Module):
def forward(self, x):
return (x, "hello")
model = nn.Sequential(NonTensorTuple())
model = Pipe(model, balance=[1], devices=["cpu"])
x = torch.rand(1)
# TypeError: CheckpointBackward.forward: expected Variable (got str) for return value 1
with pytest.raises(TypeError):
model(x)
# TypeError: expected Tensor to scatter, but got str
with pytest.raises(TypeError):
model((x, "hello"))
@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"])
def test_deferred_batch_norm(checkpoint):
bn = nn.BatchNorm2d(3)
pipe_bn = deepcopy(bn)
pipe = Pipe(
nn.Sequential(pipe_bn), balance=[1], devices=["cpu"], chunks=2, checkpoint=checkpoint, deferred_batch_norm=True
)
x = torch.rand(4, 3, 10, 10)
pipe(x).mean().backward()
bn(x).mean().backward()
assert torch.allclose(pipe[0].running_mean, bn.running_mean, atol=1e-4)
assert torch.allclose(pipe[0].running_var, bn.running_var, atol=1e-4)
@pytest.mark.parametrize("checkpoint", ["never", "always"])
def test_deferred_batch_norm_params(checkpoint):
bn = nn.BatchNorm2d(3)
pipe_bn = deepcopy(bn)
pipe = Pipe(
nn.Sequential(pipe_bn), balance=[1], devices=["cpu"], chunks=1, checkpoint=checkpoint, deferred_batch_norm=True
)
x = torch.rand(4, 3, 10, 10)
pipe(x).mean().backward()
bn(x).mean().backward()
assert pipe[0].weight.grad is not None
assert pipe[0].bias.grad is not None
assert torch.allclose(pipe[0].weight.grad, bn.weight.grad, atol=1e-4)
assert torch.allclose(pipe[0].bias.grad, bn.bias.grad, atol=1e-4)
def test_devices():
a = nn.Linear(1, 1)
b = nn.Linear(1, 1)
c = nn.Linear(1, 1)
# There are extra two devices.
devices = ["cpu", "cpu", "cpu", "cpu", "cpu"]
model = nn.Sequential(a, b, c)
model = Pipe(model, [1, 1, 1], devices=devices)
cpu = torch.device("cpu")
# Extra devices must be discarded.
assert model.devices == [cpu, cpu, cpu]
def test_partitions():
a = nn.Linear(1, 1)
b = nn.Linear(1, 1)
model = nn.Sequential(a, b)
model = Pipe(model, [1, 1], devices=["cpu", "cpu"])
assert isinstance(model.partitions, nn.ModuleList)
assert isinstance(model.partitions[0], nn.Sequential)
assert isinstance(model.partitions[1], nn.Sequential)
assert "partitions.0.0.weight" in model.state_dict()
def test_deny_moving():
a = nn.Linear(1, 1)
b = nn.Linear(1, 1)
model = nn.Sequential(a, b)
model = Pipe(model, [1, 1], devices=["cpu", "cpu"])
# Moving is denied.
with pytest.raises(TypeError):
model.cuda()
with pytest.raises(TypeError):
model.cpu()
with pytest.raises(TypeError):
model.to(torch.device("cuda"))
with pytest.raises(TypeError):
model.to(0)
with pytest.raises(TypeError):
model.to("cuda")
with pytest.raises(TypeError):
model.to(device=0)
with pytest.raises(TypeError):
model.to(torch.rand(1))
with pytest.raises(TypeError):
model.to(tensor=torch.rand(1))
# Casting is allowed.
model.half()
model.to(torch.double)
model.to(dtype=torch.float)
def test_empty_module():
# Empty sequential module is not illegal.
model = nn.Sequential()
model = Pipe(model, [])
assert model(torch.tensor(42)) == torch.tensor(42)
assert model((torch.tensor(42),)) == (torch.tensor(42),)
# But only tensor or tensors is legal in Pipe.
with pytest.raises(TypeError):
model(42)
def test_named_children():
a = nn.Linear(1, 1)
b = nn.Linear(1, 1)
model = nn.Sequential(OrderedDict([("a", a), ("b", b)]))
model = Pipe(model, [1, 1], devices=["cpu", "cpu"])
names = set(n for n, _ in model.named_modules())
assert "partitions.0.a" in names
assert "partitions.1.b" in names
# Pipe doesn't support __getattr__. Unlike nn.Sequential, Pipe requires
# several methods in its namespace.
with pytest.raises(AttributeError):
model.a
def test_recommend_auto_balance():
with pytest.raises(ValueError, match="torch.distributed._pipeline.sync.balance"):
# balance is required
Pipe(nn.Sequential())
with pytest.raises(ValueError, match="torch.distributed._pipeline.sync.balance"):
# module and sum of balance have differen length (module: 0, sum of balance: 1)
Pipe(nn.Sequential(), [1])
with pytest.raises(ValueError, match="torch.distributed._pipeline.sync.balance"):
# module and sum of balance have different length (module: 2, sum of balance: 1)
Pipe(nn.Sequential(nn.Linear(1, 1), nn.Linear(1, 1)), [1])
def test_verify_module_non_sequential():
with pytest.raises(TypeError, match="module must be nn.Sequential to be partitioned"):
Pipe(nn.Module(), [1])
def test_verify_module_duplicate_children():
conv = nn.Conv2d(3, 3, 1)
model = nn.Sequential(conv, conv)
with pytest.raises(ValueError, match="module with duplicate children is not supported"):
Pipe(model, [1, 1])
@skip_if_no_cuda
def test_verify_module_duplicate_parameters_on_distinct_devices():
class Surrogate(nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
conv = nn.Conv2d(3, 3, 1)
model = nn.Sequential(Surrogate(conv), Surrogate(conv))
with pytest.raises(ValueError, match="module with duplicate parameters on distinct devices is not supported"):
Pipe(model, [1, 1], devices=["cpu", "cuda"])
def test_verify_module_duplicate_parameters_on_same_device():
class Surrogate(nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
conv = nn.Conv2d(3, 3, 1)
model = nn.Sequential(Surrogate(conv), Surrogate(conv))
Pipe(model, [1, 1], devices=["cpu", "cpu"])
def test_forward_lockstep():
timeline = []
class DelayedLog(nn.Module):
def __init__(self, j, seconds):
super().__init__()
self.i = 0
self.j = j
self.seconds = seconds
def forward(self, x):
time.sleep(self.seconds)
timeline.append((self.i, self.j))
self.i += 1
return x
model = nn.Sequential(DelayedLog(0, seconds=0), DelayedLog(1, seconds=0.1))
model = Pipe(model, balance=[1, 1], devices=["cpu", "cpu"], chunks=3)
model(torch.rand(3, 1))
# Expected timeline: (Logs are recorded at !)
#
# Partition #0: 0! 1! 2!
# Partition #1: 000! 111! 222!
#
assert timeline == [(0, 0), (1, 0), (0, 1), (2, 0), (1, 1), (2, 1)]

View File

@ -0,0 +1,29 @@
# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from torch.distributed._pipeline.sync.pipeline import clock_cycles
def test_clock_cycles():
assert list(clock_cycles(1, 1)) == [[(0, 0)]]
assert list(clock_cycles(1, 3)) == [[(0, 0)], [(0, 1)], [(0, 2)]]
assert list(clock_cycles(3, 1)) == [[(0, 0)], [(1, 0)], [(2, 0)]]
assert list(clock_cycles(3, 3)) == [ # noqa
[(0, 0)],
[(1, 0), (0, 1)],
[(2, 0), (1, 1), (0, 2)],
[(2, 1), (1, 2)],
[(2, 2)],
]
assert list(clock_cycles(4, 2)) == [ # noqa
[(0, 0)],
[(1, 0), (0, 1)],
[(2, 0), (1, 1)],
[(3, 0), (2, 1)],
[(3, 1)],
]

View File

@ -0,0 +1,188 @@
# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import pytest
import torch
from torch.distributed._pipeline.sync.stream import (
CPUStream,
current_stream,
default_stream,
get_device,
is_cuda,
new_stream,
record_stream,
use_device,
use_stream,
wait_stream,
)
skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
class TestNewStream:
def test_new_stream_cpu(self):
stream = new_stream(torch.device("cpu"))
assert stream is CPUStream
@skip_if_no_cuda
def test_new_stream_cuda(self):
stream = new_stream(torch.device("cuda"))
assert isinstance(stream, torch.cuda.Stream)
assert stream != torch.cuda.default_stream()
class TestCurrentStream:
def test_current_stream_cpu(self):
stream = current_stream(torch.device("cpu"))
assert stream is CPUStream
@skip_if_no_cuda
def test_current_stream_cuda(self):
stream = current_stream(torch.device("cuda"))
assert isinstance(stream, torch.cuda.Stream)
assert stream == torch.cuda.current_stream()
class TestDefaultStream:
def test_default_stream_cpu(self):
stream = default_stream(torch.device("cpu"))
assert stream is CPUStream
@skip_if_no_cuda
def test_default_stream_cuda(self):
stream = default_stream(torch.device("cuda"))
assert isinstance(stream, torch.cuda.Stream)
assert stream == torch.cuda.default_stream()
class TestUseDevice:
def test_use_device_cpu(self):
with use_device(torch.device("cpu")):
pass
@skip_if_no_cuda
def test_use_device_cuda(self):
with use_device(torch.device("cuda")):
pass
class TestUseStream:
def test_use_stream_cpu(self):
with use_stream(CPUStream):
pass
@skip_if_no_cuda
def test_use_stream_cuda(self):
stream = new_stream(torch.device("cuda"))
with use_stream(stream):
assert current_stream(torch.device("cuda")) == stream
class TestGetDevice:
def test_get_device_cpu(self):
assert get_device(CPUStream).type == "cpu"
@skip_if_no_cuda
def test_get_device_cuda(self):
stream = current_stream(torch.device("cuda"))
assert get_device(stream).type == "cuda"
class TestWaitStream:
def _test_wait_stream(self, source, target, cuda_sleep=None):
with use_stream(target):
if is_cuda(target):
cuda_sleep(0.5)
x = torch.ones(100, 100, device=get_device(target))
wait_stream(source, target)
with use_stream(source):
assert x.sum().item() == 10000
def test_wait_stream_cpu_cpu(self):
source = CPUStream
target = CPUStream
self._test_wait_stream(source, target)
@skip_if_no_cuda
def test_wait_stream_cpu_cuda(self, cuda_sleep):
source = CPUStream
target = new_stream(torch.device("cuda"))
self._test_wait_stream(source, target, cuda_sleep)
@skip_if_no_cuda
def test_wait_stream_cuda_cpu(self, cuda_sleep):
source = new_stream(torch.device("cuda"))
target = CPUStream
self._test_wait_stream(source, target, cuda_sleep)
@skip_if_no_cuda
def test_wait_stream_cuda_cuda(self, cuda_sleep):
source = current_stream(torch.device("cuda"))
target = new_stream(torch.device("cuda"))
self._test_wait_stream(source, target, cuda_sleep)
class TestRecordStream:
def test_record_stream_cpu(self):
# It should silently ignore CPU tensors.
x = torch.rand(1, device=torch.device("cpu"))
record_stream(x, CPUStream)
@skip_if_no_cuda
def test_record_stream_cuda(self, cuda_sleep):
# This test detects unexpected block reallocation. For reliable test,
# the stream to allocate tensors is isolated. The allocator will not
# reuse free blocks which were allocated from another stream.
stream_alloc = new_stream(torch.device("cuda"))
with torch.cuda.stream(stream_alloc):
x = torch.rand(1, device=torch.device("cuda"))
stream = new_stream(torch.device("cuda"))
record_stream(x, stream)
with use_stream(stream):
cuda_sleep(0.5)
# 'x' is deleted at Python's perspective. But the block of 'x' is still
# required for 'stream'. 'y' shouldn't be allocated to the block.
data_ptr = x.data_ptr()
del x
stream_alloc.synchronize()
with torch.cuda.stream(stream_alloc):
y = torch.rand(1, device=torch.device("cuda"))
assert y.data_ptr() != data_ptr
# Pause Python until 'stream' finishes tasks queued. Now the block of
# 'x' is free to be reallocated.
wait_stream(CPUStream, stream)
with torch.cuda.stream(stream_alloc):
z = torch.rand(1, device=torch.device("cuda"))
assert z.data_ptr() == data_ptr
@skip_if_no_cuda
def test_record_stream_shifted_view(self, cuda_sleep):
# Issue: https://github.com/pytorch/pytorch/issues/27366
stream_alloc = new_stream(torch.device("cuda"))
with torch.cuda.stream(stream_alloc):
x = torch.rand(2, device=torch.device("cuda"))
y = x[1:]
assert y.data_ptr() > x.data_ptr()
stream = new_stream(torch.device("cuda"))
with use_stream(stream):
cuda_sleep(0.5)
record_stream(y, stream)
data_ptr = x.data_ptr()
del x, y
stream_alloc.synchronize()
with torch.cuda.stream(stream_alloc):
z = torch.rand(2, device=torch.device("cuda"))
assert z.data_ptr() != data_ptr

View File

@ -0,0 +1,43 @@
# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import torch
from torch import nn
from torch.distributed._pipeline.sync import Pipe
def test_simple_linears():
def sum_grad(parameters):
return sum([p.grad.sum() for p in parameters if p.grad is not None])
def zero_grad(parameters):
for p in parameters:
p.grad = None
inputs = torch.rand(8, 1)
model = nn.Sequential(nn.Linear(1, 2), nn.Linear(2, 4), nn.Linear(4, 2), nn.Linear(2, 1),)
# Without Pipe
outputs = model(inputs)
loss = outputs.mean()
loss.backward()
grad_without_pipe = sum_grad(model.parameters())
zero_grad(model.parameters())
# With Pipe
model = Pipe(model, [2, 2], devices=["cpu", "cpu"], chunks=4)
outputs = model(inputs)
loss = outputs.mean()
loss.backward()
grad_with_pipe = sum_grad(model.parameters())
# Both grads should be identical.
assert torch.allclose(grad_with_pipe, grad_without_pipe)

View File

@ -0,0 +1,163 @@
# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import threading
import time
import pytest
import torch
from torch.distributed._pipeline.sync.microbatch import Batch
from torch.distributed._pipeline.sync.stream import CPUStream
from torch.distributed._pipeline.sync.worker import Task, spawn_workers
class fake_device:
"""A test double for :class:`torch.device`. Every fake device is different
with each other.
"""
type = "fake"
index = None
def test_join_running_workers():
count = 0
def counter():
nonlocal count
time.sleep(0.1)
count += 1
return Batch(())
with spawn_workers([fake_device() for _ in range(10)]) as (in_queues, out_queues):
def call_in_worker(i, f):
task = Task(CPUStream, compute=f, finalize=None)
in_queues[i].put(task)
for i in range(10):
call_in_worker(i, counter)
# There's no nondeterminism because 'spawn_workers' joins all running
# workers.
assert count == 10
def test_join_running_workers_with_exception():
class ExpectedException(Exception):
pass
count = 0
def counter():
nonlocal count
time.sleep(0.1)
count += 1
return Batch(())
with pytest.raises(ExpectedException):
with spawn_workers([fake_device() for _ in range(10)]) as (in_queues, out_queues):
def call_in_worker(i, f):
task = Task(CPUStream, compute=f, finalize=None)
in_queues[i].put(task)
for i in range(10):
call_in_worker(i, counter)
raise ExpectedException
# There's no nondeterminism because only 1 task can be placed in input
# queues.
assert count == 10
def test_compute_multithreading():
"""Task.compute should be executed on multiple threads."""
thread_ids = set()
def log_thread_id():
thread_id = threading.current_thread().ident
thread_ids.add(thread_id)
return Batch(())
with spawn_workers([fake_device() for _ in range(2)]) as (in_queues, out_queues):
for i in range(2):
t = Task(CPUStream, compute=log_thread_id, finalize=None)
in_queues[i].put(t)
for i in range(2):
out_queues[i].get()
assert len(thread_ids) == 2
def test_compute_success():
"""Task.compute returns (True, (task, batch)) on success."""
def _42():
return Batch(torch.tensor(42))
with spawn_workers([torch.device("cpu")]) as (in_queues, out_queues):
t = Task(CPUStream, compute=_42, finalize=None)
in_queues[0].put(t)
ok, (task, batch) = out_queues[0].get()
assert ok
assert task is t
assert isinstance(batch, Batch)
assert batch[0].item() == 42
def test_compute_exception():
"""Task.compute returns (False, exc_info) on failure."""
def zero_div():
0 / 0
with spawn_workers([torch.device("cpu")]) as (in_queues, out_queues):
t = Task(CPUStream, compute=zero_div, finalize=None)
in_queues[0].put(t)
ok, exc_info = out_queues[0].get()
assert not ok
assert isinstance(exc_info, tuple)
assert issubclass(exc_info[0], ZeroDivisionError)
@pytest.mark.parametrize("grad_mode", [True, False])
def test_grad_mode(grad_mode):
def detect_grad_enabled():
x = torch.rand(1, requires_grad=torch.is_grad_enabled())
return Batch(x)
with torch.set_grad_enabled(grad_mode):
with spawn_workers([torch.device("cpu")]) as (in_queues, out_queues):
task = Task(CPUStream, compute=detect_grad_enabled, finalize=None)
in_queues[0].put(task)
ok, (_, batch) = out_queues[0].get()
assert ok
assert batch[0].requires_grad == grad_mode
def test_worker_per_device():
cpu = torch.device("cpu")
cpu0 = torch.device("cpu", index=0)
fake1 = fake_device()
fake2 = fake_device()
with spawn_workers([cpu, cpu, cpu0, fake1, fake2]) as (in_queues, out_queues):
assert len(in_queues) == len(out_queues) == 5
# 0: cpu, 1: cpu, 2: cpu0
assert in_queues[0] is in_queues[1] is in_queues[2]
assert out_queues[0] is out_queues[1] is out_queues[2]
# 3: fake1, 4: fake2
assert in_queues[3] is not in_queues[4]
assert out_queues[3] is not out_queues[4]

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python #!/usr/bin/env python
import argparse import argparse
import copy
from datetime import datetime from datetime import datetime
import importlib import importlib
import modulefinder import modulefinder
@ -95,6 +96,54 @@ TESTS = [
'test_fx_experimental', 'test_fx_experimental',
'test_functional_autograd_benchmark', 'test_functional_autograd_benchmark',
'test_package', 'test_package',
'distributed/_pipeline/sync/skip/test_api',
'distributed/_pipeline/sync/skip/test_gpipe',
'distributed/_pipeline/sync/skip/test_inspect_skip_layout',
'distributed/_pipeline/sync/skip/test_leak',
'distributed/_pipeline/sync/skip/test_portal',
'distributed/_pipeline/sync/skip/test_stash_pop',
'distributed/_pipeline/sync/skip/test_tracker',
'distributed/_pipeline/sync/skip/test_verify_skippables',
'distributed/_pipeline/sync/test_balance',
'distributed/_pipeline/sync/test_bugs',
'distributed/_pipeline/sync/test_checkpoint',
'distributed/_pipeline/sync/test_copy',
'distributed/_pipeline/sync/test_deferred_batch_norm',
'distributed/_pipeline/sync/test_dependency',
'distributed/_pipeline/sync/test_inplace',
'distributed/_pipeline/sync/test_microbatch',
'distributed/_pipeline/sync/test_phony',
'distributed/_pipeline/sync/test_pipe',
'distributed/_pipeline/sync/test_pipeline',
'distributed/_pipeline/sync/test_stream',
'distributed/_pipeline/sync/test_transparency',
'distributed/_pipeline/sync/test_worker',
]
# Tests need to be run with pytest.
USE_PYTEST_LIST = [
'distributed/_pipeline/sync/skip/test_api',
'distributed/_pipeline/sync/skip/test_gpipe',
'distributed/_pipeline/sync/skip/test_inspect_skip_layout',
'distributed/_pipeline/sync/skip/test_leak',
'distributed/_pipeline/sync/skip/test_portal',
'distributed/_pipeline/sync/skip/test_stash_pop',
'distributed/_pipeline/sync/skip/test_tracker',
'distributed/_pipeline/sync/skip/test_verify_skippables',
'distributed/_pipeline/sync/test_balance',
'distributed/_pipeline/sync/test_bugs',
'distributed/_pipeline/sync/test_checkpoint',
'distributed/_pipeline/sync/test_copy',
'distributed/_pipeline/sync/test_deferred_batch_norm',
'distributed/_pipeline/sync/test_dependency',
'distributed/_pipeline/sync/test_inplace',
'distributed/_pipeline/sync/test_microbatch',
'distributed/_pipeline/sync/test_phony',
'distributed/_pipeline/sync/test_pipe',
'distributed/_pipeline/sync/test_pipeline',
'distributed/_pipeline/sync/test_stream',
'distributed/_pipeline/sync/test_transparency',
'distributed/_pipeline/sync/test_worker',
] ]
WINDOWS_BLOCKLIST = [ WINDOWS_BLOCKLIST = [
@ -170,6 +219,28 @@ SLOW_TESTS = [
'test_quantization', 'test_quantization',
'test_determination', 'test_determination',
'test_futures', 'test_futures',
'distributed/_pipeline/sync/skip/test_api',
'distributed/_pipeline/sync/skip/test_gpipe',
'distributed/_pipeline/sync/skip/test_inspect_skip_layout',
'distributed/_pipeline/sync/skip/test_leak',
'distributed/_pipeline/sync/skip/test_portal',
'distributed/_pipeline/sync/skip/test_stash_pop',
'distributed/_pipeline/sync/skip/test_tracker',
'distributed/_pipeline/sync/skip/test_verify_skippables',
'distributed/_pipeline/sync/test_balance',
'distributed/_pipeline/sync/test_bugs',
'distributed/_pipeline/sync/test_checkpoint',
'distributed/_pipeline/sync/test_copy',
'distributed/_pipeline/sync/test_deferred_batch_norm',
'distributed/_pipeline/sync/test_dependency',
'distributed/_pipeline/sync/test_inplace',
'distributed/_pipeline/sync/test_microbatch',
'distributed/_pipeline/sync/test_phony',
'distributed/_pipeline/sync/test_pipe',
'distributed/_pipeline/sync/test_pipeline',
'distributed/_pipeline/sync/test_stream',
'distributed/_pipeline/sync/test_transparency',
'distributed/_pipeline/sync/test_worker',
] ]
_DEP_MODULES_CACHE: Dict[str, set] = {} _DEP_MODULES_CACHE: Dict[str, set] = {}
@ -762,12 +833,15 @@ def main():
failure_messages = [] failure_messages = []
try: try:
for test in selected_tests: for test in selected_tests:
err_message = run_test_module(test, test_directory, options) options_clone = copy.deepcopy(options)
if test in USE_PYTEST_LIST:
options_clone.pytest = True
err_message = run_test_module(test, test_directory, options_clone)
if err_message is None: if err_message is None:
continue continue
has_failed = True has_failed = True
failure_messages.append(err_message) failure_messages.append(err_message)
if not options.continue_through_error: if not options_clone.continue_through_error:
raise RuntimeError(err_message) raise RuntimeError(err_message)
print_to_stderr(err_message) print_to_stderr(err_message)
finally: finally:

View File

View File

@ -0,0 +1,27 @@
Copyright 2019-2020 Kakao Brain
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from this
software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
POSSIBILITY OF SUCH DAMAGE.

View File

@ -0,0 +1,11 @@
# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
"""A Pipe implementation in PyTorch."""
from .checkpoint import is_checkpointing, is_recomputing
from .pipe import Pipe
__all__ = ["Pipe", "is_checkpointing", "is_recomputing"]

View File

@ -0,0 +1,164 @@
# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
"""A helper to roughly balance a sequential module.
Usage::
import torch
from torch.distributed._pipeline.sync import Pipe
from torch.distributed._pipeline.sync.balance import balance_by_time
sample = torch.empty(128, 3, 224, 224)
balance = balance_by_time(torch.cuda.device_count(), model, sample)
pipe = Pipe(model, balance, chunks=8)
"""
from typing import List, Tuple, Union
import torch
from torch import Tensor
import torch.nn as nn
from . import blockpartition
from .profile import profile_sizes, profile_times
__all__ = ["balance_by_time", "balance_by_size"]
Device = Union[torch.device, int, str]
Tensors = Tuple[Tensor, ...]
TensorOrTensors = Union[Tensor, Tensors]
def balance_cost(cost: List[int], partitions: int) -> List[int]:
partitioned = blockpartition.solve(cost, partitions)
return [len(p) for p in partitioned]
def balance_by_time(
partitions: int,
module: nn.Sequential,
sample: TensorOrTensors,
*,
timeout: float = 1.0,
device: Device = torch.device("cuda"),
) -> List[int]:
"""Naive automatic balancing by elapsed time per layer.
::
sample = torch.empty(128, 3, 224, 224)
balance = balance_by_time(torch.cuda.device_count(), model, sample)
pipe = Pipe(model, balance, chunks=8)
Args:
partitions (int):
intended number of partitions
module (torch.nn.Sequential):
sequential module to be partitioned
sample (torch.Tensor):
example input with arbitrary batch size
Keyword Args:
timeout (float):
profiling iterates again if the timeout (in second) is not exceeded
(default: ``1.0``)
device ('cpu' or 'cuda' device):
CPU or CUDA device where each layer is profiled (default: the
current CUDA device)
Returns:
A list of number of layers in each partition. Use it for the `balance`
parameter of :class:`~torchpipe.Pipe`.
.. note::
`module` and `sample` must be placed on the same device.
"""
times = profile_times(module, sample, timeout, torch.device(device))
return balance_cost(times, partitions)
def balance_by_size(
partitions: int,
module: nn.Sequential,
input: TensorOrTensors,
*,
chunks: int = 1,
param_scale: float = 2.0,
device: Device = torch.device("cuda"),
) -> List[int]:
"""Naive automatic balancing by CUDA memory usage per layer.
During training, required memory for parameters depends on which optimizer
is used. Optimizers may use buffers for each parameter to track
optimization statistics internally, such as momentum buffer in SGD.
To get more reliable size based balance, you should specify `param_scale`
with regard to your optimizer. The default `param_scale` is 2 instead of 1
due to gradient accumulation which is necessary for every optimizer.
Follow this guide to choose correct `param_scale` for typical optimizers:
========= ============= =========================================
Optimizer `param_scale` Internal State
========= ============= =========================================
SGD 2--3 (momentum_buffer)
Adam 4--5 exp_avg, exp_avg_sq, (max_exp_avg_sq)
Adadelta 4 square_avg, acc_delta
Adagrad 3 sum
RMSprop 3--5 square_avg, (momentum_buffer), (grad_avg)
========= ============= =========================================
Here's a simple example with the Adam optimizer::
balance = balance_by_size(
torch.cuda.device_count(),
model,
# Same size with mini-batch to train
torch.empty(1024, 3, 224, 224),
# Number of micro-batches to train with Pipe
chunks=8,
# 4 for Adam
param_scale=4.0,
)
pipe = Pipe(model, balance, chunks=8)
adam = Adam(pipe.parameters())
Args:
partitions (int):
intended number of partitions
module (torch.nn.Sequential):
sequential module to be partitioned
input (torch.Tensor):
example mini-batch with the same size to train
Keyword Args:
chunks (int):
number of micro-batches will be used to train (default: ``1``)
param_scale (float):
how many copies of parameters would be allocated for training. It
depends on optimizer. See the above guide. (default: ``2.0``)
device ('cuda' device):
CUDA device where each layer is profiled (default: the current CUDA
device)
Returns:
A list of number of layers in each partition. Use it for the `balance`
parameter of :class:`~torchpipe.Pipe`.
.. note::
`module` and `input` must be placed on the same CUDA device.
"""
sizes = profile_sizes(module, input, chunks, param_scale, torch.device(device))
return balance_cost(sizes, partitions)

View File

@ -0,0 +1,95 @@
# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
"""Implements "Block Partitions of Sequences" by Imre Bárány et al.
Paper: https://arxiv.org/pdf/1308.2452.pdf
"""
from typing import Iterator, List, Tuple
__all__ = ["solve"]
def solve(sequence: List[int], partitions: int = 1) -> List[List[int]]:
"""Splits a sequence into several partitions to minimize variance for each
partition.
The result might not be optimal. However, it can be done only in O(kn³),
where k is the number of partitions and n is the length of the sequence.
"""
if partitions < 1:
raise ValueError(f"partitions must be a positive integer ({partitions} < 1)")
n = len(sequence)
if n < partitions:
raise ValueError(f"sequence is shorter than intended partitions ({n} < {partitions})")
# Normalize the sequence in [0, 1].
minimum = min(sequence)
maximum = max(sequence) - minimum
normal_sequence: List[float]
if maximum == 0:
normal_sequence = [0 for _ in sequence]
else:
normal_sequence = [(x - minimum) / maximum for x in sequence]
splits = [n // partitions * (x + 1) for x in range(partitions - 1)] + [n]
def block_size(i: int) -> float:
start = splits[i - 1] if i > 0 else 0
stop = splits[i]
return sum(normal_sequence[start:stop])
def leaderboard() -> Iterator[Tuple[float, int]]:
return ((block_size(i), i) for i in range(partitions))
while True:
"""
(1) Fix p ∈ [k] with M(P) = bp. So Bp is a maximal block of P.
"""
# max_size: M(P)
max_size, p = max(leaderboard())
while True:
"""
(2) If M(P) ≤ m(P) + 1, then stop.
"""
# min_size: m(P)
min_size, q = min(leaderboard())
if max_size <= min_size + 1:
return [sequence[i:j] for i, j in zip([0] + splits[:-1], splits)]
"""
(3) If M(P) > m(P) + 1, then let m(P) = bq for the q ∈ [k] which is
closest to p (ties broken arbitrarily). Thus Bq is a minimal block
of P. Let Bh be the block next to Bq between Bp and Bq. (Note that
Bh is a non-empty block: if it were, then m(P) = 0 and we should
have chosen Bh instead of Bq.)
"""
if p < q:
"""
So either p < q and then h = q1 and we define P by moving
the last element from Bh = Bq1 to Bq,
"""
h = q - 1
splits[h] -= 1
else:
"""
or q < p, and then h = q + 1 and P is obtained by moving the
first element of Bh = Bq+1 to Bq.
"""
h = q + 1
splits[q] += 1
"""
Set P = P . If p = h, then go to (1), else go to (2).
"""
if p == h:
break

View File

@ -0,0 +1,114 @@
# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
"""Per-layer profilers."""
import copy
import time
from typing import Generator, List, Tuple, Union
import torch
from torch import Tensor
import torch.nn as nn
from ..microbatch import Batch
__all__: List[str] = []
Device = Union[torch.device, int, str]
Tensors = Tuple[Tensor, ...]
TensorOrTensors = Union[Tensor, Tensors]
def layerwise_sandbox(module: nn.Sequential, device: torch.device,) -> Generator[nn.Module, None, None]:
"""Copies layers for ease to profile. It doesn't modify the given
module.
"""
for layer in module:
layer_copy = copy.deepcopy(layer)
layer_copy.to(device)
layer_copy.train()
yield layer_copy
def detach(batch: Batch) -> None:
"""Detaches from autograd graph."""
for i, x in enumerate(batch):
batch[i] = x.detach().requires_grad_(x.requires_grad)
def profile_times(module: nn.Sequential, sample: TensorOrTensors, timeout: float, device: torch.device,) -> List[int]:
"""Profiles elapsed times per layer."""
if any(p.grad is not None for p in module.parameters()):
raise ValueError("some parameter already has gradient")
_batch = Batch(sample)
for i, x in enumerate(_batch):
_batch[i] = x.detach().to(device).requires_grad_(x.requires_grad)
time_bufs: List[List[float]] = [[] for _ in module]
begun_at = time.time()
while time.time() - begun_at < timeout:
batch = _batch
for i, layer in enumerate(layerwise_sandbox(module, device)):
detach(batch)
if device.type == "cuda":
torch.cuda.synchronize(device)
tick = time.time()
# Forward
batch = batch.call(layer)
# Backward
backward_tensors = tuple(y for y in batch if y.requires_grad)
if backward_tensors:
torch.autograd.backward(backward_tensors, backward_tensors)
if device.type == "cuda":
torch.cuda.synchronize(device)
tock = time.time()
time_bufs[i].append(tock - tick)
us = 1_000_000
return [sum(int(t * us) for t in buf) for buf in time_bufs]
def profile_sizes(
module: nn.Sequential, input: TensorOrTensors, chunks: int, param_scale: float, device: torch.device,
) -> List[int]:
"""Profiles CUDA memory usage per layer."""
if device.type != "cuda":
raise ValueError("size profiler supports only CUDA device")
batch = Batch(input)
sizes: List[int] = []
latent_scale = batch[0].size(0) / chunks
for i, x in enumerate(batch):
batch[i] = x[:1].detach().to(device).requires_grad_(x.requires_grad)
for layer in layerwise_sandbox(module, device):
detach(batch)
# Detect memory usage at forward.
memory_before = torch.cuda.memory_allocated(device)
batch = batch.call(layer)
memory_after = torch.cuda.memory_allocated(device)
latent_size = memory_after - memory_before
# Analyze size of parameters.
param_size = sum(p.storage().size() * p.storage().element_size() for p in layer.parameters())
# Combine size of parameters and activations with normalize scales.
size = latent_size * latent_scale + param_size * param_scale
sizes.append(int(size))
return sizes

View File

@ -0,0 +1,6 @@
# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

View File

@ -0,0 +1,159 @@
# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
"""Tracks the running statistics per mini-batch instead of micro-batch."""
from typing import Optional, TypeVar, cast
import torch
from torch import Tensor, nn
import torch.nn.functional as F
from torch.nn.modules.batchnorm import _BatchNorm
from .checkpoint import is_recomputing
__all__ = ["DeferredBatchNorm"]
TModule = TypeVar("TModule", bound=nn.Module)
class DeferredBatchNorm(_BatchNorm):
"""A BatchNorm layer tracks multiple micro-batches to update running
statistics per mini-batch.
"""
sum: Tensor
sum_squares: Tensor
def __init__(
self,
num_features: int,
eps: float = 1e-5,
momentum: Optional[float] = 0.1,
affine: bool = True,
chunks: int = 1,
) -> None:
super().__init__(num_features, eps, momentum, affine, track_running_stats=True)
self.register_buffer("sum", torch.zeros_like(self.running_mean))
self.register_buffer("sum_squares", torch.zeros_like(self.running_var))
self.counter = 0
self.tracked = 0
self.chunks = chunks
def _check_input_dim(self, input: Tensor) -> None:
# It's the typical _check_input_dim() implementation in PyTorch.
if input.dim() <= 2:
raise ValueError("expected at least 3D input (got %dD input)" % input.dim())
def _track(self, input: Tensor) -> bool:
"""Tracks statistics of a micro-batch."""
# Dimensions except channel. For example, (0, 2, 3) is for BatchNorm2d.
dim = [0]
dim.extend(range(2, input.dim()))
with torch.no_grad():
self.sum += input.sum(dim)
self.sum_squares += (input ** 2).sum(dim)
size = input.size().numel() // input.size(1)
self.counter += size
self.tracked += 1
return self.tracked == self.chunks
def _commit(self) -> None:
"""Updates the running statistics of a mini-batch."""
exponential_average_factor = 0.0
self.num_batches_tracked += 1
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else: # use exponential moving average
exponential_average_factor = self.momentum
mean = self.sum / self.counter
var = self.sum_squares / self.counter - mean ** 2
# Calculate the exponential moving average here.
m = exponential_average_factor
self.running_mean *= 1 - m
self.running_mean += mean * m
self.running_var *= 1 - m
self.running_var += var * m
self.sum.zero_()
self.sum_squares.zero_()
self.counter = 0
self.tracked = 0
def forward(self, input: Tensor) -> Tensor: # type: ignore
if not self.training:
# Don't train parameters on the evaluation mode.
return F.batch_norm(
input,
running_mean=self.running_mean,
running_var=self.running_var,
weight=self.weight,
bias=self.bias,
training=False,
momentum=0.0,
eps=self.eps,
)
if not is_recomputing():
# Track a micro-batch on the training mode
# but not under a recomputation.
tracked_enough = self._track(input)
# Update the running statistics for a mini-batch
# if it has tracked enough micro-batches.
if tracked_enough:
self._commit()
# Normalize a micro-batch and train the parameters.
return F.batch_norm(
input,
running_mean=None,
running_var=None,
weight=self.weight,
bias=self.bias,
training=True,
momentum=0.0,
eps=self.eps,
)
@classmethod
def convert_deferred_batch_norm(cls, module: TModule, chunks: int = 1) -> TModule:
"""Converts a :class:`nn.BatchNorm` or underlying
:class:`nn.BatchNorm`s into :class:`DeferredBatchNorm`::
from torchvision.models.resnet import resnet101
from torchpipe.batchnorm import DeferredBatchNorm
model = resnet101()
model = DeferredBatchNorm.convert_deferred_batch_norm(model)
"""
if isinstance(module, DeferredBatchNorm) and module.chunks is chunks:
return cast(TModule, module)
module_output: nn.Module = module
if isinstance(module, _BatchNorm) and module.track_running_stats:
module_output = DeferredBatchNorm(module.num_features, module.eps, module.momentum, module.affine, chunks)
if module.affine:
module_output.register_parameter("weight", module.weight)
module_output.register_parameter("bias", module.bias)
module_output.register_buffer("running_mean", module.running_mean)
module_output.register_buffer("running_var", module.running_var)
module_output.register_buffer("num_batches_tracked", module.num_batches_tracked)
for name, child in module.named_children():
module_output.add_module(name, cls.convert_deferred_batch_norm(child, chunks))
return cast(TModule, module_output)

View File

@ -0,0 +1,317 @@
# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
"""Checkpointing with preceding recomputation.
PyTorch already provides the official checkpointing utilities in
:mod:`torch.utils.checkpoint`. The official checkpointing combines
recomputation and recursive backpropagation into one autograd function named
``CheckpointFunction``. Hence, the recomputation can be started only when the
gradients arrive to the function. In Pipe, the recomputation needs to precede
the gradient arrival to minimize the GPU idle time.
We solve this problem by introducing separate autograd functions named
:class:`Recompute` and :class:`Checkpoint`. Each function represents
recomputation and recursive backpropagation, respectively. We can manipulate
the control flow in aspect of both the autograd engine and CUDA with a pair of
the functions.
Specifically, we place CUDA stream synchronization between :class:`Recompute`
and :class:`Checkpoint` to delay only :class:`Checkpoint` until the gradient is
copied entirely.
"""
from collections import deque
from contextlib import contextmanager
import threading
from typing import TYPE_CHECKING, Deque, Generator, List, Optional, Tuple, Union
import torch
from torch import ByteTensor, Tensor
import torch.autograd
from .dependency import fork, join
from .microbatch import Batch
from .phony import get_phony
__all__ = ["is_checkpointing", "is_recomputing"]
Tensors = Tuple[Tensor, ...]
TensorOrTensors = Union[Tensor, Tensors]
# Types for shared memory between Checkpoint and Recompute.
Recomputed = Tuple[TensorOrTensors, Tensors] # (output, input_leaf)
RNGStates = Tuple[ByteTensor, Optional[ByteTensor]] # (cpu_rng_state, gpu_rng_state)
if TYPE_CHECKING:
from typing_extensions import Protocol
else:
Protocol = object
# Protocol with __call__ instead of Callable can be used as an attribute type.
# See: https://github.com/python/mypy/issues/708#issuecomment-561735949
class Function(Protocol):
def __call__(self, input: TensorOrTensors) -> TensorOrTensors:
...
def checkpoint(function: Function, input: TensorOrTensors) -> TensorOrTensors:
"""Makes a checkpoint with a simple interface like
:func:`torch.utils.checkpoint.checkpoint`. It's only used to test or debug
:class:`Checkpoint` and :class:`Recompute` without boilerplate.
"""
batch = Batch(input)
chk = Checkpointing(function, batch)
batch = chk.checkpoint()
chk.recompute(batch)
return batch.tensor_or_tensors
class Checkpointing:
"""Generates a pair of :class:`Checkpoint` and :class:`Recompute`."""
def __init__(self, function: Function, batch: Batch) -> None:
self.function = function
self.batch = batch
# Shared memory between Checkpoint and Recompute. 1-length deque is
# used for mutability and length limitation.
self.recomputed: Deque[Recomputed] = deque(maxlen=1)
self.rng_states: Deque[RNGStates] = deque(maxlen=1)
def checkpoint(self) -> Batch:
"""Returns a batch applied by :class:`Checkpoint`."""
input_atomic = self.batch.atomic
input = tuple(self.batch)
# Use a phony which requires grad to ensure that Checkpoint can be
# tracked by the autograd engine even when none of the input tensors
# require grad.
phony = get_phony(self.batch[0].device, requires_grad=True)
output = Checkpoint.apply(phony, self.recomputed, self.rng_states, self.function, input_atomic, *input)
# Gradients are only supported for float Tensors.
if isinstance(output, tuple):
output = tuple([x if x.is_floating_point() else x.detach() for x in output])
return Batch(output)
def recompute(self, batch: Batch) -> None:
"""Applies :class:`Recompute` to the batch in place."""
input_atomic = self.batch.atomic
input = tuple(self.batch)
# batch[0] is always requiring grad, because it has been passed
# checkpoint with a phony requiring grad.
batch[0], phony = fork(batch[0])
phony = Recompute.apply(phony, self.recomputed, self.rng_states, self.function, input_atomic, *input)
batch[0] = join(batch[0], phony)
class ThreadLocal(threading.local):
def __init__(self) -> None:
self.is_checkpointing = False
self.is_recomputing = False
thread_local = ThreadLocal()
@contextmanager
def enable_checkpointing() -> Generator[None, None, None]:
"""Makes :func:`is_checkpointing` return :data:`True` within a context."""
orig = thread_local.is_checkpointing
thread_local.is_checkpointing = True
try:
yield
finally:
thread_local.is_checkpointing = orig
@contextmanager
def enable_recomputing() -> Generator[None, None, None]:
"""Makes :func:`is_recomputing` return :data:`True` within a context."""
orig = thread_local.is_recomputing
thread_local.is_recomputing = True
try:
yield
finally:
thread_local.is_recomputing = orig
def is_checkpointing() -> bool:
"""Whether the current forward propagation is under checkpointing.
Returns:
bool: :data:`True` if it's under checkpointing.
"""
return thread_local.is_checkpointing
def is_recomputing() -> bool:
"""Whether the current forward propagation is under checkpoint
recomputation. Use this to prevent duplicated side-effects at forward
propagation::
class Counter(nn.Module):
def __init__(self):
super().__init__()
self.counter = 0
def forward(self, input):
if not is_recomputing():
self.counter += 1
return input
Returns:
bool: :data:`True` if it's under checkpoint recomputation.
.. seealso:: :ref:`Detecting Recomputation`
"""
return thread_local.is_recomputing
class Context:
"""The common interface between the :class:`Checkpoint` and
:class:`Recompute` context.
"""
recomputed: Deque[Recomputed]
rng_states: Deque[RNGStates]
function: Function
input_atomic: bool
saved_tensors: Tuple[Tensor, ...]
def save_for_backward(self, *tensors: Tensor) -> None: # pragma: no cover
pass
def save_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> None:
""":meth:`Checkpoint.forward` captures the current PyTorch's random number
generator states at CPU and GPU to reuse in :meth:`Recompute.backward`.
.. seealso:: :ref:`Referential Transparency`
"""
cpu_rng_state = torch.get_rng_state()
gpu_rng_state: Optional[ByteTensor]
if device.type == "cuda":
gpu_rng_state = torch.cuda.get_rng_state(device)
else:
gpu_rng_state = None
rng_states.append((cpu_rng_state, gpu_rng_state))
@contextmanager
def restore_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> Generator[None, None, None]:
""":meth:`Recompute.backward` restores the random number generator states
captured by :func:`save_rng_states` within its context.
.. seealso:: :ref:`Referential Transparency`
"""
cpu_rng_state, gpu_rng_state = rng_states.pop()
gpu_devices: List[torch.device] = []
if device.type == "cuda":
gpu_devices.append(device)
with torch.random.fork_rng(gpu_devices):
torch.set_rng_state(cpu_rng_state)
if gpu_rng_state is not None:
torch.cuda.set_rng_state(gpu_rng_state, device)
yield
class Checkpoint(torch.autograd.Function):
@staticmethod
# type: ignore
def forward(
ctx: Context,
phony: Tensor,
recomputed: Deque[Recomputed],
rng_states: Deque[RNGStates],
function: Function,
input_atomic: bool,
*input: Tensor,
) -> TensorOrTensors:
ctx.recomputed = recomputed
ctx.rng_states = rng_states
save_rng_states(input[0].device, ctx.rng_states)
ctx.function = function
ctx.input_atomic = input_atomic
ctx.save_for_backward(*input)
with torch.no_grad(), enable_checkpointing():
output = function(input[0] if input_atomic else input)
return output
@staticmethod
def backward(ctx: Context, *grad_output: Tensor,) -> Tuple[Optional[Tensor], ...]: # pragma: no cover
output, input_leaf = ctx.recomputed.pop()
if isinstance(output, tuple):
tensors = output
else:
tensors = (output,)
if any(y.requires_grad for y in tensors):
tensors = tuple([x for x in tensors if x.requires_grad])
torch.autograd.backward(tensors, grad_output)
grad_input: List[Optional[Tensor]] = [None, None, None, None, None]
grad_input.extend(x.grad for x in input_leaf)
return tuple(grad_input)
class Recompute(torch.autograd.Function):
@staticmethod
# type: ignore
def forward(
ctx: Context,
phony: Tensor,
recomputed: Deque[Recomputed],
rng_states: Deque[RNGStates],
function: Function,
input_atomic: bool,
*input: Tensor,
) -> Tensor:
ctx.recomputed = recomputed
ctx.rng_states = rng_states
ctx.function = function
ctx.input_atomic = input_atomic
ctx.save_for_backward(*input)
return phony
@staticmethod
def backward(ctx: Context, *grad_output: Tensor) -> Tuple[None, ...]: # pragma: no cover
input = ctx.saved_tensors
input_leaf = tuple(x.detach().requires_grad_(x.requires_grad) for x in input)
with restore_rng_states(input[0].device, ctx.rng_states):
with torch.enable_grad(), enable_recomputing():
output = ctx.function(input_leaf[0] if ctx.input_atomic else input_leaf)
ctx.recomputed.append((output, input_leaf))
grad_input: List[None] = [None, None, None, None, None]
grad_input.extend(None for _ in ctx.saved_tensors)
return tuple(grad_input)

View File

@ -0,0 +1,104 @@
# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
"""Autograd functions for stream-aware CUDA copy. It is used to overlap copy
and computation on the same GPU.
"""
from collections import deque
from typing import Deque, List, Optional, Tuple
import torch
from torch import Tensor
from .stream import AbstractStream, current_stream, get_device, record_stream, use_stream, wait_stream
__all__: List[str] = []
Tensors = Tuple[Tensor, ...]
# Common interface between :class:`Copy` and :class:`Wait`.
class Context:
prev_stream: AbstractStream
next_stream: AbstractStream
class Copy(torch.autograd.Function):
"""Copies tensors on specific streams."""
@staticmethod
# type: ignore
def forward(ctx: Context, prev_stream: AbstractStream, next_stream: AbstractStream, *input: Tensor,) -> Tensors:
ctx.prev_stream = prev_stream
ctx.next_stream = next_stream
output = []
output_stream = current_stream(get_device(next_stream))
with use_stream(prev_stream), use_stream(next_stream):
for x in input:
y = x.to(get_device(next_stream), non_blocking=True)
output.append(y)
# 'prev_stream' is not where 'x' has been allocated.
record_stream(x, prev_stream)
# 'y' has been allocated on 'next_stream'.
# It might be used on the current stream captured as 'output_stream'.
record_stream(y, output_stream)
return tuple(output)
@staticmethod
def backward(ctx: Context, *grad_output: Tensor,) -> Tuple[Optional[Tensor], ...]:
prev_stream = ctx.prev_stream
next_stream = ctx.next_stream
grad_input: Deque[Tensor] = deque(maxlen=len(grad_output))
input_stream = current_stream(get_device(prev_stream))
with use_stream(prev_stream), use_stream(next_stream):
for x in reversed(grad_output):
y = x.to(get_device(prev_stream), non_blocking=True)
grad_input.appendleft(y)
# 'next_stream' is not where 'x' has been allocated.
record_stream(x, next_stream)
# 'y' has been allocated on 'prev_stream'.
# It might be used on the current stream captured as 'input_stream'.
record_stream(y, input_stream)
grad_streams: Tuple[Optional[Tensor], ...] = (None, None)
return grad_streams + tuple(grad_input)
class Wait(torch.autograd.Function):
"""Synchronizes a stream to another stream.
Place it just before you want to start an operation on the next stream,
provided that all operations on the previous stream are done.
"""
@staticmethod
# type: ignore
def forward(ctx: Context, prev_stream: AbstractStream, next_stream: AbstractStream, *input: Tensor,) -> Tensors:
ctx.prev_stream = prev_stream
ctx.next_stream = next_stream
wait_stream(next_stream, prev_stream)
return tuple(x.detach() for x in input)
@staticmethod
def backward(ctx: Context, *grad_input: Tensor,) -> Tuple[Optional[Tensor], ...]:
prev_stream = ctx.prev_stream
next_stream = ctx.next_stream
wait_stream(prev_stream, next_stream)
grad_streams: Tuple[Optional[Tensor], ...] = (None, None)
return grad_streams + grad_input

View File

@ -0,0 +1,54 @@
# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
"""Arbitrary dependency between two autograd lanes."""
from typing import List, Tuple
import torch
from torch import Tensor
from .phony import get_phony
__all__: List[str] = []
def fork(input: Tensor) -> Tuple[Tensor, Tensor]:
"""Branches out from an autograd lane of the given tensor."""
if torch.is_grad_enabled() and input.requires_grad:
input, phony = Fork.apply(input)
else:
phony = get_phony(input.device, requires_grad=False)
return input, phony
class Fork(torch.autograd.Function):
@staticmethod
def forward(ctx: "Fork", input: Tensor) -> Tuple[Tensor, Tensor]: # type: ignore
phony = get_phony(input.device, requires_grad=False)
return input.detach(), phony.detach()
@staticmethod
def backward(ctx: "Fork", grad_input: Tensor, grad_grad: Tensor) -> Tensor: # type: ignore
return grad_input
def join(input: Tensor, phony: Tensor) -> Tensor:
"""Merges two autograd lanes."""
if torch.is_grad_enabled() and (input.requires_grad or phony.requires_grad):
input = Join.apply(input, phony)
return input
class Join(torch.autograd.Function):
@staticmethod
def forward(ctx: "Join", input: Tensor, phony: Tensor) -> Tensor: # type: ignore
return input.detach()
@staticmethod
def backward(ctx: "Join", grad_input: Tensor) -> Tuple[Tensor, None]: # type: ignore
return grad_input, None

View File

@ -0,0 +1,185 @@
# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
"""Manipulation of micro-batches."""
import typing
from typing import Callable, Iterable, Iterator, List, Tuple, Union, cast
import torch
from torch import Tensor
import torch.cuda.comm
__all__: List[str] = []
Tensors = Tuple[Tensor, ...]
TensorOrTensors = Union[Tensor, Tensors]
Function = Callable[[TensorOrTensors], TensorOrTensors]
class Batch:
"""An abstraction of an atomic tensor or a tuple of tensors. This
eliminates every boilerplate code to classify an atomic tensor or a tuple
of tensors.
::
x = generate_tensor_or_tensors()
x = Batch(x)
# in-place update
x[0] = F.apply(x[0])
x[:] = F.apply(*x)
# f(x) if x is a tensor.
# f(*x) if x is a tuple of tensors.
# y is also a batch.
y = x.call(f)
"""
def __init__(self, value: TensorOrTensors) -> None:
self.value = value
self.atomic = torch.is_tensor(value)
@property
def tensor(self) -> Tensor:
"""Retrieves the underlying tensor."""
if not self.atomic:
raise AttributeError("not atomic batch")
return cast(Tensor, self.value)
@property
def tensors(self) -> Tensors:
"""Retrieves the underlying tensors."""
if self.atomic:
raise AttributeError("batch is atomic")
return cast(Tensors, self.value)
@property
def tensor_or_tensors(self) -> TensorOrTensors:
"""Retrieves the underlying tensor or tensors regardless of type."""
return self.value
def call(self, function: Function) -> "Batch":
"""Calls a function by the underlying tensor or tensors. It also wraps
the output with :class:`Batch`.
"""
return Batch(function(self.value))
def __repr__(self) -> str:
return f"Batch[atomic={self.atomic!r}]({self.value!r})"
def __iter__(self) -> Iterator[Tensor]:
if self.atomic:
yield self.tensor
else:
yield from self.tensors
def __len__(self) -> int:
return 1 if self.atomic else len(self.tensors)
def __getitem__(self, index: int) -> Tensor:
if not self.atomic:
return self.tensors[index]
if index != 0:
raise IndexError("atomic batch allows index 0 only")
return self.tensor
# NOTE(sublee): pyflakes can't detect "overload" instead of "typing.overload".
@typing.overload
def __setitem__(self, index: int, value: Tensor) -> None:
...
@typing.overload
def __setitem__(self, index: slice, value: Tensors) -> None:
...
def __setitem__(self, index: Union[int, slice], value: TensorOrTensors) -> None:
if isinstance(index, int):
value = cast(Tensor, value)
self._setitem_by_index(index, value)
else:
value = cast(Tensors, value)
self._setitem_by_slice(index, value)
def _setitem_by_index(self, index: int, value: Tensor) -> None:
if not self.atomic:
i = index
self.value = self.value[:i] + (value,) + self.value[i + 1 :]
return
if index != 0:
raise IndexError("atomic batch allows index 0 only")
self.value = value
def _setitem_by_slice(self, index: slice, value: Tensors) -> None:
if not (index.start is index.stop is index.step is None):
raise NotImplementedError("only slice [:] supported")
if not self.atomic:
self.value = value
return
if len(value) != 1:
raise IndexError("atomic batch cannot be replaced with multiple tensors")
self.value = value[0]
def check(input: TensorOrTensors) -> None:
"""Checks whether the input is a tensor or tensors.
Raises:
TypeError: input is not a tensor or tensors.
"""
if isinstance(input, tuple):
for x in input:
check(x)
return
if not isinstance(input, Tensor):
raise TypeError(f"expected Tensor, but got {input.__class__.__name__}")
def scatter(input: TensorOrTensors, chunks: int) -> List[Batch]:
"""Splits an input mini-batch into multiple micro-batches."""
inputs: Iterable[TensorOrTensors]
if isinstance(input, Tensor):
inputs = input.chunk(chunks)
else:
rotated: List[Tensors] = []
for tensor in input:
tensors = tensor.chunk(chunks)
rotated.append(cast(Tensors, tensors))
inputs = zip(*rotated)
return [Batch(x) for x in inputs]
def gather(outputs: List[Batch]) -> TensorOrTensors:
"""Concatenates output micro-batches into a mini-batch."""
output: TensorOrTensors
if outputs[0].atomic:
tensors = tuple(b.tensor for b in outputs)
output = torch.cat(tensors)
else:
rotated = [b.tensors for b in outputs]
output_buf = []
for tensors in zip(*rotated):
output_buf.append(torch.cat(tensors))
output = tuple(output_buf)
return output

View File

@ -0,0 +1,49 @@
# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
"""Provides phony for arbitrary dependency in a autograd graph."""
from typing import Dict, List, Tuple
import torch
from torch import Tensor
from .stream import default_stream, use_stream
__all__: List[str] = []
_phonies: Dict[Tuple[torch.device, bool], Tensor] = {}
def get_phony(device: torch.device, *, requires_grad: bool) -> Tensor:
"""Gets a phony. Phony is tensor without space. It is useful to make
arbitrary dependency in a autograd graph because it doesn't require any
gradient accumulation.
.. note::
Phonies for each device are cached. If an autograd function gets a phony
internally, the phony must be detached to be returned. Otherwise, the
autograd engine will mutate the cached phony in-place::
class Phonify(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
phony = get_phony(input.device, requires_grad=False)
return phony.detach() # detach() is necessary.
"""
key = (device, requires_grad)
try:
phony = _phonies[key]
except KeyError:
with use_stream(default_stream(device)):
phony = torch.empty(0, device=device, requires_grad=requires_grad)
_phonies[key] = phony
return phony

View File

@ -0,0 +1,394 @@
# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
"""The Pipe interface."""
from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union, cast
import torch
from torch import Tensor, nn
import torch.autograd
import torch.cuda
from . import microbatch
from .batchnorm import DeferredBatchNorm
from .pipeline import Pipeline
from .skip.layout import inspect_skip_layout
from .skip.skippable import verify_skippables
from .stream import AbstractStream, new_stream
__all__ = ["Pipe"]
Device = Union[torch.device, int, str]
Devices = Union[Iterable[Device], List[Device]]
Tensors = Tuple[Tensor, ...]
TensorOrTensors = Union[Tensor, Tensors]
if TYPE_CHECKING:
Module = nn.Module[TensorOrTensors]
NamedModules = OrderedDict[str, Module]
else:
Module = nn.Module
NamedModules = OrderedDict
def recommend_auto_balance(message: str) -> str:
"""Expands a message with recommendation to :mod:`torchpipe.balance`."""
return f"""{message}
If your model is still under development, its optimal balance would change
frequently. In this case, we highly recommend 'torch.distributed._pipeline.sync.balance' for
naive automatic balancing:
from torch.distributed._pipeline.sync import Pipe
from torch.distributed._pipeline.sync.balance import balance_by_time
partitions = torch.cuda.device_count()
sample = torch.empty(...)
balance = balance_by_time(partitions, model, sample)
model = Pipe(model, balance, ...)
"""
def verify_module(module: nn.Sequential) -> None:
if not isinstance(module, nn.Sequential):
raise TypeError("module must be nn.Sequential to be partitioned")
named_children = list(module.named_children())
if len(named_children) != len(module):
raise ValueError("module with duplicate children is not supported")
def verify_splitting(
module: nn.Sequential, partitions: List[nn.Sequential], balance: Iterable[int], devices: List[torch.device]
) -> None:
num_parameters = len(list(module.parameters()))
num_child_parameters = sum(len(list(child.parameters())) for child in module.children())
if num_parameters == num_child_parameters:
return
for i in range(len(partitions)):
for j in range(i + 1, len(partitions)):
parti = partitions[i]
partj = partitions[j]
if devices[i] == devices[j]:
continue
for p in parti.parameters():
for q in partj.parameters():
if p is q:
raise ValueError("module with duplicate parameters on distinct devices is not supported")
class BalanceError(ValueError):
pass
def split_module(
module: nn.Sequential, balance: Iterable[int], devices: List[torch.device],
) -> Tuple[List[nn.Sequential], List[int], List[torch.device]]:
"""Splits a module into multiple partitions.
Returns:
A tuple of (partitions, balance, devices).
Partitions are represented as a :class:`~torch.nn.ModuleList` whose
item is a partition. All layers in a partition are placed in the
same device.
Raises:
BalanceError:
wrong balance
IndexError:
the number of devices is fewer than the number of partitions.
"""
balance = list(balance)
if len(module) != sum(balance):
raise BalanceError(
"module and sum of balance have different length "
f"(module: {len(module)}, sum of balance: {sum(balance)})"
)
if any(x <= 0 for x in balance):
raise BalanceError(f"all balance numbers must be positive integer (balance: {balance})")
if len(balance) > len(devices):
raise IndexError(
"too few devices to hold given partitions " f"(devices: {len(devices)}, partitions: {len(balance)})"
)
j = 0
partitions = []
layers: NamedModules = OrderedDict()
for name, layer in module.named_children():
layers[name] = layer
if len(layers) == balance[j]:
# Group buffered layers as a partition.
partition = nn.Sequential(layers)
device = devices[j]
partition.to(device)
partitions.append(partition)
# Prepare for the next partition.
layers.clear()
j += 1
partitions = cast(List[nn.Sequential], nn.ModuleList(partitions))
del devices[j:]
return partitions, balance, devices
MOVING_DENIED = TypeError("denied to move parameters and buffers, " "because Pipe should manage device placement")
class Pipe(Module):
"""Wraps an arbitrary :class:`nn.Sequential <torch.nn.Sequential>` module
to train on Pipe_. If the module requires lots of memory, Pipe will be
very efficient.
::
model = nn.Sequential(a, b, c, d)
model = Pipe(model, balance=[1, 1, 1, 1], chunks=8)
output = model(input)
.. _Pipe: https://arxiv.org/abs/1811.06965
Pipe combines pipeline parallelism with checkpointing to reduce peak
memory required to train while minimizing device under-utilization.
You should determine the balance when defining a :class:`Pipe` module, as
balancing will not be done automatically. The module will be partitioned
into multiple devices according to the given balance. You may rely on
heuristics to find your own optimal configuration.
Args:
module (torch.nn.Sequential):
sequential module to be parallelized
balance (ints):
list of number of layers in each partition
Keyword Args:
devices (iterable of devices):
devices to use (default: all CUDA devices)
chunks (int):
number of micro-batches (default: ``1``)
checkpoint (str):
when to enable checkpointing, one of ``'always'``,
``'except_last'``, or ``'never'`` (default: ``'except_last'``)
deferred_batch_norm (bool):
whether to use deferred BatchNorm moving statistics (default:
:data:`False`, see :ref:`Deferred Batch Normalization` for more
details)
Raises:
TypeError:
the module is not a :class:`nn.Sequential <torch.nn.Sequential>`.
ValueError:
invalid arguments, or wrong balance
IndexError:
the number of devices is fewer than the number of partitions.
"""
#: The number of layers in each partition.
balance: List[int] = []
# ^^
# The default value [] required for Sphinx's autoattribute.
#: The devices mapped to each partition.
#:
#: ``devices[-1]`` refers to the device of the last partition, which means
#: it is the output device. Probably, you need to use it to transfer the
#: target to calculate the loss without a device mismatch
#: :exc:`RuntimeError`. For example::
#:
#: out_device = pipe.devices[-1]
#:
#: for input, target in loader:
#: target = target.to(out_device, non_blocking=True)
#: output = pipe(input)
#: loss = F.cross_entropy(output, target)
#:
devices: List[torch.device] = []
#: The number of micro-batches.
chunks: int = 1
#: The checkpoint mode to determine when to enable checkpointing. It is one
#: of ``'always'``, ``'except_last'``, or ``'never'``.
checkpoint: str = "except_last"
def __init__(
self,
module: nn.Sequential,
balance: Optional[Iterable[int]] = None,
*,
devices: Optional[Devices] = None,
chunks: int = chunks,
checkpoint: str = checkpoint,
deferred_batch_norm: bool = False,
) -> None:
super().__init__()
chunks = int(chunks)
checkpoint = str(checkpoint)
if balance is None:
raise ValueError(recommend_auto_balance("balance is required"))
if chunks <= 0:
raise ValueError("number of chunks must be positive integer")
if checkpoint not in ["always", "except_last", "never"]:
raise ValueError("checkpoint is not one of 'always', 'except_last', or 'never'")
verify_module(module)
# Verify if the underlying skippable modules satisfy integrity. The
# integrity can be verified before forward() because it is static.
verify_skippables(module)
self.chunks = chunks
self.checkpoint = checkpoint
if deferred_batch_norm:
module = DeferredBatchNorm.convert_deferred_batch_norm(module, chunks)
if devices is None:
devices = range(torch.cuda.device_count())
devices = [torch.device(d) for d in devices]
devices = cast(List[torch.device], devices)
try:
self.partitions, self.balance, self.devices = split_module(module, balance, devices)
except BalanceError as exc:
raise ValueError(recommend_auto_balance(str(exc)))
verify_splitting(module, self.partitions, self.balance, self.devices)
self._copy_streams: List[List[AbstractStream]] = []
self._skip_layout = inspect_skip_layout(self.partitions)
# Separate CUDA streams for copy.
copy_streams = self._ensure_copy_streams()
# The micro-batch index where the checkpointing stops.
checkpoint_stop = {"always": self.chunks, "except_last": self.chunks - 1, "never": 0}[self.checkpoint]
self.pipeline = Pipeline(self.partitions, self.devices, copy_streams, self._skip_layout, checkpoint_stop)
def __len__(self) -> int:
"""Counts the length of the underlying sequential module."""
return sum(len(p) for p in self.partitions)
def __getitem__(self, index: int) -> nn.Module:
"""Gets a layer in the underlying sequential module."""
partitions = self.partitions
if index < 0:
partitions = partitions[::-1]
for partition in partitions:
try:
return partition[index]
except IndexError:
pass
shift = len(partition)
if index < 0:
index += shift
else:
index -= shift
raise IndexError
def __iter__(self) -> Iterable[nn.Module]:
"""Iterates over children of the underlying sequential module."""
for partition in self.partitions:
yield from partition
# Pipe should manage the device of each partition.
# Deny cuda(), cpu(), and to() with device, by TypeError.
def cuda(self, device: Optional[Device] = None) -> "Pipe":
raise MOVING_DENIED
def cpu(self) -> "Pipe":
raise MOVING_DENIED
def to(self, *args: Any, **kwargs: Any) -> "Pipe":
# Deny these usages:
#
# - to(device[, dtype, non_blocking])
# - to(tensor[, non_blocking])
#
# But allow this:
#
# - to(dtype[, non_blocking])
#
if "device" in kwargs or "tensor" in kwargs:
raise MOVING_DENIED
if args:
if isinstance(args[0], (torch.device, int, str)):
raise MOVING_DENIED
if torch.is_tensor(args[0]):
raise MOVING_DENIED
return super().to(*args, **kwargs)
def _ensure_copy_streams(self) -> List[List[AbstractStream]]:
"""Ensures that :class:`Pipe` caches CUDA streams for copy.
It's worth to cache CUDA streams although PyTorch already manages a
pool of pre-allocated CUDA streams, because it may reduce GPU memory
fragementation when the number of micro-batches is small.
"""
if not self._copy_streams:
for device in self.devices:
self._copy_streams.append([new_stream(device) for _ in range(self.chunks)])
return self._copy_streams
def forward(self, input: TensorOrTensors) -> TensorOrTensors: # type: ignore
""":class:`Pipe` is a fairly transparent module wrapper. It doesn't
modify the input and output signature of the underlying module. But
there's type restriction. Input and output have to be a
:class:`~torch.Tensor` or a tuple of tensors. This restriction is
applied at partition boundaries too.
Args:
input (torch.Tensor or tensors): input mini-batch
Returns:
tensor or tensors: output mini-batch
Raises:
TypeError: input is not a tensor or tensors.
"""
microbatch.check(input)
if not self.devices:
# Empty sequential module is not illegal.
return input
# Divide a mini-batch into micro-batches.
batches = microbatch.scatter(input, self.chunks)
# Run pipeline parallelism.
self.pipeline.run(batches)
# Merge the micro-batches into one mini-batch.
output = microbatch.gather(batches)
return output

View File

@ -0,0 +1,257 @@
# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
"""The pipeline parallelism of Pipe."""
from queue import Queue
from types import TracebackType
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Type, Union, cast
import torch
from torch import Tensor, nn
from torch.autograd.profiler import record_function
from .checkpoint import Checkpointing
from .copy import Copy, Wait
from .dependency import fork, join
from .microbatch import Batch
from .skip.layout import SkipLayout
from .skip.tracker import SkipTrackerThroughPotals, use_skip_tracker
from .stream import AbstractStream, current_stream, use_device
from .worker import Task, create_workers, join_workers
__all__: List[str] = []
Tensors = Tuple[Tensor, ...]
TensorOrTensors = Union[Tensor, Tensors]
ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType]
# Queue is generic only in stubs.
# https://mypy.readthedocs.io/en/latest/common_issues.html#using-classes-that-are-generic-in-stubs-but-not-at-runtime
if TYPE_CHECKING:
InQueue = Queue[Optional["Task"]]
OutQueue = Queue[Tuple[bool, Union[Tuple["Task", Batch], ExcInfo, None]]]
else:
InQueue = Queue
OutQueue = Queue
def depend(fork_from: Batch, join_to: Batch) -> None:
fork_from[0], phony = fork(fork_from[0])
join_to[0] = join(join_to[0], phony)
def copy(batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream) -> None:
batch[:] = Copy.apply(prev_stream, next_stream, *batch)
# Gradients are only supported for float Tensors.
batch[:] = tuple([x if x.is_floating_point() else x.detach() for x in batch])
def wait(batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream) -> None:
batch[:] = Wait.apply(prev_stream, next_stream, *batch)
# Gradients are only supported for float Tensors.
batch[:] = tuple([x if x.is_floating_point() else x.detach() for x in batch])
def clock_cycles(m: int, n: int) -> Iterable[List[Tuple[int, int]]]:
"""Generates schedules for each clock cycle."""
# m: number of micro-batches
# n: number of partitions
# i: index of micro-batch
# j: index of partition
# k: clock number
#
# k (i,j) (i,j) (i,j)
# - ----- ----- -----
# 0 (0,0)
# 1 (1,0) (0,1)
# 2 (2,0) (1,1) (0,2)
# 3 (2,1) (1,2)
# 4 (2,2)
for k in range(m + n - 1):
yield [(k - j, j) for j in range(max(1 + k - m, 0), min(1 + k, n))]
class Pipeline:
"""The pipeline parallelism for Pipe."""
def __init__(
self,
partitions: List[nn.Sequential],
devices: List[torch.device],
copy_streams: List[List[AbstractStream]],
skip_layout: SkipLayout,
checkpoint_stop: int,
) -> None:
self.partitions = partitions
self.devices = devices
self.copy_streams = copy_streams
self.skip_layout = skip_layout
self.checkpoint_stop = checkpoint_stop
(self.in_queues, self.out_queues) = create_workers(devices)
def __del__(self) -> None:
join_workers(self.in_queues, self.out_queues)
def run(self, batches: List[Batch]) -> None:
"""Runs pipeline parallelism.
It modifies the given batches in place.
"""
partitions = self.partitions
devices = self.devices
skip_layout = self.skip_layout
m = len(batches)
n = len(partitions)
skip_trackers = [SkipTrackerThroughPotals(skip_layout) for _ in batches]
for schedule in clock_cycles(m, n):
self.fence(batches, schedule, skip_trackers)
self.compute(batches, schedule, skip_trackers)
def fence(
self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals],
) -> None:
"""Copies micro-batches after computation for the previous
micro-batches.
"""
copy_streams = self.copy_streams
skip_layout = self.skip_layout
for i, j in schedule:
# Ensure that batches[i-1] is executed after batches[i] in
# backpropagation by an explicit dependency.
if i != 0 and j != 0:
depend(batches[i - 1], batches[i])
next_stream = copy_streams[j][i]
for prev_j, ns, name in skip_layout.copy_policy(j):
prev_stream = copy_streams[prev_j][i]
skip_trackers[i].copy(batches[i], prev_stream, next_stream, ns, name)
if j != 0:
prev_stream = copy_streams[j - 1][i]
copy(batches[i], prev_stream, next_stream)
def compute(
self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals],
) -> None:
"""Runs tasks with synchronization to copy streams."""
partitions = self.partitions
devices = self.devices
copy_streams = self.copy_streams
checkpoint_stop = self.checkpoint_stop
# Disable checkpointing if in eval mode.
if not self.partitions[0].training:
checkpoint_stop = 0
n = len(partitions)
streams = [current_stream(d) for d in devices]
exc_info: Optional[ExcInfo] = None
# With checkpointing, the autograd graph looks like this diagram:
# ┌─────┸──────┐
# │ Copy │
# └─────┰──────┘ (fence)
# ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
# ┃ (compute)
# ┌─────┸──────┐
# │ Wait │ [1] Synchronize the current stream with the copy stream.
# └─────┰──────┘
# ┌─────┸──────┐
# │ Checkpoint │ [2] Compute a partition within checkpointing.
# └─────┰──────┘
# ┌─────┸──────┐
# │ Wait │ [3] Synchronize the copy stream with the current stream.
# └─────┰──────┘
# ┠ ─ ─ ─ ┐
# ┃ ┌─────┴─────┐
# ┃ │ Recompute │ [4] Schedule the recomputation at backpropagation.
# ┃ └─────┬─────┘
# ┠ ─ ─ ─ ┘
# ┃
# ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
# ┌─────┸──────┐ (fence)
# │ Copy │
# └─────┰──────┘
for i, j in schedule:
batch = batches[i]
partition = partitions[j]
# Synchronize with the copied input. ([1] in the diagram)
if j != 0:
wait(batch, copy_streams[j][i], streams[j])
# Determine whether checkpointing or not.
checkpoint = i < checkpoint_stop
if checkpoint:
def function(
input: TensorOrTensors,
partition: nn.Sequential = partition,
skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
chunk_id: int = i,
part_id: int = j,
) -> TensorOrTensors:
with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)):
return partition(input)
chk = Checkpointing(function, batch)
task = Task(streams[j], compute=chk.checkpoint, finalize=chk.recompute)
del function, chk
else:
def compute(
batch: Batch = batch,
partition: nn.Sequential = partition,
skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
chunk_id: int = i,
part_id: int = j,
) -> Batch:
with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)):
return batch.call(partition)
task = Task(streams[j], compute=compute, finalize=None)
del compute
# Compute tasks in parallel. ([2] in the diagram)
self.in_queues[j].put(task)
for i, j in schedule:
ok, payload = self.out_queues[j].get()
# Hold the first exception.
if exc_info is not None:
continue
elif not ok:
exc_info = cast(ExcInfo, payload)
continue
task, batch = cast(Tuple[Task, Batch], payload)
# The copy stream synchronizes to copy the output. ([3] in the
# diagram)
if j != n - 1:
wait(batch, streams[j], copy_streams[j][i])
# Finalize tasks. If checkpointing is enabled, here the
# recomputation is scheduled at backpropagation. ([4] in the
# diagram)
with use_device(devices[j]):
task.finalize(batch)
batches[i] = batch
# Fail at the first exception.
if exc_info is not None:
raise exc_info[0].with_traceback(exc_info[1], exc_info[2])

View File

@ -0,0 +1,6 @@
# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

View File

@ -0,0 +1,11 @@
# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
"""Supports efficiency with skip connections."""
from .namespace import Namespace
from .skippable import pop, skippable, stash, verify_skippables
__all__ = ["skippable", "stash", "pop", "verify_skippables", "Namespace"]

View File

@ -0,0 +1,86 @@
# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
"""Static skip connection layout of ``@skippable`` modules."""
from typing import Dict, Iterable, List, Tuple
from torch import nn
from .namespace import Namespace
__all__: List[str] = []
class SkipLayout:
"""Represents a skip connection layout across partitions."""
# Skip routes indexed by 'ns, name': {(ns, name): (prev_j, next_j), ...}
by_ns_name: Dict[Tuple[Namespace, str], Tuple[int, int]]
# Skip routes indexed by partition number 'j': [[next_j]: [(prev_j, ns, name), ...], ...]
by_partition: List[List[Tuple[int, Namespace, str]]]
def __init__(self, num_partitions: int, skip_routes: Dict[Tuple[Namespace, str], Tuple[int, int]],) -> None:
# The skip routes are already indexed by 'ns, name'.
self.by_ns_name = skip_routes
# Index skip routes by partition number 'j'.
self.by_partition = [[] for _ in range(num_partitions)]
for (ns, name), (prev_j, next_j) in skip_routes.items():
self.by_partition[next_j].append((prev_j, ns, name))
for p in self.by_partition:
p.sort()
def copy_policy(self, next_j: int) -> Iterable[Tuple[int, Namespace, str]]:
"""Generates skip routes for the given destination partition number.
The skip routes are sorted by source partition number in ascending
order.
Yields:
Each tuple of (source partition number, namespace, name).
"""
for prev_j, ns, name in self.by_partition[next_j]:
if prev_j == next_j:
# This skip tensor will be popped at the same partition where
# it is stashed. In this case, copy is not required.
continue
yield (prev_j, ns, name)
def requires_copy(self, ns: Namespace, name: str) -> bool:
"""Whether the given namespace and name requires partition-to-partition
copy or not.
"""
prev_j, next_j = self.by_ns_name.get((ns, name), (-1, -1))
return prev_j != next_j
def inspect_skip_layout(partitions: List[nn.Sequential]) -> SkipLayout:
"""Inspects the skip connection layout in the given partitions."""
# NOTE(sublee): Hide circular import inside this subroutine. Circular
# import is not ideal but placing this logic near to SkipLayout may
# increase cohesion of code.
from .skippable import Skippable
skip_routes: Dict[Tuple[Namespace, str], Tuple[int, int]] = {}
stashed_at: Dict[Tuple[Namespace, str], int] = {}
for j, partition in enumerate(partitions):
for layer in partition:
if not isinstance(layer, Skippable):
continue
for ns, name in layer.stashable():
stashed_at[(ns, name)] = j
for ns, name in layer.poppable():
prev_j = stashed_at.pop((ns, name))
skip_routes[(ns, name)] = (prev_j, j)
return SkipLayout(len(partitions), skip_routes)

View File

@ -0,0 +1,50 @@
# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
"""Provides isolated namespace of skip tensors."""
import abc
from functools import total_ordering
from typing import Any
import uuid
__all__ = ["Namespace"]
@total_ordering
class Namespace(metaclass=abc.ABCMeta):
"""Namespace for isolating skip tensors used by :meth:`isolate()
<torchpipe.skip.skippable.Skippable.isolate>`.
"""
__slots__ = ("id",)
def __init__(self) -> None:
self.id = uuid.uuid4()
def __repr__(self) -> str:
return f"<Namespace '{self.id}'>"
def __hash__(self) -> int:
return hash(self.id)
# Namespaces should support ordering, since SkipLayout will sort tuples
# including a namespace. But actual order between namespaces is not
# important. That's why they are ordered by version 4 UUID which generates
# random numbers.
def __lt__(self, other: Any) -> bool:
if isinstance(other, Namespace):
return self.id < other.id
return False
def __eq__(self, other: Any) -> bool:
if isinstance(other, Namespace):
return self.id == other.id
return False
# 'None' is the default namespace,
# which means that 'isinstance(None, Namespace)' is 'True'.
Namespace.register(type(None))

View File

@ -0,0 +1,231 @@
# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
"""Portal keeps a tensor in the pocket plane. The tensor becomes hidden to the
autograd engine. The shared context of three functions (:class:`PortalBlue`,
:class:`PortalOrange`, and :class:`PortalCopy`) out of the computation graph is
one of the most important feature of :mod:`torchpipe.skip`.
The metaphor is inspired by Portal™ from Valve.
"""
from typing import List, Optional, Tuple
import torch
from torch import Tensor
from ..copy import Context as CopyContext
from ..copy import Copy
from ..phony import get_phony
from ..stream import AbstractStream, get_device
__all__: List[str] = []
class Portal:
"""A portal for a tensor."""
def __init__(self, tensor: Optional[Tensor], tensor_life: int) -> None:
self.put_tensor(tensor, tensor_life)
self.grad: Optional[Tensor] = None
def blue(self) -> Tensor:
"""Creates a :class:`PortalBlue` which hides the underlying tensor from
the autograd engine.
Join the returning phony to the main lane of the autograd graph to
assure the correct backpropagation::
PortalBlue --+
|
---------- Join --
"""
tensor = self.use_tensor()
if tensor is None:
return get_phony(torch.device("cpu"), requires_grad=False)
return PortalBlue.apply(self, tensor)
def orange(self, phony: Tensor) -> Optional[Tensor]:
"""Creates a :class:`PortalOrange` which retrieves the hidden tensor
without losing ability of backpropagation.
Give a phony forked from the main lane of an autograd graph::
+-- PortalOrange --+
| |
-- Fork --------- f(a, b) --
"""
self.check_tensor_life()
if self.tensor is None:
return self.use_tensor()
return PortalOrange.apply(self, phony)
def copy(self, prev_stream: AbstractStream, next_stream: AbstractStream, phony: Tensor,) -> Tensor:
"""Copies the hidden tensor by a :class:`PortalCopy`.
Give a phony and use the returning phony to keep backpropagation::
+-- PortalCopy --+
| |
-- Fork ---------- Join --
"""
if self.tensor is None:
return get_phony(torch.device("cpu"), requires_grad=False)
return PortalCopy.apply(self, prev_stream, next_stream, phony)
def check_tensor_life(self) -> None:
if self.tensor_life <= 0:
raise RuntimeError("tensor in portal has been removed")
def put_tensor(self, tensor: Optional[Tensor], tensor_life: int) -> None:
"""Stores a tensor into this portal."""
# [Life of Tensor through Portal]
#
# The tensor can be retrieved by use_tensor() up to 'tensor_life'
# times. When the life becomes 0, the tensor will be deleted for
# deallocation in CUDA memory.
#
# The below events participate in a tensor through a portal.
# Note that [x] denotes the events which call use_tensor():
#
# 1. [x] blue()
# 2. [ ] PortalBlue.forward
# 3. [ ] copy()
# 4. [ ] PortalCopy.forward
# 5. [ ] orange()
# 6. [x] PortalOrange.forward
# - - - - - - - - - - - - - - - - - - - - - - - - - - -
# 7. [ ] orange() (recomputed)
# 8. [x] PortalOrange.forward (recomputed)
# 9. [ ] PortalOrange.backward
# 10. [ ] PortalCopy.backward
# 11. [x] blue() (recomputed)
# 12. [ ] PortalBlue.forward (recomputed)
# 13. [ ] PortalBlue.backward
#
self.tensor_life = tensor_life
if tensor_life > 0:
self.tensor = tensor
else:
self.tensor = None
def use_tensor(self) -> Optional[Tensor]:
"""Retrieves the underlying tensor and decreases the tensor life. When
the life becomes 0, it the tensor will be removed.
"""
self.check_tensor_life()
tensor = self.tensor
self.tensor_life -= 1
if self.tensor_life <= 0:
self.tensor = None
return tensor
def put_grad(self, grad: Tensor) -> None:
"""Stores a gradient into this portal."""
self.grad = grad
def use_grad(self) -> Tensor:
"""Retrieves and removes the underlying gradient. The gradient is
always ephemeral.
"""
if self.grad is None:
raise RuntimeError("grad in portal has been removed or never set")
grad = self.grad
self.grad = None
return grad
# Common interface between :class:`PortalBlue`, :class:`PortalOrange`, and
# :class:`PortalCopy`.
class Context(CopyContext):
portal: Portal
class PortalBlue(torch.autograd.Function):
"""Hides a tensor from the autograd engine by a :class:`Portal`."""
@staticmethod
# type: ignore
def forward(
ctx: Context,
portal: Portal,
# This tensor must be retrieved by portal.use_tensor().
tensor: Tensor,
) -> Tensor:
ctx.portal = portal
phony = get_phony(tensor.device, requires_grad=False)
return phony.detach()
@staticmethod
# type: ignore
def backward(ctx: Context, grad_phony: Tensor,) -> Tuple[None, Tensor]:
# The paired PortalOrange should keep the gradient.
grad = ctx.portal.use_grad()
return None, grad
class PortalOrange(torch.autograd.Function):
"""Retrieves the hidden tensor from a :class:`Portal`."""
@staticmethod
# type: ignore
def forward(ctx: Context, portal: Portal, phony: Tensor) -> Tensor:
ctx.portal = portal
tensor = portal.use_tensor()
assert tensor is not None
return tensor.detach()
@staticmethod
def backward(ctx: Context, grad: Tensor) -> Tuple[None, None]: # type: ignore
# The paired PortalBlue will use the gradient.
ctx.portal.put_grad(grad)
return None, None
class PortalCopy(torch.autograd.Function):
"""Copies the hidden tensor in a :class:`Portal`. It replaces the hidden
tensor with copied one.
"""
@staticmethod
# type: ignore
def forward(
ctx: Context, portal: Portal, prev_stream: AbstractStream, next_stream: AbstractStream, phony: Tensor,
) -> Tensor:
ctx.portal = portal
assert portal.tensor is not None
(portal.tensor,) = Copy.forward(ctx, prev_stream, next_stream, portal.tensor)
phony = get_phony(get_device(next_stream), requires_grad=False)
return phony.detach()
@staticmethod
# type: ignore
def backward(ctx: Context, grad_phony: Tensor,) -> Tuple[None, None, None, None]:
portal = ctx.portal
assert portal.grad is not None
_, _, portal.grad = Copy.backward(ctx, portal.grad)
return None, None, None, None

View File

@ -0,0 +1,439 @@
# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
"""The user interface to define skip connections."""
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Dict,
FrozenSet,
Generator,
Iterable,
List,
Optional,
Set,
Tuple,
Type,
TypeVar,
Union,
cast,
)
from torch import Tensor, nn
from ..microbatch import Batch
from .namespace import Namespace
from .tracker import current_skip_tracker
__all__ = ["skippable", "stash", "pop", "verify_skippables"]
Tensors = Tuple[Tensor, ...]
TensorOrTensors = Union[Tensor, Tensors]
StashPop = Union["stash", "pop"]
StashPopGenerator = Generator[StashPop, Optional[Tensor], TensorOrTensors]
if TYPE_CHECKING:
SkippableModule = nn.Module[Union[StashPopGenerator, TensorOrTensors]]
else:
SkippableModule = nn.Module
T = TypeVar("T", bound="Skippable")
class Skippable(nn.Module):
"""The base class for skippable modules.
Do not use this class directly. Define a subclass by :func:`skippable`
instead.
"""
module_cls: ClassVar[Type[SkippableModule]]
stashable_names: ClassVar[FrozenSet[str]]
poppable_names: ClassVar[FrozenSet[str]]
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__()
self.module = self.module_cls(*args, **kwargs) # type: ignore
self.namespaces: Dict[str, Namespace] = {}
def __repr__(self) -> str:
return f"@skippable({self.module})"
def namespaced(self, name: str) -> Tuple[Namespace, str]:
"""Prepends namespace for the given skip name."""
ns = self.namespaces.get(name)
ns = cast(Namespace, ns)
return (ns, name)
def stashable(self) -> Iterable[Tuple[Namespace, str]]:
"""Iterates over namespaced skip names to be stashed."""
for name in self.stashable_names:
yield self.namespaced(name)
def poppable(self) -> Iterable[Tuple[Namespace, str]]:
"""Iterates over namespaced skip names to be popped."""
for name in self.poppable_names:
yield self.namespaced(name)
def isolate(self: T, ns: Namespace, *, only: Optional[Iterable[str]] = None) -> T:
r"""Isolates a specified subset or the whole set of skip tensors into a
namespace. In a single sequential module, skip tensors with the same
name are not allowed unless they are isolated by different namespaces.
Here's an example using the same name for skip tensors twice. Each pair
of ``Layer1`` and ``Layer2`` is isolated with its own namespace ``ns1``
and ``ns2``. There is no conflict anymore::
ns1 = Namespace()
ns2 = Namespace()
model = nn.Sequential(
Layer1().isolate(ns1),
Layer1().isolate(ns2),
Layer2(),
Layer3().isolate(ns2),
Layer3().isolate(ns1),
)
When `only` parameter is omitted, all skip tensors are isolated. You
can isolate a subset of skip tensors by passing `only` parameter::
ns_alice = Namespace()
ns_bob = Namespace()
model = nn.Sequential(
...
StashStashPop().isolate(ns_alice, only=['alice']) \
.isolate(ns_bob, only=['bob']),
...
)
Args:
ns (Namespace):
namespace for isolation
Keyword Args:
only (iterable of strs):
names of specific skip tensors to be isolated (omit this option
to isolate all skip tensors declared in this module)
Returns:
this module itself
"""
names: Iterable[str]
if only is None:
names = self.stashable_names | self.poppable_names
else:
names = set(only)
for name in names:
self.namespaces[name] = ns
return self
def dispatch(
self,
input: TensorOrTensors,
handle_stash: Callable[[str, Optional[Tensor]], None],
handle_pop: Callable[[str], Optional[Tensor]],
) -> TensorOrTensors:
"""Dispatches :class:`stash` or :class:`pop` commands generated by the
module's ``forward()``.
"""
generator = self.module(input)
if not isinstance(generator, Generator):
# The underlying module returned output without any yield.
output = generator
return output
try:
op = next(generator)
while True:
if isinstance(op, stash):
handle_stash(op.name, op.tensor)
op = next(generator)
continue
if isinstance(op, pop):
tensor = handle_pop(op.name)
op = generator.send(tensor)
continue
raise TypeError("%r is not a command from @skippable" % op)
except StopIteration as stop:
output = stop.args[0]
return output
def forward(self, input: TensorOrTensors) -> TensorOrTensors: # type: ignore
"""Performs the forward propagation. :class:`stash` or :class:`pop`
commands will be handled by portals silently. The portals won't be
exposed to users.
Raises:
RuntimeError:
illegal 'stash' or 'pop' is found.
"""
skip_tracker = current_skip_tracker()
stashed_tensors: Dict[str, Optional[Tensor]] = {}
# Load skip tensors that might be popped.
poppable_tensors = {}
batch = Batch(input)
for ns, name in self.poppable():
try:
poppable_tensors[name] = skip_tracker.load(batch, ns, name)
except KeyError:
raise RuntimeError(f"'{name}' has not been stashed")
input = batch.tensor_or_tensors
# Handle skip commands.
def handle_stash(name: str, tensor: Optional[Tensor]) -> None:
if name not in self.stashable_names:
raise RuntimeError(f"'{name}' has not been declared as stashable")
stashed_tensors[name] = tensor
def handle_pop(name: str) -> Optional[Tensor]:
if name not in self.poppable_names:
raise RuntimeError(f"'{name}' has not been declared as poppable")
return poppable_tensors.pop(name)
output = self.dispatch(input, handle_stash, handle_pop)
# All declared skips must be stashed or popped.
not_stashed = self.stashable_names - stashed_tensors.keys()
if not_stashed:
comma_names = ", ".join("'%s'" % n for n in not_stashed)
raise RuntimeError(f"{comma_names} must be stashed but have not")
not_popped = poppable_tensors.keys()
if not_popped:
comma_names = ", ".join("'%s'" % n for n in not_popped)
raise RuntimeError(f"{comma_names} must be popped but have not")
# Save stashed skip tensors.
batch = Batch(output)
for ns, name in self.stashable():
tensor = stashed_tensors[name]
skip_tracker.save(batch, ns, name, tensor)
output = batch.tensor_or_tensors
return output
# TODO(sublee): Move to above of Skippable class for better read flow.
def skippable(
stash: Iterable[str] = (), pop: Iterable[str] = (),
) -> Callable[[Type[SkippableModule]], Type[Skippable]]:
"""The decorator to define a :class:`nn.Module <torch.nn.Module>` with skip
connections. Decorated modules are called "skippable". This functionality
works perfectly fine even when the module is not wrapped by
:class:`~torchpipe.Pipe`.
Each skip tensor is managed by its name. Before manipulating skip tensors,
a skippable module must statically declare the names for skip tensors by
`stash` and/or `pop` parameters. Skip tensors with pre-declared name can be
stashed by ``yield stash(name, tensor)`` or popped by ``tensor = yield
pop(name)``.
Here is an example with three layers. A skip tensor named "1to3" is stashed
and popped at the first and last layer, respectively::
@skippable(stash=['1to3'])
class Layer1(nn.Module):
def forward(self, input):
yield stash('1to3', input)
return f1(input)
class Layer2(nn.Module):
def forward(self, input):
return f2(input)
@skippable(pop=['1to3'])
class Layer3(nn.Module):
def forward(self, input):
skip_1to3 = yield pop('1to3')
return f3(input) + skip_1to3
model = nn.Sequential(Layer1(), Layer2(), Layer3())
One skippable module can stash or pop multiple skip tensors::
@skippable(stash=['alice', 'bob'], pop=['carol'])
class StashStashPop(nn.Module):
def forward(self, input):
yield stash('alice', f_alice(input))
yield stash('bob', f_bob(input))
carol = yield pop('carol')
return input + carol
Every skip tensor must be associated with exactly one pair of `stash` and
`pop`. :class:`~torchpipe.Pipe` checks this restriction automatically
when wrapping a module. You can also check the restriction by
:func:`~torchpipe.skip.verify_skippables` without
:class:`~torchpipe.Pipe`.
.. note::
:func:`@skippable <skippable>` changes the type of the wrapped class.
But currently (mypy v0.740), mypy could not understand class decorators
yet (`#3135 <https://github.com/python/mypy/issues/3135>`_).
There are two workarounds:
1. Naively ignore type errors by ``# type: ignore``.
2. Use ``skippable()()`` as a function instead of a decorator.
.. seealso:: :ref:`Long Skip Connections`
"""
stashable_names = frozenset(stash)
poppable_names = frozenset(pop)
def extend_skippable(module_cls: Type[SkippableModule]) -> Type[Skippable]:
name = module_cls.__name__
bases = (Skippable,)
attrs = {"module_cls": module_cls, "stashable_names": stashable_names, "poppable_names": poppable_names}
return type(name, bases, attrs)
return extend_skippable
class stash:
"""The command to stash a skip tensor.
::
def forward(self, input):
yield stash('name', input)
return f(input)
Args:
name (str): name of skip tensor
input (torch.Tensor or None): tensor to pass to the skip connection
"""
__slots__ = ("name", "tensor")
def __init__(self, name: str, tensor: Optional[Tensor]) -> None:
self.name = name
self.tensor = tensor
class pop:
"""The command to pop a skip tensor.
::
def forward(self, input):
skip = yield pop('name')
return f(input) + skip
Args:
name (str): name of skip tensor
Returns:
the skip tensor previously stashed by another layer under the same name
"""
__slots__ = ("name",)
def __init__(self, name: str) -> None:
self.name = name
def verify_skippables(module: nn.Sequential) -> None:
"""Verifies if the underlying skippable modules satisfy integrity.
Every skip tensor must have only one pair of `stash` and `pop`. If there
are one or more unmatched pairs, it will raise :exc:`TypeError` with the
detailed messages.
Here are a few failure cases. :func:`verify_skippables` will report failure
for these cases::
# Layer1 stashes "1to3".
# Layer3 pops "1to3".
nn.Sequential(Layer1(), Layer2())
# └──── ?
nn.Sequential(Layer2(), Layer3())
# ? ────┘
nn.Sequential(Layer1(), Layer2(), Layer3(), Layer3())
# └───────────────────┘ ^^^^^^
nn.Sequential(Layer1(), Layer1(), Layer2(), Layer3())
# ^^^^^^ └───────────────────┘
To use the same name for multiple skip tensors, they must be isolated by
different namespaces. See :meth:`isolate()
<torchpipe.skip.skippable.Skippable.isolate>`.
Raises:
TypeError:
one or more pairs of `stash` and `pop` are not matched.
"""
stashed: Set[Tuple[Namespace, str]] = set()
popped: Set[Tuple[Namespace, str]] = set()
msgs: List[str] = []
for layer_name, layer in module.named_children():
if not isinstance(layer, Skippable):
continue
for name in layer.stashable_names & layer.poppable_names:
msg = f"'{layer_name}' declared '{name}' both as stashable and as poppable"
msgs.append(msg)
for ns, name in layer.stashable():
if name in layer.poppable_names:
continue
if (ns, name) in stashed:
msg = f"'{layer_name}' redeclared '{name}' as stashable " "but not isolated by namespace"
msgs.append(msg)
continue
stashed.add((ns, name))
for ns, name in layer.poppable():
if name in layer.stashable_names:
continue
if (ns, name) in popped:
msg = f"'{layer_name}' redeclared '{name}' as poppable " "but not isolated by namespace"
msgs.append(msg)
continue
if (ns, name) not in stashed:
msg = f"'{layer_name}' declared '{name}' as poppable but it was not stashed"
msgs.append(msg)
continue
popped.add((ns, name))
for (_, name) in stashed - popped:
msg = f"no module declared '{name}' as poppable but stashed"
msgs.append(msg)
if msgs:
raise TypeError(
"one or more pairs of stash and pop do not match:\n\n%s" "" % "\n".join("* %s" % x for x in msgs)
)

View File

@ -0,0 +1,177 @@
# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
"""Tracks skip tensors on a thread."""
from contextlib import contextmanager
import threading
from typing import Dict, Generator, List, Optional, Tuple
from torch import Tensor
from ..checkpoint import is_checkpointing
from ..dependency import fork, join
from ..microbatch import Batch
from ..stream import AbstractStream
from .layout import SkipLayout
from .namespace import Namespace
from .portal import Portal
__all__: List[str] = []
class SkipTracker:
"""Tracks saved skip tensors.
It will update the given micro-batch in place. This is because when it
manipulates the underlying skip tensors, the current micro-batch also has
to be connected with the skip tensors.
One thread has one skip tracker. Call :func:`current_skip_tracker` to get
the skip tracker on the current thread.
"""
def __init__(self) -> None:
self.tensors: Dict[Tuple[Namespace, str], Optional[Tensor]] = {}
def save(self, batch: Batch, ns: Namespace, name: str, tensor: Optional[Tensor]) -> None:
self.tensors[(ns, name)] = tensor
def load(self, batch: Batch, ns: Namespace, name: str) -> Optional[Tensor]:
return self.tensors.pop((ns, name))
def copy(
self, batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream, ns: Namespace, name: str,
) -> None:
raise TypeError("copy is not supported for non-portal skip tensors")
class SkipTrackerThroughPotals(SkipTracker):
"""Tracks saved skip tensors through portals. The skip tensors will be
hidden in portals so that the autograd engine does not need to track them.
This tracker is only used when the training or evaluating module is wrapped
with :class:`torchpipe.Pipe`.
"""
def __init__(self, skip_layout: SkipLayout) -> None:
super().__init__()
self.skip_layout = skip_layout
self.portals: Dict[Tuple[Namespace, str], Portal] = {}
def save(self, batch: Batch, ns: Namespace, name: str, tensor: Optional[Tensor]) -> None:
"""Saves the stashed skip tensor in a portal. The portal is then
connected to the given micro-batch with :class:`Join`.
"""
if not self.skip_layout.requires_copy(ns, name):
super().save(batch, ns, name, tensor)
return
# See [Tensor Life of Portal] at Portal.put_tensor() to understand the
# below tensor_life values. Here are the selected events which retrieve
# the tensor in portal:
#
# 1. [x] blue()
# ...
# 6. [x] PortalOrange.forward
# ...
# 8. [x] PortalOrange.forward (recomputed)
# ...
# 11. [x] blue() (recomputed)
#
if (ns, name) not in self.portals:
if is_checkpointing():
# Under checkpointing, the tensor used by the first
# PortalOrange should be alive in the portal. This tensor will
# be used again by the second PortalOrange during the
# recomputation.
tensor_life = 3 # Delete at [8. PortalOrange.forward (recomputed)]
else:
tensor_life = 2 # Delete at [6. PortalOrange.forward]
portal = Portal(tensor, tensor_life)
self.portals[(ns, name)] = portal
else:
# Under recomputation, the portal already exists.
portal = self.portals[(ns, name)]
# The existing tensor life already became 0. It should be reset as
# 1 to delete the tensor after the second PortalBlue immediately.
tensor_life = 1 # Delete at [11. blue() (recomputed)]
portal.put_tensor(tensor, tensor_life)
phony = portal.blue()
batch[0] = join(batch[0], phony)
def load(self, batch: Batch, ns: Namespace, name: str) -> Optional[Tensor]:
"""Loads a skip tensor from the corresponding portal to pop. The given
micro-batch is connected to the portal with :class:`Fork`.
"""
if not self.skip_layout.requires_copy(ns, name):
tensor = super().load(batch, ns, name)
return tensor
portal = self.portals[(ns, name)]
batch[0], phony = fork(batch[0])
tensor = portal.orange(phony)
return tensor
def copy(
self, batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream, ns: Namespace, name: str,
) -> None:
"""Copies the skip tensor in the corresponding portal. The given
micro-batch and the portal will be tied with :class:`Fork` and
:class:`Join`.
"""
assert self.skip_layout.requires_copy(ns, name)
batch[0], phony = fork(batch[0])
portal = self.portals[(ns, name)]
phony = portal.copy(prev_stream, next_stream, phony)
batch[0] = join(batch[0], phony)
class ThreadLocal(threading.local):
def __init__(self) -> None:
self.skip_tracker: Optional[SkipTracker] = None
thread_local = ThreadLocal()
@contextmanager
def use_skip_tracker(skip_tracker: SkipTracker) -> Generator[None, None, None]:
"""Registers the given skip tracker on the current thread within a
context::
with use_skip_tracker(my_skip_tracker):
...
"""
orig = thread_local.skip_tracker
thread_local.skip_tracker = skip_tracker
try:
yield
finally:
thread_local.skip_tracker = orig
def current_skip_tracker() -> SkipTracker:
"""Gets the skip tracker on the current thread."""
skip_tracker = thread_local.skip_tracker
if skip_tracker is None:
skip_tracker = SkipTracker()
thread_local.skip_tracker = skip_tracker
return skip_tracker

View File

@ -0,0 +1,117 @@
# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
"""Utilities for eliminating boilerplate code to handle abstract streams with
CPU device.
"""
from contextlib import contextmanager
from typing import Generator, List, Union, cast
import torch
__all__: List[str] = []
class CPUStreamType:
pass
# The placeholder on place of streams for the CPU device instead of CUDA.
CPUStream = CPUStreamType()
# It represents both CUDA streams and the CPU stream.
AbstractStream = Union[torch.cuda.Stream, CPUStreamType]
def new_stream(device: torch.device) -> AbstractStream:
"""Creates a new stream for either CPU or CUDA device."""
if device.type != "cuda":
return CPUStream
return torch.cuda.Stream(device)
def current_stream(device: torch.device) -> AbstractStream:
""":func:`torch.cuda.current_stream` for either CPU or CUDA device."""
if device.type != "cuda":
return CPUStream
return torch.cuda.current_stream(device)
def default_stream(device: torch.device) -> AbstractStream:
""":func:`torch.cuda.default_stream` for either CPU or CUDA device."""
if device.type != "cuda":
return CPUStream
return torch.cuda.default_stream(device)
@contextmanager
def use_device(device: torch.device) -> Generator[None, None, None]:
""":func:`torch.cuda.device` for either CPU or CUDA device."""
if device.type != "cuda":
yield
return
with torch.cuda.device(device):
yield
@contextmanager
def use_stream(stream: AbstractStream) -> Generator[None, None, None]:
""":func:`torch.cuda.stream` for either CPU or CUDA stream."""
if not is_cuda(stream):
yield
return
with torch.cuda.stream(as_cuda(stream)):
yield
def get_device(stream: AbstractStream) -> torch.device:
"""Gets the device from CPU or CUDA stream."""
if is_cuda(stream):
return as_cuda(stream).device
return torch.device("cpu")
def wait_stream(source: AbstractStream, target: AbstractStream) -> None:
""":meth:`torch.cuda.Stream.wait_stream` for either CPU or CUDA stream. It
makes the source stream wait until the target stream completes work queued.
"""
if is_cuda(target):
if is_cuda(source):
# A CUDA stream waits another CUDA stream.
as_cuda(source).wait_stream(as_cuda(target))
else:
# CPU waits a CUDA stream.
as_cuda(target).synchronize()
# If the target is CPU, synchronization is not required.
def record_stream(tensor: torch.Tensor, stream: AbstractStream) -> None:
""":meth:`torch.Tensor.record_stream` for either CPU or CUDA stream."""
if is_cuda(stream):
# NOTE(sublee): record_stream() on a shifted view tensor throws
# RuntimeError in PyTorch 1.1.0, and does nothing in 1.2.0. To safely
# protect the tensor against unexpected reallocation, here we use a
# temporal tensor associated with the same storage without shifting as
# a workaround.
#
# Issue: https://github.com/pytorch/pytorch/issues/27366
#
tensor = tensor.new_empty([0]).set_(tensor.storage())
tensor.record_stream(as_cuda(stream))
def is_cuda(stream: AbstractStream) -> bool:
"""Returns ``True`` if the given stream is a valid CUDA stream."""
return stream is not CPUStream
def as_cuda(stream: AbstractStream) -> torch.cuda.Stream:
"""Casts the given stream as :class:`torch.cuda.Stream`."""
return cast(torch.cuda.Stream, stream)

View File

@ -0,0 +1,151 @@
# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
"""Multithreading in pipeline parallelism."""
from contextlib import contextmanager
from queue import Queue
import sys
from threading import Thread
from types import TracebackType
from typing import TYPE_CHECKING, Callable, Dict, Generator, List, Optional, Tuple, Type, Union, cast
import torch
from .microbatch import Batch
from .stream import AbstractStream, use_device, use_stream
__all__: List[str] = []
ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType]
# Queue is generic only in stubs.
# https://mypy.readthedocs.io/en/latest/common_issues.html#using-classes-that-are-generic-in-stubs-but-not-at-runtime
if TYPE_CHECKING:
InQueue = Queue[Optional["Task"]]
OutQueue = Queue[Tuple[bool, Union[Tuple["Task", Batch], ExcInfo, None]]]
else:
InQueue = Queue
OutQueue = Queue
class Task:
"""A task represents how to compute a micro-batch on a partition.
It consists of two parts: :meth:`compute` and :meth:`finalize`.
:meth:`compute` should be executed in worker threads concurrently.
:meth:`finalize` should be executed after when worker threads complete to
execute :meth:`compute`.
:meth:`compute` might be boosted by worker threads. Because it produces
several CUDA API calls by user code. In PyTorch, parallel CUDA API calls
are not serialized through GIL. So more than one CUDA API call can be
produced at the same time.
"""
def __init__(
self, stream: AbstractStream, *, compute: Callable[[], Batch], finalize: Optional[Callable[[Batch], None]],
) -> None:
self.stream = stream
self._compute = compute
self._finalize = finalize
self._grad_enabled = torch.is_grad_enabled()
def compute(self) -> Batch:
with use_stream(self.stream), torch.set_grad_enabled(self._grad_enabled):
return self._compute()
def finalize(self, batch: Batch) -> None:
if self._finalize is None:
return
with use_stream(self.stream), torch.set_grad_enabled(self._grad_enabled):
self._finalize(batch)
def worker(in_queue: InQueue, out_queue: OutQueue, device: torch.device) -> None:
"""The main loop of a worker thread."""
with use_device(device):
while True:
task = in_queue.get()
if task is None:
break
try:
batch = task.compute()
except Exception:
exc_info = cast(ExcInfo, sys.exc_info())
out_queue.put((False, exc_info))
continue
out_queue.put((True, (task, batch)))
done = (False, None)
out_queue.put(done)
def create_workers(devices: List[torch.device],) -> Tuple[List[InQueue], List[OutQueue]]:
"""Spawns worker threads. A worker thread is bound to a device."""
in_queues: List[InQueue] = []
out_queues: List[OutQueue] = []
# Spawn workers.
workers: Dict[torch.device, Tuple[InQueue, OutQueue]] = {}
def normalize_device(device: torch.device) -> torch.device:
if device.type == "cuda" and device.index is None:
return torch.device("cuda", index=torch.cuda.current_device())
if device.type == "cpu" and device.index is not None:
return torch.device("cpu")
return device
for device in devices:
device = normalize_device(device)
try:
in_queue, out_queue = workers[device]
except KeyError:
in_queue = Queue()
out_queue = Queue()
workers[device] = (in_queue, out_queue)
t = Thread(target=worker, args=(in_queue, out_queue, device), daemon=True,)
t.start()
in_queues.append(in_queue)
out_queues.append(out_queue)
return (in_queues, out_queues)
def join_workers(in_queues: List[InQueue], out_queues: List[OutQueue]) -> None:
# Close workers.
for in_queue in set(in_queues):
in_queue.put(None)
# Join running workers.
running = set(out_queues)
while running:
out_queue = running.pop()
ok, payload = out_queue.get()
done = (False, None)
if (ok, payload) == done:
continue
running.add(out_queue)
@contextmanager
def spawn_workers(devices: List[torch.device],) -> Generator[Tuple[List[InQueue], List[OutQueue]], None, None]:
try:
(in_queues, out_queues) = create_workers(devices)
yield (in_queues, out_queues)
finally:
join_workers(in_queues, out_queues)