Switch most Python RAII guard usages to context manager (#102642)

There are some I can't easily switch due to reasons like:
- Dynamo modelling the guard
- BC concerns (for torch.autograd.set_multithreading_enabled)

Test Plan:
- existing tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102642
Approved by: https://github.com/albanD
This commit is contained in:
Richard Zou
2023-05-31 14:02:16 -07:00
committed by PyTorch MergeBot
parent dcf0c5fb6e
commit 74f10b9ea5
9 changed files with 33 additions and 73 deletions

View File

@ -1744,13 +1744,10 @@ def forward(self, tangents_1):
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
def test_autocast_disable_guard(self):
guard = torch._C._DisableAutocast()
try:
with torch._C._DisableAutocast():
x = torch.rand([4, 4]).cuda()
y = x @ x
self.assertEqual(y.dtype, torch.float32)
finally:
del guard
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
def test_nonidempotent_amp(self):

View File

@ -64,11 +64,8 @@ def construct_autograd_kernel(
def forward(ctx, *flat_args):
ctx.set_materialize_grads(True)
args = pytree.tree_unflatten(list(flat_args), spec)
guard = torch._C._AutoDispatchBelowAutograd()
try:
with torch._C._AutoDispatchBelowAutograd():
output = forward_op(*args)
finally:
del guard
# We use the info about args to give better error messages in backward
args_info = namedtuple_args(

View File

@ -9,21 +9,8 @@ import torch._ops
__all__ = ['enable_python_dispatcher', 'no_python_dispatcher']
@contextmanager
def no_python_dispatcher():
g = torch._C._DisablePythonDispatcher()
try:
yield
finally:
del g
@contextmanager
def enable_python_dispatcher():
g = torch._C._EnablePythonDispatcher()
try:
yield
finally:
del g
no_python_dispatcher = torch._C._DisablePythonDispatcher
enable_python_dispatcher = torch._C._EnablePythonDispatcher
CROSSREF_FUNCTIONALIZE = False

View File

@ -54,11 +54,8 @@ class ConstPropPass(ExportPassBase):
(not op_is_q_dq and not self.propogate_quant)
or (op_is_q_dq and self.propogate_quant)
) and is_const([args, kwargs]):
guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined]
try:
with torch._C._DisableTorchDispatch():
result = op(*args, **kwargs)
finally:
del guard
return result.to_tensor() if isinstance(result, ProxyValue) else result
else:
return super().call_operator(op, args, kwargs, meta)

View File

@ -1457,9 +1457,8 @@ def call_func_with_args(f, args, steal_args=False, disable_amp=False):
args = list(args)
assert isinstance(args, list)
if disable_amp:
guard = torch._C._DisableAutocast()
try:
context = torch._C._DisableAutocast if disable_amp else nullcontext
with context():
if hasattr(f, "_boxed_call"):
out = normalize_as_list(f(args))
else:
@ -1471,9 +1470,6 @@ def call_func_with_args(f, args, steal_args=False, disable_amp=False):
"See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale."
)
out = normalize_as_list(f(*args))
finally:
if disable_amp:
del guard
return out
def aot_dispatch_base_graph(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig, *, fw_metadata: ViewAndMutationMeta):
@ -1519,7 +1515,7 @@ def aot_dispatch_base(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig, *
fw_module = aot_dispatch_base_graph(flat_fn, flat_args, aot_config, fw_metadata=fw_metadata)
disable_amp = torch._C._is_any_autocast_enabled()
context = disable_autocast_manager if disable_amp else nullcontext
context = torch._C._DisableAutocast if disable_amp else nullcontext
with context(), track_graph_compiling(aot_config, "inference"):
compiler = aot_config.inference_compiler if aot_config.inference_compiler is not None else aot_config.fw_compiler
@ -1586,15 +1582,6 @@ def assert_functional_graph(fx_g: torch.fx.Graph, *, allow_input_mutations: bool
return copy_count
@contextmanager
def disable_autocast_manager():
guard = torch._C._DisableAutocast()
try:
yield
finally:
del guard
def are_differentiable_views(view1, view2):
if view1 is view2:
return True
@ -2993,7 +2980,7 @@ def aot_dispatch_autograd(flat_fn, flat_args: List[Any], aot_config: AOTConfig,
def call_compiled_backward():
if CompiledFunction.compiled_bw is None:
assert all(a is not None for a in all_args)
context = disable_autocast_manager if disable_amp else nullcontext
context = torch._C._DisableAutocast if disable_amp else nullcontext
with tracing(saved_context), context(), track_graph_compiling(aot_config, "backward"):
CompiledFunction.compiled_bw = aot_config.bw_compiler(
bw_module, fx_placeholder_vals(bw_module)

View File

@ -216,16 +216,16 @@ class inference_mode(_DecoratorContextManager):
def __init__(self, mode: bool = True) -> None:
if not torch._jit_internal.is_scripting():
super().__init__()
# Holds a python binding to a RAII guard that can enable or disable
# inference mode
self._inference_mode_raii_guard: Optional[torch._C._InferenceMode] = None
# Holds a context manager that can enable or disable inference mode
self._inference_mode_raii_context: Optional[torch._C._InferenceMode] = None
self.mode = mode
def __enter__(self) -> None:
self._inference_mode_raii_guard = torch._C._InferenceMode(self.mode)
self._inference_mode_context = torch._C._InferenceMode(self.mode)
self._inference_mode_context.__enter__()
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
del self._inference_mode_raii_guard
self._inference_mode_context.__exit__(exc_type, exc_value, traceback)
def clone(self) -> "inference_mode":
return self.__class__(self.mode)
@ -287,13 +287,13 @@ class _force_original_view_tracking(_DecoratorContextManager):
def __init__(self, mode: bool) -> None:
self.mode = mode
self._force_original_view_tracking_guard = torch._C._ViewReplayEnabled(mode)
def __enter__(self) -> None:
pass
self._force_original_view_tracking_context = torch._C._ViewReplayEnabled(self.mode)
self._force_original_view_tracking_context.__enter__()
def __exit__(self, *args) -> None:
del self._force_original_view_tracking_guard
self._force_original_view_tracking_context.__exit__(*args)
def clone(self):
return self.__class__(self.mode)

View File

@ -1876,12 +1876,20 @@ class BaseTorchFunctionMode(TorchFunctionMode):
return func(*args, **kwargs)
class enable_reentrant_dispatch():
def __enter__(self):
self._raii_guard = torch._C._RestorePythonTLSSnapshot()
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
del self._raii_guard
@contextlib.contextmanager
def enable_reentrant_dispatch():
# NB: this can't simply be
# `enable_reentrant_dispatch = torch._C._RestorePythonTLSSnapshot`
# because:
# 1. torch._C._RestorePythonTLSSnapshot is unavailable when this file
# initially gets imported. Probably an import order thing.
# 2. enable_reentrant_dispatch is technically public API; assigning
# it the object would change the __module__ to look private.
with torch._C._RestorePythonTLSSnapshot():
try:
yield
finally:
pass
def get_buffer(tensor_subclass, data, prefix):
import ctypes

View File

@ -1483,13 +1483,7 @@ def set_rng_seed(seed):
np.random.seed(seed)
@contextmanager
def disable_functorch():
guard = torch._C._DisableFuncTorch() # type: ignore[attr-defined]
try:
yield
finally:
del guard
disable_functorch = torch._C._DisableFuncTorch
@contextlib.contextmanager

View File

@ -1,6 +1,5 @@
import torch
from typing import TypeVar
from contextlib import contextmanager
T = TypeVar('T')
@ -8,10 +7,4 @@ T = TypeVar('T')
def all_same_mode(modes):
return all(tuple(mode == modes[0] for mode in modes))
@contextmanager
def no_dispatch():
guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined]
try:
yield
finally:
del guard
no_dispatch = torch._C._DisableTorchDispatch