convert output_device at data_parallel from torch.device to index (#10189)

Summary:
- fixes #9984
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10189

Differential Revision: D9545390

Pulled By: weiyangfb

fbshipit-source-id: 3a6a705437553ba319e9fd4b7f676ff73857a27e
This commit is contained in:
Wei Yang
2018-09-11 20:20:54 -07:00
committed by Facebook Github Bot
parent 045f862574
commit 54107ae8cf
8 changed files with 74 additions and 32 deletions

View File

@ -567,8 +567,7 @@ class DistributedDataParallelTest(MultiProcessTestCase):
def world_size(self): def world_size(self):
return 2 return 2
def _test_ddp_with_process_group(self, process_group): def _test_ddp_with_process_group(self, process_group, gpus):
gpus = gpus_for_rank(self.world_size)[self.rank]
model = Net() model = Net()
ddp_model = DistributedDataParallel( ddp_model = DistributedDataParallel(
copy.deepcopy(model).cuda(gpus[0]), copy.deepcopy(model).cuda(gpus[0]),
@ -620,14 +619,18 @@ class DistributedDataParallelTest(MultiProcessTestCase):
options = c10d.ProcessGroupGloo.Options() options = c10d.ProcessGroupGloo.Options()
options.devices = [c10d.ProcessGroupGloo.create_tcp_device(interface="lo")] options.devices = [c10d.ProcessGroupGloo.create_tcp_device(interface="lo")]
process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size, options) process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size, options)
self._test_ddp_with_process_group(process_group) gpus = gpus_for_rank(self.world_size)[self.rank]
self._test_ddp_with_process_group(process_group, gpus)
self._test_ddp_with_process_group(process_group, list(map(lambda i: torch.device('cuda:' + str(i)), gpus)))
@skip_if_not_multigpu @skip_if_not_multigpu
@skip_if_not_nccl @skip_if_not_nccl
def test_nccl_backend(self): def test_nccl_backend(self):
store = c10d.TCPStore('localhost', self.port, self.is_master) store = c10d.TCPStore('localhost', self.port, self.is_master)
process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size) process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)
self._test_ddp_with_process_group(process_group) gpus = gpus_for_rank(self.world_size)[self.rank]
self._test_ddp_with_process_group(process_group, gpus)
self._test_ddp_with_process_group(process_group, list(map(lambda i: torch.device('cuda:' + str(i)), gpus)))
@skip_if_not_multigpu @skip_if_not_multigpu
def test_dist_broadcast_coalesced(self): def test_dist_broadcast_coalesced(self):

View File

@ -1126,24 +1126,15 @@ class _DistTestBase(object):
# Shuffle the input so that DDP input is different # Shuffle the input so that DDP input is different
input = input[torch.randperm(batch_size)] input = input[torch.randperm(batch_size)]
@unittest.skipIf( def _test_DistributedDataParallel(self, gpu_subset, rank, output_device=None):
BACKEND != "nccl" and BACKEND != "gloo",
"Only Nccl & Gloo backend support DistributedDataParallel",
)
@skip_if_no_cuda_distributed
@skip_if_no_gpu
def test_DistributedDataParallel(self):
# Run a simple end to end DDP model, use result of single node model # Run a simple end to end DDP model, use result of single node model
# as baseline # as baseline
group, group_id, rank = self._init_global_test()
rank_to_GPU = self._init_multigpu_helper()
# cpu training setup # cpu training setup
model = self._create_Net() model = self._create_Net()
# single gpu training setup # single gpu training setup
model_gpu = copy.deepcopy(model) model_gpu = copy.deepcopy(model)
gpu_subset = list(rank_to_GPU[rank])
model_gpu.cuda(gpu_subset[0]) model_gpu.cuda(gpu_subset[0])
# DDP training setup # DDP training setup
@ -1195,6 +1186,22 @@ class _DistTestBase(object):
) )
self._barrier() self._barrier()
@unittest.skipIf(BACKEND != 'nccl' and BACKEND != 'gloo',
"Only Nccl & Gloo backend support DistributedDataParallel")
@skip_if_no_cuda_distributed
@skip_if_no_gpu
def test_DistributedDataParallel(self):
group, group_id, rank = self._init_global_test()
rank_to_GPU = self._init_multigpu_helper()
gpus = list(rank_to_GPU[rank])
self._test_DistributedDataParallel(gpu_subset=gpus, rank=rank)
# test output_device
self._test_DistributedDataParallel(gpu_subset=gpus, rank=rank, output_device=torch.device('cuda'))
# test device_ids
gpus = list(map(lambda i: torch.device('cuda:' + str(i)), gpus))
self._test_DistributedDataParallel(gpu_subset=gpus, rank=rank, output_device=torch.device('cuda'))
if BACKEND == "gloo" or BACKEND == "nccl": if BACKEND == "gloo" or BACKEND == "nccl":
WORLD_SIZE = os.environ["WORLD_SIZE"] WORLD_SIZE = os.environ["WORLD_SIZE"]

View File

@ -3154,6 +3154,24 @@ class TestNN(NNTestCase):
self.assertEqual(out.get_device(), 0) self.assertEqual(out.get_device(), 0)
self.assertEqual(out.data, expected_out) self.assertEqual(out.data, expected_out)
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
@skipIfRocm
def test_data_parallel_device_args(self):
cuda0 = torch.device('cuda:0')
cuda1 = torch.device('cuda:1')
# test output_device
l = nn.Linear(10, 5).to(cuda0, torch.float)
i = torch.randn(20, 10, dtype=torch.float, device=cuda0, requires_grad=True)
out = dp.data_parallel(l, i, device_ids=(0, 1), output_device=cuda0)
self.assertEqual(out, l(i))
# test device_ids
l = nn.Linear(10, 5).to(cuda0, torch.float)
i = torch.randn(20, 10, dtype=torch.float, device=cuda0, requires_grad=True)
out = dp.data_parallel(l, i, device_ids=(cuda0, cuda1), output_device=cuda0)
self.assertEqual(out, l(i))
def test_state_dict(self): def test_state_dict(self):
l = nn.Linear(5, 5) l = nn.Linear(5, 5)
block = nn.Module() block = nn.Module()

View File

@ -3,6 +3,7 @@ import warnings
import torch import torch
import torch.cuda.comm as comm import torch.cuda.comm as comm
from torch.autograd import Function from torch.autograd import Function
from torch.cuda._utils import _get_device_index
class Broadcast(Function): class Broadcast(Function):
@ -11,6 +12,7 @@ class Broadcast(Function):
def forward(ctx, target_gpus, *inputs): def forward(ctx, target_gpus, *inputs):
if not all(input.is_cuda for input in inputs): if not all(input.is_cuda for input in inputs):
raise TypeError('Broadcast function not implemented for CPU tensors') raise TypeError('Broadcast function not implemented for CPU tensors')
target_gpus = list(map(lambda x: _get_device_index(x, True), target_gpus))
ctx.target_gpus = target_gpus ctx.target_gpus = target_gpus
if len(inputs) == 0: if len(inputs) == 0:
return tuple() return tuple()
@ -50,6 +52,7 @@ class Gather(Function):
@staticmethod @staticmethod
def forward(ctx, target_device, dim, *inputs): def forward(ctx, target_device, dim, *inputs):
assert all(map(lambda i: i.is_cuda, inputs)) assert all(map(lambda i: i.is_cuda, inputs))
target_device = _get_device_index(target_device, True)
ctx.target_device = target_device ctx.target_device = target_device
ctx.dim = dim ctx.dim = dim
ctx.input_gpus = tuple(map(lambda i: i.get_device(), inputs)) ctx.input_gpus = tuple(map(lambda i: i.get_device(), inputs))
@ -76,6 +79,7 @@ class Scatter(Function):
@staticmethod @staticmethod
def forward(ctx, target_gpus, chunk_sizes, dim, input): def forward(ctx, target_gpus, chunk_sizes, dim, input):
target_gpus = list(map(lambda x: _get_device_index(x, True), target_gpus))
ctx.dim = dim ctx.dim = dim
ctx.input_device = input.get_device() if input.is_cuda else -1 ctx.input_device = input.get_device() if input.is_cuda else -1
streams = None streams = None

View File

@ -5,6 +5,7 @@ from ..modules import Module
from .scatter_gather import scatter_kwargs, gather from .scatter_gather import scatter_kwargs, gather
from .replicate import replicate from .replicate import replicate
from .parallel_apply import parallel_apply from .parallel_apply import parallel_apply
from torch.cuda._utils import _get_device_index
def _check_balance(device_ids): def _check_balance(device_ids):
@ -13,7 +14,7 @@ def _check_balance(device_ids):
has less than 75% of the memory or cores of GPU {}. You can do so by setting has less than 75% of the memory or cores of GPU {}. You can do so by setting
the device_ids argument to DataParallel, or by setting the CUDA_VISIBLE_DEVICES the device_ids argument to DataParallel, or by setting the CUDA_VISIBLE_DEVICES
environment variable.""" environment variable."""
device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
dev_props = [torch.cuda.get_device_properties(i) for i in device_ids] dev_props = [torch.cuda.get_device_properties(i) for i in device_ids]
def warn_imbalance(get_prop): def warn_imbalance(get_prop):
@ -77,9 +78,9 @@ class DataParallel(Module):
Args: Args:
module: module to be parallelized module (Module): module to be parallelized
device_ids: CUDA devices (default: all devices) device_ids (list of int or torch.device): CUDA devices (default: all devices)
output_device: device location of output (default: device_ids[0]) output_device (int or torch.device): device location of output (default: device_ids[0])
Attributes: Attributes:
module (Module): the module to be parallelized module (Module): the module to be parallelized
@ -104,10 +105,11 @@ class DataParallel(Module):
device_ids = list(range(torch.cuda.device_count())) device_ids = list(range(torch.cuda.device_count()))
if output_device is None: if output_device is None:
output_device = device_ids[0] output_device = device_ids[0]
self.dim = dim self.dim = dim
self.module = module self.module = module
self.device_ids = device_ids self.device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
self.output_device = output_device self.output_device = _get_device_index(output_device, True)
_check_balance(self.device_ids) _check_balance(self.device_ids)
@ -143,10 +145,10 @@ def data_parallel(module, inputs, device_ids=None, output_device=None, dim=0, mo
This is the functional version of the DataParallel module. This is the functional version of the DataParallel module.
Args: Args:
module: the module to evaluate in parallel module (Module): the module to evaluate in parallel
inputs: inputs to the module inputs (tensor): inputs to the module
device_ids: GPU ids on which to replicate module device_ids (list of int or torch.device): GPU ids on which to replicate module
output_device: GPU location of the output Use -1 to indicate the CPU. output_device (list of int or torch.device): GPU location of the output Use -1 to indicate the CPU.
(default: device_ids[0]) (default: device_ids[0])
Returns: Returns:
a Tensor containing the result of module(input) located on a Tensor containing the result of module(input) located on

View File

@ -12,6 +12,7 @@ from ..modules import Module
from .replicate import replicate from .replicate import replicate
from .scatter_gather import scatter_kwargs, gather from .scatter_gather import scatter_kwargs, gather
from .parallel_apply import parallel_apply from .parallel_apply import parallel_apply
from torch.cuda._utils import _get_device_index
class DistributedDataParallel(Module): class DistributedDataParallel(Module):
@ -90,10 +91,10 @@ class DistributedDataParallel(Module):
:meth:`forward` method. :meth:`forward` method.
Args: Args:
module: module to be parallelized module (Module): module to be parallelized
device_ids: CUDA devices (default: all devices) device_ids (list of int or torch.device): CUDA devices (default: all devices)
output_device: device location of output (default: device_ids[0]) output_device (int or torch.device): device location of output (default: device_ids[0])
broadcast_buffers: flag that enables syncing (broadcasting) buffers of broadcast_buffers (bool): flag that enables syncing (broadcasting) buffers of
the module at beginning of the forward function. the module at beginning of the forward function.
(default: True) (default: True)
process_group: the c10d process group to be used for distributed data process_group: the c10d process group to be used for distributed data
@ -133,8 +134,8 @@ class DistributedDataParallel(Module):
self.dim = dim self.dim = dim
self.module = module self.module = module
self.device_ids = device_ids self.device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
self.output_device = output_device self.output_device = _get_device_index(output_device, True)
self.broadcast_buffers = broadcast_buffers self.broadcast_buffers = broadcast_buffers
self.allreduce_opts = dist.AllreduceOptions() self.allreduce_opts = dist.AllreduceOptions()

View File

@ -1,5 +1,6 @@
import threading import threading
import torch import torch
from torch.cuda._utils import _get_device_index
def get_a_var(obj): def get_a_var(obj):
@ -22,6 +23,11 @@ def parallel_apply(modules, inputs, kwargs_tup=None, devices=None):
contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword) contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword)
on each of :attr:`devices`. on each of :attr:`devices`.
Args:
modules (Module): modules to be parallelized
inputs (tensor): inputs to the modules
devices (list of int or torch.device): CUDA devices
:attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and :attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and
:attr:`devices` (if given) should all have same length. Moreover, each :attr:`devices` (if given) should all have same length. Moreover, each
element of :attr:`inputs` can either be a single object as the only argument element of :attr:`inputs` can either be a single object as the only argument
@ -36,7 +42,7 @@ def parallel_apply(modules, inputs, kwargs_tup=None, devices=None):
assert len(modules) == len(devices) assert len(modules) == len(devices)
else: else:
devices = [None] * len(modules) devices = [None] * len(modules)
devices = list(map(lambda x: _get_device_index(x, True), devices))
lock = threading.Lock() lock = threading.Lock()
results = {} results = {}
grad_enabled = torch.is_grad_enabled() grad_enabled = torch.is_grad_enabled()

View File

@ -1,10 +1,11 @@
import torch.cuda.comm as comm import torch.cuda.comm as comm
from torch.cuda._utils import _get_device_index
def replicate(network, devices, detach=False): def replicate(network, devices, detach=False):
from ._functions import Broadcast from ._functions import Broadcast
devices = tuple(devices) devices = list(map(lambda x: _get_device_index(x, True), devices))
num_replicas = len(devices) num_replicas = len(devices)
params = list(network.parameters()) params = list(network.parameters())