convert output_device at data_parallel from torch.device to index (#10189)

Summary:
- fixes #9984
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10189

Differential Revision: D9545390

Pulled By: weiyangfb

fbshipit-source-id: 3a6a705437553ba319e9fd4b7f676ff73857a27e
This commit is contained in:
Wei Yang
2018-09-11 20:20:54 -07:00
committed by Facebook Github Bot
parent 045f862574
commit 54107ae8cf
8 changed files with 74 additions and 32 deletions

View File

@ -12,6 +12,7 @@ from ..modules import Module
from .replicate import replicate
from .scatter_gather import scatter_kwargs, gather
from .parallel_apply import parallel_apply
from torch.cuda._utils import _get_device_index
class DistributedDataParallel(Module):
@ -90,10 +91,10 @@ class DistributedDataParallel(Module):
:meth:`forward` method.
Args:
module: module to be parallelized
device_ids: CUDA devices (default: all devices)
output_device: device location of output (default: device_ids[0])
broadcast_buffers: flag that enables syncing (broadcasting) buffers of
module (Module): module to be parallelized
device_ids (list of int or torch.device): CUDA devices (default: all devices)
output_device (int or torch.device): device location of output (default: device_ids[0])
broadcast_buffers (bool): flag that enables syncing (broadcasting) buffers of
the module at beginning of the forward function.
(default: True)
process_group: the c10d process group to be used for distributed data
@ -133,8 +134,8 @@ class DistributedDataParallel(Module):
self.dim = dim
self.module = module
self.device_ids = device_ids
self.output_device = output_device
self.device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
self.output_device = _get_device_index(output_device, True)
self.broadcast_buffers = broadcast_buffers
self.allreduce_opts = dist.AllreduceOptions()