mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Fix mp serialization for integer nn.Parameter on CUDA (#56529)
Summary: Fixes https://github.com/pytorch/pytorch/issues/56342 Pull Request resolved: https://github.com/pytorch/pytorch/pull/56529 Reviewed By: albanD Differential Revision: D27896094 Pulled By: ngimel fbshipit-source-id: fe817781eb7139ea57c78acfd56e7c11b61eb4ed
This commit is contained in:
committed by
Facebook GitHub Bot
parent
febff45900
commit
bac4cfd54d
@ -832,14 +832,31 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
@unittest.skipIf(NO_MULTIPROCESSING_SPAWN, "Disabled for environments that \
|
@unittest.skipIf(NO_MULTIPROCESSING_SPAWN, "Disabled for environments that \
|
||||||
don't support multiprocessing with spawn start method")
|
don't support multiprocessing with spawn start method")
|
||||||
def test_integer_parameter_serialization(self):
|
def test_integer_parameter_serialization_cpu(self):
|
||||||
iparam = torch.nn.Parameter(torch.tensor(0, dtype=torch.int64), requires_grad=False)
|
self._test_integer_parameter_serialization(device='cpu')
|
||||||
|
|
||||||
|
@unittest.skipIf(NO_MULTIPROCESSING_SPAWN, "Disabled for environments that \
|
||||||
|
don't support multiprocessing with spawn start method")
|
||||||
|
@unittest.skipIf(not TEST_CUDA_IPC, 'CUDA IPC not available')
|
||||||
|
def test_integer_parameter_serialization_cuda(self):
|
||||||
|
self._test_integer_parameter_serialization(device='cuda')
|
||||||
|
|
||||||
|
def _test_integer_parameter_serialization(self, device):
|
||||||
|
param = torch.nn.Parameter(
|
||||||
|
torch.tensor(0, dtype=torch.int64, device=device),
|
||||||
|
requires_grad=False
|
||||||
|
)
|
||||||
|
|
||||||
ctx = mp.get_context('spawn')
|
ctx = mp.get_context('spawn')
|
||||||
p = ctx.Process(target=integer_parameter_serialization, args=(iparam,))
|
p = ctx.Process(target=integer_parameter_serialization, args=(param,))
|
||||||
p.start()
|
p.start()
|
||||||
p.join()
|
p.join()
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
0, p.exitcode,
|
||||||
|
msg=f'Failed to serialize successfully for "{device}" device!'
|
||||||
|
)
|
||||||
|
|
||||||
def test_empty_shared(self):
|
def test_empty_shared(self):
|
||||||
t = torch.tensor([])
|
t = torch.tensor([])
|
||||||
t.share_memory_()
|
t.share_memory_()
|
||||||
|
@ -123,9 +123,14 @@ def rebuild_cuda_tensor(tensor_cls, tensor_size, tensor_stride, tensor_offset,
|
|||||||
storage_cls._release_ipc_counter(ref_counter_handle, ref_counter_offset)
|
storage_cls._release_ipc_counter(ref_counter_handle, ref_counter_offset)
|
||||||
|
|
||||||
t = torch._utils._rebuild_tensor(storage, tensor_offset, tensor_size, tensor_stride)
|
t = torch._utils._rebuild_tensor(storage, tensor_offset, tensor_size, tensor_stride)
|
||||||
|
|
||||||
if tensor_cls == torch.nn.parameter.Parameter:
|
if tensor_cls == torch.nn.parameter.Parameter:
|
||||||
t = torch.nn.parameter.Parameter(t)
|
# It is crucial for integer tensors to receive
|
||||||
t.requires_grad = requires_grad
|
# the requires_grad=False as an argument in the constructor
|
||||||
|
t = torch.nn.parameter.Parameter(t, requires_grad=requires_grad)
|
||||||
|
else:
|
||||||
|
t.requires_grad = requires_grad
|
||||||
|
|
||||||
return t
|
return t
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user