Fix mp serialization for integer nn.Parameter on CUDA (#56529)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/56342

Pull Request resolved: https://github.com/pytorch/pytorch/pull/56529

Reviewed By: albanD

Differential Revision: D27896094

Pulled By: ngimel

fbshipit-source-id: fe817781eb7139ea57c78acfd56e7c11b61eb4ed
This commit is contained in:
Vasiliy Alekseev
2021-04-22 16:15:41 -07:00
committed by Facebook GitHub Bot
parent febff45900
commit bac4cfd54d
2 changed files with 27 additions and 5 deletions

View File

@ -832,14 +832,31 @@ if __name__ == "__main__":
@unittest.skipIf(NO_MULTIPROCESSING_SPAWN, "Disabled for environments that \
don't support multiprocessing with spawn start method")
def test_integer_parameter_serialization(self):
iparam = torch.nn.Parameter(torch.tensor(0, dtype=torch.int64), requires_grad=False)
def test_integer_parameter_serialization_cpu(self):
self._test_integer_parameter_serialization(device='cpu')
@unittest.skipIf(NO_MULTIPROCESSING_SPAWN, "Disabled for environments that \
don't support multiprocessing with spawn start method")
@unittest.skipIf(not TEST_CUDA_IPC, 'CUDA IPC not available')
def test_integer_parameter_serialization_cuda(self):
self._test_integer_parameter_serialization(device='cuda')
def _test_integer_parameter_serialization(self, device):
param = torch.nn.Parameter(
torch.tensor(0, dtype=torch.int64, device=device),
requires_grad=False
)
ctx = mp.get_context('spawn')
p = ctx.Process(target=integer_parameter_serialization, args=(iparam,))
p = ctx.Process(target=integer_parameter_serialization, args=(param,))
p.start()
p.join()
self.assertEqual(
0, p.exitcode,
msg=f'Failed to serialize successfully for "{device}" device!'
)
def test_empty_shared(self):
t = torch.tensor([])
t.share_memory_()