Revert "Add device argument to torch.random.get_rng_state (#163034)"

This reverts commit 9580539e2f73d68e89544c713ff460bea3038701.

Reverted https://github.com/pytorch/pytorch/pull/163034 on behalf of https://github.com/cyyever due to It cased partially initialised torch module ([comment](https://github.com/pytorch/pytorch/pull/163034#issuecomment-3368349209))
This commit is contained in:
PyTorch MergeBot
2025-10-04 15:25:45 +00:00
parent f7ad6dbad6
commit 2e1742dd63
2 changed files with 3 additions and 32 deletions

View File

@ -560,29 +560,6 @@ class TestCheckpoint(TestCase):
self.assertTrue("first device type: cuda" in warning_msg) self.assertTrue("first device type: cuda" in warning_msg)
class TestRNGState(TestCase):
def test_get_rng_state_cpu(self):
# get state from get_rng_state with no arguments
state_no_args = torch.get_rng_state()
# get state from get_rng_state with device passed in
device = torch.device("cpu")
state_device_arg = torch.get_rng_state(device)
self.assertEqual(state_no_args, state_device_arg)
@unittest.skipIf(not HAS_CUDA, "No CUDA")
def test_get_rng_state_cuda(self):
# get state from cuda's get_rng_state
state_no_args = torch.cuda.get_rng_state()
# get state from get_rng_state with device passed in
device = torch.device("cuda")
state_device_arg = torch.get_rng_state(device)
self.assertEqual(state_no_args, state_device_arg)
class TestDataLoaderUtils(TestCase): class TestDataLoaderUtils(TestCase):
MAX_TIMEOUT_IN_SECOND = 300 MAX_TIMEOUT_IN_SECOND = 300

View File

@ -19,20 +19,14 @@ def set_rng_state(new_state: torch.Tensor) -> None:
default_generator.set_state(new_state) default_generator.set_state(new_state)
def get_rng_state(device: torch.device = torch.device("cpu")) -> torch.Tensor: def get_rng_state() -> torch.Tensor:
r"""Returns the random number generator state as a `torch.ByteTensor`. r"""Returns the random number generator state as a `torch.ByteTensor`.
Args: .. note:: The returned state is for the default generator on CPU only.
device (torch.device, optional): The device to return the RNG state of.
Default: ``torch.device('cpu')``, the current CPU device.
See also: :func:`torch.random.fork_rng`. See also: :func:`torch.random.fork_rng`.
""" """
if device.type == "cpu":
return default_generator.get_state() return default_generator.get_state()
else:
return torch.get_device_module(device).get_rng_state(device)
def manual_seed(seed) -> torch._C.Generator: def manual_seed(seed) -> torch._C.Generator: