mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Revert "Add device argument to torch.random.get_rng_state (#163034)"
This reverts commit 9580539e2f73d68e89544c713ff460bea3038701. Reverted https://github.com/pytorch/pytorch/pull/163034 on behalf of https://github.com/cyyever due to It cased partially initialised torch module ([comment](https://github.com/pytorch/pytorch/pull/163034#issuecomment-3368349209))
This commit is contained in:
@ -560,29 +560,6 @@ class TestCheckpoint(TestCase):
|
|||||||
self.assertTrue("first device type: cuda" in warning_msg)
|
self.assertTrue("first device type: cuda" in warning_msg)
|
||||||
|
|
||||||
|
|
||||||
class TestRNGState(TestCase):
|
|
||||||
def test_get_rng_state_cpu(self):
|
|
||||||
# get state from get_rng_state with no arguments
|
|
||||||
state_no_args = torch.get_rng_state()
|
|
||||||
|
|
||||||
# get state from get_rng_state with device passed in
|
|
||||||
device = torch.device("cpu")
|
|
||||||
state_device_arg = torch.get_rng_state(device)
|
|
||||||
|
|
||||||
self.assertEqual(state_no_args, state_device_arg)
|
|
||||||
|
|
||||||
@unittest.skipIf(not HAS_CUDA, "No CUDA")
|
|
||||||
def test_get_rng_state_cuda(self):
|
|
||||||
# get state from cuda's get_rng_state
|
|
||||||
state_no_args = torch.cuda.get_rng_state()
|
|
||||||
|
|
||||||
# get state from get_rng_state with device passed in
|
|
||||||
device = torch.device("cuda")
|
|
||||||
state_device_arg = torch.get_rng_state(device)
|
|
||||||
|
|
||||||
self.assertEqual(state_no_args, state_device_arg)
|
|
||||||
|
|
||||||
|
|
||||||
class TestDataLoaderUtils(TestCase):
|
class TestDataLoaderUtils(TestCase):
|
||||||
MAX_TIMEOUT_IN_SECOND = 300
|
MAX_TIMEOUT_IN_SECOND = 300
|
||||||
|
|
||||||
|
@ -19,20 +19,14 @@ def set_rng_state(new_state: torch.Tensor) -> None:
|
|||||||
default_generator.set_state(new_state)
|
default_generator.set_state(new_state)
|
||||||
|
|
||||||
|
|
||||||
def get_rng_state(device: torch.device = torch.device("cpu")) -> torch.Tensor:
|
def get_rng_state() -> torch.Tensor:
|
||||||
r"""Returns the random number generator state as a `torch.ByteTensor`.
|
r"""Returns the random number generator state as a `torch.ByteTensor`.
|
||||||
|
|
||||||
Args:
|
.. note:: The returned state is for the default generator on CPU only.
|
||||||
device (torch.device, optional): The device to return the RNG state of.
|
|
||||||
Default: ``torch.device('cpu')``, the current CPU device.
|
|
||||||
|
|
||||||
See also: :func:`torch.random.fork_rng`.
|
See also: :func:`torch.random.fork_rng`.
|
||||||
"""
|
"""
|
||||||
|
return default_generator.get_state()
|
||||||
if device.type == "cpu":
|
|
||||||
return default_generator.get_state()
|
|
||||||
else:
|
|
||||||
return torch.get_device_module(device).get_rng_state(device)
|
|
||||||
|
|
||||||
|
|
||||||
def manual_seed(seed) -> torch._C.Generator:
|
def manual_seed(seed) -> torch._C.Generator:
|
||||||
|
Reference in New Issue
Block a user