mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert replicate.py to disallow replicating multi-device modules (#19278)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19278 Based on discussion in https://github.com/pytorch/pytorch/pull/19278 and https://github.com/pytorch/pytorch/pull/18687, changes to replicate.py will be reverted to disallow replicating multi-device modules. Reviewed By: pietern Differential Revision: D14940018 fbshipit-source-id: 7504c0f4325c2639264c52dcbb499e61c9ad2c26
This commit is contained in:
committed by
Facebook Github Bot
parent
b9c20d5224
commit
344acaa0ca
@ -7,7 +7,6 @@ from common_utils import TEST_WITH_ROCM, TEST_NUMBA
|
||||
|
||||
TEST_CUDA = torch.cuda.is_available()
|
||||
TEST_MULTIGPU = TEST_CUDA and torch.cuda.device_count() >= 2
|
||||
TEST_GEQ4GPU = TEST_CUDA and torch.cuda.device_count() >= 4
|
||||
CUDA_DEVICE = TEST_CUDA and torch.device("cuda:0")
|
||||
# note: if ROCm is targeted, TEST_CUDNN is code for TEST_MIOPEN
|
||||
TEST_CUDNN = TEST_CUDA and (TEST_WITH_ROCM or torch.backends.cudnn.is_acceptable(torch.tensor(1., device=CUDA_DEVICE)))
|
||||
|
139
test/test_nn.py
139
test/test_nn.py
@ -30,8 +30,7 @@ from torch.nn.parallel._functions import Broadcast
|
||||
from common_utils import freeze_rng_state, run_tests, TestCase, skipIfNoLapack, skipIfRocm, \
|
||||
TEST_NUMPY, TEST_SCIPY, download_file, PY3, PY34, to_gpu, \
|
||||
get_function_arglist, load_tests
|
||||
from common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_GEQ4GPU, TEST_CUDNN, \
|
||||
TEST_CUDNN_VERSION
|
||||
from common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, TEST_CUDNN_VERSION
|
||||
from common_nn import NNTestCase, ModuleTest, CriterionTest, TestBase, \
|
||||
module_tests, criterion_tests, loss_reference_fns, get_reduction, \
|
||||
get_weight, smoothl1loss_reference, kldivloss_reference, \
|
||||
@ -3709,7 +3708,7 @@ class TestNN(NNTestCase):
|
||||
module = nn.Linear(10, 5).float().cuda()
|
||||
input = Variable(torch.randn(2, 10).float().cuda())
|
||||
expected_output = module(input).data
|
||||
for devices in [(0, 1), [[0], [1]]]:
|
||||
for devices in [(0, 1), [0, 1]]:
|
||||
replicas = dp.replicate(module, devices)
|
||||
for i, replica in enumerate(replicas):
|
||||
for p in replica.parameters():
|
||||
@ -3717,144 +3716,12 @@ class TestNN(NNTestCase):
|
||||
replica_input = input.cuda(i)
|
||||
self.assertEqual(replica(replica_input).data, expected_output)
|
||||
|
||||
@unittest.skipIf(not TEST_GEQ4GPU, "less than 4 GPUs")
|
||||
def test_replicate_multi_gpu_module(self):
|
||||
class MultiGpuModule(nn.Module):
|
||||
def __init__(self):
|
||||
super(MultiGpuModule, self).__init__()
|
||||
self.net1 = torch.nn.Linear(10, 5).cuda(0)
|
||||
self.net2 = torch.nn.Linear(5, 5).cuda(1)
|
||||
self.bn = nn.BatchNorm2d(10).cuda(0)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.net1(x.cuda(self.net1.weight.get_device()))
|
||||
return self.net2(out.cuda(self.net2.weight.get_device()))
|
||||
|
||||
module = MultiGpuModule()
|
||||
|
||||
input = torch.rand(2, 10).cuda(0)
|
||||
expected_output = module(input).cpu()
|
||||
|
||||
for devices in ([[0, 1], [2, 3]], [[1, 0], [3, 2]]):
|
||||
replicas = dp.replicate(module, devices)
|
||||
for i, replica in enumerate(replicas):
|
||||
self.assertEqual(replica.net1.weight.get_device(), 2 * i)
|
||||
self.assertEqual(replica.net1.bias.get_device(), 2 * i)
|
||||
self.assertEqual(replica.net2.weight.get_device(), 2 * i + 1)
|
||||
self.assertEqual(replica.net2.bias.get_device(), 2 * i + 1)
|
||||
self.assertEqual(replica.bn.running_mean.get_device(), 2 * i)
|
||||
self.assertEqual(replica.bn.running_var.get_device(), 2 * i)
|
||||
self.assertEqual(
|
||||
replica.bn.num_batches_tracked.get_device(), 2 * i)
|
||||
|
||||
replica_input = input.cuda(2 * i)
|
||||
replica_output = replica(replica_input)
|
||||
self.assertEqual(replica_output.get_device(), 2 * i + 1)
|
||||
self.assertEqual(replica_output.cpu(), expected_output)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
||||
def test_replicate_device_indices(self):
|
||||
from torch.nn.parallel.replicate import _to_device_index as f
|
||||
|
||||
self.assertEqual(
|
||||
f([['cuda:0', 'cuda:1', 'cuda:2'],
|
||||
['cuda:4', 'cuda:3', 'cuda:6']]),
|
||||
[[0, 1, 2], [4, 3, 6]])
|
||||
|
||||
self.assertEqual(f(('cuda:0', 'cuda:1', 'cuda:2')), [0, 1, 2])
|
||||
|
||||
self.assertEqual(
|
||||
len(set([0, 1, 2]).intersection(f({'cuda:0', 'cuda:1', 'cuda:2'}))),
|
||||
3)
|
||||
self.assertEqual(
|
||||
f([['cuda:0'], ['cuda:1'], ['cuda:2']]), [[0], [1], [2]])
|
||||
|
||||
msg = "empty device list"
|
||||
for devices in (None, (), [], [[]]):
|
||||
with self.assertRaisesRegex(RuntimeError, msg):
|
||||
f(devices)
|
||||
|
||||
msg = "unidentical number of devices"
|
||||
for devices in ([[0, 1], [2]], [[0], [1, 2]]):
|
||||
with self.assertRaisesRegex(AssertionError, msg):
|
||||
f(devices)
|
||||
|
||||
msg = "shared by multiple replicas"
|
||||
for devices in ([[0, 1], [1, 2]], [[0], [1], [0]]):
|
||||
with self.assertRaisesRegex(AssertionError, msg):
|
||||
f(devices)
|
||||
|
||||
msg = "Duplicated device ids"
|
||||
for devices in ([[0, 1, 2, 1]], [0, 1, 1], [0, 0]):
|
||||
with self.assertRaisesRegex(AssertionError, msg):
|
||||
f(devices)
|
||||
|
||||
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
|
||||
def test_replicate_tensor_grouping_multi_gpu(self):
|
||||
from torch.nn.parallel.replicate import _group_by_device
|
||||
|
||||
a = torch.Tensor(1).cuda(0)
|
||||
b = torch.Tensor(2).cuda(0)
|
||||
c = torch.Tensor(3).cuda(1)
|
||||
d = torch.Tensor(4).cuda(0)
|
||||
e = torch.Tensor(5).cuda(1)
|
||||
|
||||
tensors = [a, b, c, d, e]
|
||||
for devices in ([[0, 1], [2, 3]], [[1, 4, 0], [3, 5, 2]]):
|
||||
grouped_tensors, grouped_devices, original_index = \
|
||||
_group_by_device(tensors, devices)
|
||||
|
||||
self.assertEqual(grouped_tensors, [[a, b, d], [c, e]])
|
||||
self.assertEqual(grouped_devices, [[0, 2], [1, 3]])
|
||||
self.assertEqual(original_index, [[0, 1, 3], [2, 4]])
|
||||
|
||||
msg = "missing from devices"
|
||||
for devices in ([[0, 2], [1, 3]], [[1, 2], [0, 3]], [[2, 3], [0, 1]]):
|
||||
with self.assertRaisesRegex(AssertionError, msg):
|
||||
grouped_tensors, grouped_devices, original_index = \
|
||||
_group_by_device(tensors, devices)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
||||
def test_replicate_tensor_grouping(self):
|
||||
from torch.nn.parallel.replicate import _group_by_device
|
||||
|
||||
a = torch.Tensor(1).cuda(0)
|
||||
b = torch.Tensor(2).cuda(0)
|
||||
c = torch.Tensor(3).cuda(0)
|
||||
|
||||
tensors = [a, b, c]
|
||||
|
||||
grouped_tensors, grouped_devices, original_index = \
|
||||
_group_by_device(tensors, [0, 1])
|
||||
|
||||
self.assertEqual(grouped_tensors, [[a, b, c]])
|
||||
self.assertEqual(grouped_devices, [[0, 1]])
|
||||
self.assertEqual(original_index, [[0, 1, 2]])
|
||||
|
||||
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
|
||||
def test_replicate_reshape(self):
|
||||
from torch.nn.parallel.replicate import _broadcast_coalesced_reshape
|
||||
|
||||
a = torch.Tensor(1).cuda(0)
|
||||
b = torch.Tensor(2).cuda(0)
|
||||
c = torch.Tensor(3).cuda(1)
|
||||
d = torch.Tensor(4).cuda(0)
|
||||
e = torch.Tensor(5).cuda(1)
|
||||
|
||||
tensors = [a, b, c, d, e]
|
||||
outputs = _broadcast_coalesced_reshape(tensors, [[0, 1], [1, 0]])
|
||||
|
||||
self.assertEqual(len(outputs), 2)
|
||||
self.assertEqual(outputs[0], [a, b, c, d, e])
|
||||
self.assertEqual(
|
||||
outputs[1], [a.cuda(1), b.cuda(1), c.cuda(0), d.cuda(1), e.cuda(0)])
|
||||
|
||||
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
|
||||
def test_replicate_buffers(self):
|
||||
net = nn.Module()
|
||||
net.bn = nn.BatchNorm2d(10)
|
||||
net.cuda()
|
||||
for devices in [(0, 1), [[0], [1]]]:
|
||||
for devices in [(0, 1), [0, 1]]:
|
||||
replicas = dp.replicate(net, devices)
|
||||
for i, replica in enumerate(replicas):
|
||||
self.assertEqual(replica.bn.running_mean.get_device(), i, 'buffer on wrong device')
|
||||
|
@ -56,34 +56,6 @@ def _replicatable_module(module, memo=None):
|
||||
return True
|
||||
|
||||
|
||||
def _to_device_index(devices):
|
||||
if not devices:
|
||||
raise RuntimeError("Cannot replicate using an empty device list.")
|
||||
|
||||
if isinstance(devices, list) and isinstance(devices[0], list):
|
||||
device_ids = []
|
||||
seen = set()
|
||||
for i, replica_devs in enumerate(devices):
|
||||
assert len(replica_devs) == len(devices[0]), (
|
||||
"Cannot replicate to unidentical number of devices, but got "
|
||||
"device list {} and {} for replica {} and {}."
|
||||
).format(devices[0], devices[i], 0, i)
|
||||
|
||||
assert len(seen.intersection(replica_devs)) == 0, (
|
||||
"Devices {} are shared by multiple replicas."
|
||||
).format(seen.intersection(replica_devs))
|
||||
seen.update(replica_devs)
|
||||
|
||||
device_ids.append(_to_device_index(replica_devs))
|
||||
return device_ids
|
||||
else:
|
||||
assert len(devices) == len(set(devices)), (
|
||||
"Duplicated device ids {}."
|
||||
).format(devices)
|
||||
|
||||
return list(map(lambda x: _get_device_index(x, True), devices))
|
||||
|
||||
|
||||
def _build_param_dict(modules, module_copies, module_indices):
|
||||
param_dict = {}
|
||||
for module in modules:
|
||||
@ -111,169 +83,26 @@ def _copy_scriptmodule_methods(modules, module_copies, module_indices):
|
||||
replica._copy_method(method_name, param_list, module)
|
||||
|
||||
|
||||
# Group tensors on the same device together, which can later be broadcast to
|
||||
# a list of devices. For example,consider 5 tensors on 2 devices
|
||||
# a = torch.Tensor(0).cuda(0)
|
||||
# b = torch.Tensor(0).cuda(0)
|
||||
# c = torch.Tensor(0).cuda(1)
|
||||
# d = torch.Tensor(0).cuda(0)
|
||||
# e = torch.Tensor(0).cuda(1).
|
||||
# Let inputs be
|
||||
# tensors = [a, b, c, d, e] and
|
||||
# devices = [[0, 1], [2, 3]].
|
||||
# Then, outputs will be:
|
||||
# grouped_tensors = [[a, b, d], [c, e]],
|
||||
# grouped_devices = [[0, 2], [1, 3]],
|
||||
# original_index = [[0, 1, 3], [2, 4]],
|
||||
# meaning that grouped_tensors[i] will be broadcast to grouped_devices[i].
|
||||
def _group_by_device(tensors, devices):
|
||||
if isinstance(devices[0], list):
|
||||
# all tensor devices must appear in devices[0]
|
||||
missing_devs = [t.device.index for t in tensors
|
||||
if t.device.index not in devices[0]]
|
||||
assert not missing_devs, (
|
||||
"tensor devices {} are missing from devices[0] {}."
|
||||
).format(missing_devs, devices[0])
|
||||
|
||||
# device id to output group index, this is necessary when `tensors` only
|
||||
# use a subset of devices in `devices[0]`
|
||||
dev_to_group_idx = {}
|
||||
for t in tensors:
|
||||
if t.device.index not in dev_to_group_idx:
|
||||
dev_to_group_idx[t.device.index] = len(dev_to_group_idx)
|
||||
|
||||
# Group tensors by devices and remember each tensor's original index.
|
||||
# The original_index helps to recover the original input tensor order
|
||||
# from grouped tensors.
|
||||
grouped_tensors = [[] for _ in range(len(dev_to_group_idx))]
|
||||
original_index = [[] for _ in range(len(dev_to_group_idx))]
|
||||
for i, t in enumerate(tensors):
|
||||
group_id = dev_to_group_idx[t.device.index]
|
||||
original_index[group_id].append(i)
|
||||
grouped_tensors[group_id].append(t)
|
||||
|
||||
# group devices together if they should be in the same broadcast call
|
||||
grouped_devices = [[] for _ in range(len(dev_to_group_idx))]
|
||||
transpose = list(zip(*devices))
|
||||
for row in transpose:
|
||||
if row[0] in dev_to_group_idx:
|
||||
grouped_devices[dev_to_group_idx[row[0]]] = list(row)
|
||||
|
||||
return grouped_tensors, grouped_devices, original_index
|
||||
else:
|
||||
return [tensors], [devices], [list(range(len(tensors)))]
|
||||
|
||||
|
||||
# Return len(devices) replicas of input tensors. If input tensors reside on
|
||||
# multiple GPUs, devices must be a 2D list with devices[0] matching input
|
||||
# tensors' devices. For example,consider 5 tensors on 2 devices
|
||||
# a = torch.Tensor(0).cuda(0)
|
||||
# b = torch.Tensor(0).cuda(0)
|
||||
# c = torch.Tensor(0).cuda(1)
|
||||
# d = torch.Tensor(0).cuda(0)
|
||||
# e = torch.Tensor(0).cuda(1).
|
||||
# Let inputs be
|
||||
# tensors = [a, b, c, d, e] and
|
||||
# devices = [[0, 1], [2, 3]].
|
||||
#
|
||||
# The output will be a 2D list of tensors:
|
||||
# [[a0, b0, c0, d0, e0],
|
||||
# [a1, b1, c1, d1, e1]], where
|
||||
# a0, b0, d0 are on device 0
|
||||
# a1, b1, d1 are on device 2
|
||||
# c0, e0 are on device 1
|
||||
# c1, e1 are on device 3
|
||||
#
|
||||
# This example will be used throughout the implementation of this function.
|
||||
def _broadcast_coalesced_reshape(tensors, devices, detach=False):
|
||||
from ._functions import Broadcast
|
||||
|
||||
# a triply-nested list of 1) broadcast group, 2) tensor list replica,
|
||||
# 3) tensors on the same device.
|
||||
grouped_replicas = []
|
||||
grouped_tensors, grouped_devices, original_index = \
|
||||
_group_by_device(tensors, devices)
|
||||
# For the example input described above, we have
|
||||
# grouped_tensors =[[a, b, d], [c, e]]
|
||||
# grouped_devices = [[0, 2], [1, 3]]
|
||||
# original_index = [[0, 1, 3], [2, 4]]
|
||||
for tensor_group, device_group in zip(grouped_tensors, grouped_devices):
|
||||
if detach:
|
||||
grouped_replicas.append(
|
||||
comm.broadcast_coalesced(tensor_group, device_group))
|
||||
else:
|
||||
if len(tensor_group) > 0:
|
||||
# Use the autograd function to broadcast if not detach
|
||||
tensor_copies = Broadcast.apply(device_group, *tensor_group)
|
||||
grouped_replicas.append(
|
||||
[tensor_copies[i:i + len(tensor_group)]
|
||||
for i in range(
|
||||
0, len(tensor_copies), len(tensor_group))])
|
||||
else:
|
||||
grouped_replicas.append([])
|
||||
|
||||
if isinstance(devices[0], list):
|
||||
# convert the triply-nested list into a doubly-nested list of 1) replica
|
||||
# 2) tensors in the same replica (can be on different devices)
|
||||
#
|
||||
# For the example input described above, we have
|
||||
# grouped_replicas = [
|
||||
# [[a0, b0, d0], # on device 0
|
||||
# [a1, b1, d1]], # on device 2
|
||||
# [[c0, e0], # on device 1
|
||||
# [c1, e1]] # on device 3
|
||||
# ]
|
||||
#
|
||||
# The code below re-organize elements in grouped_replicas to the
|
||||
# expected form:
|
||||
# [[a0, b0, c0, d0, e0],
|
||||
# [a1, b1, c1, d1, e1]].
|
||||
transpose = [0 for _ in tensors]
|
||||
for g_idx in range(len(original_index)):
|
||||
for t_idx in range(len(original_index[g_idx])):
|
||||
# g_idx is the broadcast group index.
|
||||
# t_idx is the tensor's index in a replica within a group.
|
||||
# Tensors in grouped_replicas[g_idx, :, t_idx] are replicas of
|
||||
# input tensor[original_index[g_idx][t_idx]]. Retrieve the
|
||||
# column and add it as the original_index[g_idx][t_idx]'s row in
|
||||
# transpose.
|
||||
transpose[original_index[g_idx][t_idx]] = \
|
||||
[replica[t_idx] for replica in grouped_replicas[g_idx]]
|
||||
|
||||
# transpose the result to stay consistent with the 1D devices case.
|
||||
return list(zip(*transpose))
|
||||
if detach:
|
||||
return comm.broadcast_coalesced(tensors, devices)
|
||||
else:
|
||||
return grouped_replicas[0]
|
||||
# Use the autograd function to broadcast if not detach
|
||||
if len(tensors) > 0:
|
||||
tensor_copies = Broadcast.apply(devices, *tensors)
|
||||
return [tensor_copies[i:i + len(tensors)]
|
||||
for i in range(0, len(tensor_copies), len(tensors))]
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
def replicate(network, devices, detach=False):
|
||||
r"""Replicate the input :attr:`network` to given :attr:`devices`. If
|
||||
:attr:`network` resides on CPU or a single GPU, :attr:`devices` must be a 1D
|
||||
list of destination devices. If :attr:`network` resides on multiple GPUs,
|
||||
:attr:`devices` must be satisfy the following conditions:
|
||||
|
||||
1. :attr:`devices` must be a 2D list,
|
||||
2. ``devices[0]`` must match the :attr:`network`'s devices, in any order.
|
||||
3. All ``devices[i]`` must have the same length.
|
||||
|
||||
For example, :attr:`network` is a ``Sequential`` module with two ``Linear``
|
||||
layers stored on ``cuda:0`` and ``cuda:1`` respectively. Setting
|
||||
:attr:`devices` to ``[[0, 1], [2, 3], [4, 5]]`` will replicate
|
||||
:attr:`network` three times with replicas stored on devices
|
||||
``[cuda:0, cuda:1]``, ``[cuda:2, cuda:3]``, and ``[cuda:4, cuda:5]``
|
||||
respectively.
|
||||
|
||||
|
||||
Args:
|
||||
network (Module): modules to be replicate
|
||||
devices (1D or 2D list of int or torch.device): CUDA devices
|
||||
detach (bool, optional): detached replicas from the current graph.
|
||||
"""
|
||||
if not _replicatable_module(network):
|
||||
raise RuntimeError("Cannot replicate network where python modules are "
|
||||
"childrens of ScriptModule")
|
||||
|
||||
devices = _to_device_index(devices)
|
||||
devices = list(map(lambda x: _get_device_index(x, True), devices))
|
||||
num_replicas = len(devices)
|
||||
|
||||
params = list(network.parameters())
|
||||
|
Reference in New Issue
Block a user