diff --git a/LICENSE b/LICENSE index 4167b929cc74..9cb8cbef5a9f 100644 --- a/LICENSE +++ b/LICENSE @@ -16,23 +16,26 @@ Copyright (c) 2016-present, Facebook Inc. All rights reserved. All contributions by Facebook: Copyright (c) 2016 Facebook Inc. - + All contributions by Google: Copyright (c) 2015 Google Inc. All rights reserved. - + All contributions by Yangqing Jia: Copyright (c) 2015 Yangqing Jia All rights reserved. - + +All contributions by Kakao Brain: +Copyright 2019-2020 Kakao Brain + All contributions from Caffe: Copyright(c) 2013, 2014, 2015, the respective contributors All rights reserved. - + All other contributions: Copyright(c) 2015, 2016 the respective contributors All rights reserved. - + Caffe2 uses a copyright model similar to Caffe: each contributor holds copyright over their contributions to Caffe2. The project versioning records all such contribution and copyright details. If a contributor wants to further diff --git a/NOTICE b/NOTICE index a346cb891713..020beaea4c46 100644 --- a/NOTICE +++ b/NOTICE @@ -22,6 +22,9 @@ All contributions by Yangqing Jia: Copyright (c) 2015 Yangqing Jia All rights reserved. +All contributions by Kakao Brain: +Copyright 2019-2020 Kakao Brain + All other contributions: Copyright(c) 2015, 2016 the respective contributors All rights reserved. diff --git a/test/distributed/_pipeline/sync/LICENSE b/test/distributed/_pipeline/sync/LICENSE new file mode 100644 index 000000000000..e52be240fdc9 --- /dev/null +++ b/test/distributed/_pipeline/sync/LICENSE @@ -0,0 +1,27 @@ +Copyright 2019-2020 Kakao Brain + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from this + software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. diff --git a/test/distributed/_pipeline/sync/__init__.py b/test/distributed/_pipeline/sync/__init__.py new file mode 100644 index 000000000000..94cd5bcb415e --- /dev/null +++ b/test/distributed/_pipeline/sync/__init__.py @@ -0,0 +1,8 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +# tests/__init__.py makes pytest can import the application without custom sys.path or PYTHONPATH. +# See also: https://docs.pytest.org/en/latest/goodpractices.html diff --git a/test/distributed/_pipeline/sync/conftest.py b/test/distributed/_pipeline/sync/conftest.py new file mode 100644 index 000000000000..315431d0b644 --- /dev/null +++ b/test/distributed/_pipeline/sync/conftest.py @@ -0,0 +1,37 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +import pytest +import torch + + +@pytest.fixture(autouse=True) +def manual_seed_zero(): + torch.manual_seed(0) + + +@pytest.fixture(scope="session") +def cuda_sleep(): + # Warm-up CUDA. + torch.empty(1, device="cuda") + + # From test/test_cuda.py in PyTorch. + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + torch.cuda._sleep(1000000) + end.record() + end.synchronize() + cycles_per_ms = 1000000 / start.elapsed_time(end) + + def cuda_sleep(seconds): + torch.cuda._sleep(int(seconds * cycles_per_ms * 1000)) + + return cuda_sleep + + +def pytest_report_header(): + return f"torch: {torch.__version__}" diff --git a/test/distributed/_pipeline/sync/skip/__init__.py b/test/distributed/_pipeline/sync/skip/__init__.py new file mode 100644 index 000000000000..ab03724cafbf --- /dev/null +++ b/test/distributed/_pipeline/sync/skip/__init__.py @@ -0,0 +1,6 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. diff --git a/test/distributed/_pipeline/sync/skip/test_api.py b/test/distributed/_pipeline/sync/skip/test_api.py new file mode 100644 index 000000000000..fd2176e799bc --- /dev/null +++ b/test/distributed/_pipeline/sync/skip/test_api.py @@ -0,0 +1,45 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +import copy + +from torch import nn + +from torch.distributed._pipeline.sync.skip import Namespace, skippable, stash + + +def test_namespace_difference(): + ns1 = Namespace() + ns2 = Namespace() + assert ns1 != ns2 + + +def test_namespace_copy(): + ns = Namespace() + assert copy.copy(ns) == ns + assert copy.copy(ns) is not ns + + +def test_skippable_repr(): + @skippable(stash=["hello"]) + class Hello(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(1, 1, 1) + + def forward(self, x): + yield stash("hello", x) + return self.conv(x) # noqa + + m = Hello() + assert ( + repr(m) + == """ +@skippable(Hello( + (conv): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1)) +)) +""".strip() + ) diff --git a/test/distributed/_pipeline/sync/skip/test_gpipe.py b/test/distributed/_pipeline/sync/skip/test_gpipe.py new file mode 100644 index 000000000000..293a263439bc --- /dev/null +++ b/test/distributed/_pipeline/sync/skip/test_gpipe.py @@ -0,0 +1,106 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +import pytest +import torch +from torch import nn + +from torch.distributed._pipeline.sync import Pipe +from torch.distributed._pipeline.sync.skip import pop, skippable, stash +from torch.distributed._pipeline.sync.skip.portal import PortalBlue, PortalCopy, PortalOrange + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") +@pytest.mark.parametrize("balance", [[3], [1, 2], [2, 1], [1, 1, 1]], ids=["3", "1:2", "2:1", "1:1:1"]) +@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) +def test_1to3(balance, checkpoint): + if torch.cuda.device_count() < len(balance): + pytest.skip("at least %d cuda devices required" % len(balance)) + + @skippable(stash=["1to3"]) + class Layer1(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(3, 3, 1) + + def forward(self, input): + yield stash("1to3", input) + output = self.conv(input) + return output # noqa + + class Layer2(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(3, 3, 1) + + def forward(self, input): + output = self.conv(input) + return output + + @skippable(pop=["1to3"]) + class Layer3(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(3, 3, 1) + + def forward(self, input): + skip_1to3 = yield pop("1to3") + output = self.conv(input) + skip_1to3 + return output + + model = nn.Sequential(Layer1(), Layer2(), Layer3()) + model = Pipe(model, balance, chunks=3, checkpoint=checkpoint) + + in_device = model.devices[0] + out_device = model.devices[-1] + + input = torch.rand(30, 3, 224, 224, device=in_device, requires_grad=True) + output = model(input) + loss = output.mean() + loss.backward() + + assert torch.allclose(output.norm(), torch.tensor(1039.0, device=out_device), atol=6e-1) + assert torch.allclose(input.grad.norm(), torch.tensor(0.0004533053, device=in_device)) + + +def test_none_skip(): + @skippable(stash=["none"]) + class Stash(nn.Module): + def forward(self, input): + yield stash("none", None) + return input # noqa + + @skippable(pop=["none"]) + class Pop(nn.Module): + def forward(self, input): + none = yield pop("none") + assert none is None + return input + + model = nn.Sequential(Stash(), Pop()) + model = Pipe(model, [1, 1], devices=["cpu", "cpu"], chunks=5) + + input = torch.rand(10, requires_grad=True) + output = model(input) + + def assert_grad_fn_is_not_portal(grad_fn, visited=None): + if visited is None: + visited = set() + if grad_fn in visited or grad_fn is None: + return + + assert not isinstance(grad_fn, PortalBlue._backward_cls) + assert not isinstance(grad_fn, PortalCopy._backward_cls) + assert not isinstance(grad_fn, PortalOrange._backward_cls) + + visited.add(grad_fn) + for next_grad_fn, _ in grad_fn.next_functions: + assert_grad_fn_is_not_portal(next_grad_fn, visited) + + assert_grad_fn_is_not_portal(output.grad_fn) + + output.sum().backward() + assert input.grad.mean().item() == 1 diff --git a/test/distributed/_pipeline/sync/skip/test_inspect_skip_layout.py b/test/distributed/_pipeline/sync/skip/test_inspect_skip_layout.py new file mode 100644 index 000000000000..b47a60f4e889 --- /dev/null +++ b/test/distributed/_pipeline/sync/skip/test_inspect_skip_layout.py @@ -0,0 +1,111 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +from torch import nn + +from torch.distributed._pipeline.sync.skip import Namespace, pop, skippable, stash +from torch.distributed._pipeline.sync.skip.layout import inspect_skip_layout + + +class Pass(nn.Module): + def forward(self, input): + return input + + +@skippable(stash=["foo"]) +class StashFoo(nn.Module): + def forward(self, input): + yield stash("foo", input) + return input # noqa + + +@skippable(pop=["foo"]) +class PopFoo(nn.Module): + def forward(self, input): + foo = yield stash("foo") + return input + foo + + +@skippable(stash=["bar"]) +class StashBar(nn.Module): + def forward(self, input): + yield stash("bar", input) + return input # noqa + + +@skippable(pop=["bar"]) +class PopBar(nn.Module): + def forward(self, input): + bar = yield pop("bar") + return input + bar + + +def test_no_skippables(): + p1 = nn.Sequential(Pass()) + p2 = nn.Sequential(Pass()) + + layout = inspect_skip_layout([p1, p2]) + policy = [list(layout.copy_policy(i)) for i in range(2)] + + assert policy == [[], []] + + +def test_inner_partition(): + p1 = nn.Sequential(StashFoo(), PopFoo()) + p2 = nn.Sequential(Pass()) + + layout = inspect_skip_layout([p1, p2]) + policy = [list(layout.copy_policy(i)) for i in range(2)] + + assert policy == [[], []] + + +def test_adjoining_partitions(): + p1 = nn.Sequential(StashFoo()) + p2 = nn.Sequential(PopFoo()) + + layout = inspect_skip_layout([p1, p2]) + policy = [list(layout.copy_policy(i)) for i in range(2)] + + assert policy == [[], [(0, None, "foo")]] + + +def test_far_partitions(): + p1 = nn.Sequential(StashFoo()) + p2 = nn.Sequential(Pass()) + p3 = nn.Sequential(PopFoo()) + + layout = inspect_skip_layout([p1, p2, p3]) + policy = [list(layout.copy_policy(i)) for i in range(3)] + + assert policy == [[], [], [(0, None, "foo")]] + + +def test_pop_2_from_different_partitions(): + p1 = nn.Sequential(StashFoo()) + p2 = nn.Sequential(StashBar()) + p3 = nn.Sequential(PopBar(), PopFoo()) + + layout = inspect_skip_layout([p1, p2, p3]) + policy = [list(layout.copy_policy(i)) for i in range(3)] + + # p3 pops 'bar' before 'foo', but the plan is sorted by source partition index. + assert policy == [[], [], [(0, None, "foo"), (1, None, "bar")]] + + +def test_namespace(): + ns1 = Namespace() + ns2 = Namespace() + + p1 = nn.Sequential(StashFoo().isolate(ns1)) + p2 = nn.Sequential(StashFoo().isolate(ns2)) + p3 = nn.Sequential(PopFoo().isolate(ns2), PopFoo().isolate(ns1)) + + layout = inspect_skip_layout([p1, p2, p3]) + policy = [list(layout.copy_policy(i)) for i in range(3)] + + # p3 pops 'bar' before 'foo', but the plan is sorted by source partition index. + assert policy == [[], [], [(0, ns1, "foo"), (1, ns2, "foo")]] diff --git a/test/distributed/_pipeline/sync/skip/test_leak.py b/test/distributed/_pipeline/sync/skip/test_leak.py new file mode 100644 index 000000000000..89e39aa9cedb --- /dev/null +++ b/test/distributed/_pipeline/sync/skip/test_leak.py @@ -0,0 +1,126 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +import pytest +import torch +from torch import nn + +from torch.distributed._pipeline.sync import Pipe, is_checkpointing, is_recomputing +from torch.distributed._pipeline.sync.skip import pop, skippable, stash +from torch.distributed._pipeline.sync.skip.tracker import current_skip_tracker + + +@skippable(stash=["skip"]) +class Stash(nn.Module): + def forward(self, input): + yield stash("skip", input) + return input # noqa + + +@skippable(pop=["skip"]) +class Pop(nn.Module): + def forward(self, input): + skip = yield pop("skip") + return input + skip + + +@pytest.mark.parametrize("train", [True, False], ids=["train", "eval"]) +@pytest.mark.parametrize("checkpoint", ["always", "except_last", "never"]) +def test_delete_portal_tensor(train, checkpoint): + # Without checkpointing: + # +- Stash --+ +--- Pop ----+ - - - layers + # | 2,blue,1 |--| 1,orange,0 | - - - tensor_life and portal function + # +----------+ +------------+ + # + # With checkpointing: + # +- Stash --+ +--- Pop ----+ +--- Pop'----+ +- Stash'--+ + # | 3,blue,2 |--| 2,orange,1 |--| 1,orange,0 |--| 1,blue,0 | + # +----------+ +------------+ +------------+ +----------+ + + def portal_tensor_life_is(tensor_life, skip_tracker=None): + if skip_tracker is None: + skip_tracker = current_skip_tracker() + + # Get the current portal. + portal = list(skip_tracker.portals.values())[0] + + if tensor_life == 0: + return portal.tensor_life == 0 and portal.tensor is None + else: + return portal.tensor_life == tensor_life and portal.tensor is not None + + # Check the portal tensor after 'Stash'. + stash_ = Stash() + + @stash_.register_forward_hook + def check_portal_tensor_after_stash(*_): + if is_checkpointing(): + assert portal_tensor_life_is(2) + elif is_recomputing(): + assert portal_tensor_life_is(0) + else: + assert portal_tensor_life_is(1) + + pop_ = Pop() + + @pop_.register_forward_hook + def check_portal_tensor_after_pop(*_): + if is_checkpointing(): + assert portal_tensor_life_is(1) + elif is_recomputing(): + assert portal_tensor_life_is(0) + else: + assert portal_tensor_life_is(0) + + class NoPortalTensorAtBackward(nn.Module): + class F(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + ctx.skip_tracker = current_skip_tracker() + return input.detach() + + @staticmethod + def backward(ctx, grad): + assert portal_tensor_life_is(0, skip_tracker=ctx.skip_tracker) + return grad + + def forward(self, input): + return self.F.apply(input) + + model = nn.Sequential(NoPortalTensorAtBackward(), stash_, pop_) + model = Pipe(model, balance=[2, 1], devices=["cpu", "cpu"], chunks=2, checkpoint=checkpoint) + + input = torch.rand(10, requires_grad=True) + + if train: + model.train() + output = model(input) + output.norm().backward() + else: + model.eval() + with torch.no_grad(): + model(input) + + +@pytest.mark.parametrize("train", [True, False], ids=["train", "eval"]) +def test_no_portal_without_pipe(train, monkeypatch): + def deny(*args, **kwargs): + raise AssertionError("tried to create Portal without Pipe") + + monkeypatch.setattr("torch.distributed._pipeline.sync.skip.portal.Portal.__init__", deny) + + model = nn.Sequential(Stash(), Pop()) + + input = torch.rand(10, requires_grad=True) + + if train: + model.train() + output = model(input) + output.norm().backward() + else: + model.eval() + with torch.no_grad(): + model(input) diff --git a/test/distributed/_pipeline/sync/skip/test_portal.py b/test/distributed/_pipeline/sync/skip/test_portal.py new file mode 100644 index 000000000000..452192ee3da9 --- /dev/null +++ b/test/distributed/_pipeline/sync/skip/test_portal.py @@ -0,0 +1,155 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +import pytest +import torch + +from torch.distributed._pipeline.sync.dependency import fork, join +from torch.distributed._pipeline.sync.skip.portal import Portal +from torch.distributed._pipeline.sync.stream import default_stream + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") +def test_copy_returns_on_next_device(): + portal = Portal(torch.rand(1), tensor_life=1) + + prev_stream = default_stream(torch.device("cpu")) + next_stream = default_stream(torch.device("cuda")) + + phony = torch.zeros(0, requires_grad=True) + assert phony.device.type == "cpu" + + phony = portal.copy(prev_stream, next_stream, phony) + assert phony.device.type == "cuda" + + +def test_blue_orange(): + tensor1 = torch.rand(1, requires_grad=True) + tensor2 = torch.rand(1, requires_grad=True) + + # Same with: output = tensor1*2 + tensor2 + # + # +----------------------+ + # | | + # tensor2 -- PortalBlue -+ +- PortalOrange -+ + # | | | + # tensor1 ------------ Join -- Fork --- Mul --- Add -- output + # + main = tensor1 + portal = Portal(tensor2, tensor_life=2) + phony = portal.blue() + main = join(main, phony) + main, phony = fork(main) + sub = portal.orange(phony) + output = main * 2 + sub + + output.backward() + + assert torch.allclose(tensor1.grad, torch.tensor([2.0])) + assert torch.allclose(tensor2.grad, torch.tensor([1.0])) + + +def test_blue_orange_not_requires_grad(): + tensor1 = torch.rand(1, requires_grad=True) + tensor2 = torch.rand(1) + + # Same with: output = tensor1*2 + tensor2 + # + # +----------------------+ + # | | + # tensor2 -- PortalBlue -+ +- PortalOrange -+ + # | | | + # tensor1 ------------ Join -- Fork --- Mul --- Add -- output + # + main = tensor1 + portal = Portal(tensor2, tensor_life=2) + phony = portal.blue() + main = join(main, phony) + main, phony = fork(main) + sub = portal.orange(phony) + output = main * 2 + sub + + output.backward() + + assert torch.allclose(tensor1.grad, torch.tensor([2.0])) + assert tensor2.grad is None + + +def test_use_grad(): + tensor = torch.rand(1, requires_grad=True) + portal = Portal(tensor, tensor_life=1) + + portal.put_grad(tensor) + assert portal.use_grad() is tensor + + # Gradient in a portal is ephemeral. + with pytest.raises(RuntimeError): + portal.use_grad() + + +class TestTensorLife: + @pytest.fixture + def new_portal(self): + portal = None + + def new_portal(tensor_life): + nonlocal portal + tensor = torch.rand(1, requires_grad=True) + portal = Portal(tensor, tensor_life) + return portal, tensor + + yield new_portal + + # A test using this fixture must exhaust the tensor in the portal. + with pytest.raises(RuntimeError): + portal.check_tensor_life() + assert portal.tensor is None + + def test_tensor_life_0(self, new_portal): + portal, tensor = new_portal(0) + assert portal.tensor is None + + def test_tensor_life_1(self, new_portal): + portal, tensor = new_portal(1) + assert portal.tensor is tensor + + portal.blue() + + def test_tensor_life_2(self, new_portal): + portal, tensor = new_portal(2) + assert portal.tensor is tensor + + phony = portal.blue() + assert portal.orange(phony).data_ptr() == tensor.data_ptr() + + def test_tensor_life_3(self, new_portal): + portal, tensor = new_portal(3) + assert portal.tensor is tensor + + phony = portal.blue() + assert portal.orange(phony).data_ptr() == tensor.data_ptr() + assert portal.orange(phony).data_ptr() == tensor.data_ptr() + + def test_tensor_life_4(self, new_portal): + portal, tensor = new_portal(4) + assert portal.tensor is tensor + + phony = portal.blue() + assert portal.orange(phony).data_ptr() == tensor.data_ptr() + assert portal.orange(phony).data_ptr() == tensor.data_ptr() + portal.blue() + + def test_tensor_life_3_plus_1(self, new_portal): + portal, tensor = new_portal(3) + assert portal.tensor is tensor + + phony = portal.blue() + assert portal.orange(phony).data_ptr() == tensor.data_ptr() + assert portal.orange(phony).data_ptr() == tensor.data_ptr() + + another_tensor = torch.rand(1, requires_grad=True) + portal.put_tensor(another_tensor, tensor_life=1) + portal.blue() diff --git a/test/distributed/_pipeline/sync/skip/test_stash_pop.py b/test/distributed/_pipeline/sync/skip/test_stash_pop.py new file mode 100644 index 000000000000..7a5b16a39cff --- /dev/null +++ b/test/distributed/_pipeline/sync/skip/test_stash_pop.py @@ -0,0 +1,136 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +import pytest +import torch +from torch import nn + +from torch.distributed._pipeline.sync.skip import pop, skippable, stash +from torch.distributed._pipeline.sync.skip.tracker import SkipTracker, use_skip_tracker + + +@pytest.fixture(autouse=True) +def skip_tracker(): + skip_tracker = SkipTracker() + with use_skip_tracker(skip_tracker): + yield skip_tracker + + +def test_stash(skip_tracker): + @skippable(stash=["foo"]) + class Stash(nn.Module): + def forward(self, input): + yield stash("foo", input) + return input * 2 # noqa + + l1 = Stash() + + assert len(skip_tracker.tensors) == 0 + + with use_skip_tracker(skip_tracker): + l1(torch.tensor(42)) + + assert len(skip_tracker.tensors) == 1 + + +def test_pop(): + @skippable(stash=["foo"]) + class Stash(nn.Module): + def forward(self, input): + yield stash("foo", input) + return input * 2 # noqa + + @skippable(pop=["foo"]) + class Pop(nn.Module): + def forward(self, input): + foo = yield pop("foo") + return foo # noqa + + l1 = Stash() + l2 = Pop() + + output = l2(l1(torch.tensor(42))) + + assert output.item() == 42 + + +def test_declare_but_not_use(): + @skippable(stash=["foo"]) + class Stash(nn.Module): + def forward(self, input): + return input * 2 + + @skippable(pop=["foo"]) + class Pop(nn.Module): + def forward(self, input): + return input * 3 + + l1 = Stash() + l2 = Pop() + + with pytest.raises(RuntimeError): + l1(torch.tensor(42)) + + with pytest.raises(RuntimeError): + l2(torch.tensor(42)) + + +def test_stash_not_declared(): + @skippable() + class Stash(nn.Module): + def forward(self, input): + yield stash("foo", input) + return input * 2 # noqa + + l1 = Stash() + + with pytest.raises(RuntimeError): + l1(torch.tensor(42)) + + +def test_pop_not_declared(): + @skippable(stash=["foo"]) + class Stash(nn.Module): + def forward(self, input): + yield stash("foo", input) + return input * 2 # noqa + + @skippable() + class Pop(nn.Module): + def forward(self, input): + foo = yield pop("foo") + return foo # noqa + + l1 = Stash() + l2 = Pop() + + latent = l1(torch.tensor(42)) + + with pytest.raises(RuntimeError): + l2(latent) + + +def test_pop_not_stashed(): + @skippable(pop=["foo"]) + class Pop(nn.Module): + def forward(self, input): + yield pop("foo") + + l1 = Pop() + + with pytest.raises(RuntimeError): + l1(torch.tensor(42)) + + +def test_stash_none(): + @skippable(stash=["foo"]) + class Stash(nn.Module): + def forward(self, input): + yield stash("foo", None) + return input * 2 # noqa + + l1 = Stash() + l1(torch.tensor(42)) diff --git a/test/distributed/_pipeline/sync/skip/test_tracker.py b/test/distributed/_pipeline/sync/skip/test_tracker.py new file mode 100644 index 000000000000..d2e1a8135c37 --- /dev/null +++ b/test/distributed/_pipeline/sync/skip/test_tracker.py @@ -0,0 +1,127 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +from queue import Queue +import threading + +import pytest +import torch +from torch import nn + +from torch.distributed._pipeline.sync.checkpoint import enable_checkpointing, enable_recomputing +from torch.distributed._pipeline.sync.microbatch import Batch +from torch.distributed._pipeline.sync.skip import pop, skippable, stash +from torch.distributed._pipeline.sync.skip.layout import SkipLayout +from torch.distributed._pipeline.sync.skip.tracker import SkipTracker, SkipTrackerThroughPotals, current_skip_tracker + + +def test_default_skip_tracker(): + q = Queue() + + def f(): + q.put(current_skip_tracker()) + + t = threading.Thread(target=f) + t.start() + t.join() + + skip_tracker = q.get() + + assert type(skip_tracker) is SkipTracker + assert type(skip_tracker) is not SkipTrackerThroughPotals + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") +def test_default_skip_tracker_by_data_parallel(): + @skippable(stash=["foo"]) + class Stash(nn.Module): + def forward(self, input): + yield stash("foo", input) + return input * 2 # noqa + + @skippable(pop=["foo"]) + class Pop(nn.Module): + def forward(self, input): + foo = yield pop("foo") + return foo + + model = nn.Sequential(Stash(), Pop()) + model = nn.DataParallel(model, device_ids=[0, 0], output_device=0) + + input = torch.rand(10, device=0) + output = model(input) + + assert torch.allclose(output, input) + + +def test_reuse_portal(): + skip_layout = SkipLayout(num_partitions=2, skip_routes={(None, "test"): (0, 1)}) + skip_tracker = SkipTrackerThroughPotals(skip_layout) + + batch = Batch(torch.tensor([1.0])) + a = torch.tensor([2.0]) + b = torch.tensor([2.0]) + + skip_tracker.save(batch, None, "test", a) + portal = skip_tracker.portals[(None, "test")] + + skip_tracker.save(batch, None, "test", b) + assert portal is skip_tracker.portals[(None, "test")] + + +def test_no_copy_no_portal(): + skip_layout = SkipLayout(num_partitions=2, skip_routes={(None, "copy"): (0, 1), (None, "not_copy"): (0, 0)}) + skip_tracker = SkipTrackerThroughPotals(skip_layout) + + batch = Batch(torch.tensor([1.0])) + a = torch.tensor([2.0]) + b = torch.tensor([2.0]) + + skip_tracker.save(batch, None, "copy", a) + skip_tracker.save(batch, None, "not_copy", b) + + assert (None, "copy") in skip_tracker.portals + assert (None, "copy") not in skip_tracker.tensors + assert (None, "not_copy") in skip_tracker.tensors + assert (None, "not_copy") not in skip_tracker.portals + + +def test_tensor_life_without_checkpointing(): + skip_layout = SkipLayout(num_partitions=2, skip_routes={(None, "test"): (0, 1)}) + skip_tracker = SkipTrackerThroughPotals(skip_layout) + + batch = Batch(torch.tensor([1.0])) + tensor = torch.tensor([2.0]) + + skip_tracker.save(batch, None, "test", tensor) + assert skip_tracker.portals[(None, "test")].tensor_life == 1 + + skip_tracker.load(batch, None, "test") + assert skip_tracker.portals[(None, "test")].tensor_life == 0 + + +def test_tensor_life_with_checkpointing(): + skip_layout = SkipLayout(num_partitions=2, skip_routes={(None, "test"): (0, 1)}) + skip_tracker = SkipTrackerThroughPotals(skip_layout) + + batch = Batch(torch.tensor([1.0])) + tensor = torch.tensor([2.0]) + + with enable_checkpointing(): + skip_tracker.save(batch, None, "test", tensor) + assert skip_tracker.portals[(None, "test")].tensor_life == 2 + + with enable_checkpointing(): + skip_tracker.load(batch, None, "test") + assert skip_tracker.portals[(None, "test")].tensor_life == 1 + + with enable_recomputing(): + skip_tracker.load(batch, None, "test") + assert skip_tracker.portals[(None, "test")].tensor_life == 0 + + with enable_recomputing(): + skip_tracker.save(batch, None, "test", tensor) + assert skip_tracker.portals[(None, "test")].tensor_life == 0 diff --git a/test/distributed/_pipeline/sync/skip/test_verify_skippables.py b/test/distributed/_pipeline/sync/skip/test_verify_skippables.py new file mode 100644 index 000000000000..a94f9756b976 --- /dev/null +++ b/test/distributed/_pipeline/sync/skip/test_verify_skippables.py @@ -0,0 +1,152 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +import pytest +from torch import nn + +from torch.distributed._pipeline.sync.skip import Namespace, skippable, verify_skippables + + +def test_matching(): + @skippable(stash=["foo"]) + class Layer1(nn.Module): + pass + + @skippable(pop=["foo"]) + class Layer2(nn.Module): + pass + + verify_skippables(nn.Sequential(Layer1(), Layer2())) + + +def test_stash_not_pop(): + @skippable(stash=["foo"]) + class Layer1(nn.Module): + pass + + with pytest.raises(TypeError) as e: + verify_skippables(nn.Sequential(Layer1())) + assert "no module declared 'foo' as poppable but stashed" in str(e.value) + + +def test_pop_unknown(): + @skippable(pop=["foo"]) + class Layer1(nn.Module): + pass + + with pytest.raises(TypeError) as e: + verify_skippables(nn.Sequential(Layer1())) + assert "'0' declared 'foo' as poppable but it was not stashed" in str(e.value) + + +def test_stash_again(): + @skippable(stash=["foo"]) + class Layer1(nn.Module): + pass + + @skippable(stash=["foo"]) + class Layer2(nn.Module): + pass + + @skippable(pop=["foo"]) + class Layer3(nn.Module): + pass + + with pytest.raises(TypeError) as e: + verify_skippables(nn.Sequential(Layer1(), Layer2(), Layer3())) + assert "'1' redeclared 'foo' as stashable" in str(e.value) + + +def test_pop_again(): + @skippable(stash=["foo"]) + class Layer1(nn.Module): + pass + + @skippable(pop=["foo"]) + class Layer2(nn.Module): + pass + + @skippable(pop=["foo"]) + class Layer3(nn.Module): + pass + + with pytest.raises(TypeError) as e: + verify_skippables(nn.Sequential(Layer1(), Layer2(), Layer3())) + assert "'2' redeclared 'foo' as poppable" in str(e.value) + + +def test_stash_pop_together_different_names(): + @skippable(stash=["foo"]) + class Layer1(nn.Module): + pass + + @skippable(pop=["foo"], stash=["bar"]) + class Layer2(nn.Module): + pass + + @skippable(pop=["bar"]) + class Layer3(nn.Module): + pass + + verify_skippables(nn.Sequential(Layer1(), Layer2(), Layer3())) + + +def test_stash_pop_together_same_name(): + @skippable(stash=["foo"], pop=["foo"]) + class Layer1(nn.Module): + pass + + with pytest.raises(TypeError) as e: + verify_skippables(nn.Sequential(Layer1())) + assert "'0' declared 'foo' both as stashable and as poppable" in str(e.value) + + +def test_double_stash_pop(): + @skippable(stash=["foo"]) + class Layer1(nn.Module): + pass + + @skippable(pop=["foo"]) + class Layer2(nn.Module): + pass + + @skippable(stash=["foo"]) + class Layer3(nn.Module): + pass + + @skippable(pop=["foo"]) + class Layer4(nn.Module): + pass + + with pytest.raises(TypeError) as e: + verify_skippables(nn.Sequential(Layer1(), Layer2(), Layer3(), Layer4())) + assert "'2' redeclared 'foo' as stashable" in str(e.value) + assert "'3' redeclared 'foo' as poppable" in str(e.value) + + +def test_double_stash_pop_but_isolated(): + @skippable(stash=["foo"]) + class Layer1(nn.Module): + pass + + @skippable(pop=["foo"]) + class Layer2(nn.Module): + pass + + @skippable(stash=["foo"]) + class Layer3(nn.Module): + pass + + @skippable(pop=["foo"]) + class Layer4(nn.Module): + pass + + ns1 = Namespace() + ns2 = Namespace() + + verify_skippables( + nn.Sequential(Layer1().isolate(ns1), Layer2().isolate(ns1), Layer3().isolate(ns2), Layer4().isolate(ns2),) + ) diff --git a/test/distributed/_pipeline/sync/test_balance.py b/test/distributed/_pipeline/sync/test_balance.py new file mode 100644 index 000000000000..59d7d8d1227d --- /dev/null +++ b/test/distributed/_pipeline/sync/test_balance.py @@ -0,0 +1,225 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +import time + +import pytest +import torch +from torch import nn + +from torch.distributed._pipeline.sync.balance import balance_by_size, balance_by_time, blockpartition +from torch.distributed._pipeline.sync.balance.profile import layerwise_sandbox + +skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") + +devices = ["cpu"] +if torch.cuda.is_available(): + devices.append("cuda") + + +def test_blockpartition(): + assert blockpartition.solve([1, 2, 3, 4, 5, 6], partitions=2) == [[1, 2, 3, 4], [5, 6]] + + +def test_blockpartition_zeros(): + assert blockpartition.solve([0, 0], partitions=2) == [[0], [0]] + + +def test_blockpartition_non_positive_partitions(): + with pytest.raises(ValueError): + blockpartition.solve([42], partitions=0) + with pytest.raises(ValueError): + blockpartition.solve([42], partitions=-1) + + +def test_blockpartition_short_sequence(): + with pytest.raises(ValueError): + blockpartition.solve([], partitions=1) + with pytest.raises(ValueError): + blockpartition.solve([42], partitions=2) + + +@pytest.mark.parametrize("device", devices) +@pytest.mark.skip(reason="Flaky due to time.sleep()") +def test_balance_by_time(device): + class Delay(nn.Module): + def __init__(self, seconds): + super().__init__() + self.seconds = seconds + + def forward(self, x): + time.sleep(self.seconds) + return x + + model = nn.Sequential(*[Delay(i / 10) for i in [1, 2, 3, 4, 5, 6]]) + sample = torch.rand(1) + balance = balance_by_time(2, model, sample, device=device) + assert balance == [4, 2] + + +def test_balance_by_time_loop_resets_input(): + # nn.Flatten was introduced at PyTorch 1.2.0. + class Flatten(nn.Module): + def forward(self, x): + return x.flatten(1) + + model = nn.Sequential(nn.Conv2d(3, 2, 1), Flatten(), nn.Linear(128, 10)) + sample = torch.rand(10, 3, 8, 8) + balance = balance_by_time(2, model, sample, device="cpu") + assert balance == [1, 2] + + +@skip_if_no_cuda +def test_balance_by_size_latent(): + class Expand(nn.Module): + def __init__(self, times): + super().__init__() + self.times = times + + def forward(self, x): + for i in range(self.times): + x = x + torch.rand_like(x, requires_grad=True) + return x + + sample = torch.rand(10, 100, 100) + + model = nn.Sequential(*[Expand(i) for i in [1, 2, 3, 4, 5, 6]]) + balance = balance_by_size(2, model, sample) + assert balance == [4, 2] + + model = nn.Sequential(*[Expand(i) for i in [6, 5, 4, 3, 2, 1]]) + balance = balance_by_size(2, model, sample) + assert balance == [2, 4] + + +@skip_if_no_cuda +def test_balance_by_size_param(): + model = nn.Sequential(*[nn.Linear(i + 1, i + 2) for i in range(6)]) + sample = torch.rand(7, 1) + balance = balance_by_size(2, model, sample, param_scale=100) + assert balance == [4, 2] + + model = nn.Sequential(*[nn.Linear(i + 2, i + 1) for i in reversed(range(6))]) + sample = torch.rand(1, 7) + balance = balance_by_size(2, model, sample, param_scale=100) + assert balance == [2, 4] + + +@skip_if_no_cuda +def test_balance_by_size_param_scale(): + class Tradeoff(nn.Module): + def __init__(self, param_size, latent_size): + super().__init__() + self.fc = nn.Linear(param_size, param_size) + self.latent_size = latent_size + + def forward(self, x): + for i in range(self.latent_size): + x = x + torch.rand_like(x, requires_grad=True) + return x + + model = nn.Sequential( + Tradeoff(param_size=1, latent_size=6), + Tradeoff(param_size=2, latent_size=5), + Tradeoff(param_size=3, latent_size=4), + Tradeoff(param_size=4, latent_size=3), + Tradeoff(param_size=5, latent_size=2), + Tradeoff(param_size=6, latent_size=1), + ) + + sample = torch.rand(1, requires_grad=True) + + balance = balance_by_size(2, model, sample, param_scale=0) + assert balance == [2, 4] + + balance = balance_by_size(2, model, sample, param_scale=100) + assert balance == [4, 2] + + +@pytest.mark.parametrize("device", devices) +def test_layerwise_sandbox(device): + model = nn.Sequential(nn.Conv2d(3, 3, 1), nn.BatchNorm2d(3)) + model.eval() + + for layer in layerwise_sandbox(model, torch.device(device)): + assert layer.training + assert all(p.device.type == device for p in layer.parameters()) + + assert all(not l.training for l in model) + assert all(p.device.type == "cpu" for p in model.parameters()) + + +@pytest.mark.parametrize("device", devices) +def test_sandbox_during_profiling(device): + model = nn.Sequential(nn.BatchNorm2d(3)) + + before = {k: v.clone() for k, v in model.state_dict().items()} + + sample = torch.rand(1, 3, 10, 10) + balance_by_time(1, model, sample, device=device) + + after = model.state_dict() + + assert before.keys() == after.keys() + for key, value in before.items(): + assert torch.allclose(after[key], value), key + + +def test_not_training(): + class AssertTraining(nn.Module): + def forward(self, x): + assert self.training + return x + + model = nn.Sequential(AssertTraining()) + + model.eval() + assert not model.training + + sample = torch.rand(1) + balance_by_time(1, model, sample, device="cpu") + + assert not model.training + + +def test_balance_by_time_tuple(): + class Twin(nn.Module): + def forward(self, x): + return x, x.detach() + + class Add(nn.Module): + def forward(self, a_b): + a, b = a_b + return a + b + + model = nn.Sequential(Twin(), Add()) + sample = torch.rand(1, requires_grad=True) + balance_by_time(1, model, sample, device="cpu") + + +@skip_if_no_cuda +def test_balance_by_size_tuple(): + class Twin(nn.Module): + def forward(self, x): + return x, x.detach() + + class Add(nn.Module): + def forward(self, a_b): + a, b = a_b + return a + b + + model = nn.Sequential(Twin(), Add()) + sample = torch.rand(1, requires_grad=True) + balance_by_size(1, model, sample) + + +def test_already_has_grad(): + model = nn.Sequential(nn.Conv2d(3, 3, 1)) + sample = torch.rand(1, 3, 32, 32) + model(sample).norm().backward() + + with pytest.raises(ValueError, match="some parameter already has gradient"): + balance_by_time(1, model, sample, device="cpu") diff --git a/test/distributed/_pipeline/sync/test_bugs.py b/test/distributed/_pipeline/sync/test_bugs.py new file mode 100644 index 000000000000..c3152745b5bb --- /dev/null +++ b/test/distributed/_pipeline/sync/test_bugs.py @@ -0,0 +1,128 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +import pytest +import torch +from torch import nn +import torch.nn.functional as F + +from torch.distributed._pipeline.sync import Pipe + + +def test_python_autograd_function(): + # A Python autograd function might fail with this error: + # + # RuntimeError: Returning Variables sharing storage with other Variables + # that require grad is not supported in Python functions. Please submit a + # feature request if you hit this error. + # + # It doesn't look like an essential restriction. But it happens on the + # current PyTorch version. To avoid it, we should detach the tensor before + # returning by identity autograd functions, such as Wait, Fork, and Join. + # + class Identity(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + return input + + @staticmethod + def backward(ctx, grad): + return grad + + class M(nn.Module): + def forward(self, input): + return Identity.apply(input) + + model = nn.Sequential(M(), M()) + model = Pipe(model, [1, 1], devices=["cpu", "cpu"], checkpoint="always") + + x = torch.rand(42) + y = model(x) + assert torch.allclose(x, y) + + +def test_exception_no_hang(): + # In v0.0.2, once a failed partition receives a normal message + # (non-closing) for the next micro-batch, a hang occured. The reason was + # that a failed partition didn't call in_queue.task_done() on a normal + # message. So the former partition was blocked at out_queue.join() for the + # next of next micro-batch. + class ExpectedException(Exception): + pass + + class Pass(nn.Module): + def forward(self, x): + return x + + class Raise(nn.Module): + def forward(self, x): + raise ExpectedException() + + model = nn.Sequential(Pass(), Pass(), Raise()) + model = Pipe(model, [1, 1, 1], devices=["cpu", "cpu", "cpu"], chunks=3) + + with pytest.raises(ExpectedException): + model(torch.rand(3)) + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="2 cuda devices required") +def test_tuple_wait(cuda_sleep): + # In v0.0.3, Wait is applied to only the first tensor on a micro-batch. + # Under this behavior, if checkpointing was disabled, there's a possibility + # that gradient accumulations on other tensors are not synchronized + # properly to the copy stream. + class Sleep(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + return x.detach() + + @staticmethod + def backward(ctx, grad): + with torch.cuda.device(grad.device): + cuda_sleep(0.05) + return grad + + class Layer1(nn.Module): + def forward(self, pair): + a, b = pair + return a * 1, b * 2, b * 3 + + class Layer2(nn.Module): + def forward(self, triple): + a, b, c = triple + b = Sleep.apply(b) + return a + b + c + + model = nn.Sequential(Layer1(), Layer2()) + model = Pipe(model, [1, 1], devices=[0, 1], chunks=32, checkpoint="never") + + a = torch.rand(1024, 3, 32, 32, device=0, requires_grad=True) + b = torch.rand(1024, 3, 32, 32, device=0, requires_grad=True) + + y = model((a, b)) + y.norm().backward() + + torch.cuda.synchronize(0) + torch.cuda.synchronize(1) + + assert torch.isclose(b.grad.norm().cpu(), torch.tensor(5.000)) + + +def test_parallel_randoms(): + class Dropouts(nn.Module): + def forward(self, x): + for _ in range(100): + x = F.dropout(x, p=0.001) + return x + + model = nn.Sequential(Dropouts(), Dropouts()) + + x = torch.rand(10, 10, requires_grad=True) + model = Pipe(model, [1, 1], devices=["cpu", "cpu"], chunks=10, checkpoint="always") + y = model(x) + y.norm().backward() + + assert y.to(torch.bool).tolist() == x.grad.to(torch.bool).tolist() diff --git a/test/distributed/_pipeline/sync/test_checkpoint.py b/test/distributed/_pipeline/sync/test_checkpoint.py new file mode 100644 index 000000000000..edbbe78a6e02 --- /dev/null +++ b/test/distributed/_pipeline/sync/test_checkpoint.py @@ -0,0 +1,158 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +from functools import partial + +import pytest +import torch +from torch import nn +import torch.cuda + +from torch.distributed._pipeline.sync.checkpoint import Checkpointing, checkpoint, is_checkpointing, is_recomputing +from torch.distributed._pipeline.sync.dependency import fork, join +from torch.distributed._pipeline.sync.microbatch import Batch + +devices = ["cpu"] +if torch.cuda.is_available(): + devices.append("cuda") + + +@pytest.mark.parametrize("device", devices) +def test_serial_checkpoints(device): + # Copied from https://github.com/pytorch/pytorch/pull/18568. + timeline = [] + + class Log(torch.autograd.Function): + @staticmethod + def forward(ctx, name, x): + ctx.name = name + timeline.append(f"{name}:forward") + return x.detach() + + @staticmethod + def backward(ctx, grad_output): + name = ctx.name + timeline.append(f"{name}:backward") + return None, grad_output + + a = torch.rand(1, device=device, requires_grad=True) + b = torch.rand(1, device=device, requires_grad=True) + + # Increase the next function sequence number. + _ = a + 1 + 2 + 3 + 4 + 5 + + a = checkpoint(partial(Log.apply, "a"), a) + + a, phony = fork(a) + b = join(b, phony) + + b = checkpoint(partial(Log.apply, "b"), b) + + c = torch.cat((a, b)) + + out = c.sum() + + # +--> {a} --Checkpoint(Log)--> {a} + # {out} --Sum--> {c} --Cat ^-----------------------------+ + # +--> {b} --Checkpoint(Log)--> {b} --First--> {b} + out.backward() + + assert timeline == ["a:forward", "b:forward", "b:forward", "b:backward", "a:forward", "a:backward"] + # |----------------------| |-----------------------| |-----------------------| + # forward pass Checkpoint(Log[b]) Checkpoint(Log[a]) + + +def test_not_requires_grad(): + x = Batch(torch.rand(1, requires_grad=False)) + assert not x[0].requires_grad + + def f(x): + return x * 2 + + chk = Checkpointing(f, x) + x = chk.checkpoint() + assert x[0].requires_grad + + chk.recompute(x) + assert x[0].requires_grad + + x.tensor.backward() + + +def test_not_requires_grad_with_parameter(): + x = torch.rand(1, requires_grad=False) + a = torch.rand(1, requires_grad=True) + + def f(x): + return x * a + + y = checkpoint(f, x) + y.backward() + + assert a.grad is not None + + +@pytest.mark.parametrize("device", devices) +def test_random_in_checkpoint(device): + dropout = nn.Dropout(p=0.5) + + torch.manual_seed(0) + x = torch.randn(3, 3, device=device, requires_grad=True) + y = dropout(x) + y.norm().backward() + + torch.manual_seed(0) + chk_x = torch.randn(3, 3, device=device, requires_grad=True) + chk_y = checkpoint(dropout, chk_x) + chk_y.norm().backward() + + assert torch.allclose(x.grad, chk_x.grad) + + +def test_detect_checkpointing_recomputing(): + logs = [] + + class Detect(nn.Module): + def forward(self, input): + logs.append((is_checkpointing(), is_recomputing())) + return input + + model = Detect() + input = torch.rand(1, requires_grad=True) + + output = checkpoint(model, input) + output.backward() + + assert logs == [(True, False), (False, True)] + + +def test_detect_checkpointing_recomputing_without_checkpoint(): + logs = [] + + class Detect(nn.Module): + def forward(self, input): + logs.append((is_checkpointing(), is_recomputing())) + return input + + model = Detect() + input = torch.rand(1, requires_grad=True) + + output = model(input) + output.backward() + + assert logs == [(False, False)] + + +def test_non_grad_output(): + class ForkNonGrad(nn.Module): + def forward(self, input): + return (input * 2, torch.rand(1)) + + model = ForkNonGrad() + input = torch.rand(1, requires_grad=True) + + output = checkpoint(model, input) + output[0].backward() diff --git a/test/distributed/_pipeline/sync/test_copy.py b/test/distributed/_pipeline/sync/test_copy.py new file mode 100644 index 000000000000..1655fabb59f4 --- /dev/null +++ b/test/distributed/_pipeline/sync/test_copy.py @@ -0,0 +1,68 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +import pytest +import torch + +from torch.distributed._pipeline.sync.copy import Copy, Wait +from torch.distributed._pipeline.sync.stream import CPUStream, current_stream, get_device, is_cuda, new_stream, use_stream + +skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") + + +def _test_copy_wait(prev_stream, next_stream, cuda_sleep=None): + device = get_device(prev_stream) + + with use_stream(prev_stream): + if is_cuda(prev_stream): + cuda_sleep(0.5) + x = torch.ones(100, device=device, requires_grad=True) + + (y,) = Copy.apply(prev_stream, next_stream, x) + (y,) = Wait.apply(prev_stream, next_stream, x) + + with use_stream(next_stream): + assert torch.allclose(y.sum(), torch.tensor(100.0, device=device)) + y.norm().backward() + with use_stream(prev_stream): + assert torch.allclose(x.grad.sum(), torch.tensor(10.0, device=device)) + + +def test_copy_wait_cpu_cpu(): + prev_stream = CPUStream + next_stream = CPUStream + _test_copy_wait(prev_stream, next_stream) + + +@skip_if_no_cuda +def test_copy_wait_cpu_cuda(cuda_sleep): + prev_stream = CPUStream + next_stream = current_stream(torch.device("cuda")) + _test_copy_wait(prev_stream, next_stream, cuda_sleep) + + +@skip_if_no_cuda +def test_copy_wait_cuda_cpu(cuda_sleep): + prev_stream = current_stream(torch.device("cuda")) + next_stream = CPUStream + _test_copy_wait(prev_stream, next_stream, cuda_sleep) + + +@skip_if_no_cuda +def test_copy_wait_cuda_cuda(cuda_sleep): + prev_stream = current_stream(torch.device("cuda")) + next_stream = new_stream(torch.device("cuda")) + _test_copy_wait(prev_stream, next_stream, cuda_sleep) + + +def test_wait_multiple_tensors(): + a = torch.rand(1, requires_grad=True) + b = torch.rand(1, requires_grad=True) + + a, b = Wait.apply(CPUStream, CPUStream, a, b) + + assert a.grad_fn is b.grad_fn + assert a.grad_fn.__class__ is Wait._backward_cls diff --git a/test/distributed/_pipeline/sync/test_deferred_batch_norm.py b/test/distributed/_pipeline/sync/test_deferred_batch_norm.py new file mode 100644 index 000000000000..1691e9fcc252 --- /dev/null +++ b/test/distributed/_pipeline/sync/test_deferred_batch_norm.py @@ -0,0 +1,192 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +from copy import deepcopy +from itertools import chain + +import pytest +import torch +from torch import nn, optim + +from torch.distributed._pipeline.sync.batchnorm import DeferredBatchNorm + +CHUNKS = 4 + + +def tilt_dist(input): + # Tilt variance by channel. + rgb = input.transpose(0, 1) + rgb[0] *= 1 + rgb[1] *= 10 + rgb[2] *= 100 + + # Tilt mean by single batch. + for i, single in enumerate(input): + single += 2 ** i + + return input + + +def chunked_forward(model, input, chunks=CHUNKS): + output_chunks = [] + + for chunk in input.chunk(chunks): + output_chunks.append(model(chunk)) + + return torch.cat(output_chunks) + + +@pytest.mark.parametrize("chunks", [1, 4]) +@pytest.mark.parametrize("input_requires_grad", [True, False]) +def test_transparency(chunks, input_requires_grad): + bn = nn.BatchNorm2d(3) + dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=chunks) + + input1 = torch.rand(16, 3, 224, 224) + input1 = tilt_dist(input1) + input2 = input1.clone() + input1.requires_grad = input_requires_grad + input2.requires_grad = input_requires_grad + + output1 = chunked_forward(bn, input1, chunks=chunks) + output2 = chunked_forward(dbn, input2, chunks=chunks) + + assert torch.allclose(output1, output2, atol=1e-4) + + output1.mean().backward() + output2.mean().backward() + + assert torch.allclose(bn.weight.grad, dbn.weight.grad, atol=1e-4) + + if input_requires_grad: + assert input1.grad is not None + assert input2.grad is not None + assert torch.allclose(input1.grad, input2.grad, atol=1e-4) + + +@pytest.mark.parametrize("momentum", [0.1, None]) +def test_running_stats(momentum): + bn = nn.BatchNorm2d(3, momentum=momentum) + dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=CHUNKS) + + input = torch.rand(16, 3, 224, 224) + input = tilt_dist(input) + + bn(input) + chunked_forward(dbn, input) + + assert torch.allclose(bn.running_mean, dbn.running_mean, atol=1e-4) + assert torch.allclose(bn.running_var, dbn.running_var, atol=1e-4) + + +def test_convert_deferred_batch_norm(): + bn = nn.BatchNorm2d(3, track_running_stats=False) + bn = DeferredBatchNorm.convert_deferred_batch_norm(bn, chunks=CHUNKS) + assert type(bn) is nn.BatchNorm2d # because of track_running_stats=False + + dbn = DeferredBatchNorm(3, chunks=CHUNKS) + dbn_again = DeferredBatchNorm.convert_deferred_batch_norm(dbn, chunks=CHUNKS) + assert dbn is dbn_again + + dbn_again = DeferredBatchNorm.convert_deferred_batch_norm(dbn, chunks=CHUNKS + 1) + assert dbn is not dbn_again # because of different chunks + + +def test_eval(): + bn = nn.BatchNorm2d(3) + dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=CHUNKS) + + input = torch.rand(16, 3, 224, 224) + input = tilt_dist(input) + + bn(input) + chunked_forward(dbn, input) + + bn.eval() + dbn.eval() + + assert torch.allclose(bn(input), dbn(input), atol=1e-4) + + +def test_optimize(): + bn = nn.BatchNorm2d(3) + dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=CHUNKS) + + opt = optim.SGD(chain(bn.parameters(), dbn.parameters()), lr=1.0) + + for i in range(5): + input = torch.rand(16, 3, 224, 224) + input = tilt_dist(input) + + # train + y = bn(input) + a = y.sum() + a.backward() + + y = chunked_forward(dbn, input) + b = y.sum() + b.backward() + + opt.step() + + # eval + bn.eval() + dbn.eval() + + with torch.no_grad(): + assert torch.allclose(bn(input), dbn(input), atol=1e-1 * (10 ** i)) + + +def test_conv_bn(): + bn = nn.Sequential(nn.Conv2d(3, 3, 1), nn.BatchNorm2d(3)) + dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=CHUNKS) + + input = torch.rand(16, 3, 224, 224) + input = tilt_dist(input) + + opt = optim.SGD(chain(bn.parameters(), dbn.parameters()), lr=0.1) + + # 1st step + a = bn(input) + b = chunked_forward(dbn, input) + + # Outputs are different. (per-mini-batch vs. per-micro-batch) + assert not torch.allclose(a, b) + + a.sum().backward() + b.sum().backward() + opt.step() + opt.zero_grad() + + # Conv layers are also trained differently because of their different outputs. + assert not torch.allclose(bn[0].weight, dbn[0].weight) + + # But BNs track identical running stats. + assert torch.allclose(bn[1].running_mean, dbn[1].running_mean, atol=1e-4) + assert torch.allclose(bn[1].running_var, dbn[1].running_var, atol=1e3) + + # 2nd step + a = bn(input) + b = chunked_forward(dbn, input) + a.sum().backward() + b.sum().backward() + + # BNs can't track identical running stats due to the different conv layers. + assert not torch.allclose(bn[1].running_mean, dbn[1].running_mean, atol=1e-4) + assert not torch.allclose(bn[1].running_var, dbn[1].running_var, atol=1e3) + + +def test_input_requiring_grad(): + dbn = DeferredBatchNorm(3, chunks=CHUNKS) + + input = torch.rand(16, 3, 224, 224) + input = tilt_dist(input) + input.requires_grad = True + + chunked_forward(dbn, input) + + assert not dbn.sum.requires_grad + assert dbn.sum.grad_fn is None diff --git a/test/distributed/_pipeline/sync/test_dependency.py b/test/distributed/_pipeline/sync/test_dependency.py new file mode 100644 index 000000000000..8dd30ca4c2f6 --- /dev/null +++ b/test/distributed/_pipeline/sync/test_dependency.py @@ -0,0 +1,144 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +import weakref + +import pytest +import torch + +from torch.distributed._pipeline.sync.dependency import Fork, Join, fork, join + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") +def test_fork_join(): + logs = [] + + class Log(torch.autograd.Function): + @staticmethod + def forward(ctx, number, tensor): + ctx.number = number + return tensor.detach() + + @staticmethod + def backward(ctx, grad): + logs.append(ctx.number) + return None, grad + + a = torch.rand(1, device="cpu", requires_grad=True) + b = torch.rand(1, device="cuda", requires_grad=True) + + a = Log.apply(1, a) + + a, phony = fork(a) + b = join(a, phony) + + b = Log.apply(2, b) + b = b.to("cpu") + + (a + b).backward() + + assert logs == [2, 1] + + +def test_fork_join_enable_grad(): + x = torch.rand(1, requires_grad=True) + + with torch.enable_grad(): + x2, p = fork(x) + + assert p.requires_grad + assert x2 is not x + x = x2 + + assert x.requires_grad + assert p.requires_grad + assert x.grad_fn.__class__ is Fork._backward_cls + assert p.grad_fn.__class__ is Fork._backward_cls + + with torch.enable_grad(): + x2 = join(x, p) + + assert x2 is not x + x = x2 + + assert x.requires_grad + assert x.grad_fn.__class__ is Join._backward_cls + + +def test_fork_join_no_grad(monkeypatch): + def do_not_apply(*args): + raise AssertionError("Function.apply called") + + monkeypatch.setattr("torch.autograd.Function.apply", do_not_apply) + + x = torch.rand(1, requires_grad=True) + + with torch.no_grad(): + x2, p = fork(x) + + assert not p.requires_grad + assert x2 is x + x = x2 + + with torch.no_grad(): + x2 = join(x, p) + + assert x2 is x + x = x2 + + +def test_fork_leak(): + leak = None + + class F(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + return input + + @staticmethod + def backward(ctx, grad): + nonlocal leak + leak = weakref.ref(ctx) + return grad + + x = torch.rand(1, requires_grad=True) + x = F.apply(x) + x, phony = fork(x) + x = join(x, phony) + + x.backward() + del x, phony + + assert leak() is None + + +def test_join_when_fork_not_requires_grad(): + x = torch.rand(2, 1) + a, b = x.chunk(2) + + assert not a.requires_grad + a, p = fork(a) + assert not a.requires_grad + assert not p.requires_grad + + assert not b.requires_grad + b = join(b, p) + assert not b.requires_grad + + +def test_join_when_fork_requires_grad(): + x = torch.rand(2, 1) + a, b = x.chunk(2) + + a.requires_grad_() + assert a.requires_grad + a, p = fork(a) + assert a.requires_grad + assert p.requires_grad + + assert not b.requires_grad + b = join(b, p) + assert b.requires_grad diff --git a/test/distributed/_pipeline/sync/test_inplace.py b/test/distributed/_pipeline/sync/test_inplace.py new file mode 100644 index 000000000000..185ad8706054 --- /dev/null +++ b/test/distributed/_pipeline/sync/test_inplace.py @@ -0,0 +1,71 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +import pytest +import torch +from torch import nn + +from torch.distributed._pipeline.sync import Pipe + + +def test_inplace_on_requires_grad(): + model = nn.Sequential(nn.Linear(1, 1), nn.ReLU(inplace=True)) + model = Pipe(model, [1, 1], devices=["cpu", "cpu"], checkpoint="always") + + x = torch.rand(1) + y = model(x) + + message = r"a leaf Variable that requires grad .* used in an in-place operation." + with pytest.raises(RuntimeError, match=message): + y.backward() + + +@pytest.mark.xfail(strict=True) +def test_inplace_on_not_requires_grad(): + # In-place operation on a tensor not requiring grad doesn't cause a + # RuntimeError. Currently, we cannot detect this case. + model = nn.Sequential(nn.ReLU(inplace=True)) + model = Pipe(model, [1], devices=["cpu"], checkpoint="always") + + x = torch.rand(1) + y = model(x) + del model + + message = r"a leaf Variable that requires grad .* used in an in-place operation." + with pytest.raises(RuntimeError, match=message): + y.backward() + + +@pytest.mark.xfail(strict=True) +def test_inplace_incorrect_grad(): + class M(nn.Module): + def forward(self, foo_bar): + # 'foo' requires grad but 'bar' does not. In-place operation on + # 'bar' won't cause a RuntimeError. + foo, bar = foo_bar + + # add_(1) is not idempotent, in contrast to relu_(). If it is + # executed multiple times, it will accumulates each difference onto + # 'bar'. + bar.add_(1) + + # 'bar' is still captured by checkpointing. 'foo' will get + # incorrect grad. + return foo * bar + + model = nn.Sequential(M()) + model = Pipe(model, [1], devices=["cpu"], checkpoint="always") + + foo = torch.tensor([1.0], requires_grad=True) + bar = torch.tensor([1.0]) + + output = model((foo, bar)) + del model + output.backward() + + # The gradient of 'foo' should be 2, but it is 3 actually because + # bar.add_(1) was executed twice due to checkpointing. + assert foo.grad.item() == 2.0 diff --git a/test/distributed/_pipeline/sync/test_microbatch.py b/test/distributed/_pipeline/sync/test_microbatch.py new file mode 100644 index 000000000000..08ee7b546e80 --- /dev/null +++ b/test/distributed/_pipeline/sync/test_microbatch.py @@ -0,0 +1,138 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +import pytest +import torch +import torch.cuda + +from torch.distributed._pipeline.sync.microbatch import Batch, check, gather, scatter + + +def test_batch_atomic(): + x = torch.tensor(42) + b = Batch(x) + + assert b.atomic + + assert b.tensor is x + with pytest.raises(AttributeError): + b.tensors + + assert list(b) == [x] + assert len(b) == 1 + assert b[0] is x + + +def test_batch_non_atomic(): + x, y = torch.tensor(42), torch.tensor(21) + b = Batch((x, y)) + + assert not b.atomic + + with pytest.raises(AttributeError): + b.tensor + assert b.tensors == (x, y) + + assert list(b) == [x, y] + assert len(b) == 2 + assert b[0] is x + assert b[1] is y + + +def test_batch_call(): + a = Batch(torch.tensor(42)) + b = Batch((torch.tensor(42), torch.tensor(21))) + + def f(x): + return x + + assert a.call(f).atomic + assert not b.call(f).atomic + + +def test_batch_setitem_by_index(): + a = Batch(torch.tensor(42)) + b = Batch((torch.tensor(42), torch.tensor(21))) + + a[0] = torch.tensor(0) + b[0] = torch.tensor(0) + + assert a.atomic + assert a[0].item() == 0 + + assert not b.atomic + assert len(b) == 2 + assert b[0].item() == 0 + assert b[1].item() == 21 + + +def test_batch_setitem_by_slice(): + a = Batch(torch.tensor(42)) + b = Batch((torch.tensor(42), torch.tensor(21))) + + a[:] = (torch.tensor(0),) + b[:] = (torch.tensor(0),) + + assert a.atomic + assert a[0].item() == 0 + + assert not b.atomic + assert len(b) == 1 + assert b[0].item() == 0 + + +def test_check(): + check(torch.tensor(42)) + check((torch.tensor(4), torch.tensor(2))) + + with pytest.raises(TypeError): + check(42) + + with pytest.raises(TypeError): + check("str") + + with pytest.raises(TypeError): + check((torch.tensor(4), 2)) + + +def test_gather_tensors(): + a = torch.zeros(1, 1) + b = torch.zeros(1, 1) + + ab = gather([Batch(a), Batch(b)]) + + assert ab.size() == (2, 1) + + +def test_gather_tuples(): + a = (torch.zeros(1, 1), torch.zeros(2, 2)) + b = (torch.zeros(1, 1), torch.zeros(2, 2)) + + ab = gather([Batch(a), Batch(b)]) + + assert isinstance(ab, tuple) + assert ab[0].size() == (2, 1) + assert ab[1].size() == (4, 2) + + +def test_scatter_tensor(): + ab = torch.zeros(2, 1) + + a, b = scatter(ab, chunks=2) + + assert a.tensor.size() == (1, 1) + assert b.tensor.size() == (1, 1) + + +def test_scatter_tuple(): + ab = (torch.zeros(2, 1), torch.zeros(4, 2)) + + a, b = scatter(ab, chunks=2) + + assert a.tensors[0].size() == (1, 1) + assert b.tensors[0].size() == (1, 1) + assert a.tensors[1].size() == (2, 2) + assert b.tensors[1].size() == (2, 2) diff --git a/test/distributed/_pipeline/sync/test_phony.py b/test/distributed/_pipeline/sync/test_phony.py new file mode 100644 index 000000000000..b54a28b9585e --- /dev/null +++ b/test/distributed/_pipeline/sync/test_phony.py @@ -0,0 +1,50 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +import torch + +from torch.distributed._pipeline.sync.phony import get_phony + + +def test_phony_size(): + p = get_phony(torch.device("cpu"), requires_grad=False) + assert p.size() == (0,) + + +def test_phony_requires_grad(): + p1 = get_phony(torch.device("cpu"), requires_grad=True) + p2 = get_phony(torch.device("cpu"), requires_grad=False) + assert p1.requires_grad + assert not p2.requires_grad + + +def test_cached_phony(): + p1 = get_phony(torch.device("cpu"), requires_grad=True) + p2 = get_phony(torch.device("cpu"), requires_grad=True) + assert p1 is p2 + + p3 = get_phony(torch.device("cpu"), requires_grad=False) + p4 = get_phony(torch.device("cpu"), requires_grad=False) + assert p3 is p4 + + assert p1 is not p3 + + +def test_phony_in_autograd_function(): + class Phonify(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + phony = get_phony(input.device, requires_grad=False) + return phony.detach() + + x = torch.rand(1, requires_grad=True) + + p1 = Phonify.apply(x) + p2 = get_phony(torch.device("cpu"), requires_grad=True) + + assert p1 is not p2 + assert p1.grad_fn is not None + assert p2.grad_fn is None diff --git a/test/distributed/_pipeline/sync/test_pipe.py b/test/distributed/_pipeline/sync/test_pipe.py new file mode 100644 index 000000000000..d7915733adc0 --- /dev/null +++ b/test/distributed/_pipeline/sync/test_pipe.py @@ -0,0 +1,608 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +from collections import OrderedDict +from copy import deepcopy +import time + +import pytest +import torch +from torch import nn + +from torch.distributed._pipeline.sync import Pipe + +skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") + + +def test_parameters(): + model = nn.Sequential(nn.Linear(1, 1)) + pipe = Pipe(model, balance=[1], devices=["cpu"], chunks=1) + assert list(pipe.parameters()) != [] + + +def test_public_attrs(): + class MyString: + def __init__(self, value): + self.value = value + + def __str__(self): + return self.value + + model = nn.Sequential(nn.Linear(1, 1)) + pipe = Pipe(model, balance=(1,), devices=("cpu",), chunks=42.000, checkpoint=MyString("always")) + + assert pipe.balance == [1] + assert pipe.devices == [torch.device("cpu")] + assert pipe.chunks == 42 + assert isinstance(pipe.chunks, int) + assert pipe.checkpoint == "always" + assert isinstance(pipe.checkpoint, str) + + +@pytest.mark.parametrize("balance", [[2], [1, 1]]) +def test_sequential_like(balance): + a = nn.Linear(1, 1) + b = nn.Linear(1, 1) + + model = nn.Sequential(a, b) + model = Pipe(model, balance, devices=["cpu", "cpu"]) + + assert len(model) == 2 + assert list(model) == [a, b] + + assert model[0] is a + assert model[1] is b + with pytest.raises(IndexError): + _ = model[2] + + assert model[-1] is b + assert model[-2] is a + + +def test_balance_wrong_length(): + a = nn.Linear(1, 1) + b = nn.Linear(1, 1) + + model = nn.Sequential(a, b) + + with pytest.raises(ValueError): + Pipe(model, balance=[1]) + + with pytest.raises(ValueError): + Pipe(model, balance=[3]) + + +def test_balance_less_than_1(): + a = nn.Linear(1, 1) + b = nn.Linear(1, 1) + + model = nn.Sequential(a, b) + + with pytest.raises(ValueError): + Pipe(model, balance=[0, 2]) + + with pytest.raises(ValueError): + Pipe(model, balance=[-1, 3]) + + +def test_chunks_less_than_1(): + model = nn.Sequential(nn.Linear(1, 1)) + + with pytest.raises(ValueError): + Pipe(model, balance=[1], devices=["cpu"], chunks=0) + + with pytest.raises(ValueError): + Pipe(model, balance=[1], devices=["cpu"], chunks=-1) + + +def test_too_few_devices(): + model = nn.Sequential(nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1)) + + with pytest.raises(IndexError): + # len(balance) > len(devices) + model = Pipe(model, balance=[1, 1, 1, 1], devices=["cpu"]) + + +def test_batch_size_indivisible(): + model = nn.Sequential(nn.Linear(1, 1)) + model = Pipe(model, balance=[1], devices=["cpu"], chunks=4) + + with pytest.warns(None) as record: + model(torch.rand(7, 1)) + + # Indivisible batch size is legal. + assert not record + + +def test_batch_size_small(): + model = nn.Sequential(nn.Linear(1, 1)) + model = Pipe(model, balance=[1], devices=["cpu"], chunks=4) + + with pytest.warns(None) as record: + model(torch.rand(2, 1)) + + # Batch size smaller than chunks is legal. + assert not record + + +def test_checkpoint_mode(): + def count_grad_fn(grad_fn, name, visited=None): + if visited is None: + visited = set() + if grad_fn in visited: + return 0 + visited.add(grad_fn) + + if grad_fn is None: + return 0 + if grad_fn.__class__.__name__ == name: + return 1 + + counter = 0 + for next_grad_fn, _ in grad_fn.next_functions: + counter += count_grad_fn(next_grad_fn, name, visited=visited) + return counter + + model = nn.Sequential(nn.Linear(1, 1)) + input = torch.rand(2, 1) + + always = Pipe(model, balance=[1], devices=["cpu"], chunks=2, checkpoint="always") + except_last = Pipe(model, balance=[1], devices=["cpu"], chunks=2, checkpoint="except_last") + never = Pipe(model, balance=[1], devices=["cpu"], chunks=2, checkpoint="never") + + always_output = always(input) + except_last_output = except_last(input) + never_output = never(input) + + assert count_grad_fn(always_output.grad_fn, "CheckpointBackward") == 2 + assert count_grad_fn(except_last_output.grad_fn, "CheckpointBackward") == 1 + assert count_grad_fn(never_output.grad_fn, "CheckpointBackward") == 0 + + +def test_checkpoint_mode_invalid(): + model = nn.Sequential(nn.Linear(1, 1)) + + with pytest.raises(ValueError, match="checkpoint is not one of 'always', 'except_last', or 'never'"): + Pipe(model, balance=[1], devices=["cpu"], chunks=2, checkpoint="INVALID_CHECKPOINT") + + +def test_checkpoint_mode_when_chunks_1(): + model = nn.Sequential(nn.Linear(1, 1)) + + # All checkpoint modes are fine. + Pipe(model, balance=[1], devices=["cpu"], chunks=1, checkpoint="except_last") + Pipe(model, balance=[1], devices=["cpu"], chunks=1, checkpoint="always") + Pipe(model, balance=[1], devices=["cpu"], chunks=1, checkpoint="never") + + +def test_checkpoint_eval(): + model = nn.Sequential(nn.Linear(1, 1)) + model = Pipe(model, balance=[1], devices=["cpu"], chunks=2) + input = torch.rand(2, 1) + + def find_grad_fn(grad_fn, name): + if grad_fn is None: + return False + if grad_fn.__class__.__name__ == name: + return True + for next_grad_fn, _ in grad_fn.next_functions: + if find_grad_fn(next_grad_fn, name): + return True + return False + + model.train() + train_output = model(input) + assert find_grad_fn(train_output.grad_fn, "CheckpointBackward") + assert find_grad_fn(train_output.grad_fn, "RecomputeBackward") + + model.eval() + eval_output = model(input) + assert not find_grad_fn(eval_output.grad_fn, "CheckpointBackward") + assert not find_grad_fn(eval_output.grad_fn, "RecomputeBackward") + + +def test_checkpoint_non_float_input(): + class ForkNonFloat(nn.Module): + def forward(self, input): + return (input * 2, torch.tensor([False])) + + class JoinNonFloat(nn.Module): + def forward(self, input): + return input[0] * 2 + + model = nn.Sequential(ForkNonFloat(), JoinNonFloat()) + model = Pipe(model, balance=[1, 1], devices=["cpu", "cpu"], chunks=1, checkpoint="always") + + input = torch.rand(1, requires_grad=True) + output = model(input) + output.backward() + + +def test_no_grad(): + model = nn.Sequential(nn.Linear(1, 1)) + model = Pipe(model, balance=[1], devices=["cpu"], chunks=2) + input = torch.rand(2, 1) + + latent = None + + def hook(module, input, output): + _ = module + _ = input + + nonlocal latent + latent = output + + partition = model.partitions[0] + partition.register_forward_hook(hook) + + with torch.no_grad(): + model(input) + + assert latent.grad_fn is None + + +def test_exception(): + class ExpectedException(Exception): + pass + + class Raise(nn.Module): + def forward(self, *_): + raise ExpectedException() + + model = nn.Sequential(Raise()) + model = Pipe(model, balance=[1], devices=["cpu"], chunks=1) + + with pytest.raises(ExpectedException): + model(torch.rand(1)) + + +def test_exception_early_stop_asap(): + """Even the first partitions have finished to process, the partition before + the failed partition should be killed as soon as possible. + """ + + class ExpectedException(Exception): + pass + + class Pass(nn.Module): + def forward(self, x): + return x + + counter = 0 + + class Counter(nn.Module): + def forward(self, x): + time.sleep(0.1) + + nonlocal counter + counter += 1 + + return x + + class Raise(nn.Module): + def forward(self, x): + raise ExpectedException() + + model = nn.Sequential(Pass(), Pass(), Counter(), Raise()) + model = Pipe(model, [1, 1, 1, 1], devices=["cpu", "cpu", "cpu", "cpu"], chunks=3) + + with pytest.raises(ExpectedException): + model(torch.rand(3)) + + # If the early stop doesn't work, it would be 3 instead. + assert counter == 2 + + +def test_input_pair(): + class Two(nn.Module): + def __init__(self): + super().__init__() + self.fc_a = nn.Linear(1, 1) + self.fc_b = nn.Linear(1, 1) + + def forward(self, a_and_b): + a, b = a_and_b + return (self.fc_a(a), self.fc_b(b)) + + model = nn.Sequential(Two()) + model = Pipe(model, balance=[1], devices=["cpu"], chunks=2) + + a = torch.rand(10, 1, requires_grad=True) + b = torch.rand(10, 1, requires_grad=True) + + a_out, b_out = model((a, b)) + loss = (a_out + b_out).mean() + loss.backward() + + assert a.grad is not None + assert b.grad is not None + + +def test_input_singleton(): + class One(nn.Module): + def __init__(self): + super().__init__() + self.fc = nn.Linear(1, 1) + + def forward(self, only_a): + (a,) = only_a + return (self.fc(a),) + + model = nn.Sequential(One()) + model = Pipe(model, balance=[1], devices=["cpu"], chunks=2) + + a = torch.rand(10, 1, requires_grad=True) + + (a_out,) = model((a,)) + loss = a_out.mean() + loss.backward() + + assert all(p.grad is not None for p in model.parameters()) + assert a.grad is not None + + +def test_input_varargs(): + model = nn.Sequential(nn.Linear(1, 1)) + model = Pipe(model, balance=[1], devices=["cpu"]) + + a = torch.rand(1) + b = torch.rand(1) + + # TypeError: forward() takes 2 positional arguments but 3 were given + with pytest.raises(TypeError): + model(a, b) + + +def test_non_tensor(): + class NonTensor(nn.Module): + def forward(self, _): + return "hello" + + model = nn.Sequential(NonTensor()) + model = Pipe(model, balance=[1], devices=["cpu"]) + x = torch.rand(1) + + # TypeError: expected Tensor as element 0 in argument 0, but got str + with pytest.raises(TypeError): + model(x) + + # TypeError: expected Tensor to scatter, but got str + with pytest.raises(TypeError): + model("hello") + + +def test_non_tensor_tuple(): + class NonTensorTuple(nn.Module): + def forward(self, x): + return (x, "hello") + + model = nn.Sequential(NonTensorTuple()) + model = Pipe(model, balance=[1], devices=["cpu"]) + x = torch.rand(1) + + # TypeError: CheckpointBackward.forward: expected Variable (got str) for return value 1 + with pytest.raises(TypeError): + model(x) + + # TypeError: expected Tensor to scatter, but got str + with pytest.raises(TypeError): + model((x, "hello")) + + +@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) +def test_deferred_batch_norm(checkpoint): + bn = nn.BatchNorm2d(3) + pipe_bn = deepcopy(bn) + pipe = Pipe( + nn.Sequential(pipe_bn), balance=[1], devices=["cpu"], chunks=2, checkpoint=checkpoint, deferred_batch_norm=True + ) + + x = torch.rand(4, 3, 10, 10) + pipe(x).mean().backward() + bn(x).mean().backward() + + assert torch.allclose(pipe[0].running_mean, bn.running_mean, atol=1e-4) + assert torch.allclose(pipe[0].running_var, bn.running_var, atol=1e-4) + + +@pytest.mark.parametrize("checkpoint", ["never", "always"]) +def test_deferred_batch_norm_params(checkpoint): + bn = nn.BatchNorm2d(3) + pipe_bn = deepcopy(bn) + pipe = Pipe( + nn.Sequential(pipe_bn), balance=[1], devices=["cpu"], chunks=1, checkpoint=checkpoint, deferred_batch_norm=True + ) + + x = torch.rand(4, 3, 10, 10) + pipe(x).mean().backward() + bn(x).mean().backward() + + assert pipe[0].weight.grad is not None + assert pipe[0].bias.grad is not None + + assert torch.allclose(pipe[0].weight.grad, bn.weight.grad, atol=1e-4) + assert torch.allclose(pipe[0].bias.grad, bn.bias.grad, atol=1e-4) + + +def test_devices(): + a = nn.Linear(1, 1) + b = nn.Linear(1, 1) + c = nn.Linear(1, 1) + + # There are extra two devices. + devices = ["cpu", "cpu", "cpu", "cpu", "cpu"] + + model = nn.Sequential(a, b, c) + model = Pipe(model, [1, 1, 1], devices=devices) + + cpu = torch.device("cpu") + # Extra devices must be discarded. + assert model.devices == [cpu, cpu, cpu] + + +def test_partitions(): + a = nn.Linear(1, 1) + b = nn.Linear(1, 1) + + model = nn.Sequential(a, b) + model = Pipe(model, [1, 1], devices=["cpu", "cpu"]) + + assert isinstance(model.partitions, nn.ModuleList) + assert isinstance(model.partitions[0], nn.Sequential) + assert isinstance(model.partitions[1], nn.Sequential) + + assert "partitions.0.0.weight" in model.state_dict() + + +def test_deny_moving(): + a = nn.Linear(1, 1) + b = nn.Linear(1, 1) + + model = nn.Sequential(a, b) + model = Pipe(model, [1, 1], devices=["cpu", "cpu"]) + + # Moving is denied. + with pytest.raises(TypeError): + model.cuda() + + with pytest.raises(TypeError): + model.cpu() + + with pytest.raises(TypeError): + model.to(torch.device("cuda")) + + with pytest.raises(TypeError): + model.to(0) + + with pytest.raises(TypeError): + model.to("cuda") + + with pytest.raises(TypeError): + model.to(device=0) + + with pytest.raises(TypeError): + model.to(torch.rand(1)) + + with pytest.raises(TypeError): + model.to(tensor=torch.rand(1)) + + # Casting is allowed. + model.half() + model.to(torch.double) + model.to(dtype=torch.float) + + +def test_empty_module(): + # Empty sequential module is not illegal. + model = nn.Sequential() + model = Pipe(model, []) + + assert model(torch.tensor(42)) == torch.tensor(42) + assert model((torch.tensor(42),)) == (torch.tensor(42),) + + # But only tensor or tensors is legal in Pipe. + with pytest.raises(TypeError): + model(42) + + +def test_named_children(): + a = nn.Linear(1, 1) + b = nn.Linear(1, 1) + + model = nn.Sequential(OrderedDict([("a", a), ("b", b)])) + model = Pipe(model, [1, 1], devices=["cpu", "cpu"]) + + names = set(n for n, _ in model.named_modules()) + assert "partitions.0.a" in names + assert "partitions.1.b" in names + + # Pipe doesn't support __getattr__. Unlike nn.Sequential, Pipe requires + # several methods in its namespace. + with pytest.raises(AttributeError): + model.a + + +def test_recommend_auto_balance(): + with pytest.raises(ValueError, match="torch.distributed._pipeline.sync.balance"): + # balance is required + Pipe(nn.Sequential()) + + with pytest.raises(ValueError, match="torch.distributed._pipeline.sync.balance"): + # module and sum of balance have differen length (module: 0, sum of balance: 1) + Pipe(nn.Sequential(), [1]) + + with pytest.raises(ValueError, match="torch.distributed._pipeline.sync.balance"): + # module and sum of balance have different length (module: 2, sum of balance: 1) + Pipe(nn.Sequential(nn.Linear(1, 1), nn.Linear(1, 1)), [1]) + + +def test_verify_module_non_sequential(): + with pytest.raises(TypeError, match="module must be nn.Sequential to be partitioned"): + Pipe(nn.Module(), [1]) + + +def test_verify_module_duplicate_children(): + conv = nn.Conv2d(3, 3, 1) + model = nn.Sequential(conv, conv) + + with pytest.raises(ValueError, match="module with duplicate children is not supported"): + Pipe(model, [1, 1]) + + +@skip_if_no_cuda +def test_verify_module_duplicate_parameters_on_distinct_devices(): + class Surrogate(nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + + conv = nn.Conv2d(3, 3, 1) + model = nn.Sequential(Surrogate(conv), Surrogate(conv)) + + with pytest.raises(ValueError, match="module with duplicate parameters on distinct devices is not supported"): + Pipe(model, [1, 1], devices=["cpu", "cuda"]) + + +def test_verify_module_duplicate_parameters_on_same_device(): + class Surrogate(nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + + conv = nn.Conv2d(3, 3, 1) + model = nn.Sequential(Surrogate(conv), Surrogate(conv)) + + Pipe(model, [1, 1], devices=["cpu", "cpu"]) + + +def test_forward_lockstep(): + timeline = [] + + class DelayedLog(nn.Module): + def __init__(self, j, seconds): + super().__init__() + self.i = 0 + self.j = j + self.seconds = seconds + + def forward(self, x): + time.sleep(self.seconds) + + timeline.append((self.i, self.j)) + self.i += 1 + + return x + + model = nn.Sequential(DelayedLog(0, seconds=0), DelayedLog(1, seconds=0.1)) + model = Pipe(model, balance=[1, 1], devices=["cpu", "cpu"], chunks=3) + model(torch.rand(3, 1)) + + # Expected timeline: (Logs are recorded at !) + # + # Partition #0: 0! 1! 2! + # Partition #1: 000! 111! 222! + # + assert timeline == [(0, 0), (1, 0), (0, 1), (2, 0), (1, 1), (2, 1)] diff --git a/test/distributed/_pipeline/sync/test_pipeline.py b/test/distributed/_pipeline/sync/test_pipeline.py new file mode 100644 index 000000000000..ee902b70cc48 --- /dev/null +++ b/test/distributed/_pipeline/sync/test_pipeline.py @@ -0,0 +1,29 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +from torch.distributed._pipeline.sync.pipeline import clock_cycles + + +def test_clock_cycles(): + assert list(clock_cycles(1, 1)) == [[(0, 0)]] + assert list(clock_cycles(1, 3)) == [[(0, 0)], [(0, 1)], [(0, 2)]] + assert list(clock_cycles(3, 1)) == [[(0, 0)], [(1, 0)], [(2, 0)]] + + assert list(clock_cycles(3, 3)) == [ # noqa + [(0, 0)], + [(1, 0), (0, 1)], + [(2, 0), (1, 1), (0, 2)], + [(2, 1), (1, 2)], + [(2, 2)], + ] + + assert list(clock_cycles(4, 2)) == [ # noqa + [(0, 0)], + [(1, 0), (0, 1)], + [(2, 0), (1, 1)], + [(3, 0), (2, 1)], + [(3, 1)], + ] diff --git a/test/distributed/_pipeline/sync/test_stream.py b/test/distributed/_pipeline/sync/test_stream.py new file mode 100644 index 000000000000..46e68aaca305 --- /dev/null +++ b/test/distributed/_pipeline/sync/test_stream.py @@ -0,0 +1,188 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +import pytest +import torch + +from torch.distributed._pipeline.sync.stream import ( + CPUStream, + current_stream, + default_stream, + get_device, + is_cuda, + new_stream, + record_stream, + use_device, + use_stream, + wait_stream, +) + +skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") + + +class TestNewStream: + def test_new_stream_cpu(self): + stream = new_stream(torch.device("cpu")) + assert stream is CPUStream + + @skip_if_no_cuda + def test_new_stream_cuda(self): + stream = new_stream(torch.device("cuda")) + assert isinstance(stream, torch.cuda.Stream) + assert stream != torch.cuda.default_stream() + + +class TestCurrentStream: + def test_current_stream_cpu(self): + stream = current_stream(torch.device("cpu")) + assert stream is CPUStream + + @skip_if_no_cuda + def test_current_stream_cuda(self): + stream = current_stream(torch.device("cuda")) + assert isinstance(stream, torch.cuda.Stream) + assert stream == torch.cuda.current_stream() + + +class TestDefaultStream: + def test_default_stream_cpu(self): + stream = default_stream(torch.device("cpu")) + assert stream is CPUStream + + @skip_if_no_cuda + def test_default_stream_cuda(self): + stream = default_stream(torch.device("cuda")) + assert isinstance(stream, torch.cuda.Stream) + assert stream == torch.cuda.default_stream() + + +class TestUseDevice: + def test_use_device_cpu(self): + with use_device(torch.device("cpu")): + pass + + @skip_if_no_cuda + def test_use_device_cuda(self): + with use_device(torch.device("cuda")): + pass + + +class TestUseStream: + def test_use_stream_cpu(self): + with use_stream(CPUStream): + pass + + @skip_if_no_cuda + def test_use_stream_cuda(self): + stream = new_stream(torch.device("cuda")) + with use_stream(stream): + assert current_stream(torch.device("cuda")) == stream + + +class TestGetDevice: + def test_get_device_cpu(self): + assert get_device(CPUStream).type == "cpu" + + @skip_if_no_cuda + def test_get_device_cuda(self): + stream = current_stream(torch.device("cuda")) + assert get_device(stream).type == "cuda" + + +class TestWaitStream: + def _test_wait_stream(self, source, target, cuda_sleep=None): + with use_stream(target): + if is_cuda(target): + cuda_sleep(0.5) + x = torch.ones(100, 100, device=get_device(target)) + + wait_stream(source, target) + + with use_stream(source): + assert x.sum().item() == 10000 + + def test_wait_stream_cpu_cpu(self): + source = CPUStream + target = CPUStream + self._test_wait_stream(source, target) + + @skip_if_no_cuda + def test_wait_stream_cpu_cuda(self, cuda_sleep): + source = CPUStream + target = new_stream(torch.device("cuda")) + self._test_wait_stream(source, target, cuda_sleep) + + @skip_if_no_cuda + def test_wait_stream_cuda_cpu(self, cuda_sleep): + source = new_stream(torch.device("cuda")) + target = CPUStream + self._test_wait_stream(source, target, cuda_sleep) + + @skip_if_no_cuda + def test_wait_stream_cuda_cuda(self, cuda_sleep): + source = current_stream(torch.device("cuda")) + target = new_stream(torch.device("cuda")) + self._test_wait_stream(source, target, cuda_sleep) + + +class TestRecordStream: + def test_record_stream_cpu(self): + # It should silently ignore CPU tensors. + x = torch.rand(1, device=torch.device("cpu")) + record_stream(x, CPUStream) + + @skip_if_no_cuda + def test_record_stream_cuda(self, cuda_sleep): + # This test detects unexpected block reallocation. For reliable test, + # the stream to allocate tensors is isolated. The allocator will not + # reuse free blocks which were allocated from another stream. + stream_alloc = new_stream(torch.device("cuda")) + with torch.cuda.stream(stream_alloc): + x = torch.rand(1, device=torch.device("cuda")) + + stream = new_stream(torch.device("cuda")) + record_stream(x, stream) + with use_stream(stream): + cuda_sleep(0.5) + + # 'x' is deleted at Python's perspective. But the block of 'x' is still + # required for 'stream'. 'y' shouldn't be allocated to the block. + data_ptr = x.data_ptr() + del x + stream_alloc.synchronize() + with torch.cuda.stream(stream_alloc): + y = torch.rand(1, device=torch.device("cuda")) + assert y.data_ptr() != data_ptr + + # Pause Python until 'stream' finishes tasks queued. Now the block of + # 'x' is free to be reallocated. + wait_stream(CPUStream, stream) + with torch.cuda.stream(stream_alloc): + z = torch.rand(1, device=torch.device("cuda")) + assert z.data_ptr() == data_ptr + + @skip_if_no_cuda + def test_record_stream_shifted_view(self, cuda_sleep): + # Issue: https://github.com/pytorch/pytorch/issues/27366 + stream_alloc = new_stream(torch.device("cuda")) + with torch.cuda.stream(stream_alloc): + x = torch.rand(2, device=torch.device("cuda")) + + y = x[1:] + assert y.data_ptr() > x.data_ptr() + + stream = new_stream(torch.device("cuda")) + with use_stream(stream): + cuda_sleep(0.5) + record_stream(y, stream) + + data_ptr = x.data_ptr() + del x, y + + stream_alloc.synchronize() + with torch.cuda.stream(stream_alloc): + z = torch.rand(2, device=torch.device("cuda")) + assert z.data_ptr() != data_ptr diff --git a/test/distributed/_pipeline/sync/test_transparency.py b/test/distributed/_pipeline/sync/test_transparency.py new file mode 100644 index 000000000000..88d9c83b9a07 --- /dev/null +++ b/test/distributed/_pipeline/sync/test_transparency.py @@ -0,0 +1,43 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +import torch +from torch import nn + +from torch.distributed._pipeline.sync import Pipe + + +def test_simple_linears(): + def sum_grad(parameters): + return sum([p.grad.sum() for p in parameters if p.grad is not None]) + + def zero_grad(parameters): + for p in parameters: + p.grad = None + + inputs = torch.rand(8, 1) + model = nn.Sequential(nn.Linear(1, 2), nn.Linear(2, 4), nn.Linear(4, 2), nn.Linear(2, 1),) + + # Without Pipe + outputs = model(inputs) + loss = outputs.mean() + loss.backward() + + grad_without_pipe = sum_grad(model.parameters()) + + zero_grad(model.parameters()) + + # With Pipe + model = Pipe(model, [2, 2], devices=["cpu", "cpu"], chunks=4) + + outputs = model(inputs) + loss = outputs.mean() + loss.backward() + + grad_with_pipe = sum_grad(model.parameters()) + + # Both grads should be identical. + assert torch.allclose(grad_with_pipe, grad_without_pipe) diff --git a/test/distributed/_pipeline/sync/test_worker.py b/test/distributed/_pipeline/sync/test_worker.py new file mode 100644 index 000000000000..0247a71ba4a8 --- /dev/null +++ b/test/distributed/_pipeline/sync/test_worker.py @@ -0,0 +1,163 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +import threading +import time + +import pytest +import torch + +from torch.distributed._pipeline.sync.microbatch import Batch +from torch.distributed._pipeline.sync.stream import CPUStream +from torch.distributed._pipeline.sync.worker import Task, spawn_workers + + +class fake_device: + """A test double for :class:`torch.device`. Every fake device is different + with each other. + """ + + type = "fake" + index = None + + +def test_join_running_workers(): + count = 0 + + def counter(): + nonlocal count + time.sleep(0.1) + count += 1 + return Batch(()) + + with spawn_workers([fake_device() for _ in range(10)]) as (in_queues, out_queues): + + def call_in_worker(i, f): + task = Task(CPUStream, compute=f, finalize=None) + in_queues[i].put(task) + + for i in range(10): + call_in_worker(i, counter) + + # There's no nondeterminism because 'spawn_workers' joins all running + # workers. + assert count == 10 + + +def test_join_running_workers_with_exception(): + class ExpectedException(Exception): + pass + + count = 0 + + def counter(): + nonlocal count + time.sleep(0.1) + count += 1 + return Batch(()) + + with pytest.raises(ExpectedException): + with spawn_workers([fake_device() for _ in range(10)]) as (in_queues, out_queues): + + def call_in_worker(i, f): + task = Task(CPUStream, compute=f, finalize=None) + in_queues[i].put(task) + + for i in range(10): + call_in_worker(i, counter) + + raise ExpectedException + + # There's no nondeterminism because only 1 task can be placed in input + # queues. + assert count == 10 + + +def test_compute_multithreading(): + """Task.compute should be executed on multiple threads.""" + thread_ids = set() + + def log_thread_id(): + thread_id = threading.current_thread().ident + thread_ids.add(thread_id) + return Batch(()) + + with spawn_workers([fake_device() for _ in range(2)]) as (in_queues, out_queues): + for i in range(2): + t = Task(CPUStream, compute=log_thread_id, finalize=None) + in_queues[i].put(t) + for i in range(2): + out_queues[i].get() + + assert len(thread_ids) == 2 + + +def test_compute_success(): + """Task.compute returns (True, (task, batch)) on success.""" + + def _42(): + return Batch(torch.tensor(42)) + + with spawn_workers([torch.device("cpu")]) as (in_queues, out_queues): + t = Task(CPUStream, compute=_42, finalize=None) + in_queues[0].put(t) + ok, (task, batch) = out_queues[0].get() + + assert ok + assert task is t + assert isinstance(batch, Batch) + assert batch[0].item() == 42 + + +def test_compute_exception(): + """Task.compute returns (False, exc_info) on failure.""" + + def zero_div(): + 0 / 0 + + with spawn_workers([torch.device("cpu")]) as (in_queues, out_queues): + t = Task(CPUStream, compute=zero_div, finalize=None) + in_queues[0].put(t) + ok, exc_info = out_queues[0].get() + + assert not ok + assert isinstance(exc_info, tuple) + assert issubclass(exc_info[0], ZeroDivisionError) + + +@pytest.mark.parametrize("grad_mode", [True, False]) +def test_grad_mode(grad_mode): + def detect_grad_enabled(): + x = torch.rand(1, requires_grad=torch.is_grad_enabled()) + return Batch(x) + + with torch.set_grad_enabled(grad_mode): + with spawn_workers([torch.device("cpu")]) as (in_queues, out_queues): + task = Task(CPUStream, compute=detect_grad_enabled, finalize=None) + in_queues[0].put(task) + + ok, (_, batch) = out_queues[0].get() + + assert ok + assert batch[0].requires_grad == grad_mode + + +def test_worker_per_device(): + cpu = torch.device("cpu") + cpu0 = torch.device("cpu", index=0) + fake1 = fake_device() + fake2 = fake_device() + + with spawn_workers([cpu, cpu, cpu0, fake1, fake2]) as (in_queues, out_queues): + assert len(in_queues) == len(out_queues) == 5 + + # 0: cpu, 1: cpu, 2: cpu0 + assert in_queues[0] is in_queues[1] is in_queues[2] + assert out_queues[0] is out_queues[1] is out_queues[2] + + # 3: fake1, 4: fake2 + assert in_queues[3] is not in_queues[4] + assert out_queues[3] is not out_queues[4] diff --git a/test/run_test.py b/test/run_test.py index f6c4fc9c2858..7c4b413b5a8f 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -1,6 +1,7 @@ #!/usr/bin/env python import argparse +import copy from datetime import datetime import importlib import modulefinder @@ -95,6 +96,54 @@ TESTS = [ 'test_fx_experimental', 'test_functional_autograd_benchmark', 'test_package', + 'distributed/_pipeline/sync/skip/test_api', + 'distributed/_pipeline/sync/skip/test_gpipe', + 'distributed/_pipeline/sync/skip/test_inspect_skip_layout', + 'distributed/_pipeline/sync/skip/test_leak', + 'distributed/_pipeline/sync/skip/test_portal', + 'distributed/_pipeline/sync/skip/test_stash_pop', + 'distributed/_pipeline/sync/skip/test_tracker', + 'distributed/_pipeline/sync/skip/test_verify_skippables', + 'distributed/_pipeline/sync/test_balance', + 'distributed/_pipeline/sync/test_bugs', + 'distributed/_pipeline/sync/test_checkpoint', + 'distributed/_pipeline/sync/test_copy', + 'distributed/_pipeline/sync/test_deferred_batch_norm', + 'distributed/_pipeline/sync/test_dependency', + 'distributed/_pipeline/sync/test_inplace', + 'distributed/_pipeline/sync/test_microbatch', + 'distributed/_pipeline/sync/test_phony', + 'distributed/_pipeline/sync/test_pipe', + 'distributed/_pipeline/sync/test_pipeline', + 'distributed/_pipeline/sync/test_stream', + 'distributed/_pipeline/sync/test_transparency', + 'distributed/_pipeline/sync/test_worker', +] + +# Tests need to be run with pytest. +USE_PYTEST_LIST = [ + 'distributed/_pipeline/sync/skip/test_api', + 'distributed/_pipeline/sync/skip/test_gpipe', + 'distributed/_pipeline/sync/skip/test_inspect_skip_layout', + 'distributed/_pipeline/sync/skip/test_leak', + 'distributed/_pipeline/sync/skip/test_portal', + 'distributed/_pipeline/sync/skip/test_stash_pop', + 'distributed/_pipeline/sync/skip/test_tracker', + 'distributed/_pipeline/sync/skip/test_verify_skippables', + 'distributed/_pipeline/sync/test_balance', + 'distributed/_pipeline/sync/test_bugs', + 'distributed/_pipeline/sync/test_checkpoint', + 'distributed/_pipeline/sync/test_copy', + 'distributed/_pipeline/sync/test_deferred_batch_norm', + 'distributed/_pipeline/sync/test_dependency', + 'distributed/_pipeline/sync/test_inplace', + 'distributed/_pipeline/sync/test_microbatch', + 'distributed/_pipeline/sync/test_phony', + 'distributed/_pipeline/sync/test_pipe', + 'distributed/_pipeline/sync/test_pipeline', + 'distributed/_pipeline/sync/test_stream', + 'distributed/_pipeline/sync/test_transparency', + 'distributed/_pipeline/sync/test_worker', ] WINDOWS_BLOCKLIST = [ @@ -170,6 +219,28 @@ SLOW_TESTS = [ 'test_quantization', 'test_determination', 'test_futures', + 'distributed/_pipeline/sync/skip/test_api', + 'distributed/_pipeline/sync/skip/test_gpipe', + 'distributed/_pipeline/sync/skip/test_inspect_skip_layout', + 'distributed/_pipeline/sync/skip/test_leak', + 'distributed/_pipeline/sync/skip/test_portal', + 'distributed/_pipeline/sync/skip/test_stash_pop', + 'distributed/_pipeline/sync/skip/test_tracker', + 'distributed/_pipeline/sync/skip/test_verify_skippables', + 'distributed/_pipeline/sync/test_balance', + 'distributed/_pipeline/sync/test_bugs', + 'distributed/_pipeline/sync/test_checkpoint', + 'distributed/_pipeline/sync/test_copy', + 'distributed/_pipeline/sync/test_deferred_batch_norm', + 'distributed/_pipeline/sync/test_dependency', + 'distributed/_pipeline/sync/test_inplace', + 'distributed/_pipeline/sync/test_microbatch', + 'distributed/_pipeline/sync/test_phony', + 'distributed/_pipeline/sync/test_pipe', + 'distributed/_pipeline/sync/test_pipeline', + 'distributed/_pipeline/sync/test_stream', + 'distributed/_pipeline/sync/test_transparency', + 'distributed/_pipeline/sync/test_worker', ] _DEP_MODULES_CACHE: Dict[str, set] = {} @@ -762,12 +833,15 @@ def main(): failure_messages = [] try: for test in selected_tests: - err_message = run_test_module(test, test_directory, options) + options_clone = copy.deepcopy(options) + if test in USE_PYTEST_LIST: + options_clone.pytest = True + err_message = run_test_module(test, test_directory, options_clone) if err_message is None: continue has_failed = True failure_messages.append(err_message) - if not options.continue_through_error: + if not options_clone.continue_through_error: raise RuntimeError(err_message) print_to_stderr(err_message) finally: diff --git a/torch/distributed/_pipeline/__init__.py b/torch/distributed/_pipeline/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/torch/distributed/_pipeline/sync/LICENSE b/torch/distributed/_pipeline/sync/LICENSE new file mode 100644 index 000000000000..e52be240fdc9 --- /dev/null +++ b/torch/distributed/_pipeline/sync/LICENSE @@ -0,0 +1,27 @@ +Copyright 2019-2020 Kakao Brain + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from this + software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. diff --git a/torch/distributed/_pipeline/sync/__init__.py b/torch/distributed/_pipeline/sync/__init__.py new file mode 100644 index 000000000000..ca3c2a8823ad --- /dev/null +++ b/torch/distributed/_pipeline/sync/__init__.py @@ -0,0 +1,11 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +"""A Pipe implementation in PyTorch.""" +from .checkpoint import is_checkpointing, is_recomputing +from .pipe import Pipe + +__all__ = ["Pipe", "is_checkpointing", "is_recomputing"] diff --git a/torch/distributed/_pipeline/sync/balance/__init__.py b/torch/distributed/_pipeline/sync/balance/__init__.py new file mode 100644 index 000000000000..15aa53bc1a2c --- /dev/null +++ b/torch/distributed/_pipeline/sync/balance/__init__.py @@ -0,0 +1,164 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +"""A helper to roughly balance a sequential module. + +Usage:: + + import torch + from torch.distributed._pipeline.sync import Pipe + from torch.distributed._pipeline.sync.balance import balance_by_time + + sample = torch.empty(128, 3, 224, 224) + balance = balance_by_time(torch.cuda.device_count(), model, sample) + + pipe = Pipe(model, balance, chunks=8) + +""" +from typing import List, Tuple, Union + +import torch +from torch import Tensor +import torch.nn as nn + +from . import blockpartition +from .profile import profile_sizes, profile_times + +__all__ = ["balance_by_time", "balance_by_size"] + + +Device = Union[torch.device, int, str] + +Tensors = Tuple[Tensor, ...] +TensorOrTensors = Union[Tensor, Tensors] + + +def balance_cost(cost: List[int], partitions: int) -> List[int]: + partitioned = blockpartition.solve(cost, partitions) + return [len(p) for p in partitioned] + + +def balance_by_time( + partitions: int, + module: nn.Sequential, + sample: TensorOrTensors, + *, + timeout: float = 1.0, + device: Device = torch.device("cuda"), +) -> List[int]: + """Naive automatic balancing by elapsed time per layer. + :: + + sample = torch.empty(128, 3, 224, 224) + balance = balance_by_time(torch.cuda.device_count(), model, sample) + pipe = Pipe(model, balance, chunks=8) + + Args: + partitions (int): + intended number of partitions + module (torch.nn.Sequential): + sequential module to be partitioned + sample (torch.Tensor): + example input with arbitrary batch size + + Keyword Args: + timeout (float): + profiling iterates again if the timeout (in second) is not exceeded + (default: ``1.0``) + device ('cpu' or 'cuda' device): + CPU or CUDA device where each layer is profiled (default: the + current CUDA device) + + Returns: + A list of number of layers in each partition. Use it for the `balance` + parameter of :class:`~torchpipe.Pipe`. + + .. note:: + `module` and `sample` must be placed on the same device. + + """ + times = profile_times(module, sample, timeout, torch.device(device)) + return balance_cost(times, partitions) + + +def balance_by_size( + partitions: int, + module: nn.Sequential, + input: TensorOrTensors, + *, + chunks: int = 1, + param_scale: float = 2.0, + device: Device = torch.device("cuda"), +) -> List[int]: + """Naive automatic balancing by CUDA memory usage per layer. + + During training, required memory for parameters depends on which optimizer + is used. Optimizers may use buffers for each parameter to track + optimization statistics internally, such as momentum buffer in SGD. + + To get more reliable size based balance, you should specify `param_scale` + with regard to your optimizer. The default `param_scale` is 2 instead of 1 + due to gradient accumulation which is necessary for every optimizer. + + Follow this guide to choose correct `param_scale` for typical optimizers: + + ========= ============= ========================================= + Optimizer `param_scale` Internal State + ========= ============= ========================================= + SGD 2--3 (momentum_buffer) + Adam 4--5 exp_avg, exp_avg_sq, (max_exp_avg_sq) + Adadelta 4 square_avg, acc_delta + Adagrad 3 sum + RMSprop 3--5 square_avg, (momentum_buffer), (grad_avg) + ========= ============= ========================================= + + Here's a simple example with the Adam optimizer:: + + balance = balance_by_size( + torch.cuda.device_count(), + model, + + # Same size with mini-batch to train + torch.empty(1024, 3, 224, 224), + + # Number of micro-batches to train with Pipe + chunks=8, + + # 4 for Adam + param_scale=4.0, + ) + + pipe = Pipe(model, balance, chunks=8) + adam = Adam(pipe.parameters()) + + Args: + partitions (int): + intended number of partitions + module (torch.nn.Sequential): + sequential module to be partitioned + input (torch.Tensor): + example mini-batch with the same size to train + + Keyword Args: + chunks (int): + number of micro-batches will be used to train (default: ``1``) + param_scale (float): + how many copies of parameters would be allocated for training. It + depends on optimizer. See the above guide. (default: ``2.0``) + device ('cuda' device): + CUDA device where each layer is profiled (default: the current CUDA + device) + + Returns: + A list of number of layers in each partition. Use it for the `balance` + parameter of :class:`~torchpipe.Pipe`. + + .. note:: + `module` and `input` must be placed on the same CUDA device. + + """ + sizes = profile_sizes(module, input, chunks, param_scale, torch.device(device)) + return balance_cost(sizes, partitions) diff --git a/torch/distributed/_pipeline/sync/balance/blockpartition.py b/torch/distributed/_pipeline/sync/balance/blockpartition.py new file mode 100644 index 000000000000..7afe782f6ac8 --- /dev/null +++ b/torch/distributed/_pipeline/sync/balance/blockpartition.py @@ -0,0 +1,95 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +"""Implements "Block Partitions of Sequences" by Imre Bárány et al. + +Paper: https://arxiv.org/pdf/1308.2452.pdf + +""" +from typing import Iterator, List, Tuple + +__all__ = ["solve"] + + +def solve(sequence: List[int], partitions: int = 1) -> List[List[int]]: + """Splits a sequence into several partitions to minimize variance for each + partition. + + The result might not be optimal. However, it can be done only in O(kn³), + where k is the number of partitions and n is the length of the sequence. + + """ + if partitions < 1: + raise ValueError(f"partitions must be a positive integer ({partitions} < 1)") + + n = len(sequence) + if n < partitions: + raise ValueError(f"sequence is shorter than intended partitions ({n} < {partitions})") + + # Normalize the sequence in [0, 1]. + minimum = min(sequence) + maximum = max(sequence) - minimum + + normal_sequence: List[float] + if maximum == 0: + normal_sequence = [0 for _ in sequence] + else: + normal_sequence = [(x - minimum) / maximum for x in sequence] + + splits = [n // partitions * (x + 1) for x in range(partitions - 1)] + [n] + + def block_size(i: int) -> float: + start = splits[i - 1] if i > 0 else 0 + stop = splits[i] + return sum(normal_sequence[start:stop]) + + def leaderboard() -> Iterator[Tuple[float, int]]: + return ((block_size(i), i) for i in range(partitions)) + + while True: + """ + (1) Fix p ∈ [k] with M(P) = bp. So Bp is a maximal block of P. + """ + # max_size: M(P) + max_size, p = max(leaderboard()) + + while True: + """ + (2) If M(P) ≤ m(P) + 1, then stop. + """ + # min_size: m(P) + min_size, q = min(leaderboard()) + + if max_size <= min_size + 1: + return [sequence[i:j] for i, j in zip([0] + splits[:-1], splits)] + + """ + (3) If M(P) > m(P) + 1, then let m(P) = bq for the q ∈ [k] which is + closest to p (ties broken arbitrarily). Thus Bq is a minimal block + of P. Let Bh be the block next to Bq between Bp and Bq. (Note that + Bh is a non-empty block: if it were, then m(P) = 0 and we should + have chosen Bh instead of Bq.) + """ + if p < q: + """ + So either p < q and then h = q−1 and we define P ∗ by moving + the last element from Bh = Bq−1 to Bq, + """ + h = q - 1 + splits[h] -= 1 + else: + """ + or q < p, and then h = q + 1 and P ∗ is obtained by moving the + first element of Bh = Bq+1 to Bq. + """ + h = q + 1 + splits[q] += 1 + + """ + Set P = P ∗ . If p = h, then go to (1), else go to (2). + """ + if p == h: + break diff --git a/torch/distributed/_pipeline/sync/balance/profile.py b/torch/distributed/_pipeline/sync/balance/profile.py new file mode 100644 index 000000000000..737dda60f6fa --- /dev/null +++ b/torch/distributed/_pipeline/sync/balance/profile.py @@ -0,0 +1,114 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +"""Per-layer profilers.""" +import copy +import time +from typing import Generator, List, Tuple, Union + +import torch +from torch import Tensor +import torch.nn as nn + +from ..microbatch import Batch + +__all__: List[str] = [] + + +Device = Union[torch.device, int, str] + +Tensors = Tuple[Tensor, ...] +TensorOrTensors = Union[Tensor, Tensors] + + +def layerwise_sandbox(module: nn.Sequential, device: torch.device,) -> Generator[nn.Module, None, None]: + """Copies layers for ease to profile. It doesn't modify the given + module. + """ + for layer in module: + layer_copy = copy.deepcopy(layer) + layer_copy.to(device) + layer_copy.train() + yield layer_copy + + +def detach(batch: Batch) -> None: + """Detaches from autograd graph.""" + for i, x in enumerate(batch): + batch[i] = x.detach().requires_grad_(x.requires_grad) + + +def profile_times(module: nn.Sequential, sample: TensorOrTensors, timeout: float, device: torch.device,) -> List[int]: + """Profiles elapsed times per layer.""" + if any(p.grad is not None for p in module.parameters()): + raise ValueError("some parameter already has gradient") + + _batch = Batch(sample) + for i, x in enumerate(_batch): + _batch[i] = x.detach().to(device).requires_grad_(x.requires_grad) + + time_bufs: List[List[float]] = [[] for _ in module] + begun_at = time.time() + + while time.time() - begun_at < timeout: + batch = _batch + + for i, layer in enumerate(layerwise_sandbox(module, device)): + detach(batch) + + if device.type == "cuda": + torch.cuda.synchronize(device) + tick = time.time() + + # Forward + batch = batch.call(layer) + + # Backward + backward_tensors = tuple(y for y in batch if y.requires_grad) + if backward_tensors: + torch.autograd.backward(backward_tensors, backward_tensors) + + if device.type == "cuda": + torch.cuda.synchronize(device) + tock = time.time() + + time_bufs[i].append(tock - tick) + + us = 1_000_000 + return [sum(int(t * us) for t in buf) for buf in time_bufs] + + +def profile_sizes( + module: nn.Sequential, input: TensorOrTensors, chunks: int, param_scale: float, device: torch.device, +) -> List[int]: + """Profiles CUDA memory usage per layer.""" + if device.type != "cuda": + raise ValueError("size profiler supports only CUDA device") + + batch = Batch(input) + sizes: List[int] = [] + + latent_scale = batch[0].size(0) / chunks + for i, x in enumerate(batch): + batch[i] = x[:1].detach().to(device).requires_grad_(x.requires_grad) + + for layer in layerwise_sandbox(module, device): + detach(batch) + + # Detect memory usage at forward. + memory_before = torch.cuda.memory_allocated(device) + batch = batch.call(layer) + memory_after = torch.cuda.memory_allocated(device) + latent_size = memory_after - memory_before + + # Analyze size of parameters. + param_size = sum(p.storage().size() * p.storage().element_size() for p in layer.parameters()) + + # Combine size of parameters and activations with normalize scales. + size = latent_size * latent_scale + param_size * param_scale + sizes.append(int(size)) + + return sizes diff --git a/torch/distributed/_pipeline/sync/balance/py.typed b/torch/distributed/_pipeline/sync/balance/py.typed new file mode 100644 index 000000000000..ab03724cafbf --- /dev/null +++ b/torch/distributed/_pipeline/sync/balance/py.typed @@ -0,0 +1,6 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. diff --git a/torch/distributed/_pipeline/sync/batchnorm.py b/torch/distributed/_pipeline/sync/batchnorm.py new file mode 100644 index 000000000000..487c3d096d98 --- /dev/null +++ b/torch/distributed/_pipeline/sync/batchnorm.py @@ -0,0 +1,159 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +"""Tracks the running statistics per mini-batch instead of micro-batch.""" +from typing import Optional, TypeVar, cast + +import torch +from torch import Tensor, nn +import torch.nn.functional as F +from torch.nn.modules.batchnorm import _BatchNorm + +from .checkpoint import is_recomputing + +__all__ = ["DeferredBatchNorm"] + + +TModule = TypeVar("TModule", bound=nn.Module) + + +class DeferredBatchNorm(_BatchNorm): + """A BatchNorm layer tracks multiple micro-batches to update running + statistics per mini-batch. + """ + + sum: Tensor + sum_squares: Tensor + + def __init__( + self, + num_features: int, + eps: float = 1e-5, + momentum: Optional[float] = 0.1, + affine: bool = True, + chunks: int = 1, + ) -> None: + super().__init__(num_features, eps, momentum, affine, track_running_stats=True) + + self.register_buffer("sum", torch.zeros_like(self.running_mean)) + self.register_buffer("sum_squares", torch.zeros_like(self.running_var)) + + self.counter = 0 + self.tracked = 0 + self.chunks = chunks + + def _check_input_dim(self, input: Tensor) -> None: + # It's the typical _check_input_dim() implementation in PyTorch. + if input.dim() <= 2: + raise ValueError("expected at least 3D input (got %dD input)" % input.dim()) + + def _track(self, input: Tensor) -> bool: + """Tracks statistics of a micro-batch.""" + # Dimensions except channel. For example, (0, 2, 3) is for BatchNorm2d. + dim = [0] + dim.extend(range(2, input.dim())) + + with torch.no_grad(): + self.sum += input.sum(dim) + self.sum_squares += (input ** 2).sum(dim) + + size = input.size().numel() // input.size(1) + self.counter += size + self.tracked += 1 + + return self.tracked == self.chunks + + def _commit(self) -> None: + """Updates the running statistics of a mini-batch.""" + exponential_average_factor = 0.0 + self.num_batches_tracked += 1 + if self.momentum is None: # use cumulative moving average + exponential_average_factor = 1.0 / float(self.num_batches_tracked) + else: # use exponential moving average + exponential_average_factor = self.momentum + + mean = self.sum / self.counter + var = self.sum_squares / self.counter - mean ** 2 + + # Calculate the exponential moving average here. + m = exponential_average_factor + + self.running_mean *= 1 - m + self.running_mean += mean * m + + self.running_var *= 1 - m + self.running_var += var * m + + self.sum.zero_() + self.sum_squares.zero_() + self.counter = 0 + self.tracked = 0 + + def forward(self, input: Tensor) -> Tensor: # type: ignore + if not self.training: + # Don't train parameters on the evaluation mode. + return F.batch_norm( + input, + running_mean=self.running_mean, + running_var=self.running_var, + weight=self.weight, + bias=self.bias, + training=False, + momentum=0.0, + eps=self.eps, + ) + + if not is_recomputing(): + # Track a micro-batch on the training mode + # but not under a recomputation. + tracked_enough = self._track(input) + + # Update the running statistics for a mini-batch + # if it has tracked enough micro-batches. + if tracked_enough: + self._commit() + + # Normalize a micro-batch and train the parameters. + return F.batch_norm( + input, + running_mean=None, + running_var=None, + weight=self.weight, + bias=self.bias, + training=True, + momentum=0.0, + eps=self.eps, + ) + + @classmethod + def convert_deferred_batch_norm(cls, module: TModule, chunks: int = 1) -> TModule: + """Converts a :class:`nn.BatchNorm` or underlying + :class:`nn.BatchNorm`s into :class:`DeferredBatchNorm`:: + + from torchvision.models.resnet import resnet101 + from torchpipe.batchnorm import DeferredBatchNorm + model = resnet101() + model = DeferredBatchNorm.convert_deferred_batch_norm(model) + + """ + if isinstance(module, DeferredBatchNorm) and module.chunks is chunks: + return cast(TModule, module) + + module_output: nn.Module = module + + if isinstance(module, _BatchNorm) and module.track_running_stats: + module_output = DeferredBatchNorm(module.num_features, module.eps, module.momentum, module.affine, chunks) + if module.affine: + module_output.register_parameter("weight", module.weight) + module_output.register_parameter("bias", module.bias) + module_output.register_buffer("running_mean", module.running_mean) + module_output.register_buffer("running_var", module.running_var) + module_output.register_buffer("num_batches_tracked", module.num_batches_tracked) + + for name, child in module.named_children(): + module_output.add_module(name, cls.convert_deferred_batch_norm(child, chunks)) + + return cast(TModule, module_output) diff --git a/torch/distributed/_pipeline/sync/checkpoint.py b/torch/distributed/_pipeline/sync/checkpoint.py new file mode 100644 index 000000000000..08e95e2d18fa --- /dev/null +++ b/torch/distributed/_pipeline/sync/checkpoint.py @@ -0,0 +1,317 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +"""Checkpointing with preceding recomputation. + +PyTorch already provides the official checkpointing utilities in +:mod:`torch.utils.checkpoint`. The official checkpointing combines +recomputation and recursive backpropagation into one autograd function named +``CheckpointFunction``. Hence, the recomputation can be started only when the +gradients arrive to the function. In Pipe, the recomputation needs to precede +the gradient arrival to minimize the GPU idle time. + +We solve this problem by introducing separate autograd functions named +:class:`Recompute` and :class:`Checkpoint`. Each function represents +recomputation and recursive backpropagation, respectively. We can manipulate +the control flow in aspect of both the autograd engine and CUDA with a pair of +the functions. + +Specifically, we place CUDA stream synchronization between :class:`Recompute` +and :class:`Checkpoint` to delay only :class:`Checkpoint` until the gradient is +copied entirely. + +""" +from collections import deque +from contextlib import contextmanager +import threading +from typing import TYPE_CHECKING, Deque, Generator, List, Optional, Tuple, Union + +import torch +from torch import ByteTensor, Tensor +import torch.autograd + +from .dependency import fork, join +from .microbatch import Batch +from .phony import get_phony + +__all__ = ["is_checkpointing", "is_recomputing"] + + +Tensors = Tuple[Tensor, ...] +TensorOrTensors = Union[Tensor, Tensors] + +# Types for shared memory between Checkpoint and Recompute. +Recomputed = Tuple[TensorOrTensors, Tensors] # (output, input_leaf) +RNGStates = Tuple[ByteTensor, Optional[ByteTensor]] # (cpu_rng_state, gpu_rng_state) + + +if TYPE_CHECKING: + from typing_extensions import Protocol +else: + Protocol = object + + +# Protocol with __call__ instead of Callable can be used as an attribute type. +# See: https://github.com/python/mypy/issues/708#issuecomment-561735949 +class Function(Protocol): + def __call__(self, input: TensorOrTensors) -> TensorOrTensors: + ... + + +def checkpoint(function: Function, input: TensorOrTensors) -> TensorOrTensors: + """Makes a checkpoint with a simple interface like + :func:`torch.utils.checkpoint.checkpoint`. It's only used to test or debug + :class:`Checkpoint` and :class:`Recompute` without boilerplate. + """ + batch = Batch(input) + + chk = Checkpointing(function, batch) + batch = chk.checkpoint() + chk.recompute(batch) + + return batch.tensor_or_tensors + + +class Checkpointing: + """Generates a pair of :class:`Checkpoint` and :class:`Recompute`.""" + + def __init__(self, function: Function, batch: Batch) -> None: + self.function = function + self.batch = batch + + # Shared memory between Checkpoint and Recompute. 1-length deque is + # used for mutability and length limitation. + self.recomputed: Deque[Recomputed] = deque(maxlen=1) + self.rng_states: Deque[RNGStates] = deque(maxlen=1) + + def checkpoint(self) -> Batch: + """Returns a batch applied by :class:`Checkpoint`.""" + input_atomic = self.batch.atomic + input = tuple(self.batch) + + # Use a phony which requires grad to ensure that Checkpoint can be + # tracked by the autograd engine even when none of the input tensors + # require grad. + phony = get_phony(self.batch[0].device, requires_grad=True) + + output = Checkpoint.apply(phony, self.recomputed, self.rng_states, self.function, input_atomic, *input) + + # Gradients are only supported for float Tensors. + if isinstance(output, tuple): + output = tuple([x if x.is_floating_point() else x.detach() for x in output]) + + return Batch(output) + + def recompute(self, batch: Batch) -> None: + """Applies :class:`Recompute` to the batch in place.""" + input_atomic = self.batch.atomic + input = tuple(self.batch) + + # batch[0] is always requiring grad, because it has been passed + # checkpoint with a phony requiring grad. + batch[0], phony = fork(batch[0]) + phony = Recompute.apply(phony, self.recomputed, self.rng_states, self.function, input_atomic, *input) + batch[0] = join(batch[0], phony) + + +class ThreadLocal(threading.local): + def __init__(self) -> None: + self.is_checkpointing = False + self.is_recomputing = False + + +thread_local = ThreadLocal() + + +@contextmanager +def enable_checkpointing() -> Generator[None, None, None]: + """Makes :func:`is_checkpointing` return :data:`True` within a context.""" + orig = thread_local.is_checkpointing + thread_local.is_checkpointing = True + try: + yield + finally: + thread_local.is_checkpointing = orig + + +@contextmanager +def enable_recomputing() -> Generator[None, None, None]: + """Makes :func:`is_recomputing` return :data:`True` within a context.""" + orig = thread_local.is_recomputing + thread_local.is_recomputing = True + try: + yield + finally: + thread_local.is_recomputing = orig + + +def is_checkpointing() -> bool: + """Whether the current forward propagation is under checkpointing. + + Returns: + bool: :data:`True` if it's under checkpointing. + + """ + return thread_local.is_checkpointing + + +def is_recomputing() -> bool: + """Whether the current forward propagation is under checkpoint + recomputation. Use this to prevent duplicated side-effects at forward + propagation:: + + class Counter(nn.Module): + def __init__(self): + super().__init__() + self.counter = 0 + + def forward(self, input): + if not is_recomputing(): + self.counter += 1 + return input + + Returns: + bool: :data:`True` if it's under checkpoint recomputation. + + .. seealso:: :ref:`Detecting Recomputation` + + """ + return thread_local.is_recomputing + + +class Context: + """The common interface between the :class:`Checkpoint` and + :class:`Recompute` context. + """ + + recomputed: Deque[Recomputed] + rng_states: Deque[RNGStates] + function: Function + input_atomic: bool + + saved_tensors: Tuple[Tensor, ...] + + def save_for_backward(self, *tensors: Tensor) -> None: # pragma: no cover + pass + + +def save_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> None: + """:meth:`Checkpoint.forward` captures the current PyTorch's random number + generator states at CPU and GPU to reuse in :meth:`Recompute.backward`. + + .. seealso:: :ref:`Referential Transparency` + + """ + cpu_rng_state = torch.get_rng_state() + + gpu_rng_state: Optional[ByteTensor] + if device.type == "cuda": + gpu_rng_state = torch.cuda.get_rng_state(device) + else: + gpu_rng_state = None + + rng_states.append((cpu_rng_state, gpu_rng_state)) + + +@contextmanager +def restore_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> Generator[None, None, None]: + """:meth:`Recompute.backward` restores the random number generator states + captured by :func:`save_rng_states` within its context. + + .. seealso:: :ref:`Referential Transparency` + + """ + cpu_rng_state, gpu_rng_state = rng_states.pop() + + gpu_devices: List[torch.device] = [] + if device.type == "cuda": + gpu_devices.append(device) + + with torch.random.fork_rng(gpu_devices): + torch.set_rng_state(cpu_rng_state) + if gpu_rng_state is not None: + torch.cuda.set_rng_state(gpu_rng_state, device) + yield + + +class Checkpoint(torch.autograd.Function): + @staticmethod + # type: ignore + def forward( + ctx: Context, + phony: Tensor, + recomputed: Deque[Recomputed], + rng_states: Deque[RNGStates], + function: Function, + input_atomic: bool, + *input: Tensor, + ) -> TensorOrTensors: + ctx.recomputed = recomputed + ctx.rng_states = rng_states + + save_rng_states(input[0].device, ctx.rng_states) + + ctx.function = function + ctx.input_atomic = input_atomic + ctx.save_for_backward(*input) + + with torch.no_grad(), enable_checkpointing(): + output = function(input[0] if input_atomic else input) + + return output + + @staticmethod + def backward(ctx: Context, *grad_output: Tensor,) -> Tuple[Optional[Tensor], ...]: # pragma: no cover + output, input_leaf = ctx.recomputed.pop() + + if isinstance(output, tuple): + tensors = output + else: + tensors = (output,) + if any(y.requires_grad for y in tensors): + tensors = tuple([x for x in tensors if x.requires_grad]) + torch.autograd.backward(tensors, grad_output) + + grad_input: List[Optional[Tensor]] = [None, None, None, None, None] + grad_input.extend(x.grad for x in input_leaf) + return tuple(grad_input) + + +class Recompute(torch.autograd.Function): + @staticmethod + # type: ignore + def forward( + ctx: Context, + phony: Tensor, + recomputed: Deque[Recomputed], + rng_states: Deque[RNGStates], + function: Function, + input_atomic: bool, + *input: Tensor, + ) -> Tensor: + ctx.recomputed = recomputed + ctx.rng_states = rng_states + + ctx.function = function + ctx.input_atomic = input_atomic + ctx.save_for_backward(*input) + + return phony + + @staticmethod + def backward(ctx: Context, *grad_output: Tensor) -> Tuple[None, ...]: # pragma: no cover + input = ctx.saved_tensors + input_leaf = tuple(x.detach().requires_grad_(x.requires_grad) for x in input) + + with restore_rng_states(input[0].device, ctx.rng_states): + with torch.enable_grad(), enable_recomputing(): + output = ctx.function(input_leaf[0] if ctx.input_atomic else input_leaf) + + ctx.recomputed.append((output, input_leaf)) + + grad_input: List[None] = [None, None, None, None, None] + grad_input.extend(None for _ in ctx.saved_tensors) + return tuple(grad_input) diff --git a/torch/distributed/_pipeline/sync/copy.py b/torch/distributed/_pipeline/sync/copy.py new file mode 100644 index 000000000000..3d330f59eeee --- /dev/null +++ b/torch/distributed/_pipeline/sync/copy.py @@ -0,0 +1,104 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +"""Autograd functions for stream-aware CUDA copy. It is used to overlap copy +and computation on the same GPU. +""" +from collections import deque +from typing import Deque, List, Optional, Tuple + +import torch +from torch import Tensor + +from .stream import AbstractStream, current_stream, get_device, record_stream, use_stream, wait_stream + +__all__: List[str] = [] + + +Tensors = Tuple[Tensor, ...] + + +# Common interface between :class:`Copy` and :class:`Wait`. +class Context: + prev_stream: AbstractStream + next_stream: AbstractStream + + +class Copy(torch.autograd.Function): + """Copies tensors on specific streams.""" + + @staticmethod + # type: ignore + def forward(ctx: Context, prev_stream: AbstractStream, next_stream: AbstractStream, *input: Tensor,) -> Tensors: + ctx.prev_stream = prev_stream + ctx.next_stream = next_stream + + output = [] + output_stream = current_stream(get_device(next_stream)) + + with use_stream(prev_stream), use_stream(next_stream): + for x in input: + y = x.to(get_device(next_stream), non_blocking=True) + output.append(y) + + # 'prev_stream' is not where 'x' has been allocated. + record_stream(x, prev_stream) + # 'y' has been allocated on 'next_stream'. + # It might be used on the current stream captured as 'output_stream'. + record_stream(y, output_stream) + + return tuple(output) + + @staticmethod + def backward(ctx: Context, *grad_output: Tensor,) -> Tuple[Optional[Tensor], ...]: + prev_stream = ctx.prev_stream + next_stream = ctx.next_stream + + grad_input: Deque[Tensor] = deque(maxlen=len(grad_output)) + input_stream = current_stream(get_device(prev_stream)) + + with use_stream(prev_stream), use_stream(next_stream): + for x in reversed(grad_output): + y = x.to(get_device(prev_stream), non_blocking=True) + grad_input.appendleft(y) + + # 'next_stream' is not where 'x' has been allocated. + record_stream(x, next_stream) + # 'y' has been allocated on 'prev_stream'. + # It might be used on the current stream captured as 'input_stream'. + record_stream(y, input_stream) + + grad_streams: Tuple[Optional[Tensor], ...] = (None, None) + return grad_streams + tuple(grad_input) + + +class Wait(torch.autograd.Function): + """Synchronizes a stream to another stream. + + Place it just before you want to start an operation on the next stream, + provided that all operations on the previous stream are done. + + """ + + @staticmethod + # type: ignore + def forward(ctx: Context, prev_stream: AbstractStream, next_stream: AbstractStream, *input: Tensor,) -> Tensors: + ctx.prev_stream = prev_stream + ctx.next_stream = next_stream + + wait_stream(next_stream, prev_stream) + + return tuple(x.detach() for x in input) + + @staticmethod + def backward(ctx: Context, *grad_input: Tensor,) -> Tuple[Optional[Tensor], ...]: + prev_stream = ctx.prev_stream + next_stream = ctx.next_stream + + wait_stream(prev_stream, next_stream) + + grad_streams: Tuple[Optional[Tensor], ...] = (None, None) + return grad_streams + grad_input diff --git a/torch/distributed/_pipeline/sync/dependency.py b/torch/distributed/_pipeline/sync/dependency.py new file mode 100644 index 000000000000..aeebc11aeeba --- /dev/null +++ b/torch/distributed/_pipeline/sync/dependency.py @@ -0,0 +1,54 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +"""Arbitrary dependency between two autograd lanes.""" +from typing import List, Tuple + +import torch +from torch import Tensor + +from .phony import get_phony + +__all__: List[str] = [] + + +def fork(input: Tensor) -> Tuple[Tensor, Tensor]: + """Branches out from an autograd lane of the given tensor.""" + if torch.is_grad_enabled() and input.requires_grad: + input, phony = Fork.apply(input) + else: + phony = get_phony(input.device, requires_grad=False) + + return input, phony + + +class Fork(torch.autograd.Function): + @staticmethod + def forward(ctx: "Fork", input: Tensor) -> Tuple[Tensor, Tensor]: # type: ignore + phony = get_phony(input.device, requires_grad=False) + return input.detach(), phony.detach() + + @staticmethod + def backward(ctx: "Fork", grad_input: Tensor, grad_grad: Tensor) -> Tensor: # type: ignore + return grad_input + + +def join(input: Tensor, phony: Tensor) -> Tensor: + """Merges two autograd lanes.""" + if torch.is_grad_enabled() and (input.requires_grad or phony.requires_grad): + input = Join.apply(input, phony) + + return input + + +class Join(torch.autograd.Function): + @staticmethod + def forward(ctx: "Join", input: Tensor, phony: Tensor) -> Tensor: # type: ignore + return input.detach() + + @staticmethod + def backward(ctx: "Join", grad_input: Tensor) -> Tuple[Tensor, None]: # type: ignore + return grad_input, None diff --git a/torch/distributed/_pipeline/sync/microbatch.py b/torch/distributed/_pipeline/sync/microbatch.py new file mode 100644 index 000000000000..d38cb6d3b85c --- /dev/null +++ b/torch/distributed/_pipeline/sync/microbatch.py @@ -0,0 +1,185 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +"""Manipulation of micro-batches.""" +import typing +from typing import Callable, Iterable, Iterator, List, Tuple, Union, cast + +import torch +from torch import Tensor +import torch.cuda.comm + +__all__: List[str] = [] + + +Tensors = Tuple[Tensor, ...] +TensorOrTensors = Union[Tensor, Tensors] +Function = Callable[[TensorOrTensors], TensorOrTensors] + + +class Batch: + """An abstraction of an atomic tensor or a tuple of tensors. This + eliminates every boilerplate code to classify an atomic tensor or a tuple + of tensors. + :: + + x = generate_tensor_or_tensors() + x = Batch(x) + + # in-place update + x[0] = F.apply(x[0]) + x[:] = F.apply(*x) + + # f(x) if x is a tensor. + # f(*x) if x is a tuple of tensors. + # y is also a batch. + y = x.call(f) + + """ + + def __init__(self, value: TensorOrTensors) -> None: + self.value = value + self.atomic = torch.is_tensor(value) + + @property + def tensor(self) -> Tensor: + """Retrieves the underlying tensor.""" + if not self.atomic: + raise AttributeError("not atomic batch") + return cast(Tensor, self.value) + + @property + def tensors(self) -> Tensors: + """Retrieves the underlying tensors.""" + if self.atomic: + raise AttributeError("batch is atomic") + return cast(Tensors, self.value) + + @property + def tensor_or_tensors(self) -> TensorOrTensors: + """Retrieves the underlying tensor or tensors regardless of type.""" + return self.value + + def call(self, function: Function) -> "Batch": + """Calls a function by the underlying tensor or tensors. It also wraps + the output with :class:`Batch`. + """ + return Batch(function(self.value)) + + def __repr__(self) -> str: + return f"Batch[atomic={self.atomic!r}]({self.value!r})" + + def __iter__(self) -> Iterator[Tensor]: + if self.atomic: + yield self.tensor + else: + yield from self.tensors + + def __len__(self) -> int: + return 1 if self.atomic else len(self.tensors) + + def __getitem__(self, index: int) -> Tensor: + if not self.atomic: + return self.tensors[index] + + if index != 0: + raise IndexError("atomic batch allows index 0 only") + + return self.tensor + + # NOTE(sublee): pyflakes can't detect "overload" instead of "typing.overload". + @typing.overload + def __setitem__(self, index: int, value: Tensor) -> None: + ... + + @typing.overload + def __setitem__(self, index: slice, value: Tensors) -> None: + ... + + def __setitem__(self, index: Union[int, slice], value: TensorOrTensors) -> None: + if isinstance(index, int): + value = cast(Tensor, value) + self._setitem_by_index(index, value) + else: + value = cast(Tensors, value) + self._setitem_by_slice(index, value) + + def _setitem_by_index(self, index: int, value: Tensor) -> None: + if not self.atomic: + i = index + self.value = self.value[:i] + (value,) + self.value[i + 1 :] + return + + if index != 0: + raise IndexError("atomic batch allows index 0 only") + + self.value = value + + def _setitem_by_slice(self, index: slice, value: Tensors) -> None: + if not (index.start is index.stop is index.step is None): + raise NotImplementedError("only slice [:] supported") + + if not self.atomic: + self.value = value + return + + if len(value) != 1: + raise IndexError("atomic batch cannot be replaced with multiple tensors") + + self.value = value[0] + + +def check(input: TensorOrTensors) -> None: + """Checks whether the input is a tensor or tensors. + + Raises: + TypeError: input is not a tensor or tensors. + + """ + if isinstance(input, tuple): + for x in input: + check(x) + return + + if not isinstance(input, Tensor): + raise TypeError(f"expected Tensor, but got {input.__class__.__name__}") + + +def scatter(input: TensorOrTensors, chunks: int) -> List[Batch]: + """Splits an input mini-batch into multiple micro-batches.""" + inputs: Iterable[TensorOrTensors] + + if isinstance(input, Tensor): + inputs = input.chunk(chunks) + else: + rotated: List[Tensors] = [] + + for tensor in input: + tensors = tensor.chunk(chunks) + rotated.append(cast(Tensors, tensors)) + + inputs = zip(*rotated) + + return [Batch(x) for x in inputs] + + +def gather(outputs: List[Batch]) -> TensorOrTensors: + """Concatenates output micro-batches into a mini-batch.""" + output: TensorOrTensors + + if outputs[0].atomic: + tensors = tuple(b.tensor for b in outputs) + output = torch.cat(tensors) + else: + rotated = [b.tensors for b in outputs] + output_buf = [] + + for tensors in zip(*rotated): + output_buf.append(torch.cat(tensors)) + + output = tuple(output_buf) + + return output diff --git a/torch/distributed/_pipeline/sync/phony.py b/torch/distributed/_pipeline/sync/phony.py new file mode 100644 index 000000000000..5e89ff0efd27 --- /dev/null +++ b/torch/distributed/_pipeline/sync/phony.py @@ -0,0 +1,49 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +"""Provides phony for arbitrary dependency in a autograd graph.""" +from typing import Dict, List, Tuple + +import torch +from torch import Tensor + +from .stream import default_stream, use_stream + +__all__: List[str] = [] + + +_phonies: Dict[Tuple[torch.device, bool], Tensor] = {} + + +def get_phony(device: torch.device, *, requires_grad: bool) -> Tensor: + """Gets a phony. Phony is tensor without space. It is useful to make + arbitrary dependency in a autograd graph because it doesn't require any + gradient accumulation. + + .. note:: + + Phonies for each device are cached. If an autograd function gets a phony + internally, the phony must be detached to be returned. Otherwise, the + autograd engine will mutate the cached phony in-place:: + + class Phonify(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + phony = get_phony(input.device, requires_grad=False) + return phony.detach() # detach() is necessary. + + """ + key = (device, requires_grad) + + try: + phony = _phonies[key] + except KeyError: + with use_stream(default_stream(device)): + phony = torch.empty(0, device=device, requires_grad=requires_grad) + + _phonies[key] = phony + + return phony diff --git a/torch/distributed/_pipeline/sync/pipe.py b/torch/distributed/_pipeline/sync/pipe.py new file mode 100644 index 000000000000..500b15b72771 --- /dev/null +++ b/torch/distributed/_pipeline/sync/pipe.py @@ -0,0 +1,394 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +"""The Pipe interface.""" +from collections import OrderedDict +from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union, cast + +import torch +from torch import Tensor, nn +import torch.autograd +import torch.cuda + +from . import microbatch +from .batchnorm import DeferredBatchNorm +from .pipeline import Pipeline +from .skip.layout import inspect_skip_layout +from .skip.skippable import verify_skippables +from .stream import AbstractStream, new_stream + +__all__ = ["Pipe"] + + +Device = Union[torch.device, int, str] +Devices = Union[Iterable[Device], List[Device]] + +Tensors = Tuple[Tensor, ...] +TensorOrTensors = Union[Tensor, Tensors] + +if TYPE_CHECKING: + Module = nn.Module[TensorOrTensors] + NamedModules = OrderedDict[str, Module] +else: + Module = nn.Module + NamedModules = OrderedDict + + +def recommend_auto_balance(message: str) -> str: + """Expands a message with recommendation to :mod:`torchpipe.balance`.""" + return f"""{message} + +If your model is still under development, its optimal balance would change +frequently. In this case, we highly recommend 'torch.distributed._pipeline.sync.balance' for +naive automatic balancing: + + from torch.distributed._pipeline.sync import Pipe + from torch.distributed._pipeline.sync.balance import balance_by_time + + partitions = torch.cuda.device_count() + sample = torch.empty(...) + balance = balance_by_time(partitions, model, sample) + + model = Pipe(model, balance, ...) +""" + + +def verify_module(module: nn.Sequential) -> None: + if not isinstance(module, nn.Sequential): + raise TypeError("module must be nn.Sequential to be partitioned") + + named_children = list(module.named_children()) + if len(named_children) != len(module): + raise ValueError("module with duplicate children is not supported") + + +def verify_splitting( + module: nn.Sequential, partitions: List[nn.Sequential], balance: Iterable[int], devices: List[torch.device] +) -> None: + num_parameters = len(list(module.parameters())) + num_child_parameters = sum(len(list(child.parameters())) for child in module.children()) + if num_parameters == num_child_parameters: + return + + for i in range(len(partitions)): + for j in range(i + 1, len(partitions)): + parti = partitions[i] + partj = partitions[j] + if devices[i] == devices[j]: + continue + for p in parti.parameters(): + for q in partj.parameters(): + if p is q: + raise ValueError("module with duplicate parameters on distinct devices is not supported") + + +class BalanceError(ValueError): + pass + + +def split_module( + module: nn.Sequential, balance: Iterable[int], devices: List[torch.device], +) -> Tuple[List[nn.Sequential], List[int], List[torch.device]]: + """Splits a module into multiple partitions. + + Returns: + A tuple of (partitions, balance, devices). + + Partitions are represented as a :class:`~torch.nn.ModuleList` whose + item is a partition. All layers in a partition are placed in the + same device. + + Raises: + BalanceError: + wrong balance + IndexError: + the number of devices is fewer than the number of partitions. + + """ + balance = list(balance) + + if len(module) != sum(balance): + raise BalanceError( + "module and sum of balance have different length " + f"(module: {len(module)}, sum of balance: {sum(balance)})" + ) + + if any(x <= 0 for x in balance): + raise BalanceError(f"all balance numbers must be positive integer (balance: {balance})") + + if len(balance) > len(devices): + raise IndexError( + "too few devices to hold given partitions " f"(devices: {len(devices)}, partitions: {len(balance)})" + ) + + j = 0 + partitions = [] + layers: NamedModules = OrderedDict() + + for name, layer in module.named_children(): + layers[name] = layer + + if len(layers) == balance[j]: + # Group buffered layers as a partition. + partition = nn.Sequential(layers) + + device = devices[j] + partition.to(device) + + partitions.append(partition) + + # Prepare for the next partition. + layers.clear() + j += 1 + + partitions = cast(List[nn.Sequential], nn.ModuleList(partitions)) + del devices[j:] + + return partitions, balance, devices + + +MOVING_DENIED = TypeError("denied to move parameters and buffers, " "because Pipe should manage device placement") + + +class Pipe(Module): + """Wraps an arbitrary :class:`nn.Sequential ` module + to train on Pipe_. If the module requires lots of memory, Pipe will be + very efficient. + :: + + model = nn.Sequential(a, b, c, d) + model = Pipe(model, balance=[1, 1, 1, 1], chunks=8) + output = model(input) + + .. _Pipe: https://arxiv.org/abs/1811.06965 + + Pipe combines pipeline parallelism with checkpointing to reduce peak + memory required to train while minimizing device under-utilization. + + You should determine the balance when defining a :class:`Pipe` module, as + balancing will not be done automatically. The module will be partitioned + into multiple devices according to the given balance. You may rely on + heuristics to find your own optimal configuration. + + Args: + module (torch.nn.Sequential): + sequential module to be parallelized + balance (ints): + list of number of layers in each partition + + Keyword Args: + devices (iterable of devices): + devices to use (default: all CUDA devices) + chunks (int): + number of micro-batches (default: ``1``) + checkpoint (str): + when to enable checkpointing, one of ``'always'``, + ``'except_last'``, or ``'never'`` (default: ``'except_last'``) + deferred_batch_norm (bool): + whether to use deferred BatchNorm moving statistics (default: + :data:`False`, see :ref:`Deferred Batch Normalization` for more + details) + + Raises: + TypeError: + the module is not a :class:`nn.Sequential `. + ValueError: + invalid arguments, or wrong balance + IndexError: + the number of devices is fewer than the number of partitions. + + """ + + #: The number of layers in each partition. + balance: List[int] = [] + # ^^ + # The default value [] required for Sphinx's autoattribute. + + #: The devices mapped to each partition. + #: + #: ``devices[-1]`` refers to the device of the last partition, which means + #: it is the output device. Probably, you need to use it to transfer the + #: target to calculate the loss without a device mismatch + #: :exc:`RuntimeError`. For example:: + #: + #: out_device = pipe.devices[-1] + #: + #: for input, target in loader: + #: target = target.to(out_device, non_blocking=True) + #: output = pipe(input) + #: loss = F.cross_entropy(output, target) + #: + devices: List[torch.device] = [] + + #: The number of micro-batches. + chunks: int = 1 + + #: The checkpoint mode to determine when to enable checkpointing. It is one + #: of ``'always'``, ``'except_last'``, or ``'never'``. + checkpoint: str = "except_last" + + def __init__( + self, + module: nn.Sequential, + balance: Optional[Iterable[int]] = None, + *, + devices: Optional[Devices] = None, + chunks: int = chunks, + checkpoint: str = checkpoint, + deferred_batch_norm: bool = False, + ) -> None: + super().__init__() + + chunks = int(chunks) + checkpoint = str(checkpoint) + + if balance is None: + raise ValueError(recommend_auto_balance("balance is required")) + if chunks <= 0: + raise ValueError("number of chunks must be positive integer") + if checkpoint not in ["always", "except_last", "never"]: + raise ValueError("checkpoint is not one of 'always', 'except_last', or 'never'") + + verify_module(module) + + # Verify if the underlying skippable modules satisfy integrity. The + # integrity can be verified before forward() because it is static. + verify_skippables(module) + + self.chunks = chunks + self.checkpoint = checkpoint + + if deferred_batch_norm: + module = DeferredBatchNorm.convert_deferred_batch_norm(module, chunks) + + if devices is None: + devices = range(torch.cuda.device_count()) + devices = [torch.device(d) for d in devices] + devices = cast(List[torch.device], devices) + + try: + self.partitions, self.balance, self.devices = split_module(module, balance, devices) + except BalanceError as exc: + raise ValueError(recommend_auto_balance(str(exc))) + + verify_splitting(module, self.partitions, self.balance, self.devices) + + self._copy_streams: List[List[AbstractStream]] = [] + self._skip_layout = inspect_skip_layout(self.partitions) + + # Separate CUDA streams for copy. + copy_streams = self._ensure_copy_streams() + + # The micro-batch index where the checkpointing stops. + checkpoint_stop = {"always": self.chunks, "except_last": self.chunks - 1, "never": 0}[self.checkpoint] + + self.pipeline = Pipeline(self.partitions, self.devices, copy_streams, self._skip_layout, checkpoint_stop) + + def __len__(self) -> int: + """Counts the length of the underlying sequential module.""" + return sum(len(p) for p in self.partitions) + + def __getitem__(self, index: int) -> nn.Module: + """Gets a layer in the underlying sequential module.""" + partitions = self.partitions + if index < 0: + partitions = partitions[::-1] + + for partition in partitions: + try: + return partition[index] + except IndexError: + pass + + shift = len(partition) + + if index < 0: + index += shift + else: + index -= shift + + raise IndexError + + def __iter__(self) -> Iterable[nn.Module]: + """Iterates over children of the underlying sequential module.""" + for partition in self.partitions: + yield from partition + + # Pipe should manage the device of each partition. + # Deny cuda(), cpu(), and to() with device, by TypeError. + def cuda(self, device: Optional[Device] = None) -> "Pipe": + raise MOVING_DENIED + + def cpu(self) -> "Pipe": + raise MOVING_DENIED + + def to(self, *args: Any, **kwargs: Any) -> "Pipe": + # Deny these usages: + # + # - to(device[, dtype, non_blocking]) + # - to(tensor[, non_blocking]) + # + # But allow this: + # + # - to(dtype[, non_blocking]) + # + if "device" in kwargs or "tensor" in kwargs: + raise MOVING_DENIED + + if args: + if isinstance(args[0], (torch.device, int, str)): + raise MOVING_DENIED + if torch.is_tensor(args[0]): + raise MOVING_DENIED + + return super().to(*args, **kwargs) + + def _ensure_copy_streams(self) -> List[List[AbstractStream]]: + """Ensures that :class:`Pipe` caches CUDA streams for copy. + + It's worth to cache CUDA streams although PyTorch already manages a + pool of pre-allocated CUDA streams, because it may reduce GPU memory + fragementation when the number of micro-batches is small. + + """ + if not self._copy_streams: + for device in self.devices: + self._copy_streams.append([new_stream(device) for _ in range(self.chunks)]) + + return self._copy_streams + + def forward(self, input: TensorOrTensors) -> TensorOrTensors: # type: ignore + """:class:`Pipe` is a fairly transparent module wrapper. It doesn't + modify the input and output signature of the underlying module. But + there's type restriction. Input and output have to be a + :class:`~torch.Tensor` or a tuple of tensors. This restriction is + applied at partition boundaries too. + + Args: + input (torch.Tensor or tensors): input mini-batch + + Returns: + tensor or tensors: output mini-batch + + Raises: + TypeError: input is not a tensor or tensors. + + """ + microbatch.check(input) + + if not self.devices: + # Empty sequential module is not illegal. + return input + + # Divide a mini-batch into micro-batches. + batches = microbatch.scatter(input, self.chunks) + + # Run pipeline parallelism. + self.pipeline.run(batches) + + # Merge the micro-batches into one mini-batch. + output = microbatch.gather(batches) + return output diff --git a/torch/distributed/_pipeline/sync/pipeline.py b/torch/distributed/_pipeline/sync/pipeline.py new file mode 100644 index 000000000000..86c8dfddebeb --- /dev/null +++ b/torch/distributed/_pipeline/sync/pipeline.py @@ -0,0 +1,257 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +"""The pipeline parallelism of Pipe.""" +from queue import Queue +from types import TracebackType +from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Type, Union, cast + +import torch +from torch import Tensor, nn +from torch.autograd.profiler import record_function + +from .checkpoint import Checkpointing +from .copy import Copy, Wait +from .dependency import fork, join +from .microbatch import Batch +from .skip.layout import SkipLayout +from .skip.tracker import SkipTrackerThroughPotals, use_skip_tracker +from .stream import AbstractStream, current_stream, use_device +from .worker import Task, create_workers, join_workers + +__all__: List[str] = [] + + +Tensors = Tuple[Tensor, ...] +TensorOrTensors = Union[Tensor, Tensors] + +ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType] + +# Queue is generic only in stubs. +# https://mypy.readthedocs.io/en/latest/common_issues.html#using-classes-that-are-generic-in-stubs-but-not-at-runtime +if TYPE_CHECKING: + InQueue = Queue[Optional["Task"]] + OutQueue = Queue[Tuple[bool, Union[Tuple["Task", Batch], ExcInfo, None]]] +else: + InQueue = Queue + OutQueue = Queue + + +def depend(fork_from: Batch, join_to: Batch) -> None: + fork_from[0], phony = fork(fork_from[0]) + join_to[0] = join(join_to[0], phony) + + +def copy(batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream) -> None: + batch[:] = Copy.apply(prev_stream, next_stream, *batch) + # Gradients are only supported for float Tensors. + batch[:] = tuple([x if x.is_floating_point() else x.detach() for x in batch]) + + +def wait(batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream) -> None: + batch[:] = Wait.apply(prev_stream, next_stream, *batch) + # Gradients are only supported for float Tensors. + batch[:] = tuple([x if x.is_floating_point() else x.detach() for x in batch]) + + +def clock_cycles(m: int, n: int) -> Iterable[List[Tuple[int, int]]]: + """Generates schedules for each clock cycle.""" + # m: number of micro-batches + # n: number of partitions + # i: index of micro-batch + # j: index of partition + # k: clock number + # + # k (i,j) (i,j) (i,j) + # - ----- ----- ----- + # 0 (0,0) + # 1 (1,0) (0,1) + # 2 (2,0) (1,1) (0,2) + # 3 (2,1) (1,2) + # 4 (2,2) + for k in range(m + n - 1): + yield [(k - j, j) for j in range(max(1 + k - m, 0), min(1 + k, n))] + + +class Pipeline: + """The pipeline parallelism for Pipe.""" + + def __init__( + self, + partitions: List[nn.Sequential], + devices: List[torch.device], + copy_streams: List[List[AbstractStream]], + skip_layout: SkipLayout, + checkpoint_stop: int, + ) -> None: + self.partitions = partitions + self.devices = devices + self.copy_streams = copy_streams + self.skip_layout = skip_layout + self.checkpoint_stop = checkpoint_stop + (self.in_queues, self.out_queues) = create_workers(devices) + + def __del__(self) -> None: + join_workers(self.in_queues, self.out_queues) + + def run(self, batches: List[Batch]) -> None: + """Runs pipeline parallelism. + + It modifies the given batches in place. + + """ + partitions = self.partitions + devices = self.devices + skip_layout = self.skip_layout + + m = len(batches) + n = len(partitions) + + skip_trackers = [SkipTrackerThroughPotals(skip_layout) for _ in batches] + + for schedule in clock_cycles(m, n): + self.fence(batches, schedule, skip_trackers) + self.compute(batches, schedule, skip_trackers) + + def fence( + self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals], + ) -> None: + """Copies micro-batches after computation for the previous + micro-batches. + """ + copy_streams = self.copy_streams + skip_layout = self.skip_layout + + for i, j in schedule: + # Ensure that batches[i-1] is executed after batches[i] in + # backpropagation by an explicit dependency. + if i != 0 and j != 0: + depend(batches[i - 1], batches[i]) + + next_stream = copy_streams[j][i] + + for prev_j, ns, name in skip_layout.copy_policy(j): + prev_stream = copy_streams[prev_j][i] + skip_trackers[i].copy(batches[i], prev_stream, next_stream, ns, name) + + if j != 0: + prev_stream = copy_streams[j - 1][i] + copy(batches[i], prev_stream, next_stream) + + def compute( + self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals], + ) -> None: + """Runs tasks with synchronization to copy streams.""" + partitions = self.partitions + devices = self.devices + copy_streams = self.copy_streams + checkpoint_stop = self.checkpoint_stop + + # Disable checkpointing if in eval mode. + if not self.partitions[0].training: + checkpoint_stop = 0 + + n = len(partitions) + streams = [current_stream(d) for d in devices] + exc_info: Optional[ExcInfo] = None + + # With checkpointing, the autograd graph looks like this diagram: + # ┌─────┸──────┐ + # │ Copy │ + # └─────┰──────┘ (fence) + # ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─ + # ┃ (compute) + # ┌─────┸──────┐ + # │ Wait │ [1] Synchronize the current stream with the copy stream. + # └─────┰──────┘ + # ┌─────┸──────┐ + # │ Checkpoint │ [2] Compute a partition within checkpointing. + # └─────┰──────┘ + # ┌─────┸──────┐ + # │ Wait │ [3] Synchronize the copy stream with the current stream. + # └─────┰──────┘ + # ┠ ─ ─ ─ ┐ + # ┃ ┌─────┴─────┐ + # ┃ │ Recompute │ [4] Schedule the recomputation at backpropagation. + # ┃ └─────┬─────┘ + # ┠ ─ ─ ─ ┘ + # ┃ + # ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─ + # ┌─────┸──────┐ (fence) + # │ Copy │ + # └─────┰──────┘ + for i, j in schedule: + batch = batches[i] + partition = partitions[j] + + # Synchronize with the copied input. ([1] in the diagram) + if j != 0: + wait(batch, copy_streams[j][i], streams[j]) + + # Determine whether checkpointing or not. + checkpoint = i < checkpoint_stop + if checkpoint: + + def function( + input: TensorOrTensors, + partition: nn.Sequential = partition, + skip_tracker: SkipTrackerThroughPotals = skip_trackers[i], + chunk_id: int = i, + part_id: int = j, + ) -> TensorOrTensors: + with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)): + return partition(input) + + chk = Checkpointing(function, batch) + task = Task(streams[j], compute=chk.checkpoint, finalize=chk.recompute) + del function, chk + + else: + + def compute( + batch: Batch = batch, + partition: nn.Sequential = partition, + skip_tracker: SkipTrackerThroughPotals = skip_trackers[i], + chunk_id: int = i, + part_id: int = j, + ) -> Batch: + with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)): + return batch.call(partition) + + task = Task(streams[j], compute=compute, finalize=None) + del compute + + # Compute tasks in parallel. ([2] in the diagram) + self.in_queues[j].put(task) + + for i, j in schedule: + ok, payload = self.out_queues[j].get() + + # Hold the first exception. + if exc_info is not None: + continue + elif not ok: + exc_info = cast(ExcInfo, payload) + continue + + task, batch = cast(Tuple[Task, Batch], payload) + + # The copy stream synchronizes to copy the output. ([3] in the + # diagram) + if j != n - 1: + wait(batch, streams[j], copy_streams[j][i]) + + # Finalize tasks. If checkpointing is enabled, here the + # recomputation is scheduled at backpropagation. ([4] in the + # diagram) + with use_device(devices[j]): + task.finalize(batch) + + batches[i] = batch + + # Fail at the first exception. + if exc_info is not None: + raise exc_info[0].with_traceback(exc_info[1], exc_info[2]) diff --git a/torch/distributed/_pipeline/sync/py.typed b/torch/distributed/_pipeline/sync/py.typed new file mode 100644 index 000000000000..ab03724cafbf --- /dev/null +++ b/torch/distributed/_pipeline/sync/py.typed @@ -0,0 +1,6 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. diff --git a/torch/distributed/_pipeline/sync/skip/__init__.py b/torch/distributed/_pipeline/sync/skip/__init__.py new file mode 100644 index 000000000000..bdcb913867a7 --- /dev/null +++ b/torch/distributed/_pipeline/sync/skip/__init__.py @@ -0,0 +1,11 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +"""Supports efficiency with skip connections.""" +from .namespace import Namespace +from .skippable import pop, skippable, stash, verify_skippables + +__all__ = ["skippable", "stash", "pop", "verify_skippables", "Namespace"] diff --git a/torch/distributed/_pipeline/sync/skip/layout.py b/torch/distributed/_pipeline/sync/skip/layout.py new file mode 100644 index 000000000000..bff417bfbd65 --- /dev/null +++ b/torch/distributed/_pipeline/sync/skip/layout.py @@ -0,0 +1,86 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +"""Static skip connection layout of ``@skippable`` modules.""" +from typing import Dict, Iterable, List, Tuple + +from torch import nn + +from .namespace import Namespace + +__all__: List[str] = [] + + +class SkipLayout: + """Represents a skip connection layout across partitions.""" + + # Skip routes indexed by 'ns, name': {(ns, name): (prev_j, next_j), ...} + by_ns_name: Dict[Tuple[Namespace, str], Tuple[int, int]] + + # Skip routes indexed by partition number 'j': [[next_j]: [(prev_j, ns, name), ...], ...] + by_partition: List[List[Tuple[int, Namespace, str]]] + + def __init__(self, num_partitions: int, skip_routes: Dict[Tuple[Namespace, str], Tuple[int, int]],) -> None: + # The skip routes are already indexed by 'ns, name'. + self.by_ns_name = skip_routes + + # Index skip routes by partition number 'j'. + self.by_partition = [[] for _ in range(num_partitions)] + + for (ns, name), (prev_j, next_j) in skip_routes.items(): + self.by_partition[next_j].append((prev_j, ns, name)) + + for p in self.by_partition: + p.sort() + + def copy_policy(self, next_j: int) -> Iterable[Tuple[int, Namespace, str]]: + """Generates skip routes for the given destination partition number. + The skip routes are sorted by source partition number in ascending + order. + + Yields: + Each tuple of (source partition number, namespace, name). + + """ + for prev_j, ns, name in self.by_partition[next_j]: + if prev_j == next_j: + # This skip tensor will be popped at the same partition where + # it is stashed. In this case, copy is not required. + continue + + yield (prev_j, ns, name) + + def requires_copy(self, ns: Namespace, name: str) -> bool: + """Whether the given namespace and name requires partition-to-partition + copy or not. + """ + prev_j, next_j = self.by_ns_name.get((ns, name), (-1, -1)) + return prev_j != next_j + + +def inspect_skip_layout(partitions: List[nn.Sequential]) -> SkipLayout: + """Inspects the skip connection layout in the given partitions.""" + # NOTE(sublee): Hide circular import inside this subroutine. Circular + # import is not ideal but placing this logic near to SkipLayout may + # increase cohesion of code. + from .skippable import Skippable + + skip_routes: Dict[Tuple[Namespace, str], Tuple[int, int]] = {} + stashed_at: Dict[Tuple[Namespace, str], int] = {} + + for j, partition in enumerate(partitions): + for layer in partition: + if not isinstance(layer, Skippable): + continue + + for ns, name in layer.stashable(): + stashed_at[(ns, name)] = j + + for ns, name in layer.poppable(): + prev_j = stashed_at.pop((ns, name)) + skip_routes[(ns, name)] = (prev_j, j) + + return SkipLayout(len(partitions), skip_routes) diff --git a/torch/distributed/_pipeline/sync/skip/namespace.py b/torch/distributed/_pipeline/sync/skip/namespace.py new file mode 100644 index 000000000000..d2a8de92588e --- /dev/null +++ b/torch/distributed/_pipeline/sync/skip/namespace.py @@ -0,0 +1,50 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +"""Provides isolated namespace of skip tensors.""" +import abc +from functools import total_ordering +from typing import Any +import uuid + +__all__ = ["Namespace"] + + +@total_ordering +class Namespace(metaclass=abc.ABCMeta): + """Namespace for isolating skip tensors used by :meth:`isolate() + `. + """ + + __slots__ = ("id",) + + def __init__(self) -> None: + self.id = uuid.uuid4() + + def __repr__(self) -> str: + return f"" + + def __hash__(self) -> int: + return hash(self.id) + + # Namespaces should support ordering, since SkipLayout will sort tuples + # including a namespace. But actual order between namespaces is not + # important. That's why they are ordered by version 4 UUID which generates + # random numbers. + def __lt__(self, other: Any) -> bool: + if isinstance(other, Namespace): + return self.id < other.id + return False + + def __eq__(self, other: Any) -> bool: + if isinstance(other, Namespace): + return self.id == other.id + return False + + +# 'None' is the default namespace, +# which means that 'isinstance(None, Namespace)' is 'True'. +Namespace.register(type(None)) diff --git a/torch/distributed/_pipeline/sync/skip/portal.py b/torch/distributed/_pipeline/sync/skip/portal.py new file mode 100644 index 000000000000..6b3bbb3fb761 --- /dev/null +++ b/torch/distributed/_pipeline/sync/skip/portal.py @@ -0,0 +1,231 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +"""Portal keeps a tensor in the pocket plane. The tensor becomes hidden to the +autograd engine. The shared context of three functions (:class:`PortalBlue`, +:class:`PortalOrange`, and :class:`PortalCopy`) out of the computation graph is +one of the most important feature of :mod:`torchpipe.skip`. + +The metaphor is inspired by Portal™ from Valve. + +""" +from typing import List, Optional, Tuple + +import torch +from torch import Tensor + +from ..copy import Context as CopyContext +from ..copy import Copy +from ..phony import get_phony +from ..stream import AbstractStream, get_device + +__all__: List[str] = [] + + +class Portal: + """A portal for a tensor.""" + + def __init__(self, tensor: Optional[Tensor], tensor_life: int) -> None: + self.put_tensor(tensor, tensor_life) + self.grad: Optional[Tensor] = None + + def blue(self) -> Tensor: + """Creates a :class:`PortalBlue` which hides the underlying tensor from + the autograd engine. + + Join the returning phony to the main lane of the autograd graph to + assure the correct backpropagation:: + + PortalBlue --+ + | + ---------- Join -- + + """ + tensor = self.use_tensor() + + if tensor is None: + return get_phony(torch.device("cpu"), requires_grad=False) + + return PortalBlue.apply(self, tensor) + + def orange(self, phony: Tensor) -> Optional[Tensor]: + """Creates a :class:`PortalOrange` which retrieves the hidden tensor + without losing ability of backpropagation. + + Give a phony forked from the main lane of an autograd graph:: + + +-- PortalOrange --+ + | | + -- Fork --------- f(a, b) -- + + """ + self.check_tensor_life() + + if self.tensor is None: + return self.use_tensor() + + return PortalOrange.apply(self, phony) + + def copy(self, prev_stream: AbstractStream, next_stream: AbstractStream, phony: Tensor,) -> Tensor: + """Copies the hidden tensor by a :class:`PortalCopy`. + + Give a phony and use the returning phony to keep backpropagation:: + + +-- PortalCopy --+ + | | + -- Fork ---------- Join -- + + """ + if self.tensor is None: + return get_phony(torch.device("cpu"), requires_grad=False) + + return PortalCopy.apply(self, prev_stream, next_stream, phony) + + def check_tensor_life(self) -> None: + if self.tensor_life <= 0: + raise RuntimeError("tensor in portal has been removed") + + def put_tensor(self, tensor: Optional[Tensor], tensor_life: int) -> None: + """Stores a tensor into this portal.""" + # [Life of Tensor through Portal] + # + # The tensor can be retrieved by use_tensor() up to 'tensor_life' + # times. When the life becomes 0, the tensor will be deleted for + # deallocation in CUDA memory. + # + # The below events participate in a tensor through a portal. + # Note that [x] denotes the events which call use_tensor(): + # + # 1. [x] blue() + # 2. [ ] PortalBlue.forward + # 3. [ ] copy() + # 4. [ ] PortalCopy.forward + # 5. [ ] orange() + # 6. [x] PortalOrange.forward + # - - - - - - - - - - - - - - - - - - - - - - - - - - - + # 7. [ ] orange() (recomputed) + # 8. [x] PortalOrange.forward (recomputed) + # 9. [ ] PortalOrange.backward + # 10. [ ] PortalCopy.backward + # 11. [x] blue() (recomputed) + # 12. [ ] PortalBlue.forward (recomputed) + # 13. [ ] PortalBlue.backward + # + self.tensor_life = tensor_life + + if tensor_life > 0: + self.tensor = tensor + else: + self.tensor = None + + def use_tensor(self) -> Optional[Tensor]: + """Retrieves the underlying tensor and decreases the tensor life. When + the life becomes 0, it the tensor will be removed. + """ + self.check_tensor_life() + + tensor = self.tensor + + self.tensor_life -= 1 + + if self.tensor_life <= 0: + self.tensor = None + + return tensor + + def put_grad(self, grad: Tensor) -> None: + """Stores a gradient into this portal.""" + self.grad = grad + + def use_grad(self) -> Tensor: + """Retrieves and removes the underlying gradient. The gradient is + always ephemeral. + """ + if self.grad is None: + raise RuntimeError("grad in portal has been removed or never set") + + grad = self.grad + self.grad = None + return grad + + +# Common interface between :class:`PortalBlue`, :class:`PortalOrange`, and +# :class:`PortalCopy`. +class Context(CopyContext): + portal: Portal + + +class PortalBlue(torch.autograd.Function): + """Hides a tensor from the autograd engine by a :class:`Portal`.""" + + @staticmethod + # type: ignore + def forward( + ctx: Context, + portal: Portal, + # This tensor must be retrieved by portal.use_tensor(). + tensor: Tensor, + ) -> Tensor: + ctx.portal = portal + + phony = get_phony(tensor.device, requires_grad=False) + return phony.detach() + + @staticmethod + # type: ignore + def backward(ctx: Context, grad_phony: Tensor,) -> Tuple[None, Tensor]: + # The paired PortalOrange should keep the gradient. + grad = ctx.portal.use_grad() + return None, grad + + +class PortalOrange(torch.autograd.Function): + """Retrieves the hidden tensor from a :class:`Portal`.""" + + @staticmethod + # type: ignore + def forward(ctx: Context, portal: Portal, phony: Tensor) -> Tensor: + ctx.portal = portal + + tensor = portal.use_tensor() + assert tensor is not None + + return tensor.detach() + + @staticmethod + def backward(ctx: Context, grad: Tensor) -> Tuple[None, None]: # type: ignore + # The paired PortalBlue will use the gradient. + ctx.portal.put_grad(grad) + return None, None + + +class PortalCopy(torch.autograd.Function): + """Copies the hidden tensor in a :class:`Portal`. It replaces the hidden + tensor with copied one. + """ + + @staticmethod + # type: ignore + def forward( + ctx: Context, portal: Portal, prev_stream: AbstractStream, next_stream: AbstractStream, phony: Tensor, + ) -> Tensor: + ctx.portal = portal + + assert portal.tensor is not None + (portal.tensor,) = Copy.forward(ctx, prev_stream, next_stream, portal.tensor) + + phony = get_phony(get_device(next_stream), requires_grad=False) + return phony.detach() + + @staticmethod + # type: ignore + def backward(ctx: Context, grad_phony: Tensor,) -> Tuple[None, None, None, None]: + portal = ctx.portal + + assert portal.grad is not None + _, _, portal.grad = Copy.backward(ctx, portal.grad) + + return None, None, None, None diff --git a/torch/distributed/_pipeline/sync/skip/skippable.py b/torch/distributed/_pipeline/sync/skip/skippable.py new file mode 100644 index 000000000000..b5d07ff9c7a0 --- /dev/null +++ b/torch/distributed/_pipeline/sync/skip/skippable.py @@ -0,0 +1,439 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +"""The user interface to define skip connections.""" +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ClassVar, + Dict, + FrozenSet, + Generator, + Iterable, + List, + Optional, + Set, + Tuple, + Type, + TypeVar, + Union, + cast, +) + +from torch import Tensor, nn + +from ..microbatch import Batch +from .namespace import Namespace +from .tracker import current_skip_tracker + +__all__ = ["skippable", "stash", "pop", "verify_skippables"] + + +Tensors = Tuple[Tensor, ...] +TensorOrTensors = Union[Tensor, Tensors] + +StashPop = Union["stash", "pop"] +StashPopGenerator = Generator[StashPop, Optional[Tensor], TensorOrTensors] +if TYPE_CHECKING: + SkippableModule = nn.Module[Union[StashPopGenerator, TensorOrTensors]] +else: + SkippableModule = nn.Module + +T = TypeVar("T", bound="Skippable") + + +class Skippable(nn.Module): + """The base class for skippable modules. + + Do not use this class directly. Define a subclass by :func:`skippable` + instead. + + """ + + module_cls: ClassVar[Type[SkippableModule]] + stashable_names: ClassVar[FrozenSet[str]] + poppable_names: ClassVar[FrozenSet[str]] + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__() + self.module = self.module_cls(*args, **kwargs) # type: ignore + self.namespaces: Dict[str, Namespace] = {} + + def __repr__(self) -> str: + return f"@skippable({self.module})" + + def namespaced(self, name: str) -> Tuple[Namespace, str]: + """Prepends namespace for the given skip name.""" + ns = self.namespaces.get(name) + ns = cast(Namespace, ns) + return (ns, name) + + def stashable(self) -> Iterable[Tuple[Namespace, str]]: + """Iterates over namespaced skip names to be stashed.""" + for name in self.stashable_names: + yield self.namespaced(name) + + def poppable(self) -> Iterable[Tuple[Namespace, str]]: + """Iterates over namespaced skip names to be popped.""" + for name in self.poppable_names: + yield self.namespaced(name) + + def isolate(self: T, ns: Namespace, *, only: Optional[Iterable[str]] = None) -> T: + r"""Isolates a specified subset or the whole set of skip tensors into a + namespace. In a single sequential module, skip tensors with the same + name are not allowed unless they are isolated by different namespaces. + + Here's an example using the same name for skip tensors twice. Each pair + of ``Layer1`` and ``Layer2`` is isolated with its own namespace ``ns1`` + and ``ns2``. There is no conflict anymore:: + + ns1 = Namespace() + ns2 = Namespace() + + model = nn.Sequential( + Layer1().isolate(ns1), + Layer1().isolate(ns2), + Layer2(), + Layer3().isolate(ns2), + Layer3().isolate(ns1), + ) + + When `only` parameter is omitted, all skip tensors are isolated. You + can isolate a subset of skip tensors by passing `only` parameter:: + + ns_alice = Namespace() + ns_bob = Namespace() + + model = nn.Sequential( + ... + StashStashPop().isolate(ns_alice, only=['alice']) \ + .isolate(ns_bob, only=['bob']), + ... + ) + + Args: + ns (Namespace): + namespace for isolation + + Keyword Args: + only (iterable of strs): + names of specific skip tensors to be isolated (omit this option + to isolate all skip tensors declared in this module) + + Returns: + this module itself + + """ + names: Iterable[str] + + if only is None: + names = self.stashable_names | self.poppable_names + else: + names = set(only) + + for name in names: + self.namespaces[name] = ns + + return self + + def dispatch( + self, + input: TensorOrTensors, + handle_stash: Callable[[str, Optional[Tensor]], None], + handle_pop: Callable[[str], Optional[Tensor]], + ) -> TensorOrTensors: + """Dispatches :class:`stash` or :class:`pop` commands generated by the + module's ``forward()``. + """ + generator = self.module(input) + + if not isinstance(generator, Generator): + # The underlying module returned output without any yield. + output = generator + return output + + try: + op = next(generator) + + while True: + if isinstance(op, stash): + handle_stash(op.name, op.tensor) + op = next(generator) + continue + + if isinstance(op, pop): + tensor = handle_pop(op.name) + op = generator.send(tensor) + continue + + raise TypeError("%r is not a command from @skippable" % op) + + except StopIteration as stop: + output = stop.args[0] + return output + + def forward(self, input: TensorOrTensors) -> TensorOrTensors: # type: ignore + """Performs the forward propagation. :class:`stash` or :class:`pop` + commands will be handled by portals silently. The portals won't be + exposed to users. + + Raises: + RuntimeError: + illegal 'stash' or 'pop' is found. + + """ + skip_tracker = current_skip_tracker() + stashed_tensors: Dict[str, Optional[Tensor]] = {} + + # Load skip tensors that might be popped. + poppable_tensors = {} + batch = Batch(input) + for ns, name in self.poppable(): + try: + poppable_tensors[name] = skip_tracker.load(batch, ns, name) + except KeyError: + raise RuntimeError(f"'{name}' has not been stashed") + input = batch.tensor_or_tensors + + # Handle skip commands. + def handle_stash(name: str, tensor: Optional[Tensor]) -> None: + if name not in self.stashable_names: + raise RuntimeError(f"'{name}' has not been declared as stashable") + stashed_tensors[name] = tensor + + def handle_pop(name: str) -> Optional[Tensor]: + if name not in self.poppable_names: + raise RuntimeError(f"'{name}' has not been declared as poppable") + return poppable_tensors.pop(name) + + output = self.dispatch(input, handle_stash, handle_pop) + + # All declared skips must be stashed or popped. + not_stashed = self.stashable_names - stashed_tensors.keys() + if not_stashed: + comma_names = ", ".join("'%s'" % n for n in not_stashed) + raise RuntimeError(f"{comma_names} must be stashed but have not") + + not_popped = poppable_tensors.keys() + if not_popped: + comma_names = ", ".join("'%s'" % n for n in not_popped) + raise RuntimeError(f"{comma_names} must be popped but have not") + + # Save stashed skip tensors. + batch = Batch(output) + for ns, name in self.stashable(): + tensor = stashed_tensors[name] + skip_tracker.save(batch, ns, name, tensor) + output = batch.tensor_or_tensors + + return output + + +# TODO(sublee): Move to above of Skippable class for better read flow. +def skippable( + stash: Iterable[str] = (), pop: Iterable[str] = (), +) -> Callable[[Type[SkippableModule]], Type[Skippable]]: + """The decorator to define a :class:`nn.Module ` with skip + connections. Decorated modules are called "skippable". This functionality + works perfectly fine even when the module is not wrapped by + :class:`~torchpipe.Pipe`. + + Each skip tensor is managed by its name. Before manipulating skip tensors, + a skippable module must statically declare the names for skip tensors by + `stash` and/or `pop` parameters. Skip tensors with pre-declared name can be + stashed by ``yield stash(name, tensor)`` or popped by ``tensor = yield + pop(name)``. + + Here is an example with three layers. A skip tensor named "1to3" is stashed + and popped at the first and last layer, respectively:: + + @skippable(stash=['1to3']) + class Layer1(nn.Module): + def forward(self, input): + yield stash('1to3', input) + return f1(input) + + class Layer2(nn.Module): + def forward(self, input): + return f2(input) + + @skippable(pop=['1to3']) + class Layer3(nn.Module): + def forward(self, input): + skip_1to3 = yield pop('1to3') + return f3(input) + skip_1to3 + + model = nn.Sequential(Layer1(), Layer2(), Layer3()) + + One skippable module can stash or pop multiple skip tensors:: + + @skippable(stash=['alice', 'bob'], pop=['carol']) + class StashStashPop(nn.Module): + def forward(self, input): + yield stash('alice', f_alice(input)) + yield stash('bob', f_bob(input)) + carol = yield pop('carol') + return input + carol + + Every skip tensor must be associated with exactly one pair of `stash` and + `pop`. :class:`~torchpipe.Pipe` checks this restriction automatically + when wrapping a module. You can also check the restriction by + :func:`~torchpipe.skip.verify_skippables` without + :class:`~torchpipe.Pipe`. + + .. note:: + + :func:`@skippable ` changes the type of the wrapped class. + But currently (mypy v0.740), mypy could not understand class decorators + yet (`#3135 `_). + + There are two workarounds: + + 1. Naively ignore type errors by ``# type: ignore``. + 2. Use ``skippable()()`` as a function instead of a decorator. + + .. seealso:: :ref:`Long Skip Connections` + + """ + stashable_names = frozenset(stash) + poppable_names = frozenset(pop) + + def extend_skippable(module_cls: Type[SkippableModule]) -> Type[Skippable]: + name = module_cls.__name__ + bases = (Skippable,) + attrs = {"module_cls": module_cls, "stashable_names": stashable_names, "poppable_names": poppable_names} + return type(name, bases, attrs) + + return extend_skippable + + +class stash: + """The command to stash a skip tensor. + + :: + + def forward(self, input): + yield stash('name', input) + return f(input) + + Args: + name (str): name of skip tensor + input (torch.Tensor or None): tensor to pass to the skip connection + + """ + + __slots__ = ("name", "tensor") + + def __init__(self, name: str, tensor: Optional[Tensor]) -> None: + self.name = name + self.tensor = tensor + + +class pop: + """The command to pop a skip tensor. + + :: + + def forward(self, input): + skip = yield pop('name') + return f(input) + skip + + Args: + name (str): name of skip tensor + + Returns: + the skip tensor previously stashed by another layer under the same name + + """ + + __slots__ = ("name",) + + def __init__(self, name: str) -> None: + self.name = name + + +def verify_skippables(module: nn.Sequential) -> None: + """Verifies if the underlying skippable modules satisfy integrity. + + Every skip tensor must have only one pair of `stash` and `pop`. If there + are one or more unmatched pairs, it will raise :exc:`TypeError` with the + detailed messages. + + Here are a few failure cases. :func:`verify_skippables` will report failure + for these cases:: + + # Layer1 stashes "1to3". + # Layer3 pops "1to3". + + nn.Sequential(Layer1(), Layer2()) + # └──── ? + + nn.Sequential(Layer2(), Layer3()) + # ? ────┘ + + nn.Sequential(Layer1(), Layer2(), Layer3(), Layer3()) + # └───────────────────┘ ^^^^^^ + + nn.Sequential(Layer1(), Layer1(), Layer2(), Layer3()) + # ^^^^^^ └───────────────────┘ + + To use the same name for multiple skip tensors, they must be isolated by + different namespaces. See :meth:`isolate() + `. + + Raises: + TypeError: + one or more pairs of `stash` and `pop` are not matched. + + """ + stashed: Set[Tuple[Namespace, str]] = set() + popped: Set[Tuple[Namespace, str]] = set() + msgs: List[str] = [] + + for layer_name, layer in module.named_children(): + if not isinstance(layer, Skippable): + continue + + for name in layer.stashable_names & layer.poppable_names: + msg = f"'{layer_name}' declared '{name}' both as stashable and as poppable" + msgs.append(msg) + + for ns, name in layer.stashable(): + if name in layer.poppable_names: + continue + + if (ns, name) in stashed: + msg = f"'{layer_name}' redeclared '{name}' as stashable " "but not isolated by namespace" + msgs.append(msg) + continue + + stashed.add((ns, name)) + + for ns, name in layer.poppable(): + if name in layer.stashable_names: + continue + + if (ns, name) in popped: + msg = f"'{layer_name}' redeclared '{name}' as poppable " "but not isolated by namespace" + msgs.append(msg) + continue + + if (ns, name) not in stashed: + msg = f"'{layer_name}' declared '{name}' as poppable but it was not stashed" + msgs.append(msg) + continue + + popped.add((ns, name)) + + for (_, name) in stashed - popped: + msg = f"no module declared '{name}' as poppable but stashed" + msgs.append(msg) + + if msgs: + raise TypeError( + "one or more pairs of stash and pop do not match:\n\n%s" "" % "\n".join("* %s" % x for x in msgs) + ) diff --git a/torch/distributed/_pipeline/sync/skip/tracker.py b/torch/distributed/_pipeline/sync/skip/tracker.py new file mode 100644 index 000000000000..397158c21dbf --- /dev/null +++ b/torch/distributed/_pipeline/sync/skip/tracker.py @@ -0,0 +1,177 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +"""Tracks skip tensors on a thread.""" +from contextlib import contextmanager +import threading +from typing import Dict, Generator, List, Optional, Tuple + +from torch import Tensor + +from ..checkpoint import is_checkpointing +from ..dependency import fork, join +from ..microbatch import Batch +from ..stream import AbstractStream +from .layout import SkipLayout +from .namespace import Namespace +from .portal import Portal + +__all__: List[str] = [] + + +class SkipTracker: + """Tracks saved skip tensors. + + It will update the given micro-batch in place. This is because when it + manipulates the underlying skip tensors, the current micro-batch also has + to be connected with the skip tensors. + + One thread has one skip tracker. Call :func:`current_skip_tracker` to get + the skip tracker on the current thread. + + """ + + def __init__(self) -> None: + self.tensors: Dict[Tuple[Namespace, str], Optional[Tensor]] = {} + + def save(self, batch: Batch, ns: Namespace, name: str, tensor: Optional[Tensor]) -> None: + self.tensors[(ns, name)] = tensor + + def load(self, batch: Batch, ns: Namespace, name: str) -> Optional[Tensor]: + return self.tensors.pop((ns, name)) + + def copy( + self, batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream, ns: Namespace, name: str, + ) -> None: + raise TypeError("copy is not supported for non-portal skip tensors") + + +class SkipTrackerThroughPotals(SkipTracker): + """Tracks saved skip tensors through portals. The skip tensors will be + hidden in portals so that the autograd engine does not need to track them. + + This tracker is only used when the training or evaluating module is wrapped + with :class:`torchpipe.Pipe`. + + """ + + def __init__(self, skip_layout: SkipLayout) -> None: + super().__init__() + self.skip_layout = skip_layout + self.portals: Dict[Tuple[Namespace, str], Portal] = {} + + def save(self, batch: Batch, ns: Namespace, name: str, tensor: Optional[Tensor]) -> None: + """Saves the stashed skip tensor in a portal. The portal is then + connected to the given micro-batch with :class:`Join`. + """ + if not self.skip_layout.requires_copy(ns, name): + super().save(batch, ns, name, tensor) + return + + # See [Tensor Life of Portal] at Portal.put_tensor() to understand the + # below tensor_life values. Here are the selected events which retrieve + # the tensor in portal: + # + # 1. [x] blue() + # ... + # 6. [x] PortalOrange.forward + # ... + # 8. [x] PortalOrange.forward (recomputed) + # ... + # 11. [x] blue() (recomputed) + # + if (ns, name) not in self.portals: + if is_checkpointing(): + # Under checkpointing, the tensor used by the first + # PortalOrange should be alive in the portal. This tensor will + # be used again by the second PortalOrange during the + # recomputation. + tensor_life = 3 # Delete at [8. PortalOrange.forward (recomputed)] + else: + tensor_life = 2 # Delete at [6. PortalOrange.forward] + + portal = Portal(tensor, tensor_life) + self.portals[(ns, name)] = portal + + else: + # Under recomputation, the portal already exists. + portal = self.portals[(ns, name)] + + # The existing tensor life already became 0. It should be reset as + # 1 to delete the tensor after the second PortalBlue immediately. + tensor_life = 1 # Delete at [11. blue() (recomputed)] + + portal.put_tensor(tensor, tensor_life) + + phony = portal.blue() + batch[0] = join(batch[0], phony) + + def load(self, batch: Batch, ns: Namespace, name: str) -> Optional[Tensor]: + """Loads a skip tensor from the corresponding portal to pop. The given + micro-batch is connected to the portal with :class:`Fork`. + """ + if not self.skip_layout.requires_copy(ns, name): + tensor = super().load(batch, ns, name) + return tensor + + portal = self.portals[(ns, name)] + batch[0], phony = fork(batch[0]) + tensor = portal.orange(phony) + return tensor + + def copy( + self, batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream, ns: Namespace, name: str, + ) -> None: + """Copies the skip tensor in the corresponding portal. The given + micro-batch and the portal will be tied with :class:`Fork` and + :class:`Join`. + """ + assert self.skip_layout.requires_copy(ns, name) + + batch[0], phony = fork(batch[0]) + + portal = self.portals[(ns, name)] + phony = portal.copy(prev_stream, next_stream, phony) + + batch[0] = join(batch[0], phony) + + +class ThreadLocal(threading.local): + def __init__(self) -> None: + self.skip_tracker: Optional[SkipTracker] = None + + +thread_local = ThreadLocal() + + +@contextmanager +def use_skip_tracker(skip_tracker: SkipTracker) -> Generator[None, None, None]: + """Registers the given skip tracker on the current thread within a + context:: + + with use_skip_tracker(my_skip_tracker): + ... + + """ + orig = thread_local.skip_tracker + + thread_local.skip_tracker = skip_tracker + + try: + yield + finally: + thread_local.skip_tracker = orig + + +def current_skip_tracker() -> SkipTracker: + """Gets the skip tracker on the current thread.""" + skip_tracker = thread_local.skip_tracker + + if skip_tracker is None: + skip_tracker = SkipTracker() + thread_local.skip_tracker = skip_tracker + + return skip_tracker diff --git a/torch/distributed/_pipeline/sync/stream.py b/torch/distributed/_pipeline/sync/stream.py new file mode 100644 index 000000000000..0de4496808a0 --- /dev/null +++ b/torch/distributed/_pipeline/sync/stream.py @@ -0,0 +1,117 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +"""Utilities for eliminating boilerplate code to handle abstract streams with +CPU device. +""" +from contextlib import contextmanager +from typing import Generator, List, Union, cast + +import torch + +__all__: List[str] = [] + + +class CPUStreamType: + pass + + +# The placeholder on place of streams for the CPU device instead of CUDA. +CPUStream = CPUStreamType() + +# It represents both CUDA streams and the CPU stream. +AbstractStream = Union[torch.cuda.Stream, CPUStreamType] + + +def new_stream(device: torch.device) -> AbstractStream: + """Creates a new stream for either CPU or CUDA device.""" + if device.type != "cuda": + return CPUStream + return torch.cuda.Stream(device) + + +def current_stream(device: torch.device) -> AbstractStream: + """:func:`torch.cuda.current_stream` for either CPU or CUDA device.""" + if device.type != "cuda": + return CPUStream + return torch.cuda.current_stream(device) + + +def default_stream(device: torch.device) -> AbstractStream: + """:func:`torch.cuda.default_stream` for either CPU or CUDA device.""" + if device.type != "cuda": + return CPUStream + return torch.cuda.default_stream(device) + + +@contextmanager +def use_device(device: torch.device) -> Generator[None, None, None]: + """:func:`torch.cuda.device` for either CPU or CUDA device.""" + if device.type != "cuda": + yield + return + + with torch.cuda.device(device): + yield + + +@contextmanager +def use_stream(stream: AbstractStream) -> Generator[None, None, None]: + """:func:`torch.cuda.stream` for either CPU or CUDA stream.""" + if not is_cuda(stream): + yield + return + + with torch.cuda.stream(as_cuda(stream)): + yield + + +def get_device(stream: AbstractStream) -> torch.device: + """Gets the device from CPU or CUDA stream.""" + if is_cuda(stream): + return as_cuda(stream).device + return torch.device("cpu") + + +def wait_stream(source: AbstractStream, target: AbstractStream) -> None: + """:meth:`torch.cuda.Stream.wait_stream` for either CPU or CUDA stream. It + makes the source stream wait until the target stream completes work queued. + """ + if is_cuda(target): + if is_cuda(source): + # A CUDA stream waits another CUDA stream. + as_cuda(source).wait_stream(as_cuda(target)) + else: + # CPU waits a CUDA stream. + as_cuda(target).synchronize() + + # If the target is CPU, synchronization is not required. + + +def record_stream(tensor: torch.Tensor, stream: AbstractStream) -> None: + """:meth:`torch.Tensor.record_stream` for either CPU or CUDA stream.""" + if is_cuda(stream): + # NOTE(sublee): record_stream() on a shifted view tensor throws + # RuntimeError in PyTorch 1.1.0, and does nothing in 1.2.0. To safely + # protect the tensor against unexpected reallocation, here we use a + # temporal tensor associated with the same storage without shifting as + # a workaround. + # + # Issue: https://github.com/pytorch/pytorch/issues/27366 + # + tensor = tensor.new_empty([0]).set_(tensor.storage()) + + tensor.record_stream(as_cuda(stream)) + + +def is_cuda(stream: AbstractStream) -> bool: + """Returns ``True`` if the given stream is a valid CUDA stream.""" + return stream is not CPUStream + + +def as_cuda(stream: AbstractStream) -> torch.cuda.Stream: + """Casts the given stream as :class:`torch.cuda.Stream`.""" + return cast(torch.cuda.Stream, stream) diff --git a/torch/distributed/_pipeline/sync/worker.py b/torch/distributed/_pipeline/sync/worker.py new file mode 100644 index 000000000000..81a588071c2e --- /dev/null +++ b/torch/distributed/_pipeline/sync/worker.py @@ -0,0 +1,151 @@ +# Copyright 2019 Kakao Brain +# +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +"""Multithreading in pipeline parallelism.""" +from contextlib import contextmanager +from queue import Queue +import sys +from threading import Thread +from types import TracebackType +from typing import TYPE_CHECKING, Callable, Dict, Generator, List, Optional, Tuple, Type, Union, cast + +import torch + +from .microbatch import Batch +from .stream import AbstractStream, use_device, use_stream + +__all__: List[str] = [] + + +ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType] + +# Queue is generic only in stubs. +# https://mypy.readthedocs.io/en/latest/common_issues.html#using-classes-that-are-generic-in-stubs-but-not-at-runtime +if TYPE_CHECKING: + InQueue = Queue[Optional["Task"]] + OutQueue = Queue[Tuple[bool, Union[Tuple["Task", Batch], ExcInfo, None]]] +else: + InQueue = Queue + OutQueue = Queue + + +class Task: + """A task represents how to compute a micro-batch on a partition. + + It consists of two parts: :meth:`compute` and :meth:`finalize`. + :meth:`compute` should be executed in worker threads concurrently. + :meth:`finalize` should be executed after when worker threads complete to + execute :meth:`compute`. + + :meth:`compute` might be boosted by worker threads. Because it produces + several CUDA API calls by user code. In PyTorch, parallel CUDA API calls + are not serialized through GIL. So more than one CUDA API call can be + produced at the same time. + + """ + + def __init__( + self, stream: AbstractStream, *, compute: Callable[[], Batch], finalize: Optional[Callable[[Batch], None]], + ) -> None: + self.stream = stream + self._compute = compute + self._finalize = finalize + self._grad_enabled = torch.is_grad_enabled() + + def compute(self) -> Batch: + with use_stream(self.stream), torch.set_grad_enabled(self._grad_enabled): + return self._compute() + + def finalize(self, batch: Batch) -> None: + if self._finalize is None: + return + with use_stream(self.stream), torch.set_grad_enabled(self._grad_enabled): + self._finalize(batch) + + +def worker(in_queue: InQueue, out_queue: OutQueue, device: torch.device) -> None: + """The main loop of a worker thread.""" + with use_device(device): + while True: + task = in_queue.get() + + if task is None: + break + + try: + batch = task.compute() + except Exception: + exc_info = cast(ExcInfo, sys.exc_info()) + out_queue.put((False, exc_info)) + continue + + out_queue.put((True, (task, batch))) + + done = (False, None) + out_queue.put(done) + + +def create_workers(devices: List[torch.device],) -> Tuple[List[InQueue], List[OutQueue]]: + """Spawns worker threads. A worker thread is bound to a device.""" + in_queues: List[InQueue] = [] + out_queues: List[OutQueue] = [] + + # Spawn workers. + workers: Dict[torch.device, Tuple[InQueue, OutQueue]] = {} + + def normalize_device(device: torch.device) -> torch.device: + if device.type == "cuda" and device.index is None: + return torch.device("cuda", index=torch.cuda.current_device()) + + if device.type == "cpu" and device.index is not None: + return torch.device("cpu") + + return device + + for device in devices: + device = normalize_device(device) + + try: + in_queue, out_queue = workers[device] + except KeyError: + in_queue = Queue() + out_queue = Queue() + workers[device] = (in_queue, out_queue) + + t = Thread(target=worker, args=(in_queue, out_queue, device), daemon=True,) + t.start() + + in_queues.append(in_queue) + out_queues.append(out_queue) + + return (in_queues, out_queues) + + +def join_workers(in_queues: List[InQueue], out_queues: List[OutQueue]) -> None: + # Close workers. + for in_queue in set(in_queues): + in_queue.put(None) + + # Join running workers. + running = set(out_queues) + while running: + out_queue = running.pop() + ok, payload = out_queue.get() + + done = (False, None) + if (ok, payload) == done: + continue + + running.add(out_queue) + + +@contextmanager +def spawn_workers(devices: List[torch.device],) -> Generator[Tuple[List[InQueue], List[OutQueue]], None, None]: + try: + (in_queues, out_queues) = create_workers(devices) + yield (in_queues, out_queues) + finally: + join_workers(in_queues, out_queues)