mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull in fairscale.nn.Pipe into PyTorch. (#44090)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/44090 This is an initial commit pulling in the torchgpipe fork at https://github.com/facebookresearch/fairscale. The purpose of this commit is to just pull in the code and ensure all tests and builds work fine. We will slowly modify this to match our intended API mentioned in https://fb.quip.com/txurAV3zIFox#RPZACAfAKMq. Follow up PRs would address further changes needed on top of the initial commit.. We're pulling the code into the `torch.distributed._pipeline.sync` package. The package is private on purpose since there is a lot of work (ex: docs, API changes etc.) that needs to go in before we can actually officially support this. ghstack-source-id: 114864254 Test Plan: 1) waitforbuildbot 2) Ran all tests on my devgpu Reviewed By: mrshenli Differential Revision: D23493316 fbshipit-source-id: fe3c8b7dadeeb86abdc00e8a8652491b0b16743a
This commit is contained in:
committed by
Facebook GitHub Bot
parent
b63ddd6f57
commit
06d50b5eb0
13
LICENSE
13
LICENSE
@ -16,23 +16,26 @@ Copyright (c) 2016-present, Facebook Inc. All rights reserved.
|
||||
|
||||
All contributions by Facebook:
|
||||
Copyright (c) 2016 Facebook Inc.
|
||||
|
||||
|
||||
All contributions by Google:
|
||||
Copyright (c) 2015 Google Inc.
|
||||
All rights reserved.
|
||||
|
||||
|
||||
All contributions by Yangqing Jia:
|
||||
Copyright (c) 2015 Yangqing Jia
|
||||
All rights reserved.
|
||||
|
||||
|
||||
All contributions by Kakao Brain:
|
||||
Copyright 2019-2020 Kakao Brain
|
||||
|
||||
All contributions from Caffe:
|
||||
Copyright(c) 2013, 2014, 2015, the respective contributors
|
||||
All rights reserved.
|
||||
|
||||
|
||||
All other contributions:
|
||||
Copyright(c) 2015, 2016 the respective contributors
|
||||
All rights reserved.
|
||||
|
||||
|
||||
Caffe2 uses a copyright model similar to Caffe: each contributor holds
|
||||
copyright over their contributions to Caffe2. The project versioning records
|
||||
all such contribution and copyright details. If a contributor wants to further
|
||||
|
3
NOTICE
3
NOTICE
@ -22,6 +22,9 @@ All contributions by Yangqing Jia:
|
||||
Copyright (c) 2015 Yangqing Jia
|
||||
All rights reserved.
|
||||
|
||||
All contributions by Kakao Brain:
|
||||
Copyright 2019-2020 Kakao Brain
|
||||
|
||||
All other contributions:
|
||||
Copyright(c) 2015, 2016 the respective contributors
|
||||
All rights reserved.
|
||||
|
27
test/distributed/_pipeline/sync/LICENSE
Normal file
27
test/distributed/_pipeline/sync/LICENSE
Normal file
@ -0,0 +1,27 @@
|
||||
Copyright 2019-2020 Kakao Brain
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from this
|
||||
software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
|
||||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
POSSIBILITY OF SUCH DAMAGE.
|
8
test/distributed/_pipeline/sync/__init__.py
Normal file
8
test/distributed/_pipeline/sync/__init__.py
Normal file
@ -0,0 +1,8 @@
|
||||
# Copyright 2019 Kakao Brain
|
||||
#
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# tests/__init__.py makes pytest can import the application without custom sys.path or PYTHONPATH.
|
||||
# See also: https://docs.pytest.org/en/latest/goodpractices.html
|
37
test/distributed/_pipeline/sync/conftest.py
Normal file
37
test/distributed/_pipeline/sync/conftest.py
Normal file
@ -0,0 +1,37 @@
|
||||
# Copyright 2019 Kakao Brain
|
||||
#
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def manual_seed_zero():
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def cuda_sleep():
|
||||
# Warm-up CUDA.
|
||||
torch.empty(1, device="cuda")
|
||||
|
||||
# From test/test_cuda.py in PyTorch.
|
||||
start = torch.cuda.Event(enable_timing=True)
|
||||
end = torch.cuda.Event(enable_timing=True)
|
||||
start.record()
|
||||
torch.cuda._sleep(1000000)
|
||||
end.record()
|
||||
end.synchronize()
|
||||
cycles_per_ms = 1000000 / start.elapsed_time(end)
|
||||
|
||||
def cuda_sleep(seconds):
|
||||
torch.cuda._sleep(int(seconds * cycles_per_ms * 1000))
|
||||
|
||||
return cuda_sleep
|
||||
|
||||
|
||||
def pytest_report_header():
|
||||
return f"torch: {torch.__version__}"
|
6
test/distributed/_pipeline/sync/skip/__init__.py
Normal file
6
test/distributed/_pipeline/sync/skip/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
# Copyright 2019 Kakao Brain
|
||||
#
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
45
test/distributed/_pipeline/sync/skip/test_api.py
Normal file
45
test/distributed/_pipeline/sync/skip/test_api.py
Normal file
@ -0,0 +1,45 @@
|
||||
# Copyright 2019 Kakao Brain
|
||||
#
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
import copy
|
||||
|
||||
from torch import nn
|
||||
|
||||
from torch.distributed._pipeline.sync.skip import Namespace, skippable, stash
|
||||
|
||||
|
||||
def test_namespace_difference():
|
||||
ns1 = Namespace()
|
||||
ns2 = Namespace()
|
||||
assert ns1 != ns2
|
||||
|
||||
|
||||
def test_namespace_copy():
|
||||
ns = Namespace()
|
||||
assert copy.copy(ns) == ns
|
||||
assert copy.copy(ns) is not ns
|
||||
|
||||
|
||||
def test_skippable_repr():
|
||||
@skippable(stash=["hello"])
|
||||
class Hello(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(1, 1, 1)
|
||||
|
||||
def forward(self, x):
|
||||
yield stash("hello", x)
|
||||
return self.conv(x) # noqa
|
||||
|
||||
m = Hello()
|
||||
assert (
|
||||
repr(m)
|
||||
== """
|
||||
@skippable(Hello(
|
||||
(conv): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1))
|
||||
))
|
||||
""".strip()
|
||||
)
|
106
test/distributed/_pipeline/sync/skip/test_gpipe.py
Normal file
106
test/distributed/_pipeline/sync/skip/test_gpipe.py
Normal file
@ -0,0 +1,106 @@
|
||||
# Copyright 2019 Kakao Brain
|
||||
#
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from torch.distributed._pipeline.sync import Pipe
|
||||
from torch.distributed._pipeline.sync.skip import pop, skippable, stash
|
||||
from torch.distributed._pipeline.sync.skip.portal import PortalBlue, PortalCopy, PortalOrange
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
|
||||
@pytest.mark.parametrize("balance", [[3], [1, 2], [2, 1], [1, 1, 1]], ids=["3", "1:2", "2:1", "1:1:1"])
|
||||
@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"])
|
||||
def test_1to3(balance, checkpoint):
|
||||
if torch.cuda.device_count() < len(balance):
|
||||
pytest.skip("at least %d cuda devices required" % len(balance))
|
||||
|
||||
@skippable(stash=["1to3"])
|
||||
class Layer1(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(3, 3, 1)
|
||||
|
||||
def forward(self, input):
|
||||
yield stash("1to3", input)
|
||||
output = self.conv(input)
|
||||
return output # noqa
|
||||
|
||||
class Layer2(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(3, 3, 1)
|
||||
|
||||
def forward(self, input):
|
||||
output = self.conv(input)
|
||||
return output
|
||||
|
||||
@skippable(pop=["1to3"])
|
||||
class Layer3(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(3, 3, 1)
|
||||
|
||||
def forward(self, input):
|
||||
skip_1to3 = yield pop("1to3")
|
||||
output = self.conv(input) + skip_1to3
|
||||
return output
|
||||
|
||||
model = nn.Sequential(Layer1(), Layer2(), Layer3())
|
||||
model = Pipe(model, balance, chunks=3, checkpoint=checkpoint)
|
||||
|
||||
in_device = model.devices[0]
|
||||
out_device = model.devices[-1]
|
||||
|
||||
input = torch.rand(30, 3, 224, 224, device=in_device, requires_grad=True)
|
||||
output = model(input)
|
||||
loss = output.mean()
|
||||
loss.backward()
|
||||
|
||||
assert torch.allclose(output.norm(), torch.tensor(1039.0, device=out_device), atol=6e-1)
|
||||
assert torch.allclose(input.grad.norm(), torch.tensor(0.0004533053, device=in_device))
|
||||
|
||||
|
||||
def test_none_skip():
|
||||
@skippable(stash=["none"])
|
||||
class Stash(nn.Module):
|
||||
def forward(self, input):
|
||||
yield stash("none", None)
|
||||
return input # noqa
|
||||
|
||||
@skippable(pop=["none"])
|
||||
class Pop(nn.Module):
|
||||
def forward(self, input):
|
||||
none = yield pop("none")
|
||||
assert none is None
|
||||
return input
|
||||
|
||||
model = nn.Sequential(Stash(), Pop())
|
||||
model = Pipe(model, [1, 1], devices=["cpu", "cpu"], chunks=5)
|
||||
|
||||
input = torch.rand(10, requires_grad=True)
|
||||
output = model(input)
|
||||
|
||||
def assert_grad_fn_is_not_portal(grad_fn, visited=None):
|
||||
if visited is None:
|
||||
visited = set()
|
||||
if grad_fn in visited or grad_fn is None:
|
||||
return
|
||||
|
||||
assert not isinstance(grad_fn, PortalBlue._backward_cls)
|
||||
assert not isinstance(grad_fn, PortalCopy._backward_cls)
|
||||
assert not isinstance(grad_fn, PortalOrange._backward_cls)
|
||||
|
||||
visited.add(grad_fn)
|
||||
for next_grad_fn, _ in grad_fn.next_functions:
|
||||
assert_grad_fn_is_not_portal(next_grad_fn, visited)
|
||||
|
||||
assert_grad_fn_is_not_portal(output.grad_fn)
|
||||
|
||||
output.sum().backward()
|
||||
assert input.grad.mean().item() == 1
|
111
test/distributed/_pipeline/sync/skip/test_inspect_skip_layout.py
Normal file
111
test/distributed/_pipeline/sync/skip/test_inspect_skip_layout.py
Normal file
@ -0,0 +1,111 @@
|
||||
# Copyright 2019 Kakao Brain
|
||||
#
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
from torch import nn
|
||||
|
||||
from torch.distributed._pipeline.sync.skip import Namespace, pop, skippable, stash
|
||||
from torch.distributed._pipeline.sync.skip.layout import inspect_skip_layout
|
||||
|
||||
|
||||
class Pass(nn.Module):
|
||||
def forward(self, input):
|
||||
return input
|
||||
|
||||
|
||||
@skippable(stash=["foo"])
|
||||
class StashFoo(nn.Module):
|
||||
def forward(self, input):
|
||||
yield stash("foo", input)
|
||||
return input # noqa
|
||||
|
||||
|
||||
@skippable(pop=["foo"])
|
||||
class PopFoo(nn.Module):
|
||||
def forward(self, input):
|
||||
foo = yield stash("foo")
|
||||
return input + foo
|
||||
|
||||
|
||||
@skippable(stash=["bar"])
|
||||
class StashBar(nn.Module):
|
||||
def forward(self, input):
|
||||
yield stash("bar", input)
|
||||
return input # noqa
|
||||
|
||||
|
||||
@skippable(pop=["bar"])
|
||||
class PopBar(nn.Module):
|
||||
def forward(self, input):
|
||||
bar = yield pop("bar")
|
||||
return input + bar
|
||||
|
||||
|
||||
def test_no_skippables():
|
||||
p1 = nn.Sequential(Pass())
|
||||
p2 = nn.Sequential(Pass())
|
||||
|
||||
layout = inspect_skip_layout([p1, p2])
|
||||
policy = [list(layout.copy_policy(i)) for i in range(2)]
|
||||
|
||||
assert policy == [[], []]
|
||||
|
||||
|
||||
def test_inner_partition():
|
||||
p1 = nn.Sequential(StashFoo(), PopFoo())
|
||||
p2 = nn.Sequential(Pass())
|
||||
|
||||
layout = inspect_skip_layout([p1, p2])
|
||||
policy = [list(layout.copy_policy(i)) for i in range(2)]
|
||||
|
||||
assert policy == [[], []]
|
||||
|
||||
|
||||
def test_adjoining_partitions():
|
||||
p1 = nn.Sequential(StashFoo())
|
||||
p2 = nn.Sequential(PopFoo())
|
||||
|
||||
layout = inspect_skip_layout([p1, p2])
|
||||
policy = [list(layout.copy_policy(i)) for i in range(2)]
|
||||
|
||||
assert policy == [[], [(0, None, "foo")]]
|
||||
|
||||
|
||||
def test_far_partitions():
|
||||
p1 = nn.Sequential(StashFoo())
|
||||
p2 = nn.Sequential(Pass())
|
||||
p3 = nn.Sequential(PopFoo())
|
||||
|
||||
layout = inspect_skip_layout([p1, p2, p3])
|
||||
policy = [list(layout.copy_policy(i)) for i in range(3)]
|
||||
|
||||
assert policy == [[], [], [(0, None, "foo")]]
|
||||
|
||||
|
||||
def test_pop_2_from_different_partitions():
|
||||
p1 = nn.Sequential(StashFoo())
|
||||
p2 = nn.Sequential(StashBar())
|
||||
p3 = nn.Sequential(PopBar(), PopFoo())
|
||||
|
||||
layout = inspect_skip_layout([p1, p2, p3])
|
||||
policy = [list(layout.copy_policy(i)) for i in range(3)]
|
||||
|
||||
# p3 pops 'bar' before 'foo', but the plan is sorted by source partition index.
|
||||
assert policy == [[], [], [(0, None, "foo"), (1, None, "bar")]]
|
||||
|
||||
|
||||
def test_namespace():
|
||||
ns1 = Namespace()
|
||||
ns2 = Namespace()
|
||||
|
||||
p1 = nn.Sequential(StashFoo().isolate(ns1))
|
||||
p2 = nn.Sequential(StashFoo().isolate(ns2))
|
||||
p3 = nn.Sequential(PopFoo().isolate(ns2), PopFoo().isolate(ns1))
|
||||
|
||||
layout = inspect_skip_layout([p1, p2, p3])
|
||||
policy = [list(layout.copy_policy(i)) for i in range(3)]
|
||||
|
||||
# p3 pops 'bar' before 'foo', but the plan is sorted by source partition index.
|
||||
assert policy == [[], [], [(0, ns1, "foo"), (1, ns2, "foo")]]
|
126
test/distributed/_pipeline/sync/skip/test_leak.py
Normal file
126
test/distributed/_pipeline/sync/skip/test_leak.py
Normal file
@ -0,0 +1,126 @@
|
||||
# Copyright 2019 Kakao Brain
|
||||
#
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from torch.distributed._pipeline.sync import Pipe, is_checkpointing, is_recomputing
|
||||
from torch.distributed._pipeline.sync.skip import pop, skippable, stash
|
||||
from torch.distributed._pipeline.sync.skip.tracker import current_skip_tracker
|
||||
|
||||
|
||||
@skippable(stash=["skip"])
|
||||
class Stash(nn.Module):
|
||||
def forward(self, input):
|
||||
yield stash("skip", input)
|
||||
return input # noqa
|
||||
|
||||
|
||||
@skippable(pop=["skip"])
|
||||
class Pop(nn.Module):
|
||||
def forward(self, input):
|
||||
skip = yield pop("skip")
|
||||
return input + skip
|
||||
|
||||
|
||||
@pytest.mark.parametrize("train", [True, False], ids=["train", "eval"])
|
||||
@pytest.mark.parametrize("checkpoint", ["always", "except_last", "never"])
|
||||
def test_delete_portal_tensor(train, checkpoint):
|
||||
# Without checkpointing:
|
||||
# +- Stash --+ +--- Pop ----+ - - - layers
|
||||
# | 2,blue,1 |--| 1,orange,0 | - - - tensor_life and portal function
|
||||
# +----------+ +------------+
|
||||
#
|
||||
# With checkpointing:
|
||||
# +- Stash --+ +--- Pop ----+ +--- Pop'----+ +- Stash'--+
|
||||
# | 3,blue,2 |--| 2,orange,1 |--| 1,orange,0 |--| 1,blue,0 |
|
||||
# +----------+ +------------+ +------------+ +----------+
|
||||
|
||||
def portal_tensor_life_is(tensor_life, skip_tracker=None):
|
||||
if skip_tracker is None:
|
||||
skip_tracker = current_skip_tracker()
|
||||
|
||||
# Get the current portal.
|
||||
portal = list(skip_tracker.portals.values())[0]
|
||||
|
||||
if tensor_life == 0:
|
||||
return portal.tensor_life == 0 and portal.tensor is None
|
||||
else:
|
||||
return portal.tensor_life == tensor_life and portal.tensor is not None
|
||||
|
||||
# Check the portal tensor after 'Stash'.
|
||||
stash_ = Stash()
|
||||
|
||||
@stash_.register_forward_hook
|
||||
def check_portal_tensor_after_stash(*_):
|
||||
if is_checkpointing():
|
||||
assert portal_tensor_life_is(2)
|
||||
elif is_recomputing():
|
||||
assert portal_tensor_life_is(0)
|
||||
else:
|
||||
assert portal_tensor_life_is(1)
|
||||
|
||||
pop_ = Pop()
|
||||
|
||||
@pop_.register_forward_hook
|
||||
def check_portal_tensor_after_pop(*_):
|
||||
if is_checkpointing():
|
||||
assert portal_tensor_life_is(1)
|
||||
elif is_recomputing():
|
||||
assert portal_tensor_life_is(0)
|
||||
else:
|
||||
assert portal_tensor_life_is(0)
|
||||
|
||||
class NoPortalTensorAtBackward(nn.Module):
|
||||
class F(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
ctx.skip_tracker = current_skip_tracker()
|
||||
return input.detach()
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad):
|
||||
assert portal_tensor_life_is(0, skip_tracker=ctx.skip_tracker)
|
||||
return grad
|
||||
|
||||
def forward(self, input):
|
||||
return self.F.apply(input)
|
||||
|
||||
model = nn.Sequential(NoPortalTensorAtBackward(), stash_, pop_)
|
||||
model = Pipe(model, balance=[2, 1], devices=["cpu", "cpu"], chunks=2, checkpoint=checkpoint)
|
||||
|
||||
input = torch.rand(10, requires_grad=True)
|
||||
|
||||
if train:
|
||||
model.train()
|
||||
output = model(input)
|
||||
output.norm().backward()
|
||||
else:
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
model(input)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("train", [True, False], ids=["train", "eval"])
|
||||
def test_no_portal_without_pipe(train, monkeypatch):
|
||||
def deny(*args, **kwargs):
|
||||
raise AssertionError("tried to create Portal without Pipe")
|
||||
|
||||
monkeypatch.setattr("torch.distributed._pipeline.sync.skip.portal.Portal.__init__", deny)
|
||||
|
||||
model = nn.Sequential(Stash(), Pop())
|
||||
|
||||
input = torch.rand(10, requires_grad=True)
|
||||
|
||||
if train:
|
||||
model.train()
|
||||
output = model(input)
|
||||
output.norm().backward()
|
||||
else:
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
model(input)
|
155
test/distributed/_pipeline/sync/skip/test_portal.py
Normal file
155
test/distributed/_pipeline/sync/skip/test_portal.py
Normal file
@ -0,0 +1,155 @@
|
||||
# Copyright 2019 Kakao Brain
|
||||
#
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from torch.distributed._pipeline.sync.dependency import fork, join
|
||||
from torch.distributed._pipeline.sync.skip.portal import Portal
|
||||
from torch.distributed._pipeline.sync.stream import default_stream
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
|
||||
def test_copy_returns_on_next_device():
|
||||
portal = Portal(torch.rand(1), tensor_life=1)
|
||||
|
||||
prev_stream = default_stream(torch.device("cpu"))
|
||||
next_stream = default_stream(torch.device("cuda"))
|
||||
|
||||
phony = torch.zeros(0, requires_grad=True)
|
||||
assert phony.device.type == "cpu"
|
||||
|
||||
phony = portal.copy(prev_stream, next_stream, phony)
|
||||
assert phony.device.type == "cuda"
|
||||
|
||||
|
||||
def test_blue_orange():
|
||||
tensor1 = torch.rand(1, requires_grad=True)
|
||||
tensor2 = torch.rand(1, requires_grad=True)
|
||||
|
||||
# Same with: output = tensor1*2 + tensor2
|
||||
#
|
||||
# +----------------------+
|
||||
# | |
|
||||
# tensor2 -- PortalBlue -+ +- PortalOrange -+
|
||||
# | | |
|
||||
# tensor1 ------------ Join -- Fork --- Mul --- Add -- output
|
||||
#
|
||||
main = tensor1
|
||||
portal = Portal(tensor2, tensor_life=2)
|
||||
phony = portal.blue()
|
||||
main = join(main, phony)
|
||||
main, phony = fork(main)
|
||||
sub = portal.orange(phony)
|
||||
output = main * 2 + sub
|
||||
|
||||
output.backward()
|
||||
|
||||
assert torch.allclose(tensor1.grad, torch.tensor([2.0]))
|
||||
assert torch.allclose(tensor2.grad, torch.tensor([1.0]))
|
||||
|
||||
|
||||
def test_blue_orange_not_requires_grad():
|
||||
tensor1 = torch.rand(1, requires_grad=True)
|
||||
tensor2 = torch.rand(1)
|
||||
|
||||
# Same with: output = tensor1*2 + tensor2
|
||||
#
|
||||
# +----------------------+
|
||||
# | |
|
||||
# tensor2 -- PortalBlue -+ +- PortalOrange -+
|
||||
# | | |
|
||||
# tensor1 ------------ Join -- Fork --- Mul --- Add -- output
|
||||
#
|
||||
main = tensor1
|
||||
portal = Portal(tensor2, tensor_life=2)
|
||||
phony = portal.blue()
|
||||
main = join(main, phony)
|
||||
main, phony = fork(main)
|
||||
sub = portal.orange(phony)
|
||||
output = main * 2 + sub
|
||||
|
||||
output.backward()
|
||||
|
||||
assert torch.allclose(tensor1.grad, torch.tensor([2.0]))
|
||||
assert tensor2.grad is None
|
||||
|
||||
|
||||
def test_use_grad():
|
||||
tensor = torch.rand(1, requires_grad=True)
|
||||
portal = Portal(tensor, tensor_life=1)
|
||||
|
||||
portal.put_grad(tensor)
|
||||
assert portal.use_grad() is tensor
|
||||
|
||||
# Gradient in a portal is ephemeral.
|
||||
with pytest.raises(RuntimeError):
|
||||
portal.use_grad()
|
||||
|
||||
|
||||
class TestTensorLife:
|
||||
@pytest.fixture
|
||||
def new_portal(self):
|
||||
portal = None
|
||||
|
||||
def new_portal(tensor_life):
|
||||
nonlocal portal
|
||||
tensor = torch.rand(1, requires_grad=True)
|
||||
portal = Portal(tensor, tensor_life)
|
||||
return portal, tensor
|
||||
|
||||
yield new_portal
|
||||
|
||||
# A test using this fixture must exhaust the tensor in the portal.
|
||||
with pytest.raises(RuntimeError):
|
||||
portal.check_tensor_life()
|
||||
assert portal.tensor is None
|
||||
|
||||
def test_tensor_life_0(self, new_portal):
|
||||
portal, tensor = new_portal(0)
|
||||
assert portal.tensor is None
|
||||
|
||||
def test_tensor_life_1(self, new_portal):
|
||||
portal, tensor = new_portal(1)
|
||||
assert portal.tensor is tensor
|
||||
|
||||
portal.blue()
|
||||
|
||||
def test_tensor_life_2(self, new_portal):
|
||||
portal, tensor = new_portal(2)
|
||||
assert portal.tensor is tensor
|
||||
|
||||
phony = portal.blue()
|
||||
assert portal.orange(phony).data_ptr() == tensor.data_ptr()
|
||||
|
||||
def test_tensor_life_3(self, new_portal):
|
||||
portal, tensor = new_portal(3)
|
||||
assert portal.tensor is tensor
|
||||
|
||||
phony = portal.blue()
|
||||
assert portal.orange(phony).data_ptr() == tensor.data_ptr()
|
||||
assert portal.orange(phony).data_ptr() == tensor.data_ptr()
|
||||
|
||||
def test_tensor_life_4(self, new_portal):
|
||||
portal, tensor = new_portal(4)
|
||||
assert portal.tensor is tensor
|
||||
|
||||
phony = portal.blue()
|
||||
assert portal.orange(phony).data_ptr() == tensor.data_ptr()
|
||||
assert portal.orange(phony).data_ptr() == tensor.data_ptr()
|
||||
portal.blue()
|
||||
|
||||
def test_tensor_life_3_plus_1(self, new_portal):
|
||||
portal, tensor = new_portal(3)
|
||||
assert portal.tensor is tensor
|
||||
|
||||
phony = portal.blue()
|
||||
assert portal.orange(phony).data_ptr() == tensor.data_ptr()
|
||||
assert portal.orange(phony).data_ptr() == tensor.data_ptr()
|
||||
|
||||
another_tensor = torch.rand(1, requires_grad=True)
|
||||
portal.put_tensor(another_tensor, tensor_life=1)
|
||||
portal.blue()
|
136
test/distributed/_pipeline/sync/skip/test_stash_pop.py
Normal file
136
test/distributed/_pipeline/sync/skip/test_stash_pop.py
Normal file
@ -0,0 +1,136 @@
|
||||
# Copyright 2019 Kakao Brain
|
||||
#
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from torch.distributed._pipeline.sync.skip import pop, skippable, stash
|
||||
from torch.distributed._pipeline.sync.skip.tracker import SkipTracker, use_skip_tracker
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def skip_tracker():
|
||||
skip_tracker = SkipTracker()
|
||||
with use_skip_tracker(skip_tracker):
|
||||
yield skip_tracker
|
||||
|
||||
|
||||
def test_stash(skip_tracker):
|
||||
@skippable(stash=["foo"])
|
||||
class Stash(nn.Module):
|
||||
def forward(self, input):
|
||||
yield stash("foo", input)
|
||||
return input * 2 # noqa
|
||||
|
||||
l1 = Stash()
|
||||
|
||||
assert len(skip_tracker.tensors) == 0
|
||||
|
||||
with use_skip_tracker(skip_tracker):
|
||||
l1(torch.tensor(42))
|
||||
|
||||
assert len(skip_tracker.tensors) == 1
|
||||
|
||||
|
||||
def test_pop():
|
||||
@skippable(stash=["foo"])
|
||||
class Stash(nn.Module):
|
||||
def forward(self, input):
|
||||
yield stash("foo", input)
|
||||
return input * 2 # noqa
|
||||
|
||||
@skippable(pop=["foo"])
|
||||
class Pop(nn.Module):
|
||||
def forward(self, input):
|
||||
foo = yield pop("foo")
|
||||
return foo # noqa
|
||||
|
||||
l1 = Stash()
|
||||
l2 = Pop()
|
||||
|
||||
output = l2(l1(torch.tensor(42)))
|
||||
|
||||
assert output.item() == 42
|
||||
|
||||
|
||||
def test_declare_but_not_use():
|
||||
@skippable(stash=["foo"])
|
||||
class Stash(nn.Module):
|
||||
def forward(self, input):
|
||||
return input * 2
|
||||
|
||||
@skippable(pop=["foo"])
|
||||
class Pop(nn.Module):
|
||||
def forward(self, input):
|
||||
return input * 3
|
||||
|
||||
l1 = Stash()
|
||||
l2 = Pop()
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
l1(torch.tensor(42))
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
l2(torch.tensor(42))
|
||||
|
||||
|
||||
def test_stash_not_declared():
|
||||
@skippable()
|
||||
class Stash(nn.Module):
|
||||
def forward(self, input):
|
||||
yield stash("foo", input)
|
||||
return input * 2 # noqa
|
||||
|
||||
l1 = Stash()
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
l1(torch.tensor(42))
|
||||
|
||||
|
||||
def test_pop_not_declared():
|
||||
@skippable(stash=["foo"])
|
||||
class Stash(nn.Module):
|
||||
def forward(self, input):
|
||||
yield stash("foo", input)
|
||||
return input * 2 # noqa
|
||||
|
||||
@skippable()
|
||||
class Pop(nn.Module):
|
||||
def forward(self, input):
|
||||
foo = yield pop("foo")
|
||||
return foo # noqa
|
||||
|
||||
l1 = Stash()
|
||||
l2 = Pop()
|
||||
|
||||
latent = l1(torch.tensor(42))
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
l2(latent)
|
||||
|
||||
|
||||
def test_pop_not_stashed():
|
||||
@skippable(pop=["foo"])
|
||||
class Pop(nn.Module):
|
||||
def forward(self, input):
|
||||
yield pop("foo")
|
||||
|
||||
l1 = Pop()
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
l1(torch.tensor(42))
|
||||
|
||||
|
||||
def test_stash_none():
|
||||
@skippable(stash=["foo"])
|
||||
class Stash(nn.Module):
|
||||
def forward(self, input):
|
||||
yield stash("foo", None)
|
||||
return input * 2 # noqa
|
||||
|
||||
l1 = Stash()
|
||||
l1(torch.tensor(42))
|
127
test/distributed/_pipeline/sync/skip/test_tracker.py
Normal file
127
test/distributed/_pipeline/sync/skip/test_tracker.py
Normal file
@ -0,0 +1,127 @@
|
||||
# Copyright 2019 Kakao Brain
|
||||
#
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
from queue import Queue
|
||||
import threading
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from torch.distributed._pipeline.sync.checkpoint import enable_checkpointing, enable_recomputing
|
||||
from torch.distributed._pipeline.sync.microbatch import Batch
|
||||
from torch.distributed._pipeline.sync.skip import pop, skippable, stash
|
||||
from torch.distributed._pipeline.sync.skip.layout import SkipLayout
|
||||
from torch.distributed._pipeline.sync.skip.tracker import SkipTracker, SkipTrackerThroughPotals, current_skip_tracker
|
||||
|
||||
|
||||
def test_default_skip_tracker():
|
||||
q = Queue()
|
||||
|
||||
def f():
|
||||
q.put(current_skip_tracker())
|
||||
|
||||
t = threading.Thread(target=f)
|
||||
t.start()
|
||||
t.join()
|
||||
|
||||
skip_tracker = q.get()
|
||||
|
||||
assert type(skip_tracker) is SkipTracker
|
||||
assert type(skip_tracker) is not SkipTrackerThroughPotals
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
|
||||
def test_default_skip_tracker_by_data_parallel():
|
||||
@skippable(stash=["foo"])
|
||||
class Stash(nn.Module):
|
||||
def forward(self, input):
|
||||
yield stash("foo", input)
|
||||
return input * 2 # noqa
|
||||
|
||||
@skippable(pop=["foo"])
|
||||
class Pop(nn.Module):
|
||||
def forward(self, input):
|
||||
foo = yield pop("foo")
|
||||
return foo
|
||||
|
||||
model = nn.Sequential(Stash(), Pop())
|
||||
model = nn.DataParallel(model, device_ids=[0, 0], output_device=0)
|
||||
|
||||
input = torch.rand(10, device=0)
|
||||
output = model(input)
|
||||
|
||||
assert torch.allclose(output, input)
|
||||
|
||||
|
||||
def test_reuse_portal():
|
||||
skip_layout = SkipLayout(num_partitions=2, skip_routes={(None, "test"): (0, 1)})
|
||||
skip_tracker = SkipTrackerThroughPotals(skip_layout)
|
||||
|
||||
batch = Batch(torch.tensor([1.0]))
|
||||
a = torch.tensor([2.0])
|
||||
b = torch.tensor([2.0])
|
||||
|
||||
skip_tracker.save(batch, None, "test", a)
|
||||
portal = skip_tracker.portals[(None, "test")]
|
||||
|
||||
skip_tracker.save(batch, None, "test", b)
|
||||
assert portal is skip_tracker.portals[(None, "test")]
|
||||
|
||||
|
||||
def test_no_copy_no_portal():
|
||||
skip_layout = SkipLayout(num_partitions=2, skip_routes={(None, "copy"): (0, 1), (None, "not_copy"): (0, 0)})
|
||||
skip_tracker = SkipTrackerThroughPotals(skip_layout)
|
||||
|
||||
batch = Batch(torch.tensor([1.0]))
|
||||
a = torch.tensor([2.0])
|
||||
b = torch.tensor([2.0])
|
||||
|
||||
skip_tracker.save(batch, None, "copy", a)
|
||||
skip_tracker.save(batch, None, "not_copy", b)
|
||||
|
||||
assert (None, "copy") in skip_tracker.portals
|
||||
assert (None, "copy") not in skip_tracker.tensors
|
||||
assert (None, "not_copy") in skip_tracker.tensors
|
||||
assert (None, "not_copy") not in skip_tracker.portals
|
||||
|
||||
|
||||
def test_tensor_life_without_checkpointing():
|
||||
skip_layout = SkipLayout(num_partitions=2, skip_routes={(None, "test"): (0, 1)})
|
||||
skip_tracker = SkipTrackerThroughPotals(skip_layout)
|
||||
|
||||
batch = Batch(torch.tensor([1.0]))
|
||||
tensor = torch.tensor([2.0])
|
||||
|
||||
skip_tracker.save(batch, None, "test", tensor)
|
||||
assert skip_tracker.portals[(None, "test")].tensor_life == 1
|
||||
|
||||
skip_tracker.load(batch, None, "test")
|
||||
assert skip_tracker.portals[(None, "test")].tensor_life == 0
|
||||
|
||||
|
||||
def test_tensor_life_with_checkpointing():
|
||||
skip_layout = SkipLayout(num_partitions=2, skip_routes={(None, "test"): (0, 1)})
|
||||
skip_tracker = SkipTrackerThroughPotals(skip_layout)
|
||||
|
||||
batch = Batch(torch.tensor([1.0]))
|
||||
tensor = torch.tensor([2.0])
|
||||
|
||||
with enable_checkpointing():
|
||||
skip_tracker.save(batch, None, "test", tensor)
|
||||
assert skip_tracker.portals[(None, "test")].tensor_life == 2
|
||||
|
||||
with enable_checkpointing():
|
||||
skip_tracker.load(batch, None, "test")
|
||||
assert skip_tracker.portals[(None, "test")].tensor_life == 1
|
||||
|
||||
with enable_recomputing():
|
||||
skip_tracker.load(batch, None, "test")
|
||||
assert skip_tracker.portals[(None, "test")].tensor_life == 0
|
||||
|
||||
with enable_recomputing():
|
||||
skip_tracker.save(batch, None, "test", tensor)
|
||||
assert skip_tracker.portals[(None, "test")].tensor_life == 0
|
152
test/distributed/_pipeline/sync/skip/test_verify_skippables.py
Normal file
152
test/distributed/_pipeline/sync/skip/test_verify_skippables.py
Normal file
@ -0,0 +1,152 @@
|
||||
# Copyright 2019 Kakao Brain
|
||||
#
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
import pytest
|
||||
from torch import nn
|
||||
|
||||
from torch.distributed._pipeline.sync.skip import Namespace, skippable, verify_skippables
|
||||
|
||||
|
||||
def test_matching():
|
||||
@skippable(stash=["foo"])
|
||||
class Layer1(nn.Module):
|
||||
pass
|
||||
|
||||
@skippable(pop=["foo"])
|
||||
class Layer2(nn.Module):
|
||||
pass
|
||||
|
||||
verify_skippables(nn.Sequential(Layer1(), Layer2()))
|
||||
|
||||
|
||||
def test_stash_not_pop():
|
||||
@skippable(stash=["foo"])
|
||||
class Layer1(nn.Module):
|
||||
pass
|
||||
|
||||
with pytest.raises(TypeError) as e:
|
||||
verify_skippables(nn.Sequential(Layer1()))
|
||||
assert "no module declared 'foo' as poppable but stashed" in str(e.value)
|
||||
|
||||
|
||||
def test_pop_unknown():
|
||||
@skippable(pop=["foo"])
|
||||
class Layer1(nn.Module):
|
||||
pass
|
||||
|
||||
with pytest.raises(TypeError) as e:
|
||||
verify_skippables(nn.Sequential(Layer1()))
|
||||
assert "'0' declared 'foo' as poppable but it was not stashed" in str(e.value)
|
||||
|
||||
|
||||
def test_stash_again():
|
||||
@skippable(stash=["foo"])
|
||||
class Layer1(nn.Module):
|
||||
pass
|
||||
|
||||
@skippable(stash=["foo"])
|
||||
class Layer2(nn.Module):
|
||||
pass
|
||||
|
||||
@skippable(pop=["foo"])
|
||||
class Layer3(nn.Module):
|
||||
pass
|
||||
|
||||
with pytest.raises(TypeError) as e:
|
||||
verify_skippables(nn.Sequential(Layer1(), Layer2(), Layer3()))
|
||||
assert "'1' redeclared 'foo' as stashable" in str(e.value)
|
||||
|
||||
|
||||
def test_pop_again():
|
||||
@skippable(stash=["foo"])
|
||||
class Layer1(nn.Module):
|
||||
pass
|
||||
|
||||
@skippable(pop=["foo"])
|
||||
class Layer2(nn.Module):
|
||||
pass
|
||||
|
||||
@skippable(pop=["foo"])
|
||||
class Layer3(nn.Module):
|
||||
pass
|
||||
|
||||
with pytest.raises(TypeError) as e:
|
||||
verify_skippables(nn.Sequential(Layer1(), Layer2(), Layer3()))
|
||||
assert "'2' redeclared 'foo' as poppable" in str(e.value)
|
||||
|
||||
|
||||
def test_stash_pop_together_different_names():
|
||||
@skippable(stash=["foo"])
|
||||
class Layer1(nn.Module):
|
||||
pass
|
||||
|
||||
@skippable(pop=["foo"], stash=["bar"])
|
||||
class Layer2(nn.Module):
|
||||
pass
|
||||
|
||||
@skippable(pop=["bar"])
|
||||
class Layer3(nn.Module):
|
||||
pass
|
||||
|
||||
verify_skippables(nn.Sequential(Layer1(), Layer2(), Layer3()))
|
||||
|
||||
|
||||
def test_stash_pop_together_same_name():
|
||||
@skippable(stash=["foo"], pop=["foo"])
|
||||
class Layer1(nn.Module):
|
||||
pass
|
||||
|
||||
with pytest.raises(TypeError) as e:
|
||||
verify_skippables(nn.Sequential(Layer1()))
|
||||
assert "'0' declared 'foo' both as stashable and as poppable" in str(e.value)
|
||||
|
||||
|
||||
def test_double_stash_pop():
|
||||
@skippable(stash=["foo"])
|
||||
class Layer1(nn.Module):
|
||||
pass
|
||||
|
||||
@skippable(pop=["foo"])
|
||||
class Layer2(nn.Module):
|
||||
pass
|
||||
|
||||
@skippable(stash=["foo"])
|
||||
class Layer3(nn.Module):
|
||||
pass
|
||||
|
||||
@skippable(pop=["foo"])
|
||||
class Layer4(nn.Module):
|
||||
pass
|
||||
|
||||
with pytest.raises(TypeError) as e:
|
||||
verify_skippables(nn.Sequential(Layer1(), Layer2(), Layer3(), Layer4()))
|
||||
assert "'2' redeclared 'foo' as stashable" in str(e.value)
|
||||
assert "'3' redeclared 'foo' as poppable" in str(e.value)
|
||||
|
||||
|
||||
def test_double_stash_pop_but_isolated():
|
||||
@skippable(stash=["foo"])
|
||||
class Layer1(nn.Module):
|
||||
pass
|
||||
|
||||
@skippable(pop=["foo"])
|
||||
class Layer2(nn.Module):
|
||||
pass
|
||||
|
||||
@skippable(stash=["foo"])
|
||||
class Layer3(nn.Module):
|
||||
pass
|
||||
|
||||
@skippable(pop=["foo"])
|
||||
class Layer4(nn.Module):
|
||||
pass
|
||||
|
||||
ns1 = Namespace()
|
||||
ns2 = Namespace()
|
||||
|
||||
verify_skippables(
|
||||
nn.Sequential(Layer1().isolate(ns1), Layer2().isolate(ns1), Layer3().isolate(ns2), Layer4().isolate(ns2),)
|
||||
)
|
225
test/distributed/_pipeline/sync/test_balance.py
Normal file
225
test/distributed/_pipeline/sync/test_balance.py
Normal file
@ -0,0 +1,225 @@
|
||||
# Copyright 2019 Kakao Brain
|
||||
#
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from torch.distributed._pipeline.sync.balance import balance_by_size, balance_by_time, blockpartition
|
||||
from torch.distributed._pipeline.sync.balance.profile import layerwise_sandbox
|
||||
|
||||
skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
|
||||
|
||||
devices = ["cpu"]
|
||||
if torch.cuda.is_available():
|
||||
devices.append("cuda")
|
||||
|
||||
|
||||
def test_blockpartition():
|
||||
assert blockpartition.solve([1, 2, 3, 4, 5, 6], partitions=2) == [[1, 2, 3, 4], [5, 6]]
|
||||
|
||||
|
||||
def test_blockpartition_zeros():
|
||||
assert blockpartition.solve([0, 0], partitions=2) == [[0], [0]]
|
||||
|
||||
|
||||
def test_blockpartition_non_positive_partitions():
|
||||
with pytest.raises(ValueError):
|
||||
blockpartition.solve([42], partitions=0)
|
||||
with pytest.raises(ValueError):
|
||||
blockpartition.solve([42], partitions=-1)
|
||||
|
||||
|
||||
def test_blockpartition_short_sequence():
|
||||
with pytest.raises(ValueError):
|
||||
blockpartition.solve([], partitions=1)
|
||||
with pytest.raises(ValueError):
|
||||
blockpartition.solve([42], partitions=2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", devices)
|
||||
@pytest.mark.skip(reason="Flaky due to time.sleep()")
|
||||
def test_balance_by_time(device):
|
||||
class Delay(nn.Module):
|
||||
def __init__(self, seconds):
|
||||
super().__init__()
|
||||
self.seconds = seconds
|
||||
|
||||
def forward(self, x):
|
||||
time.sleep(self.seconds)
|
||||
return x
|
||||
|
||||
model = nn.Sequential(*[Delay(i / 10) for i in [1, 2, 3, 4, 5, 6]])
|
||||
sample = torch.rand(1)
|
||||
balance = balance_by_time(2, model, sample, device=device)
|
||||
assert balance == [4, 2]
|
||||
|
||||
|
||||
def test_balance_by_time_loop_resets_input():
|
||||
# nn.Flatten was introduced at PyTorch 1.2.0.
|
||||
class Flatten(nn.Module):
|
||||
def forward(self, x):
|
||||
return x.flatten(1)
|
||||
|
||||
model = nn.Sequential(nn.Conv2d(3, 2, 1), Flatten(), nn.Linear(128, 10))
|
||||
sample = torch.rand(10, 3, 8, 8)
|
||||
balance = balance_by_time(2, model, sample, device="cpu")
|
||||
assert balance == [1, 2]
|
||||
|
||||
|
||||
@skip_if_no_cuda
|
||||
def test_balance_by_size_latent():
|
||||
class Expand(nn.Module):
|
||||
def __init__(self, times):
|
||||
super().__init__()
|
||||
self.times = times
|
||||
|
||||
def forward(self, x):
|
||||
for i in range(self.times):
|
||||
x = x + torch.rand_like(x, requires_grad=True)
|
||||
return x
|
||||
|
||||
sample = torch.rand(10, 100, 100)
|
||||
|
||||
model = nn.Sequential(*[Expand(i) for i in [1, 2, 3, 4, 5, 6]])
|
||||
balance = balance_by_size(2, model, sample)
|
||||
assert balance == [4, 2]
|
||||
|
||||
model = nn.Sequential(*[Expand(i) for i in [6, 5, 4, 3, 2, 1]])
|
||||
balance = balance_by_size(2, model, sample)
|
||||
assert balance == [2, 4]
|
||||
|
||||
|
||||
@skip_if_no_cuda
|
||||
def test_balance_by_size_param():
|
||||
model = nn.Sequential(*[nn.Linear(i + 1, i + 2) for i in range(6)])
|
||||
sample = torch.rand(7, 1)
|
||||
balance = balance_by_size(2, model, sample, param_scale=100)
|
||||
assert balance == [4, 2]
|
||||
|
||||
model = nn.Sequential(*[nn.Linear(i + 2, i + 1) for i in reversed(range(6))])
|
||||
sample = torch.rand(1, 7)
|
||||
balance = balance_by_size(2, model, sample, param_scale=100)
|
||||
assert balance == [2, 4]
|
||||
|
||||
|
||||
@skip_if_no_cuda
|
||||
def test_balance_by_size_param_scale():
|
||||
class Tradeoff(nn.Module):
|
||||
def __init__(self, param_size, latent_size):
|
||||
super().__init__()
|
||||
self.fc = nn.Linear(param_size, param_size)
|
||||
self.latent_size = latent_size
|
||||
|
||||
def forward(self, x):
|
||||
for i in range(self.latent_size):
|
||||
x = x + torch.rand_like(x, requires_grad=True)
|
||||
return x
|
||||
|
||||
model = nn.Sequential(
|
||||
Tradeoff(param_size=1, latent_size=6),
|
||||
Tradeoff(param_size=2, latent_size=5),
|
||||
Tradeoff(param_size=3, latent_size=4),
|
||||
Tradeoff(param_size=4, latent_size=3),
|
||||
Tradeoff(param_size=5, latent_size=2),
|
||||
Tradeoff(param_size=6, latent_size=1),
|
||||
)
|
||||
|
||||
sample = torch.rand(1, requires_grad=True)
|
||||
|
||||
balance = balance_by_size(2, model, sample, param_scale=0)
|
||||
assert balance == [2, 4]
|
||||
|
||||
balance = balance_by_size(2, model, sample, param_scale=100)
|
||||
assert balance == [4, 2]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", devices)
|
||||
def test_layerwise_sandbox(device):
|
||||
model = nn.Sequential(nn.Conv2d(3, 3, 1), nn.BatchNorm2d(3))
|
||||
model.eval()
|
||||
|
||||
for layer in layerwise_sandbox(model, torch.device(device)):
|
||||
assert layer.training
|
||||
assert all(p.device.type == device for p in layer.parameters())
|
||||
|
||||
assert all(not l.training for l in model)
|
||||
assert all(p.device.type == "cpu" for p in model.parameters())
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", devices)
|
||||
def test_sandbox_during_profiling(device):
|
||||
model = nn.Sequential(nn.BatchNorm2d(3))
|
||||
|
||||
before = {k: v.clone() for k, v in model.state_dict().items()}
|
||||
|
||||
sample = torch.rand(1, 3, 10, 10)
|
||||
balance_by_time(1, model, sample, device=device)
|
||||
|
||||
after = model.state_dict()
|
||||
|
||||
assert before.keys() == after.keys()
|
||||
for key, value in before.items():
|
||||
assert torch.allclose(after[key], value), key
|
||||
|
||||
|
||||
def test_not_training():
|
||||
class AssertTraining(nn.Module):
|
||||
def forward(self, x):
|
||||
assert self.training
|
||||
return x
|
||||
|
||||
model = nn.Sequential(AssertTraining())
|
||||
|
||||
model.eval()
|
||||
assert not model.training
|
||||
|
||||
sample = torch.rand(1)
|
||||
balance_by_time(1, model, sample, device="cpu")
|
||||
|
||||
assert not model.training
|
||||
|
||||
|
||||
def test_balance_by_time_tuple():
|
||||
class Twin(nn.Module):
|
||||
def forward(self, x):
|
||||
return x, x.detach()
|
||||
|
||||
class Add(nn.Module):
|
||||
def forward(self, a_b):
|
||||
a, b = a_b
|
||||
return a + b
|
||||
|
||||
model = nn.Sequential(Twin(), Add())
|
||||
sample = torch.rand(1, requires_grad=True)
|
||||
balance_by_time(1, model, sample, device="cpu")
|
||||
|
||||
|
||||
@skip_if_no_cuda
|
||||
def test_balance_by_size_tuple():
|
||||
class Twin(nn.Module):
|
||||
def forward(self, x):
|
||||
return x, x.detach()
|
||||
|
||||
class Add(nn.Module):
|
||||
def forward(self, a_b):
|
||||
a, b = a_b
|
||||
return a + b
|
||||
|
||||
model = nn.Sequential(Twin(), Add())
|
||||
sample = torch.rand(1, requires_grad=True)
|
||||
balance_by_size(1, model, sample)
|
||||
|
||||
|
||||
def test_already_has_grad():
|
||||
model = nn.Sequential(nn.Conv2d(3, 3, 1))
|
||||
sample = torch.rand(1, 3, 32, 32)
|
||||
model(sample).norm().backward()
|
||||
|
||||
with pytest.raises(ValueError, match="some parameter already has gradient"):
|
||||
balance_by_time(1, model, sample, device="cpu")
|
128
test/distributed/_pipeline/sync/test_bugs.py
Normal file
128
test/distributed/_pipeline/sync/test_bugs.py
Normal file
@ -0,0 +1,128 @@
|
||||
# Copyright 2019 Kakao Brain
|
||||
#
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from torch.distributed._pipeline.sync import Pipe
|
||||
|
||||
|
||||
def test_python_autograd_function():
|
||||
# A Python autograd function might fail with this error:
|
||||
#
|
||||
# RuntimeError: Returning Variables sharing storage with other Variables
|
||||
# that require grad is not supported in Python functions. Please submit a
|
||||
# feature request if you hit this error.
|
||||
#
|
||||
# It doesn't look like an essential restriction. But it happens on the
|
||||
# current PyTorch version. To avoid it, we should detach the tensor before
|
||||
# returning by identity autograd functions, such as Wait, Fork, and Join.
|
||||
#
|
||||
class Identity(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
return input
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad):
|
||||
return grad
|
||||
|
||||
class M(nn.Module):
|
||||
def forward(self, input):
|
||||
return Identity.apply(input)
|
||||
|
||||
model = nn.Sequential(M(), M())
|
||||
model = Pipe(model, [1, 1], devices=["cpu", "cpu"], checkpoint="always")
|
||||
|
||||
x = torch.rand(42)
|
||||
y = model(x)
|
||||
assert torch.allclose(x, y)
|
||||
|
||||
|
||||
def test_exception_no_hang():
|
||||
# In v0.0.2, once a failed partition receives a normal message
|
||||
# (non-closing) for the next micro-batch, a hang occured. The reason was
|
||||
# that a failed partition didn't call in_queue.task_done() on a normal
|
||||
# message. So the former partition was blocked at out_queue.join() for the
|
||||
# next of next micro-batch.
|
||||
class ExpectedException(Exception):
|
||||
pass
|
||||
|
||||
class Pass(nn.Module):
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
class Raise(nn.Module):
|
||||
def forward(self, x):
|
||||
raise ExpectedException()
|
||||
|
||||
model = nn.Sequential(Pass(), Pass(), Raise())
|
||||
model = Pipe(model, [1, 1, 1], devices=["cpu", "cpu", "cpu"], chunks=3)
|
||||
|
||||
with pytest.raises(ExpectedException):
|
||||
model(torch.rand(3))
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="2 cuda devices required")
|
||||
def test_tuple_wait(cuda_sleep):
|
||||
# In v0.0.3, Wait is applied to only the first tensor on a micro-batch.
|
||||
# Under this behavior, if checkpointing was disabled, there's a possibility
|
||||
# that gradient accumulations on other tensors are not synchronized
|
||||
# properly to the copy stream.
|
||||
class Sleep(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
return x.detach()
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad):
|
||||
with torch.cuda.device(grad.device):
|
||||
cuda_sleep(0.05)
|
||||
return grad
|
||||
|
||||
class Layer1(nn.Module):
|
||||
def forward(self, pair):
|
||||
a, b = pair
|
||||
return a * 1, b * 2, b * 3
|
||||
|
||||
class Layer2(nn.Module):
|
||||
def forward(self, triple):
|
||||
a, b, c = triple
|
||||
b = Sleep.apply(b)
|
||||
return a + b + c
|
||||
|
||||
model = nn.Sequential(Layer1(), Layer2())
|
||||
model = Pipe(model, [1, 1], devices=[0, 1], chunks=32, checkpoint="never")
|
||||
|
||||
a = torch.rand(1024, 3, 32, 32, device=0, requires_grad=True)
|
||||
b = torch.rand(1024, 3, 32, 32, device=0, requires_grad=True)
|
||||
|
||||
y = model((a, b))
|
||||
y.norm().backward()
|
||||
|
||||
torch.cuda.synchronize(0)
|
||||
torch.cuda.synchronize(1)
|
||||
|
||||
assert torch.isclose(b.grad.norm().cpu(), torch.tensor(5.000))
|
||||
|
||||
|
||||
def test_parallel_randoms():
|
||||
class Dropouts(nn.Module):
|
||||
def forward(self, x):
|
||||
for _ in range(100):
|
||||
x = F.dropout(x, p=0.001)
|
||||
return x
|
||||
|
||||
model = nn.Sequential(Dropouts(), Dropouts())
|
||||
|
||||
x = torch.rand(10, 10, requires_grad=True)
|
||||
model = Pipe(model, [1, 1], devices=["cpu", "cpu"], chunks=10, checkpoint="always")
|
||||
y = model(x)
|
||||
y.norm().backward()
|
||||
|
||||
assert y.to(torch.bool).tolist() == x.grad.to(torch.bool).tolist()
|
158
test/distributed/_pipeline/sync/test_checkpoint.py
Normal file
158
test/distributed/_pipeline/sync/test_checkpoint.py
Normal file
@ -0,0 +1,158 @@
|
||||
# Copyright 2019 Kakao Brain
|
||||
#
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.cuda
|
||||
|
||||
from torch.distributed._pipeline.sync.checkpoint import Checkpointing, checkpoint, is_checkpointing, is_recomputing
|
||||
from torch.distributed._pipeline.sync.dependency import fork, join
|
||||
from torch.distributed._pipeline.sync.microbatch import Batch
|
||||
|
||||
devices = ["cpu"]
|
||||
if torch.cuda.is_available():
|
||||
devices.append("cuda")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", devices)
|
||||
def test_serial_checkpoints(device):
|
||||
# Copied from https://github.com/pytorch/pytorch/pull/18568.
|
||||
timeline = []
|
||||
|
||||
class Log(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, name, x):
|
||||
ctx.name = name
|
||||
timeline.append(f"{name}:forward")
|
||||
return x.detach()
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
name = ctx.name
|
||||
timeline.append(f"{name}:backward")
|
||||
return None, grad_output
|
||||
|
||||
a = torch.rand(1, device=device, requires_grad=True)
|
||||
b = torch.rand(1, device=device, requires_grad=True)
|
||||
|
||||
# Increase the next function sequence number.
|
||||
_ = a + 1 + 2 + 3 + 4 + 5
|
||||
|
||||
a = checkpoint(partial(Log.apply, "a"), a)
|
||||
|
||||
a, phony = fork(a)
|
||||
b = join(b, phony)
|
||||
|
||||
b = checkpoint(partial(Log.apply, "b"), b)
|
||||
|
||||
c = torch.cat((a, b))
|
||||
|
||||
out = c.sum()
|
||||
|
||||
# +--> {a} --Checkpoint(Log)--> {a}
|
||||
# {out} --Sum--> {c} --Cat ^-----------------------------+
|
||||
# +--> {b} --Checkpoint(Log)--> {b} --First--> {b}
|
||||
out.backward()
|
||||
|
||||
assert timeline == ["a:forward", "b:forward", "b:forward", "b:backward", "a:forward", "a:backward"]
|
||||
# |----------------------| |-----------------------| |-----------------------|
|
||||
# forward pass Checkpoint(Log[b]) Checkpoint(Log[a])
|
||||
|
||||
|
||||
def test_not_requires_grad():
|
||||
x = Batch(torch.rand(1, requires_grad=False))
|
||||
assert not x[0].requires_grad
|
||||
|
||||
def f(x):
|
||||
return x * 2
|
||||
|
||||
chk = Checkpointing(f, x)
|
||||
x = chk.checkpoint()
|
||||
assert x[0].requires_grad
|
||||
|
||||
chk.recompute(x)
|
||||
assert x[0].requires_grad
|
||||
|
||||
x.tensor.backward()
|
||||
|
||||
|
||||
def test_not_requires_grad_with_parameter():
|
||||
x = torch.rand(1, requires_grad=False)
|
||||
a = torch.rand(1, requires_grad=True)
|
||||
|
||||
def f(x):
|
||||
return x * a
|
||||
|
||||
y = checkpoint(f, x)
|
||||
y.backward()
|
||||
|
||||
assert a.grad is not None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", devices)
|
||||
def test_random_in_checkpoint(device):
|
||||
dropout = nn.Dropout(p=0.5)
|
||||
|
||||
torch.manual_seed(0)
|
||||
x = torch.randn(3, 3, device=device, requires_grad=True)
|
||||
y = dropout(x)
|
||||
y.norm().backward()
|
||||
|
||||
torch.manual_seed(0)
|
||||
chk_x = torch.randn(3, 3, device=device, requires_grad=True)
|
||||
chk_y = checkpoint(dropout, chk_x)
|
||||
chk_y.norm().backward()
|
||||
|
||||
assert torch.allclose(x.grad, chk_x.grad)
|
||||
|
||||
|
||||
def test_detect_checkpointing_recomputing():
|
||||
logs = []
|
||||
|
||||
class Detect(nn.Module):
|
||||
def forward(self, input):
|
||||
logs.append((is_checkpointing(), is_recomputing()))
|
||||
return input
|
||||
|
||||
model = Detect()
|
||||
input = torch.rand(1, requires_grad=True)
|
||||
|
||||
output = checkpoint(model, input)
|
||||
output.backward()
|
||||
|
||||
assert logs == [(True, False), (False, True)]
|
||||
|
||||
|
||||
def test_detect_checkpointing_recomputing_without_checkpoint():
|
||||
logs = []
|
||||
|
||||
class Detect(nn.Module):
|
||||
def forward(self, input):
|
||||
logs.append((is_checkpointing(), is_recomputing()))
|
||||
return input
|
||||
|
||||
model = Detect()
|
||||
input = torch.rand(1, requires_grad=True)
|
||||
|
||||
output = model(input)
|
||||
output.backward()
|
||||
|
||||
assert logs == [(False, False)]
|
||||
|
||||
|
||||
def test_non_grad_output():
|
||||
class ForkNonGrad(nn.Module):
|
||||
def forward(self, input):
|
||||
return (input * 2, torch.rand(1))
|
||||
|
||||
model = ForkNonGrad()
|
||||
input = torch.rand(1, requires_grad=True)
|
||||
|
||||
output = checkpoint(model, input)
|
||||
output[0].backward()
|
68
test/distributed/_pipeline/sync/test_copy.py
Normal file
68
test/distributed/_pipeline/sync/test_copy.py
Normal file
@ -0,0 +1,68 @@
|
||||
# Copyright 2019 Kakao Brain
|
||||
#
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from torch.distributed._pipeline.sync.copy import Copy, Wait
|
||||
from torch.distributed._pipeline.sync.stream import CPUStream, current_stream, get_device, is_cuda, new_stream, use_stream
|
||||
|
||||
skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
|
||||
|
||||
|
||||
def _test_copy_wait(prev_stream, next_stream, cuda_sleep=None):
|
||||
device = get_device(prev_stream)
|
||||
|
||||
with use_stream(prev_stream):
|
||||
if is_cuda(prev_stream):
|
||||
cuda_sleep(0.5)
|
||||
x = torch.ones(100, device=device, requires_grad=True)
|
||||
|
||||
(y,) = Copy.apply(prev_stream, next_stream, x)
|
||||
(y,) = Wait.apply(prev_stream, next_stream, x)
|
||||
|
||||
with use_stream(next_stream):
|
||||
assert torch.allclose(y.sum(), torch.tensor(100.0, device=device))
|
||||
y.norm().backward()
|
||||
with use_stream(prev_stream):
|
||||
assert torch.allclose(x.grad.sum(), torch.tensor(10.0, device=device))
|
||||
|
||||
|
||||
def test_copy_wait_cpu_cpu():
|
||||
prev_stream = CPUStream
|
||||
next_stream = CPUStream
|
||||
_test_copy_wait(prev_stream, next_stream)
|
||||
|
||||
|
||||
@skip_if_no_cuda
|
||||
def test_copy_wait_cpu_cuda(cuda_sleep):
|
||||
prev_stream = CPUStream
|
||||
next_stream = current_stream(torch.device("cuda"))
|
||||
_test_copy_wait(prev_stream, next_stream, cuda_sleep)
|
||||
|
||||
|
||||
@skip_if_no_cuda
|
||||
def test_copy_wait_cuda_cpu(cuda_sleep):
|
||||
prev_stream = current_stream(torch.device("cuda"))
|
||||
next_stream = CPUStream
|
||||
_test_copy_wait(prev_stream, next_stream, cuda_sleep)
|
||||
|
||||
|
||||
@skip_if_no_cuda
|
||||
def test_copy_wait_cuda_cuda(cuda_sleep):
|
||||
prev_stream = current_stream(torch.device("cuda"))
|
||||
next_stream = new_stream(torch.device("cuda"))
|
||||
_test_copy_wait(prev_stream, next_stream, cuda_sleep)
|
||||
|
||||
|
||||
def test_wait_multiple_tensors():
|
||||
a = torch.rand(1, requires_grad=True)
|
||||
b = torch.rand(1, requires_grad=True)
|
||||
|
||||
a, b = Wait.apply(CPUStream, CPUStream, a, b)
|
||||
|
||||
assert a.grad_fn is b.grad_fn
|
||||
assert a.grad_fn.__class__ is Wait._backward_cls
|
192
test/distributed/_pipeline/sync/test_deferred_batch_norm.py
Normal file
192
test/distributed/_pipeline/sync/test_deferred_batch_norm.py
Normal file
@ -0,0 +1,192 @@
|
||||
# Copyright 2019 Kakao Brain
|
||||
#
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
from copy import deepcopy
|
||||
from itertools import chain
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn, optim
|
||||
|
||||
from torch.distributed._pipeline.sync.batchnorm import DeferredBatchNorm
|
||||
|
||||
CHUNKS = 4
|
||||
|
||||
|
||||
def tilt_dist(input):
|
||||
# Tilt variance by channel.
|
||||
rgb = input.transpose(0, 1)
|
||||
rgb[0] *= 1
|
||||
rgb[1] *= 10
|
||||
rgb[2] *= 100
|
||||
|
||||
# Tilt mean by single batch.
|
||||
for i, single in enumerate(input):
|
||||
single += 2 ** i
|
||||
|
||||
return input
|
||||
|
||||
|
||||
def chunked_forward(model, input, chunks=CHUNKS):
|
||||
output_chunks = []
|
||||
|
||||
for chunk in input.chunk(chunks):
|
||||
output_chunks.append(model(chunk))
|
||||
|
||||
return torch.cat(output_chunks)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("chunks", [1, 4])
|
||||
@pytest.mark.parametrize("input_requires_grad", [True, False])
|
||||
def test_transparency(chunks, input_requires_grad):
|
||||
bn = nn.BatchNorm2d(3)
|
||||
dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=chunks)
|
||||
|
||||
input1 = torch.rand(16, 3, 224, 224)
|
||||
input1 = tilt_dist(input1)
|
||||
input2 = input1.clone()
|
||||
input1.requires_grad = input_requires_grad
|
||||
input2.requires_grad = input_requires_grad
|
||||
|
||||
output1 = chunked_forward(bn, input1, chunks=chunks)
|
||||
output2 = chunked_forward(dbn, input2, chunks=chunks)
|
||||
|
||||
assert torch.allclose(output1, output2, atol=1e-4)
|
||||
|
||||
output1.mean().backward()
|
||||
output2.mean().backward()
|
||||
|
||||
assert torch.allclose(bn.weight.grad, dbn.weight.grad, atol=1e-4)
|
||||
|
||||
if input_requires_grad:
|
||||
assert input1.grad is not None
|
||||
assert input2.grad is not None
|
||||
assert torch.allclose(input1.grad, input2.grad, atol=1e-4)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("momentum", [0.1, None])
|
||||
def test_running_stats(momentum):
|
||||
bn = nn.BatchNorm2d(3, momentum=momentum)
|
||||
dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=CHUNKS)
|
||||
|
||||
input = torch.rand(16, 3, 224, 224)
|
||||
input = tilt_dist(input)
|
||||
|
||||
bn(input)
|
||||
chunked_forward(dbn, input)
|
||||
|
||||
assert torch.allclose(bn.running_mean, dbn.running_mean, atol=1e-4)
|
||||
assert torch.allclose(bn.running_var, dbn.running_var, atol=1e-4)
|
||||
|
||||
|
||||
def test_convert_deferred_batch_norm():
|
||||
bn = nn.BatchNorm2d(3, track_running_stats=False)
|
||||
bn = DeferredBatchNorm.convert_deferred_batch_norm(bn, chunks=CHUNKS)
|
||||
assert type(bn) is nn.BatchNorm2d # because of track_running_stats=False
|
||||
|
||||
dbn = DeferredBatchNorm(3, chunks=CHUNKS)
|
||||
dbn_again = DeferredBatchNorm.convert_deferred_batch_norm(dbn, chunks=CHUNKS)
|
||||
assert dbn is dbn_again
|
||||
|
||||
dbn_again = DeferredBatchNorm.convert_deferred_batch_norm(dbn, chunks=CHUNKS + 1)
|
||||
assert dbn is not dbn_again # because of different chunks
|
||||
|
||||
|
||||
def test_eval():
|
||||
bn = nn.BatchNorm2d(3)
|
||||
dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=CHUNKS)
|
||||
|
||||
input = torch.rand(16, 3, 224, 224)
|
||||
input = tilt_dist(input)
|
||||
|
||||
bn(input)
|
||||
chunked_forward(dbn, input)
|
||||
|
||||
bn.eval()
|
||||
dbn.eval()
|
||||
|
||||
assert torch.allclose(bn(input), dbn(input), atol=1e-4)
|
||||
|
||||
|
||||
def test_optimize():
|
||||
bn = nn.BatchNorm2d(3)
|
||||
dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=CHUNKS)
|
||||
|
||||
opt = optim.SGD(chain(bn.parameters(), dbn.parameters()), lr=1.0)
|
||||
|
||||
for i in range(5):
|
||||
input = torch.rand(16, 3, 224, 224)
|
||||
input = tilt_dist(input)
|
||||
|
||||
# train
|
||||
y = bn(input)
|
||||
a = y.sum()
|
||||
a.backward()
|
||||
|
||||
y = chunked_forward(dbn, input)
|
||||
b = y.sum()
|
||||
b.backward()
|
||||
|
||||
opt.step()
|
||||
|
||||
# eval
|
||||
bn.eval()
|
||||
dbn.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
assert torch.allclose(bn(input), dbn(input), atol=1e-1 * (10 ** i))
|
||||
|
||||
|
||||
def test_conv_bn():
|
||||
bn = nn.Sequential(nn.Conv2d(3, 3, 1), nn.BatchNorm2d(3))
|
||||
dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=CHUNKS)
|
||||
|
||||
input = torch.rand(16, 3, 224, 224)
|
||||
input = tilt_dist(input)
|
||||
|
||||
opt = optim.SGD(chain(bn.parameters(), dbn.parameters()), lr=0.1)
|
||||
|
||||
# 1st step
|
||||
a = bn(input)
|
||||
b = chunked_forward(dbn, input)
|
||||
|
||||
# Outputs are different. (per-mini-batch vs. per-micro-batch)
|
||||
assert not torch.allclose(a, b)
|
||||
|
||||
a.sum().backward()
|
||||
b.sum().backward()
|
||||
opt.step()
|
||||
opt.zero_grad()
|
||||
|
||||
# Conv layers are also trained differently because of their different outputs.
|
||||
assert not torch.allclose(bn[0].weight, dbn[0].weight)
|
||||
|
||||
# But BNs track identical running stats.
|
||||
assert torch.allclose(bn[1].running_mean, dbn[1].running_mean, atol=1e-4)
|
||||
assert torch.allclose(bn[1].running_var, dbn[1].running_var, atol=1e3)
|
||||
|
||||
# 2nd step
|
||||
a = bn(input)
|
||||
b = chunked_forward(dbn, input)
|
||||
a.sum().backward()
|
||||
b.sum().backward()
|
||||
|
||||
# BNs can't track identical running stats due to the different conv layers.
|
||||
assert not torch.allclose(bn[1].running_mean, dbn[1].running_mean, atol=1e-4)
|
||||
assert not torch.allclose(bn[1].running_var, dbn[1].running_var, atol=1e3)
|
||||
|
||||
|
||||
def test_input_requiring_grad():
|
||||
dbn = DeferredBatchNorm(3, chunks=CHUNKS)
|
||||
|
||||
input = torch.rand(16, 3, 224, 224)
|
||||
input = tilt_dist(input)
|
||||
input.requires_grad = True
|
||||
|
||||
chunked_forward(dbn, input)
|
||||
|
||||
assert not dbn.sum.requires_grad
|
||||
assert dbn.sum.grad_fn is None
|
144
test/distributed/_pipeline/sync/test_dependency.py
Normal file
144
test/distributed/_pipeline/sync/test_dependency.py
Normal file
@ -0,0 +1,144 @@
|
||||
# Copyright 2019 Kakao Brain
|
||||
#
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
import weakref
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from torch.distributed._pipeline.sync.dependency import Fork, Join, fork, join
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
|
||||
def test_fork_join():
|
||||
logs = []
|
||||
|
||||
class Log(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, number, tensor):
|
||||
ctx.number = number
|
||||
return tensor.detach()
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad):
|
||||
logs.append(ctx.number)
|
||||
return None, grad
|
||||
|
||||
a = torch.rand(1, device="cpu", requires_grad=True)
|
||||
b = torch.rand(1, device="cuda", requires_grad=True)
|
||||
|
||||
a = Log.apply(1, a)
|
||||
|
||||
a, phony = fork(a)
|
||||
b = join(a, phony)
|
||||
|
||||
b = Log.apply(2, b)
|
||||
b = b.to("cpu")
|
||||
|
||||
(a + b).backward()
|
||||
|
||||
assert logs == [2, 1]
|
||||
|
||||
|
||||
def test_fork_join_enable_grad():
|
||||
x = torch.rand(1, requires_grad=True)
|
||||
|
||||
with torch.enable_grad():
|
||||
x2, p = fork(x)
|
||||
|
||||
assert p.requires_grad
|
||||
assert x2 is not x
|
||||
x = x2
|
||||
|
||||
assert x.requires_grad
|
||||
assert p.requires_grad
|
||||
assert x.grad_fn.__class__ is Fork._backward_cls
|
||||
assert p.grad_fn.__class__ is Fork._backward_cls
|
||||
|
||||
with torch.enable_grad():
|
||||
x2 = join(x, p)
|
||||
|
||||
assert x2 is not x
|
||||
x = x2
|
||||
|
||||
assert x.requires_grad
|
||||
assert x.grad_fn.__class__ is Join._backward_cls
|
||||
|
||||
|
||||
def test_fork_join_no_grad(monkeypatch):
|
||||
def do_not_apply(*args):
|
||||
raise AssertionError("Function.apply called")
|
||||
|
||||
monkeypatch.setattr("torch.autograd.Function.apply", do_not_apply)
|
||||
|
||||
x = torch.rand(1, requires_grad=True)
|
||||
|
||||
with torch.no_grad():
|
||||
x2, p = fork(x)
|
||||
|
||||
assert not p.requires_grad
|
||||
assert x2 is x
|
||||
x = x2
|
||||
|
||||
with torch.no_grad():
|
||||
x2 = join(x, p)
|
||||
|
||||
assert x2 is x
|
||||
x = x2
|
||||
|
||||
|
||||
def test_fork_leak():
|
||||
leak = None
|
||||
|
||||
class F(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
return input
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad):
|
||||
nonlocal leak
|
||||
leak = weakref.ref(ctx)
|
||||
return grad
|
||||
|
||||
x = torch.rand(1, requires_grad=True)
|
||||
x = F.apply(x)
|
||||
x, phony = fork(x)
|
||||
x = join(x, phony)
|
||||
|
||||
x.backward()
|
||||
del x, phony
|
||||
|
||||
assert leak() is None
|
||||
|
||||
|
||||
def test_join_when_fork_not_requires_grad():
|
||||
x = torch.rand(2, 1)
|
||||
a, b = x.chunk(2)
|
||||
|
||||
assert not a.requires_grad
|
||||
a, p = fork(a)
|
||||
assert not a.requires_grad
|
||||
assert not p.requires_grad
|
||||
|
||||
assert not b.requires_grad
|
||||
b = join(b, p)
|
||||
assert not b.requires_grad
|
||||
|
||||
|
||||
def test_join_when_fork_requires_grad():
|
||||
x = torch.rand(2, 1)
|
||||
a, b = x.chunk(2)
|
||||
|
||||
a.requires_grad_()
|
||||
assert a.requires_grad
|
||||
a, p = fork(a)
|
||||
assert a.requires_grad
|
||||
assert p.requires_grad
|
||||
|
||||
assert not b.requires_grad
|
||||
b = join(b, p)
|
||||
assert b.requires_grad
|
71
test/distributed/_pipeline/sync/test_inplace.py
Normal file
71
test/distributed/_pipeline/sync/test_inplace.py
Normal file
@ -0,0 +1,71 @@
|
||||
# Copyright 2019 Kakao Brain
|
||||
#
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from torch.distributed._pipeline.sync import Pipe
|
||||
|
||||
|
||||
def test_inplace_on_requires_grad():
|
||||
model = nn.Sequential(nn.Linear(1, 1), nn.ReLU(inplace=True))
|
||||
model = Pipe(model, [1, 1], devices=["cpu", "cpu"], checkpoint="always")
|
||||
|
||||
x = torch.rand(1)
|
||||
y = model(x)
|
||||
|
||||
message = r"a leaf Variable that requires grad .* used in an in-place operation."
|
||||
with pytest.raises(RuntimeError, match=message):
|
||||
y.backward()
|
||||
|
||||
|
||||
@pytest.mark.xfail(strict=True)
|
||||
def test_inplace_on_not_requires_grad():
|
||||
# In-place operation on a tensor not requiring grad doesn't cause a
|
||||
# RuntimeError. Currently, we cannot detect this case.
|
||||
model = nn.Sequential(nn.ReLU(inplace=True))
|
||||
model = Pipe(model, [1], devices=["cpu"], checkpoint="always")
|
||||
|
||||
x = torch.rand(1)
|
||||
y = model(x)
|
||||
del model
|
||||
|
||||
message = r"a leaf Variable that requires grad .* used in an in-place operation."
|
||||
with pytest.raises(RuntimeError, match=message):
|
||||
y.backward()
|
||||
|
||||
|
||||
@pytest.mark.xfail(strict=True)
|
||||
def test_inplace_incorrect_grad():
|
||||
class M(nn.Module):
|
||||
def forward(self, foo_bar):
|
||||
# 'foo' requires grad but 'bar' does not. In-place operation on
|
||||
# 'bar' won't cause a RuntimeError.
|
||||
foo, bar = foo_bar
|
||||
|
||||
# add_(1) is not idempotent, in contrast to relu_(). If it is
|
||||
# executed multiple times, it will accumulates each difference onto
|
||||
# 'bar'.
|
||||
bar.add_(1)
|
||||
|
||||
# 'bar' is still captured by checkpointing. 'foo' will get
|
||||
# incorrect grad.
|
||||
return foo * bar
|
||||
|
||||
model = nn.Sequential(M())
|
||||
model = Pipe(model, [1], devices=["cpu"], checkpoint="always")
|
||||
|
||||
foo = torch.tensor([1.0], requires_grad=True)
|
||||
bar = torch.tensor([1.0])
|
||||
|
||||
output = model((foo, bar))
|
||||
del model
|
||||
output.backward()
|
||||
|
||||
# The gradient of 'foo' should be 2, but it is 3 actually because
|
||||
# bar.add_(1) was executed twice due to checkpointing.
|
||||
assert foo.grad.item() == 2.0
|
138
test/distributed/_pipeline/sync/test_microbatch.py
Normal file
138
test/distributed/_pipeline/sync/test_microbatch.py
Normal file
@ -0,0 +1,138 @@
|
||||
# Copyright 2019 Kakao Brain
|
||||
#
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
import pytest
|
||||
import torch
|
||||
import torch.cuda
|
||||
|
||||
from torch.distributed._pipeline.sync.microbatch import Batch, check, gather, scatter
|
||||
|
||||
|
||||
def test_batch_atomic():
|
||||
x = torch.tensor(42)
|
||||
b = Batch(x)
|
||||
|
||||
assert b.atomic
|
||||
|
||||
assert b.tensor is x
|
||||
with pytest.raises(AttributeError):
|
||||
b.tensors
|
||||
|
||||
assert list(b) == [x]
|
||||
assert len(b) == 1
|
||||
assert b[0] is x
|
||||
|
||||
|
||||
def test_batch_non_atomic():
|
||||
x, y = torch.tensor(42), torch.tensor(21)
|
||||
b = Batch((x, y))
|
||||
|
||||
assert not b.atomic
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
b.tensor
|
||||
assert b.tensors == (x, y)
|
||||
|
||||
assert list(b) == [x, y]
|
||||
assert len(b) == 2
|
||||
assert b[0] is x
|
||||
assert b[1] is y
|
||||
|
||||
|
||||
def test_batch_call():
|
||||
a = Batch(torch.tensor(42))
|
||||
b = Batch((torch.tensor(42), torch.tensor(21)))
|
||||
|
||||
def f(x):
|
||||
return x
|
||||
|
||||
assert a.call(f).atomic
|
||||
assert not b.call(f).atomic
|
||||
|
||||
|
||||
def test_batch_setitem_by_index():
|
||||
a = Batch(torch.tensor(42))
|
||||
b = Batch((torch.tensor(42), torch.tensor(21)))
|
||||
|
||||
a[0] = torch.tensor(0)
|
||||
b[0] = torch.tensor(0)
|
||||
|
||||
assert a.atomic
|
||||
assert a[0].item() == 0
|
||||
|
||||
assert not b.atomic
|
||||
assert len(b) == 2
|
||||
assert b[0].item() == 0
|
||||
assert b[1].item() == 21
|
||||
|
||||
|
||||
def test_batch_setitem_by_slice():
|
||||
a = Batch(torch.tensor(42))
|
||||
b = Batch((torch.tensor(42), torch.tensor(21)))
|
||||
|
||||
a[:] = (torch.tensor(0),)
|
||||
b[:] = (torch.tensor(0),)
|
||||
|
||||
assert a.atomic
|
||||
assert a[0].item() == 0
|
||||
|
||||
assert not b.atomic
|
||||
assert len(b) == 1
|
||||
assert b[0].item() == 0
|
||||
|
||||
|
||||
def test_check():
|
||||
check(torch.tensor(42))
|
||||
check((torch.tensor(4), torch.tensor(2)))
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
check(42)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
check("str")
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
check((torch.tensor(4), 2))
|
||||
|
||||
|
||||
def test_gather_tensors():
|
||||
a = torch.zeros(1, 1)
|
||||
b = torch.zeros(1, 1)
|
||||
|
||||
ab = gather([Batch(a), Batch(b)])
|
||||
|
||||
assert ab.size() == (2, 1)
|
||||
|
||||
|
||||
def test_gather_tuples():
|
||||
a = (torch.zeros(1, 1), torch.zeros(2, 2))
|
||||
b = (torch.zeros(1, 1), torch.zeros(2, 2))
|
||||
|
||||
ab = gather([Batch(a), Batch(b)])
|
||||
|
||||
assert isinstance(ab, tuple)
|
||||
assert ab[0].size() == (2, 1)
|
||||
assert ab[1].size() == (4, 2)
|
||||
|
||||
|
||||
def test_scatter_tensor():
|
||||
ab = torch.zeros(2, 1)
|
||||
|
||||
a, b = scatter(ab, chunks=2)
|
||||
|
||||
assert a.tensor.size() == (1, 1)
|
||||
assert b.tensor.size() == (1, 1)
|
||||
|
||||
|
||||
def test_scatter_tuple():
|
||||
ab = (torch.zeros(2, 1), torch.zeros(4, 2))
|
||||
|
||||
a, b = scatter(ab, chunks=2)
|
||||
|
||||
assert a.tensors[0].size() == (1, 1)
|
||||
assert b.tensors[0].size() == (1, 1)
|
||||
assert a.tensors[1].size() == (2, 2)
|
||||
assert b.tensors[1].size() == (2, 2)
|
50
test/distributed/_pipeline/sync/test_phony.py
Normal file
50
test/distributed/_pipeline/sync/test_phony.py
Normal file
@ -0,0 +1,50 @@
|
||||
# Copyright 2019 Kakao Brain
|
||||
#
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
import torch
|
||||
|
||||
from torch.distributed._pipeline.sync.phony import get_phony
|
||||
|
||||
|
||||
def test_phony_size():
|
||||
p = get_phony(torch.device("cpu"), requires_grad=False)
|
||||
assert p.size() == (0,)
|
||||
|
||||
|
||||
def test_phony_requires_grad():
|
||||
p1 = get_phony(torch.device("cpu"), requires_grad=True)
|
||||
p2 = get_phony(torch.device("cpu"), requires_grad=False)
|
||||
assert p1.requires_grad
|
||||
assert not p2.requires_grad
|
||||
|
||||
|
||||
def test_cached_phony():
|
||||
p1 = get_phony(torch.device("cpu"), requires_grad=True)
|
||||
p2 = get_phony(torch.device("cpu"), requires_grad=True)
|
||||
assert p1 is p2
|
||||
|
||||
p3 = get_phony(torch.device("cpu"), requires_grad=False)
|
||||
p4 = get_phony(torch.device("cpu"), requires_grad=False)
|
||||
assert p3 is p4
|
||||
|
||||
assert p1 is not p3
|
||||
|
||||
|
||||
def test_phony_in_autograd_function():
|
||||
class Phonify(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
phony = get_phony(input.device, requires_grad=False)
|
||||
return phony.detach()
|
||||
|
||||
x = torch.rand(1, requires_grad=True)
|
||||
|
||||
p1 = Phonify.apply(x)
|
||||
p2 = get_phony(torch.device("cpu"), requires_grad=True)
|
||||
|
||||
assert p1 is not p2
|
||||
assert p1.grad_fn is not None
|
||||
assert p2.grad_fn is None
|
608
test/distributed/_pipeline/sync/test_pipe.py
Normal file
608
test/distributed/_pipeline/sync/test_pipe.py
Normal file
@ -0,0 +1,608 @@
|
||||
# Copyright 2019 Kakao Brain
|
||||
#
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
from collections import OrderedDict
|
||||
from copy import deepcopy
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from torch.distributed._pipeline.sync import Pipe
|
||||
|
||||
skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
|
||||
|
||||
|
||||
def test_parameters():
|
||||
model = nn.Sequential(nn.Linear(1, 1))
|
||||
pipe = Pipe(model, balance=[1], devices=["cpu"], chunks=1)
|
||||
assert list(pipe.parameters()) != []
|
||||
|
||||
|
||||
def test_public_attrs():
|
||||
class MyString:
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
||||
model = nn.Sequential(nn.Linear(1, 1))
|
||||
pipe = Pipe(model, balance=(1,), devices=("cpu",), chunks=42.000, checkpoint=MyString("always"))
|
||||
|
||||
assert pipe.balance == [1]
|
||||
assert pipe.devices == [torch.device("cpu")]
|
||||
assert pipe.chunks == 42
|
||||
assert isinstance(pipe.chunks, int)
|
||||
assert pipe.checkpoint == "always"
|
||||
assert isinstance(pipe.checkpoint, str)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("balance", [[2], [1, 1]])
|
||||
def test_sequential_like(balance):
|
||||
a = nn.Linear(1, 1)
|
||||
b = nn.Linear(1, 1)
|
||||
|
||||
model = nn.Sequential(a, b)
|
||||
model = Pipe(model, balance, devices=["cpu", "cpu"])
|
||||
|
||||
assert len(model) == 2
|
||||
assert list(model) == [a, b]
|
||||
|
||||
assert model[0] is a
|
||||
assert model[1] is b
|
||||
with pytest.raises(IndexError):
|
||||
_ = model[2]
|
||||
|
||||
assert model[-1] is b
|
||||
assert model[-2] is a
|
||||
|
||||
|
||||
def test_balance_wrong_length():
|
||||
a = nn.Linear(1, 1)
|
||||
b = nn.Linear(1, 1)
|
||||
|
||||
model = nn.Sequential(a, b)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
Pipe(model, balance=[1])
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
Pipe(model, balance=[3])
|
||||
|
||||
|
||||
def test_balance_less_than_1():
|
||||
a = nn.Linear(1, 1)
|
||||
b = nn.Linear(1, 1)
|
||||
|
||||
model = nn.Sequential(a, b)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
Pipe(model, balance=[0, 2])
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
Pipe(model, balance=[-1, 3])
|
||||
|
||||
|
||||
def test_chunks_less_than_1():
|
||||
model = nn.Sequential(nn.Linear(1, 1))
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
Pipe(model, balance=[1], devices=["cpu"], chunks=0)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
Pipe(model, balance=[1], devices=["cpu"], chunks=-1)
|
||||
|
||||
|
||||
def test_too_few_devices():
|
||||
model = nn.Sequential(nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1))
|
||||
|
||||
with pytest.raises(IndexError):
|
||||
# len(balance) > len(devices)
|
||||
model = Pipe(model, balance=[1, 1, 1, 1], devices=["cpu"])
|
||||
|
||||
|
||||
def test_batch_size_indivisible():
|
||||
model = nn.Sequential(nn.Linear(1, 1))
|
||||
model = Pipe(model, balance=[1], devices=["cpu"], chunks=4)
|
||||
|
||||
with pytest.warns(None) as record:
|
||||
model(torch.rand(7, 1))
|
||||
|
||||
# Indivisible batch size is legal.
|
||||
assert not record
|
||||
|
||||
|
||||
def test_batch_size_small():
|
||||
model = nn.Sequential(nn.Linear(1, 1))
|
||||
model = Pipe(model, balance=[1], devices=["cpu"], chunks=4)
|
||||
|
||||
with pytest.warns(None) as record:
|
||||
model(torch.rand(2, 1))
|
||||
|
||||
# Batch size smaller than chunks is legal.
|
||||
assert not record
|
||||
|
||||
|
||||
def test_checkpoint_mode():
|
||||
def count_grad_fn(grad_fn, name, visited=None):
|
||||
if visited is None:
|
||||
visited = set()
|
||||
if grad_fn in visited:
|
||||
return 0
|
||||
visited.add(grad_fn)
|
||||
|
||||
if grad_fn is None:
|
||||
return 0
|
||||
if grad_fn.__class__.__name__ == name:
|
||||
return 1
|
||||
|
||||
counter = 0
|
||||
for next_grad_fn, _ in grad_fn.next_functions:
|
||||
counter += count_grad_fn(next_grad_fn, name, visited=visited)
|
||||
return counter
|
||||
|
||||
model = nn.Sequential(nn.Linear(1, 1))
|
||||
input = torch.rand(2, 1)
|
||||
|
||||
always = Pipe(model, balance=[1], devices=["cpu"], chunks=2, checkpoint="always")
|
||||
except_last = Pipe(model, balance=[1], devices=["cpu"], chunks=2, checkpoint="except_last")
|
||||
never = Pipe(model, balance=[1], devices=["cpu"], chunks=2, checkpoint="never")
|
||||
|
||||
always_output = always(input)
|
||||
except_last_output = except_last(input)
|
||||
never_output = never(input)
|
||||
|
||||
assert count_grad_fn(always_output.grad_fn, "CheckpointBackward") == 2
|
||||
assert count_grad_fn(except_last_output.grad_fn, "CheckpointBackward") == 1
|
||||
assert count_grad_fn(never_output.grad_fn, "CheckpointBackward") == 0
|
||||
|
||||
|
||||
def test_checkpoint_mode_invalid():
|
||||
model = nn.Sequential(nn.Linear(1, 1))
|
||||
|
||||
with pytest.raises(ValueError, match="checkpoint is not one of 'always', 'except_last', or 'never'"):
|
||||
Pipe(model, balance=[1], devices=["cpu"], chunks=2, checkpoint="INVALID_CHECKPOINT")
|
||||
|
||||
|
||||
def test_checkpoint_mode_when_chunks_1():
|
||||
model = nn.Sequential(nn.Linear(1, 1))
|
||||
|
||||
# All checkpoint modes are fine.
|
||||
Pipe(model, balance=[1], devices=["cpu"], chunks=1, checkpoint="except_last")
|
||||
Pipe(model, balance=[1], devices=["cpu"], chunks=1, checkpoint="always")
|
||||
Pipe(model, balance=[1], devices=["cpu"], chunks=1, checkpoint="never")
|
||||
|
||||
|
||||
def test_checkpoint_eval():
|
||||
model = nn.Sequential(nn.Linear(1, 1))
|
||||
model = Pipe(model, balance=[1], devices=["cpu"], chunks=2)
|
||||
input = torch.rand(2, 1)
|
||||
|
||||
def find_grad_fn(grad_fn, name):
|
||||
if grad_fn is None:
|
||||
return False
|
||||
if grad_fn.__class__.__name__ == name:
|
||||
return True
|
||||
for next_grad_fn, _ in grad_fn.next_functions:
|
||||
if find_grad_fn(next_grad_fn, name):
|
||||
return True
|
||||
return False
|
||||
|
||||
model.train()
|
||||
train_output = model(input)
|
||||
assert find_grad_fn(train_output.grad_fn, "CheckpointBackward")
|
||||
assert find_grad_fn(train_output.grad_fn, "RecomputeBackward")
|
||||
|
||||
model.eval()
|
||||
eval_output = model(input)
|
||||
assert not find_grad_fn(eval_output.grad_fn, "CheckpointBackward")
|
||||
assert not find_grad_fn(eval_output.grad_fn, "RecomputeBackward")
|
||||
|
||||
|
||||
def test_checkpoint_non_float_input():
|
||||
class ForkNonFloat(nn.Module):
|
||||
def forward(self, input):
|
||||
return (input * 2, torch.tensor([False]))
|
||||
|
||||
class JoinNonFloat(nn.Module):
|
||||
def forward(self, input):
|
||||
return input[0] * 2
|
||||
|
||||
model = nn.Sequential(ForkNonFloat(), JoinNonFloat())
|
||||
model = Pipe(model, balance=[1, 1], devices=["cpu", "cpu"], chunks=1, checkpoint="always")
|
||||
|
||||
input = torch.rand(1, requires_grad=True)
|
||||
output = model(input)
|
||||
output.backward()
|
||||
|
||||
|
||||
def test_no_grad():
|
||||
model = nn.Sequential(nn.Linear(1, 1))
|
||||
model = Pipe(model, balance=[1], devices=["cpu"], chunks=2)
|
||||
input = torch.rand(2, 1)
|
||||
|
||||
latent = None
|
||||
|
||||
def hook(module, input, output):
|
||||
_ = module
|
||||
_ = input
|
||||
|
||||
nonlocal latent
|
||||
latent = output
|
||||
|
||||
partition = model.partitions[0]
|
||||
partition.register_forward_hook(hook)
|
||||
|
||||
with torch.no_grad():
|
||||
model(input)
|
||||
|
||||
assert latent.grad_fn is None
|
||||
|
||||
|
||||
def test_exception():
|
||||
class ExpectedException(Exception):
|
||||
pass
|
||||
|
||||
class Raise(nn.Module):
|
||||
def forward(self, *_):
|
||||
raise ExpectedException()
|
||||
|
||||
model = nn.Sequential(Raise())
|
||||
model = Pipe(model, balance=[1], devices=["cpu"], chunks=1)
|
||||
|
||||
with pytest.raises(ExpectedException):
|
||||
model(torch.rand(1))
|
||||
|
||||
|
||||
def test_exception_early_stop_asap():
|
||||
"""Even the first partitions have finished to process, the partition before
|
||||
the failed partition should be killed as soon as possible.
|
||||
"""
|
||||
|
||||
class ExpectedException(Exception):
|
||||
pass
|
||||
|
||||
class Pass(nn.Module):
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
counter = 0
|
||||
|
||||
class Counter(nn.Module):
|
||||
def forward(self, x):
|
||||
time.sleep(0.1)
|
||||
|
||||
nonlocal counter
|
||||
counter += 1
|
||||
|
||||
return x
|
||||
|
||||
class Raise(nn.Module):
|
||||
def forward(self, x):
|
||||
raise ExpectedException()
|
||||
|
||||
model = nn.Sequential(Pass(), Pass(), Counter(), Raise())
|
||||
model = Pipe(model, [1, 1, 1, 1], devices=["cpu", "cpu", "cpu", "cpu"], chunks=3)
|
||||
|
||||
with pytest.raises(ExpectedException):
|
||||
model(torch.rand(3))
|
||||
|
||||
# If the early stop doesn't work, it would be 3 instead.
|
||||
assert counter == 2
|
||||
|
||||
|
||||
def test_input_pair():
|
||||
class Two(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc_a = nn.Linear(1, 1)
|
||||
self.fc_b = nn.Linear(1, 1)
|
||||
|
||||
def forward(self, a_and_b):
|
||||
a, b = a_and_b
|
||||
return (self.fc_a(a), self.fc_b(b))
|
||||
|
||||
model = nn.Sequential(Two())
|
||||
model = Pipe(model, balance=[1], devices=["cpu"], chunks=2)
|
||||
|
||||
a = torch.rand(10, 1, requires_grad=True)
|
||||
b = torch.rand(10, 1, requires_grad=True)
|
||||
|
||||
a_out, b_out = model((a, b))
|
||||
loss = (a_out + b_out).mean()
|
||||
loss.backward()
|
||||
|
||||
assert a.grad is not None
|
||||
assert b.grad is not None
|
||||
|
||||
|
||||
def test_input_singleton():
|
||||
class One(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc = nn.Linear(1, 1)
|
||||
|
||||
def forward(self, only_a):
|
||||
(a,) = only_a
|
||||
return (self.fc(a),)
|
||||
|
||||
model = nn.Sequential(One())
|
||||
model = Pipe(model, balance=[1], devices=["cpu"], chunks=2)
|
||||
|
||||
a = torch.rand(10, 1, requires_grad=True)
|
||||
|
||||
(a_out,) = model((a,))
|
||||
loss = a_out.mean()
|
||||
loss.backward()
|
||||
|
||||
assert all(p.grad is not None for p in model.parameters())
|
||||
assert a.grad is not None
|
||||
|
||||
|
||||
def test_input_varargs():
|
||||
model = nn.Sequential(nn.Linear(1, 1))
|
||||
model = Pipe(model, balance=[1], devices=["cpu"])
|
||||
|
||||
a = torch.rand(1)
|
||||
b = torch.rand(1)
|
||||
|
||||
# TypeError: forward() takes 2 positional arguments but 3 were given
|
||||
with pytest.raises(TypeError):
|
||||
model(a, b)
|
||||
|
||||
|
||||
def test_non_tensor():
|
||||
class NonTensor(nn.Module):
|
||||
def forward(self, _):
|
||||
return "hello"
|
||||
|
||||
model = nn.Sequential(NonTensor())
|
||||
model = Pipe(model, balance=[1], devices=["cpu"])
|
||||
x = torch.rand(1)
|
||||
|
||||
# TypeError: expected Tensor as element 0 in argument 0, but got str
|
||||
with pytest.raises(TypeError):
|
||||
model(x)
|
||||
|
||||
# TypeError: expected Tensor to scatter, but got str
|
||||
with pytest.raises(TypeError):
|
||||
model("hello")
|
||||
|
||||
|
||||
def test_non_tensor_tuple():
|
||||
class NonTensorTuple(nn.Module):
|
||||
def forward(self, x):
|
||||
return (x, "hello")
|
||||
|
||||
model = nn.Sequential(NonTensorTuple())
|
||||
model = Pipe(model, balance=[1], devices=["cpu"])
|
||||
x = torch.rand(1)
|
||||
|
||||
# TypeError: CheckpointBackward.forward: expected Variable (got str) for return value 1
|
||||
with pytest.raises(TypeError):
|
||||
model(x)
|
||||
|
||||
# TypeError: expected Tensor to scatter, but got str
|
||||
with pytest.raises(TypeError):
|
||||
model((x, "hello"))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"])
|
||||
def test_deferred_batch_norm(checkpoint):
|
||||
bn = nn.BatchNorm2d(3)
|
||||
pipe_bn = deepcopy(bn)
|
||||
pipe = Pipe(
|
||||
nn.Sequential(pipe_bn), balance=[1], devices=["cpu"], chunks=2, checkpoint=checkpoint, deferred_batch_norm=True
|
||||
)
|
||||
|
||||
x = torch.rand(4, 3, 10, 10)
|
||||
pipe(x).mean().backward()
|
||||
bn(x).mean().backward()
|
||||
|
||||
assert torch.allclose(pipe[0].running_mean, bn.running_mean, atol=1e-4)
|
||||
assert torch.allclose(pipe[0].running_var, bn.running_var, atol=1e-4)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("checkpoint", ["never", "always"])
|
||||
def test_deferred_batch_norm_params(checkpoint):
|
||||
bn = nn.BatchNorm2d(3)
|
||||
pipe_bn = deepcopy(bn)
|
||||
pipe = Pipe(
|
||||
nn.Sequential(pipe_bn), balance=[1], devices=["cpu"], chunks=1, checkpoint=checkpoint, deferred_batch_norm=True
|
||||
)
|
||||
|
||||
x = torch.rand(4, 3, 10, 10)
|
||||
pipe(x).mean().backward()
|
||||
bn(x).mean().backward()
|
||||
|
||||
assert pipe[0].weight.grad is not None
|
||||
assert pipe[0].bias.grad is not None
|
||||
|
||||
assert torch.allclose(pipe[0].weight.grad, bn.weight.grad, atol=1e-4)
|
||||
assert torch.allclose(pipe[0].bias.grad, bn.bias.grad, atol=1e-4)
|
||||
|
||||
|
||||
def test_devices():
|
||||
a = nn.Linear(1, 1)
|
||||
b = nn.Linear(1, 1)
|
||||
c = nn.Linear(1, 1)
|
||||
|
||||
# There are extra two devices.
|
||||
devices = ["cpu", "cpu", "cpu", "cpu", "cpu"]
|
||||
|
||||
model = nn.Sequential(a, b, c)
|
||||
model = Pipe(model, [1, 1, 1], devices=devices)
|
||||
|
||||
cpu = torch.device("cpu")
|
||||
# Extra devices must be discarded.
|
||||
assert model.devices == [cpu, cpu, cpu]
|
||||
|
||||
|
||||
def test_partitions():
|
||||
a = nn.Linear(1, 1)
|
||||
b = nn.Linear(1, 1)
|
||||
|
||||
model = nn.Sequential(a, b)
|
||||
model = Pipe(model, [1, 1], devices=["cpu", "cpu"])
|
||||
|
||||
assert isinstance(model.partitions, nn.ModuleList)
|
||||
assert isinstance(model.partitions[0], nn.Sequential)
|
||||
assert isinstance(model.partitions[1], nn.Sequential)
|
||||
|
||||
assert "partitions.0.0.weight" in model.state_dict()
|
||||
|
||||
|
||||
def test_deny_moving():
|
||||
a = nn.Linear(1, 1)
|
||||
b = nn.Linear(1, 1)
|
||||
|
||||
model = nn.Sequential(a, b)
|
||||
model = Pipe(model, [1, 1], devices=["cpu", "cpu"])
|
||||
|
||||
# Moving is denied.
|
||||
with pytest.raises(TypeError):
|
||||
model.cuda()
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
model.cpu()
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
model.to(torch.device("cuda"))
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
model.to(0)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
model.to("cuda")
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
model.to(device=0)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
model.to(torch.rand(1))
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
model.to(tensor=torch.rand(1))
|
||||
|
||||
# Casting is allowed.
|
||||
model.half()
|
||||
model.to(torch.double)
|
||||
model.to(dtype=torch.float)
|
||||
|
||||
|
||||
def test_empty_module():
|
||||
# Empty sequential module is not illegal.
|
||||
model = nn.Sequential()
|
||||
model = Pipe(model, [])
|
||||
|
||||
assert model(torch.tensor(42)) == torch.tensor(42)
|
||||
assert model((torch.tensor(42),)) == (torch.tensor(42),)
|
||||
|
||||
# But only tensor or tensors is legal in Pipe.
|
||||
with pytest.raises(TypeError):
|
||||
model(42)
|
||||
|
||||
|
||||
def test_named_children():
|
||||
a = nn.Linear(1, 1)
|
||||
b = nn.Linear(1, 1)
|
||||
|
||||
model = nn.Sequential(OrderedDict([("a", a), ("b", b)]))
|
||||
model = Pipe(model, [1, 1], devices=["cpu", "cpu"])
|
||||
|
||||
names = set(n for n, _ in model.named_modules())
|
||||
assert "partitions.0.a" in names
|
||||
assert "partitions.1.b" in names
|
||||
|
||||
# Pipe doesn't support __getattr__. Unlike nn.Sequential, Pipe requires
|
||||
# several methods in its namespace.
|
||||
with pytest.raises(AttributeError):
|
||||
model.a
|
||||
|
||||
|
||||
def test_recommend_auto_balance():
|
||||
with pytest.raises(ValueError, match="torch.distributed._pipeline.sync.balance"):
|
||||
# balance is required
|
||||
Pipe(nn.Sequential())
|
||||
|
||||
with pytest.raises(ValueError, match="torch.distributed._pipeline.sync.balance"):
|
||||
# module and sum of balance have differen length (module: 0, sum of balance: 1)
|
||||
Pipe(nn.Sequential(), [1])
|
||||
|
||||
with pytest.raises(ValueError, match="torch.distributed._pipeline.sync.balance"):
|
||||
# module and sum of balance have different length (module: 2, sum of balance: 1)
|
||||
Pipe(nn.Sequential(nn.Linear(1, 1), nn.Linear(1, 1)), [1])
|
||||
|
||||
|
||||
def test_verify_module_non_sequential():
|
||||
with pytest.raises(TypeError, match="module must be nn.Sequential to be partitioned"):
|
||||
Pipe(nn.Module(), [1])
|
||||
|
||||
|
||||
def test_verify_module_duplicate_children():
|
||||
conv = nn.Conv2d(3, 3, 1)
|
||||
model = nn.Sequential(conv, conv)
|
||||
|
||||
with pytest.raises(ValueError, match="module with duplicate children is not supported"):
|
||||
Pipe(model, [1, 1])
|
||||
|
||||
|
||||
@skip_if_no_cuda
|
||||
def test_verify_module_duplicate_parameters_on_distinct_devices():
|
||||
class Surrogate(nn.Module):
|
||||
def __init__(self, module):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
|
||||
conv = nn.Conv2d(3, 3, 1)
|
||||
model = nn.Sequential(Surrogate(conv), Surrogate(conv))
|
||||
|
||||
with pytest.raises(ValueError, match="module with duplicate parameters on distinct devices is not supported"):
|
||||
Pipe(model, [1, 1], devices=["cpu", "cuda"])
|
||||
|
||||
|
||||
def test_verify_module_duplicate_parameters_on_same_device():
|
||||
class Surrogate(nn.Module):
|
||||
def __init__(self, module):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
|
||||
conv = nn.Conv2d(3, 3, 1)
|
||||
model = nn.Sequential(Surrogate(conv), Surrogate(conv))
|
||||
|
||||
Pipe(model, [1, 1], devices=["cpu", "cpu"])
|
||||
|
||||
|
||||
def test_forward_lockstep():
|
||||
timeline = []
|
||||
|
||||
class DelayedLog(nn.Module):
|
||||
def __init__(self, j, seconds):
|
||||
super().__init__()
|
||||
self.i = 0
|
||||
self.j = j
|
||||
self.seconds = seconds
|
||||
|
||||
def forward(self, x):
|
||||
time.sleep(self.seconds)
|
||||
|
||||
timeline.append((self.i, self.j))
|
||||
self.i += 1
|
||||
|
||||
return x
|
||||
|
||||
model = nn.Sequential(DelayedLog(0, seconds=0), DelayedLog(1, seconds=0.1))
|
||||
model = Pipe(model, balance=[1, 1], devices=["cpu", "cpu"], chunks=3)
|
||||
model(torch.rand(3, 1))
|
||||
|
||||
# Expected timeline: (Logs are recorded at !)
|
||||
#
|
||||
# Partition #0: 0! 1! 2!
|
||||
# Partition #1: 000! 111! 222!
|
||||
#
|
||||
assert timeline == [(0, 0), (1, 0), (0, 1), (2, 0), (1, 1), (2, 1)]
|
29
test/distributed/_pipeline/sync/test_pipeline.py
Normal file
29
test/distributed/_pipeline/sync/test_pipeline.py
Normal file
@ -0,0 +1,29 @@
|
||||
# Copyright 2019 Kakao Brain
|
||||
#
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
from torch.distributed._pipeline.sync.pipeline import clock_cycles
|
||||
|
||||
|
||||
def test_clock_cycles():
|
||||
assert list(clock_cycles(1, 1)) == [[(0, 0)]]
|
||||
assert list(clock_cycles(1, 3)) == [[(0, 0)], [(0, 1)], [(0, 2)]]
|
||||
assert list(clock_cycles(3, 1)) == [[(0, 0)], [(1, 0)], [(2, 0)]]
|
||||
|
||||
assert list(clock_cycles(3, 3)) == [ # noqa
|
||||
[(0, 0)],
|
||||
[(1, 0), (0, 1)],
|
||||
[(2, 0), (1, 1), (0, 2)],
|
||||
[(2, 1), (1, 2)],
|
||||
[(2, 2)],
|
||||
]
|
||||
|
||||
assert list(clock_cycles(4, 2)) == [ # noqa
|
||||
[(0, 0)],
|
||||
[(1, 0), (0, 1)],
|
||||
[(2, 0), (1, 1)],
|
||||
[(3, 0), (2, 1)],
|
||||
[(3, 1)],
|
||||
]
|
188
test/distributed/_pipeline/sync/test_stream.py
Normal file
188
test/distributed/_pipeline/sync/test_stream.py
Normal file
@ -0,0 +1,188 @@
|
||||
# Copyright 2019 Kakao Brain
|
||||
#
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from torch.distributed._pipeline.sync.stream import (
|
||||
CPUStream,
|
||||
current_stream,
|
||||
default_stream,
|
||||
get_device,
|
||||
is_cuda,
|
||||
new_stream,
|
||||
record_stream,
|
||||
use_device,
|
||||
use_stream,
|
||||
wait_stream,
|
||||
)
|
||||
|
||||
skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
|
||||
|
||||
|
||||
class TestNewStream:
|
||||
def test_new_stream_cpu(self):
|
||||
stream = new_stream(torch.device("cpu"))
|
||||
assert stream is CPUStream
|
||||
|
||||
@skip_if_no_cuda
|
||||
def test_new_stream_cuda(self):
|
||||
stream = new_stream(torch.device("cuda"))
|
||||
assert isinstance(stream, torch.cuda.Stream)
|
||||
assert stream != torch.cuda.default_stream()
|
||||
|
||||
|
||||
class TestCurrentStream:
|
||||
def test_current_stream_cpu(self):
|
||||
stream = current_stream(torch.device("cpu"))
|
||||
assert stream is CPUStream
|
||||
|
||||
@skip_if_no_cuda
|
||||
def test_current_stream_cuda(self):
|
||||
stream = current_stream(torch.device("cuda"))
|
||||
assert isinstance(stream, torch.cuda.Stream)
|
||||
assert stream == torch.cuda.current_stream()
|
||||
|
||||
|
||||
class TestDefaultStream:
|
||||
def test_default_stream_cpu(self):
|
||||
stream = default_stream(torch.device("cpu"))
|
||||
assert stream is CPUStream
|
||||
|
||||
@skip_if_no_cuda
|
||||
def test_default_stream_cuda(self):
|
||||
stream = default_stream(torch.device("cuda"))
|
||||
assert isinstance(stream, torch.cuda.Stream)
|
||||
assert stream == torch.cuda.default_stream()
|
||||
|
||||
|
||||
class TestUseDevice:
|
||||
def test_use_device_cpu(self):
|
||||
with use_device(torch.device("cpu")):
|
||||
pass
|
||||
|
||||
@skip_if_no_cuda
|
||||
def test_use_device_cuda(self):
|
||||
with use_device(torch.device("cuda")):
|
||||
pass
|
||||
|
||||
|
||||
class TestUseStream:
|
||||
def test_use_stream_cpu(self):
|
||||
with use_stream(CPUStream):
|
||||
pass
|
||||
|
||||
@skip_if_no_cuda
|
||||
def test_use_stream_cuda(self):
|
||||
stream = new_stream(torch.device("cuda"))
|
||||
with use_stream(stream):
|
||||
assert current_stream(torch.device("cuda")) == stream
|
||||
|
||||
|
||||
class TestGetDevice:
|
||||
def test_get_device_cpu(self):
|
||||
assert get_device(CPUStream).type == "cpu"
|
||||
|
||||
@skip_if_no_cuda
|
||||
def test_get_device_cuda(self):
|
||||
stream = current_stream(torch.device("cuda"))
|
||||
assert get_device(stream).type == "cuda"
|
||||
|
||||
|
||||
class TestWaitStream:
|
||||
def _test_wait_stream(self, source, target, cuda_sleep=None):
|
||||
with use_stream(target):
|
||||
if is_cuda(target):
|
||||
cuda_sleep(0.5)
|
||||
x = torch.ones(100, 100, device=get_device(target))
|
||||
|
||||
wait_stream(source, target)
|
||||
|
||||
with use_stream(source):
|
||||
assert x.sum().item() == 10000
|
||||
|
||||
def test_wait_stream_cpu_cpu(self):
|
||||
source = CPUStream
|
||||
target = CPUStream
|
||||
self._test_wait_stream(source, target)
|
||||
|
||||
@skip_if_no_cuda
|
||||
def test_wait_stream_cpu_cuda(self, cuda_sleep):
|
||||
source = CPUStream
|
||||
target = new_stream(torch.device("cuda"))
|
||||
self._test_wait_stream(source, target, cuda_sleep)
|
||||
|
||||
@skip_if_no_cuda
|
||||
def test_wait_stream_cuda_cpu(self, cuda_sleep):
|
||||
source = new_stream(torch.device("cuda"))
|
||||
target = CPUStream
|
||||
self._test_wait_stream(source, target, cuda_sleep)
|
||||
|
||||
@skip_if_no_cuda
|
||||
def test_wait_stream_cuda_cuda(self, cuda_sleep):
|
||||
source = current_stream(torch.device("cuda"))
|
||||
target = new_stream(torch.device("cuda"))
|
||||
self._test_wait_stream(source, target, cuda_sleep)
|
||||
|
||||
|
||||
class TestRecordStream:
|
||||
def test_record_stream_cpu(self):
|
||||
# It should silently ignore CPU tensors.
|
||||
x = torch.rand(1, device=torch.device("cpu"))
|
||||
record_stream(x, CPUStream)
|
||||
|
||||
@skip_if_no_cuda
|
||||
def test_record_stream_cuda(self, cuda_sleep):
|
||||
# This test detects unexpected block reallocation. For reliable test,
|
||||
# the stream to allocate tensors is isolated. The allocator will not
|
||||
# reuse free blocks which were allocated from another stream.
|
||||
stream_alloc = new_stream(torch.device("cuda"))
|
||||
with torch.cuda.stream(stream_alloc):
|
||||
x = torch.rand(1, device=torch.device("cuda"))
|
||||
|
||||
stream = new_stream(torch.device("cuda"))
|
||||
record_stream(x, stream)
|
||||
with use_stream(stream):
|
||||
cuda_sleep(0.5)
|
||||
|
||||
# 'x' is deleted at Python's perspective. But the block of 'x' is still
|
||||
# required for 'stream'. 'y' shouldn't be allocated to the block.
|
||||
data_ptr = x.data_ptr()
|
||||
del x
|
||||
stream_alloc.synchronize()
|
||||
with torch.cuda.stream(stream_alloc):
|
||||
y = torch.rand(1, device=torch.device("cuda"))
|
||||
assert y.data_ptr() != data_ptr
|
||||
|
||||
# Pause Python until 'stream' finishes tasks queued. Now the block of
|
||||
# 'x' is free to be reallocated.
|
||||
wait_stream(CPUStream, stream)
|
||||
with torch.cuda.stream(stream_alloc):
|
||||
z = torch.rand(1, device=torch.device("cuda"))
|
||||
assert z.data_ptr() == data_ptr
|
||||
|
||||
@skip_if_no_cuda
|
||||
def test_record_stream_shifted_view(self, cuda_sleep):
|
||||
# Issue: https://github.com/pytorch/pytorch/issues/27366
|
||||
stream_alloc = new_stream(torch.device("cuda"))
|
||||
with torch.cuda.stream(stream_alloc):
|
||||
x = torch.rand(2, device=torch.device("cuda"))
|
||||
|
||||
y = x[1:]
|
||||
assert y.data_ptr() > x.data_ptr()
|
||||
|
||||
stream = new_stream(torch.device("cuda"))
|
||||
with use_stream(stream):
|
||||
cuda_sleep(0.5)
|
||||
record_stream(y, stream)
|
||||
|
||||
data_ptr = x.data_ptr()
|
||||
del x, y
|
||||
|
||||
stream_alloc.synchronize()
|
||||
with torch.cuda.stream(stream_alloc):
|
||||
z = torch.rand(2, device=torch.device("cuda"))
|
||||
assert z.data_ptr() != data_ptr
|
43
test/distributed/_pipeline/sync/test_transparency.py
Normal file
43
test/distributed/_pipeline/sync/test_transparency.py
Normal file
@ -0,0 +1,43 @@
|
||||
# Copyright 2019 Kakao Brain
|
||||
#
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from torch.distributed._pipeline.sync import Pipe
|
||||
|
||||
|
||||
def test_simple_linears():
|
||||
def sum_grad(parameters):
|
||||
return sum([p.grad.sum() for p in parameters if p.grad is not None])
|
||||
|
||||
def zero_grad(parameters):
|
||||
for p in parameters:
|
||||
p.grad = None
|
||||
|
||||
inputs = torch.rand(8, 1)
|
||||
model = nn.Sequential(nn.Linear(1, 2), nn.Linear(2, 4), nn.Linear(4, 2), nn.Linear(2, 1),)
|
||||
|
||||
# Without Pipe
|
||||
outputs = model(inputs)
|
||||
loss = outputs.mean()
|
||||
loss.backward()
|
||||
|
||||
grad_without_pipe = sum_grad(model.parameters())
|
||||
|
||||
zero_grad(model.parameters())
|
||||
|
||||
# With Pipe
|
||||
model = Pipe(model, [2, 2], devices=["cpu", "cpu"], chunks=4)
|
||||
|
||||
outputs = model(inputs)
|
||||
loss = outputs.mean()
|
||||
loss.backward()
|
||||
|
||||
grad_with_pipe = sum_grad(model.parameters())
|
||||
|
||||
# Both grads should be identical.
|
||||
assert torch.allclose(grad_with_pipe, grad_without_pipe)
|
163
test/distributed/_pipeline/sync/test_worker.py
Normal file
163
test/distributed/_pipeline/sync/test_worker.py
Normal file
@ -0,0 +1,163 @@
|
||||
# Copyright 2019 Kakao Brain
|
||||
#
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
import threading
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from torch.distributed._pipeline.sync.microbatch import Batch
|
||||
from torch.distributed._pipeline.sync.stream import CPUStream
|
||||
from torch.distributed._pipeline.sync.worker import Task, spawn_workers
|
||||
|
||||
|
||||
class fake_device:
|
||||
"""A test double for :class:`torch.device`. Every fake device is different
|
||||
with each other.
|
||||
"""
|
||||
|
||||
type = "fake"
|
||||
index = None
|
||||
|
||||
|
||||
def test_join_running_workers():
|
||||
count = 0
|
||||
|
||||
def counter():
|
||||
nonlocal count
|
||||
time.sleep(0.1)
|
||||
count += 1
|
||||
return Batch(())
|
||||
|
||||
with spawn_workers([fake_device() for _ in range(10)]) as (in_queues, out_queues):
|
||||
|
||||
def call_in_worker(i, f):
|
||||
task = Task(CPUStream, compute=f, finalize=None)
|
||||
in_queues[i].put(task)
|
||||
|
||||
for i in range(10):
|
||||
call_in_worker(i, counter)
|
||||
|
||||
# There's no nondeterminism because 'spawn_workers' joins all running
|
||||
# workers.
|
||||
assert count == 10
|
||||
|
||||
|
||||
def test_join_running_workers_with_exception():
|
||||
class ExpectedException(Exception):
|
||||
pass
|
||||
|
||||
count = 0
|
||||
|
||||
def counter():
|
||||
nonlocal count
|
||||
time.sleep(0.1)
|
||||
count += 1
|
||||
return Batch(())
|
||||
|
||||
with pytest.raises(ExpectedException):
|
||||
with spawn_workers([fake_device() for _ in range(10)]) as (in_queues, out_queues):
|
||||
|
||||
def call_in_worker(i, f):
|
||||
task = Task(CPUStream, compute=f, finalize=None)
|
||||
in_queues[i].put(task)
|
||||
|
||||
for i in range(10):
|
||||
call_in_worker(i, counter)
|
||||
|
||||
raise ExpectedException
|
||||
|
||||
# There's no nondeterminism because only 1 task can be placed in input
|
||||
# queues.
|
||||
assert count == 10
|
||||
|
||||
|
||||
def test_compute_multithreading():
|
||||
"""Task.compute should be executed on multiple threads."""
|
||||
thread_ids = set()
|
||||
|
||||
def log_thread_id():
|
||||
thread_id = threading.current_thread().ident
|
||||
thread_ids.add(thread_id)
|
||||
return Batch(())
|
||||
|
||||
with spawn_workers([fake_device() for _ in range(2)]) as (in_queues, out_queues):
|
||||
for i in range(2):
|
||||
t = Task(CPUStream, compute=log_thread_id, finalize=None)
|
||||
in_queues[i].put(t)
|
||||
for i in range(2):
|
||||
out_queues[i].get()
|
||||
|
||||
assert len(thread_ids) == 2
|
||||
|
||||
|
||||
def test_compute_success():
|
||||
"""Task.compute returns (True, (task, batch)) on success."""
|
||||
|
||||
def _42():
|
||||
return Batch(torch.tensor(42))
|
||||
|
||||
with spawn_workers([torch.device("cpu")]) as (in_queues, out_queues):
|
||||
t = Task(CPUStream, compute=_42, finalize=None)
|
||||
in_queues[0].put(t)
|
||||
ok, (task, batch) = out_queues[0].get()
|
||||
|
||||
assert ok
|
||||
assert task is t
|
||||
assert isinstance(batch, Batch)
|
||||
assert batch[0].item() == 42
|
||||
|
||||
|
||||
def test_compute_exception():
|
||||
"""Task.compute returns (False, exc_info) on failure."""
|
||||
|
||||
def zero_div():
|
||||
0 / 0
|
||||
|
||||
with spawn_workers([torch.device("cpu")]) as (in_queues, out_queues):
|
||||
t = Task(CPUStream, compute=zero_div, finalize=None)
|
||||
in_queues[0].put(t)
|
||||
ok, exc_info = out_queues[0].get()
|
||||
|
||||
assert not ok
|
||||
assert isinstance(exc_info, tuple)
|
||||
assert issubclass(exc_info[0], ZeroDivisionError)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("grad_mode", [True, False])
|
||||
def test_grad_mode(grad_mode):
|
||||
def detect_grad_enabled():
|
||||
x = torch.rand(1, requires_grad=torch.is_grad_enabled())
|
||||
return Batch(x)
|
||||
|
||||
with torch.set_grad_enabled(grad_mode):
|
||||
with spawn_workers([torch.device("cpu")]) as (in_queues, out_queues):
|
||||
task = Task(CPUStream, compute=detect_grad_enabled, finalize=None)
|
||||
in_queues[0].put(task)
|
||||
|
||||
ok, (_, batch) = out_queues[0].get()
|
||||
|
||||
assert ok
|
||||
assert batch[0].requires_grad == grad_mode
|
||||
|
||||
|
||||
def test_worker_per_device():
|
||||
cpu = torch.device("cpu")
|
||||
cpu0 = torch.device("cpu", index=0)
|
||||
fake1 = fake_device()
|
||||
fake2 = fake_device()
|
||||
|
||||
with spawn_workers([cpu, cpu, cpu0, fake1, fake2]) as (in_queues, out_queues):
|
||||
assert len(in_queues) == len(out_queues) == 5
|
||||
|
||||
# 0: cpu, 1: cpu, 2: cpu0
|
||||
assert in_queues[0] is in_queues[1] is in_queues[2]
|
||||
assert out_queues[0] is out_queues[1] is out_queues[2]
|
||||
|
||||
# 3: fake1, 4: fake2
|
||||
assert in_queues[3] is not in_queues[4]
|
||||
assert out_queues[3] is not out_queues[4]
|
@ -1,6 +1,7 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
from datetime import datetime
|
||||
import importlib
|
||||
import modulefinder
|
||||
@ -95,6 +96,54 @@ TESTS = [
|
||||
'test_fx_experimental',
|
||||
'test_functional_autograd_benchmark',
|
||||
'test_package',
|
||||
'distributed/_pipeline/sync/skip/test_api',
|
||||
'distributed/_pipeline/sync/skip/test_gpipe',
|
||||
'distributed/_pipeline/sync/skip/test_inspect_skip_layout',
|
||||
'distributed/_pipeline/sync/skip/test_leak',
|
||||
'distributed/_pipeline/sync/skip/test_portal',
|
||||
'distributed/_pipeline/sync/skip/test_stash_pop',
|
||||
'distributed/_pipeline/sync/skip/test_tracker',
|
||||
'distributed/_pipeline/sync/skip/test_verify_skippables',
|
||||
'distributed/_pipeline/sync/test_balance',
|
||||
'distributed/_pipeline/sync/test_bugs',
|
||||
'distributed/_pipeline/sync/test_checkpoint',
|
||||
'distributed/_pipeline/sync/test_copy',
|
||||
'distributed/_pipeline/sync/test_deferred_batch_norm',
|
||||
'distributed/_pipeline/sync/test_dependency',
|
||||
'distributed/_pipeline/sync/test_inplace',
|
||||
'distributed/_pipeline/sync/test_microbatch',
|
||||
'distributed/_pipeline/sync/test_phony',
|
||||
'distributed/_pipeline/sync/test_pipe',
|
||||
'distributed/_pipeline/sync/test_pipeline',
|
||||
'distributed/_pipeline/sync/test_stream',
|
||||
'distributed/_pipeline/sync/test_transparency',
|
||||
'distributed/_pipeline/sync/test_worker',
|
||||
]
|
||||
|
||||
# Tests need to be run with pytest.
|
||||
USE_PYTEST_LIST = [
|
||||
'distributed/_pipeline/sync/skip/test_api',
|
||||
'distributed/_pipeline/sync/skip/test_gpipe',
|
||||
'distributed/_pipeline/sync/skip/test_inspect_skip_layout',
|
||||
'distributed/_pipeline/sync/skip/test_leak',
|
||||
'distributed/_pipeline/sync/skip/test_portal',
|
||||
'distributed/_pipeline/sync/skip/test_stash_pop',
|
||||
'distributed/_pipeline/sync/skip/test_tracker',
|
||||
'distributed/_pipeline/sync/skip/test_verify_skippables',
|
||||
'distributed/_pipeline/sync/test_balance',
|
||||
'distributed/_pipeline/sync/test_bugs',
|
||||
'distributed/_pipeline/sync/test_checkpoint',
|
||||
'distributed/_pipeline/sync/test_copy',
|
||||
'distributed/_pipeline/sync/test_deferred_batch_norm',
|
||||
'distributed/_pipeline/sync/test_dependency',
|
||||
'distributed/_pipeline/sync/test_inplace',
|
||||
'distributed/_pipeline/sync/test_microbatch',
|
||||
'distributed/_pipeline/sync/test_phony',
|
||||
'distributed/_pipeline/sync/test_pipe',
|
||||
'distributed/_pipeline/sync/test_pipeline',
|
||||
'distributed/_pipeline/sync/test_stream',
|
||||
'distributed/_pipeline/sync/test_transparency',
|
||||
'distributed/_pipeline/sync/test_worker',
|
||||
]
|
||||
|
||||
WINDOWS_BLOCKLIST = [
|
||||
@ -170,6 +219,28 @@ SLOW_TESTS = [
|
||||
'test_quantization',
|
||||
'test_determination',
|
||||
'test_futures',
|
||||
'distributed/_pipeline/sync/skip/test_api',
|
||||
'distributed/_pipeline/sync/skip/test_gpipe',
|
||||
'distributed/_pipeline/sync/skip/test_inspect_skip_layout',
|
||||
'distributed/_pipeline/sync/skip/test_leak',
|
||||
'distributed/_pipeline/sync/skip/test_portal',
|
||||
'distributed/_pipeline/sync/skip/test_stash_pop',
|
||||
'distributed/_pipeline/sync/skip/test_tracker',
|
||||
'distributed/_pipeline/sync/skip/test_verify_skippables',
|
||||
'distributed/_pipeline/sync/test_balance',
|
||||
'distributed/_pipeline/sync/test_bugs',
|
||||
'distributed/_pipeline/sync/test_checkpoint',
|
||||
'distributed/_pipeline/sync/test_copy',
|
||||
'distributed/_pipeline/sync/test_deferred_batch_norm',
|
||||
'distributed/_pipeline/sync/test_dependency',
|
||||
'distributed/_pipeline/sync/test_inplace',
|
||||
'distributed/_pipeline/sync/test_microbatch',
|
||||
'distributed/_pipeline/sync/test_phony',
|
||||
'distributed/_pipeline/sync/test_pipe',
|
||||
'distributed/_pipeline/sync/test_pipeline',
|
||||
'distributed/_pipeline/sync/test_stream',
|
||||
'distributed/_pipeline/sync/test_transparency',
|
||||
'distributed/_pipeline/sync/test_worker',
|
||||
]
|
||||
_DEP_MODULES_CACHE: Dict[str, set] = {}
|
||||
|
||||
@ -762,12 +833,15 @@ def main():
|
||||
failure_messages = []
|
||||
try:
|
||||
for test in selected_tests:
|
||||
err_message = run_test_module(test, test_directory, options)
|
||||
options_clone = copy.deepcopy(options)
|
||||
if test in USE_PYTEST_LIST:
|
||||
options_clone.pytest = True
|
||||
err_message = run_test_module(test, test_directory, options_clone)
|
||||
if err_message is None:
|
||||
continue
|
||||
has_failed = True
|
||||
failure_messages.append(err_message)
|
||||
if not options.continue_through_error:
|
||||
if not options_clone.continue_through_error:
|
||||
raise RuntimeError(err_message)
|
||||
print_to_stderr(err_message)
|
||||
finally:
|
||||
|
0
torch/distributed/_pipeline/__init__.py
Normal file
0
torch/distributed/_pipeline/__init__.py
Normal file
27
torch/distributed/_pipeline/sync/LICENSE
Normal file
27
torch/distributed/_pipeline/sync/LICENSE
Normal file
@ -0,0 +1,27 @@
|
||||
Copyright 2019-2020 Kakao Brain
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from this
|
||||
software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
|
||||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
POSSIBILITY OF SUCH DAMAGE.
|
11
torch/distributed/_pipeline/sync/__init__.py
Normal file
11
torch/distributed/_pipeline/sync/__init__.py
Normal file
@ -0,0 +1,11 @@
|
||||
# Copyright 2019 Kakao Brain
|
||||
#
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""A Pipe implementation in PyTorch."""
|
||||
from .checkpoint import is_checkpointing, is_recomputing
|
||||
from .pipe import Pipe
|
||||
|
||||
__all__ = ["Pipe", "is_checkpointing", "is_recomputing"]
|
164
torch/distributed/_pipeline/sync/balance/__init__.py
Normal file
164
torch/distributed/_pipeline/sync/balance/__init__.py
Normal file
@ -0,0 +1,164 @@
|
||||
# Copyright 2019 Kakao Brain
|
||||
#
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""A helper to roughly balance a sequential module.
|
||||
|
||||
Usage::
|
||||
|
||||
import torch
|
||||
from torch.distributed._pipeline.sync import Pipe
|
||||
from torch.distributed._pipeline.sync.balance import balance_by_time
|
||||
|
||||
sample = torch.empty(128, 3, 224, 224)
|
||||
balance = balance_by_time(torch.cuda.device_count(), model, sample)
|
||||
|
||||
pipe = Pipe(model, balance, chunks=8)
|
||||
|
||||
"""
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
import torch.nn as nn
|
||||
|
||||
from . import blockpartition
|
||||
from .profile import profile_sizes, profile_times
|
||||
|
||||
__all__ = ["balance_by_time", "balance_by_size"]
|
||||
|
||||
|
||||
Device = Union[torch.device, int, str]
|
||||
|
||||
Tensors = Tuple[Tensor, ...]
|
||||
TensorOrTensors = Union[Tensor, Tensors]
|
||||
|
||||
|
||||
def balance_cost(cost: List[int], partitions: int) -> List[int]:
|
||||
partitioned = blockpartition.solve(cost, partitions)
|
||||
return [len(p) for p in partitioned]
|
||||
|
||||
|
||||
def balance_by_time(
|
||||
partitions: int,
|
||||
module: nn.Sequential,
|
||||
sample: TensorOrTensors,
|
||||
*,
|
||||
timeout: float = 1.0,
|
||||
device: Device = torch.device("cuda"),
|
||||
) -> List[int]:
|
||||
"""Naive automatic balancing by elapsed time per layer.
|
||||
::
|
||||
|
||||
sample = torch.empty(128, 3, 224, 224)
|
||||
balance = balance_by_time(torch.cuda.device_count(), model, sample)
|
||||
pipe = Pipe(model, balance, chunks=8)
|
||||
|
||||
Args:
|
||||
partitions (int):
|
||||
intended number of partitions
|
||||
module (torch.nn.Sequential):
|
||||
sequential module to be partitioned
|
||||
sample (torch.Tensor):
|
||||
example input with arbitrary batch size
|
||||
|
||||
Keyword Args:
|
||||
timeout (float):
|
||||
profiling iterates again if the timeout (in second) is not exceeded
|
||||
(default: ``1.0``)
|
||||
device ('cpu' or 'cuda' device):
|
||||
CPU or CUDA device where each layer is profiled (default: the
|
||||
current CUDA device)
|
||||
|
||||
Returns:
|
||||
A list of number of layers in each partition. Use it for the `balance`
|
||||
parameter of :class:`~torchpipe.Pipe`.
|
||||
|
||||
.. note::
|
||||
`module` and `sample` must be placed on the same device.
|
||||
|
||||
"""
|
||||
times = profile_times(module, sample, timeout, torch.device(device))
|
||||
return balance_cost(times, partitions)
|
||||
|
||||
|
||||
def balance_by_size(
|
||||
partitions: int,
|
||||
module: nn.Sequential,
|
||||
input: TensorOrTensors,
|
||||
*,
|
||||
chunks: int = 1,
|
||||
param_scale: float = 2.0,
|
||||
device: Device = torch.device("cuda"),
|
||||
) -> List[int]:
|
||||
"""Naive automatic balancing by CUDA memory usage per layer.
|
||||
|
||||
During training, required memory for parameters depends on which optimizer
|
||||
is used. Optimizers may use buffers for each parameter to track
|
||||
optimization statistics internally, such as momentum buffer in SGD.
|
||||
|
||||
To get more reliable size based balance, you should specify `param_scale`
|
||||
with regard to your optimizer. The default `param_scale` is 2 instead of 1
|
||||
due to gradient accumulation which is necessary for every optimizer.
|
||||
|
||||
Follow this guide to choose correct `param_scale` for typical optimizers:
|
||||
|
||||
========= ============= =========================================
|
||||
Optimizer `param_scale` Internal State
|
||||
========= ============= =========================================
|
||||
SGD 2--3 (momentum_buffer)
|
||||
Adam 4--5 exp_avg, exp_avg_sq, (max_exp_avg_sq)
|
||||
Adadelta 4 square_avg, acc_delta
|
||||
Adagrad 3 sum
|
||||
RMSprop 3--5 square_avg, (momentum_buffer), (grad_avg)
|
||||
========= ============= =========================================
|
||||
|
||||
Here's a simple example with the Adam optimizer::
|
||||
|
||||
balance = balance_by_size(
|
||||
torch.cuda.device_count(),
|
||||
model,
|
||||
|
||||
# Same size with mini-batch to train
|
||||
torch.empty(1024, 3, 224, 224),
|
||||
|
||||
# Number of micro-batches to train with Pipe
|
||||
chunks=8,
|
||||
|
||||
# 4 for Adam
|
||||
param_scale=4.0,
|
||||
)
|
||||
|
||||
pipe = Pipe(model, balance, chunks=8)
|
||||
adam = Adam(pipe.parameters())
|
||||
|
||||
Args:
|
||||
partitions (int):
|
||||
intended number of partitions
|
||||
module (torch.nn.Sequential):
|
||||
sequential module to be partitioned
|
||||
input (torch.Tensor):
|
||||
example mini-batch with the same size to train
|
||||
|
||||
Keyword Args:
|
||||
chunks (int):
|
||||
number of micro-batches will be used to train (default: ``1``)
|
||||
param_scale (float):
|
||||
how many copies of parameters would be allocated for training. It
|
||||
depends on optimizer. See the above guide. (default: ``2.0``)
|
||||
device ('cuda' device):
|
||||
CUDA device where each layer is profiled (default: the current CUDA
|
||||
device)
|
||||
|
||||
Returns:
|
||||
A list of number of layers in each partition. Use it for the `balance`
|
||||
parameter of :class:`~torchpipe.Pipe`.
|
||||
|
||||
.. note::
|
||||
`module` and `input` must be placed on the same CUDA device.
|
||||
|
||||
"""
|
||||
sizes = profile_sizes(module, input, chunks, param_scale, torch.device(device))
|
||||
return balance_cost(sizes, partitions)
|
95
torch/distributed/_pipeline/sync/balance/blockpartition.py
Normal file
95
torch/distributed/_pipeline/sync/balance/blockpartition.py
Normal file
@ -0,0 +1,95 @@
|
||||
# Copyright 2019 Kakao Brain
|
||||
#
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""Implements "Block Partitions of Sequences" by Imre Bárány et al.
|
||||
|
||||
Paper: https://arxiv.org/pdf/1308.2452.pdf
|
||||
|
||||
"""
|
||||
from typing import Iterator, List, Tuple
|
||||
|
||||
__all__ = ["solve"]
|
||||
|
||||
|
||||
def solve(sequence: List[int], partitions: int = 1) -> List[List[int]]:
|
||||
"""Splits a sequence into several partitions to minimize variance for each
|
||||
partition.
|
||||
|
||||
The result might not be optimal. However, it can be done only in O(kn³),
|
||||
where k is the number of partitions and n is the length of the sequence.
|
||||
|
||||
"""
|
||||
if partitions < 1:
|
||||
raise ValueError(f"partitions must be a positive integer ({partitions} < 1)")
|
||||
|
||||
n = len(sequence)
|
||||
if n < partitions:
|
||||
raise ValueError(f"sequence is shorter than intended partitions ({n} < {partitions})")
|
||||
|
||||
# Normalize the sequence in [0, 1].
|
||||
minimum = min(sequence)
|
||||
maximum = max(sequence) - minimum
|
||||
|
||||
normal_sequence: List[float]
|
||||
if maximum == 0:
|
||||
normal_sequence = [0 for _ in sequence]
|
||||
else:
|
||||
normal_sequence = [(x - minimum) / maximum for x in sequence]
|
||||
|
||||
splits = [n // partitions * (x + 1) for x in range(partitions - 1)] + [n]
|
||||
|
||||
def block_size(i: int) -> float:
|
||||
start = splits[i - 1] if i > 0 else 0
|
||||
stop = splits[i]
|
||||
return sum(normal_sequence[start:stop])
|
||||
|
||||
def leaderboard() -> Iterator[Tuple[float, int]]:
|
||||
return ((block_size(i), i) for i in range(partitions))
|
||||
|
||||
while True:
|
||||
"""
|
||||
(1) Fix p ∈ [k] with M(P) = bp. So Bp is a maximal block of P.
|
||||
"""
|
||||
# max_size: M(P)
|
||||
max_size, p = max(leaderboard())
|
||||
|
||||
while True:
|
||||
"""
|
||||
(2) If M(P) ≤ m(P) + 1, then stop.
|
||||
"""
|
||||
# min_size: m(P)
|
||||
min_size, q = min(leaderboard())
|
||||
|
||||
if max_size <= min_size + 1:
|
||||
return [sequence[i:j] for i, j in zip([0] + splits[:-1], splits)]
|
||||
|
||||
"""
|
||||
(3) If M(P) > m(P) + 1, then let m(P) = bq for the q ∈ [k] which is
|
||||
closest to p (ties broken arbitrarily). Thus Bq is a minimal block
|
||||
of P. Let Bh be the block next to Bq between Bp and Bq. (Note that
|
||||
Bh is a non-empty block: if it were, then m(P) = 0 and we should
|
||||
have chosen Bh instead of Bq.)
|
||||
"""
|
||||
if p < q:
|
||||
"""
|
||||
So either p < q and then h = q−1 and we define P ∗ by moving
|
||||
the last element from Bh = Bq−1 to Bq,
|
||||
"""
|
||||
h = q - 1
|
||||
splits[h] -= 1
|
||||
else:
|
||||
"""
|
||||
or q < p, and then h = q + 1 and P ∗ is obtained by moving the
|
||||
first element of Bh = Bq+1 to Bq.
|
||||
"""
|
||||
h = q + 1
|
||||
splits[q] += 1
|
||||
|
||||
"""
|
||||
Set P = P ∗ . If p = h, then go to (1), else go to (2).
|
||||
"""
|
||||
if p == h:
|
||||
break
|
114
torch/distributed/_pipeline/sync/balance/profile.py
Normal file
114
torch/distributed/_pipeline/sync/balance/profile.py
Normal file
@ -0,0 +1,114 @@
|
||||
# Copyright 2019 Kakao Brain
|
||||
#
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""Per-layer profilers."""
|
||||
import copy
|
||||
import time
|
||||
from typing import Generator, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
import torch.nn as nn
|
||||
|
||||
from ..microbatch import Batch
|
||||
|
||||
__all__: List[str] = []
|
||||
|
||||
|
||||
Device = Union[torch.device, int, str]
|
||||
|
||||
Tensors = Tuple[Tensor, ...]
|
||||
TensorOrTensors = Union[Tensor, Tensors]
|
||||
|
||||
|
||||
def layerwise_sandbox(module: nn.Sequential, device: torch.device,) -> Generator[nn.Module, None, None]:
|
||||
"""Copies layers for ease to profile. It doesn't modify the given
|
||||
module.
|
||||
"""
|
||||
for layer in module:
|
||||
layer_copy = copy.deepcopy(layer)
|
||||
layer_copy.to(device)
|
||||
layer_copy.train()
|
||||
yield layer_copy
|
||||
|
||||
|
||||
def detach(batch: Batch) -> None:
|
||||
"""Detaches from autograd graph."""
|
||||
for i, x in enumerate(batch):
|
||||
batch[i] = x.detach().requires_grad_(x.requires_grad)
|
||||
|
||||
|
||||
def profile_times(module: nn.Sequential, sample: TensorOrTensors, timeout: float, device: torch.device,) -> List[int]:
|
||||
"""Profiles elapsed times per layer."""
|
||||
if any(p.grad is not None for p in module.parameters()):
|
||||
raise ValueError("some parameter already has gradient")
|
||||
|
||||
_batch = Batch(sample)
|
||||
for i, x in enumerate(_batch):
|
||||
_batch[i] = x.detach().to(device).requires_grad_(x.requires_grad)
|
||||
|
||||
time_bufs: List[List[float]] = [[] for _ in module]
|
||||
begun_at = time.time()
|
||||
|
||||
while time.time() - begun_at < timeout:
|
||||
batch = _batch
|
||||
|
||||
for i, layer in enumerate(layerwise_sandbox(module, device)):
|
||||
detach(batch)
|
||||
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize(device)
|
||||
tick = time.time()
|
||||
|
||||
# Forward
|
||||
batch = batch.call(layer)
|
||||
|
||||
# Backward
|
||||
backward_tensors = tuple(y for y in batch if y.requires_grad)
|
||||
if backward_tensors:
|
||||
torch.autograd.backward(backward_tensors, backward_tensors)
|
||||
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize(device)
|
||||
tock = time.time()
|
||||
|
||||
time_bufs[i].append(tock - tick)
|
||||
|
||||
us = 1_000_000
|
||||
return [sum(int(t * us) for t in buf) for buf in time_bufs]
|
||||
|
||||
|
||||
def profile_sizes(
|
||||
module: nn.Sequential, input: TensorOrTensors, chunks: int, param_scale: float, device: torch.device,
|
||||
) -> List[int]:
|
||||
"""Profiles CUDA memory usage per layer."""
|
||||
if device.type != "cuda":
|
||||
raise ValueError("size profiler supports only CUDA device")
|
||||
|
||||
batch = Batch(input)
|
||||
sizes: List[int] = []
|
||||
|
||||
latent_scale = batch[0].size(0) / chunks
|
||||
for i, x in enumerate(batch):
|
||||
batch[i] = x[:1].detach().to(device).requires_grad_(x.requires_grad)
|
||||
|
||||
for layer in layerwise_sandbox(module, device):
|
||||
detach(batch)
|
||||
|
||||
# Detect memory usage at forward.
|
||||
memory_before = torch.cuda.memory_allocated(device)
|
||||
batch = batch.call(layer)
|
||||
memory_after = torch.cuda.memory_allocated(device)
|
||||
latent_size = memory_after - memory_before
|
||||
|
||||
# Analyze size of parameters.
|
||||
param_size = sum(p.storage().size() * p.storage().element_size() for p in layer.parameters())
|
||||
|
||||
# Combine size of parameters and activations with normalize scales.
|
||||
size = latent_size * latent_scale + param_size * param_scale
|
||||
sizes.append(int(size))
|
||||
|
||||
return sizes
|
6
torch/distributed/_pipeline/sync/balance/py.typed
Normal file
6
torch/distributed/_pipeline/sync/balance/py.typed
Normal file
@ -0,0 +1,6 @@
|
||||
# Copyright 2019 Kakao Brain
|
||||
#
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
159
torch/distributed/_pipeline/sync/batchnorm.py
Normal file
159
torch/distributed/_pipeline/sync/batchnorm.py
Normal file
@ -0,0 +1,159 @@
|
||||
# Copyright 2019 Kakao Brain
|
||||
#
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""Tracks the running statistics per mini-batch instead of micro-batch."""
|
||||
from typing import Optional, TypeVar, cast
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from .checkpoint import is_recomputing
|
||||
|
||||
__all__ = ["DeferredBatchNorm"]
|
||||
|
||||
|
||||
TModule = TypeVar("TModule", bound=nn.Module)
|
||||
|
||||
|
||||
class DeferredBatchNorm(_BatchNorm):
|
||||
"""A BatchNorm layer tracks multiple micro-batches to update running
|
||||
statistics per mini-batch.
|
||||
"""
|
||||
|
||||
sum: Tensor
|
||||
sum_squares: Tensor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_features: int,
|
||||
eps: float = 1e-5,
|
||||
momentum: Optional[float] = 0.1,
|
||||
affine: bool = True,
|
||||
chunks: int = 1,
|
||||
) -> None:
|
||||
super().__init__(num_features, eps, momentum, affine, track_running_stats=True)
|
||||
|
||||
self.register_buffer("sum", torch.zeros_like(self.running_mean))
|
||||
self.register_buffer("sum_squares", torch.zeros_like(self.running_var))
|
||||
|
||||
self.counter = 0
|
||||
self.tracked = 0
|
||||
self.chunks = chunks
|
||||
|
||||
def _check_input_dim(self, input: Tensor) -> None:
|
||||
# It's the typical _check_input_dim() implementation in PyTorch.
|
||||
if input.dim() <= 2:
|
||||
raise ValueError("expected at least 3D input (got %dD input)" % input.dim())
|
||||
|
||||
def _track(self, input: Tensor) -> bool:
|
||||
"""Tracks statistics of a micro-batch."""
|
||||
# Dimensions except channel. For example, (0, 2, 3) is for BatchNorm2d.
|
||||
dim = [0]
|
||||
dim.extend(range(2, input.dim()))
|
||||
|
||||
with torch.no_grad():
|
||||
self.sum += input.sum(dim)
|
||||
self.sum_squares += (input ** 2).sum(dim)
|
||||
|
||||
size = input.size().numel() // input.size(1)
|
||||
self.counter += size
|
||||
self.tracked += 1
|
||||
|
||||
return self.tracked == self.chunks
|
||||
|
||||
def _commit(self) -> None:
|
||||
"""Updates the running statistics of a mini-batch."""
|
||||
exponential_average_factor = 0.0
|
||||
self.num_batches_tracked += 1
|
||||
if self.momentum is None: # use cumulative moving average
|
||||
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
|
||||
else: # use exponential moving average
|
||||
exponential_average_factor = self.momentum
|
||||
|
||||
mean = self.sum / self.counter
|
||||
var = self.sum_squares / self.counter - mean ** 2
|
||||
|
||||
# Calculate the exponential moving average here.
|
||||
m = exponential_average_factor
|
||||
|
||||
self.running_mean *= 1 - m
|
||||
self.running_mean += mean * m
|
||||
|
||||
self.running_var *= 1 - m
|
||||
self.running_var += var * m
|
||||
|
||||
self.sum.zero_()
|
||||
self.sum_squares.zero_()
|
||||
self.counter = 0
|
||||
self.tracked = 0
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor: # type: ignore
|
||||
if not self.training:
|
||||
# Don't train parameters on the evaluation mode.
|
||||
return F.batch_norm(
|
||||
input,
|
||||
running_mean=self.running_mean,
|
||||
running_var=self.running_var,
|
||||
weight=self.weight,
|
||||
bias=self.bias,
|
||||
training=False,
|
||||
momentum=0.0,
|
||||
eps=self.eps,
|
||||
)
|
||||
|
||||
if not is_recomputing():
|
||||
# Track a micro-batch on the training mode
|
||||
# but not under a recomputation.
|
||||
tracked_enough = self._track(input)
|
||||
|
||||
# Update the running statistics for a mini-batch
|
||||
# if it has tracked enough micro-batches.
|
||||
if tracked_enough:
|
||||
self._commit()
|
||||
|
||||
# Normalize a micro-batch and train the parameters.
|
||||
return F.batch_norm(
|
||||
input,
|
||||
running_mean=None,
|
||||
running_var=None,
|
||||
weight=self.weight,
|
||||
bias=self.bias,
|
||||
training=True,
|
||||
momentum=0.0,
|
||||
eps=self.eps,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def convert_deferred_batch_norm(cls, module: TModule, chunks: int = 1) -> TModule:
|
||||
"""Converts a :class:`nn.BatchNorm` or underlying
|
||||
:class:`nn.BatchNorm`s into :class:`DeferredBatchNorm`::
|
||||
|
||||
from torchvision.models.resnet import resnet101
|
||||
from torchpipe.batchnorm import DeferredBatchNorm
|
||||
model = resnet101()
|
||||
model = DeferredBatchNorm.convert_deferred_batch_norm(model)
|
||||
|
||||
"""
|
||||
if isinstance(module, DeferredBatchNorm) and module.chunks is chunks:
|
||||
return cast(TModule, module)
|
||||
|
||||
module_output: nn.Module = module
|
||||
|
||||
if isinstance(module, _BatchNorm) and module.track_running_stats:
|
||||
module_output = DeferredBatchNorm(module.num_features, module.eps, module.momentum, module.affine, chunks)
|
||||
if module.affine:
|
||||
module_output.register_parameter("weight", module.weight)
|
||||
module_output.register_parameter("bias", module.bias)
|
||||
module_output.register_buffer("running_mean", module.running_mean)
|
||||
module_output.register_buffer("running_var", module.running_var)
|
||||
module_output.register_buffer("num_batches_tracked", module.num_batches_tracked)
|
||||
|
||||
for name, child in module.named_children():
|
||||
module_output.add_module(name, cls.convert_deferred_batch_norm(child, chunks))
|
||||
|
||||
return cast(TModule, module_output)
|
317
torch/distributed/_pipeline/sync/checkpoint.py
Normal file
317
torch/distributed/_pipeline/sync/checkpoint.py
Normal file
@ -0,0 +1,317 @@
|
||||
# Copyright 2019 Kakao Brain
|
||||
#
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""Checkpointing with preceding recomputation.
|
||||
|
||||
PyTorch already provides the official checkpointing utilities in
|
||||
:mod:`torch.utils.checkpoint`. The official checkpointing combines
|
||||
recomputation and recursive backpropagation into one autograd function named
|
||||
``CheckpointFunction``. Hence, the recomputation can be started only when the
|
||||
gradients arrive to the function. In Pipe, the recomputation needs to precede
|
||||
the gradient arrival to minimize the GPU idle time.
|
||||
|
||||
We solve this problem by introducing separate autograd functions named
|
||||
:class:`Recompute` and :class:`Checkpoint`. Each function represents
|
||||
recomputation and recursive backpropagation, respectively. We can manipulate
|
||||
the control flow in aspect of both the autograd engine and CUDA with a pair of
|
||||
the functions.
|
||||
|
||||
Specifically, we place CUDA stream synchronization between :class:`Recompute`
|
||||
and :class:`Checkpoint` to delay only :class:`Checkpoint` until the gradient is
|
||||
copied entirely.
|
||||
|
||||
"""
|
||||
from collections import deque
|
||||
from contextlib import contextmanager
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, Deque, Generator, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import ByteTensor, Tensor
|
||||
import torch.autograd
|
||||
|
||||
from .dependency import fork, join
|
||||
from .microbatch import Batch
|
||||
from .phony import get_phony
|
||||
|
||||
__all__ = ["is_checkpointing", "is_recomputing"]
|
||||
|
||||
|
||||
Tensors = Tuple[Tensor, ...]
|
||||
TensorOrTensors = Union[Tensor, Tensors]
|
||||
|
||||
# Types for shared memory between Checkpoint and Recompute.
|
||||
Recomputed = Tuple[TensorOrTensors, Tensors] # (output, input_leaf)
|
||||
RNGStates = Tuple[ByteTensor, Optional[ByteTensor]] # (cpu_rng_state, gpu_rng_state)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import Protocol
|
||||
else:
|
||||
Protocol = object
|
||||
|
||||
|
||||
# Protocol with __call__ instead of Callable can be used as an attribute type.
|
||||
# See: https://github.com/python/mypy/issues/708#issuecomment-561735949
|
||||
class Function(Protocol):
|
||||
def __call__(self, input: TensorOrTensors) -> TensorOrTensors:
|
||||
...
|
||||
|
||||
|
||||
def checkpoint(function: Function, input: TensorOrTensors) -> TensorOrTensors:
|
||||
"""Makes a checkpoint with a simple interface like
|
||||
:func:`torch.utils.checkpoint.checkpoint`. It's only used to test or debug
|
||||
:class:`Checkpoint` and :class:`Recompute` without boilerplate.
|
||||
"""
|
||||
batch = Batch(input)
|
||||
|
||||
chk = Checkpointing(function, batch)
|
||||
batch = chk.checkpoint()
|
||||
chk.recompute(batch)
|
||||
|
||||
return batch.tensor_or_tensors
|
||||
|
||||
|
||||
class Checkpointing:
|
||||
"""Generates a pair of :class:`Checkpoint` and :class:`Recompute`."""
|
||||
|
||||
def __init__(self, function: Function, batch: Batch) -> None:
|
||||
self.function = function
|
||||
self.batch = batch
|
||||
|
||||
# Shared memory between Checkpoint and Recompute. 1-length deque is
|
||||
# used for mutability and length limitation.
|
||||
self.recomputed: Deque[Recomputed] = deque(maxlen=1)
|
||||
self.rng_states: Deque[RNGStates] = deque(maxlen=1)
|
||||
|
||||
def checkpoint(self) -> Batch:
|
||||
"""Returns a batch applied by :class:`Checkpoint`."""
|
||||
input_atomic = self.batch.atomic
|
||||
input = tuple(self.batch)
|
||||
|
||||
# Use a phony which requires grad to ensure that Checkpoint can be
|
||||
# tracked by the autograd engine even when none of the input tensors
|
||||
# require grad.
|
||||
phony = get_phony(self.batch[0].device, requires_grad=True)
|
||||
|
||||
output = Checkpoint.apply(phony, self.recomputed, self.rng_states, self.function, input_atomic, *input)
|
||||
|
||||
# Gradients are only supported for float Tensors.
|
||||
if isinstance(output, tuple):
|
||||
output = tuple([x if x.is_floating_point() else x.detach() for x in output])
|
||||
|
||||
return Batch(output)
|
||||
|
||||
def recompute(self, batch: Batch) -> None:
|
||||
"""Applies :class:`Recompute` to the batch in place."""
|
||||
input_atomic = self.batch.atomic
|
||||
input = tuple(self.batch)
|
||||
|
||||
# batch[0] is always requiring grad, because it has been passed
|
||||
# checkpoint with a phony requiring grad.
|
||||
batch[0], phony = fork(batch[0])
|
||||
phony = Recompute.apply(phony, self.recomputed, self.rng_states, self.function, input_atomic, *input)
|
||||
batch[0] = join(batch[0], phony)
|
||||
|
||||
|
||||
class ThreadLocal(threading.local):
|
||||
def __init__(self) -> None:
|
||||
self.is_checkpointing = False
|
||||
self.is_recomputing = False
|
||||
|
||||
|
||||
thread_local = ThreadLocal()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def enable_checkpointing() -> Generator[None, None, None]:
|
||||
"""Makes :func:`is_checkpointing` return :data:`True` within a context."""
|
||||
orig = thread_local.is_checkpointing
|
||||
thread_local.is_checkpointing = True
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
thread_local.is_checkpointing = orig
|
||||
|
||||
|
||||
@contextmanager
|
||||
def enable_recomputing() -> Generator[None, None, None]:
|
||||
"""Makes :func:`is_recomputing` return :data:`True` within a context."""
|
||||
orig = thread_local.is_recomputing
|
||||
thread_local.is_recomputing = True
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
thread_local.is_recomputing = orig
|
||||
|
||||
|
||||
def is_checkpointing() -> bool:
|
||||
"""Whether the current forward propagation is under checkpointing.
|
||||
|
||||
Returns:
|
||||
bool: :data:`True` if it's under checkpointing.
|
||||
|
||||
"""
|
||||
return thread_local.is_checkpointing
|
||||
|
||||
|
||||
def is_recomputing() -> bool:
|
||||
"""Whether the current forward propagation is under checkpoint
|
||||
recomputation. Use this to prevent duplicated side-effects at forward
|
||||
propagation::
|
||||
|
||||
class Counter(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.counter = 0
|
||||
|
||||
def forward(self, input):
|
||||
if not is_recomputing():
|
||||
self.counter += 1
|
||||
return input
|
||||
|
||||
Returns:
|
||||
bool: :data:`True` if it's under checkpoint recomputation.
|
||||
|
||||
.. seealso:: :ref:`Detecting Recomputation`
|
||||
|
||||
"""
|
||||
return thread_local.is_recomputing
|
||||
|
||||
|
||||
class Context:
|
||||
"""The common interface between the :class:`Checkpoint` and
|
||||
:class:`Recompute` context.
|
||||
"""
|
||||
|
||||
recomputed: Deque[Recomputed]
|
||||
rng_states: Deque[RNGStates]
|
||||
function: Function
|
||||
input_atomic: bool
|
||||
|
||||
saved_tensors: Tuple[Tensor, ...]
|
||||
|
||||
def save_for_backward(self, *tensors: Tensor) -> None: # pragma: no cover
|
||||
pass
|
||||
|
||||
|
||||
def save_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> None:
|
||||
""":meth:`Checkpoint.forward` captures the current PyTorch's random number
|
||||
generator states at CPU and GPU to reuse in :meth:`Recompute.backward`.
|
||||
|
||||
.. seealso:: :ref:`Referential Transparency`
|
||||
|
||||
"""
|
||||
cpu_rng_state = torch.get_rng_state()
|
||||
|
||||
gpu_rng_state: Optional[ByteTensor]
|
||||
if device.type == "cuda":
|
||||
gpu_rng_state = torch.cuda.get_rng_state(device)
|
||||
else:
|
||||
gpu_rng_state = None
|
||||
|
||||
rng_states.append((cpu_rng_state, gpu_rng_state))
|
||||
|
||||
|
||||
@contextmanager
|
||||
def restore_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> Generator[None, None, None]:
|
||||
""":meth:`Recompute.backward` restores the random number generator states
|
||||
captured by :func:`save_rng_states` within its context.
|
||||
|
||||
.. seealso:: :ref:`Referential Transparency`
|
||||
|
||||
"""
|
||||
cpu_rng_state, gpu_rng_state = rng_states.pop()
|
||||
|
||||
gpu_devices: List[torch.device] = []
|
||||
if device.type == "cuda":
|
||||
gpu_devices.append(device)
|
||||
|
||||
with torch.random.fork_rng(gpu_devices):
|
||||
torch.set_rng_state(cpu_rng_state)
|
||||
if gpu_rng_state is not None:
|
||||
torch.cuda.set_rng_state(gpu_rng_state, device)
|
||||
yield
|
||||
|
||||
|
||||
class Checkpoint(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# type: ignore
|
||||
def forward(
|
||||
ctx: Context,
|
||||
phony: Tensor,
|
||||
recomputed: Deque[Recomputed],
|
||||
rng_states: Deque[RNGStates],
|
||||
function: Function,
|
||||
input_atomic: bool,
|
||||
*input: Tensor,
|
||||
) -> TensorOrTensors:
|
||||
ctx.recomputed = recomputed
|
||||
ctx.rng_states = rng_states
|
||||
|
||||
save_rng_states(input[0].device, ctx.rng_states)
|
||||
|
||||
ctx.function = function
|
||||
ctx.input_atomic = input_atomic
|
||||
ctx.save_for_backward(*input)
|
||||
|
||||
with torch.no_grad(), enable_checkpointing():
|
||||
output = function(input[0] if input_atomic else input)
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Context, *grad_output: Tensor,) -> Tuple[Optional[Tensor], ...]: # pragma: no cover
|
||||
output, input_leaf = ctx.recomputed.pop()
|
||||
|
||||
if isinstance(output, tuple):
|
||||
tensors = output
|
||||
else:
|
||||
tensors = (output,)
|
||||
if any(y.requires_grad for y in tensors):
|
||||
tensors = tuple([x for x in tensors if x.requires_grad])
|
||||
torch.autograd.backward(tensors, grad_output)
|
||||
|
||||
grad_input: List[Optional[Tensor]] = [None, None, None, None, None]
|
||||
grad_input.extend(x.grad for x in input_leaf)
|
||||
return tuple(grad_input)
|
||||
|
||||
|
||||
class Recompute(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# type: ignore
|
||||
def forward(
|
||||
ctx: Context,
|
||||
phony: Tensor,
|
||||
recomputed: Deque[Recomputed],
|
||||
rng_states: Deque[RNGStates],
|
||||
function: Function,
|
||||
input_atomic: bool,
|
||||
*input: Tensor,
|
||||
) -> Tensor:
|
||||
ctx.recomputed = recomputed
|
||||
ctx.rng_states = rng_states
|
||||
|
||||
ctx.function = function
|
||||
ctx.input_atomic = input_atomic
|
||||
ctx.save_for_backward(*input)
|
||||
|
||||
return phony
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Context, *grad_output: Tensor) -> Tuple[None, ...]: # pragma: no cover
|
||||
input = ctx.saved_tensors
|
||||
input_leaf = tuple(x.detach().requires_grad_(x.requires_grad) for x in input)
|
||||
|
||||
with restore_rng_states(input[0].device, ctx.rng_states):
|
||||
with torch.enable_grad(), enable_recomputing():
|
||||
output = ctx.function(input_leaf[0] if ctx.input_atomic else input_leaf)
|
||||
|
||||
ctx.recomputed.append((output, input_leaf))
|
||||
|
||||
grad_input: List[None] = [None, None, None, None, None]
|
||||
grad_input.extend(None for _ in ctx.saved_tensors)
|
||||
return tuple(grad_input)
|
104
torch/distributed/_pipeline/sync/copy.py
Normal file
104
torch/distributed/_pipeline/sync/copy.py
Normal file
@ -0,0 +1,104 @@
|
||||
# Copyright 2019 Kakao Brain
|
||||
#
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""Autograd functions for stream-aware CUDA copy. It is used to overlap copy
|
||||
and computation on the same GPU.
|
||||
"""
|
||||
from collections import deque
|
||||
from typing import Deque, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from .stream import AbstractStream, current_stream, get_device, record_stream, use_stream, wait_stream
|
||||
|
||||
__all__: List[str] = []
|
||||
|
||||
|
||||
Tensors = Tuple[Tensor, ...]
|
||||
|
||||
|
||||
# Common interface between :class:`Copy` and :class:`Wait`.
|
||||
class Context:
|
||||
prev_stream: AbstractStream
|
||||
next_stream: AbstractStream
|
||||
|
||||
|
||||
class Copy(torch.autograd.Function):
|
||||
"""Copies tensors on specific streams."""
|
||||
|
||||
@staticmethod
|
||||
# type: ignore
|
||||
def forward(ctx: Context, prev_stream: AbstractStream, next_stream: AbstractStream, *input: Tensor,) -> Tensors:
|
||||
ctx.prev_stream = prev_stream
|
||||
ctx.next_stream = next_stream
|
||||
|
||||
output = []
|
||||
output_stream = current_stream(get_device(next_stream))
|
||||
|
||||
with use_stream(prev_stream), use_stream(next_stream):
|
||||
for x in input:
|
||||
y = x.to(get_device(next_stream), non_blocking=True)
|
||||
output.append(y)
|
||||
|
||||
# 'prev_stream' is not where 'x' has been allocated.
|
||||
record_stream(x, prev_stream)
|
||||
# 'y' has been allocated on 'next_stream'.
|
||||
# It might be used on the current stream captured as 'output_stream'.
|
||||
record_stream(y, output_stream)
|
||||
|
||||
return tuple(output)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Context, *grad_output: Tensor,) -> Tuple[Optional[Tensor], ...]:
|
||||
prev_stream = ctx.prev_stream
|
||||
next_stream = ctx.next_stream
|
||||
|
||||
grad_input: Deque[Tensor] = deque(maxlen=len(grad_output))
|
||||
input_stream = current_stream(get_device(prev_stream))
|
||||
|
||||
with use_stream(prev_stream), use_stream(next_stream):
|
||||
for x in reversed(grad_output):
|
||||
y = x.to(get_device(prev_stream), non_blocking=True)
|
||||
grad_input.appendleft(y)
|
||||
|
||||
# 'next_stream' is not where 'x' has been allocated.
|
||||
record_stream(x, next_stream)
|
||||
# 'y' has been allocated on 'prev_stream'.
|
||||
# It might be used on the current stream captured as 'input_stream'.
|
||||
record_stream(y, input_stream)
|
||||
|
||||
grad_streams: Tuple[Optional[Tensor], ...] = (None, None)
|
||||
return grad_streams + tuple(grad_input)
|
||||
|
||||
|
||||
class Wait(torch.autograd.Function):
|
||||
"""Synchronizes a stream to another stream.
|
||||
|
||||
Place it just before you want to start an operation on the next stream,
|
||||
provided that all operations on the previous stream are done.
|
||||
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
# type: ignore
|
||||
def forward(ctx: Context, prev_stream: AbstractStream, next_stream: AbstractStream, *input: Tensor,) -> Tensors:
|
||||
ctx.prev_stream = prev_stream
|
||||
ctx.next_stream = next_stream
|
||||
|
||||
wait_stream(next_stream, prev_stream)
|
||||
|
||||
return tuple(x.detach() for x in input)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Context, *grad_input: Tensor,) -> Tuple[Optional[Tensor], ...]:
|
||||
prev_stream = ctx.prev_stream
|
||||
next_stream = ctx.next_stream
|
||||
|
||||
wait_stream(prev_stream, next_stream)
|
||||
|
||||
grad_streams: Tuple[Optional[Tensor], ...] = (None, None)
|
||||
return grad_streams + grad_input
|
54
torch/distributed/_pipeline/sync/dependency.py
Normal file
54
torch/distributed/_pipeline/sync/dependency.py
Normal file
@ -0,0 +1,54 @@
|
||||
# Copyright 2019 Kakao Brain
|
||||
#
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""Arbitrary dependency between two autograd lanes."""
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from .phony import get_phony
|
||||
|
||||
__all__: List[str] = []
|
||||
|
||||
|
||||
def fork(input: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
"""Branches out from an autograd lane of the given tensor."""
|
||||
if torch.is_grad_enabled() and input.requires_grad:
|
||||
input, phony = Fork.apply(input)
|
||||
else:
|
||||
phony = get_phony(input.device, requires_grad=False)
|
||||
|
||||
return input, phony
|
||||
|
||||
|
||||
class Fork(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx: "Fork", input: Tensor) -> Tuple[Tensor, Tensor]: # type: ignore
|
||||
phony = get_phony(input.device, requires_grad=False)
|
||||
return input.detach(), phony.detach()
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: "Fork", grad_input: Tensor, grad_grad: Tensor) -> Tensor: # type: ignore
|
||||
return grad_input
|
||||
|
||||
|
||||
def join(input: Tensor, phony: Tensor) -> Tensor:
|
||||
"""Merges two autograd lanes."""
|
||||
if torch.is_grad_enabled() and (input.requires_grad or phony.requires_grad):
|
||||
input = Join.apply(input, phony)
|
||||
|
||||
return input
|
||||
|
||||
|
||||
class Join(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx: "Join", input: Tensor, phony: Tensor) -> Tensor: # type: ignore
|
||||
return input.detach()
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: "Join", grad_input: Tensor) -> Tuple[Tensor, None]: # type: ignore
|
||||
return grad_input, None
|
185
torch/distributed/_pipeline/sync/microbatch.py
Normal file
185
torch/distributed/_pipeline/sync/microbatch.py
Normal file
@ -0,0 +1,185 @@
|
||||
# Copyright 2019 Kakao Brain
|
||||
#
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""Manipulation of micro-batches."""
|
||||
import typing
|
||||
from typing import Callable, Iterable, Iterator, List, Tuple, Union, cast
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
import torch.cuda.comm
|
||||
|
||||
__all__: List[str] = []
|
||||
|
||||
|
||||
Tensors = Tuple[Tensor, ...]
|
||||
TensorOrTensors = Union[Tensor, Tensors]
|
||||
Function = Callable[[TensorOrTensors], TensorOrTensors]
|
||||
|
||||
|
||||
class Batch:
|
||||
"""An abstraction of an atomic tensor or a tuple of tensors. This
|
||||
eliminates every boilerplate code to classify an atomic tensor or a tuple
|
||||
of tensors.
|
||||
::
|
||||
|
||||
x = generate_tensor_or_tensors()
|
||||
x = Batch(x)
|
||||
|
||||
# in-place update
|
||||
x[0] = F.apply(x[0])
|
||||
x[:] = F.apply(*x)
|
||||
|
||||
# f(x) if x is a tensor.
|
||||
# f(*x) if x is a tuple of tensors.
|
||||
# y is also a batch.
|
||||
y = x.call(f)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, value: TensorOrTensors) -> None:
|
||||
self.value = value
|
||||
self.atomic = torch.is_tensor(value)
|
||||
|
||||
@property
|
||||
def tensor(self) -> Tensor:
|
||||
"""Retrieves the underlying tensor."""
|
||||
if not self.atomic:
|
||||
raise AttributeError("not atomic batch")
|
||||
return cast(Tensor, self.value)
|
||||
|
||||
@property
|
||||
def tensors(self) -> Tensors:
|
||||
"""Retrieves the underlying tensors."""
|
||||
if self.atomic:
|
||||
raise AttributeError("batch is atomic")
|
||||
return cast(Tensors, self.value)
|
||||
|
||||
@property
|
||||
def tensor_or_tensors(self) -> TensorOrTensors:
|
||||
"""Retrieves the underlying tensor or tensors regardless of type."""
|
||||
return self.value
|
||||
|
||||
def call(self, function: Function) -> "Batch":
|
||||
"""Calls a function by the underlying tensor or tensors. It also wraps
|
||||
the output with :class:`Batch`.
|
||||
"""
|
||||
return Batch(function(self.value))
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Batch[atomic={self.atomic!r}]({self.value!r})"
|
||||
|
||||
def __iter__(self) -> Iterator[Tensor]:
|
||||
if self.atomic:
|
||||
yield self.tensor
|
||||
else:
|
||||
yield from self.tensors
|
||||
|
||||
def __len__(self) -> int:
|
||||
return 1 if self.atomic else len(self.tensors)
|
||||
|
||||
def __getitem__(self, index: int) -> Tensor:
|
||||
if not self.atomic:
|
||||
return self.tensors[index]
|
||||
|
||||
if index != 0:
|
||||
raise IndexError("atomic batch allows index 0 only")
|
||||
|
||||
return self.tensor
|
||||
|
||||
# NOTE(sublee): pyflakes can't detect "overload" instead of "typing.overload".
|
||||
@typing.overload
|
||||
def __setitem__(self, index: int, value: Tensor) -> None:
|
||||
...
|
||||
|
||||
@typing.overload
|
||||
def __setitem__(self, index: slice, value: Tensors) -> None:
|
||||
...
|
||||
|
||||
def __setitem__(self, index: Union[int, slice], value: TensorOrTensors) -> None:
|
||||
if isinstance(index, int):
|
||||
value = cast(Tensor, value)
|
||||
self._setitem_by_index(index, value)
|
||||
else:
|
||||
value = cast(Tensors, value)
|
||||
self._setitem_by_slice(index, value)
|
||||
|
||||
def _setitem_by_index(self, index: int, value: Tensor) -> None:
|
||||
if not self.atomic:
|
||||
i = index
|
||||
self.value = self.value[:i] + (value,) + self.value[i + 1 :]
|
||||
return
|
||||
|
||||
if index != 0:
|
||||
raise IndexError("atomic batch allows index 0 only")
|
||||
|
||||
self.value = value
|
||||
|
||||
def _setitem_by_slice(self, index: slice, value: Tensors) -> None:
|
||||
if not (index.start is index.stop is index.step is None):
|
||||
raise NotImplementedError("only slice [:] supported")
|
||||
|
||||
if not self.atomic:
|
||||
self.value = value
|
||||
return
|
||||
|
||||
if len(value) != 1:
|
||||
raise IndexError("atomic batch cannot be replaced with multiple tensors")
|
||||
|
||||
self.value = value[0]
|
||||
|
||||
|
||||
def check(input: TensorOrTensors) -> None:
|
||||
"""Checks whether the input is a tensor or tensors.
|
||||
|
||||
Raises:
|
||||
TypeError: input is not a tensor or tensors.
|
||||
|
||||
"""
|
||||
if isinstance(input, tuple):
|
||||
for x in input:
|
||||
check(x)
|
||||
return
|
||||
|
||||
if not isinstance(input, Tensor):
|
||||
raise TypeError(f"expected Tensor, but got {input.__class__.__name__}")
|
||||
|
||||
|
||||
def scatter(input: TensorOrTensors, chunks: int) -> List[Batch]:
|
||||
"""Splits an input mini-batch into multiple micro-batches."""
|
||||
inputs: Iterable[TensorOrTensors]
|
||||
|
||||
if isinstance(input, Tensor):
|
||||
inputs = input.chunk(chunks)
|
||||
else:
|
||||
rotated: List[Tensors] = []
|
||||
|
||||
for tensor in input:
|
||||
tensors = tensor.chunk(chunks)
|
||||
rotated.append(cast(Tensors, tensors))
|
||||
|
||||
inputs = zip(*rotated)
|
||||
|
||||
return [Batch(x) for x in inputs]
|
||||
|
||||
|
||||
def gather(outputs: List[Batch]) -> TensorOrTensors:
|
||||
"""Concatenates output micro-batches into a mini-batch."""
|
||||
output: TensorOrTensors
|
||||
|
||||
if outputs[0].atomic:
|
||||
tensors = tuple(b.tensor for b in outputs)
|
||||
output = torch.cat(tensors)
|
||||
else:
|
||||
rotated = [b.tensors for b in outputs]
|
||||
output_buf = []
|
||||
|
||||
for tensors in zip(*rotated):
|
||||
output_buf.append(torch.cat(tensors))
|
||||
|
||||
output = tuple(output_buf)
|
||||
|
||||
return output
|
49
torch/distributed/_pipeline/sync/phony.py
Normal file
49
torch/distributed/_pipeline/sync/phony.py
Normal file
@ -0,0 +1,49 @@
|
||||
# Copyright 2019 Kakao Brain
|
||||
#
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""Provides phony for arbitrary dependency in a autograd graph."""
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from .stream import default_stream, use_stream
|
||||
|
||||
__all__: List[str] = []
|
||||
|
||||
|
||||
_phonies: Dict[Tuple[torch.device, bool], Tensor] = {}
|
||||
|
||||
|
||||
def get_phony(device: torch.device, *, requires_grad: bool) -> Tensor:
|
||||
"""Gets a phony. Phony is tensor without space. It is useful to make
|
||||
arbitrary dependency in a autograd graph because it doesn't require any
|
||||
gradient accumulation.
|
||||
|
||||
.. note::
|
||||
|
||||
Phonies for each device are cached. If an autograd function gets a phony
|
||||
internally, the phony must be detached to be returned. Otherwise, the
|
||||
autograd engine will mutate the cached phony in-place::
|
||||
|
||||
class Phonify(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
phony = get_phony(input.device, requires_grad=False)
|
||||
return phony.detach() # detach() is necessary.
|
||||
|
||||
"""
|
||||
key = (device, requires_grad)
|
||||
|
||||
try:
|
||||
phony = _phonies[key]
|
||||
except KeyError:
|
||||
with use_stream(default_stream(device)):
|
||||
phony = torch.empty(0, device=device, requires_grad=requires_grad)
|
||||
|
||||
_phonies[key] = phony
|
||||
|
||||
return phony
|
394
torch/distributed/_pipeline/sync/pipe.py
Normal file
394
torch/distributed/_pipeline/sync/pipe.py
Normal file
@ -0,0 +1,394 @@
|
||||
# Copyright 2019 Kakao Brain
|
||||
#
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""The Pipe interface."""
|
||||
from collections import OrderedDict
|
||||
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union, cast
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
import torch.autograd
|
||||
import torch.cuda
|
||||
|
||||
from . import microbatch
|
||||
from .batchnorm import DeferredBatchNorm
|
||||
from .pipeline import Pipeline
|
||||
from .skip.layout import inspect_skip_layout
|
||||
from .skip.skippable import verify_skippables
|
||||
from .stream import AbstractStream, new_stream
|
||||
|
||||
__all__ = ["Pipe"]
|
||||
|
||||
|
||||
Device = Union[torch.device, int, str]
|
||||
Devices = Union[Iterable[Device], List[Device]]
|
||||
|
||||
Tensors = Tuple[Tensor, ...]
|
||||
TensorOrTensors = Union[Tensor, Tensors]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
Module = nn.Module[TensorOrTensors]
|
||||
NamedModules = OrderedDict[str, Module]
|
||||
else:
|
||||
Module = nn.Module
|
||||
NamedModules = OrderedDict
|
||||
|
||||
|
||||
def recommend_auto_balance(message: str) -> str:
|
||||
"""Expands a message with recommendation to :mod:`torchpipe.balance`."""
|
||||
return f"""{message}
|
||||
|
||||
If your model is still under development, its optimal balance would change
|
||||
frequently. In this case, we highly recommend 'torch.distributed._pipeline.sync.balance' for
|
||||
naive automatic balancing:
|
||||
|
||||
from torch.distributed._pipeline.sync import Pipe
|
||||
from torch.distributed._pipeline.sync.balance import balance_by_time
|
||||
|
||||
partitions = torch.cuda.device_count()
|
||||
sample = torch.empty(...)
|
||||
balance = balance_by_time(partitions, model, sample)
|
||||
|
||||
model = Pipe(model, balance, ...)
|
||||
"""
|
||||
|
||||
|
||||
def verify_module(module: nn.Sequential) -> None:
|
||||
if not isinstance(module, nn.Sequential):
|
||||
raise TypeError("module must be nn.Sequential to be partitioned")
|
||||
|
||||
named_children = list(module.named_children())
|
||||
if len(named_children) != len(module):
|
||||
raise ValueError("module with duplicate children is not supported")
|
||||
|
||||
|
||||
def verify_splitting(
|
||||
module: nn.Sequential, partitions: List[nn.Sequential], balance: Iterable[int], devices: List[torch.device]
|
||||
) -> None:
|
||||
num_parameters = len(list(module.parameters()))
|
||||
num_child_parameters = sum(len(list(child.parameters())) for child in module.children())
|
||||
if num_parameters == num_child_parameters:
|
||||
return
|
||||
|
||||
for i in range(len(partitions)):
|
||||
for j in range(i + 1, len(partitions)):
|
||||
parti = partitions[i]
|
||||
partj = partitions[j]
|
||||
if devices[i] == devices[j]:
|
||||
continue
|
||||
for p in parti.parameters():
|
||||
for q in partj.parameters():
|
||||
if p is q:
|
||||
raise ValueError("module with duplicate parameters on distinct devices is not supported")
|
||||
|
||||
|
||||
class BalanceError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
def split_module(
|
||||
module: nn.Sequential, balance: Iterable[int], devices: List[torch.device],
|
||||
) -> Tuple[List[nn.Sequential], List[int], List[torch.device]]:
|
||||
"""Splits a module into multiple partitions.
|
||||
|
||||
Returns:
|
||||
A tuple of (partitions, balance, devices).
|
||||
|
||||
Partitions are represented as a :class:`~torch.nn.ModuleList` whose
|
||||
item is a partition. All layers in a partition are placed in the
|
||||
same device.
|
||||
|
||||
Raises:
|
||||
BalanceError:
|
||||
wrong balance
|
||||
IndexError:
|
||||
the number of devices is fewer than the number of partitions.
|
||||
|
||||
"""
|
||||
balance = list(balance)
|
||||
|
||||
if len(module) != sum(balance):
|
||||
raise BalanceError(
|
||||
"module and sum of balance have different length "
|
||||
f"(module: {len(module)}, sum of balance: {sum(balance)})"
|
||||
)
|
||||
|
||||
if any(x <= 0 for x in balance):
|
||||
raise BalanceError(f"all balance numbers must be positive integer (balance: {balance})")
|
||||
|
||||
if len(balance) > len(devices):
|
||||
raise IndexError(
|
||||
"too few devices to hold given partitions " f"(devices: {len(devices)}, partitions: {len(balance)})"
|
||||
)
|
||||
|
||||
j = 0
|
||||
partitions = []
|
||||
layers: NamedModules = OrderedDict()
|
||||
|
||||
for name, layer in module.named_children():
|
||||
layers[name] = layer
|
||||
|
||||
if len(layers) == balance[j]:
|
||||
# Group buffered layers as a partition.
|
||||
partition = nn.Sequential(layers)
|
||||
|
||||
device = devices[j]
|
||||
partition.to(device)
|
||||
|
||||
partitions.append(partition)
|
||||
|
||||
# Prepare for the next partition.
|
||||
layers.clear()
|
||||
j += 1
|
||||
|
||||
partitions = cast(List[nn.Sequential], nn.ModuleList(partitions))
|
||||
del devices[j:]
|
||||
|
||||
return partitions, balance, devices
|
||||
|
||||
|
||||
MOVING_DENIED = TypeError("denied to move parameters and buffers, " "because Pipe should manage device placement")
|
||||
|
||||
|
||||
class Pipe(Module):
|
||||
"""Wraps an arbitrary :class:`nn.Sequential <torch.nn.Sequential>` module
|
||||
to train on Pipe_. If the module requires lots of memory, Pipe will be
|
||||
very efficient.
|
||||
::
|
||||
|
||||
model = nn.Sequential(a, b, c, d)
|
||||
model = Pipe(model, balance=[1, 1, 1, 1], chunks=8)
|
||||
output = model(input)
|
||||
|
||||
.. _Pipe: https://arxiv.org/abs/1811.06965
|
||||
|
||||
Pipe combines pipeline parallelism with checkpointing to reduce peak
|
||||
memory required to train while minimizing device under-utilization.
|
||||
|
||||
You should determine the balance when defining a :class:`Pipe` module, as
|
||||
balancing will not be done automatically. The module will be partitioned
|
||||
into multiple devices according to the given balance. You may rely on
|
||||
heuristics to find your own optimal configuration.
|
||||
|
||||
Args:
|
||||
module (torch.nn.Sequential):
|
||||
sequential module to be parallelized
|
||||
balance (ints):
|
||||
list of number of layers in each partition
|
||||
|
||||
Keyword Args:
|
||||
devices (iterable of devices):
|
||||
devices to use (default: all CUDA devices)
|
||||
chunks (int):
|
||||
number of micro-batches (default: ``1``)
|
||||
checkpoint (str):
|
||||
when to enable checkpointing, one of ``'always'``,
|
||||
``'except_last'``, or ``'never'`` (default: ``'except_last'``)
|
||||
deferred_batch_norm (bool):
|
||||
whether to use deferred BatchNorm moving statistics (default:
|
||||
:data:`False`, see :ref:`Deferred Batch Normalization` for more
|
||||
details)
|
||||
|
||||
Raises:
|
||||
TypeError:
|
||||
the module is not a :class:`nn.Sequential <torch.nn.Sequential>`.
|
||||
ValueError:
|
||||
invalid arguments, or wrong balance
|
||||
IndexError:
|
||||
the number of devices is fewer than the number of partitions.
|
||||
|
||||
"""
|
||||
|
||||
#: The number of layers in each partition.
|
||||
balance: List[int] = []
|
||||
# ^^
|
||||
# The default value [] required for Sphinx's autoattribute.
|
||||
|
||||
#: The devices mapped to each partition.
|
||||
#:
|
||||
#: ``devices[-1]`` refers to the device of the last partition, which means
|
||||
#: it is the output device. Probably, you need to use it to transfer the
|
||||
#: target to calculate the loss without a device mismatch
|
||||
#: :exc:`RuntimeError`. For example::
|
||||
#:
|
||||
#: out_device = pipe.devices[-1]
|
||||
#:
|
||||
#: for input, target in loader:
|
||||
#: target = target.to(out_device, non_blocking=True)
|
||||
#: output = pipe(input)
|
||||
#: loss = F.cross_entropy(output, target)
|
||||
#:
|
||||
devices: List[torch.device] = []
|
||||
|
||||
#: The number of micro-batches.
|
||||
chunks: int = 1
|
||||
|
||||
#: The checkpoint mode to determine when to enable checkpointing. It is one
|
||||
#: of ``'always'``, ``'except_last'``, or ``'never'``.
|
||||
checkpoint: str = "except_last"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
module: nn.Sequential,
|
||||
balance: Optional[Iterable[int]] = None,
|
||||
*,
|
||||
devices: Optional[Devices] = None,
|
||||
chunks: int = chunks,
|
||||
checkpoint: str = checkpoint,
|
||||
deferred_batch_norm: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
chunks = int(chunks)
|
||||
checkpoint = str(checkpoint)
|
||||
|
||||
if balance is None:
|
||||
raise ValueError(recommend_auto_balance("balance is required"))
|
||||
if chunks <= 0:
|
||||
raise ValueError("number of chunks must be positive integer")
|
||||
if checkpoint not in ["always", "except_last", "never"]:
|
||||
raise ValueError("checkpoint is not one of 'always', 'except_last', or 'never'")
|
||||
|
||||
verify_module(module)
|
||||
|
||||
# Verify if the underlying skippable modules satisfy integrity. The
|
||||
# integrity can be verified before forward() because it is static.
|
||||
verify_skippables(module)
|
||||
|
||||
self.chunks = chunks
|
||||
self.checkpoint = checkpoint
|
||||
|
||||
if deferred_batch_norm:
|
||||
module = DeferredBatchNorm.convert_deferred_batch_norm(module, chunks)
|
||||
|
||||
if devices is None:
|
||||
devices = range(torch.cuda.device_count())
|
||||
devices = [torch.device(d) for d in devices]
|
||||
devices = cast(List[torch.device], devices)
|
||||
|
||||
try:
|
||||
self.partitions, self.balance, self.devices = split_module(module, balance, devices)
|
||||
except BalanceError as exc:
|
||||
raise ValueError(recommend_auto_balance(str(exc)))
|
||||
|
||||
verify_splitting(module, self.partitions, self.balance, self.devices)
|
||||
|
||||
self._copy_streams: List[List[AbstractStream]] = []
|
||||
self._skip_layout = inspect_skip_layout(self.partitions)
|
||||
|
||||
# Separate CUDA streams for copy.
|
||||
copy_streams = self._ensure_copy_streams()
|
||||
|
||||
# The micro-batch index where the checkpointing stops.
|
||||
checkpoint_stop = {"always": self.chunks, "except_last": self.chunks - 1, "never": 0}[self.checkpoint]
|
||||
|
||||
self.pipeline = Pipeline(self.partitions, self.devices, copy_streams, self._skip_layout, checkpoint_stop)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Counts the length of the underlying sequential module."""
|
||||
return sum(len(p) for p in self.partitions)
|
||||
|
||||
def __getitem__(self, index: int) -> nn.Module:
|
||||
"""Gets a layer in the underlying sequential module."""
|
||||
partitions = self.partitions
|
||||
if index < 0:
|
||||
partitions = partitions[::-1]
|
||||
|
||||
for partition in partitions:
|
||||
try:
|
||||
return partition[index]
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
shift = len(partition)
|
||||
|
||||
if index < 0:
|
||||
index += shift
|
||||
else:
|
||||
index -= shift
|
||||
|
||||
raise IndexError
|
||||
|
||||
def __iter__(self) -> Iterable[nn.Module]:
|
||||
"""Iterates over children of the underlying sequential module."""
|
||||
for partition in self.partitions:
|
||||
yield from partition
|
||||
|
||||
# Pipe should manage the device of each partition.
|
||||
# Deny cuda(), cpu(), and to() with device, by TypeError.
|
||||
def cuda(self, device: Optional[Device] = None) -> "Pipe":
|
||||
raise MOVING_DENIED
|
||||
|
||||
def cpu(self) -> "Pipe":
|
||||
raise MOVING_DENIED
|
||||
|
||||
def to(self, *args: Any, **kwargs: Any) -> "Pipe":
|
||||
# Deny these usages:
|
||||
#
|
||||
# - to(device[, dtype, non_blocking])
|
||||
# - to(tensor[, non_blocking])
|
||||
#
|
||||
# But allow this:
|
||||
#
|
||||
# - to(dtype[, non_blocking])
|
||||
#
|
||||
if "device" in kwargs or "tensor" in kwargs:
|
||||
raise MOVING_DENIED
|
||||
|
||||
if args:
|
||||
if isinstance(args[0], (torch.device, int, str)):
|
||||
raise MOVING_DENIED
|
||||
if torch.is_tensor(args[0]):
|
||||
raise MOVING_DENIED
|
||||
|
||||
return super().to(*args, **kwargs)
|
||||
|
||||
def _ensure_copy_streams(self) -> List[List[AbstractStream]]:
|
||||
"""Ensures that :class:`Pipe` caches CUDA streams for copy.
|
||||
|
||||
It's worth to cache CUDA streams although PyTorch already manages a
|
||||
pool of pre-allocated CUDA streams, because it may reduce GPU memory
|
||||
fragementation when the number of micro-batches is small.
|
||||
|
||||
"""
|
||||
if not self._copy_streams:
|
||||
for device in self.devices:
|
||||
self._copy_streams.append([new_stream(device) for _ in range(self.chunks)])
|
||||
|
||||
return self._copy_streams
|
||||
|
||||
def forward(self, input: TensorOrTensors) -> TensorOrTensors: # type: ignore
|
||||
""":class:`Pipe` is a fairly transparent module wrapper. It doesn't
|
||||
modify the input and output signature of the underlying module. But
|
||||
there's type restriction. Input and output have to be a
|
||||
:class:`~torch.Tensor` or a tuple of tensors. This restriction is
|
||||
applied at partition boundaries too.
|
||||
|
||||
Args:
|
||||
input (torch.Tensor or tensors): input mini-batch
|
||||
|
||||
Returns:
|
||||
tensor or tensors: output mini-batch
|
||||
|
||||
Raises:
|
||||
TypeError: input is not a tensor or tensors.
|
||||
|
||||
"""
|
||||
microbatch.check(input)
|
||||
|
||||
if not self.devices:
|
||||
# Empty sequential module is not illegal.
|
||||
return input
|
||||
|
||||
# Divide a mini-batch into micro-batches.
|
||||
batches = microbatch.scatter(input, self.chunks)
|
||||
|
||||
# Run pipeline parallelism.
|
||||
self.pipeline.run(batches)
|
||||
|
||||
# Merge the micro-batches into one mini-batch.
|
||||
output = microbatch.gather(batches)
|
||||
return output
|
257
torch/distributed/_pipeline/sync/pipeline.py
Normal file
257
torch/distributed/_pipeline/sync/pipeline.py
Normal file
@ -0,0 +1,257 @@
|
||||
# Copyright 2019 Kakao Brain
|
||||
#
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""The pipeline parallelism of Pipe."""
|
||||
from queue import Queue
|
||||
from types import TracebackType
|
||||
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Type, Union, cast
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from torch.autograd.profiler import record_function
|
||||
|
||||
from .checkpoint import Checkpointing
|
||||
from .copy import Copy, Wait
|
||||
from .dependency import fork, join
|
||||
from .microbatch import Batch
|
||||
from .skip.layout import SkipLayout
|
||||
from .skip.tracker import SkipTrackerThroughPotals, use_skip_tracker
|
||||
from .stream import AbstractStream, current_stream, use_device
|
||||
from .worker import Task, create_workers, join_workers
|
||||
|
||||
__all__: List[str] = []
|
||||
|
||||
|
||||
Tensors = Tuple[Tensor, ...]
|
||||
TensorOrTensors = Union[Tensor, Tensors]
|
||||
|
||||
ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType]
|
||||
|
||||
# Queue is generic only in stubs.
|
||||
# https://mypy.readthedocs.io/en/latest/common_issues.html#using-classes-that-are-generic-in-stubs-but-not-at-runtime
|
||||
if TYPE_CHECKING:
|
||||
InQueue = Queue[Optional["Task"]]
|
||||
OutQueue = Queue[Tuple[bool, Union[Tuple["Task", Batch], ExcInfo, None]]]
|
||||
else:
|
||||
InQueue = Queue
|
||||
OutQueue = Queue
|
||||
|
||||
|
||||
def depend(fork_from: Batch, join_to: Batch) -> None:
|
||||
fork_from[0], phony = fork(fork_from[0])
|
||||
join_to[0] = join(join_to[0], phony)
|
||||
|
||||
|
||||
def copy(batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream) -> None:
|
||||
batch[:] = Copy.apply(prev_stream, next_stream, *batch)
|
||||
# Gradients are only supported for float Tensors.
|
||||
batch[:] = tuple([x if x.is_floating_point() else x.detach() for x in batch])
|
||||
|
||||
|
||||
def wait(batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream) -> None:
|
||||
batch[:] = Wait.apply(prev_stream, next_stream, *batch)
|
||||
# Gradients are only supported for float Tensors.
|
||||
batch[:] = tuple([x if x.is_floating_point() else x.detach() for x in batch])
|
||||
|
||||
|
||||
def clock_cycles(m: int, n: int) -> Iterable[List[Tuple[int, int]]]:
|
||||
"""Generates schedules for each clock cycle."""
|
||||
# m: number of micro-batches
|
||||
# n: number of partitions
|
||||
# i: index of micro-batch
|
||||
# j: index of partition
|
||||
# k: clock number
|
||||
#
|
||||
# k (i,j) (i,j) (i,j)
|
||||
# - ----- ----- -----
|
||||
# 0 (0,0)
|
||||
# 1 (1,0) (0,1)
|
||||
# 2 (2,0) (1,1) (0,2)
|
||||
# 3 (2,1) (1,2)
|
||||
# 4 (2,2)
|
||||
for k in range(m + n - 1):
|
||||
yield [(k - j, j) for j in range(max(1 + k - m, 0), min(1 + k, n))]
|
||||
|
||||
|
||||
class Pipeline:
|
||||
"""The pipeline parallelism for Pipe."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
partitions: List[nn.Sequential],
|
||||
devices: List[torch.device],
|
||||
copy_streams: List[List[AbstractStream]],
|
||||
skip_layout: SkipLayout,
|
||||
checkpoint_stop: int,
|
||||
) -> None:
|
||||
self.partitions = partitions
|
||||
self.devices = devices
|
||||
self.copy_streams = copy_streams
|
||||
self.skip_layout = skip_layout
|
||||
self.checkpoint_stop = checkpoint_stop
|
||||
(self.in_queues, self.out_queues) = create_workers(devices)
|
||||
|
||||
def __del__(self) -> None:
|
||||
join_workers(self.in_queues, self.out_queues)
|
||||
|
||||
def run(self, batches: List[Batch]) -> None:
|
||||
"""Runs pipeline parallelism.
|
||||
|
||||
It modifies the given batches in place.
|
||||
|
||||
"""
|
||||
partitions = self.partitions
|
||||
devices = self.devices
|
||||
skip_layout = self.skip_layout
|
||||
|
||||
m = len(batches)
|
||||
n = len(partitions)
|
||||
|
||||
skip_trackers = [SkipTrackerThroughPotals(skip_layout) for _ in batches]
|
||||
|
||||
for schedule in clock_cycles(m, n):
|
||||
self.fence(batches, schedule, skip_trackers)
|
||||
self.compute(batches, schedule, skip_trackers)
|
||||
|
||||
def fence(
|
||||
self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals],
|
||||
) -> None:
|
||||
"""Copies micro-batches after computation for the previous
|
||||
micro-batches.
|
||||
"""
|
||||
copy_streams = self.copy_streams
|
||||
skip_layout = self.skip_layout
|
||||
|
||||
for i, j in schedule:
|
||||
# Ensure that batches[i-1] is executed after batches[i] in
|
||||
# backpropagation by an explicit dependency.
|
||||
if i != 0 and j != 0:
|
||||
depend(batches[i - 1], batches[i])
|
||||
|
||||
next_stream = copy_streams[j][i]
|
||||
|
||||
for prev_j, ns, name in skip_layout.copy_policy(j):
|
||||
prev_stream = copy_streams[prev_j][i]
|
||||
skip_trackers[i].copy(batches[i], prev_stream, next_stream, ns, name)
|
||||
|
||||
if j != 0:
|
||||
prev_stream = copy_streams[j - 1][i]
|
||||
copy(batches[i], prev_stream, next_stream)
|
||||
|
||||
def compute(
|
||||
self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals],
|
||||
) -> None:
|
||||
"""Runs tasks with synchronization to copy streams."""
|
||||
partitions = self.partitions
|
||||
devices = self.devices
|
||||
copy_streams = self.copy_streams
|
||||
checkpoint_stop = self.checkpoint_stop
|
||||
|
||||
# Disable checkpointing if in eval mode.
|
||||
if not self.partitions[0].training:
|
||||
checkpoint_stop = 0
|
||||
|
||||
n = len(partitions)
|
||||
streams = [current_stream(d) for d in devices]
|
||||
exc_info: Optional[ExcInfo] = None
|
||||
|
||||
# With checkpointing, the autograd graph looks like this diagram:
|
||||
# ┌─────┸──────┐
|
||||
# │ Copy │
|
||||
# └─────┰──────┘ (fence)
|
||||
# ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
|
||||
# ┃ (compute)
|
||||
# ┌─────┸──────┐
|
||||
# │ Wait │ [1] Synchronize the current stream with the copy stream.
|
||||
# └─────┰──────┘
|
||||
# ┌─────┸──────┐
|
||||
# │ Checkpoint │ [2] Compute a partition within checkpointing.
|
||||
# └─────┰──────┘
|
||||
# ┌─────┸──────┐
|
||||
# │ Wait │ [3] Synchronize the copy stream with the current stream.
|
||||
# └─────┰──────┘
|
||||
# ┠ ─ ─ ─ ┐
|
||||
# ┃ ┌─────┴─────┐
|
||||
# ┃ │ Recompute │ [4] Schedule the recomputation at backpropagation.
|
||||
# ┃ └─────┬─────┘
|
||||
# ┠ ─ ─ ─ ┘
|
||||
# ┃
|
||||
# ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─
|
||||
# ┌─────┸──────┐ (fence)
|
||||
# │ Copy │
|
||||
# └─────┰──────┘
|
||||
for i, j in schedule:
|
||||
batch = batches[i]
|
||||
partition = partitions[j]
|
||||
|
||||
# Synchronize with the copied input. ([1] in the diagram)
|
||||
if j != 0:
|
||||
wait(batch, copy_streams[j][i], streams[j])
|
||||
|
||||
# Determine whether checkpointing or not.
|
||||
checkpoint = i < checkpoint_stop
|
||||
if checkpoint:
|
||||
|
||||
def function(
|
||||
input: TensorOrTensors,
|
||||
partition: nn.Sequential = partition,
|
||||
skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
|
||||
chunk_id: int = i,
|
||||
part_id: int = j,
|
||||
) -> TensorOrTensors:
|
||||
with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)):
|
||||
return partition(input)
|
||||
|
||||
chk = Checkpointing(function, batch)
|
||||
task = Task(streams[j], compute=chk.checkpoint, finalize=chk.recompute)
|
||||
del function, chk
|
||||
|
||||
else:
|
||||
|
||||
def compute(
|
||||
batch: Batch = batch,
|
||||
partition: nn.Sequential = partition,
|
||||
skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
|
||||
chunk_id: int = i,
|
||||
part_id: int = j,
|
||||
) -> Batch:
|
||||
with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)):
|
||||
return batch.call(partition)
|
||||
|
||||
task = Task(streams[j], compute=compute, finalize=None)
|
||||
del compute
|
||||
|
||||
# Compute tasks in parallel. ([2] in the diagram)
|
||||
self.in_queues[j].put(task)
|
||||
|
||||
for i, j in schedule:
|
||||
ok, payload = self.out_queues[j].get()
|
||||
|
||||
# Hold the first exception.
|
||||
if exc_info is not None:
|
||||
continue
|
||||
elif not ok:
|
||||
exc_info = cast(ExcInfo, payload)
|
||||
continue
|
||||
|
||||
task, batch = cast(Tuple[Task, Batch], payload)
|
||||
|
||||
# The copy stream synchronizes to copy the output. ([3] in the
|
||||
# diagram)
|
||||
if j != n - 1:
|
||||
wait(batch, streams[j], copy_streams[j][i])
|
||||
|
||||
# Finalize tasks. If checkpointing is enabled, here the
|
||||
# recomputation is scheduled at backpropagation. ([4] in the
|
||||
# diagram)
|
||||
with use_device(devices[j]):
|
||||
task.finalize(batch)
|
||||
|
||||
batches[i] = batch
|
||||
|
||||
# Fail at the first exception.
|
||||
if exc_info is not None:
|
||||
raise exc_info[0].with_traceback(exc_info[1], exc_info[2])
|
6
torch/distributed/_pipeline/sync/py.typed
Normal file
6
torch/distributed/_pipeline/sync/py.typed
Normal file
@ -0,0 +1,6 @@
|
||||
# Copyright 2019 Kakao Brain
|
||||
#
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
11
torch/distributed/_pipeline/sync/skip/__init__.py
Normal file
11
torch/distributed/_pipeline/sync/skip/__init__.py
Normal file
@ -0,0 +1,11 @@
|
||||
# Copyright 2019 Kakao Brain
|
||||
#
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""Supports efficiency with skip connections."""
|
||||
from .namespace import Namespace
|
||||
from .skippable import pop, skippable, stash, verify_skippables
|
||||
|
||||
__all__ = ["skippable", "stash", "pop", "verify_skippables", "Namespace"]
|
86
torch/distributed/_pipeline/sync/skip/layout.py
Normal file
86
torch/distributed/_pipeline/sync/skip/layout.py
Normal file
@ -0,0 +1,86 @@
|
||||
# Copyright 2019 Kakao Brain
|
||||
#
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""Static skip connection layout of ``@skippable`` modules."""
|
||||
from typing import Dict, Iterable, List, Tuple
|
||||
|
||||
from torch import nn
|
||||
|
||||
from .namespace import Namespace
|
||||
|
||||
__all__: List[str] = []
|
||||
|
||||
|
||||
class SkipLayout:
|
||||
"""Represents a skip connection layout across partitions."""
|
||||
|
||||
# Skip routes indexed by 'ns, name': {(ns, name): (prev_j, next_j), ...}
|
||||
by_ns_name: Dict[Tuple[Namespace, str], Tuple[int, int]]
|
||||
|
||||
# Skip routes indexed by partition number 'j': [[next_j]: [(prev_j, ns, name), ...], ...]
|
||||
by_partition: List[List[Tuple[int, Namespace, str]]]
|
||||
|
||||
def __init__(self, num_partitions: int, skip_routes: Dict[Tuple[Namespace, str], Tuple[int, int]],) -> None:
|
||||
# The skip routes are already indexed by 'ns, name'.
|
||||
self.by_ns_name = skip_routes
|
||||
|
||||
# Index skip routes by partition number 'j'.
|
||||
self.by_partition = [[] for _ in range(num_partitions)]
|
||||
|
||||
for (ns, name), (prev_j, next_j) in skip_routes.items():
|
||||
self.by_partition[next_j].append((prev_j, ns, name))
|
||||
|
||||
for p in self.by_partition:
|
||||
p.sort()
|
||||
|
||||
def copy_policy(self, next_j: int) -> Iterable[Tuple[int, Namespace, str]]:
|
||||
"""Generates skip routes for the given destination partition number.
|
||||
The skip routes are sorted by source partition number in ascending
|
||||
order.
|
||||
|
||||
Yields:
|
||||
Each tuple of (source partition number, namespace, name).
|
||||
|
||||
"""
|
||||
for prev_j, ns, name in self.by_partition[next_j]:
|
||||
if prev_j == next_j:
|
||||
# This skip tensor will be popped at the same partition where
|
||||
# it is stashed. In this case, copy is not required.
|
||||
continue
|
||||
|
||||
yield (prev_j, ns, name)
|
||||
|
||||
def requires_copy(self, ns: Namespace, name: str) -> bool:
|
||||
"""Whether the given namespace and name requires partition-to-partition
|
||||
copy or not.
|
||||
"""
|
||||
prev_j, next_j = self.by_ns_name.get((ns, name), (-1, -1))
|
||||
return prev_j != next_j
|
||||
|
||||
|
||||
def inspect_skip_layout(partitions: List[nn.Sequential]) -> SkipLayout:
|
||||
"""Inspects the skip connection layout in the given partitions."""
|
||||
# NOTE(sublee): Hide circular import inside this subroutine. Circular
|
||||
# import is not ideal but placing this logic near to SkipLayout may
|
||||
# increase cohesion of code.
|
||||
from .skippable import Skippable
|
||||
|
||||
skip_routes: Dict[Tuple[Namespace, str], Tuple[int, int]] = {}
|
||||
stashed_at: Dict[Tuple[Namespace, str], int] = {}
|
||||
|
||||
for j, partition in enumerate(partitions):
|
||||
for layer in partition:
|
||||
if not isinstance(layer, Skippable):
|
||||
continue
|
||||
|
||||
for ns, name in layer.stashable():
|
||||
stashed_at[(ns, name)] = j
|
||||
|
||||
for ns, name in layer.poppable():
|
||||
prev_j = stashed_at.pop((ns, name))
|
||||
skip_routes[(ns, name)] = (prev_j, j)
|
||||
|
||||
return SkipLayout(len(partitions), skip_routes)
|
50
torch/distributed/_pipeline/sync/skip/namespace.py
Normal file
50
torch/distributed/_pipeline/sync/skip/namespace.py
Normal file
@ -0,0 +1,50 @@
|
||||
# Copyright 2019 Kakao Brain
|
||||
#
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""Provides isolated namespace of skip tensors."""
|
||||
import abc
|
||||
from functools import total_ordering
|
||||
from typing import Any
|
||||
import uuid
|
||||
|
||||
__all__ = ["Namespace"]
|
||||
|
||||
|
||||
@total_ordering
|
||||
class Namespace(metaclass=abc.ABCMeta):
|
||||
"""Namespace for isolating skip tensors used by :meth:`isolate()
|
||||
<torchpipe.skip.skippable.Skippable.isolate>`.
|
||||
"""
|
||||
|
||||
__slots__ = ("id",)
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.id = uuid.uuid4()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Namespace '{self.id}'>"
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.id)
|
||||
|
||||
# Namespaces should support ordering, since SkipLayout will sort tuples
|
||||
# including a namespace. But actual order between namespaces is not
|
||||
# important. That's why they are ordered by version 4 UUID which generates
|
||||
# random numbers.
|
||||
def __lt__(self, other: Any) -> bool:
|
||||
if isinstance(other, Namespace):
|
||||
return self.id < other.id
|
||||
return False
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if isinstance(other, Namespace):
|
||||
return self.id == other.id
|
||||
return False
|
||||
|
||||
|
||||
# 'None' is the default namespace,
|
||||
# which means that 'isinstance(None, Namespace)' is 'True'.
|
||||
Namespace.register(type(None))
|
231
torch/distributed/_pipeline/sync/skip/portal.py
Normal file
231
torch/distributed/_pipeline/sync/skip/portal.py
Normal file
@ -0,0 +1,231 @@
|
||||
# Copyright 2019 Kakao Brain
|
||||
#
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""Portal keeps a tensor in the pocket plane. The tensor becomes hidden to the
|
||||
autograd engine. The shared context of three functions (:class:`PortalBlue`,
|
||||
:class:`PortalOrange`, and :class:`PortalCopy`) out of the computation graph is
|
||||
one of the most important feature of :mod:`torchpipe.skip`.
|
||||
|
||||
The metaphor is inspired by Portal™ from Valve.
|
||||
|
||||
"""
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from ..copy import Context as CopyContext
|
||||
from ..copy import Copy
|
||||
from ..phony import get_phony
|
||||
from ..stream import AbstractStream, get_device
|
||||
|
||||
__all__: List[str] = []
|
||||
|
||||
|
||||
class Portal:
|
||||
"""A portal for a tensor."""
|
||||
|
||||
def __init__(self, tensor: Optional[Tensor], tensor_life: int) -> None:
|
||||
self.put_tensor(tensor, tensor_life)
|
||||
self.grad: Optional[Tensor] = None
|
||||
|
||||
def blue(self) -> Tensor:
|
||||
"""Creates a :class:`PortalBlue` which hides the underlying tensor from
|
||||
the autograd engine.
|
||||
|
||||
Join the returning phony to the main lane of the autograd graph to
|
||||
assure the correct backpropagation::
|
||||
|
||||
PortalBlue --+
|
||||
|
|
||||
---------- Join --
|
||||
|
||||
"""
|
||||
tensor = self.use_tensor()
|
||||
|
||||
if tensor is None:
|
||||
return get_phony(torch.device("cpu"), requires_grad=False)
|
||||
|
||||
return PortalBlue.apply(self, tensor)
|
||||
|
||||
def orange(self, phony: Tensor) -> Optional[Tensor]:
|
||||
"""Creates a :class:`PortalOrange` which retrieves the hidden tensor
|
||||
without losing ability of backpropagation.
|
||||
|
||||
Give a phony forked from the main lane of an autograd graph::
|
||||
|
||||
+-- PortalOrange --+
|
||||
| |
|
||||
-- Fork --------- f(a, b) --
|
||||
|
||||
"""
|
||||
self.check_tensor_life()
|
||||
|
||||
if self.tensor is None:
|
||||
return self.use_tensor()
|
||||
|
||||
return PortalOrange.apply(self, phony)
|
||||
|
||||
def copy(self, prev_stream: AbstractStream, next_stream: AbstractStream, phony: Tensor,) -> Tensor:
|
||||
"""Copies the hidden tensor by a :class:`PortalCopy`.
|
||||
|
||||
Give a phony and use the returning phony to keep backpropagation::
|
||||
|
||||
+-- PortalCopy --+
|
||||
| |
|
||||
-- Fork ---------- Join --
|
||||
|
||||
"""
|
||||
if self.tensor is None:
|
||||
return get_phony(torch.device("cpu"), requires_grad=False)
|
||||
|
||||
return PortalCopy.apply(self, prev_stream, next_stream, phony)
|
||||
|
||||
def check_tensor_life(self) -> None:
|
||||
if self.tensor_life <= 0:
|
||||
raise RuntimeError("tensor in portal has been removed")
|
||||
|
||||
def put_tensor(self, tensor: Optional[Tensor], tensor_life: int) -> None:
|
||||
"""Stores a tensor into this portal."""
|
||||
# [Life of Tensor through Portal]
|
||||
#
|
||||
# The tensor can be retrieved by use_tensor() up to 'tensor_life'
|
||||
# times. When the life becomes 0, the tensor will be deleted for
|
||||
# deallocation in CUDA memory.
|
||||
#
|
||||
# The below events participate in a tensor through a portal.
|
||||
# Note that [x] denotes the events which call use_tensor():
|
||||
#
|
||||
# 1. [x] blue()
|
||||
# 2. [ ] PortalBlue.forward
|
||||
# 3. [ ] copy()
|
||||
# 4. [ ] PortalCopy.forward
|
||||
# 5. [ ] orange()
|
||||
# 6. [x] PortalOrange.forward
|
||||
# - - - - - - - - - - - - - - - - - - - - - - - - - - -
|
||||
# 7. [ ] orange() (recomputed)
|
||||
# 8. [x] PortalOrange.forward (recomputed)
|
||||
# 9. [ ] PortalOrange.backward
|
||||
# 10. [ ] PortalCopy.backward
|
||||
# 11. [x] blue() (recomputed)
|
||||
# 12. [ ] PortalBlue.forward (recomputed)
|
||||
# 13. [ ] PortalBlue.backward
|
||||
#
|
||||
self.tensor_life = tensor_life
|
||||
|
||||
if tensor_life > 0:
|
||||
self.tensor = tensor
|
||||
else:
|
||||
self.tensor = None
|
||||
|
||||
def use_tensor(self) -> Optional[Tensor]:
|
||||
"""Retrieves the underlying tensor and decreases the tensor life. When
|
||||
the life becomes 0, it the tensor will be removed.
|
||||
"""
|
||||
self.check_tensor_life()
|
||||
|
||||
tensor = self.tensor
|
||||
|
||||
self.tensor_life -= 1
|
||||
|
||||
if self.tensor_life <= 0:
|
||||
self.tensor = None
|
||||
|
||||
return tensor
|
||||
|
||||
def put_grad(self, grad: Tensor) -> None:
|
||||
"""Stores a gradient into this portal."""
|
||||
self.grad = grad
|
||||
|
||||
def use_grad(self) -> Tensor:
|
||||
"""Retrieves and removes the underlying gradient. The gradient is
|
||||
always ephemeral.
|
||||
"""
|
||||
if self.grad is None:
|
||||
raise RuntimeError("grad in portal has been removed or never set")
|
||||
|
||||
grad = self.grad
|
||||
self.grad = None
|
||||
return grad
|
||||
|
||||
|
||||
# Common interface between :class:`PortalBlue`, :class:`PortalOrange`, and
|
||||
# :class:`PortalCopy`.
|
||||
class Context(CopyContext):
|
||||
portal: Portal
|
||||
|
||||
|
||||
class PortalBlue(torch.autograd.Function):
|
||||
"""Hides a tensor from the autograd engine by a :class:`Portal`."""
|
||||
|
||||
@staticmethod
|
||||
# type: ignore
|
||||
def forward(
|
||||
ctx: Context,
|
||||
portal: Portal,
|
||||
# This tensor must be retrieved by portal.use_tensor().
|
||||
tensor: Tensor,
|
||||
) -> Tensor:
|
||||
ctx.portal = portal
|
||||
|
||||
phony = get_phony(tensor.device, requires_grad=False)
|
||||
return phony.detach()
|
||||
|
||||
@staticmethod
|
||||
# type: ignore
|
||||
def backward(ctx: Context, grad_phony: Tensor,) -> Tuple[None, Tensor]:
|
||||
# The paired PortalOrange should keep the gradient.
|
||||
grad = ctx.portal.use_grad()
|
||||
return None, grad
|
||||
|
||||
|
||||
class PortalOrange(torch.autograd.Function):
|
||||
"""Retrieves the hidden tensor from a :class:`Portal`."""
|
||||
|
||||
@staticmethod
|
||||
# type: ignore
|
||||
def forward(ctx: Context, portal: Portal, phony: Tensor) -> Tensor:
|
||||
ctx.portal = portal
|
||||
|
||||
tensor = portal.use_tensor()
|
||||
assert tensor is not None
|
||||
|
||||
return tensor.detach()
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Context, grad: Tensor) -> Tuple[None, None]: # type: ignore
|
||||
# The paired PortalBlue will use the gradient.
|
||||
ctx.portal.put_grad(grad)
|
||||
return None, None
|
||||
|
||||
|
||||
class PortalCopy(torch.autograd.Function):
|
||||
"""Copies the hidden tensor in a :class:`Portal`. It replaces the hidden
|
||||
tensor with copied one.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
# type: ignore
|
||||
def forward(
|
||||
ctx: Context, portal: Portal, prev_stream: AbstractStream, next_stream: AbstractStream, phony: Tensor,
|
||||
) -> Tensor:
|
||||
ctx.portal = portal
|
||||
|
||||
assert portal.tensor is not None
|
||||
(portal.tensor,) = Copy.forward(ctx, prev_stream, next_stream, portal.tensor)
|
||||
|
||||
phony = get_phony(get_device(next_stream), requires_grad=False)
|
||||
return phony.detach()
|
||||
|
||||
@staticmethod
|
||||
# type: ignore
|
||||
def backward(ctx: Context, grad_phony: Tensor,) -> Tuple[None, None, None, None]:
|
||||
portal = ctx.portal
|
||||
|
||||
assert portal.grad is not None
|
||||
_, _, portal.grad = Copy.backward(ctx, portal.grad)
|
||||
|
||||
return None, None, None, None
|
439
torch/distributed/_pipeline/sync/skip/skippable.py
Normal file
439
torch/distributed/_pipeline/sync/skip/skippable.py
Normal file
@ -0,0 +1,439 @@
|
||||
# Copyright 2019 Kakao Brain
|
||||
#
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""The user interface to define skip connections."""
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
ClassVar,
|
||||
Dict,
|
||||
FrozenSet,
|
||||
Generator,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from torch import Tensor, nn
|
||||
|
||||
from ..microbatch import Batch
|
||||
from .namespace import Namespace
|
||||
from .tracker import current_skip_tracker
|
||||
|
||||
__all__ = ["skippable", "stash", "pop", "verify_skippables"]
|
||||
|
||||
|
||||
Tensors = Tuple[Tensor, ...]
|
||||
TensorOrTensors = Union[Tensor, Tensors]
|
||||
|
||||
StashPop = Union["stash", "pop"]
|
||||
StashPopGenerator = Generator[StashPop, Optional[Tensor], TensorOrTensors]
|
||||
if TYPE_CHECKING:
|
||||
SkippableModule = nn.Module[Union[StashPopGenerator, TensorOrTensors]]
|
||||
else:
|
||||
SkippableModule = nn.Module
|
||||
|
||||
T = TypeVar("T", bound="Skippable")
|
||||
|
||||
|
||||
class Skippable(nn.Module):
|
||||
"""The base class for skippable modules.
|
||||
|
||||
Do not use this class directly. Define a subclass by :func:`skippable`
|
||||
instead.
|
||||
|
||||
"""
|
||||
|
||||
module_cls: ClassVar[Type[SkippableModule]]
|
||||
stashable_names: ClassVar[FrozenSet[str]]
|
||||
poppable_names: ClassVar[FrozenSet[str]]
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__()
|
||||
self.module = self.module_cls(*args, **kwargs) # type: ignore
|
||||
self.namespaces: Dict[str, Namespace] = {}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"@skippable({self.module})"
|
||||
|
||||
def namespaced(self, name: str) -> Tuple[Namespace, str]:
|
||||
"""Prepends namespace for the given skip name."""
|
||||
ns = self.namespaces.get(name)
|
||||
ns = cast(Namespace, ns)
|
||||
return (ns, name)
|
||||
|
||||
def stashable(self) -> Iterable[Tuple[Namespace, str]]:
|
||||
"""Iterates over namespaced skip names to be stashed."""
|
||||
for name in self.stashable_names:
|
||||
yield self.namespaced(name)
|
||||
|
||||
def poppable(self) -> Iterable[Tuple[Namespace, str]]:
|
||||
"""Iterates over namespaced skip names to be popped."""
|
||||
for name in self.poppable_names:
|
||||
yield self.namespaced(name)
|
||||
|
||||
def isolate(self: T, ns: Namespace, *, only: Optional[Iterable[str]] = None) -> T:
|
||||
r"""Isolates a specified subset or the whole set of skip tensors into a
|
||||
namespace. In a single sequential module, skip tensors with the same
|
||||
name are not allowed unless they are isolated by different namespaces.
|
||||
|
||||
Here's an example using the same name for skip tensors twice. Each pair
|
||||
of ``Layer1`` and ``Layer2`` is isolated with its own namespace ``ns1``
|
||||
and ``ns2``. There is no conflict anymore::
|
||||
|
||||
ns1 = Namespace()
|
||||
ns2 = Namespace()
|
||||
|
||||
model = nn.Sequential(
|
||||
Layer1().isolate(ns1),
|
||||
Layer1().isolate(ns2),
|
||||
Layer2(),
|
||||
Layer3().isolate(ns2),
|
||||
Layer3().isolate(ns1),
|
||||
)
|
||||
|
||||
When `only` parameter is omitted, all skip tensors are isolated. You
|
||||
can isolate a subset of skip tensors by passing `only` parameter::
|
||||
|
||||
ns_alice = Namespace()
|
||||
ns_bob = Namespace()
|
||||
|
||||
model = nn.Sequential(
|
||||
...
|
||||
StashStashPop().isolate(ns_alice, only=['alice']) \
|
||||
.isolate(ns_bob, only=['bob']),
|
||||
...
|
||||
)
|
||||
|
||||
Args:
|
||||
ns (Namespace):
|
||||
namespace for isolation
|
||||
|
||||
Keyword Args:
|
||||
only (iterable of strs):
|
||||
names of specific skip tensors to be isolated (omit this option
|
||||
to isolate all skip tensors declared in this module)
|
||||
|
||||
Returns:
|
||||
this module itself
|
||||
|
||||
"""
|
||||
names: Iterable[str]
|
||||
|
||||
if only is None:
|
||||
names = self.stashable_names | self.poppable_names
|
||||
else:
|
||||
names = set(only)
|
||||
|
||||
for name in names:
|
||||
self.namespaces[name] = ns
|
||||
|
||||
return self
|
||||
|
||||
def dispatch(
|
||||
self,
|
||||
input: TensorOrTensors,
|
||||
handle_stash: Callable[[str, Optional[Tensor]], None],
|
||||
handle_pop: Callable[[str], Optional[Tensor]],
|
||||
) -> TensorOrTensors:
|
||||
"""Dispatches :class:`stash` or :class:`pop` commands generated by the
|
||||
module's ``forward()``.
|
||||
"""
|
||||
generator = self.module(input)
|
||||
|
||||
if not isinstance(generator, Generator):
|
||||
# The underlying module returned output without any yield.
|
||||
output = generator
|
||||
return output
|
||||
|
||||
try:
|
||||
op = next(generator)
|
||||
|
||||
while True:
|
||||
if isinstance(op, stash):
|
||||
handle_stash(op.name, op.tensor)
|
||||
op = next(generator)
|
||||
continue
|
||||
|
||||
if isinstance(op, pop):
|
||||
tensor = handle_pop(op.name)
|
||||
op = generator.send(tensor)
|
||||
continue
|
||||
|
||||
raise TypeError("%r is not a command from @skippable" % op)
|
||||
|
||||
except StopIteration as stop:
|
||||
output = stop.args[0]
|
||||
return output
|
||||
|
||||
def forward(self, input: TensorOrTensors) -> TensorOrTensors: # type: ignore
|
||||
"""Performs the forward propagation. :class:`stash` or :class:`pop`
|
||||
commands will be handled by portals silently. The portals won't be
|
||||
exposed to users.
|
||||
|
||||
Raises:
|
||||
RuntimeError:
|
||||
illegal 'stash' or 'pop' is found.
|
||||
|
||||
"""
|
||||
skip_tracker = current_skip_tracker()
|
||||
stashed_tensors: Dict[str, Optional[Tensor]] = {}
|
||||
|
||||
# Load skip tensors that might be popped.
|
||||
poppable_tensors = {}
|
||||
batch = Batch(input)
|
||||
for ns, name in self.poppable():
|
||||
try:
|
||||
poppable_tensors[name] = skip_tracker.load(batch, ns, name)
|
||||
except KeyError:
|
||||
raise RuntimeError(f"'{name}' has not been stashed")
|
||||
input = batch.tensor_or_tensors
|
||||
|
||||
# Handle skip commands.
|
||||
def handle_stash(name: str, tensor: Optional[Tensor]) -> None:
|
||||
if name not in self.stashable_names:
|
||||
raise RuntimeError(f"'{name}' has not been declared as stashable")
|
||||
stashed_tensors[name] = tensor
|
||||
|
||||
def handle_pop(name: str) -> Optional[Tensor]:
|
||||
if name not in self.poppable_names:
|
||||
raise RuntimeError(f"'{name}' has not been declared as poppable")
|
||||
return poppable_tensors.pop(name)
|
||||
|
||||
output = self.dispatch(input, handle_stash, handle_pop)
|
||||
|
||||
# All declared skips must be stashed or popped.
|
||||
not_stashed = self.stashable_names - stashed_tensors.keys()
|
||||
if not_stashed:
|
||||
comma_names = ", ".join("'%s'" % n for n in not_stashed)
|
||||
raise RuntimeError(f"{comma_names} must be stashed but have not")
|
||||
|
||||
not_popped = poppable_tensors.keys()
|
||||
if not_popped:
|
||||
comma_names = ", ".join("'%s'" % n for n in not_popped)
|
||||
raise RuntimeError(f"{comma_names} must be popped but have not")
|
||||
|
||||
# Save stashed skip tensors.
|
||||
batch = Batch(output)
|
||||
for ns, name in self.stashable():
|
||||
tensor = stashed_tensors[name]
|
||||
skip_tracker.save(batch, ns, name, tensor)
|
||||
output = batch.tensor_or_tensors
|
||||
|
||||
return output
|
||||
|
||||
|
||||
# TODO(sublee): Move to above of Skippable class for better read flow.
|
||||
def skippable(
|
||||
stash: Iterable[str] = (), pop: Iterable[str] = (),
|
||||
) -> Callable[[Type[SkippableModule]], Type[Skippable]]:
|
||||
"""The decorator to define a :class:`nn.Module <torch.nn.Module>` with skip
|
||||
connections. Decorated modules are called "skippable". This functionality
|
||||
works perfectly fine even when the module is not wrapped by
|
||||
:class:`~torchpipe.Pipe`.
|
||||
|
||||
Each skip tensor is managed by its name. Before manipulating skip tensors,
|
||||
a skippable module must statically declare the names for skip tensors by
|
||||
`stash` and/or `pop` parameters. Skip tensors with pre-declared name can be
|
||||
stashed by ``yield stash(name, tensor)`` or popped by ``tensor = yield
|
||||
pop(name)``.
|
||||
|
||||
Here is an example with three layers. A skip tensor named "1to3" is stashed
|
||||
and popped at the first and last layer, respectively::
|
||||
|
||||
@skippable(stash=['1to3'])
|
||||
class Layer1(nn.Module):
|
||||
def forward(self, input):
|
||||
yield stash('1to3', input)
|
||||
return f1(input)
|
||||
|
||||
class Layer2(nn.Module):
|
||||
def forward(self, input):
|
||||
return f2(input)
|
||||
|
||||
@skippable(pop=['1to3'])
|
||||
class Layer3(nn.Module):
|
||||
def forward(self, input):
|
||||
skip_1to3 = yield pop('1to3')
|
||||
return f3(input) + skip_1to3
|
||||
|
||||
model = nn.Sequential(Layer1(), Layer2(), Layer3())
|
||||
|
||||
One skippable module can stash or pop multiple skip tensors::
|
||||
|
||||
@skippable(stash=['alice', 'bob'], pop=['carol'])
|
||||
class StashStashPop(nn.Module):
|
||||
def forward(self, input):
|
||||
yield stash('alice', f_alice(input))
|
||||
yield stash('bob', f_bob(input))
|
||||
carol = yield pop('carol')
|
||||
return input + carol
|
||||
|
||||
Every skip tensor must be associated with exactly one pair of `stash` and
|
||||
`pop`. :class:`~torchpipe.Pipe` checks this restriction automatically
|
||||
when wrapping a module. You can also check the restriction by
|
||||
:func:`~torchpipe.skip.verify_skippables` without
|
||||
:class:`~torchpipe.Pipe`.
|
||||
|
||||
.. note::
|
||||
|
||||
:func:`@skippable <skippable>` changes the type of the wrapped class.
|
||||
But currently (mypy v0.740), mypy could not understand class decorators
|
||||
yet (`#3135 <https://github.com/python/mypy/issues/3135>`_).
|
||||
|
||||
There are two workarounds:
|
||||
|
||||
1. Naively ignore type errors by ``# type: ignore``.
|
||||
2. Use ``skippable()()`` as a function instead of a decorator.
|
||||
|
||||
.. seealso:: :ref:`Long Skip Connections`
|
||||
|
||||
"""
|
||||
stashable_names = frozenset(stash)
|
||||
poppable_names = frozenset(pop)
|
||||
|
||||
def extend_skippable(module_cls: Type[SkippableModule]) -> Type[Skippable]:
|
||||
name = module_cls.__name__
|
||||
bases = (Skippable,)
|
||||
attrs = {"module_cls": module_cls, "stashable_names": stashable_names, "poppable_names": poppable_names}
|
||||
return type(name, bases, attrs)
|
||||
|
||||
return extend_skippable
|
||||
|
||||
|
||||
class stash:
|
||||
"""The command to stash a skip tensor.
|
||||
|
||||
::
|
||||
|
||||
def forward(self, input):
|
||||
yield stash('name', input)
|
||||
return f(input)
|
||||
|
||||
Args:
|
||||
name (str): name of skip tensor
|
||||
input (torch.Tensor or None): tensor to pass to the skip connection
|
||||
|
||||
"""
|
||||
|
||||
__slots__ = ("name", "tensor")
|
||||
|
||||
def __init__(self, name: str, tensor: Optional[Tensor]) -> None:
|
||||
self.name = name
|
||||
self.tensor = tensor
|
||||
|
||||
|
||||
class pop:
|
||||
"""The command to pop a skip tensor.
|
||||
|
||||
::
|
||||
|
||||
def forward(self, input):
|
||||
skip = yield pop('name')
|
||||
return f(input) + skip
|
||||
|
||||
Args:
|
||||
name (str): name of skip tensor
|
||||
|
||||
Returns:
|
||||
the skip tensor previously stashed by another layer under the same name
|
||||
|
||||
"""
|
||||
|
||||
__slots__ = ("name",)
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
self.name = name
|
||||
|
||||
|
||||
def verify_skippables(module: nn.Sequential) -> None:
|
||||
"""Verifies if the underlying skippable modules satisfy integrity.
|
||||
|
||||
Every skip tensor must have only one pair of `stash` and `pop`. If there
|
||||
are one or more unmatched pairs, it will raise :exc:`TypeError` with the
|
||||
detailed messages.
|
||||
|
||||
Here are a few failure cases. :func:`verify_skippables` will report failure
|
||||
for these cases::
|
||||
|
||||
# Layer1 stashes "1to3".
|
||||
# Layer3 pops "1to3".
|
||||
|
||||
nn.Sequential(Layer1(), Layer2())
|
||||
# └──── ?
|
||||
|
||||
nn.Sequential(Layer2(), Layer3())
|
||||
# ? ────┘
|
||||
|
||||
nn.Sequential(Layer1(), Layer2(), Layer3(), Layer3())
|
||||
# └───────────────────┘ ^^^^^^
|
||||
|
||||
nn.Sequential(Layer1(), Layer1(), Layer2(), Layer3())
|
||||
# ^^^^^^ └───────────────────┘
|
||||
|
||||
To use the same name for multiple skip tensors, they must be isolated by
|
||||
different namespaces. See :meth:`isolate()
|
||||
<torchpipe.skip.skippable.Skippable.isolate>`.
|
||||
|
||||
Raises:
|
||||
TypeError:
|
||||
one or more pairs of `stash` and `pop` are not matched.
|
||||
|
||||
"""
|
||||
stashed: Set[Tuple[Namespace, str]] = set()
|
||||
popped: Set[Tuple[Namespace, str]] = set()
|
||||
msgs: List[str] = []
|
||||
|
||||
for layer_name, layer in module.named_children():
|
||||
if not isinstance(layer, Skippable):
|
||||
continue
|
||||
|
||||
for name in layer.stashable_names & layer.poppable_names:
|
||||
msg = f"'{layer_name}' declared '{name}' both as stashable and as poppable"
|
||||
msgs.append(msg)
|
||||
|
||||
for ns, name in layer.stashable():
|
||||
if name in layer.poppable_names:
|
||||
continue
|
||||
|
||||
if (ns, name) in stashed:
|
||||
msg = f"'{layer_name}' redeclared '{name}' as stashable " "but not isolated by namespace"
|
||||
msgs.append(msg)
|
||||
continue
|
||||
|
||||
stashed.add((ns, name))
|
||||
|
||||
for ns, name in layer.poppable():
|
||||
if name in layer.stashable_names:
|
||||
continue
|
||||
|
||||
if (ns, name) in popped:
|
||||
msg = f"'{layer_name}' redeclared '{name}' as poppable " "but not isolated by namespace"
|
||||
msgs.append(msg)
|
||||
continue
|
||||
|
||||
if (ns, name) not in stashed:
|
||||
msg = f"'{layer_name}' declared '{name}' as poppable but it was not stashed"
|
||||
msgs.append(msg)
|
||||
continue
|
||||
|
||||
popped.add((ns, name))
|
||||
|
||||
for (_, name) in stashed - popped:
|
||||
msg = f"no module declared '{name}' as poppable but stashed"
|
||||
msgs.append(msg)
|
||||
|
||||
if msgs:
|
||||
raise TypeError(
|
||||
"one or more pairs of stash and pop do not match:\n\n%s" "" % "\n".join("* %s" % x for x in msgs)
|
||||
)
|
177
torch/distributed/_pipeline/sync/skip/tracker.py
Normal file
177
torch/distributed/_pipeline/sync/skip/tracker.py
Normal file
@ -0,0 +1,177 @@
|
||||
# Copyright 2019 Kakao Brain
|
||||
#
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""Tracks skip tensors on a thread."""
|
||||
from contextlib import contextmanager
|
||||
import threading
|
||||
from typing import Dict, Generator, List, Optional, Tuple
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
from ..checkpoint import is_checkpointing
|
||||
from ..dependency import fork, join
|
||||
from ..microbatch import Batch
|
||||
from ..stream import AbstractStream
|
||||
from .layout import SkipLayout
|
||||
from .namespace import Namespace
|
||||
from .portal import Portal
|
||||
|
||||
__all__: List[str] = []
|
||||
|
||||
|
||||
class SkipTracker:
|
||||
"""Tracks saved skip tensors.
|
||||
|
||||
It will update the given micro-batch in place. This is because when it
|
||||
manipulates the underlying skip tensors, the current micro-batch also has
|
||||
to be connected with the skip tensors.
|
||||
|
||||
One thread has one skip tracker. Call :func:`current_skip_tracker` to get
|
||||
the skip tracker on the current thread.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.tensors: Dict[Tuple[Namespace, str], Optional[Tensor]] = {}
|
||||
|
||||
def save(self, batch: Batch, ns: Namespace, name: str, tensor: Optional[Tensor]) -> None:
|
||||
self.tensors[(ns, name)] = tensor
|
||||
|
||||
def load(self, batch: Batch, ns: Namespace, name: str) -> Optional[Tensor]:
|
||||
return self.tensors.pop((ns, name))
|
||||
|
||||
def copy(
|
||||
self, batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream, ns: Namespace, name: str,
|
||||
) -> None:
|
||||
raise TypeError("copy is not supported for non-portal skip tensors")
|
||||
|
||||
|
||||
class SkipTrackerThroughPotals(SkipTracker):
|
||||
"""Tracks saved skip tensors through portals. The skip tensors will be
|
||||
hidden in portals so that the autograd engine does not need to track them.
|
||||
|
||||
This tracker is only used when the training or evaluating module is wrapped
|
||||
with :class:`torchpipe.Pipe`.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, skip_layout: SkipLayout) -> None:
|
||||
super().__init__()
|
||||
self.skip_layout = skip_layout
|
||||
self.portals: Dict[Tuple[Namespace, str], Portal] = {}
|
||||
|
||||
def save(self, batch: Batch, ns: Namespace, name: str, tensor: Optional[Tensor]) -> None:
|
||||
"""Saves the stashed skip tensor in a portal. The portal is then
|
||||
connected to the given micro-batch with :class:`Join`.
|
||||
"""
|
||||
if not self.skip_layout.requires_copy(ns, name):
|
||||
super().save(batch, ns, name, tensor)
|
||||
return
|
||||
|
||||
# See [Tensor Life of Portal] at Portal.put_tensor() to understand the
|
||||
# below tensor_life values. Here are the selected events which retrieve
|
||||
# the tensor in portal:
|
||||
#
|
||||
# 1. [x] blue()
|
||||
# ...
|
||||
# 6. [x] PortalOrange.forward
|
||||
# ...
|
||||
# 8. [x] PortalOrange.forward (recomputed)
|
||||
# ...
|
||||
# 11. [x] blue() (recomputed)
|
||||
#
|
||||
if (ns, name) not in self.portals:
|
||||
if is_checkpointing():
|
||||
# Under checkpointing, the tensor used by the first
|
||||
# PortalOrange should be alive in the portal. This tensor will
|
||||
# be used again by the second PortalOrange during the
|
||||
# recomputation.
|
||||
tensor_life = 3 # Delete at [8. PortalOrange.forward (recomputed)]
|
||||
else:
|
||||
tensor_life = 2 # Delete at [6. PortalOrange.forward]
|
||||
|
||||
portal = Portal(tensor, tensor_life)
|
||||
self.portals[(ns, name)] = portal
|
||||
|
||||
else:
|
||||
# Under recomputation, the portal already exists.
|
||||
portal = self.portals[(ns, name)]
|
||||
|
||||
# The existing tensor life already became 0. It should be reset as
|
||||
# 1 to delete the tensor after the second PortalBlue immediately.
|
||||
tensor_life = 1 # Delete at [11. blue() (recomputed)]
|
||||
|
||||
portal.put_tensor(tensor, tensor_life)
|
||||
|
||||
phony = portal.blue()
|
||||
batch[0] = join(batch[0], phony)
|
||||
|
||||
def load(self, batch: Batch, ns: Namespace, name: str) -> Optional[Tensor]:
|
||||
"""Loads a skip tensor from the corresponding portal to pop. The given
|
||||
micro-batch is connected to the portal with :class:`Fork`.
|
||||
"""
|
||||
if not self.skip_layout.requires_copy(ns, name):
|
||||
tensor = super().load(batch, ns, name)
|
||||
return tensor
|
||||
|
||||
portal = self.portals[(ns, name)]
|
||||
batch[0], phony = fork(batch[0])
|
||||
tensor = portal.orange(phony)
|
||||
return tensor
|
||||
|
||||
def copy(
|
||||
self, batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream, ns: Namespace, name: str,
|
||||
) -> None:
|
||||
"""Copies the skip tensor in the corresponding portal. The given
|
||||
micro-batch and the portal will be tied with :class:`Fork` and
|
||||
:class:`Join`.
|
||||
"""
|
||||
assert self.skip_layout.requires_copy(ns, name)
|
||||
|
||||
batch[0], phony = fork(batch[0])
|
||||
|
||||
portal = self.portals[(ns, name)]
|
||||
phony = portal.copy(prev_stream, next_stream, phony)
|
||||
|
||||
batch[0] = join(batch[0], phony)
|
||||
|
||||
|
||||
class ThreadLocal(threading.local):
|
||||
def __init__(self) -> None:
|
||||
self.skip_tracker: Optional[SkipTracker] = None
|
||||
|
||||
|
||||
thread_local = ThreadLocal()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def use_skip_tracker(skip_tracker: SkipTracker) -> Generator[None, None, None]:
|
||||
"""Registers the given skip tracker on the current thread within a
|
||||
context::
|
||||
|
||||
with use_skip_tracker(my_skip_tracker):
|
||||
...
|
||||
|
||||
"""
|
||||
orig = thread_local.skip_tracker
|
||||
|
||||
thread_local.skip_tracker = skip_tracker
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
thread_local.skip_tracker = orig
|
||||
|
||||
|
||||
def current_skip_tracker() -> SkipTracker:
|
||||
"""Gets the skip tracker on the current thread."""
|
||||
skip_tracker = thread_local.skip_tracker
|
||||
|
||||
if skip_tracker is None:
|
||||
skip_tracker = SkipTracker()
|
||||
thread_local.skip_tracker = skip_tracker
|
||||
|
||||
return skip_tracker
|
117
torch/distributed/_pipeline/sync/stream.py
Normal file
117
torch/distributed/_pipeline/sync/stream.py
Normal file
@ -0,0 +1,117 @@
|
||||
# Copyright 2019 Kakao Brain
|
||||
#
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""Utilities for eliminating boilerplate code to handle abstract streams with
|
||||
CPU device.
|
||||
"""
|
||||
from contextlib import contextmanager
|
||||
from typing import Generator, List, Union, cast
|
||||
|
||||
import torch
|
||||
|
||||
__all__: List[str] = []
|
||||
|
||||
|
||||
class CPUStreamType:
|
||||
pass
|
||||
|
||||
|
||||
# The placeholder on place of streams for the CPU device instead of CUDA.
|
||||
CPUStream = CPUStreamType()
|
||||
|
||||
# It represents both CUDA streams and the CPU stream.
|
||||
AbstractStream = Union[torch.cuda.Stream, CPUStreamType]
|
||||
|
||||
|
||||
def new_stream(device: torch.device) -> AbstractStream:
|
||||
"""Creates a new stream for either CPU or CUDA device."""
|
||||
if device.type != "cuda":
|
||||
return CPUStream
|
||||
return torch.cuda.Stream(device)
|
||||
|
||||
|
||||
def current_stream(device: torch.device) -> AbstractStream:
|
||||
""":func:`torch.cuda.current_stream` for either CPU or CUDA device."""
|
||||
if device.type != "cuda":
|
||||
return CPUStream
|
||||
return torch.cuda.current_stream(device)
|
||||
|
||||
|
||||
def default_stream(device: torch.device) -> AbstractStream:
|
||||
""":func:`torch.cuda.default_stream` for either CPU or CUDA device."""
|
||||
if device.type != "cuda":
|
||||
return CPUStream
|
||||
return torch.cuda.default_stream(device)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def use_device(device: torch.device) -> Generator[None, None, None]:
|
||||
""":func:`torch.cuda.device` for either CPU or CUDA device."""
|
||||
if device.type != "cuda":
|
||||
yield
|
||||
return
|
||||
|
||||
with torch.cuda.device(device):
|
||||
yield
|
||||
|
||||
|
||||
@contextmanager
|
||||
def use_stream(stream: AbstractStream) -> Generator[None, None, None]:
|
||||
""":func:`torch.cuda.stream` for either CPU or CUDA stream."""
|
||||
if not is_cuda(stream):
|
||||
yield
|
||||
return
|
||||
|
||||
with torch.cuda.stream(as_cuda(stream)):
|
||||
yield
|
||||
|
||||
|
||||
def get_device(stream: AbstractStream) -> torch.device:
|
||||
"""Gets the device from CPU or CUDA stream."""
|
||||
if is_cuda(stream):
|
||||
return as_cuda(stream).device
|
||||
return torch.device("cpu")
|
||||
|
||||
|
||||
def wait_stream(source: AbstractStream, target: AbstractStream) -> None:
|
||||
""":meth:`torch.cuda.Stream.wait_stream` for either CPU or CUDA stream. It
|
||||
makes the source stream wait until the target stream completes work queued.
|
||||
"""
|
||||
if is_cuda(target):
|
||||
if is_cuda(source):
|
||||
# A CUDA stream waits another CUDA stream.
|
||||
as_cuda(source).wait_stream(as_cuda(target))
|
||||
else:
|
||||
# CPU waits a CUDA stream.
|
||||
as_cuda(target).synchronize()
|
||||
|
||||
# If the target is CPU, synchronization is not required.
|
||||
|
||||
|
||||
def record_stream(tensor: torch.Tensor, stream: AbstractStream) -> None:
|
||||
""":meth:`torch.Tensor.record_stream` for either CPU or CUDA stream."""
|
||||
if is_cuda(stream):
|
||||
# NOTE(sublee): record_stream() on a shifted view tensor throws
|
||||
# RuntimeError in PyTorch 1.1.0, and does nothing in 1.2.0. To safely
|
||||
# protect the tensor against unexpected reallocation, here we use a
|
||||
# temporal tensor associated with the same storage without shifting as
|
||||
# a workaround.
|
||||
#
|
||||
# Issue: https://github.com/pytorch/pytorch/issues/27366
|
||||
#
|
||||
tensor = tensor.new_empty([0]).set_(tensor.storage())
|
||||
|
||||
tensor.record_stream(as_cuda(stream))
|
||||
|
||||
|
||||
def is_cuda(stream: AbstractStream) -> bool:
|
||||
"""Returns ``True`` if the given stream is a valid CUDA stream."""
|
||||
return stream is not CPUStream
|
||||
|
||||
|
||||
def as_cuda(stream: AbstractStream) -> torch.cuda.Stream:
|
||||
"""Casts the given stream as :class:`torch.cuda.Stream`."""
|
||||
return cast(torch.cuda.Stream, stream)
|
151
torch/distributed/_pipeline/sync/worker.py
Normal file
151
torch/distributed/_pipeline/sync/worker.py
Normal file
@ -0,0 +1,151 @@
|
||||
# Copyright 2019 Kakao Brain
|
||||
#
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""Multithreading in pipeline parallelism."""
|
||||
from contextlib import contextmanager
|
||||
from queue import Queue
|
||||
import sys
|
||||
from threading import Thread
|
||||
from types import TracebackType
|
||||
from typing import TYPE_CHECKING, Callable, Dict, Generator, List, Optional, Tuple, Type, Union, cast
|
||||
|
||||
import torch
|
||||
|
||||
from .microbatch import Batch
|
||||
from .stream import AbstractStream, use_device, use_stream
|
||||
|
||||
__all__: List[str] = []
|
||||
|
||||
|
||||
ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType]
|
||||
|
||||
# Queue is generic only in stubs.
|
||||
# https://mypy.readthedocs.io/en/latest/common_issues.html#using-classes-that-are-generic-in-stubs-but-not-at-runtime
|
||||
if TYPE_CHECKING:
|
||||
InQueue = Queue[Optional["Task"]]
|
||||
OutQueue = Queue[Tuple[bool, Union[Tuple["Task", Batch], ExcInfo, None]]]
|
||||
else:
|
||||
InQueue = Queue
|
||||
OutQueue = Queue
|
||||
|
||||
|
||||
class Task:
|
||||
"""A task represents how to compute a micro-batch on a partition.
|
||||
|
||||
It consists of two parts: :meth:`compute` and :meth:`finalize`.
|
||||
:meth:`compute` should be executed in worker threads concurrently.
|
||||
:meth:`finalize` should be executed after when worker threads complete to
|
||||
execute :meth:`compute`.
|
||||
|
||||
:meth:`compute` might be boosted by worker threads. Because it produces
|
||||
several CUDA API calls by user code. In PyTorch, parallel CUDA API calls
|
||||
are not serialized through GIL. So more than one CUDA API call can be
|
||||
produced at the same time.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, stream: AbstractStream, *, compute: Callable[[], Batch], finalize: Optional[Callable[[Batch], None]],
|
||||
) -> None:
|
||||
self.stream = stream
|
||||
self._compute = compute
|
||||
self._finalize = finalize
|
||||
self._grad_enabled = torch.is_grad_enabled()
|
||||
|
||||
def compute(self) -> Batch:
|
||||
with use_stream(self.stream), torch.set_grad_enabled(self._grad_enabled):
|
||||
return self._compute()
|
||||
|
||||
def finalize(self, batch: Batch) -> None:
|
||||
if self._finalize is None:
|
||||
return
|
||||
with use_stream(self.stream), torch.set_grad_enabled(self._grad_enabled):
|
||||
self._finalize(batch)
|
||||
|
||||
|
||||
def worker(in_queue: InQueue, out_queue: OutQueue, device: torch.device) -> None:
|
||||
"""The main loop of a worker thread."""
|
||||
with use_device(device):
|
||||
while True:
|
||||
task = in_queue.get()
|
||||
|
||||
if task is None:
|
||||
break
|
||||
|
||||
try:
|
||||
batch = task.compute()
|
||||
except Exception:
|
||||
exc_info = cast(ExcInfo, sys.exc_info())
|
||||
out_queue.put((False, exc_info))
|
||||
continue
|
||||
|
||||
out_queue.put((True, (task, batch)))
|
||||
|
||||
done = (False, None)
|
||||
out_queue.put(done)
|
||||
|
||||
|
||||
def create_workers(devices: List[torch.device],) -> Tuple[List[InQueue], List[OutQueue]]:
|
||||
"""Spawns worker threads. A worker thread is bound to a device."""
|
||||
in_queues: List[InQueue] = []
|
||||
out_queues: List[OutQueue] = []
|
||||
|
||||
# Spawn workers.
|
||||
workers: Dict[torch.device, Tuple[InQueue, OutQueue]] = {}
|
||||
|
||||
def normalize_device(device: torch.device) -> torch.device:
|
||||
if device.type == "cuda" and device.index is None:
|
||||
return torch.device("cuda", index=torch.cuda.current_device())
|
||||
|
||||
if device.type == "cpu" and device.index is not None:
|
||||
return torch.device("cpu")
|
||||
|
||||
return device
|
||||
|
||||
for device in devices:
|
||||
device = normalize_device(device)
|
||||
|
||||
try:
|
||||
in_queue, out_queue = workers[device]
|
||||
except KeyError:
|
||||
in_queue = Queue()
|
||||
out_queue = Queue()
|
||||
workers[device] = (in_queue, out_queue)
|
||||
|
||||
t = Thread(target=worker, args=(in_queue, out_queue, device), daemon=True,)
|
||||
t.start()
|
||||
|
||||
in_queues.append(in_queue)
|
||||
out_queues.append(out_queue)
|
||||
|
||||
return (in_queues, out_queues)
|
||||
|
||||
|
||||
def join_workers(in_queues: List[InQueue], out_queues: List[OutQueue]) -> None:
|
||||
# Close workers.
|
||||
for in_queue in set(in_queues):
|
||||
in_queue.put(None)
|
||||
|
||||
# Join running workers.
|
||||
running = set(out_queues)
|
||||
while running:
|
||||
out_queue = running.pop()
|
||||
ok, payload = out_queue.get()
|
||||
|
||||
done = (False, None)
|
||||
if (ok, payload) == done:
|
||||
continue
|
||||
|
||||
running.add(out_queue)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def spawn_workers(devices: List[torch.device],) -> Generator[Tuple[List[InQueue], List[OutQueue]], None, None]:
|
||||
try:
|
||||
(in_queues, out_queues) = create_workers(devices)
|
||||
yield (in_queues, out_queues)
|
||||
finally:
|
||||
join_workers(in_queues, out_queues)
|
Reference in New Issue
Block a user