mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-30 03:34:56 +08:00
convert output_device at data_parallel from torch.device to index (#10189)
Summary: - fixes #9984 Pull Request resolved: https://github.com/pytorch/pytorch/pull/10189 Differential Revision: D9545390 Pulled By: weiyangfb fbshipit-source-id: 3a6a705437553ba319e9fd4b7f676ff73857a27e
This commit is contained in:
committed by
Facebook Github Bot
parent
045f862574
commit
54107ae8cf
@ -567,8 +567,7 @@ class DistributedDataParallelTest(MultiProcessTestCase):
|
|||||||
def world_size(self):
|
def world_size(self):
|
||||||
return 2
|
return 2
|
||||||
|
|
||||||
def _test_ddp_with_process_group(self, process_group):
|
def _test_ddp_with_process_group(self, process_group, gpus):
|
||||||
gpus = gpus_for_rank(self.world_size)[self.rank]
|
|
||||||
model = Net()
|
model = Net()
|
||||||
ddp_model = DistributedDataParallel(
|
ddp_model = DistributedDataParallel(
|
||||||
copy.deepcopy(model).cuda(gpus[0]),
|
copy.deepcopy(model).cuda(gpus[0]),
|
||||||
@ -620,14 +619,18 @@ class DistributedDataParallelTest(MultiProcessTestCase):
|
|||||||
options = c10d.ProcessGroupGloo.Options()
|
options = c10d.ProcessGroupGloo.Options()
|
||||||
options.devices = [c10d.ProcessGroupGloo.create_tcp_device(interface="lo")]
|
options.devices = [c10d.ProcessGroupGloo.create_tcp_device(interface="lo")]
|
||||||
process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size, options)
|
process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size, options)
|
||||||
self._test_ddp_with_process_group(process_group)
|
gpus = gpus_for_rank(self.world_size)[self.rank]
|
||||||
|
self._test_ddp_with_process_group(process_group, gpus)
|
||||||
|
self._test_ddp_with_process_group(process_group, list(map(lambda i: torch.device('cuda:' + str(i)), gpus)))
|
||||||
|
|
||||||
@skip_if_not_multigpu
|
@skip_if_not_multigpu
|
||||||
@skip_if_not_nccl
|
@skip_if_not_nccl
|
||||||
def test_nccl_backend(self):
|
def test_nccl_backend(self):
|
||||||
store = c10d.TCPStore('localhost', self.port, self.is_master)
|
store = c10d.TCPStore('localhost', self.port, self.is_master)
|
||||||
process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)
|
process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)
|
||||||
self._test_ddp_with_process_group(process_group)
|
gpus = gpus_for_rank(self.world_size)[self.rank]
|
||||||
|
self._test_ddp_with_process_group(process_group, gpus)
|
||||||
|
self._test_ddp_with_process_group(process_group, list(map(lambda i: torch.device('cuda:' + str(i)), gpus)))
|
||||||
|
|
||||||
@skip_if_not_multigpu
|
@skip_if_not_multigpu
|
||||||
def test_dist_broadcast_coalesced(self):
|
def test_dist_broadcast_coalesced(self):
|
||||||
|
|||||||
@ -1126,24 +1126,15 @@ class _DistTestBase(object):
|
|||||||
# Shuffle the input so that DDP input is different
|
# Shuffle the input so that DDP input is different
|
||||||
input = input[torch.randperm(batch_size)]
|
input = input[torch.randperm(batch_size)]
|
||||||
|
|
||||||
@unittest.skipIf(
|
def _test_DistributedDataParallel(self, gpu_subset, rank, output_device=None):
|
||||||
BACKEND != "nccl" and BACKEND != "gloo",
|
|
||||||
"Only Nccl & Gloo backend support DistributedDataParallel",
|
|
||||||
)
|
|
||||||
@skip_if_no_cuda_distributed
|
|
||||||
@skip_if_no_gpu
|
|
||||||
def test_DistributedDataParallel(self):
|
|
||||||
# Run a simple end to end DDP model, use result of single node model
|
# Run a simple end to end DDP model, use result of single node model
|
||||||
# as baseline
|
# as baseline
|
||||||
group, group_id, rank = self._init_global_test()
|
|
||||||
rank_to_GPU = self._init_multigpu_helper()
|
|
||||||
|
|
||||||
# cpu training setup
|
# cpu training setup
|
||||||
model = self._create_Net()
|
model = self._create_Net()
|
||||||
|
|
||||||
# single gpu training setup
|
# single gpu training setup
|
||||||
model_gpu = copy.deepcopy(model)
|
model_gpu = copy.deepcopy(model)
|
||||||
gpu_subset = list(rank_to_GPU[rank])
|
|
||||||
model_gpu.cuda(gpu_subset[0])
|
model_gpu.cuda(gpu_subset[0])
|
||||||
|
|
||||||
# DDP training setup
|
# DDP training setup
|
||||||
@ -1195,6 +1186,22 @@ class _DistTestBase(object):
|
|||||||
)
|
)
|
||||||
self._barrier()
|
self._barrier()
|
||||||
|
|
||||||
|
@unittest.skipIf(BACKEND != 'nccl' and BACKEND != 'gloo',
|
||||||
|
"Only Nccl & Gloo backend support DistributedDataParallel")
|
||||||
|
@skip_if_no_cuda_distributed
|
||||||
|
@skip_if_no_gpu
|
||||||
|
def test_DistributedDataParallel(self):
|
||||||
|
group, group_id, rank = self._init_global_test()
|
||||||
|
rank_to_GPU = self._init_multigpu_helper()
|
||||||
|
gpus = list(rank_to_GPU[rank])
|
||||||
|
self._test_DistributedDataParallel(gpu_subset=gpus, rank=rank)
|
||||||
|
|
||||||
|
# test output_device
|
||||||
|
self._test_DistributedDataParallel(gpu_subset=gpus, rank=rank, output_device=torch.device('cuda'))
|
||||||
|
|
||||||
|
# test device_ids
|
||||||
|
gpus = list(map(lambda i: torch.device('cuda:' + str(i)), gpus))
|
||||||
|
self._test_DistributedDataParallel(gpu_subset=gpus, rank=rank, output_device=torch.device('cuda'))
|
||||||
|
|
||||||
if BACKEND == "gloo" or BACKEND == "nccl":
|
if BACKEND == "gloo" or BACKEND == "nccl":
|
||||||
WORLD_SIZE = os.environ["WORLD_SIZE"]
|
WORLD_SIZE = os.environ["WORLD_SIZE"]
|
||||||
|
|||||||
@ -3154,6 +3154,24 @@ class TestNN(NNTestCase):
|
|||||||
self.assertEqual(out.get_device(), 0)
|
self.assertEqual(out.get_device(), 0)
|
||||||
self.assertEqual(out.data, expected_out)
|
self.assertEqual(out.data, expected_out)
|
||||||
|
|
||||||
|
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
|
||||||
|
@skipIfRocm
|
||||||
|
def test_data_parallel_device_args(self):
|
||||||
|
cuda0 = torch.device('cuda:0')
|
||||||
|
cuda1 = torch.device('cuda:1')
|
||||||
|
|
||||||
|
# test output_device
|
||||||
|
l = nn.Linear(10, 5).to(cuda0, torch.float)
|
||||||
|
i = torch.randn(20, 10, dtype=torch.float, device=cuda0, requires_grad=True)
|
||||||
|
out = dp.data_parallel(l, i, device_ids=(0, 1), output_device=cuda0)
|
||||||
|
self.assertEqual(out, l(i))
|
||||||
|
|
||||||
|
# test device_ids
|
||||||
|
l = nn.Linear(10, 5).to(cuda0, torch.float)
|
||||||
|
i = torch.randn(20, 10, dtype=torch.float, device=cuda0, requires_grad=True)
|
||||||
|
out = dp.data_parallel(l, i, device_ids=(cuda0, cuda1), output_device=cuda0)
|
||||||
|
self.assertEqual(out, l(i))
|
||||||
|
|
||||||
def test_state_dict(self):
|
def test_state_dict(self):
|
||||||
l = nn.Linear(5, 5)
|
l = nn.Linear(5, 5)
|
||||||
block = nn.Module()
|
block = nn.Module()
|
||||||
|
|||||||
@ -3,6 +3,7 @@ import warnings
|
|||||||
import torch
|
import torch
|
||||||
import torch.cuda.comm as comm
|
import torch.cuda.comm as comm
|
||||||
from torch.autograd import Function
|
from torch.autograd import Function
|
||||||
|
from torch.cuda._utils import _get_device_index
|
||||||
|
|
||||||
|
|
||||||
class Broadcast(Function):
|
class Broadcast(Function):
|
||||||
@ -11,6 +12,7 @@ class Broadcast(Function):
|
|||||||
def forward(ctx, target_gpus, *inputs):
|
def forward(ctx, target_gpus, *inputs):
|
||||||
if not all(input.is_cuda for input in inputs):
|
if not all(input.is_cuda for input in inputs):
|
||||||
raise TypeError('Broadcast function not implemented for CPU tensors')
|
raise TypeError('Broadcast function not implemented for CPU tensors')
|
||||||
|
target_gpus = list(map(lambda x: _get_device_index(x, True), target_gpus))
|
||||||
ctx.target_gpus = target_gpus
|
ctx.target_gpus = target_gpus
|
||||||
if len(inputs) == 0:
|
if len(inputs) == 0:
|
||||||
return tuple()
|
return tuple()
|
||||||
@ -50,6 +52,7 @@ class Gather(Function):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, target_device, dim, *inputs):
|
def forward(ctx, target_device, dim, *inputs):
|
||||||
assert all(map(lambda i: i.is_cuda, inputs))
|
assert all(map(lambda i: i.is_cuda, inputs))
|
||||||
|
target_device = _get_device_index(target_device, True)
|
||||||
ctx.target_device = target_device
|
ctx.target_device = target_device
|
||||||
ctx.dim = dim
|
ctx.dim = dim
|
||||||
ctx.input_gpus = tuple(map(lambda i: i.get_device(), inputs))
|
ctx.input_gpus = tuple(map(lambda i: i.get_device(), inputs))
|
||||||
@ -76,6 +79,7 @@ class Scatter(Function):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, target_gpus, chunk_sizes, dim, input):
|
def forward(ctx, target_gpus, chunk_sizes, dim, input):
|
||||||
|
target_gpus = list(map(lambda x: _get_device_index(x, True), target_gpus))
|
||||||
ctx.dim = dim
|
ctx.dim = dim
|
||||||
ctx.input_device = input.get_device() if input.is_cuda else -1
|
ctx.input_device = input.get_device() if input.is_cuda else -1
|
||||||
streams = None
|
streams = None
|
||||||
|
|||||||
@ -5,6 +5,7 @@ from ..modules import Module
|
|||||||
from .scatter_gather import scatter_kwargs, gather
|
from .scatter_gather import scatter_kwargs, gather
|
||||||
from .replicate import replicate
|
from .replicate import replicate
|
||||||
from .parallel_apply import parallel_apply
|
from .parallel_apply import parallel_apply
|
||||||
|
from torch.cuda._utils import _get_device_index
|
||||||
|
|
||||||
|
|
||||||
def _check_balance(device_ids):
|
def _check_balance(device_ids):
|
||||||
@ -13,7 +14,7 @@ def _check_balance(device_ids):
|
|||||||
has less than 75% of the memory or cores of GPU {}. You can do so by setting
|
has less than 75% of the memory or cores of GPU {}. You can do so by setting
|
||||||
the device_ids argument to DataParallel, or by setting the CUDA_VISIBLE_DEVICES
|
the device_ids argument to DataParallel, or by setting the CUDA_VISIBLE_DEVICES
|
||||||
environment variable."""
|
environment variable."""
|
||||||
|
device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
|
||||||
dev_props = [torch.cuda.get_device_properties(i) for i in device_ids]
|
dev_props = [torch.cuda.get_device_properties(i) for i in device_ids]
|
||||||
|
|
||||||
def warn_imbalance(get_prop):
|
def warn_imbalance(get_prop):
|
||||||
@ -77,9 +78,9 @@ class DataParallel(Module):
|
|||||||
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
module: module to be parallelized
|
module (Module): module to be parallelized
|
||||||
device_ids: CUDA devices (default: all devices)
|
device_ids (list of int or torch.device): CUDA devices (default: all devices)
|
||||||
output_device: device location of output (default: device_ids[0])
|
output_device (int or torch.device): device location of output (default: device_ids[0])
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
module (Module): the module to be parallelized
|
module (Module): the module to be parallelized
|
||||||
@ -104,10 +105,11 @@ class DataParallel(Module):
|
|||||||
device_ids = list(range(torch.cuda.device_count()))
|
device_ids = list(range(torch.cuda.device_count()))
|
||||||
if output_device is None:
|
if output_device is None:
|
||||||
output_device = device_ids[0]
|
output_device = device_ids[0]
|
||||||
|
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.module = module
|
self.module = module
|
||||||
self.device_ids = device_ids
|
self.device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
|
||||||
self.output_device = output_device
|
self.output_device = _get_device_index(output_device, True)
|
||||||
|
|
||||||
_check_balance(self.device_ids)
|
_check_balance(self.device_ids)
|
||||||
|
|
||||||
@ -143,10 +145,10 @@ def data_parallel(module, inputs, device_ids=None, output_device=None, dim=0, mo
|
|||||||
This is the functional version of the DataParallel module.
|
This is the functional version of the DataParallel module.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
module: the module to evaluate in parallel
|
module (Module): the module to evaluate in parallel
|
||||||
inputs: inputs to the module
|
inputs (tensor): inputs to the module
|
||||||
device_ids: GPU ids on which to replicate module
|
device_ids (list of int or torch.device): GPU ids on which to replicate module
|
||||||
output_device: GPU location of the output Use -1 to indicate the CPU.
|
output_device (list of int or torch.device): GPU location of the output Use -1 to indicate the CPU.
|
||||||
(default: device_ids[0])
|
(default: device_ids[0])
|
||||||
Returns:
|
Returns:
|
||||||
a Tensor containing the result of module(input) located on
|
a Tensor containing the result of module(input) located on
|
||||||
|
|||||||
@ -12,6 +12,7 @@ from ..modules import Module
|
|||||||
from .replicate import replicate
|
from .replicate import replicate
|
||||||
from .scatter_gather import scatter_kwargs, gather
|
from .scatter_gather import scatter_kwargs, gather
|
||||||
from .parallel_apply import parallel_apply
|
from .parallel_apply import parallel_apply
|
||||||
|
from torch.cuda._utils import _get_device_index
|
||||||
|
|
||||||
|
|
||||||
class DistributedDataParallel(Module):
|
class DistributedDataParallel(Module):
|
||||||
@ -90,10 +91,10 @@ class DistributedDataParallel(Module):
|
|||||||
:meth:`forward` method.
|
:meth:`forward` method.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
module: module to be parallelized
|
module (Module): module to be parallelized
|
||||||
device_ids: CUDA devices (default: all devices)
|
device_ids (list of int or torch.device): CUDA devices (default: all devices)
|
||||||
output_device: device location of output (default: device_ids[0])
|
output_device (int or torch.device): device location of output (default: device_ids[0])
|
||||||
broadcast_buffers: flag that enables syncing (broadcasting) buffers of
|
broadcast_buffers (bool): flag that enables syncing (broadcasting) buffers of
|
||||||
the module at beginning of the forward function.
|
the module at beginning of the forward function.
|
||||||
(default: True)
|
(default: True)
|
||||||
process_group: the c10d process group to be used for distributed data
|
process_group: the c10d process group to be used for distributed data
|
||||||
@ -133,8 +134,8 @@ class DistributedDataParallel(Module):
|
|||||||
|
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.module = module
|
self.module = module
|
||||||
self.device_ids = device_ids
|
self.device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
|
||||||
self.output_device = output_device
|
self.output_device = _get_device_index(output_device, True)
|
||||||
self.broadcast_buffers = broadcast_buffers
|
self.broadcast_buffers = broadcast_buffers
|
||||||
|
|
||||||
self.allreduce_opts = dist.AllreduceOptions()
|
self.allreduce_opts = dist.AllreduceOptions()
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
import threading
|
import threading
|
||||||
import torch
|
import torch
|
||||||
|
from torch.cuda._utils import _get_device_index
|
||||||
|
|
||||||
|
|
||||||
def get_a_var(obj):
|
def get_a_var(obj):
|
||||||
@ -22,6 +23,11 @@ def parallel_apply(modules, inputs, kwargs_tup=None, devices=None):
|
|||||||
contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword)
|
contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword)
|
||||||
on each of :attr:`devices`.
|
on each of :attr:`devices`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
modules (Module): modules to be parallelized
|
||||||
|
inputs (tensor): inputs to the modules
|
||||||
|
devices (list of int or torch.device): CUDA devices
|
||||||
|
|
||||||
:attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and
|
:attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and
|
||||||
:attr:`devices` (if given) should all have same length. Moreover, each
|
:attr:`devices` (if given) should all have same length. Moreover, each
|
||||||
element of :attr:`inputs` can either be a single object as the only argument
|
element of :attr:`inputs` can either be a single object as the only argument
|
||||||
@ -36,7 +42,7 @@ def parallel_apply(modules, inputs, kwargs_tup=None, devices=None):
|
|||||||
assert len(modules) == len(devices)
|
assert len(modules) == len(devices)
|
||||||
else:
|
else:
|
||||||
devices = [None] * len(modules)
|
devices = [None] * len(modules)
|
||||||
|
devices = list(map(lambda x: _get_device_index(x, True), devices))
|
||||||
lock = threading.Lock()
|
lock = threading.Lock()
|
||||||
results = {}
|
results = {}
|
||||||
grad_enabled = torch.is_grad_enabled()
|
grad_enabled = torch.is_grad_enabled()
|
||||||
|
|||||||
@ -1,10 +1,11 @@
|
|||||||
import torch.cuda.comm as comm
|
import torch.cuda.comm as comm
|
||||||
|
from torch.cuda._utils import _get_device_index
|
||||||
|
|
||||||
|
|
||||||
def replicate(network, devices, detach=False):
|
def replicate(network, devices, detach=False):
|
||||||
from ._functions import Broadcast
|
from ._functions import Broadcast
|
||||||
|
|
||||||
devices = tuple(devices)
|
devices = list(map(lambda x: _get_device_index(x, True), devices))
|
||||||
num_replicas = len(devices)
|
num_replicas = len(devices)
|
||||||
|
|
||||||
params = list(network.parameters())
|
params = list(network.parameters())
|
||||||
|
|||||||
Reference in New Issue
Block a user