mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix mp serialization for integer nn.Parameter on CUDA (#56529)
Summary: Fixes https://github.com/pytorch/pytorch/issues/56342 Pull Request resolved: https://github.com/pytorch/pytorch/pull/56529 Reviewed By: albanD Differential Revision: D27896094 Pulled By: ngimel fbshipit-source-id: fe817781eb7139ea57c78acfd56e7c11b61eb4ed
This commit is contained in:
committed by
Facebook GitHub Bot
parent
febff45900
commit
bac4cfd54d
@ -832,14 +832,31 @@ if __name__ == "__main__":
|
||||
|
||||
@unittest.skipIf(NO_MULTIPROCESSING_SPAWN, "Disabled for environments that \
|
||||
don't support multiprocessing with spawn start method")
|
||||
def test_integer_parameter_serialization(self):
|
||||
iparam = torch.nn.Parameter(torch.tensor(0, dtype=torch.int64), requires_grad=False)
|
||||
def test_integer_parameter_serialization_cpu(self):
|
||||
self._test_integer_parameter_serialization(device='cpu')
|
||||
|
||||
@unittest.skipIf(NO_MULTIPROCESSING_SPAWN, "Disabled for environments that \
|
||||
don't support multiprocessing with spawn start method")
|
||||
@unittest.skipIf(not TEST_CUDA_IPC, 'CUDA IPC not available')
|
||||
def test_integer_parameter_serialization_cuda(self):
|
||||
self._test_integer_parameter_serialization(device='cuda')
|
||||
|
||||
def _test_integer_parameter_serialization(self, device):
|
||||
param = torch.nn.Parameter(
|
||||
torch.tensor(0, dtype=torch.int64, device=device),
|
||||
requires_grad=False
|
||||
)
|
||||
|
||||
ctx = mp.get_context('spawn')
|
||||
p = ctx.Process(target=integer_parameter_serialization, args=(iparam,))
|
||||
p = ctx.Process(target=integer_parameter_serialization, args=(param,))
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
self.assertEqual(
|
||||
0, p.exitcode,
|
||||
msg=f'Failed to serialize successfully for "{device}" device!'
|
||||
)
|
||||
|
||||
def test_empty_shared(self):
|
||||
t = torch.tensor([])
|
||||
t.share_memory_()
|
||||
|
@ -123,9 +123,14 @@ def rebuild_cuda_tensor(tensor_cls, tensor_size, tensor_stride, tensor_offset,
|
||||
storage_cls._release_ipc_counter(ref_counter_handle, ref_counter_offset)
|
||||
|
||||
t = torch._utils._rebuild_tensor(storage, tensor_offset, tensor_size, tensor_stride)
|
||||
|
||||
if tensor_cls == torch.nn.parameter.Parameter:
|
||||
t = torch.nn.parameter.Parameter(t)
|
||||
t.requires_grad = requires_grad
|
||||
# It is crucial for integer tensors to receive
|
||||
# the requires_grad=False as an argument in the constructor
|
||||
t = torch.nn.parameter.Parameter(t, requires_grad=requires_grad)
|
||||
else:
|
||||
t.requires_grad = requires_grad
|
||||
|
||||
return t
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user