Fix mp serialization for integer nn.Parameter on CUDA (#56529)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/56342

Pull Request resolved: https://github.com/pytorch/pytorch/pull/56529

Reviewed By: albanD

Differential Revision: D27896094

Pulled By: ngimel

fbshipit-source-id: fe817781eb7139ea57c78acfd56e7c11b61eb4ed
This commit is contained in:
Vasiliy Alekseev
2021-04-22 16:15:41 -07:00
committed by Facebook GitHub Bot
parent febff45900
commit bac4cfd54d
2 changed files with 27 additions and 5 deletions

View File

@ -832,14 +832,31 @@ if __name__ == "__main__":
@unittest.skipIf(NO_MULTIPROCESSING_SPAWN, "Disabled for environments that \
don't support multiprocessing with spawn start method")
def test_integer_parameter_serialization(self):
iparam = torch.nn.Parameter(torch.tensor(0, dtype=torch.int64), requires_grad=False)
def test_integer_parameter_serialization_cpu(self):
self._test_integer_parameter_serialization(device='cpu')
@unittest.skipIf(NO_MULTIPROCESSING_SPAWN, "Disabled for environments that \
don't support multiprocessing with spawn start method")
@unittest.skipIf(not TEST_CUDA_IPC, 'CUDA IPC not available')
def test_integer_parameter_serialization_cuda(self):
self._test_integer_parameter_serialization(device='cuda')
def _test_integer_parameter_serialization(self, device):
param = torch.nn.Parameter(
torch.tensor(0, dtype=torch.int64, device=device),
requires_grad=False
)
ctx = mp.get_context('spawn')
p = ctx.Process(target=integer_parameter_serialization, args=(iparam,))
p = ctx.Process(target=integer_parameter_serialization, args=(param,))
p.start()
p.join()
self.assertEqual(
0, p.exitcode,
msg=f'Failed to serialize successfully for "{device}" device!'
)
def test_empty_shared(self):
t = torch.tensor([])
t.share_memory_()

View File

@ -123,9 +123,14 @@ def rebuild_cuda_tensor(tensor_cls, tensor_size, tensor_stride, tensor_offset,
storage_cls._release_ipc_counter(ref_counter_handle, ref_counter_offset)
t = torch._utils._rebuild_tensor(storage, tensor_offset, tensor_size, tensor_stride)
if tensor_cls == torch.nn.parameter.Parameter:
t = torch.nn.parameter.Parameter(t)
# It is crucial for integer tensors to receive
# the requires_grad=False as an argument in the constructor
t = torch.nn.parameter.Parameter(t, requires_grad=requires_grad)
else:
t.requires_grad = requires_grad
return t