From 2e1742dd63c2168fd9649dbba96a95abf1f57cae Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Sat, 4 Oct 2025 15:25:45 +0000 Subject: [PATCH] Revert "Add device argument to torch.random.get_rng_state (#163034)" This reverts commit 9580539e2f73d68e89544c713ff460bea3038701. Reverted https://github.com/pytorch/pytorch/pull/163034 on behalf of https://github.com/cyyever due to It cased partially initialised torch module ([comment](https://github.com/pytorch/pytorch/pull/163034#issuecomment-3368349209)) --- test/test_utils.py | 23 ----------------------- torch/random.py | 12 +++--------- 2 files changed, 3 insertions(+), 32 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index c9a485179b76..7c9e4c1d334f 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -560,29 +560,6 @@ class TestCheckpoint(TestCase): self.assertTrue("first device type: cuda" in warning_msg) -class TestRNGState(TestCase): - def test_get_rng_state_cpu(self): - # get state from get_rng_state with no arguments - state_no_args = torch.get_rng_state() - - # get state from get_rng_state with device passed in - device = torch.device("cpu") - state_device_arg = torch.get_rng_state(device) - - self.assertEqual(state_no_args, state_device_arg) - - @unittest.skipIf(not HAS_CUDA, "No CUDA") - def test_get_rng_state_cuda(self): - # get state from cuda's get_rng_state - state_no_args = torch.cuda.get_rng_state() - - # get state from get_rng_state with device passed in - device = torch.device("cuda") - state_device_arg = torch.get_rng_state(device) - - self.assertEqual(state_no_args, state_device_arg) - - class TestDataLoaderUtils(TestCase): MAX_TIMEOUT_IN_SECOND = 300 diff --git a/torch/random.py b/torch/random.py index d221900ccaff..18a1cf9a5d5d 100644 --- a/torch/random.py +++ b/torch/random.py @@ -19,20 +19,14 @@ def set_rng_state(new_state: torch.Tensor) -> None: default_generator.set_state(new_state) -def get_rng_state(device: torch.device = torch.device("cpu")) -> torch.Tensor: +def get_rng_state() -> torch.Tensor: r"""Returns the random number generator state as a `torch.ByteTensor`. - Args: - device (torch.device, optional): The device to return the RNG state of. - Default: ``torch.device('cpu')``, the current CPU device. + .. note:: The returned state is for the default generator on CPU only. See also: :func:`torch.random.fork_rng`. """ - - if device.type == "cpu": - return default_generator.get_state() - else: - return torch.get_device_module(device).get_rng_state(device) + return default_generator.get_state() def manual_seed(seed) -> torch._C.Generator: