mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix mp serialization for integer nn.Parameter on CUDA (#56529)
Summary: Fixes https://github.com/pytorch/pytorch/issues/56342 Pull Request resolved: https://github.com/pytorch/pytorch/pull/56529 Reviewed By: albanD Differential Revision: D27896094 Pulled By: ngimel fbshipit-source-id: fe817781eb7139ea57c78acfd56e7c11b61eb4ed
This commit is contained in:
committed by
Facebook GitHub Bot
parent
febff45900
commit
bac4cfd54d
@ -832,14 +832,31 @@ if __name__ == "__main__":
|
||||
|
||||
@unittest.skipIf(NO_MULTIPROCESSING_SPAWN, "Disabled for environments that \
|
||||
don't support multiprocessing with spawn start method")
|
||||
def test_integer_parameter_serialization(self):
|
||||
iparam = torch.nn.Parameter(torch.tensor(0, dtype=torch.int64), requires_grad=False)
|
||||
def test_integer_parameter_serialization_cpu(self):
|
||||
self._test_integer_parameter_serialization(device='cpu')
|
||||
|
||||
@unittest.skipIf(NO_MULTIPROCESSING_SPAWN, "Disabled for environments that \
|
||||
don't support multiprocessing with spawn start method")
|
||||
@unittest.skipIf(not TEST_CUDA_IPC, 'CUDA IPC not available')
|
||||
def test_integer_parameter_serialization_cuda(self):
|
||||
self._test_integer_parameter_serialization(device='cuda')
|
||||
|
||||
def _test_integer_parameter_serialization(self, device):
|
||||
param = torch.nn.Parameter(
|
||||
torch.tensor(0, dtype=torch.int64, device=device),
|
||||
requires_grad=False
|
||||
)
|
||||
|
||||
ctx = mp.get_context('spawn')
|
||||
p = ctx.Process(target=integer_parameter_serialization, args=(iparam,))
|
||||
p = ctx.Process(target=integer_parameter_serialization, args=(param,))
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
self.assertEqual(
|
||||
0, p.exitcode,
|
||||
msg=f'Failed to serialize successfully for "{device}" device!'
|
||||
)
|
||||
|
||||
def test_empty_shared(self):
|
||||
t = torch.tensor([])
|
||||
t.share_memory_()
|
||||
|
Reference in New Issue
Block a user