mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: Action based on https://github.com/pytorch/pytorch/issues/66232 cc pietern mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse SciPioneer H-Huang Pull Request resolved: https://github.com/pytorch/pytorch/pull/66797 Reviewed By: gchanan Differential Revision: D31761389 Pulled By: janeyx99 fbshipit-source-id: c27c9ab4acec1eb71d5edd4538cd113b770dfc6c
822 lines
23 KiB
Python
822 lines
23 KiB
Python
# Owner(s): ["oncall: distributed"]
|
|
|
|
# Copyright 2019 Kakao Brain
|
|
#
|
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
from collections import OrderedDict
|
|
from copy import deepcopy
|
|
import time
|
|
|
|
import pytest
|
|
import random
|
|
import torch
|
|
from torch import nn
|
|
from torch import Tensor
|
|
|
|
from torch.distributed.pipeline.sync import Pipe, NoChunk, WithDevice
|
|
from torch.distributed.pipeline.sync.pipe import PipeSequential
|
|
|
|
skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
|
|
|
|
|
|
def test_pipe_without_rpc():
|
|
model = nn.Sequential(nn.Linear(1, 1))
|
|
with pytest.raises(RuntimeError, match='Please initialize RPC framework'):
|
|
pipe = Pipe(model, chunks=1)
|
|
|
|
def test_parameters(setup_rpc):
|
|
model = nn.Sequential(nn.Linear(1, 1))
|
|
pipe = Pipe(model, chunks=1)
|
|
assert list(pipe.parameters()) != []
|
|
|
|
|
|
def test_public_attrs(setup_rpc):
|
|
class MyString:
|
|
def __init__(self, value):
|
|
self.value = value
|
|
|
|
def __str__(self):
|
|
return self.value
|
|
|
|
model = nn.Sequential(nn.Linear(1, 1))
|
|
pipe = Pipe(model, chunks=42.000, checkpoint=MyString("always"))
|
|
|
|
assert pipe.devices == [torch.device("cpu")]
|
|
assert pipe.chunks == 42
|
|
assert isinstance(pipe.chunks, int)
|
|
assert pipe.checkpoint == "always"
|
|
assert isinstance(pipe.checkpoint, str)
|
|
|
|
|
|
def test_sequential_like(setup_rpc):
|
|
a = nn.Linear(1, 1)
|
|
b = nn.Linear(1, 1)
|
|
|
|
model = nn.Sequential(a, b)
|
|
model = Pipe(model)
|
|
|
|
assert len(model) == 2
|
|
assert list(model) == [a, b]
|
|
|
|
assert model[0] is a
|
|
assert model[1] is b
|
|
with pytest.raises(IndexError):
|
|
_ = model[2]
|
|
|
|
assert model[-1] is b
|
|
assert model[-2] is a
|
|
|
|
def test_chunks_less_than_1(setup_rpc):
|
|
model = nn.Sequential(nn.Linear(1, 1))
|
|
|
|
with pytest.raises(ValueError):
|
|
Pipe(model, chunks=0)
|
|
|
|
with pytest.raises(ValueError):
|
|
Pipe(model, chunks=-1)
|
|
|
|
def test_batch_size_indivisible(setup_rpc):
|
|
model = nn.Sequential(nn.Linear(1, 1))
|
|
model = Pipe(model, chunks=4)
|
|
|
|
with pytest.warns(None) as record:
|
|
model(torch.rand(7, 1))
|
|
|
|
# Indivisible batch size is legal.
|
|
assert not record
|
|
|
|
|
|
def test_batch_size_small(setup_rpc):
|
|
model = nn.Sequential(nn.Linear(1, 1))
|
|
model = Pipe(model, chunks=4)
|
|
|
|
with pytest.warns(None) as record:
|
|
model(torch.rand(2, 1))
|
|
|
|
# Batch size smaller than chunks is legal.
|
|
assert not record
|
|
|
|
|
|
def test_checkpoint_mode(setup_rpc):
|
|
def count_grad_fn(grad_fn, name, visited=None):
|
|
if visited is None:
|
|
visited = set()
|
|
if grad_fn in visited:
|
|
return 0
|
|
visited.add(grad_fn)
|
|
|
|
if grad_fn is None:
|
|
return 0
|
|
if grad_fn.__class__.__name__ == name:
|
|
return 1
|
|
|
|
counter = 0
|
|
for next_grad_fn, _ in grad_fn.next_functions:
|
|
counter += count_grad_fn(next_grad_fn, name, visited=visited)
|
|
return counter
|
|
|
|
model = nn.Sequential(nn.Linear(1, 1))
|
|
input = torch.rand(2, 1)
|
|
|
|
always = Pipe(model, chunks=2, checkpoint="always")
|
|
except_last = Pipe(model, chunks=2, checkpoint="except_last")
|
|
never = Pipe(model, chunks=2, checkpoint="never")
|
|
|
|
always_output = always(input)
|
|
except_last_output = except_last(input)
|
|
never_output = never(input)
|
|
|
|
assert count_grad_fn(always_output.local_value().grad_fn, "CheckpointBackward") == 2
|
|
assert count_grad_fn(except_last_output.local_value().grad_fn, "CheckpointBackward") == 1
|
|
assert count_grad_fn(never_output.local_value().grad_fn, "CheckpointBackward") == 0
|
|
|
|
|
|
def test_checkpoint_mode_invalid(setup_rpc):
|
|
model = nn.Sequential(nn.Linear(1, 1))
|
|
|
|
with pytest.raises(ValueError, match="checkpoint is not one of 'always', 'except_last', or 'never'"):
|
|
Pipe(model, chunks=2, checkpoint="INVALID_CHECKPOINT")
|
|
|
|
|
|
def test_checkpoint_mode_when_chunks_1(setup_rpc):
|
|
model = nn.Sequential(nn.Linear(1, 1))
|
|
|
|
# All checkpoint modes are fine.
|
|
Pipe(model, chunks=1, checkpoint="except_last")
|
|
Pipe(model, chunks=1, checkpoint="always")
|
|
Pipe(model, chunks=1, checkpoint="never")
|
|
|
|
|
|
def test_checkpoint_eval(setup_rpc):
|
|
model = nn.Sequential(nn.Linear(1, 1))
|
|
model = Pipe(model, chunks=2)
|
|
input = torch.rand(2, 1)
|
|
|
|
def find_grad_fn(grad_fn, name):
|
|
if grad_fn is None:
|
|
return False
|
|
if grad_fn.__class__.__name__ == name:
|
|
return True
|
|
for next_grad_fn, _ in grad_fn.next_functions:
|
|
if find_grad_fn(next_grad_fn, name):
|
|
return True
|
|
return False
|
|
|
|
model.train()
|
|
train_output = model(input)
|
|
assert find_grad_fn(train_output.local_value().grad_fn, "CheckpointBackward")
|
|
assert find_grad_fn(train_output.local_value().grad_fn, "RecomputeBackward")
|
|
|
|
model.eval()
|
|
eval_output = model(input)
|
|
assert not find_grad_fn(eval_output.local_value().grad_fn, "CheckpointBackward")
|
|
assert not find_grad_fn(eval_output.local_value().grad_fn, "RecomputeBackward")
|
|
|
|
|
|
def test_checkpoint_non_float_input(setup_rpc):
|
|
class ForkNonFloat(nn.Module):
|
|
def forward(self, input):
|
|
return (input * 2, torch.tensor([False]))
|
|
|
|
class JoinNonFloat(nn.Module):
|
|
def forward(self, input, non_float):
|
|
return input * 2
|
|
|
|
model = nn.Sequential(ForkNonFloat(), JoinNonFloat())
|
|
model = Pipe(model, chunks=1, checkpoint="always")
|
|
|
|
input = torch.rand(1, requires_grad=True)
|
|
output = model(input)
|
|
output.backward()
|
|
|
|
|
|
def test_no_grad(setup_rpc):
|
|
model = nn.Sequential(nn.Linear(1, 1))
|
|
model = Pipe(model, chunks=2)
|
|
input = torch.rand(2, 1)
|
|
|
|
latent = None
|
|
|
|
def hook(module, input, output):
|
|
_ = module
|
|
_ = input
|
|
|
|
nonlocal latent
|
|
latent = output
|
|
|
|
partition = model.partitions[0]
|
|
partition.register_forward_hook(hook)
|
|
|
|
with torch.no_grad():
|
|
model(input)
|
|
|
|
assert latent.grad_fn is None
|
|
|
|
|
|
def test_exception(setup_rpc):
|
|
class ExpectedException(Exception):
|
|
pass
|
|
|
|
class Raise(nn.Module):
|
|
def forward(self, *_):
|
|
raise ExpectedException()
|
|
|
|
model = nn.Sequential(Raise())
|
|
model = Pipe(model, chunks=1)
|
|
|
|
with pytest.raises(ExpectedException):
|
|
model(torch.rand(1))
|
|
|
|
|
|
def test_exception_early_stop_asap(setup_rpc):
|
|
"""Even the first partitions have finished to process, the partition before
|
|
the failed partition should be killed as soon as possible.
|
|
"""
|
|
|
|
class ExpectedException(Exception):
|
|
pass
|
|
|
|
class Pass(nn.Module):
|
|
def forward(self, x):
|
|
return x
|
|
|
|
counter = 0
|
|
|
|
class Counter(nn.Module):
|
|
def forward(self, x):
|
|
time.sleep(0.1)
|
|
|
|
nonlocal counter
|
|
counter += 1
|
|
|
|
return x
|
|
|
|
class Raise(nn.Module):
|
|
def forward(self, x):
|
|
raise ExpectedException()
|
|
|
|
model = nn.Sequential(Pass(), Pass(), Counter(), Raise())
|
|
model = Pipe(model, chunks=3)
|
|
|
|
with pytest.raises(ExpectedException):
|
|
model(torch.rand(3))
|
|
|
|
# If the early stop doesn't work, it would be 3 instead.
|
|
assert counter == 2
|
|
|
|
|
|
def test_nested_input(setup_rpc):
|
|
class NestedInput(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.fc_a = nn.Linear(1, 1)
|
|
self.fc_b = nn.Linear(1, 1)
|
|
|
|
def forward(self, inp):
|
|
return inp
|
|
|
|
model = nn.Sequential(NestedInput())
|
|
model = Pipe(model, chunks=2)
|
|
|
|
a = torch.rand(10, 1, requires_grad=True)
|
|
b = torch.rand(10, 1, requires_grad=True)
|
|
|
|
# TypeError: expected Tensor, but got tuple
|
|
with pytest.raises(TypeError):
|
|
model((a, (a, b))).local_value()
|
|
|
|
# TypeError: expected Tensor, but got list
|
|
with pytest.raises(TypeError):
|
|
model((a, [a, b])).local_value()
|
|
|
|
|
|
def test_input_pair(setup_rpc):
|
|
class Two(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.fc_a = nn.Linear(1, 1)
|
|
self.fc_b = nn.Linear(1, 1)
|
|
|
|
def forward(self, a, b):
|
|
return (self.fc_a(a), self.fc_b(b))
|
|
|
|
model = nn.Sequential(Two())
|
|
model = Pipe(model, chunks=2)
|
|
|
|
a = torch.rand(10, 1, requires_grad=True)
|
|
b = torch.rand(10, 1, requires_grad=True)
|
|
|
|
a_out, b_out = model(a, b).local_value()
|
|
loss = (a_out + b_out).mean()
|
|
loss.backward()
|
|
|
|
assert a.grad is not None
|
|
assert b.grad is not None
|
|
|
|
def test_multi_sequence_input(setup_rpc):
|
|
class MultiSeq(nn.Module):
|
|
def forward(self, tup1, tup2):
|
|
return tup1, tup2
|
|
|
|
model = Pipe(nn.Sequential(MultiSeq()))
|
|
with pytest.raises(TypeError):
|
|
model(
|
|
[torch.rand(10), torch.rand(10)],
|
|
[torch.rand(10), torch.rand(10)]
|
|
)
|
|
|
|
def test_input_singleton(setup_rpc):
|
|
class One(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.fc = nn.Linear(1, 1)
|
|
|
|
def forward(self, a):
|
|
return (self.fc(a),)
|
|
|
|
model = nn.Sequential(One())
|
|
model = Pipe(model, chunks=2)
|
|
|
|
a = torch.rand(10, 1, requires_grad=True)
|
|
|
|
(a_out,) = model(a).local_value()
|
|
loss = a_out.mean()
|
|
loss.backward()
|
|
|
|
assert all(p.grad is not None for p in model.parameters())
|
|
assert a.grad is not None
|
|
|
|
|
|
def test_input_varargs(setup_rpc):
|
|
model = nn.Sequential(nn.Linear(1, 1))
|
|
model = Pipe(model)
|
|
|
|
a = torch.rand(1)
|
|
b = torch.rand(1)
|
|
|
|
# TypeError: forward() takes 2 positional arguments but 3 were given
|
|
with pytest.raises(TypeError):
|
|
model(a, b)
|
|
|
|
|
|
def test_non_tensor(setup_rpc):
|
|
class NonTensor(nn.Module):
|
|
def forward(self, _):
|
|
return "hello"
|
|
|
|
model = nn.Sequential(NonTensor())
|
|
model = Pipe(model)
|
|
x = torch.rand(1)
|
|
|
|
with pytest.raises(TypeError):
|
|
model(x)
|
|
|
|
with pytest.raises(TypeError):
|
|
model("hello")
|
|
|
|
|
|
def test_non_tensor_sequence(setup_rpc):
|
|
class NonTensorTuple(nn.Module):
|
|
def forward(self, x):
|
|
return (x, "hello")
|
|
|
|
class NonTensorArgs(nn.Module):
|
|
def forward(self, x: str, y: bool):
|
|
return x, y
|
|
|
|
model = nn.Sequential(NonTensorTuple())
|
|
model = Pipe(model)
|
|
x = torch.rand(1)
|
|
|
|
with pytest.raises(TypeError):
|
|
model((x, "hello"))
|
|
|
|
with pytest.raises(TypeError):
|
|
model([x, "hello"])
|
|
|
|
model = nn.Sequential(NonTensorArgs())
|
|
model = Pipe(model)
|
|
|
|
with pytest.raises(TypeError):
|
|
# Need atleast one Tensor.
|
|
model("hello", True)
|
|
|
|
|
|
@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"])
|
|
def test_valid_non_tensor(checkpoint, setup_rpc):
|
|
class NonTensor1(nn.Module):
|
|
def forward(self, a: int, b: Tensor, c: bool, d: Tensor):
|
|
res = b + a if c else b * a
|
|
if d is not None:
|
|
res += d
|
|
return res, c, a, b, "hello", d
|
|
|
|
class NonTensor2(nn.Module):
|
|
def forward(self, a: Tensor, b: bool, c: int, d: Tensor, e: str, f: Tensor):
|
|
res = a * c if b else a + c
|
|
res += d
|
|
return c, res, a, d + f if f is not None else d, b, e, f
|
|
|
|
model = Pipe(nn.Sequential(NonTensor1(), NonTensor2()), chunks=5, checkpoint=checkpoint)
|
|
a = random.randint(0, 10)
|
|
b = torch.rand(10, 10)
|
|
c = random.randint(0, 1) == 0
|
|
d = torch.rand(10, 10)
|
|
res = model(a, b, c, d).local_value()
|
|
assert 7 == len(res)
|
|
assert [a] * 5 == res[0]
|
|
if c:
|
|
assert torch.allclose(((b + a + d) * a) + b, res[1])
|
|
assert torch.allclose(b + a + d, res[2])
|
|
else:
|
|
assert torch.allclose(((b * a) + d + a) + b, res[1])
|
|
assert torch.allclose(b * a + d, res[2])
|
|
assert torch.allclose(b + d, res[3])
|
|
assert [c] * 5 == res[4]
|
|
assert ["hello"] * 5 == res[5]
|
|
assert torch.allclose(d, res[6])
|
|
|
|
# Test one of the tensors can be None
|
|
res = model(a, b, c, None).local_value()
|
|
assert 7 == len(res)
|
|
assert [a] * 5 == res[0]
|
|
if c:
|
|
assert torch.allclose(((b + a) * a) + b, res[1])
|
|
assert torch.allclose(b + a, res[2])
|
|
else:
|
|
assert torch.allclose(((b * a) + a) + b, res[1])
|
|
assert torch.allclose(b * a, res[2])
|
|
assert torch.allclose(b, res[3])
|
|
assert [c] * 5 == res[4]
|
|
assert ["hello"] * 5 == res[5]
|
|
assert [None] * 5 == res[6]
|
|
|
|
# Need atleast one tensor.
|
|
with pytest.raises(TypeError):
|
|
model(a, None, c, None)
|
|
|
|
@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"])
|
|
def test_no_tensor_output(checkpoint, setup_rpc):
|
|
class Model1(nn.Module):
|
|
def forward(self, a: int, b: Tensor, c: bool):
|
|
return a, c, "hello"
|
|
|
|
class Model2(nn.Module):
|
|
def forward(self, a: int, b: bool, c: str):
|
|
return a, c, b
|
|
|
|
model = Pipe(nn.Sequential(Model1(), Model2()), chunks=5)
|
|
a = random.randint(0, 10)
|
|
b = torch.rand(10, 10)
|
|
c = random.randint(0, 1) == 0
|
|
|
|
# Need atleast one tensor across partitions too.
|
|
with pytest.raises(TypeError):
|
|
res = model(a, b, c).local_value()
|
|
|
|
|
|
@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"])
|
|
def test_uneven_batch_size(checkpoint, setup_rpc):
|
|
class Model(nn.Module):
|
|
def forward(self, a: Tensor, b: int, c: Tensor):
|
|
return a, b, c
|
|
|
|
model = Pipe(nn.Sequential(Model()), checkpoint=checkpoint, chunks=5)
|
|
a = torch.rand(3, 10)
|
|
b = random.randint(0, 10)
|
|
c = torch.rand(6, 10)
|
|
res = model(a, b, c).local_value()
|
|
assert torch.allclose(a, res[0])
|
|
assert [b] * 3 == res[1] # 3 chunks
|
|
assert torch.allclose(c, res[2])
|
|
|
|
# Two tensors producing uneven chunks would fail.
|
|
model = Pipe(nn.Sequential(Model()), checkpoint=checkpoint, chunks=5)
|
|
a = torch.rand(3, 10)
|
|
b = random.randint(0, 10)
|
|
c = torch.rand(4, 10)
|
|
|
|
with pytest.raises(RuntimeError, match='Found different number of chunks'):
|
|
model(a, b, c)
|
|
|
|
@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"])
|
|
def test_no_chunk(checkpoint, setup_rpc):
|
|
class Model(nn.Module):
|
|
def forward(self, a: Tensor, b: int, c: Tensor):
|
|
return a, b, c
|
|
|
|
model = Pipe(nn.Sequential(Model()), checkpoint=checkpoint, chunks=5)
|
|
a = torch.rand(10, 10)
|
|
b = random.randint(0, 10)
|
|
c = torch.rand(10, 10)
|
|
res = model(a, b, NoChunk(c)).local_value()
|
|
assert torch.allclose(a, res[0])
|
|
assert [b] * 5 == res[1]
|
|
# c gets replicated due to NoChunk and the same tensor gets concatenated 5
|
|
# times in the output.
|
|
assert torch.allclose(torch.cat((c, c, c, c, c)), res[2])
|
|
|
|
# Test invalid type for NoChunk
|
|
with pytest.raises(TypeError, match='NoChunk only supported for tensors'):
|
|
NoChunk(b)
|
|
|
|
|
|
@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"])
|
|
def test_deferred_batch_norm(checkpoint, setup_rpc):
|
|
bn = nn.BatchNorm2d(3)
|
|
pipe_bn = deepcopy(bn)
|
|
pipe = Pipe(
|
|
nn.Sequential(pipe_bn), chunks=2, checkpoint=checkpoint, deferred_batch_norm=True
|
|
)
|
|
|
|
x = torch.rand(4, 3, 10, 10)
|
|
pipe(x).local_value().mean().backward()
|
|
bn(x).mean().backward()
|
|
|
|
assert torch.allclose(pipe[0].running_mean, bn.running_mean, atol=1e-4)
|
|
assert torch.allclose(pipe[0].running_var, bn.running_var, atol=1e-4)
|
|
|
|
|
|
@pytest.mark.parametrize("checkpoint", ["never", "always"])
|
|
def test_deferred_batch_norm_params(checkpoint, setup_rpc):
|
|
bn = nn.BatchNorm2d(3)
|
|
pipe_bn = deepcopy(bn)
|
|
pipe = Pipe(
|
|
nn.Sequential(pipe_bn), chunks=1, checkpoint=checkpoint, deferred_batch_norm=True
|
|
)
|
|
|
|
x = torch.rand(4, 3, 10, 10)
|
|
pipe(x).local_value().mean().backward()
|
|
bn(x).mean().backward()
|
|
|
|
assert pipe[0].weight.grad is not None
|
|
assert pipe[0].bias.grad is not None
|
|
|
|
assert torch.allclose(pipe[0].weight.grad, bn.weight.grad, atol=1e-4)
|
|
assert torch.allclose(pipe[0].bias.grad, bn.bias.grad, atol=1e-4)
|
|
|
|
|
|
def test_devices(setup_rpc):
|
|
a = nn.Linear(1, 1)
|
|
b = nn.Linear(1, 1)
|
|
c = nn.Linear(1, 1)
|
|
|
|
# There are extra two devices.
|
|
model = nn.Sequential(a, b, c)
|
|
model = Pipe(model)
|
|
|
|
cpu = torch.device("cpu")
|
|
# Extra devices must be discarded.
|
|
assert model.devices == [cpu, cpu, cpu]
|
|
|
|
|
|
def test_partitions(setup_rpc):
|
|
a = nn.Linear(1, 1)
|
|
b = nn.Linear(1, 1)
|
|
|
|
model = nn.Sequential(a, b)
|
|
model = Pipe(model)
|
|
|
|
assert isinstance(model.partitions, nn.ModuleList)
|
|
assert isinstance(model.partitions[0], nn.Sequential)
|
|
assert isinstance(model.partitions[1], nn.Sequential)
|
|
|
|
assert "partitions.0.0.weight" in model.state_dict()
|
|
|
|
|
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
|
|
def test_merged_partitions(setup_rpc):
|
|
a = nn.Linear(1, 1).to(0)
|
|
b = nn.Sequential(nn.Linear(1, 1), nn.Linear(1, 2)).to(0)
|
|
c = nn.Linear(1, 1)
|
|
d = nn.Linear(1, 2)
|
|
|
|
model = nn.Sequential(a, b, c, d)
|
|
model = Pipe(model)
|
|
|
|
assert isinstance(model.partitions, nn.ModuleList)
|
|
assert isinstance(model.partitions[0], PipeSequential)
|
|
assert isinstance(model.partitions[1], PipeSequential)
|
|
assert list(model.partitions[0]) == [a, b[0], b[1]]
|
|
assert list(model.partitions[1]) == [c]
|
|
assert list(model.partitions[2]) == [d]
|
|
|
|
|
|
def test_deny_moving(setup_rpc):
|
|
a = nn.Linear(1, 1)
|
|
b = nn.Linear(1, 1)
|
|
|
|
model = nn.Sequential(a, b)
|
|
model = Pipe(model)
|
|
|
|
# Moving is denied.
|
|
with pytest.raises(TypeError):
|
|
model.cuda()
|
|
|
|
with pytest.raises(TypeError):
|
|
model.cpu()
|
|
|
|
with pytest.raises(TypeError):
|
|
model.to(torch.device("cuda"))
|
|
|
|
with pytest.raises(TypeError):
|
|
model.to(0)
|
|
|
|
with pytest.raises(TypeError):
|
|
model.to("cuda")
|
|
|
|
with pytest.raises(TypeError):
|
|
model.to(device=0)
|
|
|
|
with pytest.raises(TypeError):
|
|
model.to(torch.rand(1))
|
|
|
|
with pytest.raises(TypeError):
|
|
model.to(tensor=torch.rand(1))
|
|
|
|
# Casting is allowed.
|
|
model.half()
|
|
model.to(torch.double)
|
|
model.to(dtype=torch.float)
|
|
|
|
|
|
def test_empty_module(setup_rpc):
|
|
# Empty sequential module is not illegal.
|
|
model = nn.Sequential()
|
|
model = Pipe(model)
|
|
|
|
assert model(torch.tensor(42)).local_value() == torch.tensor(42)
|
|
|
|
# But only tensor or tensors is legal in Pipe.
|
|
with pytest.raises(TypeError):
|
|
model(42)
|
|
|
|
|
|
def test_named_children(setup_rpc):
|
|
a = nn.Linear(1, 1)
|
|
b = nn.Linear(1, 1)
|
|
|
|
model = nn.Sequential(OrderedDict([("a", a), ("b", b)]))
|
|
model = Pipe(model)
|
|
|
|
names = set(n for n, _ in model.named_modules())
|
|
assert "partitions.0.0" in names
|
|
assert "partitions.1.0" in names
|
|
|
|
# Pipe doesn't support __getattr__. Unlike nn.Sequential, Pipe requires
|
|
# several methods in its namespace.
|
|
with pytest.raises(AttributeError):
|
|
model.a
|
|
|
|
|
|
def test_verify_module_non_sequential(setup_rpc):
|
|
with pytest.raises(TypeError, match="module must be nn.Sequential to be partitioned"):
|
|
Pipe(nn.Module())
|
|
|
|
|
|
def test_verify_module_duplicate_children(setup_rpc):
|
|
conv = nn.Conv2d(3, 3, 1)
|
|
model = nn.Sequential(conv, conv)
|
|
|
|
with pytest.raises(ValueError, match="module with duplicate children is not supported"):
|
|
Pipe(model)
|
|
|
|
|
|
@skip_if_no_cuda
|
|
def test_verify_module_params_on_same_device(setup_rpc):
|
|
class Surrogate(nn.Module):
|
|
def __init__(self, param1, param2):
|
|
super().__init__()
|
|
self.param1 = param1
|
|
self.param2 = param2
|
|
|
|
conv1 = nn.Conv2d(3, 3, 1)
|
|
conv2 = nn.Conv2d(3, 3, 1)
|
|
model = nn.Sequential(Surrogate(conv1, conv2.cuda()))
|
|
|
|
with pytest.raises(
|
|
ValueError,
|
|
match=r'should have all parameters on a single device, please use .to\(\)'
|
|
' to place the module on a single device'):
|
|
Pipe(model)
|
|
|
|
@skip_if_no_cuda
|
|
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need atleast two GPUs")
|
|
def test_verify_nested_modules(setup_rpc):
|
|
model = nn.Sequential(
|
|
nn.Sequential(
|
|
nn.Linear(32, 16).cuda(0),
|
|
nn.Linear(16, 8).cuda(0)
|
|
),
|
|
nn.Sequential(
|
|
nn.Linear(8, 4).cuda(1),
|
|
nn.Linear(4, 2).cuda(1)
|
|
),
|
|
)
|
|
|
|
pipe = Pipe(model)
|
|
out = pipe(torch.rand(10, 32).cuda(0))
|
|
assert out.local_value().device == torch.device("cuda:1")
|
|
assert out.local_value().size() == torch.Size([10, 2])
|
|
|
|
def test_verify_module_duplicate_parameters_on_same_device(setup_rpc):
|
|
class Surrogate(nn.Module):
|
|
def __init__(self, module):
|
|
super().__init__()
|
|
self.module = module
|
|
|
|
conv = nn.Conv2d(3, 3, 1)
|
|
model = nn.Sequential(Surrogate(conv), Surrogate(conv))
|
|
|
|
Pipe(model)
|
|
|
|
|
|
def test_forward_lockstep(setup_rpc):
|
|
timeline = []
|
|
|
|
class DelayedLog(nn.Module):
|
|
def __init__(self, j, seconds):
|
|
super().__init__()
|
|
self.i = 0
|
|
self.j = j
|
|
self.seconds = seconds
|
|
|
|
def forward(self, x):
|
|
time.sleep(self.seconds)
|
|
|
|
timeline.append((self.i, self.j))
|
|
self.i += 1
|
|
|
|
return x
|
|
|
|
model = nn.Sequential(DelayedLog(0, seconds=0), DelayedLog(1, seconds=0.1))
|
|
model = Pipe(model, chunks=3)
|
|
model(torch.rand(3, 1))
|
|
|
|
# Expected timeline: (Logs are recorded at !)
|
|
#
|
|
# Partition #0: 0! 1! 2!
|
|
# Partition #1: 000! 111! 222!
|
|
#
|
|
assert timeline == [(0, 0), (1, 0), (0, 1), (2, 0), (1, 1), (2, 1)]
|
|
|
|
@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"])
|
|
@skip_if_no_cuda
|
|
def test_multiple_inputs(checkpoint, setup_rpc):
|
|
class Module1(nn.Module):
|
|
def forward(self, a, b, c):
|
|
return a + b + c, a * b * c
|
|
|
|
class Module2(nn.Module):
|
|
def forward(self, a, b):
|
|
return a + b
|
|
|
|
model = Pipe(nn.Sequential(Module1().cuda(0), Module2().cuda(0)), chunks=2, checkpoint=checkpoint)
|
|
t = torch.rand(10)
|
|
res = model(t, t, t).local_value()
|
|
assert torch.equal(res, (t + t + t) + (t * t * t))
|
|
|
|
@skip_if_no_cuda
|
|
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need atleast two GPUs")
|
|
def test_inputs_wrong_device(setup_rpc):
|
|
class Module1(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.param = torch.nn.Parameter(torch.rand(5))
|
|
|
|
def forward(self, a, b):
|
|
return a + b + self.param, b
|
|
|
|
# Start inputs on wrong device and ensure Pipe moves them correctly.
|
|
a = torch.rand(10).cuda(1)
|
|
b = torch.rand(10).cuda(1)
|
|
model = Pipe(nn.Sequential(Module1().cuda(0), Module1().cuda(1)), chunks=2)
|
|
with pytest.raises(ValueError, match='All inputs should be on the same device as the first partition'):
|
|
model(a, b)
|
|
|
|
@skip_if_no_cuda
|
|
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need atleast two GPUs")
|
|
def test_with_device_wrapper(setup_rpc):
|
|
fc1 = nn.Linear(16, 8).cuda(0)
|
|
fc2 = nn.Linear(8, 4).cuda(1)
|
|
dropout = nn.Dropout()
|
|
|
|
model = nn.Sequential(fc1, fc2, WithDevice(dropout, 'cuda:1'))
|
|
model = Pipe(model, chunks=8)
|
|
assert torch.device('cuda:1') == model(torch.rand(16, 16).cuda(0)).local_value().device
|
|
assert [torch.device('cuda:0'), torch.device('cuda:1')] == model.devices
|
|
|
|
model = nn.Sequential(fc1, WithDevice(dropout, 'cuda:1'))
|
|
model = Pipe(model, chunks=8)
|
|
assert torch.device('cuda:1') == model(torch.rand(16, 16).cuda(0)).local_value().device
|
|
assert [torch.device('cuda:0'), torch.device('cuda:1')] == model.devices
|
|
|
|
model = nn.Sequential(fc1, WithDevice(fc2, 'cuda:0'))
|
|
model = Pipe(model, chunks=8)
|
|
assert torch.device('cuda:0') == model(torch.rand(16, 16).cuda(0)).local_value().device
|
|
assert [torch.device('cuda:0')] == model.devices
|
|
assert torch.device('cuda:0') == fc2.weight.device
|