mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[dynamo] Fix TorchFunctionMode handling with get_rng_state (#163412)
Fixes #162624 Fixes #162586 Pull Request resolved: https://github.com/pytorch/pytorch/pull/163412 Approved by: https://github.com/eellison ghstack dependencies: #163386, #163398, #163387, #163414, #163415, #163419, #163434, #163393
This commit is contained in:
committed by
PyTorch MergeBot
parent
9c4d9f940b
commit
6ef74879f6
@ -12,7 +12,11 @@ from torch._C import (
|
||||
_push_on_torch_function_stack,
|
||||
)
|
||||
from torch._dynamo.utils import counters
|
||||
from torch.overrides import _get_current_function_mode_stack, BaseTorchFunctionMode
|
||||
from torch.overrides import (
|
||||
_get_current_function_mode_stack,
|
||||
BaseTorchFunctionMode,
|
||||
TorchFunctionMode,
|
||||
)
|
||||
from torch.testing._internal.common_utils import skipIfXpu
|
||||
from torch.testing._internal.inductor_utils import GPU_TYPE
|
||||
from torch.testing._internal.triton_utils import requires_gpu
|
||||
@ -190,6 +194,19 @@ class TorchFunctionModeTests(torch._dynamo.test_case.TestCase):
|
||||
def test_torch_function_mode_guards_cpp(self):
|
||||
self._run_torch_function_mode_guard_test()
|
||||
|
||||
@requires_gpu
|
||||
def test_torch_function_mode_preserves_cuda_rng_state(self):
|
||||
class ConstantReturnMode(TorchFunctionMode):
|
||||
def __torch_function__(self, func, types, args=(), kwargs=None):
|
||||
return -42
|
||||
|
||||
@torch._dynamo.optimize("eager")
|
||||
def fn():
|
||||
with ConstantReturnMode():
|
||||
return 123
|
||||
|
||||
self.assertEqual(fn(), 123)
|
||||
|
||||
def test_stack_state_mutation_default_device(self):
|
||||
m = BaseTorchFunctionMode()
|
||||
m1 = BaseTorchFunctionMode()
|
||||
|
@ -297,7 +297,8 @@ def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]:
|
||||
torch_rng_state = torch.random.get_rng_state()
|
||||
cuda_rng_state = None
|
||||
if torch.cuda.is_available():
|
||||
cuda_rng_state = torch.cuda.get_rng_state()
|
||||
with torch._C.DisableTorchFunction():
|
||||
cuda_rng_state = torch.cuda.get_rng_state()
|
||||
cuda_matmul_fp32_prec = torch._C._get_fp32_precision_getter(
|
||||
"cuda", "matmul"
|
||||
)
|
||||
@ -331,7 +332,8 @@ def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]:
|
||||
if prior_mobile_allocator_state != curr_mobile_allocator_state:
|
||||
torch._C._unset_default_mobile_cpu_allocator()
|
||||
if cuda_rng_state is not None:
|
||||
torch.cuda.set_rng_state(cuda_rng_state)
|
||||
with torch._C.DisableTorchFunction():
|
||||
torch.cuda.set_rng_state(cuda_rng_state)
|
||||
torch._C._set_fp32_precision_setter(
|
||||
"cuda", "matmul", cuda_matmul_fp32_prec
|
||||
)
|
||||
|
Reference in New Issue
Block a user