[dynamo] Fix TorchFunctionMode handling with get_rng_state (#163412)

Fixes #162624
Fixes #162586

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163412
Approved by: https://github.com/eellison
ghstack dependencies: #163386, #163398, #163387, #163414, #163415, #163419, #163434, #163393
This commit is contained in:
Jason Ansel
2025-09-22 21:41:16 -07:00
committed by PyTorch MergeBot
parent 9c4d9f940b
commit 6ef74879f6
2 changed files with 22 additions and 3 deletions

View File

@ -12,7 +12,11 @@ from torch._C import (
_push_on_torch_function_stack, _push_on_torch_function_stack,
) )
from torch._dynamo.utils import counters from torch._dynamo.utils import counters
from torch.overrides import _get_current_function_mode_stack, BaseTorchFunctionMode from torch.overrides import (
_get_current_function_mode_stack,
BaseTorchFunctionMode,
TorchFunctionMode,
)
from torch.testing._internal.common_utils import skipIfXpu from torch.testing._internal.common_utils import skipIfXpu
from torch.testing._internal.inductor_utils import GPU_TYPE from torch.testing._internal.inductor_utils import GPU_TYPE
from torch.testing._internal.triton_utils import requires_gpu from torch.testing._internal.triton_utils import requires_gpu
@ -190,6 +194,19 @@ class TorchFunctionModeTests(torch._dynamo.test_case.TestCase):
def test_torch_function_mode_guards_cpp(self): def test_torch_function_mode_guards_cpp(self):
self._run_torch_function_mode_guard_test() self._run_torch_function_mode_guard_test()
@requires_gpu
def test_torch_function_mode_preserves_cuda_rng_state(self):
class ConstantReturnMode(TorchFunctionMode):
def __torch_function__(self, func, types, args=(), kwargs=None):
return -42
@torch._dynamo.optimize("eager")
def fn():
with ConstantReturnMode():
return 123
self.assertEqual(fn(), 123)
def test_stack_state_mutation_default_device(self): def test_stack_state_mutation_default_device(self):
m = BaseTorchFunctionMode() m = BaseTorchFunctionMode()
m1 = BaseTorchFunctionMode() m1 = BaseTorchFunctionMode()

View File

@ -297,7 +297,8 @@ def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]:
torch_rng_state = torch.random.get_rng_state() torch_rng_state = torch.random.get_rng_state()
cuda_rng_state = None cuda_rng_state = None
if torch.cuda.is_available(): if torch.cuda.is_available():
cuda_rng_state = torch.cuda.get_rng_state() with torch._C.DisableTorchFunction():
cuda_rng_state = torch.cuda.get_rng_state()
cuda_matmul_fp32_prec = torch._C._get_fp32_precision_getter( cuda_matmul_fp32_prec = torch._C._get_fp32_precision_getter(
"cuda", "matmul" "cuda", "matmul"
) )
@ -331,7 +332,8 @@ def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]:
if prior_mobile_allocator_state != curr_mobile_allocator_state: if prior_mobile_allocator_state != curr_mobile_allocator_state:
torch._C._unset_default_mobile_cpu_allocator() torch._C._unset_default_mobile_cpu_allocator()
if cuda_rng_state is not None: if cuda_rng_state is not None:
torch.cuda.set_rng_state(cuda_rng_state) with torch._C.DisableTorchFunction():
torch.cuda.set_rng_state(cuda_rng_state)
torch._C._set_fp32_precision_setter( torch._C._set_fp32_precision_setter(
"cuda", "matmul", cuda_matmul_fp32_prec "cuda", "matmul", cuda_matmul_fp32_prec
) )