mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-24 07:27:32 +08:00
convert output_device at data_parallel from torch.device to index (#10189)
Summary: - fixes #9984 Pull Request resolved: https://github.com/pytorch/pytorch/pull/10189 Differential Revision: D9545390 Pulled By: weiyangfb fbshipit-source-id: 3a6a705437553ba319e9fd4b7f676ff73857a27e
This commit is contained in:
committed by
Facebook Github Bot
parent
045f862574
commit
54107ae8cf
@ -12,6 +12,7 @@ from ..modules import Module
|
||||
from .replicate import replicate
|
||||
from .scatter_gather import scatter_kwargs, gather
|
||||
from .parallel_apply import parallel_apply
|
||||
from torch.cuda._utils import _get_device_index
|
||||
|
||||
|
||||
class DistributedDataParallel(Module):
|
||||
@ -90,10 +91,10 @@ class DistributedDataParallel(Module):
|
||||
:meth:`forward` method.
|
||||
|
||||
Args:
|
||||
module: module to be parallelized
|
||||
device_ids: CUDA devices (default: all devices)
|
||||
output_device: device location of output (default: device_ids[0])
|
||||
broadcast_buffers: flag that enables syncing (broadcasting) buffers of
|
||||
module (Module): module to be parallelized
|
||||
device_ids (list of int or torch.device): CUDA devices (default: all devices)
|
||||
output_device (int or torch.device): device location of output (default: device_ids[0])
|
||||
broadcast_buffers (bool): flag that enables syncing (broadcasting) buffers of
|
||||
the module at beginning of the forward function.
|
||||
(default: True)
|
||||
process_group: the c10d process group to be used for distributed data
|
||||
@ -133,8 +134,8 @@ class DistributedDataParallel(Module):
|
||||
|
||||
self.dim = dim
|
||||
self.module = module
|
||||
self.device_ids = device_ids
|
||||
self.output_device = output_device
|
||||
self.device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
|
||||
self.output_device = _get_device_index(output_device, True)
|
||||
self.broadcast_buffers = broadcast_buffers
|
||||
|
||||
self.allreduce_opts = dist.AllreduceOptions()
|
||||
|
||||
Reference in New Issue
Block a user