[Modes] remove enable and rewrite mode stack (squashed) (#84774)

Based on @ezyang's suggestion, mode stack now has "one true mode" which is the _only_ mode that can ever be active at the C++ level. That mode's torch dispatch is just to take the top mode in the stack, reenable itself (if we aren't at the end of the mode stack), and run the top mode's torch_{dispatch|function}

This maintains that in the middle of a mode's torch dispatch, the mode itself will not be active. It changes the function the user has to call to see what the current mode is (no longer queries the C++, it's python only) but allows the user to also see the entire mode stack easily

Removes `enable_torch_dispatch_mode` and `.restore()` since neither makes sense in this new setup

### Background
Why do we want this? Well, a pretty common pattern that was coming up was that users had to do something like

```python
## PRE-PR UX
def f(mode):
  with mode.restore():  # user needs to understand this restore thing?
    ...

with Mode() as m:
  pass
f(m)
```

Many users were getting error from forgetting to call `.restore` or from forgetting to add the (tbh weird) "mode instantiation"  step where they use the mode as a context manager with an empty body. Really, they wanted to treat modes like context managers and just write
```python
## FROM FEEDBACK, USER DESIRED CODE. POSSIBLE POST-PR
def f(mode):
  with mode:
    ...
f(Mode())
```

** Technical Details **
With the old mode stack, we basically had a linked list so the mode itself could only be used once and had a fixed parent. In this new design, the mode stack is just a python list that we're pushing to and popping from. There's only one mode that's ever active at the C++ level and it runs the next mode in the Python list. The modes don't have state on them anymore
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84774
Approved by: https://github.com/ezyang, https://github.com/zou3519
This commit is contained in:
samdow
2022-09-26 16:42:07 -04:00
committed by PyTorch MergeBot
parent a0be0ca161
commit 18d8c548f4
28 changed files with 666 additions and 999 deletions

View File

@ -18,6 +18,26 @@ void PythonTorchFunctionTLS::swap_mode(std::shared_ptr<c10::SafePyObject>& mode)
pythonTorchFunctionState.mode_.swap(mode);
}
void PythonTorchFunctionTLS::push_onto_stack(std::shared_ptr<SafePyObject> mode) {
pythonTorchFunctionState.stack_.push_back(std::move(mode));
}
const std::shared_ptr<SafePyObject> PythonTorchFunctionTLS::pop_stack() {
TORCH_CHECK(pythonTorchFunctionState.stack_.size() > 0, "trying to pop from empty mode stack");
const auto out = pythonTorchFunctionState.stack_.back();
pythonTorchFunctionState.stack_.pop_back();
return out;
}
const std::shared_ptr<SafePyObject>& PythonTorchFunctionTLS::get_stack_at(int64_t idx) {
TORCH_CHECK(idx < static_cast<int64_t>(pythonTorchFunctionState.stack_.size()), "Tried to get stack at idx that's too big");
return pythonTorchFunctionState.stack_[idx];
}
int64_t PythonTorchFunctionTLS::stack_len() {
return pythonTorchFunctionState.stack_.size();
}
void PythonTorchFunctionTLS::set_disabled(bool disabled) {
pythonTorchFunctionState.disabled_ = disabled;
}
@ -34,5 +54,9 @@ const PythonTorchFunctionTLS& PythonTorchFunctionTLS::get_state() {
return pythonTorchFunctionState;
}
bool function_mode_enabled() {
return static_cast<bool>(PythonTorchFunctionTLS::get_mode());
}
} // namespace impl
} // namespace at

View File

@ -14,13 +14,28 @@ struct TORCH_API PythonTorchFunctionTLS {
static const std::shared_ptr<c10::SafePyObject>& get_mode();
static void swap_mode(std::shared_ptr<c10::SafePyObject>&);
static void set_state(const PythonTorchFunctionTLS& state);
static void push_onto_stack(std::shared_ptr<SafePyObject> mode);
static const std::shared_ptr<SafePyObject> pop_stack();
static const std::shared_ptr<SafePyObject>& get_stack_at(int64_t idx);
static int64_t stack_len();
static const PythonTorchFunctionTLS& get_state();
static void set_state(const PythonTorchFunctionTLS& state);
private:
// The mode TLS is split into
// - disabled_, which says whether or not to disable all torch function
// modes
// - mode_, which is the C++ mode, that can only be the mode handling mode
// or null
// - stack_, which is a vector of modes representing the stack of user
// defined modes
bool disabled_;
std::shared_ptr<c10::SafePyObject> mode_;
std::shared_ptr<c10::SafePyObject> mode_ = nullptr;
std::vector<std::shared_ptr<c10::SafePyObject>> stack_;
};
TORCH_API bool function_mode_enabled();
} // namespace impl
} // namespace at

View File

@ -56,7 +56,7 @@ class TORCH_API ThreadLocalState {
AutogradState autograd_tls_;
// TLS for enable_torch_dispatch_mode
std::shared_ptr<SafePyObject> torch_dispatch_mode_state_;
c10::impl::TorchDispatchModeTLS torch_dispatch_mode_state_;
// TLS for enable_python_dispatcher
c10::impl::PyInterpreter* python_dispatcher_state_;

View File

@ -52,7 +52,7 @@ void pythonFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
// If Torch Dispatch Mode is active, use its PyInterpreter for dispatch
const auto& maybe_torch_dispatch_mode_state = c10::impl::TorchDispatchModeTLS::get_state();
const auto& maybe_torch_dispatch_mode_state = c10::impl::TorchDispatchModeTLS::get_mode();
if (maybe_torch_dispatch_mode_state) {
maybe_torch_dispatch_mode_state->pyinterpreter()->dispatch(op, stack);
return;

View File

@ -547,7 +547,7 @@ c10::intrusive_ptr<TensorImpl> TensorImpl::shallow_copy_and_detach_core(
bool allow_tensor_metadata_change) const {
c10::intrusive_ptr<TensorImpl> r;
const auto& maybe_torch_dispatch_mode_state =
c10::impl::TorchDispatchModeTLS::get_state();
c10::impl::TorchDispatchModeTLS::get_mode();
// TODO: do we have to exclude after Python dispatch key set?
if (maybe_torch_dispatch_mode_state &&
!c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python)) {

View File

@ -6,32 +6,85 @@
namespace c10 {
namespace impl {
thread_local std::shared_ptr<SafePyObject> torchDispatchModeState;
thread_local TorchDispatchModeTLS torchDispatchModeState;
void TorchDispatchModeTLS::set_state(std::shared_ptr<SafePyObject> state) {
if (state) {
// MODE
void TorchDispatchModeTLS::set_mode(std::shared_ptr<SafePyObject> mode) {
if (mode) {
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, true);
c10::impl::tls_set_dispatch_key_included(
DispatchKey::PythonTLSSnapshot, true);
} else {
TorchDispatchModeTLS::reset_state();
TorchDispatchModeTLS::reset_mode();
}
torchDispatchModeState = std::move(state);
torchDispatchModeState.mode_ = std::move(mode);
}
const std::shared_ptr<SafePyObject>& TorchDispatchModeTLS::get_state() {
return torchDispatchModeState;
const std::shared_ptr<SafePyObject>& TorchDispatchModeTLS::get_mode() {
return torchDispatchModeState.mode_;
}
void TorchDispatchModeTLS::reset_state() {
torchDispatchModeState.reset();
void TorchDispatchModeTLS::reset_mode() {
torchDispatchModeState.mode_.reset();
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false);
c10::impl::tls_set_dispatch_key_included(
DispatchKey::PythonTLSSnapshot, false);
}
void TorchDispatchModeTLS::swap_mode(std::shared_ptr<SafePyObject>& mode) {
if (mode) {
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, true);
c10::impl::tls_set_dispatch_key_included(
DispatchKey::PythonTLSSnapshot, true);
} else {
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false);
c10::impl::tls_set_dispatch_key_included(
DispatchKey::PythonTLSSnapshot, false);
}
torchDispatchModeState.mode_.swap(mode);
}
// STACK
void TorchDispatchModeTLS::push_onto_stack(std::shared_ptr<SafePyObject> mode) {
torchDispatchModeState.stack_.push_back(std::move(mode));
}
const std::shared_ptr<SafePyObject> TorchDispatchModeTLS::pop_stack() {
TORCH_CHECK(
torchDispatchModeState.stack_.size() > 0,
"trying to pop from empty mode stack");
const std::shared_ptr<SafePyObject> out =
torchDispatchModeState.stack_.back();
torchDispatchModeState.stack_.pop_back();
return out;
}
const std::shared_ptr<SafePyObject>& TorchDispatchModeTLS::get_stack_at(
int64_t idx) {
TORCH_CHECK(
idx < static_cast<int64_t>(torchDispatchModeState.stack_.size()),
"Tried to get stack at idx that's too big");
return torchDispatchModeState.stack_[idx];
}
int64_t TorchDispatchModeTLS::stack_len() {
return torchDispatchModeState.stack_.size();
}
// STATE
const TorchDispatchModeTLS& TorchDispatchModeTLS::get_state() {
return torchDispatchModeState;
}
void TorchDispatchModeTLS::set_state(const TorchDispatchModeTLS& state) {
torchDispatchModeState = state;
}
// UTIL
bool dispatch_mode_enabled() {
return static_cast<bool>(c10::impl::TorchDispatchModeTLS::get_state());
return static_cast<bool>(c10::impl::TorchDispatchModeTLS::get_mode());
}
} // namespace impl

View File

@ -9,9 +9,27 @@ namespace c10 {
namespace impl {
struct C10_API TorchDispatchModeTLS {
static void set_state(std::shared_ptr<SafePyObject> state);
static const std::shared_ptr<SafePyObject>& get_state();
static void reset_state();
static void set_mode(std::shared_ptr<SafePyObject> mode);
static const std::shared_ptr<SafePyObject>& get_mode();
static void reset_mode();
static void swap_mode(std::shared_ptr<SafePyObject>& mode);
static void push_onto_stack(std::shared_ptr<SafePyObject> mode);
static const std::shared_ptr<SafePyObject> pop_stack();
static const std::shared_ptr<SafePyObject>& get_stack_at(int64_t idx);
static int64_t stack_len();
static const TorchDispatchModeTLS& get_state();
static void set_state(const TorchDispatchModeTLS& state);
private:
// The mode TLS is split into
// - mode_, which is the C++ mode, that can only be the mode handling mode
// or null
// - stack_, which is a vector of modes representing the stack of user
// defined modes
std::shared_ptr<c10::SafePyObject> mode_;
std::vector<std::shared_ptr<c10::SafePyObject>> stack_;
};
C10_API bool dispatch_mode_enabled();

View File

@ -4,7 +4,7 @@ from torch._ops import PyOperator
from torch.utils._pytree import tree_flatten
from torch.fx.experimental.proxy_tensor import get_isolated_graphmodule, get_proxy_slot
import torch.utils._pytree as pytree
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._python_dispatch import TorchDispatchMode, _get_current_dispatch_mode
from torch.fx.experimental.proxy_tensor import track_tensor_tree
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode
@ -29,15 +29,6 @@ def suspend_mode(mode):
finally:
torch._C._set_torch_dispatch_mode(mode)
@contextmanager
def enable_mode(mode):
curr_mode = torch._C._get_torch_dispatch_mode()
torch._C._set_torch_dispatch_mode(mode)
try:
yield
finally:
torch._C._set_torch_dispatch_mode(curr_mode)
def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
def _unwrap_proxy(e):
@ -104,7 +95,7 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
@cond.py_impl(DispatchKey.CPU)
def cond_dense(pred, true_fn, false_fn, operands):
mode = torch._C._get_torch_dispatch_mode()
mode = _get_current_dispatch_mode()
assert (mode is None), "Mode should never be enabled for CPU key"
if pred:
return true_fn(*operands)
@ -125,7 +116,7 @@ def cond_autograd(pred, true_fn, false_fn, *operands):
@cond.py_impl(ProxyTorchDispatchMode)
def inner(pred, true_fn, false_fn, operands):
mode = torch._C._get_torch_dispatch_mode()
mode = _get_current_dispatch_mode()
assert (mode is not None), "Mode should always be enabled for python fallback key"
with suspend_mode(mode):
res = trace_cond(mode, cond, pred, true_fn, false_fn, operands)

View File

@ -3,8 +3,8 @@
from collections import defaultdict
from torch import Tensor
import torch.autograd
from torch.utils._python_dispatch import enable_torch_dispatch_mode
from torch._decomp import decomposition_table
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
from torch.utils._mode_utils import no_dispatch
@ -384,17 +384,16 @@ class TestDecomp(TestCase):
saved_precision = self.precision
saved_rel_tol = self.rel_tol
test_case = self
class DecompCrossRefMode(torch.Tensor):
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
class DecompCrossRefMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
with no_dispatch():
return cls._torch_dispatch(func, types, args, kwargs)
return self._torch_dispatch(func, types, args, kwargs)
@classmethod
def _torch_dispatch(cls, func, types, args=(), kwargs=None):
self.precision = saved_precision
self.rel_tol = saved_rel_tol
def _torch_dispatch(self, func, types, args=(), kwargs=None):
test_case.precision = saved_precision
test_case.rel_tol = saved_rel_tol
called.add(func)
all_called[func] += 1
@ -441,14 +440,14 @@ class TestDecomp(TestCase):
assert type(orig) == type(decomp)
assert orig == decomp
continue
op_assert_ref(self, func, test_dtype, i, orig, decomp, ref, args, kwargs)
op_assert_ref(test_case, func, test_dtype, i, orig, decomp, ref, args, kwargs)
else:
for orig, decomp in zip(real_out, decomp_out):
if not isinstance(orig, torch.Tensor):
assert type(orig) == type(decomp)
assert orig == decomp
continue
op_assert_equal(self, func, test_dtype, orig, decomp, args, kwargs)
op_assert_equal(test_case, func, test_dtype, orig, decomp, args, kwargs)
return real_out_unflat
@ -487,7 +486,7 @@ class TestDecomp(TestCase):
# explicit clearing is necessary as I will create a fresh mode
# for each region
decomposed.clear()
with enable_torch_dispatch_mode(DecompCrossRefMode), enable_python_dispatcher():
with DecompCrossRefMode(), enable_python_dispatcher():
decomp_out, decomp_vjp_fn = ref_vjp_no_create(fn, *primals)
if aten_name in decomposition_names:
check_decomposed(aten_name)
@ -496,7 +495,7 @@ class TestDecomp(TestCase):
cotangents = tree_map(lambda x: torch.randn_like(x), decomp_out)
decomposed.clear()
with enable_torch_dispatch_mode(DecompCrossRefMode), enable_python_dispatcher():
with DecompCrossRefMode(), enable_python_dispatcher():
decomp_vjp_fn(cotangents)
if not run_all:
check_decomposed(op.aten_backward_name)
@ -505,7 +504,7 @@ class TestDecomp(TestCase):
args = [sample_input.input] + list(sample_input.args)
kwargs = sample_input.kwargs
decomposed.clear()
with enable_torch_dispatch_mode(DecompCrossRefMode), enable_python_dispatcher():
with DecompCrossRefMode(), enable_python_dispatcher():
func(*args, **kwargs)
if not run_all:
check_decomposed(aten_name)

View File

@ -12,7 +12,6 @@ from torch._subclasses.fake_tensor import (
DynamicOutputShapeException,
)
from torch.testing import FileCheck
from torch.utils._python_dispatch import enable_torch_dispatch_mode
from torch import nn
import unittest
import torch._prims as prims
@ -27,10 +26,9 @@ class FakeTensorTest(TestCase):
self.assertEqual(list(t.size()), size)
def test_basic(self):
mode = FakeTensorMode()
x = torch.empty(2, 2, device="cpu")
y = torch.empty(4, 2, 2, device="cpu")
with enable_torch_dispatch_mode(mode):
with FakeTensorMode() as mode:
x = mode.from_tensor(x)
y = mode.from_tensor(y)
z = x + y
@ -39,7 +37,7 @@ class FakeTensorTest(TestCase):
self.assertTrue(isinstance(z, FakeTensor))
def test_parameter_instantiation(self):
with enable_torch_dispatch_mode(FakeTensorMode()):
with FakeTensorMode():
x = torch.rand([4])
y = torch.nn.parameter.Parameter(x)
self.assertTrue(isinstance(y, torch.nn.Parameter))
@ -52,14 +50,14 @@ class FakeTensorTest(TestCase):
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_index_cuda_with_cpu(self):
with enable_torch_dispatch_mode(FakeTensorMode()):
with FakeTensorMode():
x = torch.rand([2048], device='cuda')
out = x[torch.zeros([36], dtype=torch.int64)]
self.checkType(out, "cuda", [36])
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_shape_take_not_device(self):
with enable_torch_dispatch_mode(FakeTensorMode()):
with FakeTensorMode():
x = torch.empty(1, device="cpu")
y = torch.empty(8, 8, device="cuda")
out = x.resize_as_(y)
@ -69,8 +67,7 @@ class FakeTensorTest(TestCase):
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_zero_dim(self):
mode = FakeTensorMode()
with enable_torch_dispatch_mode(mode):
with FakeTensorMode() as mode:
x = torch.tensor(0.)
y = torch.rand([4, 4], device="cuda")
out = x + y
@ -79,8 +76,7 @@ class FakeTensorTest(TestCase):
self.assertTrue(isinstance(out, FakeTensor))
def test_nan_to_num(self):
mode = FakeTensorMode()
with enable_torch_dispatch_mode(mode):
with FakeTensorMode():
for dtype in [torch.float16, torch.float32]:
x = torch.rand([4], dtype=dtype)
y = torch.nan_to_num(x, nan=None)
@ -90,9 +86,8 @@ class FakeTensorTest(TestCase):
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_throw(self):
mode = FakeTensorMode()
x = torch.tensor(0.) # TODO: tensor() errors
with enable_torch_dispatch_mode(mode):
with FakeTensorMode() as mode:
x_conv = mode.from_tensor(x)
y = torch.rand([4, 4], device="cuda")
z = torch.rand([4, 4], device="cpu")
@ -100,7 +95,7 @@ class FakeTensorTest(TestCase):
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_type_as(self):
with enable_torch_dispatch_mode(FakeTensorMode()):
with FakeTensorMode():
x = torch.rand([16, 1], device="cpu")
y = torch.rand([4, 4], device="cuda")
out = x.type_as(y)
@ -110,12 +105,12 @@ class FakeTensorTest(TestCase):
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_setitem(self):
for device in ["cpu", "cuda"]:
with enable_torch_dispatch_mode(FakeTensorMode()):
with FakeTensorMode():
x = torch.rand([16, 1], device=device)
x[..., 0] = 0
def test_fake_dispatch_keys(self):
with enable_torch_dispatch_mode(FakeTensorMode()):
with FakeTensorMode():
x = torch.rand([4])
f = FileCheck().check("CPU").check("ADInplaceOrView").check("AutogradCPU").check("AutocastCPU")
f.run(torch._C._dispatch_key_set(x))
@ -127,14 +122,14 @@ class FakeTensorTest(TestCase):
FileCheck().check_not("ADInplaceOrView").check_not("Autograd").run(torch._C._dispatch_key_set(y))
def test_constructor(self):
with enable_torch_dispatch_mode(FakeTensorMode()):
with FakeTensorMode():
x = torch.rand([4, 4], device="cpu")
self.assertTrue(isinstance(x, FakeTensor))
self.assertTrue(x.device.type == "cpu")
def test_mode(self):
with enable_torch_dispatch_mode(FakeTensorMode()):
with FakeTensorMode():
y = torch.rand([4], device="cpu")
out = y + y
@ -142,7 +137,7 @@ class FakeTensorTest(TestCase):
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_non_kwarg_device(self):
with enable_torch_dispatch_mode(FakeTensorMode()):
with FakeTensorMode():
x = torch.rand([16, 1], device="cpu")
y = x.to(torch.device("cpu"))
self.assertIs(x, y)
@ -153,7 +148,7 @@ class FakeTensorTest(TestCase):
x = torch.rand([4, 4])
with self.assertRaisesRegex(Exception, "non-Fake Tensor inputs"):
with enable_torch_dispatch_mode(FakeTensorMode()):
with FakeTensorMode():
y = x[0]
def test_fake_grad_copy(self):
@ -168,7 +163,7 @@ class FakeTensorTest(TestCase):
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_like_constructor(self):
with enable_torch_dispatch_mode(FakeTensorMode()):
with FakeTensorMode():
x = torch.rand([4, 4])
y = torch.ones_like(x)
self.assertTrue(isinstance(y, FakeTensor))
@ -178,7 +173,7 @@ class FakeTensorTest(TestCase):
self.assertEqual(z.device.type, "cuda")
def test_binary_op_type_promotion(self):
with enable_torch_dispatch_mode(FakeTensorMode()):
with FakeTensorMode():
x = torch.empty([2, 2], dtype=torch.float)
y = torch.empty([2, 2], dtype=torch.int64)
out = x / y
@ -186,14 +181,14 @@ class FakeTensorTest(TestCase):
self.assertEqual(out.device.type, "cpu")
def test_from_numpy(self):
with enable_torch_dispatch_mode(FakeTensorMode()):
with FakeTensorMode():
x = torch.tensor(np.zeros([4, 4]))
self.checkType(x, "cpu", [4, 4])
def test_randperm(self):
x = torch.randperm(10)
y = torch.randperm(5, device="cpu")
with enable_torch_dispatch_mode(FakeTensorMode()):
with FakeTensorMode():
x1 = torch.randperm(10)
prims.utils.compare_tensor_meta(x, x1)
y1 = torch.randperm(5, device="cpu")
@ -202,21 +197,21 @@ class FakeTensorTest(TestCase):
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_cpu_fallback(self):
with enable_torch_dispatch_mode(FakeTensorMode(allow_fallback_kernels=False)):
with FakeTensorMode(allow_fallback_kernels=False):
filters = torch.randn(8, 4, 3, 3).cuda()
inputs = torch.randn(1, 4, 5, 5).cuda()
out = torch.nn.functional.conv2d(inputs, filters, padding=1)
self.assertEqual(out.device.type, "cuda")
self.assertEqual(list(out.size()), [1, 8, 5, 5])
with enable_torch_dispatch_mode(FakeTensorMode(allow_fallback_kernels=True)):
with FakeTensorMode(allow_fallback_kernels=True):
# intentionally bad inputs
filters = torch.randn(8, 20, 3, 3).cuda()
inputs = torch.randn(1, 7, 10, 5).cuda()
with self.assertRaises(RuntimeError):
torch.nn.functional.conv2d(inputs, filters, padding=1)
with enable_torch_dispatch_mode(FakeTensorMode(allow_fallback_kernels=True)):
with FakeTensorMode(allow_fallback_kernels=True):
filters = torch.randn(8, 4, 3, 3).cuda()
inputs = torch.randn(1, 4, 5, 5).cuda()
@ -295,7 +290,7 @@ class FakeTensorTest(TestCase):
)
mode = FakeTensorMode()
for i, context in enumerate([contextlib.nullcontext, lambda: enable_torch_dispatch_mode(mode)]):
for i, context in enumerate([contextlib.nullcontext, lambda: mode]):
with context():
inps = (
torch.randn([92, 8, 2048]).cuda(),
@ -336,16 +331,14 @@ class FakeTensorTest(TestCase):
with torch._subclasses.fake_tensor.FakeCopyMode(mode):
mod_copied = copy.deepcopy(m)
with enable_torch_dispatch_mode(mode):
with mode:
input = torch.rand(20, 16, 50, 100, dtype=torch.half, device="cuda").to(memory_format=torch.channels_last)
out = mod_copied(input)
self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
self.checkType(out, "cuda", [20, 33, 24, 49])
def test_data_dependent_operator(self):
with enable_torch_dispatch_mode(
FakeTensorMode(allow_fallback_kernels=False)
):
with FakeTensorMode(allow_fallback_kernels=False):
x = torch.rand([10, 10])
self.assertRaises(DynamicOutputShapeException, lambda: torch.nonzero(x))
@ -355,7 +348,8 @@ class FakeTensorTest(TestCase):
@skipIfCrossRef
def test_deepcopy(self):
mode = FakeTensorMode()
with FakeTensorMode() as mode:
pass
mod = torch.nn.BatchNorm2d(10)
with torch._subclasses.fake_tensor.FakeCopyMode(mode):
mod_copied = copy.deepcopy(mod)
@ -386,7 +380,7 @@ class FakeTensorTest(TestCase):
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_new(self):
with enable_torch_dispatch_mode(FakeTensorMode()):
with FakeTensorMode():
a = torch.rand([16, 1])
self.checkType(a.new(10, 10), "cpu", [10, 10])
self.checkType(a.new([1, 2, 3, 4]), "cpu", [4])
@ -527,8 +521,7 @@ class FakeTensorConverterTest(TestCase):
self.assertEqual(len(converter.meta_converter.tensor_memo), 0)
def test_no_active_mode(self):
mode = FakeTensorMode()
with enable_torch_dispatch_mode(mode):
with FakeTensorMode() as mode:
x = torch.empty(2, 2, device="cpu")
y = torch.empty(2, 2, device="cpu")
@ -538,9 +531,9 @@ class FakeTensorConverterTest(TestCase):
self.assertEqual(out.device.type, "cpu")
def test_separate_mode_error(self):
with enable_torch_dispatch_mode(FakeTensorMode()):
with FakeTensorMode():
x = torch.empty(2, 2, device="cpu")
with enable_torch_dispatch_mode(FakeTensorMode()):
with FakeTensorMode():
y = torch.empty(2, 2, device="cpu")
self.assertRaises(Exception, lambda: x, y)

View File

@ -63,7 +63,6 @@ from torch._subclasses.fake_tensor import (
)
from torch._subclasses.fake_utils import outputs_alias_inputs
from torch.utils._python_dispatch import enable_torch_dispatch_mode
import torch._prims as prims
from torch._prims.context import TorchRefsMode
@ -173,8 +172,7 @@ class TestCommon(TestCase):
@onlyNativeDeviceTypes
@ops(python_ref_db)
def test_python_ref_meta(self, device, dtype, op):
mode = FakeTensorMode()
with mode:
with FakeTensorMode() as mode:
pass
def _to_tensormeta(x):
@ -189,7 +187,7 @@ class TestCommon(TestCase):
meta_sample = sample.transform(_to_tensormeta)
try:
with enable_torch_dispatch_mode(mode):
with mode:
meta_result = op(meta_sample.input, *meta_sample.args, **meta_sample.kwargs)
except torch._subclasses.fake_tensor.UnsupportedFakeTensorException:
continue
@ -1864,7 +1862,7 @@ class TestFakeTensor(TestCase):
continue
with context():
with enable_torch_dispatch_mode(mode):
with mode:
res_fake = op(input, *args, **kwargs)

View File

@ -16,9 +16,11 @@ from torch.overrides import (
get_overridable_functions,
get_testing_overrides,
is_tensor_method_or_property,
TorchFunctionMode
TorchFunctionMode,
_get_current_function_mode,
_get_current_function_mode_stack,
)
from torch.utils._mode_utils import find_outermost_mode, all_same_mode, all_same_mode_scope
from torch.utils._mode_utils import all_same_mode
from torch.utils._pytree import tree_map
Tensor = torch.Tensor
@ -1150,11 +1152,6 @@ class TestTorchFunctionMode(TestCase):
self.assertEqual(torch._sparse_csr_tensor_unsafe(1, 1, 1, (1, 1)), -1)
self.assertEqual(torch.as_tensor([1]), -1)
def test_enable_torch_function_mode_with_tensor_subclass(self):
x = torch.randn(1)
with torch.overrides.enable_torch_function_mode(SubTensor):
self.assertEqual(torch.mm(x, x), -1)
def test_modes_handle_first(self):
class A(TorchFunctionMode):
def __torch_function__(self, *args, **kwargs):
@ -1178,54 +1175,9 @@ class TestTorchFunctionMode(TestCase):
self.assertEqual(torch.mm(x, x), -1)
self.assertEqual(bar(x), 1)
self.assertRaisesRegex(
TypeError, r'SubTensor.+MyMode',
TypeError, r'SubTensor.+TorchFunctionStackMode',
lambda: self.assertEqual(torch.max(x, x)))
def test_enable_torch_function_mode_trivial(self):
class A(TorchFunctionMode):
def __torch_function__(self, *args, **kwargs):
return -40
a = A()
with torch.overrides.enable_torch_function_mode(a):
with torch.overrides.enable_torch_function_mode(a):
self.assertEqual(bar(None), -40)
def test_enable_torch_function_mode_replace(self):
class A(TorchFunctionMode):
def __init__(self, val):
self.val = val
def __torch_function__(self, *args, **kwargs):
return self.val
a1 = A(-40)
a2 = A(-41)
with torch.overrides.enable_torch_function_mode(a1):
with torch.overrides.enable_torch_function_mode(a2, replace=a1):
self.assertEqual(bar(None), -41)
def test_enable_torch_function_mode_ignore_preexisting(self):
class A(TorchFunctionMode):
def __init__(self, val):
self.val = val
def __torch_function__(self, *args, **kwargs):
return self.val
a1 = A(-40)
a2 = A(-41)
with torch.overrides.enable_torch_function_mode(a1):
with torch.overrides.enable_torch_function_mode(a2, ignore_preexisting=True):
self.assertEqual(bar(None), -41)
def test_ctor_no_inner(self):
class A(TorchFunctionMode):
def __torch_function__(self, *args, **kwargs):
return torch.zeros([])
with torch.overrides.enable_torch_function_mode(A()):
x = torch.randn((3, 4))
self.assertEqual(x, torch.zeros([]))
def test_with_mode(self):
class ErrorA(RuntimeError):
pass
@ -1270,15 +1222,24 @@ class TestTorchFunctionMode(TestCase):
self.assertEqual(out, ["layer2", "layer1"])
def test_error_using_same_mode(self):
class A(TorchFunctionMode):
pass
def test_nested_same_mode(self):
out = []
x = A()
with x:
with self.assertRaisesRegex(RuntimeError, "has already been used as a mode. Please use a fresh version"):
with x:
pass
class A(TorchFunctionMode):
def __init__(self, msg):
self.msg = msg
def __torch_function__(self, func, _, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
out.append(self.msg)
return func(*args, **kwargs)
with A("layer1") as a:
with a:
torch.empty([])
self.assertEqual(out, ["layer1", "layer1"])
def test_error_using_class_method_on_mode(self):
class A(TorchFunctionMode):
@ -1291,86 +1252,43 @@ class TestTorchFunctionMode(TestCase):
with A():
x + x
def test_error_with_ancestor(self):
def test_restacking_with_ancestor(self):
class A(TorchFunctionMode):
pass
with A() as x:
pass
with self.assertRaisesRegex(RuntimeError, "has already been used as a mode. Please use a fresh version"):
with x:
pass
def test_restore_errors(self):
class A(TorchFunctionMode):
pass
with self.assertRaisesRegex(RuntimeError, "does not have any ancestors. Use the standard version instead"):
with A().restore():
pass
x = A()
with A():
with x:
with A() as x:
pass
with A(): # a different mode instance than the one above
with self.assertRaisesRegex(RuntimeError, "the current mode is not its ancestor"):
with x.restore():
pass
def test_restore_ancestor_mode(self):
class A(TorchFunctionMode):
pass
x = A()
y = A()
with x:
with y:
pass
pass
z = A()
with y.restore():
with z:
pass
with x.restore():
with z.restore():
pass
def test_find_outermost_mode(self):
def test_get_cur_mode(self):
class A(TorchFunctionMode):
pass
self.assertIsNone(find_outermost_mode([None, None]))
x = A()
y = A()
with x:
with y:
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
pass
self.assertEqual(find_outermost_mode([x, y]), y)
with A() as mode1:
self.assertEqual(_get_current_function_mode(), mode1)
z = A()
with y.restore():
with z:
with mode1:
with A() as mode2:
self.assertEqual(_get_current_function_mode(), mode2)
def test_get_mode_stack(self):
class A(TorchFunctionMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
pass
self.assertEqual(find_outermost_mode([z, x]), z)
i = A()
self.assertEqual(_get_current_function_mode_stack(), [])
with self.assertRaisesRegex(RuntimeError, "doesn't have ancestors set so the ordering with other modes"):
find_outermost_mode([i, x, y, z])
with A() as mode1:
self.assertEqual(_get_current_function_mode_stack(), [mode1])
k = A()
with k:
pass
with self.assertRaisesRegex(RuntimeError, "don't come from the same scope"):
find_outermost_mode([k, x, y, z])
with mode1:
with A() as mode2:
self.assertEqual(_get_current_function_mode_stack(), [mode1, mode2])
def test_all_same_mode(self):
class A(TorchFunctionMode):
@ -1382,32 +1300,6 @@ class TestTorchFunctionMode(TestCase):
self.assertFalse(all_same_mode([x, None]))
self.assertFalse(all_same_mode([x, y]))
def test_all_same_mode_scope(self):
class A(TorchFunctionMode):
pass
x = A()
y = A()
z = A()
with x:
with y:
pass
with x.restore():
with z:
pass
i = A()
self.assertTrue(all_same_mode_scope([x, y], y))
self.assertTrue(all_same_mode_scope([x, z], z))
self.assertFalse(all_same_mode_scope([x, y, z], y))
self.assertFalse(all_same_mode_scope([x, y, z], z))
self.assertFalse(all_same_mode_scope([x, y, i], y))
no_ancestor = A()
self.assertFalse(all_same_mode_scope([x, y, z], no_ancestor))
def test_reentrant_mode_idiom(self):
log = []
@ -1417,7 +1309,7 @@ class TestTorchFunctionMode(TestCase):
kwargs = {}
log.append(func)
if func is torch.sub:
with torch.overrides.enable_torch_function_mode(self, replace=self.inner):
with self:
input, other = args
assert not kwargs
return torch.add(input, other, alpha=-1)

View File

@ -742,7 +742,7 @@ class TestDecomp(TestCase):
op = torch._decomp.decomposition_table.get(torch.ops.aten.leaky_relu_backward.default)
def fn0(*arg):
return _is_func_unsupported_nvfuser(mode, op, arg, {})
return _is_func_unsupported_nvfuser(TorchRefsNvfuserCapabilityMode(), op, arg, {})
def fn1(x):
x = x * 2
@ -750,8 +750,8 @@ class TestDecomp(TestCase):
x = x * 2
return x
with TorchRefsNvfuserCapabilityMode() as mode:
self.assertFalse(fn0(x, y, 0.3, False))
self.assertFalse(fn0(x, y, 0.3, False))
with TorchRefsNvfuserCapabilityMode():
# Autocast context has C++ level ATen calls that are hidden from
# TorchRefsNvfuserCapabilityMode that works only on Python level.

View File

@ -7,11 +7,11 @@ from torch.library import Library
from torch.cuda.jiterator import _create_jit_fn
import unittest
from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_ROCM, IS_WINDOWS
from torch.utils._mode_utils import no_dispatch, find_outermost_mode, all_same_mode, all_same_mode_scope
from torch.utils._mode_utils import no_dispatch, all_same_mode
from torch.testing._internal.logging_tensor import LoggingTensor, LoggingTensorReentrant, LoggingTensorMode, \
log_input, capture_logs, capture_logs_with_logging_tensor_mode
from torch.utils._pytree import tree_map, tree_map_only
from torch.utils._python_dispatch import enable_torch_dispatch_mode, TorchDispatchMode
from torch.utils._python_dispatch import TorchDispatchMode, _get_current_dispatch_mode, _get_current_dispatch_mode_stack
import logging
@ -730,34 +730,22 @@ $6 = torch._ops.aten.add_.Tensor($1, $5)''')
res = x.index_put_(idxs, v)
self.assertEqual(called_funcs, [torch.ops.aten.index_put_.default])
def test_enable_torch_dispatch_mode_error(self) -> None:
z = LoggingTensor(torch.empty([]))
with self.assertRaisesRegex(ValueError, "expected to get TorchDispatchMode, Tensor-like class, or None"):
with enable_torch_dispatch_mode(z):
pass
def test_enable_torch_dispatch_mode_basic(self) -> None:
def test_torch_dispatch_mode_basic(self) -> None:
with capture_logs(is_mode=True) as logs:
with enable_torch_dispatch_mode(LoggingTensorMode()):
with LoggingTensorMode():
torch.empty([])
self.assertExpectedInline('\n'.join(logs), """\
$0 = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)""")
def test_enable_torch_dispatch_mode_unrelated_tensors(self) -> None:
def test_torch_dispatch_mode_unrelated_tensors(self) -> None:
x = torch.randn([])
y = torch.randn([])
with capture_logs(is_mode=True) as logs:
with enable_torch_dispatch_mode(LoggingTensorMode()):
with LoggingTensorMode():
x + y
self.assertExpectedInline('\n'.join(logs), """\
$2 = torch._ops.aten.add.Tensor($0, $1)""")
def test_nested_push_regular(self):
with LoggingTensorMode.push() as mode:
# This previously errored
with LoggingTensorMode():
pass
def test_nested_push_logging_tensor_mode(self):
x = torch.randn([])
y = torch.randn([])
@ -799,7 +787,7 @@ $3 = torch._ops.aten.add.Tensor($1, $2)""")
self.assertEqual(logs1, logs2)
def test_enable_torch_dispatch_mode_subclass_priority(self) -> None:
def test_torch_dispatch_mode_subclass_priority(self) -> None:
class ErrorA(RuntimeError):
pass
@ -813,7 +801,8 @@ $3 = torch._ops.aten.add.Tensor($1, $2)""")
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
raise ErrorA
with AMode():
raise ErrorA
class B(A):
@staticmethod
@ -822,6 +811,15 @@ $3 = torch._ops.aten.add.Tensor($1, $2)""")
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
with BMode():
func(*args, **kwargs)
class AMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
raise ErrorA
class BMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
raise ErrorB
a = A(torch.empty(1))
@ -834,38 +832,41 @@ $3 = torch._ops.aten.add.Tensor($1, $2)""")
# B has precedence over A due to the subclass relationship yet
# modes take precedence over arguments
with self.assertRaises(ErrorA):
with enable_torch_dispatch_mode(A):
with AMode():
b + b
with self.assertRaises(ErrorB):
with enable_torch_dispatch_mode(B):
with BMode():
a + a
with self.assertRaises(ErrorB):
with enable_torch_dispatch_mode(B):
with BMode():
a + b
def test_enable_torch_dispatch_mode_respects_no_dispatch(self) -> None:
def test_mode_with_make_subclass(self):
class SubTensor(torch.Tensor):
@staticmethod
def __new__(cls, elem):
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
class BasicMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
return func(*args, **kwargs)
x = torch.randn(3)
with BasicMode():
y = SubTensor(x)
self.assertIsInstance(y, SubTensor)
def test_torch_dispatch_mode_respects_no_dispatch(self) -> None:
with capture_logs(is_mode=True) as logs1:
with enable_torch_dispatch_mode(LoggingTensorMode()):
with LoggingTensorMode():
torch.ones([2, 3])
with no_dispatch():
torch.ones([2, 3])
with capture_logs(is_mode=True) as logs2:
with enable_torch_dispatch_mode(LoggingTensorMode()):
with LoggingTensorMode():
torch.ones([2, 3])
self.assertEqual(logs1, logs2)
def test_enable_torch_dispatch_mode_instance(self) -> None:
class TestMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
return func(*args, **kwargs)
x = TestMode()
y = torch.tensor([2.])
with enable_torch_dispatch_mode(x):
y + y
def test_shallow_copy_and_detach(self) -> None:
seen = set()
test_case = self
@ -884,103 +885,25 @@ $3 = torch._ops.aten.add.Tensor($1, $2)""")
loss = (x * x).sum()
loss.backward()
def test_nested_enable_torch_dispatch_mode(self) -> None:
class A(LoggingTensorMode):
pass
with self.assertRaisesRegex(ValueError, "there is already an active mode"):
with enable_torch_dispatch_mode(LoggingTensorMode()):
with enable_torch_dispatch_mode(A()):
pass
# For nesting to be a noop, they need to be the same instance
with self.assertRaisesRegex(ValueError, "there is already an active mode"):
with enable_torch_dispatch_mode(LoggingTensorMode()):
with enable_torch_dispatch_mode(LoggingTensorMode()):
pass
def test_nesting_with_same_enable_torch_dispatch_mode(self) -> None:
# "nested" enable_torch_dispatch_modes are allowed if they're the same mode (same instance).
# It's the equivalent of a noop, so it will only write once to the log
x = torch.tensor([3.])
mode = LoggingTensorMode()
with capture_logs(is_mode=True) as logs:
log_input("x", x)
with enable_torch_dispatch_mode(mode):
with enable_torch_dispatch_mode(mode):
x + x
self.assertExpectedInline('\n'.join(logs), '''\
$0 = input('x')
$1 = torch._ops.aten.add.Tensor($0, $0)''')
def test_enable_torch_dispatch_mode_ignore_preexisting(self):
class A(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
raise AssertionError
x = torch.tensor([3.])
with capture_logs(is_mode=True) as logs:
with enable_torch_dispatch_mode(A()):
with enable_torch_dispatch_mode(LoggingTensorMode(), ignore_preexisting=True):
x + x
self.assertExpectedInline('\n'.join(logs), """\
$1 = torch._ops.aten.add.Tensor($0, $0)""")
def test_enable_torch_dispatch_mode_replace(self):
class A(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
raise AssertionError
x = torch.tensor([3.])
outer_mode = A()
with capture_logs(is_mode=True) as logs:
with enable_torch_dispatch_mode(outer_mode):
with enable_torch_dispatch_mode(LoggingTensorMode(), replace=outer_mode):
x + x
self.assertExpectedInline('\n'.join(logs), """\
$1 = torch._ops.aten.add.Tensor($0, $0)""")
def test_exception_handling(self):
class A(torch.Tensor):
@staticmethod
def __new__(cls, elem):
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
class AMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if func.__name__ == 'randn.default':
raise RuntimeError()
return cls(torch.zeros(()))
return A(torch.zeros(()))
with enable_torch_dispatch_mode(A):
with AMode():
try:
torch.randn(())
except RuntimeError:
pass
self.assertTrue(isinstance(torch.zeros(()), A))
def test_ctor_no_inner(self):
class A(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
return torch.zeros([])
with enable_torch_dispatch_mode(A()):
x = torch.randn((3, 4))
self.assertEqual(x, torch.zeros([]))
def test_with_mode(self):
class ErrorA(RuntimeError):
pass
class A(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
raise ErrorA()
with self.assertRaises(ErrorA):
with A():
torch.empty([])
def test_with_mode_created_separately(self):
class ErrorA(RuntimeError):
pass
@ -1021,10 +944,7 @@ $1 = torch._ops.aten.add.Tensor($0, $0)""")
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
modes = (arg.mode for arg in args + tuple(kwargs.values()) if isinstance(arg, ModeTensor))
outermost = find_outermost_mode(modes)
with outermost.restore():
return func(*args, **kwargs)
raise NotImplementedError("Shouldn't be here")
class Mode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
@ -1111,70 +1031,17 @@ $1 = torch._ops.aten.add.Tensor($0, $0)""")
with PoliteMode():
a.abs()
def test_disable_mode(self):
class FailEverythingMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
raise RuntimeError("arf")
def test_nesting_same_mode(self):
# If the pushed mode is the same instance as the current mode, we allow pushing an already active mode.
with FailEverythingMode() as m:
self.assertRaises(RuntimeError, lambda: torch.ones([2, 3]))
with enable_torch_dispatch_mode(None, replace=m):
torch.ones([2, 3])
with capture_logs(is_mode=True) as logs:
with LoggingTensorMode() as reenabled:
with reenabled:
torch.empty([])
self.assertExpectedInline('\n'.join(logs), """\
$0 = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
$0 = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)""")
def test_make_wrapper_subclass_with_modes(self):
class ModeTensor(torch.Tensor):
def __new__(cls, elem, mode):
r = torch.Tensor._make_wrapper_subclass(cls, elem.shape)
r.elem = elem
r.mode = mode
return r
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
modes = (arg.mode for arg in args + tuple(kwargs.values()) if isinstance(arg, ModeTensor))
outermost = find_outermost_mode(modes)
with outermost.restore():
return func(*args, **kwargs)
class Mode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
def unwrap(e):
if isinstance(e, ModeTensor):
return e.elem
else:
return e
def wrap(t):
if isinstance(t, torch.Tensor):
return ModeTensor(t, self)
else:
return t
return wrap(func(*tuple(unwrap(a) for a in args), **kwargs))
x = torch.tensor(4.)
with Mode():
y = x + x
z = y + y
self.assertIsInstance(y, ModeTensor)
self.assertIsInstance(z, ModeTensor)
with Mode():
with Mode():
y = x + x
z = y + y
self.assertIsInstance(y, ModeTensor)
self.assertIsInstance(z, ModeTensor)
def test_error_using_same_mode(self):
class A(TorchDispatchMode):
pass
x = A()
with x:
with self.assertRaisesRegex(RuntimeError, "has already been used as a mode. Please use a fresh version"):
with x:
pass
def test_error_using_class_method_on_mode(self):
class A(TorchDispatchMode):
@ -1187,74 +1054,33 @@ $1 = torch._ops.aten.add.Tensor($0, $0)""")
with A():
x + x
def test_error_with_ancestor(self):
x = LoggingTensorMode()
with x:
pass
with self.assertRaisesRegex(RuntimeError, "has already been used as a mode. Please use a fresh version"):
with x:
def test_get_cur_mode(self):
class A(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
pass
def test_restore_errors(self):
with self.assertRaisesRegex(RuntimeError, "does not have any ancestors. Use the standard version instead"):
with LoggingTensorMode().restore():
self.assertEqual(_get_current_dispatch_mode(), None)
with A() as mode1:
self.assertEqual(_get_current_dispatch_mode(), mode1)
with mode1:
with A() as mode2:
self.assertEqual(_get_current_dispatch_mode(), mode2)
def test_get_mode_stack(self):
class A(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
pass
x = LoggingTensorMode()
with LoggingTensorMode():
with x:
pass
self.assertEqual(_get_current_dispatch_mode_stack(), [])
with LoggingTensorMode(): # a different mode instance than the one above
with self.assertRaisesRegex(RuntimeError, "the current mode is not its ancestor"):
with x.restore():
pass
with A() as mode1:
self.assertEqual(_get_current_dispatch_mode_stack(), [mode1])
def test_restore_ancestor_mode(self):
x = LoggingTensorMode()
y = LoggingTensorMode()
with x:
with y:
pass
z = LoggingTensorMode()
with y.restore():
with z:
pass
with x.restore():
with z.restore():
pass
def test_find_outermost_mode(self):
self.assertIsNone(find_outermost_mode([None, None]))
x = LoggingTensorMode()
y = LoggingTensorMode()
with x:
with y:
pass
self.assertEqual(find_outermost_mode([x, y]), y)
z = LoggingTensorMode()
with y.restore():
with z:
pass
self.assertEqual(find_outermost_mode([z, x]), z)
i = LoggingTensorMode()
with self.assertRaisesRegex(RuntimeError, "doesn't have ancestors set so the ordering with other modes"):
find_outermost_mode([i, x, y, z])
k = LoggingTensorMode()
with k:
pass
with self.assertRaisesRegex(RuntimeError, "don't come from the same scope"):
find_outermost_mode([k, x, y, z])
with mode1:
with A() as mode2:
self.assertEqual(_get_current_dispatch_mode_stack(), [mode1, mode2])
def test_all_same_mode(self):
x = LoggingTensorMode()
@ -1263,29 +1089,6 @@ $1 = torch._ops.aten.add.Tensor($0, $0)""")
self.assertFalse(all_same_mode([x, None]))
self.assertFalse(all_same_mode([x, y]))
def test_all_same_mode_scope(self):
x = LoggingTensorMode()
y = LoggingTensorMode()
z = LoggingTensorMode()
with x:
with y:
pass
with x.restore():
with z:
pass
i = LoggingTensorMode()
self.assertTrue(all_same_mode_scope([x, y], y))
self.assertTrue(all_same_mode_scope([x, z], z))
self.assertFalse(all_same_mode_scope([x, y, z], y))
self.assertFalse(all_same_mode_scope([x, y, z], z))
self.assertFalse(all_same_mode_scope([x, y, i], y))
no_ancestor = LoggingTensorMode()
self.assertFalse(all_same_mode_scope([x, y, z], no_ancestor))
def test_tolist_numpy_with_torch_dispatch_mode(self) -> None:
x = LoggingTensor(torch.tensor([2.0, 3.0]))
with self.assertRaisesRegex(RuntimeError, "is not supported for tensor subclasses."):
@ -1295,7 +1098,7 @@ $1 = torch._ops.aten.add.Tensor($0, $0)""")
with self.assertRaises(AssertionError):
self.assertEqual(x, None)
def test_enable_torch_dispatch_mode_subclass_autograd_device_check(self) -> None:
def test_subclass_autograd_device_check(self) -> None:
class NonWrapperSubclass(torch.Tensor):
elem: torch.Tensor

View File

@ -8,7 +8,7 @@ from torch.utils._pytree import tree_map
from torch.testing._internal.common_utils import run_tests
from torch.fx.operator_schemas import normalize_function
from torch.testing._internal.schema_check_mode import SchemaCheckMode
from torch.utils._python_dispatch import enable_torch_dispatch_mode, TorchDispatchMode
from torch.utils._python_dispatch import TorchDispatchMode
from torch.testing._internal.common_methods_invocations import op_db
from torch.testing._internal.jit_utils import JitTestCase
from torch.testing._internal.common_device_type import ops, OpDTypes, instantiate_device_type_tests
@ -72,16 +72,14 @@ class IncorrectAliasTensor(torch.Tensor):
class TestSchemaCheck(JitTestCase):
# Tests that SchemaCheckMode records operator order with grad
def test_schema_check_mode_operator_order(self):
schema_check = SchemaCheckMode()
with enable_torch_dispatch_mode(schema_check):
with SchemaCheckMode() as schema_check:
x = torch.rand((3, 3), requires_grad=True)
x.relu().sin()
self.assertEqual(["aten::rand", "aten::relu", "aten::detach", "aten::sin"], schema_check.ops)
# Tests that SchemaCheckMode records operator order without grad
def test_schema_check_mode_operator_order_without_grad(self):
schema_check = SchemaCheckMode()
with enable_torch_dispatch_mode(schema_check):
with SchemaCheckMode() as schema_check:
x = torch.rand((3, 3), requires_grad=False)
x.relu().sin()
self.assertEqual(["aten::rand", "aten::relu", "aten::sin"], schema_check.ops)
@ -91,8 +89,7 @@ class TestSchemaCheck(JitTestCase):
# NB: previously requires_grad=True, but this induces a detach for
# saved variable
x = torch.rand((3, 3))
schema_check = SchemaCheckMode()
with enable_torch_dispatch_mode(schema_check):
with SchemaCheckMode() as schema_check:
actual = x.relu().sin()
self.assertEqual([], schema_check.mutated)
self.assertEqual([], schema_check.aliasing)
@ -100,8 +97,7 @@ class TestSchemaCheck(JitTestCase):
# Tests that SchemaCheckMode records mutations and aliases with mutation expected
def test_schema_check_mode_mutated_aliasing_mutation(self):
actual = torch.rand((3, 3), requires_grad=False)
schema_check = SchemaCheckMode()
with enable_torch_dispatch_mode(schema_check):
with SchemaCheckMode() as schema_check:
actual.sinh_()
self.assertEqual([('aten::sinh_', 'input')], schema_check.mutated)
self.assertEqual([('aten::sinh_', 'input', 'output_0')], schema_check.aliasing)
@ -109,8 +105,7 @@ class TestSchemaCheck(JitTestCase):
# Tests that SchemaCheckMode records mutations and aliases with resize_
def test_schema_check_mode_mutated_aliasing_resize_(self):
actual = torch.rand((3, 3), requires_grad=False)
schema_check = SchemaCheckMode()
with enable_torch_dispatch_mode(schema_check):
with SchemaCheckMode() as schema_check:
actual.resize_(9)
self.assertEqual([('aten::resize_', 'input')], schema_check.mutated)
self.assertEqual([('aten::resize_', 'input', 'output_0')], schema_check.aliasing)
@ -119,8 +114,7 @@ class TestSchemaCheck(JitTestCase):
def test_schema_check_mode_mutated_aliasing_aliasing_inputs(self):
actual = torch.rand((3, 3))
y = actual
schema_check = SchemaCheckMode()
with enable_torch_dispatch_mode(schema_check):
with SchemaCheckMode() as schema_check:
actual.add_(y)
self.assertEqual(
[
@ -140,8 +134,7 @@ class TestSchemaCheck(JitTestCase):
# Tests that SchemaCheckMode records mutations and alias with as_strided
def test_schema_check_mode_mutated_aliasing_as_strided(self):
x = torch.rand((3, 6, 4))
schema_check = SchemaCheckMode()
with enable_torch_dispatch_mode(schema_check):
with SchemaCheckMode() as schema_check:
x.as_strided_([3, 6, 4], [9, 1, 1])
self.assertEqual(
[
@ -161,8 +154,7 @@ class TestSchemaCheck(JitTestCase):
x = torch.arange(9.)
m_actual = torch.arange(9.)
e_actual = torch.zeros([9], dtype=torch.int32)
schema_check = SchemaCheckMode()
with enable_torch_dispatch_mode(schema_check):
with SchemaCheckMode() as schema_check:
torch.frexp(x, out=(m_actual, e_actual))
self.assertEqual(
[
@ -183,8 +175,7 @@ class TestSchemaCheck(JitTestCase):
def test_schema_check_mode_mutated_aliasing_aliasing_outputs(self):
x = torch.rand((3, 3))
actual = torch.zeros(3)
schema_check = SchemaCheckMode()
with enable_torch_dispatch_mode(schema_check):
with SchemaCheckMode() as schema_check:
torch.aminmax(x, dim=0, out=[actual, actual])
self.assertEqual(
[
@ -207,7 +198,7 @@ class TestSchemaCheck(JitTestCase):
def test_schema_check_mode_functionality(self):
x = torch.rand((3, 3), requires_grad=True)
expected = x.relu().sin()
with enable_torch_dispatch_mode(SchemaCheckMode()):
with SchemaCheckMode():
actual = x.relu().sin()
self.assertEqual(expected, actual)
@ -215,7 +206,7 @@ class TestSchemaCheck(JitTestCase):
def test_schema_check_mode_functionality_default_replaced(self):
x = torch.rand((3, 3), requires_grad=True)
expected = x.add(x, alpha=2)
with enable_torch_dispatch_mode(SchemaCheckMode()):
with SchemaCheckMode():
actual = x.add(x, alpha=2)
self.assertEqual(expected, actual)
@ -225,7 +216,7 @@ class TestSchemaCheck(JitTestCase):
b = torch.rand((3, 3))
c = torch.rand((3, 3))
expected = torch.linalg.multi_dot([a, b, c])
with enable_torch_dispatch_mode(SchemaCheckMode()):
with SchemaCheckMode():
actual = torch.linalg.multi_dot([a, b, c])
self.assertEqual(expected, actual)
@ -233,7 +224,7 @@ class TestSchemaCheck(JitTestCase):
def test_schema_check_mode_functionality_wildcard_after(self):
x = torch.rand((3, 3))
expected = x.chunk(6)
with enable_torch_dispatch_mode(SchemaCheckMode()):
with SchemaCheckMode():
actual = x.chunk(6)
self.assertEqual(expected, actual)
@ -242,7 +233,7 @@ class TestSchemaCheck(JitTestCase):
x = torch.rand((3, 5))
w = torch.rand((4))
expected = torch.stft(x, 4, win_length=4, window=w, return_complex=True)
with enable_torch_dispatch_mode(SchemaCheckMode()):
with SchemaCheckMode():
actual = torch.stft(x, 4, win_length=4, window=w, return_complex=True)
self.assertEqual(expected, actual)
@ -251,7 +242,7 @@ class TestSchemaCheck(JitTestCase):
expected = torch.rand((3, 3), requires_grad=False)
actual = torch.clone(expected)
expected.sinh_()
with enable_torch_dispatch_mode(SchemaCheckMode()):
with SchemaCheckMode():
actual.sinh_()
self.assertEqual(expected, actual)
@ -262,7 +253,7 @@ class TestSchemaCheck(JitTestCase):
actual = torch.clone(expected)
y = actual
expected.add_(x)
with enable_torch_dispatch_mode(SchemaCheckMode()):
with SchemaCheckMode():
actual.add_(y)
self.assertEqual(expected, actual)
@ -272,7 +263,7 @@ class TestSchemaCheck(JitTestCase):
m_expected, e_expected = torch.frexp(x)
m_actual = torch.arange(9.)
e_actual = torch.zeros([9], dtype=torch.int32)
with enable_torch_dispatch_mode(SchemaCheckMode()):
with SchemaCheckMode():
torch.frexp(x, out=(m_actual, e_actual))
self.assertEqual(m_expected, m_actual)
self.assertEqual(e_expected, e_actual)
@ -281,13 +272,13 @@ class TestSchemaCheck(JitTestCase):
def test_schema_check_mode_functionality_with_multiple_outputs_aliasing(self):
x = torch.rand((3, 3))
actual = torch.zeros(3)
with enable_torch_dispatch_mode(SchemaCheckMode()):
with SchemaCheckMode():
torch.aminmax(x, dim=0, out=[actual, actual])
self.assertEqual(torch.amax(x, dim=0), actual)
# Tests that SchemaCheckMode wraps Torch.tensor in ops with real Device input
def test_schema_check_mode_functionality_device_input(self):
with enable_torch_dispatch_mode(SchemaCheckMode()):
with SchemaCheckMode():
x = torch.rand((3, 3), device="cpu", dtype=torch.double)
y = x + x
self.assertEqual(x + x, y)
@ -297,7 +288,7 @@ class TestSchemaCheck(JitTestCase):
x = torch.rand((3, 3), requires_grad=True)
batch = torch.nn.BatchNorm1d(3, track_running_stats=True)
expected = batch(x)
with enable_torch_dispatch_mode(SchemaCheckMode()):
with SchemaCheckMode():
actual = batch(x)
self.assertEqual(expected, actual)
@ -311,7 +302,7 @@ class TestSchemaCheck(JitTestCase):
expected.relu_()
expected = batch(expected)
with enable_torch_dispatch_mode(SchemaCheckMode()):
with SchemaCheckMode():
actual.sinh_()
actual.tanh_()
actual.relu_()
@ -321,7 +312,7 @@ class TestSchemaCheck(JitTestCase):
# Tests that SchemaCheckMode wraps Torch.tensor with empty list input
def test_schema_check_mode_empty_list_input(self):
expected = torch.atleast_1d([])
with enable_torch_dispatch_mode(SchemaCheckMode()):
with SchemaCheckMode():
actual = torch.atleast_1d([])
self.assertEqual(expected, actual)
@ -330,7 +321,7 @@ class TestSchemaCheck(JitTestCase):
with self.assertRaisesRegex(RuntimeError, "Argument input is not defined as mutable but was mutated"):
x = torch.rand((3, 3))
y = torch.rand((3, 3))
with enable_torch_dispatch_mode(SchemaCheckMode()):
with SchemaCheckMode():
IncorrectAliasTensor(x).sub(IncorrectAliasTensor(y))
# # Tests that an exception is raised for a mismatching mutation over multiple ops
@ -338,7 +329,7 @@ class TestSchemaCheck(JitTestCase):
with self.assertRaisesRegex(RuntimeError, "Argument input is not defined as mutable but was mutated"):
x = torch.rand((3, 3))
y = torch.rand((3, 3))
with enable_torch_dispatch_mode(SchemaCheckMode()):
with SchemaCheckMode():
IncorrectAliasTensor(x).sin().cos().sub(IncorrectAliasTensor(y))
# Tests that an exception is raised for a mismatching alias
@ -346,7 +337,7 @@ class TestSchemaCheck(JitTestCase):
with self.assertRaisesRegex(RuntimeError, "Argument input is not defined to alias output but was aliasing"):
x = torch.rand((3, 3), requires_grad=True)
y = torch.rand((3, 3))
with enable_torch_dispatch_mode(SchemaCheckMode()):
with SchemaCheckMode():
IncorrectAliasTensor(x).add(IncorrectAliasTensor(y), alpha=2)
# Tests that an exception is raised for a mismatching alias over multiple ops
@ -354,7 +345,7 @@ class TestSchemaCheck(JitTestCase):
with self.assertRaisesRegex(RuntimeError, "Argument input is not defined to alias output but was aliasing"):
x = torch.rand((3, 3), requires_grad=True)
y = torch.zeros((3, 3), requires_grad=True)
with enable_torch_dispatch_mode(SchemaCheckMode()):
with SchemaCheckMode():
IncorrectAliasTensor(x).sin().relu().add(IncorrectAliasTensor(y), alpha=2)
# Tests that an exception is raised for a centered mismatching alias over multiple ops
@ -362,15 +353,14 @@ class TestSchemaCheck(JitTestCase):
with self.assertRaisesRegex(RuntimeError, "Argument input is not defined to alias output but was aliasing"):
x = torch.rand((3, 3), requires_grad=True)
y = torch.zeros((3, 3), requires_grad=True)
with enable_torch_dispatch_mode(SchemaCheckMode()):
with SchemaCheckMode():
IncorrectAliasTensor(x).sin().add(IncorrectAliasTensor(y), alpha=2).relu()
# Tests that an exception is raised for a centered mismatching alias over multiple ops
def test_alias_check_fail_outputs_unexpectedly_aliasing(self):
with self.assertRaisesRegex(RuntimeError, "Outputs 0 and 1 alias unexpectedly"):
x = torch.rand((3, 3))
s = SchemaCheckMode()
with enable_torch_dispatch_mode(s):
with SchemaCheckMode() as s:
IncorrectAliasTensor(x).aminmax(dim=0)
# Tests that is_alias_of returns as expected
@ -439,8 +429,7 @@ class TestSchemaCheck(JitTestCase):
return func(*args, **kwargs)
x = torch.rand((3, 3))
schemaInfoCheck = SchemaInfoBindTestMode(self)
with enable_torch_dispatch_mode(schemaInfoCheck):
with SchemaInfoBindTestMode(self) as schemaInfoCheck:
x.add(x)
@ -452,7 +441,7 @@ class TestSchemaCheckModeOpInfo(JitTestCase):
if (dtype == torch.complex32):
return
for sample in op.sample_inputs(device, dtype, requires_grad=False):
with enable_torch_dispatch_mode(SchemaCheckMode()):
with SchemaCheckMode():
op(sample.input, *sample.args, **sample.kwargs)
instantiate_device_type_tests(TestSchemaCheckModeOpInfo, globals(), only_for=("cpu", "cuda"))

View File

@ -918,11 +918,18 @@ def __is_forward_AD_enabled() -> _bool: ...
def _register_default_hooks(pack_hook: Callable, unpack_hook: Callable) -> None: ...
def _reset_default_hooks() -> None: ...
def _is_torch_function_mode_enabled()-> _bool: ...
def _set_torch_function_mode(cls: Any) -> None: ...
def _get_torch_function_mode() -> Any: ...
def _push_on_torch_function_stack(cls: Any) -> None: ...
def _pop_torch_function_stack() -> Any: ...
def _get_function_stack_at(idx: _int) -> Any: ...
def _len_torch_function_stack() -> _int: ...
def _set_torch_dispatch_mode(cls: Any) -> None: ...
def _get_torch_dispatch_mode() -> Any: ...
def _push_on_torch_dispatch_stack(cls: Any) -> None: ...
def _pop_torch_dispatch_stack() -> Any: ...
def _get_dispatch_stack_at(idx: _int) -> Any: ...
def _len_torch_dispatch_stack() -> _int: ...
class _InferenceMode(object):
def __init__(self, mode: _bool) -> None: ...

View File

@ -12,6 +12,7 @@ import torch.jit
from torch import _utils_internal
# Query `hasattr` only once.
_SET_GLOBAL_FLAGS = hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags")
@ -144,9 +145,11 @@ class PyOperator(PyOperatorABC):
return inner
def dispatch(self, dispatch_key, *args, **kwargs):
from torch.utils._python_dispatch import _get_current_dispatch_mode
if dispatch_key == torch._C.DispatchKey.Python:
# TODO(voz): We should walk all the nodes here / turn it into a list, topmode is ok for now.
curr_mode = type(torch._C._get_torch_dispatch_mode())
curr_mode = type(_get_current_dispatch_mode())
assert (
curr_mode is not None
), "Illegal invocation of dispatch on torch._C.DispatchKey.Python without a mode."
@ -314,9 +317,11 @@ class OpOverload(PyOperatorABC):
return key
def handler(*args, **kwargs):
from torch.utils._python_dispatch import _get_current_dispatch_mode
# TODO: We also need to handle tensor subclasses here
# TODO(voz): We should walk all the nodes here / turn it into a list, topmode is ok for now.
curr_mode = type(torch._C._get_torch_dispatch_mode())
curr_mode = type(_get_current_dispatch_mode())
assert (
curr_mode is not None
), "Illegal invocation of dispatch on torch._C.DispatchKey.Python without a mode."

View File

@ -185,7 +185,7 @@ class TorchRefsMode(torch.overrides.TorchFunctionMode):
if self.should_fallback_fn(self, func, args, kwargs):
return orig_func(*args, **kwargs)
# torch calls inside func should be interpreted as refs calls
with torch.overrides.enable_torch_function_mode(self, replace=self.inner):
with self:
return func(*args, **kwargs)
if self.strict:
raise RuntimeError(
@ -202,9 +202,7 @@ def _is_node_supported_nvfuser(node):
def _is_func_unsupported_nvfuser(torch_function_mode, func, args, kwargs):
with torch.overrides.enable_torch_function_mode(
torch_function_mode, replace=torch_function_mode.inner
):
with torch_function_mode:
gm = get_isolated_graphmodule(func, args, kwargs)
supported_ops = NvfuserPrimOperatorSupport()

View File

@ -16,7 +16,7 @@ from torch.fx.operator_schemas import normalize_function
from torch.multiprocessing.reductions import StorageWeakRef
from torch.overrides import TorchFunctionMode
from torch.utils._mode_utils import no_dispatch
from torch.utils._python_dispatch import enable_torch_dispatch_mode, TorchDispatchMode
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import PyTree, tree_flatten, tree_map
@ -564,7 +564,8 @@ class FakeTensor(torch.Tensor):
else:
assert fake_mode is arg.fake_mode, "Mixing modes NYI"
with enable_torch_dispatch_mode(fake_mode):
assert fake_mode is not None
with fake_mode: # type: ignore[attr-defined]
return func(*args, **kwargs)
@staticmethod
@ -623,7 +624,7 @@ class FakeTensor(torch.Tensor):
# We keep one instantiation of `fake_tensor_converter` active
# for the duration of `with torch_enable_mode(FakeTensorMode)`.
# for the duration of `with FakeTensorMode()`.
# This allows accurate storage aliasing across invocation of
# different operators. While this will keep all freshly allocated
# tensors alive during `FakeTensorMode`, there will no be no
@ -779,7 +780,7 @@ class FakeTensorMode(TorchDispatchMode):
# We do this to allow for better error localization with `TORCH_SHOW_CPP_STACKTRACES=1`
return None
with self.restore():
with self:
if func in meta_table:
r = meta_table[func](*args, **kwargs)
return r
@ -801,7 +802,7 @@ class FakeTensorMode(TorchDispatchMode):
and len(flat_arg_fake_tensors) != 0
and hasattr(func, "prim_meta_impl")
):
with self.restore():
with self:
return func.prim_meta_impl(*args, **kwargs)
if has_symbolic_sizes:

View File

@ -585,34 +585,6 @@ static PyObject* python_exit_dual_level(
END_HANDLE_TH_ERRORS
}
static PyObject* set_torch_dispatch_mode(PyObject* _unused, PyObject* arg) {
HANDLE_TH_ERRORS
if (arg == Py_None) {
c10::impl::TorchDispatchModeTLS::set_state(nullptr);
} else {
Py_INCREF(arg);
c10::impl::TorchDispatchModeTLS::set_state(
std::make_shared<c10::SafePyObject>(arg, getPyInterpreter()));
}
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
static PyObject* get_torch_dispatch_mode(
PyObject* _unused,
PyObject* _unused2) {
HANDLE_TH_ERRORS
const auto& mode = c10::impl::TorchDispatchModeTLS::get_state();
if (!mode) {
Py_RETURN_NONE;
} else {
auto* r = mode->ptr(getPyInterpreter());
Py_INCREF(r);
return r;
}
END_HANDLE_TH_ERRORS
}
static PyObject* set_torch_function_mode(PyObject* _unused, PyObject* arg) {
HANDLE_TH_ERRORS
if (arg == Py_None) {
@ -623,24 +595,136 @@ static PyObject* set_torch_function_mode(PyObject* _unused, PyObject* arg) {
std::make_shared<c10::SafePyObject>(arg, getPyInterpreter()));
}
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
END_HANDLE_TH_ERRORS;
}
static PyObject* get_torch_function_mode(
static PyObject* is_torch_function_mode_enabled(
PyObject* _unused,
PyObject* _unused2) {
HANDLE_TH_ERRORS
const auto& mode = at::impl::PythonTorchFunctionTLS::get_mode();
if (!mode) {
Py_RETURN_NONE;
if (at::impl::function_mode_enabled()) {
Py_RETURN_TRUE;
} else {
auto* r = mode->ptr(getPyInterpreter());
Py_INCREF(r);
return r;
Py_RETURN_FALSE;
}
END_HANDLE_TH_ERRORS
}
static PyObject* push_on_torch_function_stack(
PyObject* _unused,
PyObject* arg) {
HANDLE_TH_ERRORS
if (arg != Py_None) {
Py_INCREF(arg);
at::impl::PythonTorchFunctionTLS::push_onto_stack(
std::make_shared<c10::SafePyObject>(arg, getPyInterpreter()));
}
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
static PyObject* pop_torch_function_stack(
PyObject* _unused,
PyObject* _unused2) {
HANDLE_TH_ERRORS
const auto& mode = at::impl::PythonTorchFunctionTLS::pop_stack();
auto* r = mode->ptr(getPyInterpreter());
Py_INCREF(r);
return r;
END_HANDLE_TH_ERRORS
}
static PyObject* get_function_stack_at(
PyObject* _unused,
PyObject* args,
PyObject* kwargs) {
HANDLE_TH_ERRORS
static PythonArgParser parser({"get_stack_at(int64_t level)"});
ParsedArgs<1> parsed_args;
auto _r = parser.parse(args, kwargs, parsed_args);
auto idx = _r.toInt64(0);
const auto& mode = at::impl::PythonTorchFunctionTLS::get_stack_at(idx);
auto* r = mode->ptr(getPyInterpreter());
Py_INCREF(r);
return r;
END_HANDLE_TH_ERRORS
}
static PyObject* len_torch_function_stack(
PyObject* _unused,
PyObject* _unused2) {
HANDLE_TH_ERRORS
const auto len = at::impl::PythonTorchFunctionTLS::stack_len();
return utils::wrap(static_cast<int64_t>(len));
END_HANDLE_TH_ERRORS
}
static PyObject* set_torch_dispatch_mode(PyObject* _unused, PyObject* arg) {
HANDLE_TH_ERRORS
if (arg == Py_None) {
c10::impl::TorchDispatchModeTLS::set_mode(nullptr);
} else {
Py_INCREF(arg);
c10::impl::TorchDispatchModeTLS::set_mode(
std::make_shared<c10::SafePyObject>(arg, getPyInterpreter()));
}
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS;
}
static PyObject* push_on_torch_dispatch_stack(
PyObject* _unused,
PyObject* arg) {
HANDLE_TH_ERRORS
if (arg != Py_None) {
Py_INCREF(arg);
c10::impl::TorchDispatchModeTLS::push_onto_stack(
std::make_shared<c10::SafePyObject>(arg, getPyInterpreter()));
}
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
static PyObject* pop_torch_dispatch_stack(
PyObject* _unused,
PyObject* _unused2) {
HANDLE_TH_ERRORS
const auto& mode = c10::impl::TorchDispatchModeTLS::pop_stack();
auto* r = mode->ptr(getPyInterpreter());
Py_INCREF(r);
return r;
END_HANDLE_TH_ERRORS
}
static PyObject* get_dispatch_stack_at(
PyObject* _unused,
PyObject* args,
PyObject* kwargs) {
HANDLE_TH_ERRORS
static PythonArgParser parser({"get_stack_at(int64_t level)"});
ParsedArgs<1> parsed_args;
auto _r = parser.parse(args, kwargs, parsed_args);
auto idx = _r.toInt64(0);
const auto& mode = c10::impl::TorchDispatchModeTLS::get_stack_at(idx);
auto* r = mode->ptr(getPyInterpreter());
Py_INCREF(r);
return r;
END_HANDLE_TH_ERRORS
}
static PyObject* len_torch_dispatch_stack(
PyObject* _unused,
PyObject* _unused2) {
HANDLE_TH_ERRORS
const auto len = c10::impl::TorchDispatchModeTLS::stack_len();
return utils::wrap(static_cast<int64_t>(len));
END_HANDLE_TH_ERRORS
}
// autograd methods on torch._C
static PyMethodDef methods[] = { // NOLINT
{"_set_grad_enabled", set_grad_enabled, METH_O, nullptr},
@ -685,10 +769,44 @@ static PyMethodDef methods[] = { // NOLINT
castPyCFunctionWithKeywords(python_exit_dual_level),
METH_VARARGS | METH_KEYWORDS,
nullptr},
{"_set_torch_dispatch_mode", set_torch_dispatch_mode, METH_O, nullptr},
{"_get_torch_dispatch_mode", get_torch_dispatch_mode, METH_NOARGS, nullptr},
{"_is_torch_function_mode_enabled",
is_torch_function_mode_enabled,
METH_NOARGS,
nullptr},
{"_set_torch_function_mode", set_torch_function_mode, METH_O, nullptr},
{"_get_torch_function_mode", get_torch_function_mode, METH_NOARGS, nullptr},
{"_push_on_torch_function_stack",
push_on_torch_function_stack,
METH_O,
nullptr},
{"_pop_torch_function_stack",
pop_torch_function_stack,
METH_NOARGS,
nullptr},
{"_get_function_stack_at",
castPyCFunctionWithKeywords(get_function_stack_at),
METH_VARARGS | METH_KEYWORDS,
nullptr},
{"_len_torch_function_stack",
len_torch_function_stack,
METH_NOARGS,
nullptr},
{"_set_torch_dispatch_mode", set_torch_dispatch_mode, METH_O, nullptr},
{"_push_on_torch_dispatch_stack",
push_on_torch_dispatch_stack,
METH_O,
nullptr},
{"_pop_torch_dispatch_stack",
pop_torch_dispatch_stack,
METH_NOARGS,
nullptr},
{"_get_dispatch_stack_at",
castPyCFunctionWithKeywords(get_dispatch_stack_at),
METH_VARARGS | METH_KEYWORDS,
nullptr},
{"_len_torch_dispatch_stack",
len_torch_dispatch_stack,
METH_NOARGS,
nullptr},
{nullptr, nullptr, 0, nullptr}};
PyMethodDef* python_functions() {

View File

@ -293,7 +293,7 @@ auto handle_torch_function_no_python_arg_parser(
torch_function_name == TorchFunctionName::TorchFunction;
auto get_mode = [&]() {
return is_torch_function ? at::impl::PythonTorchFunctionTLS::get_mode()
: c10::impl::TorchDispatchModeTLS::get_state();
: c10::impl::TorchDispatchModeTLS::get_mode();
};
const auto& maybe_mode = get_mode();

View File

@ -19,7 +19,7 @@ struct StashTorchFunctionModeGuard {
}
private:
std::shared_ptr<c10::SafePyObject> old_mode_ = nullptr;
std::shared_ptr<c10::SafePyObject> old_mode_;
};
} // namespace overrides

View File

@ -8,16 +8,15 @@ namespace torch_dispatch_mode {
struct StashTorchDispatchModeGuard {
public:
StashTorchDispatchModeGuard() {
saved_ = c10::impl::TorchDispatchModeTLS::get_state();
c10::impl::TorchDispatchModeTLS::set_state(nullptr);
c10::impl::TorchDispatchModeTLS::swap_mode(saved_mode_);
}
~StashTorchDispatchModeGuard() {
c10::impl::TorchDispatchModeTLS::set_state(saved_);
c10::impl::TorchDispatchModeTLS::set_mode(std::move(saved_mode_));
}
private:
std::shared_ptr<at::SafePyObject> saved_;
std::shared_ptr<at::SafePyObject> saved_mode_;
};
} // namespace torch_dispatch_mode

View File

@ -19,7 +19,7 @@ from dataclasses import dataclass
import weakref
import operator
from torch.utils._python_dispatch import TorchDispatchMode, enable_torch_dispatch_mode
from torch.utils._python_dispatch import TorchDispatchMode, _pop_mode_temporarily, _get_current_dispatch_mode
from torch._subclasses import FakeTensor
from .symbolic_shapes import ShapeEnv, SymDispatchMode, PySymInt, PySymFloat
from torch.fx import Proxy
@ -167,9 +167,9 @@ def track_tensor_tree(inner_res, proxy_res, *, constant, tracer):
def maybe_disable_fake_tensor_mode():
# TODO: figure out if this API generally makes sense and bake it into the
# library
mb_fake_mode = torch._C._get_torch_dispatch_mode()
mb_fake_mode = _get_current_dispatch_mode()
if isinstance(mb_fake_mode, FakeTensorMode):
return enable_torch_dispatch_mode(mb_fake_mode.inner, replace=mb_fake_mode)
return _pop_mode_temporarily()
else:
return nullcontext()
@ -432,7 +432,7 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
@contextmanager
def restore(self):
with self.sym_mode.enable(True):
with super().restore():
with self:
yield
def inner_torch_dispatch(self, func, types, args=(), kwargs=None):
@ -637,12 +637,7 @@ def make_fx(f, decomposition_table=None, tracing_mode="real"):
def get_torch_dispatch_modes():
modes = [torch._C._get_torch_dispatch_mode()]
if modes[-1] is None:
return list()
while modes[-1].inner is not None:
modes.append(modes[-1].inner)
return modes
return torch.utils._python_dispatch._get_current_dispatch_mode_stack()
@contextlib.contextmanager

View File

@ -32,9 +32,9 @@ import contextlib
import torch
from torch._C import (
_has_torch_function, _has_torch_function_unary,
_has_torch_function_variadic, _add_docstr, _set_torch_function_mode, _get_torch_function_mode)
from torch.utils._mode_utils import _enable_mode, _ModeInfo, _restore_mode
_has_torch_function_variadic, _add_docstr,
_push_on_torch_function_stack, _pop_torch_function_stack, _get_function_stack_at, _len_torch_function_stack,
_set_torch_function_mode, _is_torch_function_mode_enabled)
__all__ = [
"get_ignored_functions",
@ -1510,8 +1510,8 @@ def handle_torch_function(
types = tuple(map(type, overloaded_args))
# Check for __torch_function__ mode.
mode = _get_torch_function_mode()
if mode is not None:
if _is_torch_function_mode_enabled():
mode = _get_current_function_mode()
# NB: unlike on tensors, modes are instances
with _no_torch_function_mode():
result = mode.__torch_function__(public_api, types, args, kwargs)
@ -1541,8 +1541,8 @@ def handle_torch_function(
"no implementation found for '{}' on types that implement "
'__torch_function__: {}'
).format(func_name, [type(arg) for arg in overloaded_args])
if mode is not None:
msg += f" nor in mode {mode}"
if _is_torch_function_mode_enabled():
msg += f" nor in mode {_get_current_function_mode()}"
raise TypeError(msg)
has_torch_function = _add_docstr(
@ -1766,53 +1766,7 @@ def is_tensor_like(inp):
"""
return type(inp) is torch.Tensor or hasattr(type(inp), "__torch_function__")
def _wrap_torch_function(f):
@functools.wraps(f)
def wrapped(self, *args, **kwargs):
if isinstance(f, classmethod):
raise RuntimeError("TorchFunctionMode's torch_function function " +
"should be a normal method not a class method")
inner = getattr(self, "inner", None)
with enable_torch_function_mode(inner):
return f(self, *args, **kwargs)
return wrapped
# Implementation note: I had a choice about how much of mode stacks
# to implement in Python versus in C++. At time of writing, I did not care
# too much about implementation efficiency; however, I do care about making it
# hard for users to implement modes in the wrong way. In the end, it turned
# out to be possible to implement mode stacks entirely from userland, with the
# C++ API providing only _get_torch_function_mode() and
# _set_torch_function_mode(), so I opted to provide some unsafe C++ bindings and
# have the bulk of the logic for managing the stack in Python, which helped
# simplify the C++ API surface. It would also have been valid to build in the
# notion of mode stack directly into C++ but in this design it's substantially
# more difficult to interact with TorchFunctionModeMeta.
class TorchFunctionModeMeta(type):
"""
Metaclass for :class:`TorchFunctionMode`; it does two things:
* Adds an implicit ``inner`` kwarg to ``__init__``, to
allow the modes to be chained together to form a stack.
* Reenables the inner mode, so that by default PyTorch API calls
will compositionally proceed to the next mode on the stack.
The default behavior for the second bullet is important, as it is easy to
accidentally write ``__torch_function__`` implementations that are not
compositional, and the wrapping here makes the obvious code do the
right thing (aka, this is why there is a metaclass).
"""
def __new__(metacls, name, bases, dct):
if '__torch_function__' in dct:
dct['__torch_function__'] = _wrap_torch_function(dct['__torch_function__'])
return super().__new__(metacls, name, bases, dct)
class TorchFunctionMode(metaclass=TorchFunctionModeMeta):
class TorchFunctionMode:
"""
A ``TorchFunctionMode`` allows you to override the meaning of all
``__torch_function__`` overrideable functions within a dynamic scope,
@ -1852,24 +1806,11 @@ class TorchFunctionMode(metaclass=TorchFunctionModeMeta):
raise NotImplementedError()
def __enter__(self):
old = _get_torch_function_mode()
if hasattr(self, "inner"):
raise RuntimeError(f"{self} has already been used as a mode. Please use a fresh version or use restore")
else:
self.inner = old
if old is None:
self.ancestors = set()
else:
self.ancestors = self.inner.ancestors.union({self.inner})
_set_torch_function_mode(self)
_push_mode(self)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
_set_torch_function_mode(self.inner)
@contextlib.contextmanager
def restore(self):
return _restore_mode(self, mode_info=_TorchFunctionModeInfo())
_pop_mode()
@classmethod
def push(cls, *args, **kwargs):
@ -1877,6 +1818,50 @@ class TorchFunctionMode(metaclass=TorchFunctionModeMeta):
instance = cls(*args, **kwargs)
return instance
def _get_current_function_mode():
stack_len = _len_torch_function_stack()
return _get_function_stack_at(stack_len - 1) if stack_len > 0 else None
def _get_current_function_mode_stack():
stack_len = _len_torch_function_stack()
return [_get_function_stack_at(i) for i in range(stack_len)]
def _push_mode(mode):
if _len_torch_function_stack() == 0:
_set_torch_function_mode(_TorchFunctionStackMode())
_push_on_torch_function_stack(mode)
def _pop_mode():
old = _pop_torch_function_stack()
if _len_torch_function_stack() == 0:
_set_torch_function_mode(None)
return old
@contextlib.contextmanager
def _pop_mode_temporarily():
old = _pop_mode()
try:
yield old
finally:
_push_mode(old)
# a helper "mode" used by the torch_function push helper method. This is the only mode that will ever
# be active at the C++ level and it will run the current mode
class _TorchFunctionStackMode:
def __torch_function__(self, func, types, args=(), kwargs=None):
with _pop_mode_temporarily() as old:
if _len_torch_function_stack() > 0:
_set_torch_function_mode(self)
# we can't check the type of __torch_function__ here but this is sufficient for checking it's a classmethod
if old.__torch_function__.__self__ is type(old):
raise RuntimeError("TorchFunctionMode's torch_function function " +
"should be a normal method not a class method")
return old.__torch_function__(func, types, args, kwargs)
class BaseTorchFunctionMode(TorchFunctionMode):
def __torch_function__(self, func, types, args=(), kwargs=None):
if kwargs is None:
@ -1889,57 +1874,14 @@ class BaseTorchFunctionMode(TorchFunctionMode):
# library code though, e.g., in handle_torch_function
@contextlib.contextmanager
def _no_torch_function_mode() -> Iterator[None]:
old = _get_torch_function_mode()
_set_torch_function_mode(None)
try:
yield
finally:
_set_torch_function_mode(old)
if _len_torch_function_stack() > 0:
_set_torch_function_mode(_TorchFunctionStackMode())
class _TorchFunctionModeInfo(_ModeInfo):
def __init__(self):
super().__init__(mode_name="torch_function", mode_class=TorchFunctionMode)
def get_mode(self):
return _get_torch_function_mode()
def set_mode(self, mode):
return _set_torch_function_mode(mode)
@contextlib.contextmanager
def enable_torch_function_mode(mode, *, replace=None, ignore_preexisting=False) -> Iterator[None]:
"""
Context manager that sets the current :class:`TorchFunctionMode`; see the
class for more information on what modes are. This function is
non-compositional; if there is already an existing mode, it will raise an
error; prefer using ``with MyMode():`` if your ``__torch_function__``
implementation can defer to an inner mode.
This function is safe to use inside a ``__torch_function__`` mode handler,
as the mode is guaranteed to be disabled in this context. You can use
this context manager to reinstate the mode so that calls to overridable
APIs recursively call back into your mode handler (this can easily cause
infinite loops, so use with care!)
Args:
mode (:class:`TorchFunctionMode`, Tensor-like class or None): the
mode to set as current mode. If you pass a Tensor-like class,
it will be treated as a non-compositional mode with no state,
which is convenient if you have an existing tensor subclass
that you'd like to apply globally in a quick and dirty way.
Passing None will disable the current mode.
replace (:class:`TorchFunctionMode` or Tensor-like class): the
mode to replace. You can use this argument to change the mode in
a situation where you know what the current mode is (and you are
intentionally overwriting it.) If you don't know what the current
mode is, use ``ignore_preexisting`` instead.
ignore_preexisting (bool): if True, ignore any preexisting mode
and overwrite it with the passed mode.
"""
return _enable_mode(mode, _TorchFunctionModeInfo(), replace=replace, ignore_preexisting=ignore_preexisting)
class enable_reentrant_dispatch():
def __enter__(self):
self._raii_guard = torch._C._RestorePythonTLSSnapshot()

View File

@ -1,10 +1,11 @@
import torch
from torch import Tensor
import itertools
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
from functools import partial
from torch.utils._mode_utils import no_dispatch
from torch.utils._python_dispatch import enable_torch_dispatch_mode
from torch.utils._mode_utils import no_dispatch, all_same_mode
import torch.autograd.forward_ad as fwAD
from typing import Callable
import re
@ -99,7 +100,7 @@ def is_inplace(func):
return name[-1] == '_'
def generate_cct(autograd_view_consistency=True):
def generate_cct_and_mode(autograd_view_consistency=True):
# This function returns a new class CompositeCompliantTensor
# The two arguments control the behaviour described below.
@ -116,7 +117,7 @@ def generate_cct(autograd_view_consistency=True):
__torch_function__ = torch._C._disabled_torch_function_impl
@staticmethod
def __new__(cls, elem, *args, **kwargs):
def __new__(cls, elem, mode, *args, **kwargs):
assert type(elem) is not cls, \
"Wrapping a CompositeCompliantTensor in a CompositeCompliantTensor is not supported"
@ -148,6 +149,8 @@ def generate_cct(autograd_view_consistency=True):
# Ref: https://github.com/albanD/subclass_zoo/issues/21
torch._C._set_conj(r, r.elem.is_conj())
torch._C._set_neg(r, r.elem.is_neg())
r.mode = mode
return r
def __repr__(self):
@ -155,11 +158,20 @@ def generate_cct(autograd_view_consistency=True):
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
all_args = tree_flatten(args)[0] + tree_flatten(kwargs)[0]
modes = tuple(e.mode for e in all_args if isinstance(e, CompositeCompliantTensor))
if not all_same_mode(modes):
raise RuntimeError("Multiple CompositeCompliantTensorModes NYI")
with modes[0]:
return func(*args, **kwargs)
class CompositeCompliantTensorMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
def unwrap(e):
return e.elem if isinstance(e, CompositeCompliantTensor) else e
def wrap(e):
return CompositeCompliantTensor(e) if isinstance(e, torch.Tensor) else e
return CompositeCompliantTensor(e, self) if isinstance(e, torch.Tensor) else e
if func == torch.ops.aten._local_scalar_dense.default:
raise RuntimeError(
@ -224,13 +236,13 @@ def generate_cct(autograd_view_consistency=True):
# For each CompositeCompliantTensor t, we check that t and t.elem
# have consistent metadata. If they don't have consistent metadata,
# that means the operator did something fishy.
check = partial(check_metadata_consistency, CCT=cls)
check = partial(check_metadata_consistency, CCT=CompositeCompliantTensor)
tree_map(check, args)
tree_map(check, kwargs)
tree_map(check, rs)
return rs
return CompositeCompliantTensor
return CompositeCompliantTensor, CompositeCompliantTensorMode()
def is_tensorlist(lst):
if not isinstance(lst, list) and not isinstance(lst, tuple):
@ -251,12 +263,12 @@ def maybe_map(fn, should_map, arg):
return fn(arg) if should_map else arg
def wrap(arg, CCT):
# CCT: CompositeCompliantTensor class which is generated using generate_cct
def wrap(arg, CCT, cct_mode):
# CCT: CompositeCompliantTensor class which is generated using generate_cct_and_mode
if isinstance(arg, torch.Tensor):
return CCT(arg)
return CCT(arg, cct_mode)
if is_tensorlist(arg):
return [CCT(a) for a in arg]
return [CCT(a, cct_mode) for a in arg]
raise RuntimeError("wrap assumes that the input can be wrapped")
@ -270,14 +282,14 @@ def wrap(arg, CCT):
# [A, 1, B]
# NB: Yes, this is exponential. No, we don't care too much because PyTorch ops
# don't accept that many input Tensors.
def generate_subclass_choices(flat_args, CCT):
# CCT: CompositeCompliantTensor class which is generated using generate_cct
def generate_subclass_choices(flat_args, CCT, cct_mode):
# CCT: CompositeCompliantTensor class which is generated using generate_cct_and_mode
is_tensor_likes = [isinstance(arg, torch.Tensor) or is_tensorlist(arg) for arg in flat_args]
subclass_options = [[False, True] if is_tensor_like else [False] for is_tensor_like in is_tensor_likes]
for which_args_are_wrapped in itertools.product(*subclass_options):
result = [maybe_map(partial(wrap, CCT=CCT), should_wrap_arg, arg)
result = [maybe_map(partial(wrap, CCT=CCT, cct_mode=cct_mode), should_wrap_arg, arg)
for should_wrap_arg, arg in zip(which_args_are_wrapped, flat_args)]
yield result, which_args_are_wrapped
@ -285,11 +297,11 @@ def generate_subclass_choices(flat_args, CCT):
# For an operation f(*args, **kwargs), each Tensor argument may either be
# a regular Tensor or a Tensor Subclass. This iterator iterates through
# all of those options.
def generate_subclass_choices_args_kwargs(args, kwargs, CCT):
# CCT: CompositeCompliantTensor class which is generated using generate_cct
def generate_subclass_choices_args_kwargs(args, kwargs, CCT, cct_mode):
# CCT: CompositeCompliantTensor class which is generated using generate_cct_and_mode
flat_kwargs, spec = tree_flatten(kwargs)
flat_args_kwargs = list(args) + list(flat_kwargs)
for choice, debug_metadata in generate_subclass_choices(flat_args_kwargs, CCT):
for choice, debug_metadata in generate_subclass_choices(flat_args_kwargs, CCT, cct_mode):
new_args = choice[:len(args)]
new_kwargs = tree_unflatten(choice[len(args):], spec)
which_args_are_wrapped = debug_metadata[:len(args)]
@ -320,9 +332,9 @@ def raise_composite_compliance_error(err, additional_info=''):
# If some composite operation does any non-compliant behavior,
# CompositeCompliantTensor will raise an error.
def check_all_permutations(op, args, kwargs, assert_equal_fn):
CCT = generate_cct()
CCT, cct_mode = generate_cct_and_mode()
expected = op(*args, **kwargs)
for choice in generate_subclass_choices_args_kwargs(args, kwargs, CCT):
for choice in generate_subclass_choices_args_kwargs(args, kwargs, CCT, cct_mode):
new_args, new_kwargs, which_args_are_wrapped, which_kwargs_are_wrapped = choice
try:
@ -365,17 +377,17 @@ def check_all_permutations(op, args, kwargs, assert_equal_fn):
# Composite does any non-compliant behavior,
# CompositeCompliantTensor will raise an error.
def check_with_mode(op, args, kwargs, assert_equal_fn):
CCT = generate_cct()
CCT, cct_mode = generate_cct_and_mode()
def wrap(e):
return CCT(e) if isinstance(e, torch.Tensor) else e
return CCT(e, cct_mode) if isinstance(e, torch.Tensor) else e
expected = op(*args, **kwargs)
args = tree_map(wrap, args)
kwargs = tree_map(wrap, kwargs)
try:
with enable_torch_dispatch_mode(CCT):
with cct_mode:
actual = op(*args, **kwargs)
# see NOTE: [What errors are Composite Compliance trying to catch?]
except RuntimeError as err:
@ -429,11 +441,11 @@ def compute_expected_grads(op, args, kwargs, output_process_fn_grad=None, gradch
def check_backward_formula(op: Callable, args, kwargs,
output_process_fn_grad=None,
gradcheck_wrapper=None, assert_equal_fn=None):
CCT = generate_cct()
CCT, cct_mode = generate_cct_and_mode()
expected = compute_expected_grads(op, args, kwargs, output_process_fn_grad, gradcheck_wrapper)
for choice in generate_subclass_choices_args_kwargs(args, kwargs, CCT):
for choice in generate_subclass_choices_args_kwargs(args, kwargs, CCT, cct_mode):
new_args, new_kwargs, which_args_are_wrapped, which_kwargs_are_wrapped = choice
leaf_tensors = gather_leaf_tensors(new_args, new_kwargs)
assert len(leaf_tensors) > 0
@ -460,7 +472,7 @@ def check_backward_formula(op: Callable, args, kwargs,
# NB: ones, not ones_like, so we get a regular Tensor here
grads = [torch.ones(r.shape, device=r.device, dtype=r.dtype)
for r in flat_diff_results]
for flat_new_grads, which_grad_is_batched in generate_subclass_choices(grads, CCT):
for flat_new_grads, which_grad_is_batched in generate_subclass_choices(grads, CCT, cct_mode):
try:
actual = torch.autograd.grad(flat_diff_results, leaf_tensors, flat_new_grads,
allow_unused=True, retain_graph=True)
@ -486,7 +498,7 @@ def check_backward_formula(op: Callable, args, kwargs,
# this means we can apply check_forward_ad_formula to things that aren't OpInfos
# while debugging.
def check_forward_ad_formula(op: Callable, args, kwargs, gradcheck_wrapper=None, assert_equal_fn=None):
CCT = generate_cct(autograd_view_consistency=False)
CCT, cct_mode = generate_cct_and_mode(autograd_view_consistency=False)
def maybe_tangent(t):
assert type(t) is not CCT
@ -529,11 +541,11 @@ def check_forward_ad_formula(op: Callable, args, kwargs, gradcheck_wrapper=None,
expected_tangents = tree_map(lambda x: x.tangent, expected)
# Permutations of arg and kwargs in CCT.
for choice in generate_subclass_choices_args_kwargs(args, kwargs, CCT):
for choice in generate_subclass_choices_args_kwargs(args, kwargs, CCT, cct_mode):
new_args, new_kwargs, which_args_are_wrapped, which_kwargs_are_wrapped = choice
# Permutations tangent arg and tangent kwargs in CCT.
for tang_choice in generate_subclass_choices_args_kwargs(tangent_args, tangent_kwargs, CCT):
for tang_choice in generate_subclass_choices_args_kwargs(tangent_args, tangent_kwargs, CCT, cct_mode):
new_tang_args, new_tang_kwargs, \
which_tang_args_are_wrapped, which_tang_kwargs_are_wrapped = tang_choice

View File

@ -1,125 +1,13 @@
import torch
from typing import Iterator, TypeVar
from dataclasses import dataclass
from typing import TypeVar
from contextlib import contextmanager
T = TypeVar('T')
# This file has all the logic to dedupe logic between torch dispatch and
# torch function modes
#
# Specifically, it has the helper functions for enable_ and push_X_mode and the
# ModeInfo class, which is extended by each where they are different
# in order to dedupe the logic between TorchDispatchMode and TorchFunctionMode, this
# is a container to hold all the differences between the modes. Then functions like
# _enable_mode are able to use this container to call functions or get correctly
# formatted names
@dataclass
class _ModeInfo:
mode_name: str
mode_class: type # the class related to the mode that's allowed to be passed in
def mode_class_name(self):
return self.mode_class.__name__
def get_mode(self):
"""gets the current mode for this type of mode"""
raise NotImplementedError()
def set_mode(self, mode):
"""
set mode to for this type of mode. Note that no checks are done on this, it's the unsafe
version where checks are assumed to have been already done by the helper function
"""
raise NotImplementedError()
# shared version of enable_torch_function/enable_torch_dispatch_mode in order to deduplicate the code.
# The differences between the modes are captured by `mode_info` and then queried when they're
# needed during the function's invocation
def _enable_mode(mode: T, mode_info: _ModeInfo, *, replace=None, ignore_preexisting=False) -> Iterator[T]:
if not (
mode is None or
isinstance(mode, mode_info.mode_class) or
(isinstance(mode, type) and not issubclass(mode, mode_info.mode_class))
):
raise ValueError(f'expected to get {mode_info.mode_class_name()}, Tensor-like class, '
f'or None as an argument got {type(mode)} instead')
old = mode_info.get_mode()
if old is mode:
yield mode # type: ignore[misc]
return
if old is not None and not ignore_preexisting and old is not replace:
if isinstance(mode, mode_info.mode_class):
help_text = 'Use `with Mode():` instead.'
else:
help_text = (
'If you intended to completely override the preexisting mode, '
'pass ignore_preexisting=True. This can result in unexpected '
'behavior; please consider rewriting your mode to be a subclass '
f'of {mode_info.mode_class_name()} to make it compositional!'
)
raise ValueError(
f'Attempted to enable_{mode_info.mode_name}_mode, but there is already an '
f'active mode {old}. {help_text}'
)
# NB: we don't require TorchFunctionMode/TorchDispatchMode since this is intended to also
# let you directly pass a Tensor subclass type to "mode-ify" it.
if mode is not None:
required_fn = "__" + mode_info.mode_name + "__"
if not hasattr(mode, required_fn):
raise ValueError(
f'The argument passed to enable_{mode_info.mode_name}_mode must implement {required_fn}'
)
mode_info.set_mode(mode)
try:
yield mode # type: ignore[misc]
finally:
mode_info.set_mode(old)
def _restore_mode(mode, mode_info: _ModeInfo):
if not hasattr(mode, "ancestors"):
raise RuntimeError(f"{mode} does not have any ancestors. Use the standard version instead of restore")
old = mode_info.get_mode()
if old is not None and old not in mode.ancestors:
raise RuntimeError(f"{mode} is not valid in the current state because the current mode is not its ancestor")
mode_info.set_mode(mode)
try:
yield mode
finally:
mode_info.set_mode(old)
# To help with non-lexical scoping, it will error if all the modes are from different scopes or haven't been used
def find_outermost_mode(modes):
outermost = None
for mode in modes:
if mode is not None:
if not hasattr(mode, "ancestors"):
raise RuntimeError(f"{mode}, doesn't have ancestors set so the ordering with other modes is unclear")
if outermost is None:
outermost = mode
elif mode not in outermost.ancestors and outermost not in mode.ancestors:
raise RuntimeError(f"modes {mode} and {outermost} are not compatible because they "
"don't come from the same scope")
elif outermost in mode.ancestors:
outermost = mode
return outermost
# returns if all are the same mode
def all_same_mode(modes):
return all(tuple(mode == modes[0] for mode in modes))
# returns if all modes are from the current scope, ``cur_mode``
def all_same_mode_scope(modes, cur_mode):
if not hasattr(cur_mode, "ancestors"):
return False
return all(tuple(mode == cur_mode or mode in cur_mode.ancestors for mode in modes))
@contextmanager
def no_dispatch():
guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined]

View File

@ -1,115 +1,17 @@
import contextlib
from typing import Iterator, Set
import functools
import warnings
from torch.utils._mode_utils import _enable_mode, _ModeInfo, _restore_mode
from torch._C import _get_torch_dispatch_mode, _set_torch_dispatch_mode
from dataclasses import dataclass
@dataclass
class TorchDispatchModeInfo(_ModeInfo):
def __init__(self):
super().__init__(mode_name="torch_dispatch", mode_class=TorchDispatchMode)
def get_mode(self):
return _get_torch_dispatch_mode()
def set_mode(self, mode):
return _set_torch_dispatch_mode(mode)
from torch._C import _len_torch_dispatch_stack, _get_dispatch_stack_at,\
_pop_torch_dispatch_stack, _push_on_torch_dispatch_stack, _set_torch_dispatch_mode
# TODO: Limitations and things about enable_torch_dispatch_mode we should fix before exposing it:
# - We need a better user-facing api for torch._C._DisableTorchDispatch that
# - We need a better user-facing api for _DisableTorchDispatch that
# is able to selectively disable __torch_dispatch__ of a particular class.
# - It doesn't work with the tensor constructors (torch.tensor, torch.Tensor)
# - Better name (see https://github.com/pytorch/pytorch/pull/63496#discussion_r694091694)
@contextlib.contextmanager
def enable_torch_dispatch_mode(mode, *, replace=None, ignore_preexisting=False) -> Iterator[None]:
"""
Context manager that causes all pytorch operators to dispatch to the passed-in
type's __torch_dispatch__ function, including operations that accept no tensors
but return a tensor.
This function is non-compositional; if there is already an existing mode,
it will raise an error
This function is safe to use inside a ``__torch_dispatch__`` mode handler,
as the mode is guaranteed to be disabled in this context. You can use
this context manager to reinstate the mode so that calls to overridable
APIs recursively call back into your mode handler (this can easily cause
infinite loops, so use with care!)
enable_torch_dispatch_mode is affected by _DisableTorchDispatch.
Args:
mode (:class:`TorchDispatchMode`, Tensor-like class, or None): the
mode to set as current mode. If you pass a Tensor-like class,
it will be treated as a non-compositional mode with no state,
which is convenient if you have an existing tensor subclass
that you'd like to apply globally in a quick and dirty way.
Passing None will disable the current mode.
replace (:class:`TorchDispatchMode` or Tensor-like class): the
mode to replace. You can use this argument to change the mode in
a situation where you know what the current mode is (and you are
intentionally overwriting it.) If you don't know what the current
mode is, use ``ignore_preexisting`` instead.
ignore_preexisting (bool): if True, ignore any preexisting mode
and overwrite it with the passed mode.
"""
return _enable_mode(mode, mode_info=TorchDispatchModeInfo(), replace=replace, ignore_preexisting=ignore_preexisting)
def _wrap_torch_dispatch(f):
@functools.wraps(f)
def wrapped(self, *args, **kwargs):
if isinstance(f, classmethod):
raise RuntimeError("TorchDispatchMode's torch_dispatch function " +
"should be a normal method not a class method")
inner = getattr(self, "inner", None)
with enable_torch_dispatch_mode(inner):
return f(self, *args, **kwargs)
return wrapped
# Implementation note, since this is based on TorchFunctionMode, this had the
# same dilemma: I had a choice about how much of mode stacks
# to implement in Python versus in C++. At time of writing, I did not care
# too much about implementation efficiency; however, I do care about making it
# hard for users to implement modes in the wrong way. In the end, it turned
# out to be possible to implement mode stacks entirely from userland, with the
# C++ API providing only _get_torch_dispatch_mode() and
# _set_torch_dispatch_mode(), so I opted to provide some unsafe C++ bindings and
# have the bulk of the logic for managing the stack in Python, which helped
# simplify the C++ API surface. It would also have been valid to build in the
# notion of mode stack directly into C++ but in this design it's substantially
# more difficult to interact with TorchDispatchModeMeta.
class TorchDispatchModeMeta(type):
"""
Metaclass for :class:`TorchDispatchMode`; it does two things:
* Adds an implicit ``inner`` kwarg to ``__init__``, to
allow the modes to be chained together to form a stack.
* Reenables the inner mode, so that by default PyTorch API calls
will compositionally proceed to the next mode on the stack.
The default behavior for the second bullet is important, as it is easy to
accidentally write ``_wrap_torch_dispatch`` implementations that are not
compositional, and the wrapping here makes the obvious code do the
right thing (aka, this is why there is a metaclass).
"""
def __new__(metacls, name, bases, dct):
if '__torch_dispatch__' in dct:
dct['__torch_dispatch__'] = _wrap_torch_dispatch(dct['__torch_dispatch__'])
return super().__new__(metacls, name, bases, dct)
class TorchDispatchMode(metaclass=TorchDispatchModeMeta):
class TorchDispatchMode:
"""
A ``TorchDispatchMode`` allows you to override the meaning of all
``__torch_dispatch__`` overrideable functions within a dynamic scope,
@ -139,32 +41,15 @@ class TorchDispatchMode(metaclass=TorchDispatchModeMeta):
``__torch_dispatch__(self)`` to make PyTorch
API self-referential (beware of infinite loops, in this case!)
"""
# Force metaclass to generate constructor at the base of the hierarchy
def __init__(self):
self.ancestors: Set[TorchDispatchMode]
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
raise NotImplementedError()
def __enter__(self):
old = _get_torch_dispatch_mode()
if hasattr(self, "inner"):
raise RuntimeError(f"{self} has already been used as a mode. Please use a fresh version or use restore")
else:
self.inner = old
if old is None:
self.ancestors = set()
else:
self.ancestors = self.inner.ancestors.union({self.inner})
_set_torch_dispatch_mode(self)
_push_mode(self)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
_set_torch_dispatch_mode(self.inner)
@contextlib.contextmanager
def restore(self):
return _restore_mode(self, mode_info=TorchDispatchModeInfo())
_pop_mode()
@classmethod
def push(cls, *args, **kwargs):
@ -172,6 +57,48 @@ class TorchDispatchMode(metaclass=TorchDispatchModeMeta):
instance = cls(*args, **kwargs)
return instance
def _get_current_dispatch_mode():
stack_len = _len_torch_dispatch_stack()
return _get_dispatch_stack_at(stack_len - 1) if stack_len > 0 else None
def _get_current_dispatch_mode_stack():
stack_len = _len_torch_dispatch_stack()
return [_get_dispatch_stack_at(i) for i in range(stack_len)]
def _push_mode(mode):
if _len_torch_dispatch_stack() == 0:
_set_torch_dispatch_mode(_TorchDispatchStackMode())
_push_on_torch_dispatch_stack(mode)
def _pop_mode():
old = _pop_torch_dispatch_stack()
if _len_torch_dispatch_stack() == 0:
_set_torch_dispatch_mode(None)
return old
@contextlib.contextmanager
def _pop_mode_temporarily():
old = _pop_mode()
try:
yield old
finally:
_push_mode(old)
# a helper "mode" used by the torch dispatch push helper method. This is the only mode that will ever
# be active at the C++ level and it will run the current mode
class _TorchDispatchStackMode:
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
with _pop_mode_temporarily() as old:
if _len_torch_dispatch_stack() > 0:
_set_torch_dispatch_mode(self)
# we can't check the type of __torch_dispatch__ here but this is sufficient for checking it's a classmethod
if old.__torch_dispatch__.__self__ is type(old):
raise RuntimeError(f"{type(old)}'s torch_dispatch function " +
"should be a normal method not a class method")
return old.__torch_dispatch__(func, types, args, kwargs)
class BaseTorchDispatchMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):