mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Revert "Add device argument to torch.random.get_rng_state (#163034)"
This reverts commit 9580539e2f73d68e89544c713ff460bea3038701. Reverted https://github.com/pytorch/pytorch/pull/163034 on behalf of https://github.com/cyyever due to It cased partially initialised torch module ([comment](https://github.com/pytorch/pytorch/pull/163034#issuecomment-3368349209))
This commit is contained in:
@ -560,29 +560,6 @@ class TestCheckpoint(TestCase):
|
||||
self.assertTrue("first device type: cuda" in warning_msg)
|
||||
|
||||
|
||||
class TestRNGState(TestCase):
|
||||
def test_get_rng_state_cpu(self):
|
||||
# get state from get_rng_state with no arguments
|
||||
state_no_args = torch.get_rng_state()
|
||||
|
||||
# get state from get_rng_state with device passed in
|
||||
device = torch.device("cpu")
|
||||
state_device_arg = torch.get_rng_state(device)
|
||||
|
||||
self.assertEqual(state_no_args, state_device_arg)
|
||||
|
||||
@unittest.skipIf(not HAS_CUDA, "No CUDA")
|
||||
def test_get_rng_state_cuda(self):
|
||||
# get state from cuda's get_rng_state
|
||||
state_no_args = torch.cuda.get_rng_state()
|
||||
|
||||
# get state from get_rng_state with device passed in
|
||||
device = torch.device("cuda")
|
||||
state_device_arg = torch.get_rng_state(device)
|
||||
|
||||
self.assertEqual(state_no_args, state_device_arg)
|
||||
|
||||
|
||||
class TestDataLoaderUtils(TestCase):
|
||||
MAX_TIMEOUT_IN_SECOND = 300
|
||||
|
||||
|
@ -19,20 +19,14 @@ def set_rng_state(new_state: torch.Tensor) -> None:
|
||||
default_generator.set_state(new_state)
|
||||
|
||||
|
||||
def get_rng_state(device: torch.device = torch.device("cpu")) -> torch.Tensor:
|
||||
def get_rng_state() -> torch.Tensor:
|
||||
r"""Returns the random number generator state as a `torch.ByteTensor`.
|
||||
|
||||
Args:
|
||||
device (torch.device, optional): The device to return the RNG state of.
|
||||
Default: ``torch.device('cpu')``, the current CPU device.
|
||||
.. note:: The returned state is for the default generator on CPU only.
|
||||
|
||||
See also: :func:`torch.random.fork_rng`.
|
||||
"""
|
||||
|
||||
if device.type == "cpu":
|
||||
return default_generator.get_state()
|
||||
else:
|
||||
return torch.get_device_module(device).get_rng_state(device)
|
||||
|
||||
|
||||
def manual_seed(seed) -> torch._C.Generator:
|
||||
|
Reference in New Issue
Block a user