mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] Fix TorchFunctionMode handling with get_rng_state (#163412)
Fixes #162624 Fixes #162586 Pull Request resolved: https://github.com/pytorch/pytorch/pull/163412 Approved by: https://github.com/eellison ghstack dependencies: #163386, #163398, #163387, #163414, #163415, #163419, #163434, #163393
This commit is contained in:
committed by
PyTorch MergeBot
parent
9c4d9f940b
commit
6ef74879f6
@ -12,7 +12,11 @@ from torch._C import (
|
|||||||
_push_on_torch_function_stack,
|
_push_on_torch_function_stack,
|
||||||
)
|
)
|
||||||
from torch._dynamo.utils import counters
|
from torch._dynamo.utils import counters
|
||||||
from torch.overrides import _get_current_function_mode_stack, BaseTorchFunctionMode
|
from torch.overrides import (
|
||||||
|
_get_current_function_mode_stack,
|
||||||
|
BaseTorchFunctionMode,
|
||||||
|
TorchFunctionMode,
|
||||||
|
)
|
||||||
from torch.testing._internal.common_utils import skipIfXpu
|
from torch.testing._internal.common_utils import skipIfXpu
|
||||||
from torch.testing._internal.inductor_utils import GPU_TYPE
|
from torch.testing._internal.inductor_utils import GPU_TYPE
|
||||||
from torch.testing._internal.triton_utils import requires_gpu
|
from torch.testing._internal.triton_utils import requires_gpu
|
||||||
@ -190,6 +194,19 @@ class TorchFunctionModeTests(torch._dynamo.test_case.TestCase):
|
|||||||
def test_torch_function_mode_guards_cpp(self):
|
def test_torch_function_mode_guards_cpp(self):
|
||||||
self._run_torch_function_mode_guard_test()
|
self._run_torch_function_mode_guard_test()
|
||||||
|
|
||||||
|
@requires_gpu
|
||||||
|
def test_torch_function_mode_preserves_cuda_rng_state(self):
|
||||||
|
class ConstantReturnMode(TorchFunctionMode):
|
||||||
|
def __torch_function__(self, func, types, args=(), kwargs=None):
|
||||||
|
return -42
|
||||||
|
|
||||||
|
@torch._dynamo.optimize("eager")
|
||||||
|
def fn():
|
||||||
|
with ConstantReturnMode():
|
||||||
|
return 123
|
||||||
|
|
||||||
|
self.assertEqual(fn(), 123)
|
||||||
|
|
||||||
def test_stack_state_mutation_default_device(self):
|
def test_stack_state_mutation_default_device(self):
|
||||||
m = BaseTorchFunctionMode()
|
m = BaseTorchFunctionMode()
|
||||||
m1 = BaseTorchFunctionMode()
|
m1 = BaseTorchFunctionMode()
|
||||||
|
@ -297,7 +297,8 @@ def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]:
|
|||||||
torch_rng_state = torch.random.get_rng_state()
|
torch_rng_state = torch.random.get_rng_state()
|
||||||
cuda_rng_state = None
|
cuda_rng_state = None
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
cuda_rng_state = torch.cuda.get_rng_state()
|
with torch._C.DisableTorchFunction():
|
||||||
|
cuda_rng_state = torch.cuda.get_rng_state()
|
||||||
cuda_matmul_fp32_prec = torch._C._get_fp32_precision_getter(
|
cuda_matmul_fp32_prec = torch._C._get_fp32_precision_getter(
|
||||||
"cuda", "matmul"
|
"cuda", "matmul"
|
||||||
)
|
)
|
||||||
@ -331,7 +332,8 @@ def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]:
|
|||||||
if prior_mobile_allocator_state != curr_mobile_allocator_state:
|
if prior_mobile_allocator_state != curr_mobile_allocator_state:
|
||||||
torch._C._unset_default_mobile_cpu_allocator()
|
torch._C._unset_default_mobile_cpu_allocator()
|
||||||
if cuda_rng_state is not None:
|
if cuda_rng_state is not None:
|
||||||
torch.cuda.set_rng_state(cuda_rng_state)
|
with torch._C.DisableTorchFunction():
|
||||||
|
torch.cuda.set_rng_state(cuda_rng_state)
|
||||||
torch._C._set_fp32_precision_setter(
|
torch._C._set_fp32_precision_setter(
|
||||||
"cuda", "matmul", cuda_matmul_fp32_prec
|
"cuda", "matmul", cuda_matmul_fp32_prec
|
||||||
)
|
)
|
||||||
|
Reference in New Issue
Block a user