diff --git a/aten/src/ATen/PythonTorchFunctionTLS.cpp b/aten/src/ATen/PythonTorchFunctionTLS.cpp index ae9f722de60a..c4e1241805a8 100644 --- a/aten/src/ATen/PythonTorchFunctionTLS.cpp +++ b/aten/src/ATen/PythonTorchFunctionTLS.cpp @@ -18,6 +18,26 @@ void PythonTorchFunctionTLS::swap_mode(std::shared_ptr& mode) pythonTorchFunctionState.mode_.swap(mode); } +void PythonTorchFunctionTLS::push_onto_stack(std::shared_ptr mode) { + pythonTorchFunctionState.stack_.push_back(std::move(mode)); +} + +const std::shared_ptr PythonTorchFunctionTLS::pop_stack() { + TORCH_CHECK(pythonTorchFunctionState.stack_.size() > 0, "trying to pop from empty mode stack"); + const auto out = pythonTorchFunctionState.stack_.back(); + pythonTorchFunctionState.stack_.pop_back(); + return out; +} + +const std::shared_ptr& PythonTorchFunctionTLS::get_stack_at(int64_t idx) { + TORCH_CHECK(idx < static_cast(pythonTorchFunctionState.stack_.size()), "Tried to get stack at idx that's too big"); + return pythonTorchFunctionState.stack_[idx]; +} + +int64_t PythonTorchFunctionTLS::stack_len() { + return pythonTorchFunctionState.stack_.size(); +} + void PythonTorchFunctionTLS::set_disabled(bool disabled) { pythonTorchFunctionState.disabled_ = disabled; } @@ -34,5 +54,9 @@ const PythonTorchFunctionTLS& PythonTorchFunctionTLS::get_state() { return pythonTorchFunctionState; } +bool function_mode_enabled() { + return static_cast(PythonTorchFunctionTLS::get_mode()); +} + } // namespace impl } // namespace at diff --git a/aten/src/ATen/PythonTorchFunctionTLS.h b/aten/src/ATen/PythonTorchFunctionTLS.h index 003dcef1e90f..ef283164246d 100644 --- a/aten/src/ATen/PythonTorchFunctionTLS.h +++ b/aten/src/ATen/PythonTorchFunctionTLS.h @@ -14,13 +14,28 @@ struct TORCH_API PythonTorchFunctionTLS { static const std::shared_ptr& get_mode(); static void swap_mode(std::shared_ptr&); - static void set_state(const PythonTorchFunctionTLS& state); + static void push_onto_stack(std::shared_ptr mode); + static const std::shared_ptr pop_stack(); + static const std::shared_ptr& get_stack_at(int64_t idx); + static int64_t stack_len(); + static const PythonTorchFunctionTLS& get_state(); + static void set_state(const PythonTorchFunctionTLS& state); private: + // The mode TLS is split into + // - disabled_, which says whether or not to disable all torch function + // modes + // - mode_, which is the C++ mode, that can only be the mode handling mode + // or null + // - stack_, which is a vector of modes representing the stack of user + // defined modes bool disabled_; - std::shared_ptr mode_; + std::shared_ptr mode_ = nullptr; + std::vector> stack_; }; +TORCH_API bool function_mode_enabled(); + } // namespace impl } // namespace at diff --git a/aten/src/ATen/ThreadLocalState.h b/aten/src/ATen/ThreadLocalState.h index 529c98b99723..cfae7dbb69f6 100644 --- a/aten/src/ATen/ThreadLocalState.h +++ b/aten/src/ATen/ThreadLocalState.h @@ -56,7 +56,7 @@ class TORCH_API ThreadLocalState { AutogradState autograd_tls_; // TLS for enable_torch_dispatch_mode - std::shared_ptr torch_dispatch_mode_state_; + c10::impl::TorchDispatchModeTLS torch_dispatch_mode_state_; // TLS for enable_python_dispatcher c10::impl::PyInterpreter* python_dispatcher_state_; diff --git a/aten/src/ATen/core/PythonFallbackKernel.cpp b/aten/src/ATen/core/PythonFallbackKernel.cpp index 381c8a45aca6..ace0b5e89494 100644 --- a/aten/src/ATen/core/PythonFallbackKernel.cpp +++ b/aten/src/ATen/core/PythonFallbackKernel.cpp @@ -52,7 +52,7 @@ void pythonFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) { // If Torch Dispatch Mode is active, use its PyInterpreter for dispatch - const auto& maybe_torch_dispatch_mode_state = c10::impl::TorchDispatchModeTLS::get_state(); + const auto& maybe_torch_dispatch_mode_state = c10::impl::TorchDispatchModeTLS::get_mode(); if (maybe_torch_dispatch_mode_state) { maybe_torch_dispatch_mode_state->pyinterpreter()->dispatch(op, stack); return; diff --git a/c10/core/TensorImpl.cpp b/c10/core/TensorImpl.cpp index 627645b1f402..869a74df4186 100644 --- a/c10/core/TensorImpl.cpp +++ b/c10/core/TensorImpl.cpp @@ -547,7 +547,7 @@ c10::intrusive_ptr TensorImpl::shallow_copy_and_detach_core( bool allow_tensor_metadata_change) const { c10::intrusive_ptr r; const auto& maybe_torch_dispatch_mode_state = - c10::impl::TorchDispatchModeTLS::get_state(); + c10::impl::TorchDispatchModeTLS::get_mode(); // TODO: do we have to exclude after Python dispatch key set? if (maybe_torch_dispatch_mode_state && !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python)) { diff --git a/c10/core/impl/TorchDispatchModeTLS.cpp b/c10/core/impl/TorchDispatchModeTLS.cpp index fbf9504f7b5a..5f0268658425 100644 --- a/c10/core/impl/TorchDispatchModeTLS.cpp +++ b/c10/core/impl/TorchDispatchModeTLS.cpp @@ -6,32 +6,85 @@ namespace c10 { namespace impl { -thread_local std::shared_ptr torchDispatchModeState; +thread_local TorchDispatchModeTLS torchDispatchModeState; -void TorchDispatchModeTLS::set_state(std::shared_ptr state) { - if (state) { +// MODE +void TorchDispatchModeTLS::set_mode(std::shared_ptr mode) { + if (mode) { c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, true); c10::impl::tls_set_dispatch_key_included( DispatchKey::PythonTLSSnapshot, true); } else { - TorchDispatchModeTLS::reset_state(); + TorchDispatchModeTLS::reset_mode(); } - torchDispatchModeState = std::move(state); + torchDispatchModeState.mode_ = std::move(mode); } -const std::shared_ptr& TorchDispatchModeTLS::get_state() { - return torchDispatchModeState; +const std::shared_ptr& TorchDispatchModeTLS::get_mode() { + return torchDispatchModeState.mode_; } -void TorchDispatchModeTLS::reset_state() { - torchDispatchModeState.reset(); +void TorchDispatchModeTLS::reset_mode() { + torchDispatchModeState.mode_.reset(); c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false); c10::impl::tls_set_dispatch_key_included( DispatchKey::PythonTLSSnapshot, false); } +void TorchDispatchModeTLS::swap_mode(std::shared_ptr& mode) { + if (mode) { + c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, true); + c10::impl::tls_set_dispatch_key_included( + DispatchKey::PythonTLSSnapshot, true); + } else { + c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false); + c10::impl::tls_set_dispatch_key_included( + DispatchKey::PythonTLSSnapshot, false); + } + torchDispatchModeState.mode_.swap(mode); +} + +// STACK +void TorchDispatchModeTLS::push_onto_stack(std::shared_ptr mode) { + torchDispatchModeState.stack_.push_back(std::move(mode)); +} + +const std::shared_ptr TorchDispatchModeTLS::pop_stack() { + TORCH_CHECK( + torchDispatchModeState.stack_.size() > 0, + "trying to pop from empty mode stack"); + const std::shared_ptr out = + torchDispatchModeState.stack_.back(); + torchDispatchModeState.stack_.pop_back(); + return out; +} + +const std::shared_ptr& TorchDispatchModeTLS::get_stack_at( + int64_t idx) { + TORCH_CHECK( + idx < static_cast(torchDispatchModeState.stack_.size()), + "Tried to get stack at idx that's too big"); + return torchDispatchModeState.stack_[idx]; +} + +int64_t TorchDispatchModeTLS::stack_len() { + return torchDispatchModeState.stack_.size(); +} + +// STATE + +const TorchDispatchModeTLS& TorchDispatchModeTLS::get_state() { + return torchDispatchModeState; +} + +void TorchDispatchModeTLS::set_state(const TorchDispatchModeTLS& state) { + torchDispatchModeState = state; +} + +// UTIL + bool dispatch_mode_enabled() { - return static_cast(c10::impl::TorchDispatchModeTLS::get_state()); + return static_cast(c10::impl::TorchDispatchModeTLS::get_mode()); } } // namespace impl diff --git a/c10/core/impl/TorchDispatchModeTLS.h b/c10/core/impl/TorchDispatchModeTLS.h index 81aa34b11c5f..708c22e014ad 100644 --- a/c10/core/impl/TorchDispatchModeTLS.h +++ b/c10/core/impl/TorchDispatchModeTLS.h @@ -9,9 +9,27 @@ namespace c10 { namespace impl { struct C10_API TorchDispatchModeTLS { - static void set_state(std::shared_ptr state); - static const std::shared_ptr& get_state(); - static void reset_state(); + static void set_mode(std::shared_ptr mode); + static const std::shared_ptr& get_mode(); + static void reset_mode(); + static void swap_mode(std::shared_ptr& mode); + + static void push_onto_stack(std::shared_ptr mode); + static const std::shared_ptr pop_stack(); + static const std::shared_ptr& get_stack_at(int64_t idx); + static int64_t stack_len(); + + static const TorchDispatchModeTLS& get_state(); + static void set_state(const TorchDispatchModeTLS& state); + + private: + // The mode TLS is split into + // - mode_, which is the C++ mode, that can only be the mode handling mode + // or null + // - stack_, which is a vector of modes representing the stack of user + // defined modes + std::shared_ptr mode_; + std::vector> stack_; }; C10_API bool dispatch_mode_enabled(); diff --git a/functorch/experimental/cond.py b/functorch/experimental/cond.py index 95cc387f4c72..9a7b2ecd5503 100644 --- a/functorch/experimental/cond.py +++ b/functorch/experimental/cond.py @@ -4,7 +4,7 @@ from torch._ops import PyOperator from torch.utils._pytree import tree_flatten from torch.fx.experimental.proxy_tensor import get_isolated_graphmodule, get_proxy_slot import torch.utils._pytree as pytree -from torch.utils._python_dispatch import TorchDispatchMode +from torch.utils._python_dispatch import TorchDispatchMode, _get_current_dispatch_mode from torch.fx.experimental.proxy_tensor import track_tensor_tree from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode @@ -29,15 +29,6 @@ def suspend_mode(mode): finally: torch._C._set_torch_dispatch_mode(mode) -@contextmanager -def enable_mode(mode): - curr_mode = torch._C._get_torch_dispatch_mode() - torch._C._set_torch_dispatch_mode(mode) - try: - yield - finally: - torch._C._set_torch_dispatch_mode(curr_mode) - def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands): def _unwrap_proxy(e): @@ -104,7 +95,7 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands): @cond.py_impl(DispatchKey.CPU) def cond_dense(pred, true_fn, false_fn, operands): - mode = torch._C._get_torch_dispatch_mode() + mode = _get_current_dispatch_mode() assert (mode is None), "Mode should never be enabled for CPU key" if pred: return true_fn(*operands) @@ -125,7 +116,7 @@ def cond_autograd(pred, true_fn, false_fn, *operands): @cond.py_impl(ProxyTorchDispatchMode) def inner(pred, true_fn, false_fn, operands): - mode = torch._C._get_torch_dispatch_mode() + mode = _get_current_dispatch_mode() assert (mode is not None), "Mode should always be enabled for python fallback key" with suspend_mode(mode): res = trace_cond(mode, cond, pred, true_fn, false_fn, operands) diff --git a/test/test_decomp.py b/test/test_decomp.py index db96ac49629c..211493f258c7 100644 --- a/test/test_decomp.py +++ b/test/test_decomp.py @@ -3,8 +3,8 @@ from collections import defaultdict from torch import Tensor import torch.autograd -from torch.utils._python_dispatch import enable_torch_dispatch_mode from torch._decomp import decomposition_table +from torch.utils._python_dispatch import TorchDispatchMode from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten from torch.utils._mode_utils import no_dispatch @@ -384,17 +384,16 @@ class TestDecomp(TestCase): saved_precision = self.precision saved_rel_tol = self.rel_tol + test_case = self - class DecompCrossRefMode(torch.Tensor): - @classmethod - def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + class DecompCrossRefMode(TorchDispatchMode): + def __torch_dispatch__(self, func, types, args=(), kwargs=None): with no_dispatch(): - return cls._torch_dispatch(func, types, args, kwargs) + return self._torch_dispatch(func, types, args, kwargs) - @classmethod - def _torch_dispatch(cls, func, types, args=(), kwargs=None): - self.precision = saved_precision - self.rel_tol = saved_rel_tol + def _torch_dispatch(self, func, types, args=(), kwargs=None): + test_case.precision = saved_precision + test_case.rel_tol = saved_rel_tol called.add(func) all_called[func] += 1 @@ -441,14 +440,14 @@ class TestDecomp(TestCase): assert type(orig) == type(decomp) assert orig == decomp continue - op_assert_ref(self, func, test_dtype, i, orig, decomp, ref, args, kwargs) + op_assert_ref(test_case, func, test_dtype, i, orig, decomp, ref, args, kwargs) else: for orig, decomp in zip(real_out, decomp_out): if not isinstance(orig, torch.Tensor): assert type(orig) == type(decomp) assert orig == decomp continue - op_assert_equal(self, func, test_dtype, orig, decomp, args, kwargs) + op_assert_equal(test_case, func, test_dtype, orig, decomp, args, kwargs) return real_out_unflat @@ -487,7 +486,7 @@ class TestDecomp(TestCase): # explicit clearing is necessary as I will create a fresh mode # for each region decomposed.clear() - with enable_torch_dispatch_mode(DecompCrossRefMode), enable_python_dispatcher(): + with DecompCrossRefMode(), enable_python_dispatcher(): decomp_out, decomp_vjp_fn = ref_vjp_no_create(fn, *primals) if aten_name in decomposition_names: check_decomposed(aten_name) @@ -496,7 +495,7 @@ class TestDecomp(TestCase): cotangents = tree_map(lambda x: torch.randn_like(x), decomp_out) decomposed.clear() - with enable_torch_dispatch_mode(DecompCrossRefMode), enable_python_dispatcher(): + with DecompCrossRefMode(), enable_python_dispatcher(): decomp_vjp_fn(cotangents) if not run_all: check_decomposed(op.aten_backward_name) @@ -505,7 +504,7 @@ class TestDecomp(TestCase): args = [sample_input.input] + list(sample_input.args) kwargs = sample_input.kwargs decomposed.clear() - with enable_torch_dispatch_mode(DecompCrossRefMode), enable_python_dispatcher(): + with DecompCrossRefMode(), enable_python_dispatcher(): func(*args, **kwargs) if not run_all: check_decomposed(aten_name) diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py index cc656ed239dd..cc8cf5553706 100644 --- a/test/test_fake_tensor.py +++ b/test/test_fake_tensor.py @@ -12,7 +12,6 @@ from torch._subclasses.fake_tensor import ( DynamicOutputShapeException, ) from torch.testing import FileCheck -from torch.utils._python_dispatch import enable_torch_dispatch_mode from torch import nn import unittest import torch._prims as prims @@ -27,10 +26,9 @@ class FakeTensorTest(TestCase): self.assertEqual(list(t.size()), size) def test_basic(self): - mode = FakeTensorMode() x = torch.empty(2, 2, device="cpu") y = torch.empty(4, 2, 2, device="cpu") - with enable_torch_dispatch_mode(mode): + with FakeTensorMode() as mode: x = mode.from_tensor(x) y = mode.from_tensor(y) z = x + y @@ -39,7 +37,7 @@ class FakeTensorTest(TestCase): self.assertTrue(isinstance(z, FakeTensor)) def test_parameter_instantiation(self): - with enable_torch_dispatch_mode(FakeTensorMode()): + with FakeTensorMode(): x = torch.rand([4]) y = torch.nn.parameter.Parameter(x) self.assertTrue(isinstance(y, torch.nn.Parameter)) @@ -52,14 +50,14 @@ class FakeTensorTest(TestCase): @unittest.skipIf(not RUN_CUDA, "requires cuda") def test_index_cuda_with_cpu(self): - with enable_torch_dispatch_mode(FakeTensorMode()): + with FakeTensorMode(): x = torch.rand([2048], device='cuda') out = x[torch.zeros([36], dtype=torch.int64)] self.checkType(out, "cuda", [36]) @unittest.skipIf(not RUN_CUDA, "requires cuda") def test_shape_take_not_device(self): - with enable_torch_dispatch_mode(FakeTensorMode()): + with FakeTensorMode(): x = torch.empty(1, device="cpu") y = torch.empty(8, 8, device="cuda") out = x.resize_as_(y) @@ -69,8 +67,7 @@ class FakeTensorTest(TestCase): @unittest.skipIf(not RUN_CUDA, "requires cuda") def test_zero_dim(self): - mode = FakeTensorMode() - with enable_torch_dispatch_mode(mode): + with FakeTensorMode() as mode: x = torch.tensor(0.) y = torch.rand([4, 4], device="cuda") out = x + y @@ -79,8 +76,7 @@ class FakeTensorTest(TestCase): self.assertTrue(isinstance(out, FakeTensor)) def test_nan_to_num(self): - mode = FakeTensorMode() - with enable_torch_dispatch_mode(mode): + with FakeTensorMode(): for dtype in [torch.float16, torch.float32]: x = torch.rand([4], dtype=dtype) y = torch.nan_to_num(x, nan=None) @@ -90,9 +86,8 @@ class FakeTensorTest(TestCase): @unittest.skipIf(not RUN_CUDA, "requires cuda") def test_throw(self): - mode = FakeTensorMode() x = torch.tensor(0.) # TODO: tensor() errors - with enable_torch_dispatch_mode(mode): + with FakeTensorMode() as mode: x_conv = mode.from_tensor(x) y = torch.rand([4, 4], device="cuda") z = torch.rand([4, 4], device="cpu") @@ -100,7 +95,7 @@ class FakeTensorTest(TestCase): @unittest.skipIf(not RUN_CUDA, "requires cuda") def test_type_as(self): - with enable_torch_dispatch_mode(FakeTensorMode()): + with FakeTensorMode(): x = torch.rand([16, 1], device="cpu") y = torch.rand([4, 4], device="cuda") out = x.type_as(y) @@ -110,12 +105,12 @@ class FakeTensorTest(TestCase): @unittest.skipIf(not RUN_CUDA, "requires cuda") def test_setitem(self): for device in ["cpu", "cuda"]: - with enable_torch_dispatch_mode(FakeTensorMode()): + with FakeTensorMode(): x = torch.rand([16, 1], device=device) x[..., 0] = 0 def test_fake_dispatch_keys(self): - with enable_torch_dispatch_mode(FakeTensorMode()): + with FakeTensorMode(): x = torch.rand([4]) f = FileCheck().check("CPU").check("ADInplaceOrView").check("AutogradCPU").check("AutocastCPU") f.run(torch._C._dispatch_key_set(x)) @@ -127,14 +122,14 @@ class FakeTensorTest(TestCase): FileCheck().check_not("ADInplaceOrView").check_not("Autograd").run(torch._C._dispatch_key_set(y)) def test_constructor(self): - with enable_torch_dispatch_mode(FakeTensorMode()): + with FakeTensorMode(): x = torch.rand([4, 4], device="cpu") self.assertTrue(isinstance(x, FakeTensor)) self.assertTrue(x.device.type == "cpu") def test_mode(self): - with enable_torch_dispatch_mode(FakeTensorMode()): + with FakeTensorMode(): y = torch.rand([4], device="cpu") out = y + y @@ -142,7 +137,7 @@ class FakeTensorTest(TestCase): @unittest.skipIf(not RUN_CUDA, "requires cuda") def test_non_kwarg_device(self): - with enable_torch_dispatch_mode(FakeTensorMode()): + with FakeTensorMode(): x = torch.rand([16, 1], device="cpu") y = x.to(torch.device("cpu")) self.assertIs(x, y) @@ -153,7 +148,7 @@ class FakeTensorTest(TestCase): x = torch.rand([4, 4]) with self.assertRaisesRegex(Exception, "non-Fake Tensor inputs"): - with enable_torch_dispatch_mode(FakeTensorMode()): + with FakeTensorMode(): y = x[0] def test_fake_grad_copy(self): @@ -168,7 +163,7 @@ class FakeTensorTest(TestCase): @unittest.skipIf(not RUN_CUDA, "requires cuda") def test_like_constructor(self): - with enable_torch_dispatch_mode(FakeTensorMode()): + with FakeTensorMode(): x = torch.rand([4, 4]) y = torch.ones_like(x) self.assertTrue(isinstance(y, FakeTensor)) @@ -178,7 +173,7 @@ class FakeTensorTest(TestCase): self.assertEqual(z.device.type, "cuda") def test_binary_op_type_promotion(self): - with enable_torch_dispatch_mode(FakeTensorMode()): + with FakeTensorMode(): x = torch.empty([2, 2], dtype=torch.float) y = torch.empty([2, 2], dtype=torch.int64) out = x / y @@ -186,14 +181,14 @@ class FakeTensorTest(TestCase): self.assertEqual(out.device.type, "cpu") def test_from_numpy(self): - with enable_torch_dispatch_mode(FakeTensorMode()): + with FakeTensorMode(): x = torch.tensor(np.zeros([4, 4])) self.checkType(x, "cpu", [4, 4]) def test_randperm(self): x = torch.randperm(10) y = torch.randperm(5, device="cpu") - with enable_torch_dispatch_mode(FakeTensorMode()): + with FakeTensorMode(): x1 = torch.randperm(10) prims.utils.compare_tensor_meta(x, x1) y1 = torch.randperm(5, device="cpu") @@ -202,21 +197,21 @@ class FakeTensorTest(TestCase): @unittest.skipIf(not RUN_CUDA, "requires cuda") def test_cpu_fallback(self): - with enable_torch_dispatch_mode(FakeTensorMode(allow_fallback_kernels=False)): + with FakeTensorMode(allow_fallback_kernels=False): filters = torch.randn(8, 4, 3, 3).cuda() inputs = torch.randn(1, 4, 5, 5).cuda() out = torch.nn.functional.conv2d(inputs, filters, padding=1) self.assertEqual(out.device.type, "cuda") self.assertEqual(list(out.size()), [1, 8, 5, 5]) - with enable_torch_dispatch_mode(FakeTensorMode(allow_fallback_kernels=True)): + with FakeTensorMode(allow_fallback_kernels=True): # intentionally bad inputs filters = torch.randn(8, 20, 3, 3).cuda() inputs = torch.randn(1, 7, 10, 5).cuda() with self.assertRaises(RuntimeError): torch.nn.functional.conv2d(inputs, filters, padding=1) - with enable_torch_dispatch_mode(FakeTensorMode(allow_fallback_kernels=True)): + with FakeTensorMode(allow_fallback_kernels=True): filters = torch.randn(8, 4, 3, 3).cuda() inputs = torch.randn(1, 4, 5, 5).cuda() @@ -295,7 +290,7 @@ class FakeTensorTest(TestCase): ) mode = FakeTensorMode() - for i, context in enumerate([contextlib.nullcontext, lambda: enable_torch_dispatch_mode(mode)]): + for i, context in enumerate([contextlib.nullcontext, lambda: mode]): with context(): inps = ( torch.randn([92, 8, 2048]).cuda(), @@ -336,16 +331,14 @@ class FakeTensorTest(TestCase): with torch._subclasses.fake_tensor.FakeCopyMode(mode): mod_copied = copy.deepcopy(m) - with enable_torch_dispatch_mode(mode): + with mode: input = torch.rand(20, 16, 50, 100, dtype=torch.half, device="cuda").to(memory_format=torch.channels_last) out = mod_copied(input) self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) self.checkType(out, "cuda", [20, 33, 24, 49]) def test_data_dependent_operator(self): - with enable_torch_dispatch_mode( - FakeTensorMode(allow_fallback_kernels=False) - ): + with FakeTensorMode(allow_fallback_kernels=False): x = torch.rand([10, 10]) self.assertRaises(DynamicOutputShapeException, lambda: torch.nonzero(x)) @@ -355,7 +348,8 @@ class FakeTensorTest(TestCase): @skipIfCrossRef def test_deepcopy(self): - mode = FakeTensorMode() + with FakeTensorMode() as mode: + pass mod = torch.nn.BatchNorm2d(10) with torch._subclasses.fake_tensor.FakeCopyMode(mode): mod_copied = copy.deepcopy(mod) @@ -386,7 +380,7 @@ class FakeTensorTest(TestCase): @unittest.skipIf(not RUN_CUDA, "requires cuda") def test_new(self): - with enable_torch_dispatch_mode(FakeTensorMode()): + with FakeTensorMode(): a = torch.rand([16, 1]) self.checkType(a.new(10, 10), "cpu", [10, 10]) self.checkType(a.new([1, 2, 3, 4]), "cpu", [4]) @@ -527,8 +521,7 @@ class FakeTensorConverterTest(TestCase): self.assertEqual(len(converter.meta_converter.tensor_memo), 0) def test_no_active_mode(self): - mode = FakeTensorMode() - with enable_torch_dispatch_mode(mode): + with FakeTensorMode() as mode: x = torch.empty(2, 2, device="cpu") y = torch.empty(2, 2, device="cpu") @@ -538,9 +531,9 @@ class FakeTensorConverterTest(TestCase): self.assertEqual(out.device.type, "cpu") def test_separate_mode_error(self): - with enable_torch_dispatch_mode(FakeTensorMode()): + with FakeTensorMode(): x = torch.empty(2, 2, device="cpu") - with enable_torch_dispatch_mode(FakeTensorMode()): + with FakeTensorMode(): y = torch.empty(2, 2, device="cpu") self.assertRaises(Exception, lambda: x, y) diff --git a/test/test_ops.py b/test/test_ops.py index f90ab0d550f6..b4b7aa788086 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -63,7 +63,6 @@ from torch._subclasses.fake_tensor import ( ) from torch._subclasses.fake_utils import outputs_alias_inputs -from torch.utils._python_dispatch import enable_torch_dispatch_mode import torch._prims as prims from torch._prims.context import TorchRefsMode @@ -173,8 +172,7 @@ class TestCommon(TestCase): @onlyNativeDeviceTypes @ops(python_ref_db) def test_python_ref_meta(self, device, dtype, op): - mode = FakeTensorMode() - with mode: + with FakeTensorMode() as mode: pass def _to_tensormeta(x): @@ -189,7 +187,7 @@ class TestCommon(TestCase): meta_sample = sample.transform(_to_tensormeta) try: - with enable_torch_dispatch_mode(mode): + with mode: meta_result = op(meta_sample.input, *meta_sample.args, **meta_sample.kwargs) except torch._subclasses.fake_tensor.UnsupportedFakeTensorException: continue @@ -1864,7 +1862,7 @@ class TestFakeTensor(TestCase): continue with context(): - with enable_torch_dispatch_mode(mode): + with mode: res_fake = op(input, *args, **kwargs) diff --git a/test/test_overrides.py b/test/test_overrides.py index 70ca676a1762..e1ecdaab0e54 100644 --- a/test/test_overrides.py +++ b/test/test_overrides.py @@ -16,9 +16,11 @@ from torch.overrides import ( get_overridable_functions, get_testing_overrides, is_tensor_method_or_property, - TorchFunctionMode + TorchFunctionMode, + _get_current_function_mode, + _get_current_function_mode_stack, ) -from torch.utils._mode_utils import find_outermost_mode, all_same_mode, all_same_mode_scope +from torch.utils._mode_utils import all_same_mode from torch.utils._pytree import tree_map Tensor = torch.Tensor @@ -1150,11 +1152,6 @@ class TestTorchFunctionMode(TestCase): self.assertEqual(torch._sparse_csr_tensor_unsafe(1, 1, 1, (1, 1)), -1) self.assertEqual(torch.as_tensor([1]), -1) - def test_enable_torch_function_mode_with_tensor_subclass(self): - x = torch.randn(1) - with torch.overrides.enable_torch_function_mode(SubTensor): - self.assertEqual(torch.mm(x, x), -1) - def test_modes_handle_first(self): class A(TorchFunctionMode): def __torch_function__(self, *args, **kwargs): @@ -1178,54 +1175,9 @@ class TestTorchFunctionMode(TestCase): self.assertEqual(torch.mm(x, x), -1) self.assertEqual(bar(x), 1) self.assertRaisesRegex( - TypeError, r'SubTensor.+MyMode', + TypeError, r'SubTensor.+TorchFunctionStackMode', lambda: self.assertEqual(torch.max(x, x))) - def test_enable_torch_function_mode_trivial(self): - class A(TorchFunctionMode): - def __torch_function__(self, *args, **kwargs): - return -40 - a = A() - with torch.overrides.enable_torch_function_mode(a): - with torch.overrides.enable_torch_function_mode(a): - self.assertEqual(bar(None), -40) - - def test_enable_torch_function_mode_replace(self): - class A(TorchFunctionMode): - def __init__(self, val): - self.val = val - - def __torch_function__(self, *args, **kwargs): - return self.val - a1 = A(-40) - a2 = A(-41) - with torch.overrides.enable_torch_function_mode(a1): - with torch.overrides.enable_torch_function_mode(a2, replace=a1): - self.assertEqual(bar(None), -41) - - def test_enable_torch_function_mode_ignore_preexisting(self): - class A(TorchFunctionMode): - def __init__(self, val): - self.val = val - - def __torch_function__(self, *args, **kwargs): - return self.val - a1 = A(-40) - a2 = A(-41) - with torch.overrides.enable_torch_function_mode(a1): - with torch.overrides.enable_torch_function_mode(a2, ignore_preexisting=True): - self.assertEqual(bar(None), -41) - - def test_ctor_no_inner(self): - class A(TorchFunctionMode): - def __torch_function__(self, *args, **kwargs): - return torch.zeros([]) - - with torch.overrides.enable_torch_function_mode(A()): - x = torch.randn((3, 4)) - - self.assertEqual(x, torch.zeros([])) - def test_with_mode(self): class ErrorA(RuntimeError): pass @@ -1270,15 +1222,24 @@ class TestTorchFunctionMode(TestCase): self.assertEqual(out, ["layer2", "layer1"]) - def test_error_using_same_mode(self): - class A(TorchFunctionMode): - pass + def test_nested_same_mode(self): + out = [] - x = A() - with x: - with self.assertRaisesRegex(RuntimeError, "has already been used as a mode. Please use a fresh version"): - with x: - pass + class A(TorchFunctionMode): + def __init__(self, msg): + self.msg = msg + + def __torch_function__(self, func, _, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + out.append(self.msg) + return func(*args, **kwargs) + + with A("layer1") as a: + with a: + torch.empty([]) + + self.assertEqual(out, ["layer1", "layer1"]) def test_error_using_class_method_on_mode(self): class A(TorchFunctionMode): @@ -1291,86 +1252,43 @@ class TestTorchFunctionMode(TestCase): with A(): x + x - def test_error_with_ancestor(self): + def test_restacking_with_ancestor(self): class A(TorchFunctionMode): pass - with A() as x: - pass - - with self.assertRaisesRegex(RuntimeError, "has already been used as a mode. Please use a fresh version"): - with x: - pass - - def test_restore_errors(self): - class A(TorchFunctionMode): - pass - - with self.assertRaisesRegex(RuntimeError, "does not have any ancestors. Use the standard version instead"): - with A().restore(): - pass - - x = A() with A(): - with x: + with A() as x: pass - with A(): # a different mode instance than the one above - with self.assertRaisesRegex(RuntimeError, "the current mode is not its ancestor"): - with x.restore(): - pass - - - def test_restore_ancestor_mode(self): - class A(TorchFunctionMode): - pass - - x = A() - y = A() with x: - with y: - pass + pass - z = A() - with y.restore(): - with z: - pass - - with x.restore(): - with z.restore(): - pass - - def test_find_outermost_mode(self): + def test_get_cur_mode(self): class A(TorchFunctionMode): - pass - - self.assertIsNone(find_outermost_mode([None, None])) - - x = A() - y = A() - with x: - with y: + def __torch_dispatch__(self, func, types, args=(), kwargs=None): pass - self.assertEqual(find_outermost_mode([x, y]), y) + with A() as mode1: + self.assertEqual(_get_current_function_mode(), mode1) - z = A() - with y.restore(): - with z: + with mode1: + with A() as mode2: + self.assertEqual(_get_current_function_mode(), mode2) + + + def test_get_mode_stack(self): + class A(TorchFunctionMode): + def __torch_dispatch__(self, func, types, args=(), kwargs=None): pass - self.assertEqual(find_outermost_mode([z, x]), z) - i = A() + self.assertEqual(_get_current_function_mode_stack(), []) - with self.assertRaisesRegex(RuntimeError, "doesn't have ancestors set so the ordering with other modes"): - find_outermost_mode([i, x, y, z]) + with A() as mode1: + self.assertEqual(_get_current_function_mode_stack(), [mode1]) - k = A() - with k: - pass - - with self.assertRaisesRegex(RuntimeError, "don't come from the same scope"): - find_outermost_mode([k, x, y, z]) + with mode1: + with A() as mode2: + self.assertEqual(_get_current_function_mode_stack(), [mode1, mode2]) def test_all_same_mode(self): class A(TorchFunctionMode): @@ -1382,32 +1300,6 @@ class TestTorchFunctionMode(TestCase): self.assertFalse(all_same_mode([x, None])) self.assertFalse(all_same_mode([x, y])) - def test_all_same_mode_scope(self): - class A(TorchFunctionMode): - pass - - x = A() - y = A() - z = A() - with x: - with y: - pass - - with x.restore(): - with z: - pass - - i = A() - - self.assertTrue(all_same_mode_scope([x, y], y)) - self.assertTrue(all_same_mode_scope([x, z], z)) - self.assertFalse(all_same_mode_scope([x, y, z], y)) - self.assertFalse(all_same_mode_scope([x, y, z], z)) - self.assertFalse(all_same_mode_scope([x, y, i], y)) - - no_ancestor = A() - self.assertFalse(all_same_mode_scope([x, y, z], no_ancestor)) - def test_reentrant_mode_idiom(self): log = [] @@ -1417,7 +1309,7 @@ class TestTorchFunctionMode(TestCase): kwargs = {} log.append(func) if func is torch.sub: - with torch.overrides.enable_torch_function_mode(self, replace=self.inner): + with self: input, other = args assert not kwargs return torch.add(input, other, alpha=-1) diff --git a/test/test_prims.py b/test/test_prims.py index 4db301df21f2..d36c8ce7156d 100644 --- a/test/test_prims.py +++ b/test/test_prims.py @@ -742,7 +742,7 @@ class TestDecomp(TestCase): op = torch._decomp.decomposition_table.get(torch.ops.aten.leaky_relu_backward.default) def fn0(*arg): - return _is_func_unsupported_nvfuser(mode, op, arg, {}) + return _is_func_unsupported_nvfuser(TorchRefsNvfuserCapabilityMode(), op, arg, {}) def fn1(x): x = x * 2 @@ -750,8 +750,8 @@ class TestDecomp(TestCase): x = x * 2 return x - with TorchRefsNvfuserCapabilityMode() as mode: - self.assertFalse(fn0(x, y, 0.3, False)) + self.assertFalse(fn0(x, y, 0.3, False)) + with TorchRefsNvfuserCapabilityMode(): # Autocast context has C++ level ATen calls that are hidden from # TorchRefsNvfuserCapabilityMode that works only on Python level. diff --git a/test/test_python_dispatch.py b/test/test_python_dispatch.py index 03f88ec3c0e1..bec69117414f 100644 --- a/test/test_python_dispatch.py +++ b/test/test_python_dispatch.py @@ -7,11 +7,11 @@ from torch.library import Library from torch.cuda.jiterator import _create_jit_fn import unittest from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_ROCM, IS_WINDOWS -from torch.utils._mode_utils import no_dispatch, find_outermost_mode, all_same_mode, all_same_mode_scope +from torch.utils._mode_utils import no_dispatch, all_same_mode from torch.testing._internal.logging_tensor import LoggingTensor, LoggingTensorReentrant, LoggingTensorMode, \ log_input, capture_logs, capture_logs_with_logging_tensor_mode from torch.utils._pytree import tree_map, tree_map_only -from torch.utils._python_dispatch import enable_torch_dispatch_mode, TorchDispatchMode +from torch.utils._python_dispatch import TorchDispatchMode, _get_current_dispatch_mode, _get_current_dispatch_mode_stack import logging @@ -730,34 +730,22 @@ $6 = torch._ops.aten.add_.Tensor($1, $5)''') res = x.index_put_(idxs, v) self.assertEqual(called_funcs, [torch.ops.aten.index_put_.default]) - def test_enable_torch_dispatch_mode_error(self) -> None: - z = LoggingTensor(torch.empty([])) - with self.assertRaisesRegex(ValueError, "expected to get TorchDispatchMode, Tensor-like class, or None"): - with enable_torch_dispatch_mode(z): - pass - - def test_enable_torch_dispatch_mode_basic(self) -> None: + def test_torch_dispatch_mode_basic(self) -> None: with capture_logs(is_mode=True) as logs: - with enable_torch_dispatch_mode(LoggingTensorMode()): + with LoggingTensorMode(): torch.empty([]) self.assertExpectedInline('\n'.join(logs), """\ $0 = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)""") - def test_enable_torch_dispatch_mode_unrelated_tensors(self) -> None: + def test_torch_dispatch_mode_unrelated_tensors(self) -> None: x = torch.randn([]) y = torch.randn([]) with capture_logs(is_mode=True) as logs: - with enable_torch_dispatch_mode(LoggingTensorMode()): + with LoggingTensorMode(): x + y self.assertExpectedInline('\n'.join(logs), """\ $2 = torch._ops.aten.add.Tensor($0, $1)""") - def test_nested_push_regular(self): - with LoggingTensorMode.push() as mode: - # This previously errored - with LoggingTensorMode(): - pass - def test_nested_push_logging_tensor_mode(self): x = torch.randn([]) y = torch.randn([]) @@ -799,7 +787,7 @@ $3 = torch._ops.aten.add.Tensor($1, $2)""") self.assertEqual(logs1, logs2) - def test_enable_torch_dispatch_mode_subclass_priority(self) -> None: + def test_torch_dispatch_mode_subclass_priority(self) -> None: class ErrorA(RuntimeError): pass @@ -813,7 +801,8 @@ $3 = torch._ops.aten.add.Tensor($1, $2)""") @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): - raise ErrorA + with AMode(): + raise ErrorA class B(A): @staticmethod @@ -822,6 +811,15 @@ $3 = torch._ops.aten.add.Tensor($1, $2)""") @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + with BMode(): + func(*args, **kwargs) + + class AMode(TorchDispatchMode): + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + raise ErrorA + + class BMode(TorchDispatchMode): + def __torch_dispatch__(self, func, types, args=(), kwargs=None): raise ErrorB a = A(torch.empty(1)) @@ -834,38 +832,41 @@ $3 = torch._ops.aten.add.Tensor($1, $2)""") # B has precedence over A due to the subclass relationship yet # modes take precedence over arguments with self.assertRaises(ErrorA): - with enable_torch_dispatch_mode(A): + with AMode(): b + b with self.assertRaises(ErrorB): - with enable_torch_dispatch_mode(B): + with BMode(): a + a with self.assertRaises(ErrorB): - with enable_torch_dispatch_mode(B): + with BMode(): a + b - def test_enable_torch_dispatch_mode_respects_no_dispatch(self) -> None: + def test_mode_with_make_subclass(self): + class SubTensor(torch.Tensor): + @staticmethod + def __new__(cls, elem): + return torch.Tensor._make_subclass(cls, elem, elem.requires_grad) + + class BasicMode(TorchDispatchMode): + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + return func(*args, **kwargs) + + x = torch.randn(3) + with BasicMode(): + y = SubTensor(x) + self.assertIsInstance(y, SubTensor) + + def test_torch_dispatch_mode_respects_no_dispatch(self) -> None: with capture_logs(is_mode=True) as logs1: - with enable_torch_dispatch_mode(LoggingTensorMode()): + with LoggingTensorMode(): torch.ones([2, 3]) with no_dispatch(): torch.ones([2, 3]) with capture_logs(is_mode=True) as logs2: - with enable_torch_dispatch_mode(LoggingTensorMode()): + with LoggingTensorMode(): torch.ones([2, 3]) self.assertEqual(logs1, logs2) - def test_enable_torch_dispatch_mode_instance(self) -> None: - class TestMode(TorchDispatchMode): - def __torch_dispatch__(self, func, types, args=(), kwargs=None): - if kwargs is None: - kwargs = {} - return func(*args, **kwargs) - - x = TestMode() - y = torch.tensor([2.]) - with enable_torch_dispatch_mode(x): - y + y - def test_shallow_copy_and_detach(self) -> None: seen = set() test_case = self @@ -884,103 +885,25 @@ $3 = torch._ops.aten.add.Tensor($1, $2)""") loss = (x * x).sum() loss.backward() - def test_nested_enable_torch_dispatch_mode(self) -> None: - class A(LoggingTensorMode): - pass - - with self.assertRaisesRegex(ValueError, "there is already an active mode"): - with enable_torch_dispatch_mode(LoggingTensorMode()): - with enable_torch_dispatch_mode(A()): - pass - - # For nesting to be a noop, they need to be the same instance - with self.assertRaisesRegex(ValueError, "there is already an active mode"): - with enable_torch_dispatch_mode(LoggingTensorMode()): - with enable_torch_dispatch_mode(LoggingTensorMode()): - pass - - def test_nesting_with_same_enable_torch_dispatch_mode(self) -> None: - # "nested" enable_torch_dispatch_modes are allowed if they're the same mode (same instance). - # It's the equivalent of a noop, so it will only write once to the log - x = torch.tensor([3.]) - mode = LoggingTensorMode() - with capture_logs(is_mode=True) as logs: - log_input("x", x) - with enable_torch_dispatch_mode(mode): - with enable_torch_dispatch_mode(mode): - x + x - self.assertExpectedInline('\n'.join(logs), '''\ -$0 = input('x') -$1 = torch._ops.aten.add.Tensor($0, $0)''') - - def test_enable_torch_dispatch_mode_ignore_preexisting(self): - class A(TorchDispatchMode): - def __torch_dispatch__(self, func, types, args=(), kwargs=None): - raise AssertionError - - x = torch.tensor([3.]) - with capture_logs(is_mode=True) as logs: - with enable_torch_dispatch_mode(A()): - with enable_torch_dispatch_mode(LoggingTensorMode(), ignore_preexisting=True): - x + x - self.assertExpectedInline('\n'.join(logs), """\ -$1 = torch._ops.aten.add.Tensor($0, $0)""") - - def test_enable_torch_dispatch_mode_replace(self): - class A(TorchDispatchMode): - def __torch_dispatch__(self, func, types, args=(), kwargs=None): - raise AssertionError - - x = torch.tensor([3.]) - outer_mode = A() - with capture_logs(is_mode=True) as logs: - with enable_torch_dispatch_mode(outer_mode): - with enable_torch_dispatch_mode(LoggingTensorMode(), replace=outer_mode): - x + x - self.assertExpectedInline('\n'.join(logs), """\ -$1 = torch._ops.aten.add.Tensor($0, $0)""") - def test_exception_handling(self): class A(torch.Tensor): @staticmethod def __new__(cls, elem): return torch.Tensor._make_subclass(cls, elem, elem.requires_grad) - @classmethod - def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + class AMode(TorchDispatchMode): + def __torch_dispatch__(self, func, types, args=(), kwargs=None): if func.__name__ == 'randn.default': raise RuntimeError() - return cls(torch.zeros(())) + return A(torch.zeros(())) - with enable_torch_dispatch_mode(A): + with AMode(): try: torch.randn(()) except RuntimeError: pass self.assertTrue(isinstance(torch.zeros(()), A)) - def test_ctor_no_inner(self): - class A(TorchDispatchMode): - def __torch_dispatch__(self, func, types, args=(), kwargs=None): - return torch.zeros([]) - - with enable_torch_dispatch_mode(A()): - x = torch.randn((3, 4)) - - self.assertEqual(x, torch.zeros([])) - - def test_with_mode(self): - class ErrorA(RuntimeError): - pass - - class A(TorchDispatchMode): - def __torch_dispatch__(self, func, types, args=(), kwargs=None): - raise ErrorA() - - with self.assertRaises(ErrorA): - with A(): - torch.empty([]) - def test_with_mode_created_separately(self): class ErrorA(RuntimeError): pass @@ -1021,10 +944,7 @@ $1 = torch._ops.aten.add.Tensor($0, $0)""") @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): - modes = (arg.mode for arg in args + tuple(kwargs.values()) if isinstance(arg, ModeTensor)) - outermost = find_outermost_mode(modes) - with outermost.restore(): - return func(*args, **kwargs) + raise NotImplementedError("Shouldn't be here") class Mode(TorchDispatchMode): def __torch_dispatch__(self, func, types, args=(), kwargs=None): @@ -1111,70 +1031,17 @@ $1 = torch._ops.aten.add.Tensor($0, $0)""") with PoliteMode(): a.abs() - def test_disable_mode(self): - class FailEverythingMode(TorchDispatchMode): - def __torch_dispatch__(self, func, types, args=(), kwargs=None): - raise RuntimeError("arf") + def test_nesting_same_mode(self): + # If the pushed mode is the same instance as the current mode, we allow pushing an already active mode. - with FailEverythingMode() as m: - self.assertRaises(RuntimeError, lambda: torch.ones([2, 3])) - with enable_torch_dispatch_mode(None, replace=m): - torch.ones([2, 3]) + with capture_logs(is_mode=True) as logs: + with LoggingTensorMode() as reenabled: + with reenabled: + torch.empty([]) + self.assertExpectedInline('\n'.join(logs), """\ +$0 = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False) +$0 = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)""") - def test_make_wrapper_subclass_with_modes(self): - class ModeTensor(torch.Tensor): - def __new__(cls, elem, mode): - r = torch.Tensor._make_wrapper_subclass(cls, elem.shape) - r.elem = elem - r.mode = mode - return r - - @classmethod - def __torch_dispatch__(cls, func, types, args=(), kwargs=None): - modes = (arg.mode for arg in args + tuple(kwargs.values()) if isinstance(arg, ModeTensor)) - outermost = find_outermost_mode(modes) - with outermost.restore(): - return func(*args, **kwargs) - - class Mode(TorchDispatchMode): - def __torch_dispatch__(self, func, types, args=(), kwargs=None): - def unwrap(e): - if isinstance(e, ModeTensor): - return e.elem - else: - return e - - def wrap(t): - if isinstance(t, torch.Tensor): - return ModeTensor(t, self) - else: - return t - - return wrap(func(*tuple(unwrap(a) for a in args), **kwargs)) - - x = torch.tensor(4.) - with Mode(): - y = x + x - z = y + y - self.assertIsInstance(y, ModeTensor) - self.assertIsInstance(z, ModeTensor) - - with Mode(): - with Mode(): - y = x + x - z = y + y - self.assertIsInstance(y, ModeTensor) - self.assertIsInstance(z, ModeTensor) - - def test_error_using_same_mode(self): - class A(TorchDispatchMode): - pass - - x = A() - with x: - with self.assertRaisesRegex(RuntimeError, "has already been used as a mode. Please use a fresh version"): - with x: - pass def test_error_using_class_method_on_mode(self): class A(TorchDispatchMode): @@ -1187,74 +1054,33 @@ $1 = torch._ops.aten.add.Tensor($0, $0)""") with A(): x + x - def test_error_with_ancestor(self): - x = LoggingTensorMode() - with x: - pass - - with self.assertRaisesRegex(RuntimeError, "has already been used as a mode. Please use a fresh version"): - with x: + def test_get_cur_mode(self): + class A(TorchDispatchMode): + def __torch_dispatch__(self, func, types, args=(), kwargs=None): pass - def test_restore_errors(self): - with self.assertRaisesRegex(RuntimeError, "does not have any ancestors. Use the standard version instead"): - with LoggingTensorMode().restore(): + self.assertEqual(_get_current_dispatch_mode(), None) + + with A() as mode1: + self.assertEqual(_get_current_dispatch_mode(), mode1) + + with mode1: + with A() as mode2: + self.assertEqual(_get_current_dispatch_mode(), mode2) + + def test_get_mode_stack(self): + class A(TorchDispatchMode): + def __torch_dispatch__(self, func, types, args=(), kwargs=None): pass - x = LoggingTensorMode() - with LoggingTensorMode(): - with x: - pass + self.assertEqual(_get_current_dispatch_mode_stack(), []) - with LoggingTensorMode(): # a different mode instance than the one above - with self.assertRaisesRegex(RuntimeError, "the current mode is not its ancestor"): - with x.restore(): - pass + with A() as mode1: + self.assertEqual(_get_current_dispatch_mode_stack(), [mode1]) - def test_restore_ancestor_mode(self): - x = LoggingTensorMode() - y = LoggingTensorMode() - with x: - with y: - pass - - z = LoggingTensorMode() - with y.restore(): - with z: - pass - - with x.restore(): - with z.restore(): - pass - - def test_find_outermost_mode(self): - self.assertIsNone(find_outermost_mode([None, None])) - - x = LoggingTensorMode() - y = LoggingTensorMode() - with x: - with y: - pass - - self.assertEqual(find_outermost_mode([x, y]), y) - - z = LoggingTensorMode() - with y.restore(): - with z: - pass - - self.assertEqual(find_outermost_mode([z, x]), z) - i = LoggingTensorMode() - - with self.assertRaisesRegex(RuntimeError, "doesn't have ancestors set so the ordering with other modes"): - find_outermost_mode([i, x, y, z]) - - k = LoggingTensorMode() - with k: - pass - - with self.assertRaisesRegex(RuntimeError, "don't come from the same scope"): - find_outermost_mode([k, x, y, z]) + with mode1: + with A() as mode2: + self.assertEqual(_get_current_dispatch_mode_stack(), [mode1, mode2]) def test_all_same_mode(self): x = LoggingTensorMode() @@ -1263,29 +1089,6 @@ $1 = torch._ops.aten.add.Tensor($0, $0)""") self.assertFalse(all_same_mode([x, None])) self.assertFalse(all_same_mode([x, y])) - def test_all_same_mode_scope(self): - x = LoggingTensorMode() - y = LoggingTensorMode() - z = LoggingTensorMode() - with x: - with y: - pass - - with x.restore(): - with z: - pass - - i = LoggingTensorMode() - - self.assertTrue(all_same_mode_scope([x, y], y)) - self.assertTrue(all_same_mode_scope([x, z], z)) - self.assertFalse(all_same_mode_scope([x, y, z], y)) - self.assertFalse(all_same_mode_scope([x, y, z], z)) - self.assertFalse(all_same_mode_scope([x, y, i], y)) - - no_ancestor = LoggingTensorMode() - self.assertFalse(all_same_mode_scope([x, y, z], no_ancestor)) - def test_tolist_numpy_with_torch_dispatch_mode(self) -> None: x = LoggingTensor(torch.tensor([2.0, 3.0])) with self.assertRaisesRegex(RuntimeError, "is not supported for tensor subclasses."): @@ -1295,7 +1098,7 @@ $1 = torch._ops.aten.add.Tensor($0, $0)""") with self.assertRaises(AssertionError): self.assertEqual(x, None) - def test_enable_torch_dispatch_mode_subclass_autograd_device_check(self) -> None: + def test_subclass_autograd_device_check(self) -> None: class NonWrapperSubclass(torch.Tensor): elem: torch.Tensor diff --git a/test/test_schema_check.py b/test/test_schema_check.py index 2ccf2698d75a..a18cc848cd05 100644 --- a/test/test_schema_check.py +++ b/test/test_schema_check.py @@ -8,7 +8,7 @@ from torch.utils._pytree import tree_map from torch.testing._internal.common_utils import run_tests from torch.fx.operator_schemas import normalize_function from torch.testing._internal.schema_check_mode import SchemaCheckMode -from torch.utils._python_dispatch import enable_torch_dispatch_mode, TorchDispatchMode +from torch.utils._python_dispatch import TorchDispatchMode from torch.testing._internal.common_methods_invocations import op_db from torch.testing._internal.jit_utils import JitTestCase from torch.testing._internal.common_device_type import ops, OpDTypes, instantiate_device_type_tests @@ -72,16 +72,14 @@ class IncorrectAliasTensor(torch.Tensor): class TestSchemaCheck(JitTestCase): # Tests that SchemaCheckMode records operator order with grad def test_schema_check_mode_operator_order(self): - schema_check = SchemaCheckMode() - with enable_torch_dispatch_mode(schema_check): + with SchemaCheckMode() as schema_check: x = torch.rand((3, 3), requires_grad=True) x.relu().sin() self.assertEqual(["aten::rand", "aten::relu", "aten::detach", "aten::sin"], schema_check.ops) # Tests that SchemaCheckMode records operator order without grad def test_schema_check_mode_operator_order_without_grad(self): - schema_check = SchemaCheckMode() - with enable_torch_dispatch_mode(schema_check): + with SchemaCheckMode() as schema_check: x = torch.rand((3, 3), requires_grad=False) x.relu().sin() self.assertEqual(["aten::rand", "aten::relu", "aten::sin"], schema_check.ops) @@ -91,8 +89,7 @@ class TestSchemaCheck(JitTestCase): # NB: previously requires_grad=True, but this induces a detach for # saved variable x = torch.rand((3, 3)) - schema_check = SchemaCheckMode() - with enable_torch_dispatch_mode(schema_check): + with SchemaCheckMode() as schema_check: actual = x.relu().sin() self.assertEqual([], schema_check.mutated) self.assertEqual([], schema_check.aliasing) @@ -100,8 +97,7 @@ class TestSchemaCheck(JitTestCase): # Tests that SchemaCheckMode records mutations and aliases with mutation expected def test_schema_check_mode_mutated_aliasing_mutation(self): actual = torch.rand((3, 3), requires_grad=False) - schema_check = SchemaCheckMode() - with enable_torch_dispatch_mode(schema_check): + with SchemaCheckMode() as schema_check: actual.sinh_() self.assertEqual([('aten::sinh_', 'input')], schema_check.mutated) self.assertEqual([('aten::sinh_', 'input', 'output_0')], schema_check.aliasing) @@ -109,8 +105,7 @@ class TestSchemaCheck(JitTestCase): # Tests that SchemaCheckMode records mutations and aliases with resize_ def test_schema_check_mode_mutated_aliasing_resize_(self): actual = torch.rand((3, 3), requires_grad=False) - schema_check = SchemaCheckMode() - with enable_torch_dispatch_mode(schema_check): + with SchemaCheckMode() as schema_check: actual.resize_(9) self.assertEqual([('aten::resize_', 'input')], schema_check.mutated) self.assertEqual([('aten::resize_', 'input', 'output_0')], schema_check.aliasing) @@ -119,8 +114,7 @@ class TestSchemaCheck(JitTestCase): def test_schema_check_mode_mutated_aliasing_aliasing_inputs(self): actual = torch.rand((3, 3)) y = actual - schema_check = SchemaCheckMode() - with enable_torch_dispatch_mode(schema_check): + with SchemaCheckMode() as schema_check: actual.add_(y) self.assertEqual( [ @@ -140,8 +134,7 @@ class TestSchemaCheck(JitTestCase): # Tests that SchemaCheckMode records mutations and alias with as_strided def test_schema_check_mode_mutated_aliasing_as_strided(self): x = torch.rand((3, 6, 4)) - schema_check = SchemaCheckMode() - with enable_torch_dispatch_mode(schema_check): + with SchemaCheckMode() as schema_check: x.as_strided_([3, 6, 4], [9, 1, 1]) self.assertEqual( [ @@ -161,8 +154,7 @@ class TestSchemaCheck(JitTestCase): x = torch.arange(9.) m_actual = torch.arange(9.) e_actual = torch.zeros([9], dtype=torch.int32) - schema_check = SchemaCheckMode() - with enable_torch_dispatch_mode(schema_check): + with SchemaCheckMode() as schema_check: torch.frexp(x, out=(m_actual, e_actual)) self.assertEqual( [ @@ -183,8 +175,7 @@ class TestSchemaCheck(JitTestCase): def test_schema_check_mode_mutated_aliasing_aliasing_outputs(self): x = torch.rand((3, 3)) actual = torch.zeros(3) - schema_check = SchemaCheckMode() - with enable_torch_dispatch_mode(schema_check): + with SchemaCheckMode() as schema_check: torch.aminmax(x, dim=0, out=[actual, actual]) self.assertEqual( [ @@ -207,7 +198,7 @@ class TestSchemaCheck(JitTestCase): def test_schema_check_mode_functionality(self): x = torch.rand((3, 3), requires_grad=True) expected = x.relu().sin() - with enable_torch_dispatch_mode(SchemaCheckMode()): + with SchemaCheckMode(): actual = x.relu().sin() self.assertEqual(expected, actual) @@ -215,7 +206,7 @@ class TestSchemaCheck(JitTestCase): def test_schema_check_mode_functionality_default_replaced(self): x = torch.rand((3, 3), requires_grad=True) expected = x.add(x, alpha=2) - with enable_torch_dispatch_mode(SchemaCheckMode()): + with SchemaCheckMode(): actual = x.add(x, alpha=2) self.assertEqual(expected, actual) @@ -225,7 +216,7 @@ class TestSchemaCheck(JitTestCase): b = torch.rand((3, 3)) c = torch.rand((3, 3)) expected = torch.linalg.multi_dot([a, b, c]) - with enable_torch_dispatch_mode(SchemaCheckMode()): + with SchemaCheckMode(): actual = torch.linalg.multi_dot([a, b, c]) self.assertEqual(expected, actual) @@ -233,7 +224,7 @@ class TestSchemaCheck(JitTestCase): def test_schema_check_mode_functionality_wildcard_after(self): x = torch.rand((3, 3)) expected = x.chunk(6) - with enable_torch_dispatch_mode(SchemaCheckMode()): + with SchemaCheckMode(): actual = x.chunk(6) self.assertEqual(expected, actual) @@ -242,7 +233,7 @@ class TestSchemaCheck(JitTestCase): x = torch.rand((3, 5)) w = torch.rand((4)) expected = torch.stft(x, 4, win_length=4, window=w, return_complex=True) - with enable_torch_dispatch_mode(SchemaCheckMode()): + with SchemaCheckMode(): actual = torch.stft(x, 4, win_length=4, window=w, return_complex=True) self.assertEqual(expected, actual) @@ -251,7 +242,7 @@ class TestSchemaCheck(JitTestCase): expected = torch.rand((3, 3), requires_grad=False) actual = torch.clone(expected) expected.sinh_() - with enable_torch_dispatch_mode(SchemaCheckMode()): + with SchemaCheckMode(): actual.sinh_() self.assertEqual(expected, actual) @@ -262,7 +253,7 @@ class TestSchemaCheck(JitTestCase): actual = torch.clone(expected) y = actual expected.add_(x) - with enable_torch_dispatch_mode(SchemaCheckMode()): + with SchemaCheckMode(): actual.add_(y) self.assertEqual(expected, actual) @@ -272,7 +263,7 @@ class TestSchemaCheck(JitTestCase): m_expected, e_expected = torch.frexp(x) m_actual = torch.arange(9.) e_actual = torch.zeros([9], dtype=torch.int32) - with enable_torch_dispatch_mode(SchemaCheckMode()): + with SchemaCheckMode(): torch.frexp(x, out=(m_actual, e_actual)) self.assertEqual(m_expected, m_actual) self.assertEqual(e_expected, e_actual) @@ -281,13 +272,13 @@ class TestSchemaCheck(JitTestCase): def test_schema_check_mode_functionality_with_multiple_outputs_aliasing(self): x = torch.rand((3, 3)) actual = torch.zeros(3) - with enable_torch_dispatch_mode(SchemaCheckMode()): + with SchemaCheckMode(): torch.aminmax(x, dim=0, out=[actual, actual]) self.assertEqual(torch.amax(x, dim=0), actual) # Tests that SchemaCheckMode wraps Torch.tensor in ops with real Device input def test_schema_check_mode_functionality_device_input(self): - with enable_torch_dispatch_mode(SchemaCheckMode()): + with SchemaCheckMode(): x = torch.rand((3, 3), device="cpu", dtype=torch.double) y = x + x self.assertEqual(x + x, y) @@ -297,7 +288,7 @@ class TestSchemaCheck(JitTestCase): x = torch.rand((3, 3), requires_grad=True) batch = torch.nn.BatchNorm1d(3, track_running_stats=True) expected = batch(x) - with enable_torch_dispatch_mode(SchemaCheckMode()): + with SchemaCheckMode(): actual = batch(x) self.assertEqual(expected, actual) @@ -311,7 +302,7 @@ class TestSchemaCheck(JitTestCase): expected.relu_() expected = batch(expected) - with enable_torch_dispatch_mode(SchemaCheckMode()): + with SchemaCheckMode(): actual.sinh_() actual.tanh_() actual.relu_() @@ -321,7 +312,7 @@ class TestSchemaCheck(JitTestCase): # Tests that SchemaCheckMode wraps Torch.tensor with empty list input def test_schema_check_mode_empty_list_input(self): expected = torch.atleast_1d([]) - with enable_torch_dispatch_mode(SchemaCheckMode()): + with SchemaCheckMode(): actual = torch.atleast_1d([]) self.assertEqual(expected, actual) @@ -330,7 +321,7 @@ class TestSchemaCheck(JitTestCase): with self.assertRaisesRegex(RuntimeError, "Argument input is not defined as mutable but was mutated"): x = torch.rand((3, 3)) y = torch.rand((3, 3)) - with enable_torch_dispatch_mode(SchemaCheckMode()): + with SchemaCheckMode(): IncorrectAliasTensor(x).sub(IncorrectAliasTensor(y)) # # Tests that an exception is raised for a mismatching mutation over multiple ops @@ -338,7 +329,7 @@ class TestSchemaCheck(JitTestCase): with self.assertRaisesRegex(RuntimeError, "Argument input is not defined as mutable but was mutated"): x = torch.rand((3, 3)) y = torch.rand((3, 3)) - with enable_torch_dispatch_mode(SchemaCheckMode()): + with SchemaCheckMode(): IncorrectAliasTensor(x).sin().cos().sub(IncorrectAliasTensor(y)) # Tests that an exception is raised for a mismatching alias @@ -346,7 +337,7 @@ class TestSchemaCheck(JitTestCase): with self.assertRaisesRegex(RuntimeError, "Argument input is not defined to alias output but was aliasing"): x = torch.rand((3, 3), requires_grad=True) y = torch.rand((3, 3)) - with enable_torch_dispatch_mode(SchemaCheckMode()): + with SchemaCheckMode(): IncorrectAliasTensor(x).add(IncorrectAliasTensor(y), alpha=2) # Tests that an exception is raised for a mismatching alias over multiple ops @@ -354,7 +345,7 @@ class TestSchemaCheck(JitTestCase): with self.assertRaisesRegex(RuntimeError, "Argument input is not defined to alias output but was aliasing"): x = torch.rand((3, 3), requires_grad=True) y = torch.zeros((3, 3), requires_grad=True) - with enable_torch_dispatch_mode(SchemaCheckMode()): + with SchemaCheckMode(): IncorrectAliasTensor(x).sin().relu().add(IncorrectAliasTensor(y), alpha=2) # Tests that an exception is raised for a centered mismatching alias over multiple ops @@ -362,15 +353,14 @@ class TestSchemaCheck(JitTestCase): with self.assertRaisesRegex(RuntimeError, "Argument input is not defined to alias output but was aliasing"): x = torch.rand((3, 3), requires_grad=True) y = torch.zeros((3, 3), requires_grad=True) - with enable_torch_dispatch_mode(SchemaCheckMode()): + with SchemaCheckMode(): IncorrectAliasTensor(x).sin().add(IncorrectAliasTensor(y), alpha=2).relu() # Tests that an exception is raised for a centered mismatching alias over multiple ops def test_alias_check_fail_outputs_unexpectedly_aliasing(self): with self.assertRaisesRegex(RuntimeError, "Outputs 0 and 1 alias unexpectedly"): x = torch.rand((3, 3)) - s = SchemaCheckMode() - with enable_torch_dispatch_mode(s): + with SchemaCheckMode() as s: IncorrectAliasTensor(x).aminmax(dim=0) # Tests that is_alias_of returns as expected @@ -439,8 +429,7 @@ class TestSchemaCheck(JitTestCase): return func(*args, **kwargs) x = torch.rand((3, 3)) - schemaInfoCheck = SchemaInfoBindTestMode(self) - with enable_torch_dispatch_mode(schemaInfoCheck): + with SchemaInfoBindTestMode(self) as schemaInfoCheck: x.add(x) @@ -452,7 +441,7 @@ class TestSchemaCheckModeOpInfo(JitTestCase): if (dtype == torch.complex32): return for sample in op.sample_inputs(device, dtype, requires_grad=False): - with enable_torch_dispatch_mode(SchemaCheckMode()): + with SchemaCheckMode(): op(sample.input, *sample.args, **sample.kwargs) instantiate_device_type_tests(TestSchemaCheckModeOpInfo, globals(), only_for=("cpu", "cuda")) diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index ece160dccdb8..15fcf59fa9bb 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -918,11 +918,18 @@ def __is_forward_AD_enabled() -> _bool: ... def _register_default_hooks(pack_hook: Callable, unpack_hook: Callable) -> None: ... def _reset_default_hooks() -> None: ... +def _is_torch_function_mode_enabled()-> _bool: ... def _set_torch_function_mode(cls: Any) -> None: ... -def _get_torch_function_mode() -> Any: ... +def _push_on_torch_function_stack(cls: Any) -> None: ... +def _pop_torch_function_stack() -> Any: ... +def _get_function_stack_at(idx: _int) -> Any: ... +def _len_torch_function_stack() -> _int: ... def _set_torch_dispatch_mode(cls: Any) -> None: ... -def _get_torch_dispatch_mode() -> Any: ... +def _push_on_torch_dispatch_stack(cls: Any) -> None: ... +def _pop_torch_dispatch_stack() -> Any: ... +def _get_dispatch_stack_at(idx: _int) -> Any: ... +def _len_torch_dispatch_stack() -> _int: ... class _InferenceMode(object): def __init__(self, mode: _bool) -> None: ... diff --git a/torch/_ops.py b/torch/_ops.py index 0c9478a1634a..b3ebd401ab8a 100644 --- a/torch/_ops.py +++ b/torch/_ops.py @@ -12,6 +12,7 @@ import torch.jit from torch import _utils_internal # Query `hasattr` only once. + _SET_GLOBAL_FLAGS = hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags") @@ -144,9 +145,11 @@ class PyOperator(PyOperatorABC): return inner def dispatch(self, dispatch_key, *args, **kwargs): + from torch.utils._python_dispatch import _get_current_dispatch_mode + if dispatch_key == torch._C.DispatchKey.Python: # TODO(voz): We should walk all the nodes here / turn it into a list, topmode is ok for now. - curr_mode = type(torch._C._get_torch_dispatch_mode()) + curr_mode = type(_get_current_dispatch_mode()) assert ( curr_mode is not None ), "Illegal invocation of dispatch on torch._C.DispatchKey.Python without a mode." @@ -314,9 +317,11 @@ class OpOverload(PyOperatorABC): return key def handler(*args, **kwargs): + from torch.utils._python_dispatch import _get_current_dispatch_mode + # TODO: We also need to handle tensor subclasses here # TODO(voz): We should walk all the nodes here / turn it into a list, topmode is ok for now. - curr_mode = type(torch._C._get_torch_dispatch_mode()) + curr_mode = type(_get_current_dispatch_mode()) assert ( curr_mode is not None ), "Illegal invocation of dispatch on torch._C.DispatchKey.Python without a mode." diff --git a/torch/_prims/context.py b/torch/_prims/context.py index e27dbba5ddcc..5b2bb9b21f22 100644 --- a/torch/_prims/context.py +++ b/torch/_prims/context.py @@ -185,7 +185,7 @@ class TorchRefsMode(torch.overrides.TorchFunctionMode): if self.should_fallback_fn(self, func, args, kwargs): return orig_func(*args, **kwargs) # torch calls inside func should be interpreted as refs calls - with torch.overrides.enable_torch_function_mode(self, replace=self.inner): + with self: return func(*args, **kwargs) if self.strict: raise RuntimeError( @@ -202,9 +202,7 @@ def _is_node_supported_nvfuser(node): def _is_func_unsupported_nvfuser(torch_function_mode, func, args, kwargs): - with torch.overrides.enable_torch_function_mode( - torch_function_mode, replace=torch_function_mode.inner - ): + with torch_function_mode: gm = get_isolated_graphmodule(func, args, kwargs) supported_ops = NvfuserPrimOperatorSupport() diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 4391f5890f89..de6a851c6bde 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -16,7 +16,7 @@ from torch.fx.operator_schemas import normalize_function from torch.multiprocessing.reductions import StorageWeakRef from torch.overrides import TorchFunctionMode from torch.utils._mode_utils import no_dispatch -from torch.utils._python_dispatch import enable_torch_dispatch_mode, TorchDispatchMode +from torch.utils._python_dispatch import TorchDispatchMode from torch.utils._pytree import PyTree, tree_flatten, tree_map @@ -564,7 +564,8 @@ class FakeTensor(torch.Tensor): else: assert fake_mode is arg.fake_mode, "Mixing modes NYI" - with enable_torch_dispatch_mode(fake_mode): + assert fake_mode is not None + with fake_mode: # type: ignore[attr-defined] return func(*args, **kwargs) @staticmethod @@ -623,7 +624,7 @@ class FakeTensor(torch.Tensor): # We keep one instantiation of `fake_tensor_converter` active -# for the duration of `with torch_enable_mode(FakeTensorMode)`. +# for the duration of `with FakeTensorMode()`. # This allows accurate storage aliasing across invocation of # different operators. While this will keep all freshly allocated # tensors alive during `FakeTensorMode`, there will no be no @@ -779,7 +780,7 @@ class FakeTensorMode(TorchDispatchMode): # We do this to allow for better error localization with `TORCH_SHOW_CPP_STACKTRACES=1` return None - with self.restore(): + with self: if func in meta_table: r = meta_table[func](*args, **kwargs) return r @@ -801,7 +802,7 @@ class FakeTensorMode(TorchDispatchMode): and len(flat_arg_fake_tensors) != 0 and hasattr(func, "prim_meta_impl") ): - with self.restore(): + with self: return func.prim_meta_impl(*args, **kwargs) if has_symbolic_sizes: diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp index edfbab91c700..5c16e4f58447 100644 --- a/torch/csrc/autograd/init.cpp +++ b/torch/csrc/autograd/init.cpp @@ -585,34 +585,6 @@ static PyObject* python_exit_dual_level( END_HANDLE_TH_ERRORS } -static PyObject* set_torch_dispatch_mode(PyObject* _unused, PyObject* arg) { - HANDLE_TH_ERRORS - if (arg == Py_None) { - c10::impl::TorchDispatchModeTLS::set_state(nullptr); - } else { - Py_INCREF(arg); - c10::impl::TorchDispatchModeTLS::set_state( - std::make_shared(arg, getPyInterpreter())); - } - Py_RETURN_NONE; - END_HANDLE_TH_ERRORS -} - -static PyObject* get_torch_dispatch_mode( - PyObject* _unused, - PyObject* _unused2) { - HANDLE_TH_ERRORS - const auto& mode = c10::impl::TorchDispatchModeTLS::get_state(); - if (!mode) { - Py_RETURN_NONE; - } else { - auto* r = mode->ptr(getPyInterpreter()); - Py_INCREF(r); - return r; - } - END_HANDLE_TH_ERRORS -} - static PyObject* set_torch_function_mode(PyObject* _unused, PyObject* arg) { HANDLE_TH_ERRORS if (arg == Py_None) { @@ -623,24 +595,136 @@ static PyObject* set_torch_function_mode(PyObject* _unused, PyObject* arg) { std::make_shared(arg, getPyInterpreter())); } Py_RETURN_NONE; - END_HANDLE_TH_ERRORS + END_HANDLE_TH_ERRORS; } -static PyObject* get_torch_function_mode( +static PyObject* is_torch_function_mode_enabled( PyObject* _unused, PyObject* _unused2) { HANDLE_TH_ERRORS - const auto& mode = at::impl::PythonTorchFunctionTLS::get_mode(); - if (!mode) { - Py_RETURN_NONE; + if (at::impl::function_mode_enabled()) { + Py_RETURN_TRUE; } else { - auto* r = mode->ptr(getPyInterpreter()); - Py_INCREF(r); - return r; + Py_RETURN_FALSE; } END_HANDLE_TH_ERRORS } +static PyObject* push_on_torch_function_stack( + PyObject* _unused, + PyObject* arg) { + HANDLE_TH_ERRORS + if (arg != Py_None) { + Py_INCREF(arg); + at::impl::PythonTorchFunctionTLS::push_onto_stack( + std::make_shared(arg, getPyInterpreter())); + } + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +static PyObject* pop_torch_function_stack( + PyObject* _unused, + PyObject* _unused2) { + HANDLE_TH_ERRORS + const auto& mode = at::impl::PythonTorchFunctionTLS::pop_stack(); + auto* r = mode->ptr(getPyInterpreter()); + Py_INCREF(r); + return r; + END_HANDLE_TH_ERRORS +} + +static PyObject* get_function_stack_at( + PyObject* _unused, + PyObject* args, + PyObject* kwargs) { + HANDLE_TH_ERRORS + static PythonArgParser parser({"get_stack_at(int64_t level)"}); + + ParsedArgs<1> parsed_args; + auto _r = parser.parse(args, kwargs, parsed_args); + + auto idx = _r.toInt64(0); + const auto& mode = at::impl::PythonTorchFunctionTLS::get_stack_at(idx); + auto* r = mode->ptr(getPyInterpreter()); + Py_INCREF(r); + return r; + END_HANDLE_TH_ERRORS +} + +static PyObject* len_torch_function_stack( + PyObject* _unused, + PyObject* _unused2) { + HANDLE_TH_ERRORS + const auto len = at::impl::PythonTorchFunctionTLS::stack_len(); + return utils::wrap(static_cast(len)); + END_HANDLE_TH_ERRORS +} + +static PyObject* set_torch_dispatch_mode(PyObject* _unused, PyObject* arg) { + HANDLE_TH_ERRORS + if (arg == Py_None) { + c10::impl::TorchDispatchModeTLS::set_mode(nullptr); + } else { + Py_INCREF(arg); + c10::impl::TorchDispatchModeTLS::set_mode( + std::make_shared(arg, getPyInterpreter())); + } + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS; +} + +static PyObject* push_on_torch_dispatch_stack( + PyObject* _unused, + PyObject* arg) { + HANDLE_TH_ERRORS + if (arg != Py_None) { + Py_INCREF(arg); + c10::impl::TorchDispatchModeTLS::push_onto_stack( + std::make_shared(arg, getPyInterpreter())); + } + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +static PyObject* pop_torch_dispatch_stack( + PyObject* _unused, + PyObject* _unused2) { + HANDLE_TH_ERRORS + const auto& mode = c10::impl::TorchDispatchModeTLS::pop_stack(); + auto* r = mode->ptr(getPyInterpreter()); + Py_INCREF(r); + return r; + END_HANDLE_TH_ERRORS +} + +static PyObject* get_dispatch_stack_at( + PyObject* _unused, + PyObject* args, + PyObject* kwargs) { + HANDLE_TH_ERRORS + static PythonArgParser parser({"get_stack_at(int64_t level)"}); + + ParsedArgs<1> parsed_args; + auto _r = parser.parse(args, kwargs, parsed_args); + + auto idx = _r.toInt64(0); + const auto& mode = c10::impl::TorchDispatchModeTLS::get_stack_at(idx); + auto* r = mode->ptr(getPyInterpreter()); + Py_INCREF(r); + return r; + END_HANDLE_TH_ERRORS +} + +static PyObject* len_torch_dispatch_stack( + PyObject* _unused, + PyObject* _unused2) { + HANDLE_TH_ERRORS + const auto len = c10::impl::TorchDispatchModeTLS::stack_len(); + return utils::wrap(static_cast(len)); + END_HANDLE_TH_ERRORS +} + // autograd methods on torch._C static PyMethodDef methods[] = { // NOLINT {"_set_grad_enabled", set_grad_enabled, METH_O, nullptr}, @@ -685,10 +769,44 @@ static PyMethodDef methods[] = { // NOLINT castPyCFunctionWithKeywords(python_exit_dual_level), METH_VARARGS | METH_KEYWORDS, nullptr}, - {"_set_torch_dispatch_mode", set_torch_dispatch_mode, METH_O, nullptr}, - {"_get_torch_dispatch_mode", get_torch_dispatch_mode, METH_NOARGS, nullptr}, + {"_is_torch_function_mode_enabled", + is_torch_function_mode_enabled, + METH_NOARGS, + nullptr}, {"_set_torch_function_mode", set_torch_function_mode, METH_O, nullptr}, - {"_get_torch_function_mode", get_torch_function_mode, METH_NOARGS, nullptr}, + {"_push_on_torch_function_stack", + push_on_torch_function_stack, + METH_O, + nullptr}, + {"_pop_torch_function_stack", + pop_torch_function_stack, + METH_NOARGS, + nullptr}, + {"_get_function_stack_at", + castPyCFunctionWithKeywords(get_function_stack_at), + METH_VARARGS | METH_KEYWORDS, + nullptr}, + {"_len_torch_function_stack", + len_torch_function_stack, + METH_NOARGS, + nullptr}, + {"_set_torch_dispatch_mode", set_torch_dispatch_mode, METH_O, nullptr}, + {"_push_on_torch_dispatch_stack", + push_on_torch_dispatch_stack, + METH_O, + nullptr}, + {"_pop_torch_dispatch_stack", + pop_torch_dispatch_stack, + METH_NOARGS, + nullptr}, + {"_get_dispatch_stack_at", + castPyCFunctionWithKeywords(get_dispatch_stack_at), + METH_VARARGS | METH_KEYWORDS, + nullptr}, + {"_len_torch_dispatch_stack", + len_torch_dispatch_stack, + METH_NOARGS, + nullptr}, {nullptr, nullptr, 0, nullptr}}; PyMethodDef* python_functions() { diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp index 18da2de55cb5..744269936cdd 100644 --- a/torch/csrc/utils/python_arg_parser.cpp +++ b/torch/csrc/utils/python_arg_parser.cpp @@ -293,7 +293,7 @@ auto handle_torch_function_no_python_arg_parser( torch_function_name == TorchFunctionName::TorchFunction; auto get_mode = [&]() { return is_torch_function ? at::impl::PythonTorchFunctionTLS::get_mode() - : c10::impl::TorchDispatchModeTLS::get_state(); + : c10::impl::TorchDispatchModeTLS::get_mode(); }; const auto& maybe_mode = get_mode(); diff --git a/torch/csrc/utils/python_torch_function_mode.h b/torch/csrc/utils/python_torch_function_mode.h index 1ab703dc5032..5faf75778469 100644 --- a/torch/csrc/utils/python_torch_function_mode.h +++ b/torch/csrc/utils/python_torch_function_mode.h @@ -19,7 +19,7 @@ struct StashTorchFunctionModeGuard { } private: - std::shared_ptr old_mode_ = nullptr; + std::shared_ptr old_mode_; }; } // namespace overrides diff --git a/torch/csrc/utils/torch_dispatch_mode.h b/torch/csrc/utils/torch_dispatch_mode.h index abc4c4e56fe8..81b219f71c09 100644 --- a/torch/csrc/utils/torch_dispatch_mode.h +++ b/torch/csrc/utils/torch_dispatch_mode.h @@ -8,16 +8,15 @@ namespace torch_dispatch_mode { struct StashTorchDispatchModeGuard { public: StashTorchDispatchModeGuard() { - saved_ = c10::impl::TorchDispatchModeTLS::get_state(); - c10::impl::TorchDispatchModeTLS::set_state(nullptr); + c10::impl::TorchDispatchModeTLS::swap_mode(saved_mode_); } ~StashTorchDispatchModeGuard() { - c10::impl::TorchDispatchModeTLS::set_state(saved_); + c10::impl::TorchDispatchModeTLS::set_mode(std::move(saved_mode_)); } private: - std::shared_ptr saved_; + std::shared_ptr saved_mode_; }; } // namespace torch_dispatch_mode diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index 67ec854ce1a2..86999dc7e58e 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -19,7 +19,7 @@ from dataclasses import dataclass import weakref import operator -from torch.utils._python_dispatch import TorchDispatchMode, enable_torch_dispatch_mode +from torch.utils._python_dispatch import TorchDispatchMode, _pop_mode_temporarily, _get_current_dispatch_mode from torch._subclasses import FakeTensor from .symbolic_shapes import ShapeEnv, SymDispatchMode, PySymInt, PySymFloat from torch.fx import Proxy @@ -167,9 +167,9 @@ def track_tensor_tree(inner_res, proxy_res, *, constant, tracer): def maybe_disable_fake_tensor_mode(): # TODO: figure out if this API generally makes sense and bake it into the # library - mb_fake_mode = torch._C._get_torch_dispatch_mode() + mb_fake_mode = _get_current_dispatch_mode() if isinstance(mb_fake_mode, FakeTensorMode): - return enable_torch_dispatch_mode(mb_fake_mode.inner, replace=mb_fake_mode) + return _pop_mode_temporarily() else: return nullcontext() @@ -432,7 +432,7 @@ class ProxyTorchDispatchMode(TorchDispatchMode): @contextmanager def restore(self): with self.sym_mode.enable(True): - with super().restore(): + with self: yield def inner_torch_dispatch(self, func, types, args=(), kwargs=None): @@ -637,12 +637,7 @@ def make_fx(f, decomposition_table=None, tracing_mode="real"): def get_torch_dispatch_modes(): - modes = [torch._C._get_torch_dispatch_mode()] - if modes[-1] is None: - return list() - while modes[-1].inner is not None: - modes.append(modes[-1].inner) - return modes + return torch.utils._python_dispatch._get_current_dispatch_mode_stack() @contextlib.contextmanager diff --git a/torch/overrides.py b/torch/overrides.py index eca5249f3077..541d8e2c0750 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -32,9 +32,9 @@ import contextlib import torch from torch._C import ( _has_torch_function, _has_torch_function_unary, - _has_torch_function_variadic, _add_docstr, _set_torch_function_mode, _get_torch_function_mode) - -from torch.utils._mode_utils import _enable_mode, _ModeInfo, _restore_mode + _has_torch_function_variadic, _add_docstr, + _push_on_torch_function_stack, _pop_torch_function_stack, _get_function_stack_at, _len_torch_function_stack, + _set_torch_function_mode, _is_torch_function_mode_enabled) __all__ = [ "get_ignored_functions", @@ -1510,8 +1510,8 @@ def handle_torch_function( types = tuple(map(type, overloaded_args)) # Check for __torch_function__ mode. - mode = _get_torch_function_mode() - if mode is not None: + if _is_torch_function_mode_enabled(): + mode = _get_current_function_mode() # NB: unlike on tensors, modes are instances with _no_torch_function_mode(): result = mode.__torch_function__(public_api, types, args, kwargs) @@ -1541,8 +1541,8 @@ def handle_torch_function( "no implementation found for '{}' on types that implement " '__torch_function__: {}' ).format(func_name, [type(arg) for arg in overloaded_args]) - if mode is not None: - msg += f" nor in mode {mode}" + if _is_torch_function_mode_enabled(): + msg += f" nor in mode {_get_current_function_mode()}" raise TypeError(msg) has_torch_function = _add_docstr( @@ -1766,53 +1766,7 @@ def is_tensor_like(inp): """ return type(inp) is torch.Tensor or hasattr(type(inp), "__torch_function__") - -def _wrap_torch_function(f): - @functools.wraps(f) - def wrapped(self, *args, **kwargs): - if isinstance(f, classmethod): - raise RuntimeError("TorchFunctionMode's torch_function function " + - "should be a normal method not a class method") - inner = getattr(self, "inner", None) - - with enable_torch_function_mode(inner): - return f(self, *args, **kwargs) - return wrapped - - -# Implementation note: I had a choice about how much of mode stacks -# to implement in Python versus in C++. At time of writing, I did not care -# too much about implementation efficiency; however, I do care about making it -# hard for users to implement modes in the wrong way. In the end, it turned -# out to be possible to implement mode stacks entirely from userland, with the -# C++ API providing only _get_torch_function_mode() and -# _set_torch_function_mode(), so I opted to provide some unsafe C++ bindings and -# have the bulk of the logic for managing the stack in Python, which helped -# simplify the C++ API surface. It would also have been valid to build in the -# notion of mode stack directly into C++ but in this design it's substantially -# more difficult to interact with TorchFunctionModeMeta. -class TorchFunctionModeMeta(type): - """ - Metaclass for :class:`TorchFunctionMode`; it does two things: - - * Adds an implicit ``inner`` kwarg to ``__init__``, to - allow the modes to be chained together to form a stack. - - * Reenables the inner mode, so that by default PyTorch API calls - will compositionally proceed to the next mode on the stack. - - The default behavior for the second bullet is important, as it is easy to - accidentally write ``__torch_function__`` implementations that are not - compositional, and the wrapping here makes the obvious code do the - right thing (aka, this is why there is a metaclass). - """ - def __new__(metacls, name, bases, dct): - if '__torch_function__' in dct: - dct['__torch_function__'] = _wrap_torch_function(dct['__torch_function__']) - return super().__new__(metacls, name, bases, dct) - - -class TorchFunctionMode(metaclass=TorchFunctionModeMeta): +class TorchFunctionMode: """ A ``TorchFunctionMode`` allows you to override the meaning of all ``__torch_function__`` overrideable functions within a dynamic scope, @@ -1852,24 +1806,11 @@ class TorchFunctionMode(metaclass=TorchFunctionModeMeta): raise NotImplementedError() def __enter__(self): - old = _get_torch_function_mode() - if hasattr(self, "inner"): - raise RuntimeError(f"{self} has already been used as a mode. Please use a fresh version or use restore") - else: - self.inner = old - if old is None: - self.ancestors = set() - else: - self.ancestors = self.inner.ancestors.union({self.inner}) - _set_torch_function_mode(self) + _push_mode(self) return self def __exit__(self, exc_type, exc_val, exc_tb): - _set_torch_function_mode(self.inner) - - @contextlib.contextmanager - def restore(self): - return _restore_mode(self, mode_info=_TorchFunctionModeInfo()) + _pop_mode() @classmethod def push(cls, *args, **kwargs): @@ -1877,6 +1818,50 @@ class TorchFunctionMode(metaclass=TorchFunctionModeMeta): instance = cls(*args, **kwargs) return instance + +def _get_current_function_mode(): + stack_len = _len_torch_function_stack() + return _get_function_stack_at(stack_len - 1) if stack_len > 0 else None + + +def _get_current_function_mode_stack(): + stack_len = _len_torch_function_stack() + return [_get_function_stack_at(i) for i in range(stack_len)] + +def _push_mode(mode): + if _len_torch_function_stack() == 0: + _set_torch_function_mode(_TorchFunctionStackMode()) + _push_on_torch_function_stack(mode) + + +def _pop_mode(): + old = _pop_torch_function_stack() + if _len_torch_function_stack() == 0: + _set_torch_function_mode(None) + return old + + +@contextlib.contextmanager +def _pop_mode_temporarily(): + old = _pop_mode() + try: + yield old + finally: + _push_mode(old) + +# a helper "mode" used by the torch_function push helper method. This is the only mode that will ever +# be active at the C++ level and it will run the current mode +class _TorchFunctionStackMode: + def __torch_function__(self, func, types, args=(), kwargs=None): + with _pop_mode_temporarily() as old: + if _len_torch_function_stack() > 0: + _set_torch_function_mode(self) + # we can't check the type of __torch_function__ here but this is sufficient for checking it's a classmethod + if old.__torch_function__.__self__ is type(old): + raise RuntimeError("TorchFunctionMode's torch_function function " + + "should be a normal method not a class method") + return old.__torch_function__(func, types, args, kwargs) + class BaseTorchFunctionMode(TorchFunctionMode): def __torch_function__(self, func, types, args=(), kwargs=None): if kwargs is None: @@ -1889,57 +1874,14 @@ class BaseTorchFunctionMode(TorchFunctionMode): # library code though, e.g., in handle_torch_function @contextlib.contextmanager def _no_torch_function_mode() -> Iterator[None]: - old = _get_torch_function_mode() _set_torch_function_mode(None) try: yield finally: - _set_torch_function_mode(old) + if _len_torch_function_stack() > 0: + _set_torch_function_mode(_TorchFunctionStackMode()) -class _TorchFunctionModeInfo(_ModeInfo): - def __init__(self): - super().__init__(mode_name="torch_function", mode_class=TorchFunctionMode) - - def get_mode(self): - return _get_torch_function_mode() - - def set_mode(self, mode): - return _set_torch_function_mode(mode) - - -@contextlib.contextmanager -def enable_torch_function_mode(mode, *, replace=None, ignore_preexisting=False) -> Iterator[None]: - """ - Context manager that sets the current :class:`TorchFunctionMode`; see the - class for more information on what modes are. This function is - non-compositional; if there is already an existing mode, it will raise an - error; prefer using ``with MyMode():`` if your ``__torch_function__`` - implementation can defer to an inner mode. - - This function is safe to use inside a ``__torch_function__`` mode handler, - as the mode is guaranteed to be disabled in this context. You can use - this context manager to reinstate the mode so that calls to overridable - APIs recursively call back into your mode handler (this can easily cause - infinite loops, so use with care!) - - Args: - mode (:class:`TorchFunctionMode`, Tensor-like class or None): the - mode to set as current mode. If you pass a Tensor-like class, - it will be treated as a non-compositional mode with no state, - which is convenient if you have an existing tensor subclass - that you'd like to apply globally in a quick and dirty way. - Passing None will disable the current mode. - replace (:class:`TorchFunctionMode` or Tensor-like class): the - mode to replace. You can use this argument to change the mode in - a situation where you know what the current mode is (and you are - intentionally overwriting it.) If you don't know what the current - mode is, use ``ignore_preexisting`` instead. - ignore_preexisting (bool): if True, ignore any preexisting mode - and overwrite it with the passed mode. - """ - return _enable_mode(mode, _TorchFunctionModeInfo(), replace=replace, ignore_preexisting=ignore_preexisting) - class enable_reentrant_dispatch(): def __enter__(self): self._raii_guard = torch._C._RestorePythonTLSSnapshot() diff --git a/torch/testing/_internal/composite_compliance.py b/torch/testing/_internal/composite_compliance.py index dadc9675fe8b..0eaab2e1796d 100644 --- a/torch/testing/_internal/composite_compliance.py +++ b/torch/testing/_internal/composite_compliance.py @@ -1,10 +1,11 @@ import torch from torch import Tensor import itertools + +from torch.utils._python_dispatch import TorchDispatchMode from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten from functools import partial -from torch.utils._mode_utils import no_dispatch -from torch.utils._python_dispatch import enable_torch_dispatch_mode +from torch.utils._mode_utils import no_dispatch, all_same_mode import torch.autograd.forward_ad as fwAD from typing import Callable import re @@ -99,7 +100,7 @@ def is_inplace(func): return name[-1] == '_' -def generate_cct(autograd_view_consistency=True): +def generate_cct_and_mode(autograd_view_consistency=True): # This function returns a new class CompositeCompliantTensor # The two arguments control the behaviour described below. @@ -116,7 +117,7 @@ def generate_cct(autograd_view_consistency=True): __torch_function__ = torch._C._disabled_torch_function_impl @staticmethod - def __new__(cls, elem, *args, **kwargs): + def __new__(cls, elem, mode, *args, **kwargs): assert type(elem) is not cls, \ "Wrapping a CompositeCompliantTensor in a CompositeCompliantTensor is not supported" @@ -148,6 +149,8 @@ def generate_cct(autograd_view_consistency=True): # Ref: https://github.com/albanD/subclass_zoo/issues/21 torch._C._set_conj(r, r.elem.is_conj()) torch._C._set_neg(r, r.elem.is_neg()) + + r.mode = mode return r def __repr__(self): @@ -155,11 +158,20 @@ def generate_cct(autograd_view_consistency=True): @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + all_args = tree_flatten(args)[0] + tree_flatten(kwargs)[0] + modes = tuple(e.mode for e in all_args if isinstance(e, CompositeCompliantTensor)) + if not all_same_mode(modes): + raise RuntimeError("Multiple CompositeCompliantTensorModes NYI") + with modes[0]: + return func(*args, **kwargs) + + class CompositeCompliantTensorMode(TorchDispatchMode): + def __torch_dispatch__(self, func, types, args=(), kwargs=None): def unwrap(e): return e.elem if isinstance(e, CompositeCompliantTensor) else e def wrap(e): - return CompositeCompliantTensor(e) if isinstance(e, torch.Tensor) else e + return CompositeCompliantTensor(e, self) if isinstance(e, torch.Tensor) else e if func == torch.ops.aten._local_scalar_dense.default: raise RuntimeError( @@ -224,13 +236,13 @@ def generate_cct(autograd_view_consistency=True): # For each CompositeCompliantTensor t, we check that t and t.elem # have consistent metadata. If they don't have consistent metadata, # that means the operator did something fishy. - check = partial(check_metadata_consistency, CCT=cls) + check = partial(check_metadata_consistency, CCT=CompositeCompliantTensor) tree_map(check, args) tree_map(check, kwargs) tree_map(check, rs) return rs - return CompositeCompliantTensor + return CompositeCompliantTensor, CompositeCompliantTensorMode() def is_tensorlist(lst): if not isinstance(lst, list) and not isinstance(lst, tuple): @@ -251,12 +263,12 @@ def maybe_map(fn, should_map, arg): return fn(arg) if should_map else arg -def wrap(arg, CCT): - # CCT: CompositeCompliantTensor class which is generated using generate_cct +def wrap(arg, CCT, cct_mode): + # CCT: CompositeCompliantTensor class which is generated using generate_cct_and_mode if isinstance(arg, torch.Tensor): - return CCT(arg) + return CCT(arg, cct_mode) if is_tensorlist(arg): - return [CCT(a) for a in arg] + return [CCT(a, cct_mode) for a in arg] raise RuntimeError("wrap assumes that the input can be wrapped") @@ -270,14 +282,14 @@ def wrap(arg, CCT): # [A, 1, B] # NB: Yes, this is exponential. No, we don't care too much because PyTorch ops # don't accept that many input Tensors. -def generate_subclass_choices(flat_args, CCT): - # CCT: CompositeCompliantTensor class which is generated using generate_cct +def generate_subclass_choices(flat_args, CCT, cct_mode): + # CCT: CompositeCompliantTensor class which is generated using generate_cct_and_mode is_tensor_likes = [isinstance(arg, torch.Tensor) or is_tensorlist(arg) for arg in flat_args] subclass_options = [[False, True] if is_tensor_like else [False] for is_tensor_like in is_tensor_likes] for which_args_are_wrapped in itertools.product(*subclass_options): - result = [maybe_map(partial(wrap, CCT=CCT), should_wrap_arg, arg) + result = [maybe_map(partial(wrap, CCT=CCT, cct_mode=cct_mode), should_wrap_arg, arg) for should_wrap_arg, arg in zip(which_args_are_wrapped, flat_args)] yield result, which_args_are_wrapped @@ -285,11 +297,11 @@ def generate_subclass_choices(flat_args, CCT): # For an operation f(*args, **kwargs), each Tensor argument may either be # a regular Tensor or a Tensor Subclass. This iterator iterates through # all of those options. -def generate_subclass_choices_args_kwargs(args, kwargs, CCT): - # CCT: CompositeCompliantTensor class which is generated using generate_cct +def generate_subclass_choices_args_kwargs(args, kwargs, CCT, cct_mode): + # CCT: CompositeCompliantTensor class which is generated using generate_cct_and_mode flat_kwargs, spec = tree_flatten(kwargs) flat_args_kwargs = list(args) + list(flat_kwargs) - for choice, debug_metadata in generate_subclass_choices(flat_args_kwargs, CCT): + for choice, debug_metadata in generate_subclass_choices(flat_args_kwargs, CCT, cct_mode): new_args = choice[:len(args)] new_kwargs = tree_unflatten(choice[len(args):], spec) which_args_are_wrapped = debug_metadata[:len(args)] @@ -320,9 +332,9 @@ def raise_composite_compliance_error(err, additional_info=''): # If some composite operation does any non-compliant behavior, # CompositeCompliantTensor will raise an error. def check_all_permutations(op, args, kwargs, assert_equal_fn): - CCT = generate_cct() + CCT, cct_mode = generate_cct_and_mode() expected = op(*args, **kwargs) - for choice in generate_subclass_choices_args_kwargs(args, kwargs, CCT): + for choice in generate_subclass_choices_args_kwargs(args, kwargs, CCT, cct_mode): new_args, new_kwargs, which_args_are_wrapped, which_kwargs_are_wrapped = choice try: @@ -365,17 +377,17 @@ def check_all_permutations(op, args, kwargs, assert_equal_fn): # Composite does any non-compliant behavior, # CompositeCompliantTensor will raise an error. def check_with_mode(op, args, kwargs, assert_equal_fn): - CCT = generate_cct() + CCT, cct_mode = generate_cct_and_mode() def wrap(e): - return CCT(e) if isinstance(e, torch.Tensor) else e + return CCT(e, cct_mode) if isinstance(e, torch.Tensor) else e expected = op(*args, **kwargs) args = tree_map(wrap, args) kwargs = tree_map(wrap, kwargs) try: - with enable_torch_dispatch_mode(CCT): + with cct_mode: actual = op(*args, **kwargs) # see NOTE: [What errors are Composite Compliance trying to catch?] except RuntimeError as err: @@ -429,11 +441,11 @@ def compute_expected_grads(op, args, kwargs, output_process_fn_grad=None, gradch def check_backward_formula(op: Callable, args, kwargs, output_process_fn_grad=None, gradcheck_wrapper=None, assert_equal_fn=None): - CCT = generate_cct() + CCT, cct_mode = generate_cct_and_mode() expected = compute_expected_grads(op, args, kwargs, output_process_fn_grad, gradcheck_wrapper) - for choice in generate_subclass_choices_args_kwargs(args, kwargs, CCT): + for choice in generate_subclass_choices_args_kwargs(args, kwargs, CCT, cct_mode): new_args, new_kwargs, which_args_are_wrapped, which_kwargs_are_wrapped = choice leaf_tensors = gather_leaf_tensors(new_args, new_kwargs) assert len(leaf_tensors) > 0 @@ -460,7 +472,7 @@ def check_backward_formula(op: Callable, args, kwargs, # NB: ones, not ones_like, so we get a regular Tensor here grads = [torch.ones(r.shape, device=r.device, dtype=r.dtype) for r in flat_diff_results] - for flat_new_grads, which_grad_is_batched in generate_subclass_choices(grads, CCT): + for flat_new_grads, which_grad_is_batched in generate_subclass_choices(grads, CCT, cct_mode): try: actual = torch.autograd.grad(flat_diff_results, leaf_tensors, flat_new_grads, allow_unused=True, retain_graph=True) @@ -486,7 +498,7 @@ def check_backward_formula(op: Callable, args, kwargs, # this means we can apply check_forward_ad_formula to things that aren't OpInfos # while debugging. def check_forward_ad_formula(op: Callable, args, kwargs, gradcheck_wrapper=None, assert_equal_fn=None): - CCT = generate_cct(autograd_view_consistency=False) + CCT, cct_mode = generate_cct_and_mode(autograd_view_consistency=False) def maybe_tangent(t): assert type(t) is not CCT @@ -529,11 +541,11 @@ def check_forward_ad_formula(op: Callable, args, kwargs, gradcheck_wrapper=None, expected_tangents = tree_map(lambda x: x.tangent, expected) # Permutations of arg and kwargs in CCT. - for choice in generate_subclass_choices_args_kwargs(args, kwargs, CCT): + for choice in generate_subclass_choices_args_kwargs(args, kwargs, CCT, cct_mode): new_args, new_kwargs, which_args_are_wrapped, which_kwargs_are_wrapped = choice # Permutations tangent arg and tangent kwargs in CCT. - for tang_choice in generate_subclass_choices_args_kwargs(tangent_args, tangent_kwargs, CCT): + for tang_choice in generate_subclass_choices_args_kwargs(tangent_args, tangent_kwargs, CCT, cct_mode): new_tang_args, new_tang_kwargs, \ which_tang_args_are_wrapped, which_tang_kwargs_are_wrapped = tang_choice diff --git a/torch/utils/_mode_utils.py b/torch/utils/_mode_utils.py index 1520cf622c7e..f9098c6d7ef4 100644 --- a/torch/utils/_mode_utils.py +++ b/torch/utils/_mode_utils.py @@ -1,125 +1,13 @@ import torch -from typing import Iterator, TypeVar -from dataclasses import dataclass +from typing import TypeVar from contextlib import contextmanager T = TypeVar('T') -# This file has all the logic to dedupe logic between torch dispatch and -# torch function modes -# -# Specifically, it has the helper functions for enable_ and push_X_mode and the -# ModeInfo class, which is extended by each where they are different - - -# in order to dedupe the logic between TorchDispatchMode and TorchFunctionMode, this -# is a container to hold all the differences between the modes. Then functions like -# _enable_mode are able to use this container to call functions or get correctly -# formatted names -@dataclass -class _ModeInfo: - mode_name: str - mode_class: type # the class related to the mode that's allowed to be passed in - - def mode_class_name(self): - return self.mode_class.__name__ - - def get_mode(self): - """gets the current mode for this type of mode""" - raise NotImplementedError() - - def set_mode(self, mode): - """ - set mode to for this type of mode. Note that no checks are done on this, it's the unsafe - version where checks are assumed to have been already done by the helper function - """ - raise NotImplementedError() - - -# shared version of enable_torch_function/enable_torch_dispatch_mode in order to deduplicate the code. -# The differences between the modes are captured by `mode_info` and then queried when they're -# needed during the function's invocation -def _enable_mode(mode: T, mode_info: _ModeInfo, *, replace=None, ignore_preexisting=False) -> Iterator[T]: - if not ( - mode is None or - isinstance(mode, mode_info.mode_class) or - (isinstance(mode, type) and not issubclass(mode, mode_info.mode_class)) - ): - raise ValueError(f'expected to get {mode_info.mode_class_name()}, Tensor-like class, ' - f'or None as an argument got {type(mode)} instead') - old = mode_info.get_mode() - if old is mode: - yield mode # type: ignore[misc] - return - if old is not None and not ignore_preexisting and old is not replace: - if isinstance(mode, mode_info.mode_class): - help_text = 'Use `with Mode():` instead.' - else: - help_text = ( - 'If you intended to completely override the preexisting mode, ' - 'pass ignore_preexisting=True. This can result in unexpected ' - 'behavior; please consider rewriting your mode to be a subclass ' - f'of {mode_info.mode_class_name()} to make it compositional!' - ) - raise ValueError( - f'Attempted to enable_{mode_info.mode_name}_mode, but there is already an ' - f'active mode {old}. {help_text}' - ) - # NB: we don't require TorchFunctionMode/TorchDispatchMode since this is intended to also - # let you directly pass a Tensor subclass type to "mode-ify" it. - if mode is not None: - required_fn = "__" + mode_info.mode_name + "__" - if not hasattr(mode, required_fn): - raise ValueError( - f'The argument passed to enable_{mode_info.mode_name}_mode must implement {required_fn}' - ) - mode_info.set_mode(mode) - try: - yield mode # type: ignore[misc] - finally: - mode_info.set_mode(old) - - -def _restore_mode(mode, mode_info: _ModeInfo): - if not hasattr(mode, "ancestors"): - raise RuntimeError(f"{mode} does not have any ancestors. Use the standard version instead of restore") - old = mode_info.get_mode() - if old is not None and old not in mode.ancestors: - raise RuntimeError(f"{mode} is not valid in the current state because the current mode is not its ancestor") - mode_info.set_mode(mode) - try: - yield mode - finally: - mode_info.set_mode(old) - - -# To help with non-lexical scoping, it will error if all the modes are from different scopes or haven't been used -def find_outermost_mode(modes): - outermost = None - for mode in modes: - if mode is not None: - if not hasattr(mode, "ancestors"): - raise RuntimeError(f"{mode}, doesn't have ancestors set so the ordering with other modes is unclear") - if outermost is None: - outermost = mode - elif mode not in outermost.ancestors and outermost not in mode.ancestors: - raise RuntimeError(f"modes {mode} and {outermost} are not compatible because they " - "don't come from the same scope") - elif outermost in mode.ancestors: - outermost = mode - return outermost - - # returns if all are the same mode def all_same_mode(modes): return all(tuple(mode == modes[0] for mode in modes)) -# returns if all modes are from the current scope, ``cur_mode`` -def all_same_mode_scope(modes, cur_mode): - if not hasattr(cur_mode, "ancestors"): - return False - return all(tuple(mode == cur_mode or mode in cur_mode.ancestors for mode in modes)) - @contextmanager def no_dispatch(): guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined] diff --git a/torch/utils/_python_dispatch.py b/torch/utils/_python_dispatch.py index a13dfa16ae0e..08ef67d7333f 100644 --- a/torch/utils/_python_dispatch.py +++ b/torch/utils/_python_dispatch.py @@ -1,115 +1,17 @@ import contextlib -from typing import Iterator, Set -import functools import warnings -from torch.utils._mode_utils import _enable_mode, _ModeInfo, _restore_mode -from torch._C import _get_torch_dispatch_mode, _set_torch_dispatch_mode -from dataclasses import dataclass - - -@dataclass -class TorchDispatchModeInfo(_ModeInfo): - def __init__(self): - super().__init__(mode_name="torch_dispatch", mode_class=TorchDispatchMode) - - def get_mode(self): - return _get_torch_dispatch_mode() - - def set_mode(self, mode): - return _set_torch_dispatch_mode(mode) +from torch._C import _len_torch_dispatch_stack, _get_dispatch_stack_at,\ + _pop_torch_dispatch_stack, _push_on_torch_dispatch_stack, _set_torch_dispatch_mode # TODO: Limitations and things about enable_torch_dispatch_mode we should fix before exposing it: -# - We need a better user-facing api for torch._C._DisableTorchDispatch that +# - We need a better user-facing api for _DisableTorchDispatch that # is able to selectively disable __torch_dispatch__ of a particular class. # - It doesn't work with the tensor constructors (torch.tensor, torch.Tensor) # - Better name (see https://github.com/pytorch/pytorch/pull/63496#discussion_r694091694) -@contextlib.contextmanager -def enable_torch_dispatch_mode(mode, *, replace=None, ignore_preexisting=False) -> Iterator[None]: - """ - Context manager that causes all pytorch operators to dispatch to the passed-in - type's __torch_dispatch__ function, including operations that accept no tensors - but return a tensor. - This function is non-compositional; if there is already an existing mode, - it will raise an error - - This function is safe to use inside a ``__torch_dispatch__`` mode handler, - as the mode is guaranteed to be disabled in this context. You can use - this context manager to reinstate the mode so that calls to overridable - APIs recursively call back into your mode handler (this can easily cause - infinite loops, so use with care!) - - enable_torch_dispatch_mode is affected by _DisableTorchDispatch. - - Args: - mode (:class:`TorchDispatchMode`, Tensor-like class, or None): the - mode to set as current mode. If you pass a Tensor-like class, - it will be treated as a non-compositional mode with no state, - which is convenient if you have an existing tensor subclass - that you'd like to apply globally in a quick and dirty way. - Passing None will disable the current mode. - replace (:class:`TorchDispatchMode` or Tensor-like class): the - mode to replace. You can use this argument to change the mode in - a situation where you know what the current mode is (and you are - intentionally overwriting it.) If you don't know what the current - mode is, use ``ignore_preexisting`` instead. - ignore_preexisting (bool): if True, ignore any preexisting mode - and overwrite it with the passed mode. - """ - - return _enable_mode(mode, mode_info=TorchDispatchModeInfo(), replace=replace, ignore_preexisting=ignore_preexisting) - - -def _wrap_torch_dispatch(f): - @functools.wraps(f) - def wrapped(self, *args, **kwargs): - if isinstance(f, classmethod): - raise RuntimeError("TorchDispatchMode's torch_dispatch function " + - "should be a normal method not a class method") - inner = getattr(self, "inner", None) - - with enable_torch_dispatch_mode(inner): - return f(self, *args, **kwargs) - return wrapped - - -# Implementation note, since this is based on TorchFunctionMode, this had the -# same dilemma: I had a choice about how much of mode stacks -# to implement in Python versus in C++. At time of writing, I did not care -# too much about implementation efficiency; however, I do care about making it -# hard for users to implement modes in the wrong way. In the end, it turned -# out to be possible to implement mode stacks entirely from userland, with the -# C++ API providing only _get_torch_dispatch_mode() and -# _set_torch_dispatch_mode(), so I opted to provide some unsafe C++ bindings and -# have the bulk of the logic for managing the stack in Python, which helped -# simplify the C++ API surface. It would also have been valid to build in the -# notion of mode stack directly into C++ but in this design it's substantially -# more difficult to interact with TorchDispatchModeMeta. - -class TorchDispatchModeMeta(type): - """ - Metaclass for :class:`TorchDispatchMode`; it does two things: - - * Adds an implicit ``inner`` kwarg to ``__init__``, to - allow the modes to be chained together to form a stack. - - * Reenables the inner mode, so that by default PyTorch API calls - will compositionally proceed to the next mode on the stack. - - The default behavior for the second bullet is important, as it is easy to - accidentally write ``_wrap_torch_dispatch`` implementations that are not - compositional, and the wrapping here makes the obvious code do the - right thing (aka, this is why there is a metaclass). - """ - def __new__(metacls, name, bases, dct): - if '__torch_dispatch__' in dct: - dct['__torch_dispatch__'] = _wrap_torch_dispatch(dct['__torch_dispatch__']) - return super().__new__(metacls, name, bases, dct) - - -class TorchDispatchMode(metaclass=TorchDispatchModeMeta): +class TorchDispatchMode: """ A ``TorchDispatchMode`` allows you to override the meaning of all ``__torch_dispatch__`` overrideable functions within a dynamic scope, @@ -139,32 +41,15 @@ class TorchDispatchMode(metaclass=TorchDispatchModeMeta): ``__torch_dispatch__(self)`` to make PyTorch API self-referential (beware of infinite loops, in this case!) """ - # Force metaclass to generate constructor at the base of the hierarchy - def __init__(self): - self.ancestors: Set[TorchDispatchMode] - def __torch_dispatch__(self, func, types, args=(), kwargs=None): raise NotImplementedError() def __enter__(self): - old = _get_torch_dispatch_mode() - if hasattr(self, "inner"): - raise RuntimeError(f"{self} has already been used as a mode. Please use a fresh version or use restore") - else: - self.inner = old - if old is None: - self.ancestors = set() - else: - self.ancestors = self.inner.ancestors.union({self.inner}) - _set_torch_dispatch_mode(self) + _push_mode(self) return self def __exit__(self, exc_type, exc_val, exc_tb): - _set_torch_dispatch_mode(self.inner) - - @contextlib.contextmanager - def restore(self): - return _restore_mode(self, mode_info=TorchDispatchModeInfo()) + _pop_mode() @classmethod def push(cls, *args, **kwargs): @@ -172,6 +57,48 @@ class TorchDispatchMode(metaclass=TorchDispatchModeMeta): instance = cls(*args, **kwargs) return instance +def _get_current_dispatch_mode(): + stack_len = _len_torch_dispatch_stack() + return _get_dispatch_stack_at(stack_len - 1) if stack_len > 0 else None + + +def _get_current_dispatch_mode_stack(): + stack_len = _len_torch_dispatch_stack() + return [_get_dispatch_stack_at(i) for i in range(stack_len)] + +def _push_mode(mode): + if _len_torch_dispatch_stack() == 0: + _set_torch_dispatch_mode(_TorchDispatchStackMode()) + _push_on_torch_dispatch_stack(mode) + + +def _pop_mode(): + old = _pop_torch_dispatch_stack() + if _len_torch_dispatch_stack() == 0: + _set_torch_dispatch_mode(None) + return old + + +@contextlib.contextmanager +def _pop_mode_temporarily(): + old = _pop_mode() + try: + yield old + finally: + _push_mode(old) + +# a helper "mode" used by the torch dispatch push helper method. This is the only mode that will ever +# be active at the C++ level and it will run the current mode +class _TorchDispatchStackMode: + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + with _pop_mode_temporarily() as old: + if _len_torch_dispatch_stack() > 0: + _set_torch_dispatch_mode(self) + # we can't check the type of __torch_dispatch__ here but this is sufficient for checking it's a classmethod + if old.__torch_dispatch__.__self__ is type(old): + raise RuntimeError(f"{type(old)}'s torch_dispatch function " + + "should be a normal method not a class method") + return old.__torch_dispatch__(func, types, args, kwargs) class BaseTorchDispatchMode(TorchDispatchMode): def __torch_dispatch__(self, func, types, args=(), kwargs=None):