Files
pytorch/test/distributed/pipeline/sync/test_pipe.py
Jane Xu 34051d74da Add test owner to distributed files starting with test_ (#66797)
Summary:
Action based on https://github.com/pytorch/pytorch/issues/66232

cc pietern mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse SciPioneer H-Huang

Pull Request resolved: https://github.com/pytorch/pytorch/pull/66797

Reviewed By: gchanan

Differential Revision: D31761389

Pulled By: janeyx99

fbshipit-source-id: c27c9ab4acec1eb71d5edd4538cd113b770dfc6c
2021-10-19 10:55:20 -07:00

822 lines
23 KiB
Python

# Owner(s): ["oncall: distributed"]
# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from collections import OrderedDict
from copy import deepcopy
import time
import pytest
import random
import torch
from torch import nn
from torch import Tensor
from torch.distributed.pipeline.sync import Pipe, NoChunk, WithDevice
from torch.distributed.pipeline.sync.pipe import PipeSequential
skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
def test_pipe_without_rpc():
model = nn.Sequential(nn.Linear(1, 1))
with pytest.raises(RuntimeError, match='Please initialize RPC framework'):
pipe = Pipe(model, chunks=1)
def test_parameters(setup_rpc):
model = nn.Sequential(nn.Linear(1, 1))
pipe = Pipe(model, chunks=1)
assert list(pipe.parameters()) != []
def test_public_attrs(setup_rpc):
class MyString:
def __init__(self, value):
self.value = value
def __str__(self):
return self.value
model = nn.Sequential(nn.Linear(1, 1))
pipe = Pipe(model, chunks=42.000, checkpoint=MyString("always"))
assert pipe.devices == [torch.device("cpu")]
assert pipe.chunks == 42
assert isinstance(pipe.chunks, int)
assert pipe.checkpoint == "always"
assert isinstance(pipe.checkpoint, str)
def test_sequential_like(setup_rpc):
a = nn.Linear(1, 1)
b = nn.Linear(1, 1)
model = nn.Sequential(a, b)
model = Pipe(model)
assert len(model) == 2
assert list(model) == [a, b]
assert model[0] is a
assert model[1] is b
with pytest.raises(IndexError):
_ = model[2]
assert model[-1] is b
assert model[-2] is a
def test_chunks_less_than_1(setup_rpc):
model = nn.Sequential(nn.Linear(1, 1))
with pytest.raises(ValueError):
Pipe(model, chunks=0)
with pytest.raises(ValueError):
Pipe(model, chunks=-1)
def test_batch_size_indivisible(setup_rpc):
model = nn.Sequential(nn.Linear(1, 1))
model = Pipe(model, chunks=4)
with pytest.warns(None) as record:
model(torch.rand(7, 1))
# Indivisible batch size is legal.
assert not record
def test_batch_size_small(setup_rpc):
model = nn.Sequential(nn.Linear(1, 1))
model = Pipe(model, chunks=4)
with pytest.warns(None) as record:
model(torch.rand(2, 1))
# Batch size smaller than chunks is legal.
assert not record
def test_checkpoint_mode(setup_rpc):
def count_grad_fn(grad_fn, name, visited=None):
if visited is None:
visited = set()
if grad_fn in visited:
return 0
visited.add(grad_fn)
if grad_fn is None:
return 0
if grad_fn.__class__.__name__ == name:
return 1
counter = 0
for next_grad_fn, _ in grad_fn.next_functions:
counter += count_grad_fn(next_grad_fn, name, visited=visited)
return counter
model = nn.Sequential(nn.Linear(1, 1))
input = torch.rand(2, 1)
always = Pipe(model, chunks=2, checkpoint="always")
except_last = Pipe(model, chunks=2, checkpoint="except_last")
never = Pipe(model, chunks=2, checkpoint="never")
always_output = always(input)
except_last_output = except_last(input)
never_output = never(input)
assert count_grad_fn(always_output.local_value().grad_fn, "CheckpointBackward") == 2
assert count_grad_fn(except_last_output.local_value().grad_fn, "CheckpointBackward") == 1
assert count_grad_fn(never_output.local_value().grad_fn, "CheckpointBackward") == 0
def test_checkpoint_mode_invalid(setup_rpc):
model = nn.Sequential(nn.Linear(1, 1))
with pytest.raises(ValueError, match="checkpoint is not one of 'always', 'except_last', or 'never'"):
Pipe(model, chunks=2, checkpoint="INVALID_CHECKPOINT")
def test_checkpoint_mode_when_chunks_1(setup_rpc):
model = nn.Sequential(nn.Linear(1, 1))
# All checkpoint modes are fine.
Pipe(model, chunks=1, checkpoint="except_last")
Pipe(model, chunks=1, checkpoint="always")
Pipe(model, chunks=1, checkpoint="never")
def test_checkpoint_eval(setup_rpc):
model = nn.Sequential(nn.Linear(1, 1))
model = Pipe(model, chunks=2)
input = torch.rand(2, 1)
def find_grad_fn(grad_fn, name):
if grad_fn is None:
return False
if grad_fn.__class__.__name__ == name:
return True
for next_grad_fn, _ in grad_fn.next_functions:
if find_grad_fn(next_grad_fn, name):
return True
return False
model.train()
train_output = model(input)
assert find_grad_fn(train_output.local_value().grad_fn, "CheckpointBackward")
assert find_grad_fn(train_output.local_value().grad_fn, "RecomputeBackward")
model.eval()
eval_output = model(input)
assert not find_grad_fn(eval_output.local_value().grad_fn, "CheckpointBackward")
assert not find_grad_fn(eval_output.local_value().grad_fn, "RecomputeBackward")
def test_checkpoint_non_float_input(setup_rpc):
class ForkNonFloat(nn.Module):
def forward(self, input):
return (input * 2, torch.tensor([False]))
class JoinNonFloat(nn.Module):
def forward(self, input, non_float):
return input * 2
model = nn.Sequential(ForkNonFloat(), JoinNonFloat())
model = Pipe(model, chunks=1, checkpoint="always")
input = torch.rand(1, requires_grad=True)
output = model(input)
output.backward()
def test_no_grad(setup_rpc):
model = nn.Sequential(nn.Linear(1, 1))
model = Pipe(model, chunks=2)
input = torch.rand(2, 1)
latent = None
def hook(module, input, output):
_ = module
_ = input
nonlocal latent
latent = output
partition = model.partitions[0]
partition.register_forward_hook(hook)
with torch.no_grad():
model(input)
assert latent.grad_fn is None
def test_exception(setup_rpc):
class ExpectedException(Exception):
pass
class Raise(nn.Module):
def forward(self, *_):
raise ExpectedException()
model = nn.Sequential(Raise())
model = Pipe(model, chunks=1)
with pytest.raises(ExpectedException):
model(torch.rand(1))
def test_exception_early_stop_asap(setup_rpc):
"""Even the first partitions have finished to process, the partition before
the failed partition should be killed as soon as possible.
"""
class ExpectedException(Exception):
pass
class Pass(nn.Module):
def forward(self, x):
return x
counter = 0
class Counter(nn.Module):
def forward(self, x):
time.sleep(0.1)
nonlocal counter
counter += 1
return x
class Raise(nn.Module):
def forward(self, x):
raise ExpectedException()
model = nn.Sequential(Pass(), Pass(), Counter(), Raise())
model = Pipe(model, chunks=3)
with pytest.raises(ExpectedException):
model(torch.rand(3))
# If the early stop doesn't work, it would be 3 instead.
assert counter == 2
def test_nested_input(setup_rpc):
class NestedInput(nn.Module):
def __init__(self):
super().__init__()
self.fc_a = nn.Linear(1, 1)
self.fc_b = nn.Linear(1, 1)
def forward(self, inp):
return inp
model = nn.Sequential(NestedInput())
model = Pipe(model, chunks=2)
a = torch.rand(10, 1, requires_grad=True)
b = torch.rand(10, 1, requires_grad=True)
# TypeError: expected Tensor, but got tuple
with pytest.raises(TypeError):
model((a, (a, b))).local_value()
# TypeError: expected Tensor, but got list
with pytest.raises(TypeError):
model((a, [a, b])).local_value()
def test_input_pair(setup_rpc):
class Two(nn.Module):
def __init__(self):
super().__init__()
self.fc_a = nn.Linear(1, 1)
self.fc_b = nn.Linear(1, 1)
def forward(self, a, b):
return (self.fc_a(a), self.fc_b(b))
model = nn.Sequential(Two())
model = Pipe(model, chunks=2)
a = torch.rand(10, 1, requires_grad=True)
b = torch.rand(10, 1, requires_grad=True)
a_out, b_out = model(a, b).local_value()
loss = (a_out + b_out).mean()
loss.backward()
assert a.grad is not None
assert b.grad is not None
def test_multi_sequence_input(setup_rpc):
class MultiSeq(nn.Module):
def forward(self, tup1, tup2):
return tup1, tup2
model = Pipe(nn.Sequential(MultiSeq()))
with pytest.raises(TypeError):
model(
[torch.rand(10), torch.rand(10)],
[torch.rand(10), torch.rand(10)]
)
def test_input_singleton(setup_rpc):
class One(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(1, 1)
def forward(self, a):
return (self.fc(a),)
model = nn.Sequential(One())
model = Pipe(model, chunks=2)
a = torch.rand(10, 1, requires_grad=True)
(a_out,) = model(a).local_value()
loss = a_out.mean()
loss.backward()
assert all(p.grad is not None for p in model.parameters())
assert a.grad is not None
def test_input_varargs(setup_rpc):
model = nn.Sequential(nn.Linear(1, 1))
model = Pipe(model)
a = torch.rand(1)
b = torch.rand(1)
# TypeError: forward() takes 2 positional arguments but 3 were given
with pytest.raises(TypeError):
model(a, b)
def test_non_tensor(setup_rpc):
class NonTensor(nn.Module):
def forward(self, _):
return "hello"
model = nn.Sequential(NonTensor())
model = Pipe(model)
x = torch.rand(1)
with pytest.raises(TypeError):
model(x)
with pytest.raises(TypeError):
model("hello")
def test_non_tensor_sequence(setup_rpc):
class NonTensorTuple(nn.Module):
def forward(self, x):
return (x, "hello")
class NonTensorArgs(nn.Module):
def forward(self, x: str, y: bool):
return x, y
model = nn.Sequential(NonTensorTuple())
model = Pipe(model)
x = torch.rand(1)
with pytest.raises(TypeError):
model((x, "hello"))
with pytest.raises(TypeError):
model([x, "hello"])
model = nn.Sequential(NonTensorArgs())
model = Pipe(model)
with pytest.raises(TypeError):
# Need atleast one Tensor.
model("hello", True)
@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"])
def test_valid_non_tensor(checkpoint, setup_rpc):
class NonTensor1(nn.Module):
def forward(self, a: int, b: Tensor, c: bool, d: Tensor):
res = b + a if c else b * a
if d is not None:
res += d
return res, c, a, b, "hello", d
class NonTensor2(nn.Module):
def forward(self, a: Tensor, b: bool, c: int, d: Tensor, e: str, f: Tensor):
res = a * c if b else a + c
res += d
return c, res, a, d + f if f is not None else d, b, e, f
model = Pipe(nn.Sequential(NonTensor1(), NonTensor2()), chunks=5, checkpoint=checkpoint)
a = random.randint(0, 10)
b = torch.rand(10, 10)
c = random.randint(0, 1) == 0
d = torch.rand(10, 10)
res = model(a, b, c, d).local_value()
assert 7 == len(res)
assert [a] * 5 == res[0]
if c:
assert torch.allclose(((b + a + d) * a) + b, res[1])
assert torch.allclose(b + a + d, res[2])
else:
assert torch.allclose(((b * a) + d + a) + b, res[1])
assert torch.allclose(b * a + d, res[2])
assert torch.allclose(b + d, res[3])
assert [c] * 5 == res[4]
assert ["hello"] * 5 == res[5]
assert torch.allclose(d, res[6])
# Test one of the tensors can be None
res = model(a, b, c, None).local_value()
assert 7 == len(res)
assert [a] * 5 == res[0]
if c:
assert torch.allclose(((b + a) * a) + b, res[1])
assert torch.allclose(b + a, res[2])
else:
assert torch.allclose(((b * a) + a) + b, res[1])
assert torch.allclose(b * a, res[2])
assert torch.allclose(b, res[3])
assert [c] * 5 == res[4]
assert ["hello"] * 5 == res[5]
assert [None] * 5 == res[6]
# Need atleast one tensor.
with pytest.raises(TypeError):
model(a, None, c, None)
@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"])
def test_no_tensor_output(checkpoint, setup_rpc):
class Model1(nn.Module):
def forward(self, a: int, b: Tensor, c: bool):
return a, c, "hello"
class Model2(nn.Module):
def forward(self, a: int, b: bool, c: str):
return a, c, b
model = Pipe(nn.Sequential(Model1(), Model2()), chunks=5)
a = random.randint(0, 10)
b = torch.rand(10, 10)
c = random.randint(0, 1) == 0
# Need atleast one tensor across partitions too.
with pytest.raises(TypeError):
res = model(a, b, c).local_value()
@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"])
def test_uneven_batch_size(checkpoint, setup_rpc):
class Model(nn.Module):
def forward(self, a: Tensor, b: int, c: Tensor):
return a, b, c
model = Pipe(nn.Sequential(Model()), checkpoint=checkpoint, chunks=5)
a = torch.rand(3, 10)
b = random.randint(0, 10)
c = torch.rand(6, 10)
res = model(a, b, c).local_value()
assert torch.allclose(a, res[0])
assert [b] * 3 == res[1] # 3 chunks
assert torch.allclose(c, res[2])
# Two tensors producing uneven chunks would fail.
model = Pipe(nn.Sequential(Model()), checkpoint=checkpoint, chunks=5)
a = torch.rand(3, 10)
b = random.randint(0, 10)
c = torch.rand(4, 10)
with pytest.raises(RuntimeError, match='Found different number of chunks'):
model(a, b, c)
@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"])
def test_no_chunk(checkpoint, setup_rpc):
class Model(nn.Module):
def forward(self, a: Tensor, b: int, c: Tensor):
return a, b, c
model = Pipe(nn.Sequential(Model()), checkpoint=checkpoint, chunks=5)
a = torch.rand(10, 10)
b = random.randint(0, 10)
c = torch.rand(10, 10)
res = model(a, b, NoChunk(c)).local_value()
assert torch.allclose(a, res[0])
assert [b] * 5 == res[1]
# c gets replicated due to NoChunk and the same tensor gets concatenated 5
# times in the output.
assert torch.allclose(torch.cat((c, c, c, c, c)), res[2])
# Test invalid type for NoChunk
with pytest.raises(TypeError, match='NoChunk only supported for tensors'):
NoChunk(b)
@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"])
def test_deferred_batch_norm(checkpoint, setup_rpc):
bn = nn.BatchNorm2d(3)
pipe_bn = deepcopy(bn)
pipe = Pipe(
nn.Sequential(pipe_bn), chunks=2, checkpoint=checkpoint, deferred_batch_norm=True
)
x = torch.rand(4, 3, 10, 10)
pipe(x).local_value().mean().backward()
bn(x).mean().backward()
assert torch.allclose(pipe[0].running_mean, bn.running_mean, atol=1e-4)
assert torch.allclose(pipe[0].running_var, bn.running_var, atol=1e-4)
@pytest.mark.parametrize("checkpoint", ["never", "always"])
def test_deferred_batch_norm_params(checkpoint, setup_rpc):
bn = nn.BatchNorm2d(3)
pipe_bn = deepcopy(bn)
pipe = Pipe(
nn.Sequential(pipe_bn), chunks=1, checkpoint=checkpoint, deferred_batch_norm=True
)
x = torch.rand(4, 3, 10, 10)
pipe(x).local_value().mean().backward()
bn(x).mean().backward()
assert pipe[0].weight.grad is not None
assert pipe[0].bias.grad is not None
assert torch.allclose(pipe[0].weight.grad, bn.weight.grad, atol=1e-4)
assert torch.allclose(pipe[0].bias.grad, bn.bias.grad, atol=1e-4)
def test_devices(setup_rpc):
a = nn.Linear(1, 1)
b = nn.Linear(1, 1)
c = nn.Linear(1, 1)
# There are extra two devices.
model = nn.Sequential(a, b, c)
model = Pipe(model)
cpu = torch.device("cpu")
# Extra devices must be discarded.
assert model.devices == [cpu, cpu, cpu]
def test_partitions(setup_rpc):
a = nn.Linear(1, 1)
b = nn.Linear(1, 1)
model = nn.Sequential(a, b)
model = Pipe(model)
assert isinstance(model.partitions, nn.ModuleList)
assert isinstance(model.partitions[0], nn.Sequential)
assert isinstance(model.partitions[1], nn.Sequential)
assert "partitions.0.0.weight" in model.state_dict()
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
def test_merged_partitions(setup_rpc):
a = nn.Linear(1, 1).to(0)
b = nn.Sequential(nn.Linear(1, 1), nn.Linear(1, 2)).to(0)
c = nn.Linear(1, 1)
d = nn.Linear(1, 2)
model = nn.Sequential(a, b, c, d)
model = Pipe(model)
assert isinstance(model.partitions, nn.ModuleList)
assert isinstance(model.partitions[0], PipeSequential)
assert isinstance(model.partitions[1], PipeSequential)
assert list(model.partitions[0]) == [a, b[0], b[1]]
assert list(model.partitions[1]) == [c]
assert list(model.partitions[2]) == [d]
def test_deny_moving(setup_rpc):
a = nn.Linear(1, 1)
b = nn.Linear(1, 1)
model = nn.Sequential(a, b)
model = Pipe(model)
# Moving is denied.
with pytest.raises(TypeError):
model.cuda()
with pytest.raises(TypeError):
model.cpu()
with pytest.raises(TypeError):
model.to(torch.device("cuda"))
with pytest.raises(TypeError):
model.to(0)
with pytest.raises(TypeError):
model.to("cuda")
with pytest.raises(TypeError):
model.to(device=0)
with pytest.raises(TypeError):
model.to(torch.rand(1))
with pytest.raises(TypeError):
model.to(tensor=torch.rand(1))
# Casting is allowed.
model.half()
model.to(torch.double)
model.to(dtype=torch.float)
def test_empty_module(setup_rpc):
# Empty sequential module is not illegal.
model = nn.Sequential()
model = Pipe(model)
assert model(torch.tensor(42)).local_value() == torch.tensor(42)
# But only tensor or tensors is legal in Pipe.
with pytest.raises(TypeError):
model(42)
def test_named_children(setup_rpc):
a = nn.Linear(1, 1)
b = nn.Linear(1, 1)
model = nn.Sequential(OrderedDict([("a", a), ("b", b)]))
model = Pipe(model)
names = set(n for n, _ in model.named_modules())
assert "partitions.0.0" in names
assert "partitions.1.0" in names
# Pipe doesn't support __getattr__. Unlike nn.Sequential, Pipe requires
# several methods in its namespace.
with pytest.raises(AttributeError):
model.a
def test_verify_module_non_sequential(setup_rpc):
with pytest.raises(TypeError, match="module must be nn.Sequential to be partitioned"):
Pipe(nn.Module())
def test_verify_module_duplicate_children(setup_rpc):
conv = nn.Conv2d(3, 3, 1)
model = nn.Sequential(conv, conv)
with pytest.raises(ValueError, match="module with duplicate children is not supported"):
Pipe(model)
@skip_if_no_cuda
def test_verify_module_params_on_same_device(setup_rpc):
class Surrogate(nn.Module):
def __init__(self, param1, param2):
super().__init__()
self.param1 = param1
self.param2 = param2
conv1 = nn.Conv2d(3, 3, 1)
conv2 = nn.Conv2d(3, 3, 1)
model = nn.Sequential(Surrogate(conv1, conv2.cuda()))
with pytest.raises(
ValueError,
match=r'should have all parameters on a single device, please use .to\(\)'
' to place the module on a single device'):
Pipe(model)
@skip_if_no_cuda
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need atleast two GPUs")
def test_verify_nested_modules(setup_rpc):
model = nn.Sequential(
nn.Sequential(
nn.Linear(32, 16).cuda(0),
nn.Linear(16, 8).cuda(0)
),
nn.Sequential(
nn.Linear(8, 4).cuda(1),
nn.Linear(4, 2).cuda(1)
),
)
pipe = Pipe(model)
out = pipe(torch.rand(10, 32).cuda(0))
assert out.local_value().device == torch.device("cuda:1")
assert out.local_value().size() == torch.Size([10, 2])
def test_verify_module_duplicate_parameters_on_same_device(setup_rpc):
class Surrogate(nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
conv = nn.Conv2d(3, 3, 1)
model = nn.Sequential(Surrogate(conv), Surrogate(conv))
Pipe(model)
def test_forward_lockstep(setup_rpc):
timeline = []
class DelayedLog(nn.Module):
def __init__(self, j, seconds):
super().__init__()
self.i = 0
self.j = j
self.seconds = seconds
def forward(self, x):
time.sleep(self.seconds)
timeline.append((self.i, self.j))
self.i += 1
return x
model = nn.Sequential(DelayedLog(0, seconds=0), DelayedLog(1, seconds=0.1))
model = Pipe(model, chunks=3)
model(torch.rand(3, 1))
# Expected timeline: (Logs are recorded at !)
#
# Partition #0: 0! 1! 2!
# Partition #1: 000! 111! 222!
#
assert timeline == [(0, 0), (1, 0), (0, 1), (2, 0), (1, 1), (2, 1)]
@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"])
@skip_if_no_cuda
def test_multiple_inputs(checkpoint, setup_rpc):
class Module1(nn.Module):
def forward(self, a, b, c):
return a + b + c, a * b * c
class Module2(nn.Module):
def forward(self, a, b):
return a + b
model = Pipe(nn.Sequential(Module1().cuda(0), Module2().cuda(0)), chunks=2, checkpoint=checkpoint)
t = torch.rand(10)
res = model(t, t, t).local_value()
assert torch.equal(res, (t + t + t) + (t * t * t))
@skip_if_no_cuda
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need atleast two GPUs")
def test_inputs_wrong_device(setup_rpc):
class Module1(nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(5))
def forward(self, a, b):
return a + b + self.param, b
# Start inputs on wrong device and ensure Pipe moves them correctly.
a = torch.rand(10).cuda(1)
b = torch.rand(10).cuda(1)
model = Pipe(nn.Sequential(Module1().cuda(0), Module1().cuda(1)), chunks=2)
with pytest.raises(ValueError, match='All inputs should be on the same device as the first partition'):
model(a, b)
@skip_if_no_cuda
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need atleast two GPUs")
def test_with_device_wrapper(setup_rpc):
fc1 = nn.Linear(16, 8).cuda(0)
fc2 = nn.Linear(8, 4).cuda(1)
dropout = nn.Dropout()
model = nn.Sequential(fc1, fc2, WithDevice(dropout, 'cuda:1'))
model = Pipe(model, chunks=8)
assert torch.device('cuda:1') == model(torch.rand(16, 16).cuda(0)).local_value().device
assert [torch.device('cuda:0'), torch.device('cuda:1')] == model.devices
model = nn.Sequential(fc1, WithDevice(dropout, 'cuda:1'))
model = Pipe(model, chunks=8)
assert torch.device('cuda:1') == model(torch.rand(16, 16).cuda(0)).local_value().device
assert [torch.device('cuda:0'), torch.device('cuda:1')] == model.devices
model = nn.Sequential(fc1, WithDevice(fc2, 'cuda:0'))
model = Pipe(model, chunks=8)
assert torch.device('cuda:0') == model(torch.rand(16, 16).cuda(0)).local_value().device
assert [torch.device('cuda:0')] == model.devices
assert torch.device('cuda:0') == fc2.weight.device