Revert replicate.py to disallow replicating multi-device modules (#19278)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19278

Based on discussion in https://github.com/pytorch/pytorch/pull/19278 and https://github.com/pytorch/pytorch/pull/18687, changes to replicate.py will be reverted to disallow replicating multi-device modules.

Reviewed By: pietern

Differential Revision: D14940018

fbshipit-source-id: 7504c0f4325c2639264c52dcbb499e61c9ad2c26
This commit is contained in:
Shen Li
2019-04-16 09:35:36 -07:00
committed by Facebook Github Bot
parent b9c20d5224
commit 344acaa0ca
3 changed files with 13 additions and 318 deletions

View File

@ -7,7 +7,6 @@ from common_utils import TEST_WITH_ROCM, TEST_NUMBA
TEST_CUDA = torch.cuda.is_available()
TEST_MULTIGPU = TEST_CUDA and torch.cuda.device_count() >= 2
TEST_GEQ4GPU = TEST_CUDA and torch.cuda.device_count() >= 4
CUDA_DEVICE = TEST_CUDA and torch.device("cuda:0")
# note: if ROCm is targeted, TEST_CUDNN is code for TEST_MIOPEN
TEST_CUDNN = TEST_CUDA and (TEST_WITH_ROCM or torch.backends.cudnn.is_acceptable(torch.tensor(1., device=CUDA_DEVICE)))

View File

@ -30,8 +30,7 @@ from torch.nn.parallel._functions import Broadcast
from common_utils import freeze_rng_state, run_tests, TestCase, skipIfNoLapack, skipIfRocm, \
TEST_NUMPY, TEST_SCIPY, download_file, PY3, PY34, to_gpu, \
get_function_arglist, load_tests
from common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_GEQ4GPU, TEST_CUDNN, \
TEST_CUDNN_VERSION
from common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, TEST_CUDNN_VERSION
from common_nn import NNTestCase, ModuleTest, CriterionTest, TestBase, \
module_tests, criterion_tests, loss_reference_fns, get_reduction, \
get_weight, smoothl1loss_reference, kldivloss_reference, \
@ -3709,7 +3708,7 @@ class TestNN(NNTestCase):
module = nn.Linear(10, 5).float().cuda()
input = Variable(torch.randn(2, 10).float().cuda())
expected_output = module(input).data
for devices in [(0, 1), [[0], [1]]]:
for devices in [(0, 1), [0, 1]]:
replicas = dp.replicate(module, devices)
for i, replica in enumerate(replicas):
for p in replica.parameters():
@ -3717,144 +3716,12 @@ class TestNN(NNTestCase):
replica_input = input.cuda(i)
self.assertEqual(replica(replica_input).data, expected_output)
@unittest.skipIf(not TEST_GEQ4GPU, "less than 4 GPUs")
def test_replicate_multi_gpu_module(self):
class MultiGpuModule(nn.Module):
def __init__(self):
super(MultiGpuModule, self).__init__()
self.net1 = torch.nn.Linear(10, 5).cuda(0)
self.net2 = torch.nn.Linear(5, 5).cuda(1)
self.bn = nn.BatchNorm2d(10).cuda(0)
def forward(self, x):
out = self.net1(x.cuda(self.net1.weight.get_device()))
return self.net2(out.cuda(self.net2.weight.get_device()))
module = MultiGpuModule()
input = torch.rand(2, 10).cuda(0)
expected_output = module(input).cpu()
for devices in ([[0, 1], [2, 3]], [[1, 0], [3, 2]]):
replicas = dp.replicate(module, devices)
for i, replica in enumerate(replicas):
self.assertEqual(replica.net1.weight.get_device(), 2 * i)
self.assertEqual(replica.net1.bias.get_device(), 2 * i)
self.assertEqual(replica.net2.weight.get_device(), 2 * i + 1)
self.assertEqual(replica.net2.bias.get_device(), 2 * i + 1)
self.assertEqual(replica.bn.running_mean.get_device(), 2 * i)
self.assertEqual(replica.bn.running_var.get_device(), 2 * i)
self.assertEqual(
replica.bn.num_batches_tracked.get_device(), 2 * i)
replica_input = input.cuda(2 * i)
replica_output = replica(replica_input)
self.assertEqual(replica_output.get_device(), 2 * i + 1)
self.assertEqual(replica_output.cpu(), expected_output)
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
def test_replicate_device_indices(self):
from torch.nn.parallel.replicate import _to_device_index as f
self.assertEqual(
f([['cuda:0', 'cuda:1', 'cuda:2'],
['cuda:4', 'cuda:3', 'cuda:6']]),
[[0, 1, 2], [4, 3, 6]])
self.assertEqual(f(('cuda:0', 'cuda:1', 'cuda:2')), [0, 1, 2])
self.assertEqual(
len(set([0, 1, 2]).intersection(f({'cuda:0', 'cuda:1', 'cuda:2'}))),
3)
self.assertEqual(
f([['cuda:0'], ['cuda:1'], ['cuda:2']]), [[0], [1], [2]])
msg = "empty device list"
for devices in (None, (), [], [[]]):
with self.assertRaisesRegex(RuntimeError, msg):
f(devices)
msg = "unidentical number of devices"
for devices in ([[0, 1], [2]], [[0], [1, 2]]):
with self.assertRaisesRegex(AssertionError, msg):
f(devices)
msg = "shared by multiple replicas"
for devices in ([[0, 1], [1, 2]], [[0], [1], [0]]):
with self.assertRaisesRegex(AssertionError, msg):
f(devices)
msg = "Duplicated device ids"
for devices in ([[0, 1, 2, 1]], [0, 1, 1], [0, 0]):
with self.assertRaisesRegex(AssertionError, msg):
f(devices)
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
def test_replicate_tensor_grouping_multi_gpu(self):
from torch.nn.parallel.replicate import _group_by_device
a = torch.Tensor(1).cuda(0)
b = torch.Tensor(2).cuda(0)
c = torch.Tensor(3).cuda(1)
d = torch.Tensor(4).cuda(0)
e = torch.Tensor(5).cuda(1)
tensors = [a, b, c, d, e]
for devices in ([[0, 1], [2, 3]], [[1, 4, 0], [3, 5, 2]]):
grouped_tensors, grouped_devices, original_index = \
_group_by_device(tensors, devices)
self.assertEqual(grouped_tensors, [[a, b, d], [c, e]])
self.assertEqual(grouped_devices, [[0, 2], [1, 3]])
self.assertEqual(original_index, [[0, 1, 3], [2, 4]])
msg = "missing from devices"
for devices in ([[0, 2], [1, 3]], [[1, 2], [0, 3]], [[2, 3], [0, 1]]):
with self.assertRaisesRegex(AssertionError, msg):
grouped_tensors, grouped_devices, original_index = \
_group_by_device(tensors, devices)
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
def test_replicate_tensor_grouping(self):
from torch.nn.parallel.replicate import _group_by_device
a = torch.Tensor(1).cuda(0)
b = torch.Tensor(2).cuda(0)
c = torch.Tensor(3).cuda(0)
tensors = [a, b, c]
grouped_tensors, grouped_devices, original_index = \
_group_by_device(tensors, [0, 1])
self.assertEqual(grouped_tensors, [[a, b, c]])
self.assertEqual(grouped_devices, [[0, 1]])
self.assertEqual(original_index, [[0, 1, 2]])
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
def test_replicate_reshape(self):
from torch.nn.parallel.replicate import _broadcast_coalesced_reshape
a = torch.Tensor(1).cuda(0)
b = torch.Tensor(2).cuda(0)
c = torch.Tensor(3).cuda(1)
d = torch.Tensor(4).cuda(0)
e = torch.Tensor(5).cuda(1)
tensors = [a, b, c, d, e]
outputs = _broadcast_coalesced_reshape(tensors, [[0, 1], [1, 0]])
self.assertEqual(len(outputs), 2)
self.assertEqual(outputs[0], [a, b, c, d, e])
self.assertEqual(
outputs[1], [a.cuda(1), b.cuda(1), c.cuda(0), d.cuda(1), e.cuda(0)])
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
def test_replicate_buffers(self):
net = nn.Module()
net.bn = nn.BatchNorm2d(10)
net.cuda()
for devices in [(0, 1), [[0], [1]]]:
for devices in [(0, 1), [0, 1]]:
replicas = dp.replicate(net, devices)
for i, replica in enumerate(replicas):
self.assertEqual(replica.bn.running_mean.get_device(), i, 'buffer on wrong device')

View File

@ -56,34 +56,6 @@ def _replicatable_module(module, memo=None):
return True
def _to_device_index(devices):
if not devices:
raise RuntimeError("Cannot replicate using an empty device list.")
if isinstance(devices, list) and isinstance(devices[0], list):
device_ids = []
seen = set()
for i, replica_devs in enumerate(devices):
assert len(replica_devs) == len(devices[0]), (
"Cannot replicate to unidentical number of devices, but got "
"device list {} and {} for replica {} and {}."
).format(devices[0], devices[i], 0, i)
assert len(seen.intersection(replica_devs)) == 0, (
"Devices {} are shared by multiple replicas."
).format(seen.intersection(replica_devs))
seen.update(replica_devs)
device_ids.append(_to_device_index(replica_devs))
return device_ids
else:
assert len(devices) == len(set(devices)), (
"Duplicated device ids {}."
).format(devices)
return list(map(lambda x: _get_device_index(x, True), devices))
def _build_param_dict(modules, module_copies, module_indices):
param_dict = {}
for module in modules:
@ -111,169 +83,26 @@ def _copy_scriptmodule_methods(modules, module_copies, module_indices):
replica._copy_method(method_name, param_list, module)
# Group tensors on the same device together, which can later be broadcast to
# a list of devices. For example,consider 5 tensors on 2 devices
# a = torch.Tensor(0).cuda(0)
# b = torch.Tensor(0).cuda(0)
# c = torch.Tensor(0).cuda(1)
# d = torch.Tensor(0).cuda(0)
# e = torch.Tensor(0).cuda(1).
# Let inputs be
# tensors = [a, b, c, d, e] and
# devices = [[0, 1], [2, 3]].
# Then, outputs will be:
# grouped_tensors = [[a, b, d], [c, e]],
# grouped_devices = [[0, 2], [1, 3]],
# original_index = [[0, 1, 3], [2, 4]],
# meaning that grouped_tensors[i] will be broadcast to grouped_devices[i].
def _group_by_device(tensors, devices):
if isinstance(devices[0], list):
# all tensor devices must appear in devices[0]
missing_devs = [t.device.index for t in tensors
if t.device.index not in devices[0]]
assert not missing_devs, (
"tensor devices {} are missing from devices[0] {}."
).format(missing_devs, devices[0])
# device id to output group index, this is necessary when `tensors` only
# use a subset of devices in `devices[0]`
dev_to_group_idx = {}
for t in tensors:
if t.device.index not in dev_to_group_idx:
dev_to_group_idx[t.device.index] = len(dev_to_group_idx)
# Group tensors by devices and remember each tensor's original index.
# The original_index helps to recover the original input tensor order
# from grouped tensors.
grouped_tensors = [[] for _ in range(len(dev_to_group_idx))]
original_index = [[] for _ in range(len(dev_to_group_idx))]
for i, t in enumerate(tensors):
group_id = dev_to_group_idx[t.device.index]
original_index[group_id].append(i)
grouped_tensors[group_id].append(t)
# group devices together if they should be in the same broadcast call
grouped_devices = [[] for _ in range(len(dev_to_group_idx))]
transpose = list(zip(*devices))
for row in transpose:
if row[0] in dev_to_group_idx:
grouped_devices[dev_to_group_idx[row[0]]] = list(row)
return grouped_tensors, grouped_devices, original_index
else:
return [tensors], [devices], [list(range(len(tensors)))]
# Return len(devices) replicas of input tensors. If input tensors reside on
# multiple GPUs, devices must be a 2D list with devices[0] matching input
# tensors' devices. For example,consider 5 tensors on 2 devices
# a = torch.Tensor(0).cuda(0)
# b = torch.Tensor(0).cuda(0)
# c = torch.Tensor(0).cuda(1)
# d = torch.Tensor(0).cuda(0)
# e = torch.Tensor(0).cuda(1).
# Let inputs be
# tensors = [a, b, c, d, e] and
# devices = [[0, 1], [2, 3]].
#
# The output will be a 2D list of tensors:
# [[a0, b0, c0, d0, e0],
# [a1, b1, c1, d1, e1]], where
# a0, b0, d0 are on device 0
# a1, b1, d1 are on device 2
# c0, e0 are on device 1
# c1, e1 are on device 3
#
# This example will be used throughout the implementation of this function.
def _broadcast_coalesced_reshape(tensors, devices, detach=False):
from ._functions import Broadcast
# a triply-nested list of 1) broadcast group, 2) tensor list replica,
# 3) tensors on the same device.
grouped_replicas = []
grouped_tensors, grouped_devices, original_index = \
_group_by_device(tensors, devices)
# For the example input described above, we have
# grouped_tensors =[[a, b, d], [c, e]]
# grouped_devices = [[0, 2], [1, 3]]
# original_index = [[0, 1, 3], [2, 4]]
for tensor_group, device_group in zip(grouped_tensors, grouped_devices):
if detach:
grouped_replicas.append(
comm.broadcast_coalesced(tensor_group, device_group))
else:
if len(tensor_group) > 0:
# Use the autograd function to broadcast if not detach
tensor_copies = Broadcast.apply(device_group, *tensor_group)
grouped_replicas.append(
[tensor_copies[i:i + len(tensor_group)]
for i in range(
0, len(tensor_copies), len(tensor_group))])
else:
grouped_replicas.append([])
if isinstance(devices[0], list):
# convert the triply-nested list into a doubly-nested list of 1) replica
# 2) tensors in the same replica (can be on different devices)
#
# For the example input described above, we have
# grouped_replicas = [
# [[a0, b0, d0], # on device 0
# [a1, b1, d1]], # on device 2
# [[c0, e0], # on device 1
# [c1, e1]] # on device 3
# ]
#
# The code below re-organize elements in grouped_replicas to the
# expected form:
# [[a0, b0, c0, d0, e0],
# [a1, b1, c1, d1, e1]].
transpose = [0 for _ in tensors]
for g_idx in range(len(original_index)):
for t_idx in range(len(original_index[g_idx])):
# g_idx is the broadcast group index.
# t_idx is the tensor's index in a replica within a group.
# Tensors in grouped_replicas[g_idx, :, t_idx] are replicas of
# input tensor[original_index[g_idx][t_idx]]. Retrieve the
# column and add it as the original_index[g_idx][t_idx]'s row in
# transpose.
transpose[original_index[g_idx][t_idx]] = \
[replica[t_idx] for replica in grouped_replicas[g_idx]]
# transpose the result to stay consistent with the 1D devices case.
return list(zip(*transpose))
if detach:
return comm.broadcast_coalesced(tensors, devices)
else:
return grouped_replicas[0]
# Use the autograd function to broadcast if not detach
if len(tensors) > 0:
tensor_copies = Broadcast.apply(devices, *tensors)
return [tensor_copies[i:i + len(tensors)]
for i in range(0, len(tensor_copies), len(tensors))]
else:
return []
def replicate(network, devices, detach=False):
r"""Replicate the input :attr:`network` to given :attr:`devices`. If
:attr:`network` resides on CPU or a single GPU, :attr:`devices` must be a 1D
list of destination devices. If :attr:`network` resides on multiple GPUs,
:attr:`devices` must be satisfy the following conditions:
1. :attr:`devices` must be a 2D list,
2. ``devices[0]`` must match the :attr:`network`'s devices, in any order.
3. All ``devices[i]`` must have the same length.
For example, :attr:`network` is a ``Sequential`` module with two ``Linear``
layers stored on ``cuda:0`` and ``cuda:1`` respectively. Setting
:attr:`devices` to ``[[0, 1], [2, 3], [4, 5]]`` will replicate
:attr:`network` three times with replicas stored on devices
``[cuda:0, cuda:1]``, ``[cuda:2, cuda:3]``, and ``[cuda:4, cuda:5]``
respectively.
Args:
network (Module): modules to be replicate
devices (1D or 2D list of int or torch.device): CUDA devices
detach (bool, optional): detached replicas from the current graph.
"""
if not _replicatable_module(network):
raise RuntimeError("Cannot replicate network where python modules are "
"childrens of ScriptModule")
devices = _to_device_index(devices)
devices = list(map(lambda x: _get_device_index(x, True), devices))
num_replicas = len(devices)
params = list(network.parameters())