mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Switch most Python RAII guard usages to context manager (#102642)
There are some I can't easily switch due to reasons like: - Dynamo modelling the guard - BC concerns (for torch.autograd.set_multithreading_enabled) Test Plan: - existing tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/102642 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
dcf0c5fb6e
commit
74f10b9ea5
@ -1744,13 +1744,10 @@ def forward(self, tangents_1):
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
|
||||
def test_autocast_disable_guard(self):
|
||||
guard = torch._C._DisableAutocast()
|
||||
try:
|
||||
with torch._C._DisableAutocast():
|
||||
x = torch.rand([4, 4]).cuda()
|
||||
y = x @ x
|
||||
self.assertEqual(y.dtype, torch.float32)
|
||||
finally:
|
||||
del guard
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
|
||||
def test_nonidempotent_amp(self):
|
||||
|
@ -64,11 +64,8 @@ def construct_autograd_kernel(
|
||||
def forward(ctx, *flat_args):
|
||||
ctx.set_materialize_grads(True)
|
||||
args = pytree.tree_unflatten(list(flat_args), spec)
|
||||
guard = torch._C._AutoDispatchBelowAutograd()
|
||||
try:
|
||||
with torch._C._AutoDispatchBelowAutograd():
|
||||
output = forward_op(*args)
|
||||
finally:
|
||||
del guard
|
||||
|
||||
# We use the info about args to give better error messages in backward
|
||||
args_info = namedtuple_args(
|
||||
|
@ -9,21 +9,8 @@ import torch._ops
|
||||
|
||||
__all__ = ['enable_python_dispatcher', 'no_python_dispatcher']
|
||||
|
||||
@contextmanager
|
||||
def no_python_dispatcher():
|
||||
g = torch._C._DisablePythonDispatcher()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
del g
|
||||
|
||||
@contextmanager
|
||||
def enable_python_dispatcher():
|
||||
g = torch._C._EnablePythonDispatcher()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
del g
|
||||
no_python_dispatcher = torch._C._DisablePythonDispatcher
|
||||
enable_python_dispatcher = torch._C._EnablePythonDispatcher
|
||||
|
||||
CROSSREF_FUNCTIONALIZE = False
|
||||
|
||||
|
@ -54,11 +54,8 @@ class ConstPropPass(ExportPassBase):
|
||||
(not op_is_q_dq and not self.propogate_quant)
|
||||
or (op_is_q_dq and self.propogate_quant)
|
||||
) and is_const([args, kwargs]):
|
||||
guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined]
|
||||
try:
|
||||
with torch._C._DisableTorchDispatch():
|
||||
result = op(*args, **kwargs)
|
||||
finally:
|
||||
del guard
|
||||
return result.to_tensor() if isinstance(result, ProxyValue) else result
|
||||
else:
|
||||
return super().call_operator(op, args, kwargs, meta)
|
||||
|
@ -1457,9 +1457,8 @@ def call_func_with_args(f, args, steal_args=False, disable_amp=False):
|
||||
args = list(args)
|
||||
assert isinstance(args, list)
|
||||
|
||||
if disable_amp:
|
||||
guard = torch._C._DisableAutocast()
|
||||
try:
|
||||
context = torch._C._DisableAutocast if disable_amp else nullcontext
|
||||
with context():
|
||||
if hasattr(f, "_boxed_call"):
|
||||
out = normalize_as_list(f(args))
|
||||
else:
|
||||
@ -1471,9 +1470,6 @@ def call_func_with_args(f, args, steal_args=False, disable_amp=False):
|
||||
"See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale."
|
||||
)
|
||||
out = normalize_as_list(f(*args))
|
||||
finally:
|
||||
if disable_amp:
|
||||
del guard
|
||||
return out
|
||||
|
||||
def aot_dispatch_base_graph(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig, *, fw_metadata: ViewAndMutationMeta):
|
||||
@ -1519,7 +1515,7 @@ def aot_dispatch_base(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig, *
|
||||
fw_module = aot_dispatch_base_graph(flat_fn, flat_args, aot_config, fw_metadata=fw_metadata)
|
||||
|
||||
disable_amp = torch._C._is_any_autocast_enabled()
|
||||
context = disable_autocast_manager if disable_amp else nullcontext
|
||||
context = torch._C._DisableAutocast if disable_amp else nullcontext
|
||||
|
||||
with context(), track_graph_compiling(aot_config, "inference"):
|
||||
compiler = aot_config.inference_compiler if aot_config.inference_compiler is not None else aot_config.fw_compiler
|
||||
@ -1586,15 +1582,6 @@ def assert_functional_graph(fx_g: torch.fx.Graph, *, allow_input_mutations: bool
|
||||
return copy_count
|
||||
|
||||
|
||||
@contextmanager
|
||||
def disable_autocast_manager():
|
||||
guard = torch._C._DisableAutocast()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
del guard
|
||||
|
||||
|
||||
def are_differentiable_views(view1, view2):
|
||||
if view1 is view2:
|
||||
return True
|
||||
@ -2993,7 +2980,7 @@ def aot_dispatch_autograd(flat_fn, flat_args: List[Any], aot_config: AOTConfig,
|
||||
def call_compiled_backward():
|
||||
if CompiledFunction.compiled_bw is None:
|
||||
assert all(a is not None for a in all_args)
|
||||
context = disable_autocast_manager if disable_amp else nullcontext
|
||||
context = torch._C._DisableAutocast if disable_amp else nullcontext
|
||||
with tracing(saved_context), context(), track_graph_compiling(aot_config, "backward"):
|
||||
CompiledFunction.compiled_bw = aot_config.bw_compiler(
|
||||
bw_module, fx_placeholder_vals(bw_module)
|
||||
|
@ -216,16 +216,16 @@ class inference_mode(_DecoratorContextManager):
|
||||
def __init__(self, mode: bool = True) -> None:
|
||||
if not torch._jit_internal.is_scripting():
|
||||
super().__init__()
|
||||
# Holds a python binding to a RAII guard that can enable or disable
|
||||
# inference mode
|
||||
self._inference_mode_raii_guard: Optional[torch._C._InferenceMode] = None
|
||||
# Holds a context manager that can enable or disable inference mode
|
||||
self._inference_mode_raii_context: Optional[torch._C._InferenceMode] = None
|
||||
self.mode = mode
|
||||
|
||||
def __enter__(self) -> None:
|
||||
self._inference_mode_raii_guard = torch._C._InferenceMode(self.mode)
|
||||
self._inference_mode_context = torch._C._InferenceMode(self.mode)
|
||||
self._inference_mode_context.__enter__()
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
||||
del self._inference_mode_raii_guard
|
||||
self._inference_mode_context.__exit__(exc_type, exc_value, traceback)
|
||||
|
||||
def clone(self) -> "inference_mode":
|
||||
return self.__class__(self.mode)
|
||||
@ -287,13 +287,13 @@ class _force_original_view_tracking(_DecoratorContextManager):
|
||||
|
||||
def __init__(self, mode: bool) -> None:
|
||||
self.mode = mode
|
||||
self._force_original_view_tracking_guard = torch._C._ViewReplayEnabled(mode)
|
||||
|
||||
def __enter__(self) -> None:
|
||||
pass
|
||||
self._force_original_view_tracking_context = torch._C._ViewReplayEnabled(self.mode)
|
||||
self._force_original_view_tracking_context.__enter__()
|
||||
|
||||
def __exit__(self, *args) -> None:
|
||||
del self._force_original_view_tracking_guard
|
||||
self._force_original_view_tracking_context.__exit__(*args)
|
||||
|
||||
def clone(self):
|
||||
return self.__class__(self.mode)
|
||||
|
@ -1876,12 +1876,20 @@ class BaseTorchFunctionMode(TorchFunctionMode):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
class enable_reentrant_dispatch():
|
||||
def __enter__(self):
|
||||
self._raii_guard = torch._C._RestorePythonTLSSnapshot()
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
||||
del self._raii_guard
|
||||
@contextlib.contextmanager
|
||||
def enable_reentrant_dispatch():
|
||||
# NB: this can't simply be
|
||||
# `enable_reentrant_dispatch = torch._C._RestorePythonTLSSnapshot`
|
||||
# because:
|
||||
# 1. torch._C._RestorePythonTLSSnapshot is unavailable when this file
|
||||
# initially gets imported. Probably an import order thing.
|
||||
# 2. enable_reentrant_dispatch is technically public API; assigning
|
||||
# it the object would change the __module__ to look private.
|
||||
with torch._C._RestorePythonTLSSnapshot():
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
pass
|
||||
|
||||
def get_buffer(tensor_subclass, data, prefix):
|
||||
import ctypes
|
||||
|
@ -1483,13 +1483,7 @@ def set_rng_seed(seed):
|
||||
np.random.seed(seed)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def disable_functorch():
|
||||
guard = torch._C._DisableFuncTorch() # type: ignore[attr-defined]
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
del guard
|
||||
disable_functorch = torch._C._DisableFuncTorch
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
|
@ -1,6 +1,5 @@
|
||||
import torch
|
||||
from typing import TypeVar
|
||||
from contextlib import contextmanager
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
@ -8,10 +7,4 @@ T = TypeVar('T')
|
||||
def all_same_mode(modes):
|
||||
return all(tuple(mode == modes[0] for mode in modes))
|
||||
|
||||
@contextmanager
|
||||
def no_dispatch():
|
||||
guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined]
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
del guard
|
||||
no_dispatch = torch._C._DisableTorchDispatch
|
||||
|
Reference in New Issue
Block a user