mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Modes] remove enable and rewrite mode stack (squashed) (#84774)
Based on @ezyang's suggestion, mode stack now has "one true mode" which is the _only_ mode that can ever be active at the C++ level. That mode's torch dispatch is just to take the top mode in the stack, reenable itself (if we aren't at the end of the mode stack), and run the top mode's torch_{dispatch|function} This maintains that in the middle of a mode's torch dispatch, the mode itself will not be active. It changes the function the user has to call to see what the current mode is (no longer queries the C++, it's python only) but allows the user to also see the entire mode stack easily Removes `enable_torch_dispatch_mode` and `.restore()` since neither makes sense in this new setup ### Background Why do we want this? Well, a pretty common pattern that was coming up was that users had to do something like ```python ## PRE-PR UX def f(mode): with mode.restore(): # user needs to understand this restore thing? ... with Mode() as m: pass f(m) ``` Many users were getting error from forgetting to call `.restore` or from forgetting to add the (tbh weird) "mode instantiation" step where they use the mode as a context manager with an empty body. Really, they wanted to treat modes like context managers and just write ```python ## FROM FEEDBACK, USER DESIRED CODE. POSSIBLE POST-PR def f(mode): with mode: ... f(Mode()) ``` ** Technical Details ** With the old mode stack, we basically had a linked list so the mode itself could only be used once and had a fixed parent. In this new design, the mode stack is just a python list that we're pushing to and popping from. There's only one mode that's ever active at the C++ level and it runs the next mode in the Python list. The modes don't have state on them anymore Pull Request resolved: https://github.com/pytorch/pytorch/pull/84774 Approved by: https://github.com/ezyang, https://github.com/zou3519
This commit is contained in:
@ -18,6 +18,26 @@ void PythonTorchFunctionTLS::swap_mode(std::shared_ptr<c10::SafePyObject>& mode)
|
||||
pythonTorchFunctionState.mode_.swap(mode);
|
||||
}
|
||||
|
||||
void PythonTorchFunctionTLS::push_onto_stack(std::shared_ptr<SafePyObject> mode) {
|
||||
pythonTorchFunctionState.stack_.push_back(std::move(mode));
|
||||
}
|
||||
|
||||
const std::shared_ptr<SafePyObject> PythonTorchFunctionTLS::pop_stack() {
|
||||
TORCH_CHECK(pythonTorchFunctionState.stack_.size() > 0, "trying to pop from empty mode stack");
|
||||
const auto out = pythonTorchFunctionState.stack_.back();
|
||||
pythonTorchFunctionState.stack_.pop_back();
|
||||
return out;
|
||||
}
|
||||
|
||||
const std::shared_ptr<SafePyObject>& PythonTorchFunctionTLS::get_stack_at(int64_t idx) {
|
||||
TORCH_CHECK(idx < static_cast<int64_t>(pythonTorchFunctionState.stack_.size()), "Tried to get stack at idx that's too big");
|
||||
return pythonTorchFunctionState.stack_[idx];
|
||||
}
|
||||
|
||||
int64_t PythonTorchFunctionTLS::stack_len() {
|
||||
return pythonTorchFunctionState.stack_.size();
|
||||
}
|
||||
|
||||
void PythonTorchFunctionTLS::set_disabled(bool disabled) {
|
||||
pythonTorchFunctionState.disabled_ = disabled;
|
||||
}
|
||||
@ -34,5 +54,9 @@ const PythonTorchFunctionTLS& PythonTorchFunctionTLS::get_state() {
|
||||
return pythonTorchFunctionState;
|
||||
}
|
||||
|
||||
bool function_mode_enabled() {
|
||||
return static_cast<bool>(PythonTorchFunctionTLS::get_mode());
|
||||
}
|
||||
|
||||
} // namespace impl
|
||||
} // namespace at
|
||||
|
@ -14,13 +14,28 @@ struct TORCH_API PythonTorchFunctionTLS {
|
||||
static const std::shared_ptr<c10::SafePyObject>& get_mode();
|
||||
static void swap_mode(std::shared_ptr<c10::SafePyObject>&);
|
||||
|
||||
static void set_state(const PythonTorchFunctionTLS& state);
|
||||
static void push_onto_stack(std::shared_ptr<SafePyObject> mode);
|
||||
static const std::shared_ptr<SafePyObject> pop_stack();
|
||||
static const std::shared_ptr<SafePyObject>& get_stack_at(int64_t idx);
|
||||
static int64_t stack_len();
|
||||
|
||||
static const PythonTorchFunctionTLS& get_state();
|
||||
static void set_state(const PythonTorchFunctionTLS& state);
|
||||
|
||||
private:
|
||||
// The mode TLS is split into
|
||||
// - disabled_, which says whether or not to disable all torch function
|
||||
// modes
|
||||
// - mode_, which is the C++ mode, that can only be the mode handling mode
|
||||
// or null
|
||||
// - stack_, which is a vector of modes representing the stack of user
|
||||
// defined modes
|
||||
bool disabled_;
|
||||
std::shared_ptr<c10::SafePyObject> mode_;
|
||||
std::shared_ptr<c10::SafePyObject> mode_ = nullptr;
|
||||
std::vector<std::shared_ptr<c10::SafePyObject>> stack_;
|
||||
};
|
||||
|
||||
TORCH_API bool function_mode_enabled();
|
||||
|
||||
} // namespace impl
|
||||
} // namespace at
|
||||
|
@ -56,7 +56,7 @@ class TORCH_API ThreadLocalState {
|
||||
AutogradState autograd_tls_;
|
||||
|
||||
// TLS for enable_torch_dispatch_mode
|
||||
std::shared_ptr<SafePyObject> torch_dispatch_mode_state_;
|
||||
c10::impl::TorchDispatchModeTLS torch_dispatch_mode_state_;
|
||||
|
||||
// TLS for enable_python_dispatcher
|
||||
c10::impl::PyInterpreter* python_dispatcher_state_;
|
||||
|
@ -52,7 +52,7 @@ void pythonFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
|
||||
|
||||
|
||||
// If Torch Dispatch Mode is active, use its PyInterpreter for dispatch
|
||||
const auto& maybe_torch_dispatch_mode_state = c10::impl::TorchDispatchModeTLS::get_state();
|
||||
const auto& maybe_torch_dispatch_mode_state = c10::impl::TorchDispatchModeTLS::get_mode();
|
||||
if (maybe_torch_dispatch_mode_state) {
|
||||
maybe_torch_dispatch_mode_state->pyinterpreter()->dispatch(op, stack);
|
||||
return;
|
||||
|
@ -547,7 +547,7 @@ c10::intrusive_ptr<TensorImpl> TensorImpl::shallow_copy_and_detach_core(
|
||||
bool allow_tensor_metadata_change) const {
|
||||
c10::intrusive_ptr<TensorImpl> r;
|
||||
const auto& maybe_torch_dispatch_mode_state =
|
||||
c10::impl::TorchDispatchModeTLS::get_state();
|
||||
c10::impl::TorchDispatchModeTLS::get_mode();
|
||||
// TODO: do we have to exclude after Python dispatch key set?
|
||||
if (maybe_torch_dispatch_mode_state &&
|
||||
!c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python)) {
|
||||
|
@ -6,32 +6,85 @@
|
||||
namespace c10 {
|
||||
namespace impl {
|
||||
|
||||
thread_local std::shared_ptr<SafePyObject> torchDispatchModeState;
|
||||
thread_local TorchDispatchModeTLS torchDispatchModeState;
|
||||
|
||||
void TorchDispatchModeTLS::set_state(std::shared_ptr<SafePyObject> state) {
|
||||
if (state) {
|
||||
// MODE
|
||||
void TorchDispatchModeTLS::set_mode(std::shared_ptr<SafePyObject> mode) {
|
||||
if (mode) {
|
||||
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, true);
|
||||
c10::impl::tls_set_dispatch_key_included(
|
||||
DispatchKey::PythonTLSSnapshot, true);
|
||||
} else {
|
||||
TorchDispatchModeTLS::reset_state();
|
||||
TorchDispatchModeTLS::reset_mode();
|
||||
}
|
||||
torchDispatchModeState = std::move(state);
|
||||
torchDispatchModeState.mode_ = std::move(mode);
|
||||
}
|
||||
|
||||
const std::shared_ptr<SafePyObject>& TorchDispatchModeTLS::get_state() {
|
||||
return torchDispatchModeState;
|
||||
const std::shared_ptr<SafePyObject>& TorchDispatchModeTLS::get_mode() {
|
||||
return torchDispatchModeState.mode_;
|
||||
}
|
||||
|
||||
void TorchDispatchModeTLS::reset_state() {
|
||||
torchDispatchModeState.reset();
|
||||
void TorchDispatchModeTLS::reset_mode() {
|
||||
torchDispatchModeState.mode_.reset();
|
||||
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false);
|
||||
c10::impl::tls_set_dispatch_key_included(
|
||||
DispatchKey::PythonTLSSnapshot, false);
|
||||
}
|
||||
|
||||
void TorchDispatchModeTLS::swap_mode(std::shared_ptr<SafePyObject>& mode) {
|
||||
if (mode) {
|
||||
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, true);
|
||||
c10::impl::tls_set_dispatch_key_included(
|
||||
DispatchKey::PythonTLSSnapshot, true);
|
||||
} else {
|
||||
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false);
|
||||
c10::impl::tls_set_dispatch_key_included(
|
||||
DispatchKey::PythonTLSSnapshot, false);
|
||||
}
|
||||
torchDispatchModeState.mode_.swap(mode);
|
||||
}
|
||||
|
||||
// STACK
|
||||
void TorchDispatchModeTLS::push_onto_stack(std::shared_ptr<SafePyObject> mode) {
|
||||
torchDispatchModeState.stack_.push_back(std::move(mode));
|
||||
}
|
||||
|
||||
const std::shared_ptr<SafePyObject> TorchDispatchModeTLS::pop_stack() {
|
||||
TORCH_CHECK(
|
||||
torchDispatchModeState.stack_.size() > 0,
|
||||
"trying to pop from empty mode stack");
|
||||
const std::shared_ptr<SafePyObject> out =
|
||||
torchDispatchModeState.stack_.back();
|
||||
torchDispatchModeState.stack_.pop_back();
|
||||
return out;
|
||||
}
|
||||
|
||||
const std::shared_ptr<SafePyObject>& TorchDispatchModeTLS::get_stack_at(
|
||||
int64_t idx) {
|
||||
TORCH_CHECK(
|
||||
idx < static_cast<int64_t>(torchDispatchModeState.stack_.size()),
|
||||
"Tried to get stack at idx that's too big");
|
||||
return torchDispatchModeState.stack_[idx];
|
||||
}
|
||||
|
||||
int64_t TorchDispatchModeTLS::stack_len() {
|
||||
return torchDispatchModeState.stack_.size();
|
||||
}
|
||||
|
||||
// STATE
|
||||
|
||||
const TorchDispatchModeTLS& TorchDispatchModeTLS::get_state() {
|
||||
return torchDispatchModeState;
|
||||
}
|
||||
|
||||
void TorchDispatchModeTLS::set_state(const TorchDispatchModeTLS& state) {
|
||||
torchDispatchModeState = state;
|
||||
}
|
||||
|
||||
// UTIL
|
||||
|
||||
bool dispatch_mode_enabled() {
|
||||
return static_cast<bool>(c10::impl::TorchDispatchModeTLS::get_state());
|
||||
return static_cast<bool>(c10::impl::TorchDispatchModeTLS::get_mode());
|
||||
}
|
||||
|
||||
} // namespace impl
|
||||
|
@ -9,9 +9,27 @@ namespace c10 {
|
||||
namespace impl {
|
||||
|
||||
struct C10_API TorchDispatchModeTLS {
|
||||
static void set_state(std::shared_ptr<SafePyObject> state);
|
||||
static const std::shared_ptr<SafePyObject>& get_state();
|
||||
static void reset_state();
|
||||
static void set_mode(std::shared_ptr<SafePyObject> mode);
|
||||
static const std::shared_ptr<SafePyObject>& get_mode();
|
||||
static void reset_mode();
|
||||
static void swap_mode(std::shared_ptr<SafePyObject>& mode);
|
||||
|
||||
static void push_onto_stack(std::shared_ptr<SafePyObject> mode);
|
||||
static const std::shared_ptr<SafePyObject> pop_stack();
|
||||
static const std::shared_ptr<SafePyObject>& get_stack_at(int64_t idx);
|
||||
static int64_t stack_len();
|
||||
|
||||
static const TorchDispatchModeTLS& get_state();
|
||||
static void set_state(const TorchDispatchModeTLS& state);
|
||||
|
||||
private:
|
||||
// The mode TLS is split into
|
||||
// - mode_, which is the C++ mode, that can only be the mode handling mode
|
||||
// or null
|
||||
// - stack_, which is a vector of modes representing the stack of user
|
||||
// defined modes
|
||||
std::shared_ptr<c10::SafePyObject> mode_;
|
||||
std::vector<std::shared_ptr<c10::SafePyObject>> stack_;
|
||||
};
|
||||
|
||||
C10_API bool dispatch_mode_enabled();
|
||||
|
@ -4,7 +4,7 @@ from torch._ops import PyOperator
|
||||
from torch.utils._pytree import tree_flatten
|
||||
from torch.fx.experimental.proxy_tensor import get_isolated_graphmodule, get_proxy_slot
|
||||
import torch.utils._pytree as pytree
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
from torch.utils._python_dispatch import TorchDispatchMode, _get_current_dispatch_mode
|
||||
from torch.fx.experimental.proxy_tensor import track_tensor_tree
|
||||
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode
|
||||
|
||||
@ -29,15 +29,6 @@ def suspend_mode(mode):
|
||||
finally:
|
||||
torch._C._set_torch_dispatch_mode(mode)
|
||||
|
||||
@contextmanager
|
||||
def enable_mode(mode):
|
||||
curr_mode = torch._C._get_torch_dispatch_mode()
|
||||
torch._C._set_torch_dispatch_mode(mode)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
torch._C._set_torch_dispatch_mode(curr_mode)
|
||||
|
||||
|
||||
def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
|
||||
def _unwrap_proxy(e):
|
||||
@ -104,7 +95,7 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
|
||||
|
||||
@cond.py_impl(DispatchKey.CPU)
|
||||
def cond_dense(pred, true_fn, false_fn, operands):
|
||||
mode = torch._C._get_torch_dispatch_mode()
|
||||
mode = _get_current_dispatch_mode()
|
||||
assert (mode is None), "Mode should never be enabled for CPU key"
|
||||
if pred:
|
||||
return true_fn(*operands)
|
||||
@ -125,7 +116,7 @@ def cond_autograd(pred, true_fn, false_fn, *operands):
|
||||
|
||||
@cond.py_impl(ProxyTorchDispatchMode)
|
||||
def inner(pred, true_fn, false_fn, operands):
|
||||
mode = torch._C._get_torch_dispatch_mode()
|
||||
mode = _get_current_dispatch_mode()
|
||||
assert (mode is not None), "Mode should always be enabled for python fallback key"
|
||||
with suspend_mode(mode):
|
||||
res = trace_cond(mode, cond, pred, true_fn, false_fn, operands)
|
||||
|
@ -3,8 +3,8 @@
|
||||
from collections import defaultdict
|
||||
from torch import Tensor
|
||||
import torch.autograd
|
||||
from torch.utils._python_dispatch import enable_torch_dispatch_mode
|
||||
from torch._decomp import decomposition_table
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
|
||||
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
|
||||
from torch.utils._mode_utils import no_dispatch
|
||||
@ -384,17 +384,16 @@ class TestDecomp(TestCase):
|
||||
|
||||
saved_precision = self.precision
|
||||
saved_rel_tol = self.rel_tol
|
||||
test_case = self
|
||||
|
||||
class DecompCrossRefMode(torch.Tensor):
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
class DecompCrossRefMode(TorchDispatchMode):
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
with no_dispatch():
|
||||
return cls._torch_dispatch(func, types, args, kwargs)
|
||||
return self._torch_dispatch(func, types, args, kwargs)
|
||||
|
||||
@classmethod
|
||||
def _torch_dispatch(cls, func, types, args=(), kwargs=None):
|
||||
self.precision = saved_precision
|
||||
self.rel_tol = saved_rel_tol
|
||||
def _torch_dispatch(self, func, types, args=(), kwargs=None):
|
||||
test_case.precision = saved_precision
|
||||
test_case.rel_tol = saved_rel_tol
|
||||
|
||||
called.add(func)
|
||||
all_called[func] += 1
|
||||
@ -441,14 +440,14 @@ class TestDecomp(TestCase):
|
||||
assert type(orig) == type(decomp)
|
||||
assert orig == decomp
|
||||
continue
|
||||
op_assert_ref(self, func, test_dtype, i, orig, decomp, ref, args, kwargs)
|
||||
op_assert_ref(test_case, func, test_dtype, i, orig, decomp, ref, args, kwargs)
|
||||
else:
|
||||
for orig, decomp in zip(real_out, decomp_out):
|
||||
if not isinstance(orig, torch.Tensor):
|
||||
assert type(orig) == type(decomp)
|
||||
assert orig == decomp
|
||||
continue
|
||||
op_assert_equal(self, func, test_dtype, orig, decomp, args, kwargs)
|
||||
op_assert_equal(test_case, func, test_dtype, orig, decomp, args, kwargs)
|
||||
|
||||
return real_out_unflat
|
||||
|
||||
@ -487,7 +486,7 @@ class TestDecomp(TestCase):
|
||||
# explicit clearing is necessary as I will create a fresh mode
|
||||
# for each region
|
||||
decomposed.clear()
|
||||
with enable_torch_dispatch_mode(DecompCrossRefMode), enable_python_dispatcher():
|
||||
with DecompCrossRefMode(), enable_python_dispatcher():
|
||||
decomp_out, decomp_vjp_fn = ref_vjp_no_create(fn, *primals)
|
||||
if aten_name in decomposition_names:
|
||||
check_decomposed(aten_name)
|
||||
@ -496,7 +495,7 @@ class TestDecomp(TestCase):
|
||||
cotangents = tree_map(lambda x: torch.randn_like(x), decomp_out)
|
||||
|
||||
decomposed.clear()
|
||||
with enable_torch_dispatch_mode(DecompCrossRefMode), enable_python_dispatcher():
|
||||
with DecompCrossRefMode(), enable_python_dispatcher():
|
||||
decomp_vjp_fn(cotangents)
|
||||
if not run_all:
|
||||
check_decomposed(op.aten_backward_name)
|
||||
@ -505,7 +504,7 @@ class TestDecomp(TestCase):
|
||||
args = [sample_input.input] + list(sample_input.args)
|
||||
kwargs = sample_input.kwargs
|
||||
decomposed.clear()
|
||||
with enable_torch_dispatch_mode(DecompCrossRefMode), enable_python_dispatcher():
|
||||
with DecompCrossRefMode(), enable_python_dispatcher():
|
||||
func(*args, **kwargs)
|
||||
if not run_all:
|
||||
check_decomposed(aten_name)
|
||||
|
@ -12,7 +12,6 @@ from torch._subclasses.fake_tensor import (
|
||||
DynamicOutputShapeException,
|
||||
)
|
||||
from torch.testing import FileCheck
|
||||
from torch.utils._python_dispatch import enable_torch_dispatch_mode
|
||||
from torch import nn
|
||||
import unittest
|
||||
import torch._prims as prims
|
||||
@ -27,10 +26,9 @@ class FakeTensorTest(TestCase):
|
||||
self.assertEqual(list(t.size()), size)
|
||||
|
||||
def test_basic(self):
|
||||
mode = FakeTensorMode()
|
||||
x = torch.empty(2, 2, device="cpu")
|
||||
y = torch.empty(4, 2, 2, device="cpu")
|
||||
with enable_torch_dispatch_mode(mode):
|
||||
with FakeTensorMode() as mode:
|
||||
x = mode.from_tensor(x)
|
||||
y = mode.from_tensor(y)
|
||||
z = x + y
|
||||
@ -39,7 +37,7 @@ class FakeTensorTest(TestCase):
|
||||
self.assertTrue(isinstance(z, FakeTensor))
|
||||
|
||||
def test_parameter_instantiation(self):
|
||||
with enable_torch_dispatch_mode(FakeTensorMode()):
|
||||
with FakeTensorMode():
|
||||
x = torch.rand([4])
|
||||
y = torch.nn.parameter.Parameter(x)
|
||||
self.assertTrue(isinstance(y, torch.nn.Parameter))
|
||||
@ -52,14 +50,14 @@ class FakeTensorTest(TestCase):
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "requires cuda")
|
||||
def test_index_cuda_with_cpu(self):
|
||||
with enable_torch_dispatch_mode(FakeTensorMode()):
|
||||
with FakeTensorMode():
|
||||
x = torch.rand([2048], device='cuda')
|
||||
out = x[torch.zeros([36], dtype=torch.int64)]
|
||||
self.checkType(out, "cuda", [36])
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "requires cuda")
|
||||
def test_shape_take_not_device(self):
|
||||
with enable_torch_dispatch_mode(FakeTensorMode()):
|
||||
with FakeTensorMode():
|
||||
x = torch.empty(1, device="cpu")
|
||||
y = torch.empty(8, 8, device="cuda")
|
||||
out = x.resize_as_(y)
|
||||
@ -69,8 +67,7 @@ class FakeTensorTest(TestCase):
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "requires cuda")
|
||||
def test_zero_dim(self):
|
||||
mode = FakeTensorMode()
|
||||
with enable_torch_dispatch_mode(mode):
|
||||
with FakeTensorMode() as mode:
|
||||
x = torch.tensor(0.)
|
||||
y = torch.rand([4, 4], device="cuda")
|
||||
out = x + y
|
||||
@ -79,8 +76,7 @@ class FakeTensorTest(TestCase):
|
||||
self.assertTrue(isinstance(out, FakeTensor))
|
||||
|
||||
def test_nan_to_num(self):
|
||||
mode = FakeTensorMode()
|
||||
with enable_torch_dispatch_mode(mode):
|
||||
with FakeTensorMode():
|
||||
for dtype in [torch.float16, torch.float32]:
|
||||
x = torch.rand([4], dtype=dtype)
|
||||
y = torch.nan_to_num(x, nan=None)
|
||||
@ -90,9 +86,8 @@ class FakeTensorTest(TestCase):
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "requires cuda")
|
||||
def test_throw(self):
|
||||
mode = FakeTensorMode()
|
||||
x = torch.tensor(0.) # TODO: tensor() errors
|
||||
with enable_torch_dispatch_mode(mode):
|
||||
with FakeTensorMode() as mode:
|
||||
x_conv = mode.from_tensor(x)
|
||||
y = torch.rand([4, 4], device="cuda")
|
||||
z = torch.rand([4, 4], device="cpu")
|
||||
@ -100,7 +95,7 @@ class FakeTensorTest(TestCase):
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "requires cuda")
|
||||
def test_type_as(self):
|
||||
with enable_torch_dispatch_mode(FakeTensorMode()):
|
||||
with FakeTensorMode():
|
||||
x = torch.rand([16, 1], device="cpu")
|
||||
y = torch.rand([4, 4], device="cuda")
|
||||
out = x.type_as(y)
|
||||
@ -110,12 +105,12 @@ class FakeTensorTest(TestCase):
|
||||
@unittest.skipIf(not RUN_CUDA, "requires cuda")
|
||||
def test_setitem(self):
|
||||
for device in ["cpu", "cuda"]:
|
||||
with enable_torch_dispatch_mode(FakeTensorMode()):
|
||||
with FakeTensorMode():
|
||||
x = torch.rand([16, 1], device=device)
|
||||
x[..., 0] = 0
|
||||
|
||||
def test_fake_dispatch_keys(self):
|
||||
with enable_torch_dispatch_mode(FakeTensorMode()):
|
||||
with FakeTensorMode():
|
||||
x = torch.rand([4])
|
||||
f = FileCheck().check("CPU").check("ADInplaceOrView").check("AutogradCPU").check("AutocastCPU")
|
||||
f.run(torch._C._dispatch_key_set(x))
|
||||
@ -127,14 +122,14 @@ class FakeTensorTest(TestCase):
|
||||
FileCheck().check_not("ADInplaceOrView").check_not("Autograd").run(torch._C._dispatch_key_set(y))
|
||||
|
||||
def test_constructor(self):
|
||||
with enable_torch_dispatch_mode(FakeTensorMode()):
|
||||
with FakeTensorMode():
|
||||
x = torch.rand([4, 4], device="cpu")
|
||||
|
||||
self.assertTrue(isinstance(x, FakeTensor))
|
||||
self.assertTrue(x.device.type == "cpu")
|
||||
|
||||
def test_mode(self):
|
||||
with enable_torch_dispatch_mode(FakeTensorMode()):
|
||||
with FakeTensorMode():
|
||||
y = torch.rand([4], device="cpu")
|
||||
out = y + y
|
||||
|
||||
@ -142,7 +137,7 @@ class FakeTensorTest(TestCase):
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "requires cuda")
|
||||
def test_non_kwarg_device(self):
|
||||
with enable_torch_dispatch_mode(FakeTensorMode()):
|
||||
with FakeTensorMode():
|
||||
x = torch.rand([16, 1], device="cpu")
|
||||
y = x.to(torch.device("cpu"))
|
||||
self.assertIs(x, y)
|
||||
@ -153,7 +148,7 @@ class FakeTensorTest(TestCase):
|
||||
x = torch.rand([4, 4])
|
||||
|
||||
with self.assertRaisesRegex(Exception, "non-Fake Tensor inputs"):
|
||||
with enable_torch_dispatch_mode(FakeTensorMode()):
|
||||
with FakeTensorMode():
|
||||
y = x[0]
|
||||
|
||||
def test_fake_grad_copy(self):
|
||||
@ -168,7 +163,7 @@ class FakeTensorTest(TestCase):
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "requires cuda")
|
||||
def test_like_constructor(self):
|
||||
with enable_torch_dispatch_mode(FakeTensorMode()):
|
||||
with FakeTensorMode():
|
||||
x = torch.rand([4, 4])
|
||||
y = torch.ones_like(x)
|
||||
self.assertTrue(isinstance(y, FakeTensor))
|
||||
@ -178,7 +173,7 @@ class FakeTensorTest(TestCase):
|
||||
self.assertEqual(z.device.type, "cuda")
|
||||
|
||||
def test_binary_op_type_promotion(self):
|
||||
with enable_torch_dispatch_mode(FakeTensorMode()):
|
||||
with FakeTensorMode():
|
||||
x = torch.empty([2, 2], dtype=torch.float)
|
||||
y = torch.empty([2, 2], dtype=torch.int64)
|
||||
out = x / y
|
||||
@ -186,14 +181,14 @@ class FakeTensorTest(TestCase):
|
||||
self.assertEqual(out.device.type, "cpu")
|
||||
|
||||
def test_from_numpy(self):
|
||||
with enable_torch_dispatch_mode(FakeTensorMode()):
|
||||
with FakeTensorMode():
|
||||
x = torch.tensor(np.zeros([4, 4]))
|
||||
self.checkType(x, "cpu", [4, 4])
|
||||
|
||||
def test_randperm(self):
|
||||
x = torch.randperm(10)
|
||||
y = torch.randperm(5, device="cpu")
|
||||
with enable_torch_dispatch_mode(FakeTensorMode()):
|
||||
with FakeTensorMode():
|
||||
x1 = torch.randperm(10)
|
||||
prims.utils.compare_tensor_meta(x, x1)
|
||||
y1 = torch.randperm(5, device="cpu")
|
||||
@ -202,21 +197,21 @@ class FakeTensorTest(TestCase):
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "requires cuda")
|
||||
def test_cpu_fallback(self):
|
||||
with enable_torch_dispatch_mode(FakeTensorMode(allow_fallback_kernels=False)):
|
||||
with FakeTensorMode(allow_fallback_kernels=False):
|
||||
filters = torch.randn(8, 4, 3, 3).cuda()
|
||||
inputs = torch.randn(1, 4, 5, 5).cuda()
|
||||
out = torch.nn.functional.conv2d(inputs, filters, padding=1)
|
||||
self.assertEqual(out.device.type, "cuda")
|
||||
self.assertEqual(list(out.size()), [1, 8, 5, 5])
|
||||
|
||||
with enable_torch_dispatch_mode(FakeTensorMode(allow_fallback_kernels=True)):
|
||||
with FakeTensorMode(allow_fallback_kernels=True):
|
||||
# intentionally bad inputs
|
||||
filters = torch.randn(8, 20, 3, 3).cuda()
|
||||
inputs = torch.randn(1, 7, 10, 5).cuda()
|
||||
with self.assertRaises(RuntimeError):
|
||||
torch.nn.functional.conv2d(inputs, filters, padding=1)
|
||||
|
||||
with enable_torch_dispatch_mode(FakeTensorMode(allow_fallback_kernels=True)):
|
||||
with FakeTensorMode(allow_fallback_kernels=True):
|
||||
filters = torch.randn(8, 4, 3, 3).cuda()
|
||||
inputs = torch.randn(1, 4, 5, 5).cuda()
|
||||
|
||||
@ -295,7 +290,7 @@ class FakeTensorTest(TestCase):
|
||||
)
|
||||
|
||||
mode = FakeTensorMode()
|
||||
for i, context in enumerate([contextlib.nullcontext, lambda: enable_torch_dispatch_mode(mode)]):
|
||||
for i, context in enumerate([contextlib.nullcontext, lambda: mode]):
|
||||
with context():
|
||||
inps = (
|
||||
torch.randn([92, 8, 2048]).cuda(),
|
||||
@ -336,16 +331,14 @@ class FakeTensorTest(TestCase):
|
||||
with torch._subclasses.fake_tensor.FakeCopyMode(mode):
|
||||
mod_copied = copy.deepcopy(m)
|
||||
|
||||
with enable_torch_dispatch_mode(mode):
|
||||
with mode:
|
||||
input = torch.rand(20, 16, 50, 100, dtype=torch.half, device="cuda").to(memory_format=torch.channels_last)
|
||||
out = mod_copied(input)
|
||||
self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
|
||||
self.checkType(out, "cuda", [20, 33, 24, 49])
|
||||
|
||||
def test_data_dependent_operator(self):
|
||||
with enable_torch_dispatch_mode(
|
||||
FakeTensorMode(allow_fallback_kernels=False)
|
||||
):
|
||||
with FakeTensorMode(allow_fallback_kernels=False):
|
||||
x = torch.rand([10, 10])
|
||||
|
||||
self.assertRaises(DynamicOutputShapeException, lambda: torch.nonzero(x))
|
||||
@ -355,7 +348,8 @@ class FakeTensorTest(TestCase):
|
||||
|
||||
@skipIfCrossRef
|
||||
def test_deepcopy(self):
|
||||
mode = FakeTensorMode()
|
||||
with FakeTensorMode() as mode:
|
||||
pass
|
||||
mod = torch.nn.BatchNorm2d(10)
|
||||
with torch._subclasses.fake_tensor.FakeCopyMode(mode):
|
||||
mod_copied = copy.deepcopy(mod)
|
||||
@ -386,7 +380,7 @@ class FakeTensorTest(TestCase):
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "requires cuda")
|
||||
def test_new(self):
|
||||
with enable_torch_dispatch_mode(FakeTensorMode()):
|
||||
with FakeTensorMode():
|
||||
a = torch.rand([16, 1])
|
||||
self.checkType(a.new(10, 10), "cpu", [10, 10])
|
||||
self.checkType(a.new([1, 2, 3, 4]), "cpu", [4])
|
||||
@ -527,8 +521,7 @@ class FakeTensorConverterTest(TestCase):
|
||||
self.assertEqual(len(converter.meta_converter.tensor_memo), 0)
|
||||
|
||||
def test_no_active_mode(self):
|
||||
mode = FakeTensorMode()
|
||||
with enable_torch_dispatch_mode(mode):
|
||||
with FakeTensorMode() as mode:
|
||||
x = torch.empty(2, 2, device="cpu")
|
||||
y = torch.empty(2, 2, device="cpu")
|
||||
|
||||
@ -538,9 +531,9 @@ class FakeTensorConverterTest(TestCase):
|
||||
self.assertEqual(out.device.type, "cpu")
|
||||
|
||||
def test_separate_mode_error(self):
|
||||
with enable_torch_dispatch_mode(FakeTensorMode()):
|
||||
with FakeTensorMode():
|
||||
x = torch.empty(2, 2, device="cpu")
|
||||
with enable_torch_dispatch_mode(FakeTensorMode()):
|
||||
with FakeTensorMode():
|
||||
y = torch.empty(2, 2, device="cpu")
|
||||
self.assertRaises(Exception, lambda: x, y)
|
||||
|
||||
|
@ -63,7 +63,6 @@ from torch._subclasses.fake_tensor import (
|
||||
)
|
||||
from torch._subclasses.fake_utils import outputs_alias_inputs
|
||||
|
||||
from torch.utils._python_dispatch import enable_torch_dispatch_mode
|
||||
import torch._prims as prims
|
||||
from torch._prims.context import TorchRefsMode
|
||||
|
||||
@ -173,8 +172,7 @@ class TestCommon(TestCase):
|
||||
@onlyNativeDeviceTypes
|
||||
@ops(python_ref_db)
|
||||
def test_python_ref_meta(self, device, dtype, op):
|
||||
mode = FakeTensorMode()
|
||||
with mode:
|
||||
with FakeTensorMode() as mode:
|
||||
pass
|
||||
|
||||
def _to_tensormeta(x):
|
||||
@ -189,7 +187,7 @@ class TestCommon(TestCase):
|
||||
|
||||
meta_sample = sample.transform(_to_tensormeta)
|
||||
try:
|
||||
with enable_torch_dispatch_mode(mode):
|
||||
with mode:
|
||||
meta_result = op(meta_sample.input, *meta_sample.args, **meta_sample.kwargs)
|
||||
except torch._subclasses.fake_tensor.UnsupportedFakeTensorException:
|
||||
continue
|
||||
@ -1864,7 +1862,7 @@ class TestFakeTensor(TestCase):
|
||||
continue
|
||||
|
||||
with context():
|
||||
with enable_torch_dispatch_mode(mode):
|
||||
with mode:
|
||||
res_fake = op(input, *args, **kwargs)
|
||||
|
||||
|
||||
|
@ -16,9 +16,11 @@ from torch.overrides import (
|
||||
get_overridable_functions,
|
||||
get_testing_overrides,
|
||||
is_tensor_method_or_property,
|
||||
TorchFunctionMode
|
||||
TorchFunctionMode,
|
||||
_get_current_function_mode,
|
||||
_get_current_function_mode_stack,
|
||||
)
|
||||
from torch.utils._mode_utils import find_outermost_mode, all_same_mode, all_same_mode_scope
|
||||
from torch.utils._mode_utils import all_same_mode
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
Tensor = torch.Tensor
|
||||
@ -1150,11 +1152,6 @@ class TestTorchFunctionMode(TestCase):
|
||||
self.assertEqual(torch._sparse_csr_tensor_unsafe(1, 1, 1, (1, 1)), -1)
|
||||
self.assertEqual(torch.as_tensor([1]), -1)
|
||||
|
||||
def test_enable_torch_function_mode_with_tensor_subclass(self):
|
||||
x = torch.randn(1)
|
||||
with torch.overrides.enable_torch_function_mode(SubTensor):
|
||||
self.assertEqual(torch.mm(x, x), -1)
|
||||
|
||||
def test_modes_handle_first(self):
|
||||
class A(TorchFunctionMode):
|
||||
def __torch_function__(self, *args, **kwargs):
|
||||
@ -1178,54 +1175,9 @@ class TestTorchFunctionMode(TestCase):
|
||||
self.assertEqual(torch.mm(x, x), -1)
|
||||
self.assertEqual(bar(x), 1)
|
||||
self.assertRaisesRegex(
|
||||
TypeError, r'SubTensor.+MyMode',
|
||||
TypeError, r'SubTensor.+TorchFunctionStackMode',
|
||||
lambda: self.assertEqual(torch.max(x, x)))
|
||||
|
||||
def test_enable_torch_function_mode_trivial(self):
|
||||
class A(TorchFunctionMode):
|
||||
def __torch_function__(self, *args, **kwargs):
|
||||
return -40
|
||||
a = A()
|
||||
with torch.overrides.enable_torch_function_mode(a):
|
||||
with torch.overrides.enable_torch_function_mode(a):
|
||||
self.assertEqual(bar(None), -40)
|
||||
|
||||
def test_enable_torch_function_mode_replace(self):
|
||||
class A(TorchFunctionMode):
|
||||
def __init__(self, val):
|
||||
self.val = val
|
||||
|
||||
def __torch_function__(self, *args, **kwargs):
|
||||
return self.val
|
||||
a1 = A(-40)
|
||||
a2 = A(-41)
|
||||
with torch.overrides.enable_torch_function_mode(a1):
|
||||
with torch.overrides.enable_torch_function_mode(a2, replace=a1):
|
||||
self.assertEqual(bar(None), -41)
|
||||
|
||||
def test_enable_torch_function_mode_ignore_preexisting(self):
|
||||
class A(TorchFunctionMode):
|
||||
def __init__(self, val):
|
||||
self.val = val
|
||||
|
||||
def __torch_function__(self, *args, **kwargs):
|
||||
return self.val
|
||||
a1 = A(-40)
|
||||
a2 = A(-41)
|
||||
with torch.overrides.enable_torch_function_mode(a1):
|
||||
with torch.overrides.enable_torch_function_mode(a2, ignore_preexisting=True):
|
||||
self.assertEqual(bar(None), -41)
|
||||
|
||||
def test_ctor_no_inner(self):
|
||||
class A(TorchFunctionMode):
|
||||
def __torch_function__(self, *args, **kwargs):
|
||||
return torch.zeros([])
|
||||
|
||||
with torch.overrides.enable_torch_function_mode(A()):
|
||||
x = torch.randn((3, 4))
|
||||
|
||||
self.assertEqual(x, torch.zeros([]))
|
||||
|
||||
def test_with_mode(self):
|
||||
class ErrorA(RuntimeError):
|
||||
pass
|
||||
@ -1270,15 +1222,24 @@ class TestTorchFunctionMode(TestCase):
|
||||
|
||||
self.assertEqual(out, ["layer2", "layer1"])
|
||||
|
||||
def test_error_using_same_mode(self):
|
||||
class A(TorchFunctionMode):
|
||||
pass
|
||||
def test_nested_same_mode(self):
|
||||
out = []
|
||||
|
||||
x = A()
|
||||
with x:
|
||||
with self.assertRaisesRegex(RuntimeError, "has already been used as a mode. Please use a fresh version"):
|
||||
with x:
|
||||
pass
|
||||
class A(TorchFunctionMode):
|
||||
def __init__(self, msg):
|
||||
self.msg = msg
|
||||
|
||||
def __torch_function__(self, func, _, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
out.append(self.msg)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
with A("layer1") as a:
|
||||
with a:
|
||||
torch.empty([])
|
||||
|
||||
self.assertEqual(out, ["layer1", "layer1"])
|
||||
|
||||
def test_error_using_class_method_on_mode(self):
|
||||
class A(TorchFunctionMode):
|
||||
@ -1291,86 +1252,43 @@ class TestTorchFunctionMode(TestCase):
|
||||
with A():
|
||||
x + x
|
||||
|
||||
def test_error_with_ancestor(self):
|
||||
def test_restacking_with_ancestor(self):
|
||||
class A(TorchFunctionMode):
|
||||
pass
|
||||
|
||||
with A() as x:
|
||||
pass
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "has already been used as a mode. Please use a fresh version"):
|
||||
with x:
|
||||
pass
|
||||
|
||||
def test_restore_errors(self):
|
||||
class A(TorchFunctionMode):
|
||||
pass
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "does not have any ancestors. Use the standard version instead"):
|
||||
with A().restore():
|
||||
pass
|
||||
|
||||
x = A()
|
||||
with A():
|
||||
with x:
|
||||
with A() as x:
|
||||
pass
|
||||
|
||||
with A(): # a different mode instance than the one above
|
||||
with self.assertRaisesRegex(RuntimeError, "the current mode is not its ancestor"):
|
||||
with x.restore():
|
||||
pass
|
||||
|
||||
|
||||
def test_restore_ancestor_mode(self):
|
||||
class A(TorchFunctionMode):
|
||||
pass
|
||||
|
||||
x = A()
|
||||
y = A()
|
||||
with x:
|
||||
with y:
|
||||
pass
|
||||
pass
|
||||
|
||||
z = A()
|
||||
with y.restore():
|
||||
with z:
|
||||
pass
|
||||
|
||||
with x.restore():
|
||||
with z.restore():
|
||||
pass
|
||||
|
||||
def test_find_outermost_mode(self):
|
||||
def test_get_cur_mode(self):
|
||||
class A(TorchFunctionMode):
|
||||
pass
|
||||
|
||||
self.assertIsNone(find_outermost_mode([None, None]))
|
||||
|
||||
x = A()
|
||||
y = A()
|
||||
with x:
|
||||
with y:
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
pass
|
||||
|
||||
self.assertEqual(find_outermost_mode([x, y]), y)
|
||||
with A() as mode1:
|
||||
self.assertEqual(_get_current_function_mode(), mode1)
|
||||
|
||||
z = A()
|
||||
with y.restore():
|
||||
with z:
|
||||
with mode1:
|
||||
with A() as mode2:
|
||||
self.assertEqual(_get_current_function_mode(), mode2)
|
||||
|
||||
|
||||
def test_get_mode_stack(self):
|
||||
class A(TorchFunctionMode):
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
pass
|
||||
|
||||
self.assertEqual(find_outermost_mode([z, x]), z)
|
||||
i = A()
|
||||
self.assertEqual(_get_current_function_mode_stack(), [])
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "doesn't have ancestors set so the ordering with other modes"):
|
||||
find_outermost_mode([i, x, y, z])
|
||||
with A() as mode1:
|
||||
self.assertEqual(_get_current_function_mode_stack(), [mode1])
|
||||
|
||||
k = A()
|
||||
with k:
|
||||
pass
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "don't come from the same scope"):
|
||||
find_outermost_mode([k, x, y, z])
|
||||
with mode1:
|
||||
with A() as mode2:
|
||||
self.assertEqual(_get_current_function_mode_stack(), [mode1, mode2])
|
||||
|
||||
def test_all_same_mode(self):
|
||||
class A(TorchFunctionMode):
|
||||
@ -1382,32 +1300,6 @@ class TestTorchFunctionMode(TestCase):
|
||||
self.assertFalse(all_same_mode([x, None]))
|
||||
self.assertFalse(all_same_mode([x, y]))
|
||||
|
||||
def test_all_same_mode_scope(self):
|
||||
class A(TorchFunctionMode):
|
||||
pass
|
||||
|
||||
x = A()
|
||||
y = A()
|
||||
z = A()
|
||||
with x:
|
||||
with y:
|
||||
pass
|
||||
|
||||
with x.restore():
|
||||
with z:
|
||||
pass
|
||||
|
||||
i = A()
|
||||
|
||||
self.assertTrue(all_same_mode_scope([x, y], y))
|
||||
self.assertTrue(all_same_mode_scope([x, z], z))
|
||||
self.assertFalse(all_same_mode_scope([x, y, z], y))
|
||||
self.assertFalse(all_same_mode_scope([x, y, z], z))
|
||||
self.assertFalse(all_same_mode_scope([x, y, i], y))
|
||||
|
||||
no_ancestor = A()
|
||||
self.assertFalse(all_same_mode_scope([x, y, z], no_ancestor))
|
||||
|
||||
def test_reentrant_mode_idiom(self):
|
||||
log = []
|
||||
|
||||
@ -1417,7 +1309,7 @@ class TestTorchFunctionMode(TestCase):
|
||||
kwargs = {}
|
||||
log.append(func)
|
||||
if func is torch.sub:
|
||||
with torch.overrides.enable_torch_function_mode(self, replace=self.inner):
|
||||
with self:
|
||||
input, other = args
|
||||
assert not kwargs
|
||||
return torch.add(input, other, alpha=-1)
|
||||
|
@ -742,7 +742,7 @@ class TestDecomp(TestCase):
|
||||
op = torch._decomp.decomposition_table.get(torch.ops.aten.leaky_relu_backward.default)
|
||||
|
||||
def fn0(*arg):
|
||||
return _is_func_unsupported_nvfuser(mode, op, arg, {})
|
||||
return _is_func_unsupported_nvfuser(TorchRefsNvfuserCapabilityMode(), op, arg, {})
|
||||
|
||||
def fn1(x):
|
||||
x = x * 2
|
||||
@ -750,8 +750,8 @@ class TestDecomp(TestCase):
|
||||
x = x * 2
|
||||
return x
|
||||
|
||||
with TorchRefsNvfuserCapabilityMode() as mode:
|
||||
self.assertFalse(fn0(x, y, 0.3, False))
|
||||
self.assertFalse(fn0(x, y, 0.3, False))
|
||||
with TorchRefsNvfuserCapabilityMode():
|
||||
|
||||
# Autocast context has C++ level ATen calls that are hidden from
|
||||
# TorchRefsNvfuserCapabilityMode that works only on Python level.
|
||||
|
@ -7,11 +7,11 @@ from torch.library import Library
|
||||
from torch.cuda.jiterator import _create_jit_fn
|
||||
import unittest
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_ROCM, IS_WINDOWS
|
||||
from torch.utils._mode_utils import no_dispatch, find_outermost_mode, all_same_mode, all_same_mode_scope
|
||||
from torch.utils._mode_utils import no_dispatch, all_same_mode
|
||||
from torch.testing._internal.logging_tensor import LoggingTensor, LoggingTensorReentrant, LoggingTensorMode, \
|
||||
log_input, capture_logs, capture_logs_with_logging_tensor_mode
|
||||
from torch.utils._pytree import tree_map, tree_map_only
|
||||
from torch.utils._python_dispatch import enable_torch_dispatch_mode, TorchDispatchMode
|
||||
from torch.utils._python_dispatch import TorchDispatchMode, _get_current_dispatch_mode, _get_current_dispatch_mode_stack
|
||||
|
||||
import logging
|
||||
|
||||
@ -730,34 +730,22 @@ $6 = torch._ops.aten.add_.Tensor($1, $5)''')
|
||||
res = x.index_put_(idxs, v)
|
||||
self.assertEqual(called_funcs, [torch.ops.aten.index_put_.default])
|
||||
|
||||
def test_enable_torch_dispatch_mode_error(self) -> None:
|
||||
z = LoggingTensor(torch.empty([]))
|
||||
with self.assertRaisesRegex(ValueError, "expected to get TorchDispatchMode, Tensor-like class, or None"):
|
||||
with enable_torch_dispatch_mode(z):
|
||||
pass
|
||||
|
||||
def test_enable_torch_dispatch_mode_basic(self) -> None:
|
||||
def test_torch_dispatch_mode_basic(self) -> None:
|
||||
with capture_logs(is_mode=True) as logs:
|
||||
with enable_torch_dispatch_mode(LoggingTensorMode()):
|
||||
with LoggingTensorMode():
|
||||
torch.empty([])
|
||||
self.assertExpectedInline('\n'.join(logs), """\
|
||||
$0 = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)""")
|
||||
|
||||
def test_enable_torch_dispatch_mode_unrelated_tensors(self) -> None:
|
||||
def test_torch_dispatch_mode_unrelated_tensors(self) -> None:
|
||||
x = torch.randn([])
|
||||
y = torch.randn([])
|
||||
with capture_logs(is_mode=True) as logs:
|
||||
with enable_torch_dispatch_mode(LoggingTensorMode()):
|
||||
with LoggingTensorMode():
|
||||
x + y
|
||||
self.assertExpectedInline('\n'.join(logs), """\
|
||||
$2 = torch._ops.aten.add.Tensor($0, $1)""")
|
||||
|
||||
def test_nested_push_regular(self):
|
||||
with LoggingTensorMode.push() as mode:
|
||||
# This previously errored
|
||||
with LoggingTensorMode():
|
||||
pass
|
||||
|
||||
def test_nested_push_logging_tensor_mode(self):
|
||||
x = torch.randn([])
|
||||
y = torch.randn([])
|
||||
@ -799,7 +787,7 @@ $3 = torch._ops.aten.add.Tensor($1, $2)""")
|
||||
|
||||
self.assertEqual(logs1, logs2)
|
||||
|
||||
def test_enable_torch_dispatch_mode_subclass_priority(self) -> None:
|
||||
def test_torch_dispatch_mode_subclass_priority(self) -> None:
|
||||
class ErrorA(RuntimeError):
|
||||
pass
|
||||
|
||||
@ -813,7 +801,8 @@ $3 = torch._ops.aten.add.Tensor($1, $2)""")
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
raise ErrorA
|
||||
with AMode():
|
||||
raise ErrorA
|
||||
|
||||
class B(A):
|
||||
@staticmethod
|
||||
@ -822,6 +811,15 @@ $3 = torch._ops.aten.add.Tensor($1, $2)""")
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
with BMode():
|
||||
func(*args, **kwargs)
|
||||
|
||||
class AMode(TorchDispatchMode):
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
raise ErrorA
|
||||
|
||||
class BMode(TorchDispatchMode):
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
raise ErrorB
|
||||
|
||||
a = A(torch.empty(1))
|
||||
@ -834,38 +832,41 @@ $3 = torch._ops.aten.add.Tensor($1, $2)""")
|
||||
# B has precedence over A due to the subclass relationship yet
|
||||
# modes take precedence over arguments
|
||||
with self.assertRaises(ErrorA):
|
||||
with enable_torch_dispatch_mode(A):
|
||||
with AMode():
|
||||
b + b
|
||||
with self.assertRaises(ErrorB):
|
||||
with enable_torch_dispatch_mode(B):
|
||||
with BMode():
|
||||
a + a
|
||||
with self.assertRaises(ErrorB):
|
||||
with enable_torch_dispatch_mode(B):
|
||||
with BMode():
|
||||
a + b
|
||||
|
||||
def test_enable_torch_dispatch_mode_respects_no_dispatch(self) -> None:
|
||||
def test_mode_with_make_subclass(self):
|
||||
class SubTensor(torch.Tensor):
|
||||
@staticmethod
|
||||
def __new__(cls, elem):
|
||||
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
|
||||
|
||||
class BasicMode(TorchDispatchMode):
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
x = torch.randn(3)
|
||||
with BasicMode():
|
||||
y = SubTensor(x)
|
||||
self.assertIsInstance(y, SubTensor)
|
||||
|
||||
def test_torch_dispatch_mode_respects_no_dispatch(self) -> None:
|
||||
with capture_logs(is_mode=True) as logs1:
|
||||
with enable_torch_dispatch_mode(LoggingTensorMode()):
|
||||
with LoggingTensorMode():
|
||||
torch.ones([2, 3])
|
||||
with no_dispatch():
|
||||
torch.ones([2, 3])
|
||||
with capture_logs(is_mode=True) as logs2:
|
||||
with enable_torch_dispatch_mode(LoggingTensorMode()):
|
||||
with LoggingTensorMode():
|
||||
torch.ones([2, 3])
|
||||
self.assertEqual(logs1, logs2)
|
||||
|
||||
def test_enable_torch_dispatch_mode_instance(self) -> None:
|
||||
class TestMode(TorchDispatchMode):
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
return func(*args, **kwargs)
|
||||
|
||||
x = TestMode()
|
||||
y = torch.tensor([2.])
|
||||
with enable_torch_dispatch_mode(x):
|
||||
y + y
|
||||
|
||||
def test_shallow_copy_and_detach(self) -> None:
|
||||
seen = set()
|
||||
test_case = self
|
||||
@ -884,103 +885,25 @@ $3 = torch._ops.aten.add.Tensor($1, $2)""")
|
||||
loss = (x * x).sum()
|
||||
loss.backward()
|
||||
|
||||
def test_nested_enable_torch_dispatch_mode(self) -> None:
|
||||
class A(LoggingTensorMode):
|
||||
pass
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "there is already an active mode"):
|
||||
with enable_torch_dispatch_mode(LoggingTensorMode()):
|
||||
with enable_torch_dispatch_mode(A()):
|
||||
pass
|
||||
|
||||
# For nesting to be a noop, they need to be the same instance
|
||||
with self.assertRaisesRegex(ValueError, "there is already an active mode"):
|
||||
with enable_torch_dispatch_mode(LoggingTensorMode()):
|
||||
with enable_torch_dispatch_mode(LoggingTensorMode()):
|
||||
pass
|
||||
|
||||
def test_nesting_with_same_enable_torch_dispatch_mode(self) -> None:
|
||||
# "nested" enable_torch_dispatch_modes are allowed if they're the same mode (same instance).
|
||||
# It's the equivalent of a noop, so it will only write once to the log
|
||||
x = torch.tensor([3.])
|
||||
mode = LoggingTensorMode()
|
||||
with capture_logs(is_mode=True) as logs:
|
||||
log_input("x", x)
|
||||
with enable_torch_dispatch_mode(mode):
|
||||
with enable_torch_dispatch_mode(mode):
|
||||
x + x
|
||||
self.assertExpectedInline('\n'.join(logs), '''\
|
||||
$0 = input('x')
|
||||
$1 = torch._ops.aten.add.Tensor($0, $0)''')
|
||||
|
||||
def test_enable_torch_dispatch_mode_ignore_preexisting(self):
|
||||
class A(TorchDispatchMode):
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
raise AssertionError
|
||||
|
||||
x = torch.tensor([3.])
|
||||
with capture_logs(is_mode=True) as logs:
|
||||
with enable_torch_dispatch_mode(A()):
|
||||
with enable_torch_dispatch_mode(LoggingTensorMode(), ignore_preexisting=True):
|
||||
x + x
|
||||
self.assertExpectedInline('\n'.join(logs), """\
|
||||
$1 = torch._ops.aten.add.Tensor($0, $0)""")
|
||||
|
||||
def test_enable_torch_dispatch_mode_replace(self):
|
||||
class A(TorchDispatchMode):
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
raise AssertionError
|
||||
|
||||
x = torch.tensor([3.])
|
||||
outer_mode = A()
|
||||
with capture_logs(is_mode=True) as logs:
|
||||
with enable_torch_dispatch_mode(outer_mode):
|
||||
with enable_torch_dispatch_mode(LoggingTensorMode(), replace=outer_mode):
|
||||
x + x
|
||||
self.assertExpectedInline('\n'.join(logs), """\
|
||||
$1 = torch._ops.aten.add.Tensor($0, $0)""")
|
||||
|
||||
def test_exception_handling(self):
|
||||
class A(torch.Tensor):
|
||||
@staticmethod
|
||||
def __new__(cls, elem):
|
||||
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
class AMode(TorchDispatchMode):
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
if func.__name__ == 'randn.default':
|
||||
raise RuntimeError()
|
||||
return cls(torch.zeros(()))
|
||||
return A(torch.zeros(()))
|
||||
|
||||
with enable_torch_dispatch_mode(A):
|
||||
with AMode():
|
||||
try:
|
||||
torch.randn(())
|
||||
except RuntimeError:
|
||||
pass
|
||||
self.assertTrue(isinstance(torch.zeros(()), A))
|
||||
|
||||
def test_ctor_no_inner(self):
|
||||
class A(TorchDispatchMode):
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
return torch.zeros([])
|
||||
|
||||
with enable_torch_dispatch_mode(A()):
|
||||
x = torch.randn((3, 4))
|
||||
|
||||
self.assertEqual(x, torch.zeros([]))
|
||||
|
||||
def test_with_mode(self):
|
||||
class ErrorA(RuntimeError):
|
||||
pass
|
||||
|
||||
class A(TorchDispatchMode):
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
raise ErrorA()
|
||||
|
||||
with self.assertRaises(ErrorA):
|
||||
with A():
|
||||
torch.empty([])
|
||||
|
||||
def test_with_mode_created_separately(self):
|
||||
class ErrorA(RuntimeError):
|
||||
pass
|
||||
@ -1021,10 +944,7 @@ $1 = torch._ops.aten.add.Tensor($0, $0)""")
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
modes = (arg.mode for arg in args + tuple(kwargs.values()) if isinstance(arg, ModeTensor))
|
||||
outermost = find_outermost_mode(modes)
|
||||
with outermost.restore():
|
||||
return func(*args, **kwargs)
|
||||
raise NotImplementedError("Shouldn't be here")
|
||||
|
||||
class Mode(TorchDispatchMode):
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
@ -1111,70 +1031,17 @@ $1 = torch._ops.aten.add.Tensor($0, $0)""")
|
||||
with PoliteMode():
|
||||
a.abs()
|
||||
|
||||
def test_disable_mode(self):
|
||||
class FailEverythingMode(TorchDispatchMode):
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
raise RuntimeError("arf")
|
||||
def test_nesting_same_mode(self):
|
||||
# If the pushed mode is the same instance as the current mode, we allow pushing an already active mode.
|
||||
|
||||
with FailEverythingMode() as m:
|
||||
self.assertRaises(RuntimeError, lambda: torch.ones([2, 3]))
|
||||
with enable_torch_dispatch_mode(None, replace=m):
|
||||
torch.ones([2, 3])
|
||||
with capture_logs(is_mode=True) as logs:
|
||||
with LoggingTensorMode() as reenabled:
|
||||
with reenabled:
|
||||
torch.empty([])
|
||||
self.assertExpectedInline('\n'.join(logs), """\
|
||||
$0 = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
|
||||
$0 = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)""")
|
||||
|
||||
def test_make_wrapper_subclass_with_modes(self):
|
||||
class ModeTensor(torch.Tensor):
|
||||
def __new__(cls, elem, mode):
|
||||
r = torch.Tensor._make_wrapper_subclass(cls, elem.shape)
|
||||
r.elem = elem
|
||||
r.mode = mode
|
||||
return r
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
modes = (arg.mode for arg in args + tuple(kwargs.values()) if isinstance(arg, ModeTensor))
|
||||
outermost = find_outermost_mode(modes)
|
||||
with outermost.restore():
|
||||
return func(*args, **kwargs)
|
||||
|
||||
class Mode(TorchDispatchMode):
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
def unwrap(e):
|
||||
if isinstance(e, ModeTensor):
|
||||
return e.elem
|
||||
else:
|
||||
return e
|
||||
|
||||
def wrap(t):
|
||||
if isinstance(t, torch.Tensor):
|
||||
return ModeTensor(t, self)
|
||||
else:
|
||||
return t
|
||||
|
||||
return wrap(func(*tuple(unwrap(a) for a in args), **kwargs))
|
||||
|
||||
x = torch.tensor(4.)
|
||||
with Mode():
|
||||
y = x + x
|
||||
z = y + y
|
||||
self.assertIsInstance(y, ModeTensor)
|
||||
self.assertIsInstance(z, ModeTensor)
|
||||
|
||||
with Mode():
|
||||
with Mode():
|
||||
y = x + x
|
||||
z = y + y
|
||||
self.assertIsInstance(y, ModeTensor)
|
||||
self.assertIsInstance(z, ModeTensor)
|
||||
|
||||
def test_error_using_same_mode(self):
|
||||
class A(TorchDispatchMode):
|
||||
pass
|
||||
|
||||
x = A()
|
||||
with x:
|
||||
with self.assertRaisesRegex(RuntimeError, "has already been used as a mode. Please use a fresh version"):
|
||||
with x:
|
||||
pass
|
||||
|
||||
def test_error_using_class_method_on_mode(self):
|
||||
class A(TorchDispatchMode):
|
||||
@ -1187,74 +1054,33 @@ $1 = torch._ops.aten.add.Tensor($0, $0)""")
|
||||
with A():
|
||||
x + x
|
||||
|
||||
def test_error_with_ancestor(self):
|
||||
x = LoggingTensorMode()
|
||||
with x:
|
||||
pass
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "has already been used as a mode. Please use a fresh version"):
|
||||
with x:
|
||||
def test_get_cur_mode(self):
|
||||
class A(TorchDispatchMode):
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
pass
|
||||
|
||||
def test_restore_errors(self):
|
||||
with self.assertRaisesRegex(RuntimeError, "does not have any ancestors. Use the standard version instead"):
|
||||
with LoggingTensorMode().restore():
|
||||
self.assertEqual(_get_current_dispatch_mode(), None)
|
||||
|
||||
with A() as mode1:
|
||||
self.assertEqual(_get_current_dispatch_mode(), mode1)
|
||||
|
||||
with mode1:
|
||||
with A() as mode2:
|
||||
self.assertEqual(_get_current_dispatch_mode(), mode2)
|
||||
|
||||
def test_get_mode_stack(self):
|
||||
class A(TorchDispatchMode):
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
pass
|
||||
|
||||
x = LoggingTensorMode()
|
||||
with LoggingTensorMode():
|
||||
with x:
|
||||
pass
|
||||
self.assertEqual(_get_current_dispatch_mode_stack(), [])
|
||||
|
||||
with LoggingTensorMode(): # a different mode instance than the one above
|
||||
with self.assertRaisesRegex(RuntimeError, "the current mode is not its ancestor"):
|
||||
with x.restore():
|
||||
pass
|
||||
with A() as mode1:
|
||||
self.assertEqual(_get_current_dispatch_mode_stack(), [mode1])
|
||||
|
||||
def test_restore_ancestor_mode(self):
|
||||
x = LoggingTensorMode()
|
||||
y = LoggingTensorMode()
|
||||
with x:
|
||||
with y:
|
||||
pass
|
||||
|
||||
z = LoggingTensorMode()
|
||||
with y.restore():
|
||||
with z:
|
||||
pass
|
||||
|
||||
with x.restore():
|
||||
with z.restore():
|
||||
pass
|
||||
|
||||
def test_find_outermost_mode(self):
|
||||
self.assertIsNone(find_outermost_mode([None, None]))
|
||||
|
||||
x = LoggingTensorMode()
|
||||
y = LoggingTensorMode()
|
||||
with x:
|
||||
with y:
|
||||
pass
|
||||
|
||||
self.assertEqual(find_outermost_mode([x, y]), y)
|
||||
|
||||
z = LoggingTensorMode()
|
||||
with y.restore():
|
||||
with z:
|
||||
pass
|
||||
|
||||
self.assertEqual(find_outermost_mode([z, x]), z)
|
||||
i = LoggingTensorMode()
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "doesn't have ancestors set so the ordering with other modes"):
|
||||
find_outermost_mode([i, x, y, z])
|
||||
|
||||
k = LoggingTensorMode()
|
||||
with k:
|
||||
pass
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "don't come from the same scope"):
|
||||
find_outermost_mode([k, x, y, z])
|
||||
with mode1:
|
||||
with A() as mode2:
|
||||
self.assertEqual(_get_current_dispatch_mode_stack(), [mode1, mode2])
|
||||
|
||||
def test_all_same_mode(self):
|
||||
x = LoggingTensorMode()
|
||||
@ -1263,29 +1089,6 @@ $1 = torch._ops.aten.add.Tensor($0, $0)""")
|
||||
self.assertFalse(all_same_mode([x, None]))
|
||||
self.assertFalse(all_same_mode([x, y]))
|
||||
|
||||
def test_all_same_mode_scope(self):
|
||||
x = LoggingTensorMode()
|
||||
y = LoggingTensorMode()
|
||||
z = LoggingTensorMode()
|
||||
with x:
|
||||
with y:
|
||||
pass
|
||||
|
||||
with x.restore():
|
||||
with z:
|
||||
pass
|
||||
|
||||
i = LoggingTensorMode()
|
||||
|
||||
self.assertTrue(all_same_mode_scope([x, y], y))
|
||||
self.assertTrue(all_same_mode_scope([x, z], z))
|
||||
self.assertFalse(all_same_mode_scope([x, y, z], y))
|
||||
self.assertFalse(all_same_mode_scope([x, y, z], z))
|
||||
self.assertFalse(all_same_mode_scope([x, y, i], y))
|
||||
|
||||
no_ancestor = LoggingTensorMode()
|
||||
self.assertFalse(all_same_mode_scope([x, y, z], no_ancestor))
|
||||
|
||||
def test_tolist_numpy_with_torch_dispatch_mode(self) -> None:
|
||||
x = LoggingTensor(torch.tensor([2.0, 3.0]))
|
||||
with self.assertRaisesRegex(RuntimeError, "is not supported for tensor subclasses."):
|
||||
@ -1295,7 +1098,7 @@ $1 = torch._ops.aten.add.Tensor($0, $0)""")
|
||||
with self.assertRaises(AssertionError):
|
||||
self.assertEqual(x, None)
|
||||
|
||||
def test_enable_torch_dispatch_mode_subclass_autograd_device_check(self) -> None:
|
||||
def test_subclass_autograd_device_check(self) -> None:
|
||||
class NonWrapperSubclass(torch.Tensor):
|
||||
elem: torch.Tensor
|
||||
|
||||
|
@ -8,7 +8,7 @@ from torch.utils._pytree import tree_map
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.fx.operator_schemas import normalize_function
|
||||
from torch.testing._internal.schema_check_mode import SchemaCheckMode
|
||||
from torch.utils._python_dispatch import enable_torch_dispatch_mode, TorchDispatchMode
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
from torch.testing._internal.common_methods_invocations import op_db
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
from torch.testing._internal.common_device_type import ops, OpDTypes, instantiate_device_type_tests
|
||||
@ -72,16 +72,14 @@ class IncorrectAliasTensor(torch.Tensor):
|
||||
class TestSchemaCheck(JitTestCase):
|
||||
# Tests that SchemaCheckMode records operator order with grad
|
||||
def test_schema_check_mode_operator_order(self):
|
||||
schema_check = SchemaCheckMode()
|
||||
with enable_torch_dispatch_mode(schema_check):
|
||||
with SchemaCheckMode() as schema_check:
|
||||
x = torch.rand((3, 3), requires_grad=True)
|
||||
x.relu().sin()
|
||||
self.assertEqual(["aten::rand", "aten::relu", "aten::detach", "aten::sin"], schema_check.ops)
|
||||
|
||||
# Tests that SchemaCheckMode records operator order without grad
|
||||
def test_schema_check_mode_operator_order_without_grad(self):
|
||||
schema_check = SchemaCheckMode()
|
||||
with enable_torch_dispatch_mode(schema_check):
|
||||
with SchemaCheckMode() as schema_check:
|
||||
x = torch.rand((3, 3), requires_grad=False)
|
||||
x.relu().sin()
|
||||
self.assertEqual(["aten::rand", "aten::relu", "aten::sin"], schema_check.ops)
|
||||
@ -91,8 +89,7 @@ class TestSchemaCheck(JitTestCase):
|
||||
# NB: previously requires_grad=True, but this induces a detach for
|
||||
# saved variable
|
||||
x = torch.rand((3, 3))
|
||||
schema_check = SchemaCheckMode()
|
||||
with enable_torch_dispatch_mode(schema_check):
|
||||
with SchemaCheckMode() as schema_check:
|
||||
actual = x.relu().sin()
|
||||
self.assertEqual([], schema_check.mutated)
|
||||
self.assertEqual([], schema_check.aliasing)
|
||||
@ -100,8 +97,7 @@ class TestSchemaCheck(JitTestCase):
|
||||
# Tests that SchemaCheckMode records mutations and aliases with mutation expected
|
||||
def test_schema_check_mode_mutated_aliasing_mutation(self):
|
||||
actual = torch.rand((3, 3), requires_grad=False)
|
||||
schema_check = SchemaCheckMode()
|
||||
with enable_torch_dispatch_mode(schema_check):
|
||||
with SchemaCheckMode() as schema_check:
|
||||
actual.sinh_()
|
||||
self.assertEqual([('aten::sinh_', 'input')], schema_check.mutated)
|
||||
self.assertEqual([('aten::sinh_', 'input', 'output_0')], schema_check.aliasing)
|
||||
@ -109,8 +105,7 @@ class TestSchemaCheck(JitTestCase):
|
||||
# Tests that SchemaCheckMode records mutations and aliases with resize_
|
||||
def test_schema_check_mode_mutated_aliasing_resize_(self):
|
||||
actual = torch.rand((3, 3), requires_grad=False)
|
||||
schema_check = SchemaCheckMode()
|
||||
with enable_torch_dispatch_mode(schema_check):
|
||||
with SchemaCheckMode() as schema_check:
|
||||
actual.resize_(9)
|
||||
self.assertEqual([('aten::resize_', 'input')], schema_check.mutated)
|
||||
self.assertEqual([('aten::resize_', 'input', 'output_0')], schema_check.aliasing)
|
||||
@ -119,8 +114,7 @@ class TestSchemaCheck(JitTestCase):
|
||||
def test_schema_check_mode_mutated_aliasing_aliasing_inputs(self):
|
||||
actual = torch.rand((3, 3))
|
||||
y = actual
|
||||
schema_check = SchemaCheckMode()
|
||||
with enable_torch_dispatch_mode(schema_check):
|
||||
with SchemaCheckMode() as schema_check:
|
||||
actual.add_(y)
|
||||
self.assertEqual(
|
||||
[
|
||||
@ -140,8 +134,7 @@ class TestSchemaCheck(JitTestCase):
|
||||
# Tests that SchemaCheckMode records mutations and alias with as_strided
|
||||
def test_schema_check_mode_mutated_aliasing_as_strided(self):
|
||||
x = torch.rand((3, 6, 4))
|
||||
schema_check = SchemaCheckMode()
|
||||
with enable_torch_dispatch_mode(schema_check):
|
||||
with SchemaCheckMode() as schema_check:
|
||||
x.as_strided_([3, 6, 4], [9, 1, 1])
|
||||
self.assertEqual(
|
||||
[
|
||||
@ -161,8 +154,7 @@ class TestSchemaCheck(JitTestCase):
|
||||
x = torch.arange(9.)
|
||||
m_actual = torch.arange(9.)
|
||||
e_actual = torch.zeros([9], dtype=torch.int32)
|
||||
schema_check = SchemaCheckMode()
|
||||
with enable_torch_dispatch_mode(schema_check):
|
||||
with SchemaCheckMode() as schema_check:
|
||||
torch.frexp(x, out=(m_actual, e_actual))
|
||||
self.assertEqual(
|
||||
[
|
||||
@ -183,8 +175,7 @@ class TestSchemaCheck(JitTestCase):
|
||||
def test_schema_check_mode_mutated_aliasing_aliasing_outputs(self):
|
||||
x = torch.rand((3, 3))
|
||||
actual = torch.zeros(3)
|
||||
schema_check = SchemaCheckMode()
|
||||
with enable_torch_dispatch_mode(schema_check):
|
||||
with SchemaCheckMode() as schema_check:
|
||||
torch.aminmax(x, dim=0, out=[actual, actual])
|
||||
self.assertEqual(
|
||||
[
|
||||
@ -207,7 +198,7 @@ class TestSchemaCheck(JitTestCase):
|
||||
def test_schema_check_mode_functionality(self):
|
||||
x = torch.rand((3, 3), requires_grad=True)
|
||||
expected = x.relu().sin()
|
||||
with enable_torch_dispatch_mode(SchemaCheckMode()):
|
||||
with SchemaCheckMode():
|
||||
actual = x.relu().sin()
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
@ -215,7 +206,7 @@ class TestSchemaCheck(JitTestCase):
|
||||
def test_schema_check_mode_functionality_default_replaced(self):
|
||||
x = torch.rand((3, 3), requires_grad=True)
|
||||
expected = x.add(x, alpha=2)
|
||||
with enable_torch_dispatch_mode(SchemaCheckMode()):
|
||||
with SchemaCheckMode():
|
||||
actual = x.add(x, alpha=2)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
@ -225,7 +216,7 @@ class TestSchemaCheck(JitTestCase):
|
||||
b = torch.rand((3, 3))
|
||||
c = torch.rand((3, 3))
|
||||
expected = torch.linalg.multi_dot([a, b, c])
|
||||
with enable_torch_dispatch_mode(SchemaCheckMode()):
|
||||
with SchemaCheckMode():
|
||||
actual = torch.linalg.multi_dot([a, b, c])
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
@ -233,7 +224,7 @@ class TestSchemaCheck(JitTestCase):
|
||||
def test_schema_check_mode_functionality_wildcard_after(self):
|
||||
x = torch.rand((3, 3))
|
||||
expected = x.chunk(6)
|
||||
with enable_torch_dispatch_mode(SchemaCheckMode()):
|
||||
with SchemaCheckMode():
|
||||
actual = x.chunk(6)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
@ -242,7 +233,7 @@ class TestSchemaCheck(JitTestCase):
|
||||
x = torch.rand((3, 5))
|
||||
w = torch.rand((4))
|
||||
expected = torch.stft(x, 4, win_length=4, window=w, return_complex=True)
|
||||
with enable_torch_dispatch_mode(SchemaCheckMode()):
|
||||
with SchemaCheckMode():
|
||||
actual = torch.stft(x, 4, win_length=4, window=w, return_complex=True)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
@ -251,7 +242,7 @@ class TestSchemaCheck(JitTestCase):
|
||||
expected = torch.rand((3, 3), requires_grad=False)
|
||||
actual = torch.clone(expected)
|
||||
expected.sinh_()
|
||||
with enable_torch_dispatch_mode(SchemaCheckMode()):
|
||||
with SchemaCheckMode():
|
||||
actual.sinh_()
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
@ -262,7 +253,7 @@ class TestSchemaCheck(JitTestCase):
|
||||
actual = torch.clone(expected)
|
||||
y = actual
|
||||
expected.add_(x)
|
||||
with enable_torch_dispatch_mode(SchemaCheckMode()):
|
||||
with SchemaCheckMode():
|
||||
actual.add_(y)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
@ -272,7 +263,7 @@ class TestSchemaCheck(JitTestCase):
|
||||
m_expected, e_expected = torch.frexp(x)
|
||||
m_actual = torch.arange(9.)
|
||||
e_actual = torch.zeros([9], dtype=torch.int32)
|
||||
with enable_torch_dispatch_mode(SchemaCheckMode()):
|
||||
with SchemaCheckMode():
|
||||
torch.frexp(x, out=(m_actual, e_actual))
|
||||
self.assertEqual(m_expected, m_actual)
|
||||
self.assertEqual(e_expected, e_actual)
|
||||
@ -281,13 +272,13 @@ class TestSchemaCheck(JitTestCase):
|
||||
def test_schema_check_mode_functionality_with_multiple_outputs_aliasing(self):
|
||||
x = torch.rand((3, 3))
|
||||
actual = torch.zeros(3)
|
||||
with enable_torch_dispatch_mode(SchemaCheckMode()):
|
||||
with SchemaCheckMode():
|
||||
torch.aminmax(x, dim=0, out=[actual, actual])
|
||||
self.assertEqual(torch.amax(x, dim=0), actual)
|
||||
|
||||
# Tests that SchemaCheckMode wraps Torch.tensor in ops with real Device input
|
||||
def test_schema_check_mode_functionality_device_input(self):
|
||||
with enable_torch_dispatch_mode(SchemaCheckMode()):
|
||||
with SchemaCheckMode():
|
||||
x = torch.rand((3, 3), device="cpu", dtype=torch.double)
|
||||
y = x + x
|
||||
self.assertEqual(x + x, y)
|
||||
@ -297,7 +288,7 @@ class TestSchemaCheck(JitTestCase):
|
||||
x = torch.rand((3, 3), requires_grad=True)
|
||||
batch = torch.nn.BatchNorm1d(3, track_running_stats=True)
|
||||
expected = batch(x)
|
||||
with enable_torch_dispatch_mode(SchemaCheckMode()):
|
||||
with SchemaCheckMode():
|
||||
actual = batch(x)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
@ -311,7 +302,7 @@ class TestSchemaCheck(JitTestCase):
|
||||
expected.relu_()
|
||||
expected = batch(expected)
|
||||
|
||||
with enable_torch_dispatch_mode(SchemaCheckMode()):
|
||||
with SchemaCheckMode():
|
||||
actual.sinh_()
|
||||
actual.tanh_()
|
||||
actual.relu_()
|
||||
@ -321,7 +312,7 @@ class TestSchemaCheck(JitTestCase):
|
||||
# Tests that SchemaCheckMode wraps Torch.tensor with empty list input
|
||||
def test_schema_check_mode_empty_list_input(self):
|
||||
expected = torch.atleast_1d([])
|
||||
with enable_torch_dispatch_mode(SchemaCheckMode()):
|
||||
with SchemaCheckMode():
|
||||
actual = torch.atleast_1d([])
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
@ -330,7 +321,7 @@ class TestSchemaCheck(JitTestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, "Argument input is not defined as mutable but was mutated"):
|
||||
x = torch.rand((3, 3))
|
||||
y = torch.rand((3, 3))
|
||||
with enable_torch_dispatch_mode(SchemaCheckMode()):
|
||||
with SchemaCheckMode():
|
||||
IncorrectAliasTensor(x).sub(IncorrectAliasTensor(y))
|
||||
|
||||
# # Tests that an exception is raised for a mismatching mutation over multiple ops
|
||||
@ -338,7 +329,7 @@ class TestSchemaCheck(JitTestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, "Argument input is not defined as mutable but was mutated"):
|
||||
x = torch.rand((3, 3))
|
||||
y = torch.rand((3, 3))
|
||||
with enable_torch_dispatch_mode(SchemaCheckMode()):
|
||||
with SchemaCheckMode():
|
||||
IncorrectAliasTensor(x).sin().cos().sub(IncorrectAliasTensor(y))
|
||||
|
||||
# Tests that an exception is raised for a mismatching alias
|
||||
@ -346,7 +337,7 @@ class TestSchemaCheck(JitTestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, "Argument input is not defined to alias output but was aliasing"):
|
||||
x = torch.rand((3, 3), requires_grad=True)
|
||||
y = torch.rand((3, 3))
|
||||
with enable_torch_dispatch_mode(SchemaCheckMode()):
|
||||
with SchemaCheckMode():
|
||||
IncorrectAliasTensor(x).add(IncorrectAliasTensor(y), alpha=2)
|
||||
|
||||
# Tests that an exception is raised for a mismatching alias over multiple ops
|
||||
@ -354,7 +345,7 @@ class TestSchemaCheck(JitTestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, "Argument input is not defined to alias output but was aliasing"):
|
||||
x = torch.rand((3, 3), requires_grad=True)
|
||||
y = torch.zeros((3, 3), requires_grad=True)
|
||||
with enable_torch_dispatch_mode(SchemaCheckMode()):
|
||||
with SchemaCheckMode():
|
||||
IncorrectAliasTensor(x).sin().relu().add(IncorrectAliasTensor(y), alpha=2)
|
||||
|
||||
# Tests that an exception is raised for a centered mismatching alias over multiple ops
|
||||
@ -362,15 +353,14 @@ class TestSchemaCheck(JitTestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, "Argument input is not defined to alias output but was aliasing"):
|
||||
x = torch.rand((3, 3), requires_grad=True)
|
||||
y = torch.zeros((3, 3), requires_grad=True)
|
||||
with enable_torch_dispatch_mode(SchemaCheckMode()):
|
||||
with SchemaCheckMode():
|
||||
IncorrectAliasTensor(x).sin().add(IncorrectAliasTensor(y), alpha=2).relu()
|
||||
|
||||
# Tests that an exception is raised for a centered mismatching alias over multiple ops
|
||||
def test_alias_check_fail_outputs_unexpectedly_aliasing(self):
|
||||
with self.assertRaisesRegex(RuntimeError, "Outputs 0 and 1 alias unexpectedly"):
|
||||
x = torch.rand((3, 3))
|
||||
s = SchemaCheckMode()
|
||||
with enable_torch_dispatch_mode(s):
|
||||
with SchemaCheckMode() as s:
|
||||
IncorrectAliasTensor(x).aminmax(dim=0)
|
||||
|
||||
# Tests that is_alias_of returns as expected
|
||||
@ -439,8 +429,7 @@ class TestSchemaCheck(JitTestCase):
|
||||
|
||||
return func(*args, **kwargs)
|
||||
x = torch.rand((3, 3))
|
||||
schemaInfoCheck = SchemaInfoBindTestMode(self)
|
||||
with enable_torch_dispatch_mode(schemaInfoCheck):
|
||||
with SchemaInfoBindTestMode(self) as schemaInfoCheck:
|
||||
x.add(x)
|
||||
|
||||
|
||||
@ -452,7 +441,7 @@ class TestSchemaCheckModeOpInfo(JitTestCase):
|
||||
if (dtype == torch.complex32):
|
||||
return
|
||||
for sample in op.sample_inputs(device, dtype, requires_grad=False):
|
||||
with enable_torch_dispatch_mode(SchemaCheckMode()):
|
||||
with SchemaCheckMode():
|
||||
op(sample.input, *sample.args, **sample.kwargs)
|
||||
|
||||
instantiate_device_type_tests(TestSchemaCheckModeOpInfo, globals(), only_for=("cpu", "cuda"))
|
||||
|
@ -918,11 +918,18 @@ def __is_forward_AD_enabled() -> _bool: ...
|
||||
def _register_default_hooks(pack_hook: Callable, unpack_hook: Callable) -> None: ...
|
||||
def _reset_default_hooks() -> None: ...
|
||||
|
||||
def _is_torch_function_mode_enabled()-> _bool: ...
|
||||
def _set_torch_function_mode(cls: Any) -> None: ...
|
||||
def _get_torch_function_mode() -> Any: ...
|
||||
def _push_on_torch_function_stack(cls: Any) -> None: ...
|
||||
def _pop_torch_function_stack() -> Any: ...
|
||||
def _get_function_stack_at(idx: _int) -> Any: ...
|
||||
def _len_torch_function_stack() -> _int: ...
|
||||
|
||||
def _set_torch_dispatch_mode(cls: Any) -> None: ...
|
||||
def _get_torch_dispatch_mode() -> Any: ...
|
||||
def _push_on_torch_dispatch_stack(cls: Any) -> None: ...
|
||||
def _pop_torch_dispatch_stack() -> Any: ...
|
||||
def _get_dispatch_stack_at(idx: _int) -> Any: ...
|
||||
def _len_torch_dispatch_stack() -> _int: ...
|
||||
|
||||
class _InferenceMode(object):
|
||||
def __init__(self, mode: _bool) -> None: ...
|
||||
|
@ -12,6 +12,7 @@ import torch.jit
|
||||
from torch import _utils_internal
|
||||
|
||||
# Query `hasattr` only once.
|
||||
|
||||
_SET_GLOBAL_FLAGS = hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags")
|
||||
|
||||
|
||||
@ -144,9 +145,11 @@ class PyOperator(PyOperatorABC):
|
||||
return inner
|
||||
|
||||
def dispatch(self, dispatch_key, *args, **kwargs):
|
||||
from torch.utils._python_dispatch import _get_current_dispatch_mode
|
||||
|
||||
if dispatch_key == torch._C.DispatchKey.Python:
|
||||
# TODO(voz): We should walk all the nodes here / turn it into a list, topmode is ok for now.
|
||||
curr_mode = type(torch._C._get_torch_dispatch_mode())
|
||||
curr_mode = type(_get_current_dispatch_mode())
|
||||
assert (
|
||||
curr_mode is not None
|
||||
), "Illegal invocation of dispatch on torch._C.DispatchKey.Python without a mode."
|
||||
@ -314,9 +317,11 @@ class OpOverload(PyOperatorABC):
|
||||
return key
|
||||
|
||||
def handler(*args, **kwargs):
|
||||
from torch.utils._python_dispatch import _get_current_dispatch_mode
|
||||
|
||||
# TODO: We also need to handle tensor subclasses here
|
||||
# TODO(voz): We should walk all the nodes here / turn it into a list, topmode is ok for now.
|
||||
curr_mode = type(torch._C._get_torch_dispatch_mode())
|
||||
curr_mode = type(_get_current_dispatch_mode())
|
||||
assert (
|
||||
curr_mode is not None
|
||||
), "Illegal invocation of dispatch on torch._C.DispatchKey.Python without a mode."
|
||||
|
@ -185,7 +185,7 @@ class TorchRefsMode(torch.overrides.TorchFunctionMode):
|
||||
if self.should_fallback_fn(self, func, args, kwargs):
|
||||
return orig_func(*args, **kwargs)
|
||||
# torch calls inside func should be interpreted as refs calls
|
||||
with torch.overrides.enable_torch_function_mode(self, replace=self.inner):
|
||||
with self:
|
||||
return func(*args, **kwargs)
|
||||
if self.strict:
|
||||
raise RuntimeError(
|
||||
@ -202,9 +202,7 @@ def _is_node_supported_nvfuser(node):
|
||||
|
||||
|
||||
def _is_func_unsupported_nvfuser(torch_function_mode, func, args, kwargs):
|
||||
with torch.overrides.enable_torch_function_mode(
|
||||
torch_function_mode, replace=torch_function_mode.inner
|
||||
):
|
||||
with torch_function_mode:
|
||||
gm = get_isolated_graphmodule(func, args, kwargs)
|
||||
|
||||
supported_ops = NvfuserPrimOperatorSupport()
|
||||
|
@ -16,7 +16,7 @@ from torch.fx.operator_schemas import normalize_function
|
||||
from torch.multiprocessing.reductions import StorageWeakRef
|
||||
from torch.overrides import TorchFunctionMode
|
||||
from torch.utils._mode_utils import no_dispatch
|
||||
from torch.utils._python_dispatch import enable_torch_dispatch_mode, TorchDispatchMode
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
|
||||
from torch.utils._pytree import PyTree, tree_flatten, tree_map
|
||||
|
||||
@ -564,7 +564,8 @@ class FakeTensor(torch.Tensor):
|
||||
else:
|
||||
assert fake_mode is arg.fake_mode, "Mixing modes NYI"
|
||||
|
||||
with enable_torch_dispatch_mode(fake_mode):
|
||||
assert fake_mode is not None
|
||||
with fake_mode: # type: ignore[attr-defined]
|
||||
return func(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
@ -623,7 +624,7 @@ class FakeTensor(torch.Tensor):
|
||||
|
||||
|
||||
# We keep one instantiation of `fake_tensor_converter` active
|
||||
# for the duration of `with torch_enable_mode(FakeTensorMode)`.
|
||||
# for the duration of `with FakeTensorMode()`.
|
||||
# This allows accurate storage aliasing across invocation of
|
||||
# different operators. While this will keep all freshly allocated
|
||||
# tensors alive during `FakeTensorMode`, there will no be no
|
||||
@ -779,7 +780,7 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
# We do this to allow for better error localization with `TORCH_SHOW_CPP_STACKTRACES=1`
|
||||
return None
|
||||
|
||||
with self.restore():
|
||||
with self:
|
||||
if func in meta_table:
|
||||
r = meta_table[func](*args, **kwargs)
|
||||
return r
|
||||
@ -801,7 +802,7 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
and len(flat_arg_fake_tensors) != 0
|
||||
and hasattr(func, "prim_meta_impl")
|
||||
):
|
||||
with self.restore():
|
||||
with self:
|
||||
return func.prim_meta_impl(*args, **kwargs)
|
||||
|
||||
if has_symbolic_sizes:
|
||||
|
@ -585,34 +585,6 @@ static PyObject* python_exit_dual_level(
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* set_torch_dispatch_mode(PyObject* _unused, PyObject* arg) {
|
||||
HANDLE_TH_ERRORS
|
||||
if (arg == Py_None) {
|
||||
c10::impl::TorchDispatchModeTLS::set_state(nullptr);
|
||||
} else {
|
||||
Py_INCREF(arg);
|
||||
c10::impl::TorchDispatchModeTLS::set_state(
|
||||
std::make_shared<c10::SafePyObject>(arg, getPyInterpreter()));
|
||||
}
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* get_torch_dispatch_mode(
|
||||
PyObject* _unused,
|
||||
PyObject* _unused2) {
|
||||
HANDLE_TH_ERRORS
|
||||
const auto& mode = c10::impl::TorchDispatchModeTLS::get_state();
|
||||
if (!mode) {
|
||||
Py_RETURN_NONE;
|
||||
} else {
|
||||
auto* r = mode->ptr(getPyInterpreter());
|
||||
Py_INCREF(r);
|
||||
return r;
|
||||
}
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* set_torch_function_mode(PyObject* _unused, PyObject* arg) {
|
||||
HANDLE_TH_ERRORS
|
||||
if (arg == Py_None) {
|
||||
@ -623,24 +595,136 @@ static PyObject* set_torch_function_mode(PyObject* _unused, PyObject* arg) {
|
||||
std::make_shared<c10::SafePyObject>(arg, getPyInterpreter()));
|
||||
}
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
END_HANDLE_TH_ERRORS;
|
||||
}
|
||||
|
||||
static PyObject* get_torch_function_mode(
|
||||
static PyObject* is_torch_function_mode_enabled(
|
||||
PyObject* _unused,
|
||||
PyObject* _unused2) {
|
||||
HANDLE_TH_ERRORS
|
||||
const auto& mode = at::impl::PythonTorchFunctionTLS::get_mode();
|
||||
if (!mode) {
|
||||
Py_RETURN_NONE;
|
||||
if (at::impl::function_mode_enabled()) {
|
||||
Py_RETURN_TRUE;
|
||||
} else {
|
||||
auto* r = mode->ptr(getPyInterpreter());
|
||||
Py_INCREF(r);
|
||||
return r;
|
||||
Py_RETURN_FALSE;
|
||||
}
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* push_on_torch_function_stack(
|
||||
PyObject* _unused,
|
||||
PyObject* arg) {
|
||||
HANDLE_TH_ERRORS
|
||||
if (arg != Py_None) {
|
||||
Py_INCREF(arg);
|
||||
at::impl::PythonTorchFunctionTLS::push_onto_stack(
|
||||
std::make_shared<c10::SafePyObject>(arg, getPyInterpreter()));
|
||||
}
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* pop_torch_function_stack(
|
||||
PyObject* _unused,
|
||||
PyObject* _unused2) {
|
||||
HANDLE_TH_ERRORS
|
||||
const auto& mode = at::impl::PythonTorchFunctionTLS::pop_stack();
|
||||
auto* r = mode->ptr(getPyInterpreter());
|
||||
Py_INCREF(r);
|
||||
return r;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* get_function_stack_at(
|
||||
PyObject* _unused,
|
||||
PyObject* args,
|
||||
PyObject* kwargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
static PythonArgParser parser({"get_stack_at(int64_t level)"});
|
||||
|
||||
ParsedArgs<1> parsed_args;
|
||||
auto _r = parser.parse(args, kwargs, parsed_args);
|
||||
|
||||
auto idx = _r.toInt64(0);
|
||||
const auto& mode = at::impl::PythonTorchFunctionTLS::get_stack_at(idx);
|
||||
auto* r = mode->ptr(getPyInterpreter());
|
||||
Py_INCREF(r);
|
||||
return r;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* len_torch_function_stack(
|
||||
PyObject* _unused,
|
||||
PyObject* _unused2) {
|
||||
HANDLE_TH_ERRORS
|
||||
const auto len = at::impl::PythonTorchFunctionTLS::stack_len();
|
||||
return utils::wrap(static_cast<int64_t>(len));
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* set_torch_dispatch_mode(PyObject* _unused, PyObject* arg) {
|
||||
HANDLE_TH_ERRORS
|
||||
if (arg == Py_None) {
|
||||
c10::impl::TorchDispatchModeTLS::set_mode(nullptr);
|
||||
} else {
|
||||
Py_INCREF(arg);
|
||||
c10::impl::TorchDispatchModeTLS::set_mode(
|
||||
std::make_shared<c10::SafePyObject>(arg, getPyInterpreter()));
|
||||
}
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS;
|
||||
}
|
||||
|
||||
static PyObject* push_on_torch_dispatch_stack(
|
||||
PyObject* _unused,
|
||||
PyObject* arg) {
|
||||
HANDLE_TH_ERRORS
|
||||
if (arg != Py_None) {
|
||||
Py_INCREF(arg);
|
||||
c10::impl::TorchDispatchModeTLS::push_onto_stack(
|
||||
std::make_shared<c10::SafePyObject>(arg, getPyInterpreter()));
|
||||
}
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* pop_torch_dispatch_stack(
|
||||
PyObject* _unused,
|
||||
PyObject* _unused2) {
|
||||
HANDLE_TH_ERRORS
|
||||
const auto& mode = c10::impl::TorchDispatchModeTLS::pop_stack();
|
||||
auto* r = mode->ptr(getPyInterpreter());
|
||||
Py_INCREF(r);
|
||||
return r;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* get_dispatch_stack_at(
|
||||
PyObject* _unused,
|
||||
PyObject* args,
|
||||
PyObject* kwargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
static PythonArgParser parser({"get_stack_at(int64_t level)"});
|
||||
|
||||
ParsedArgs<1> parsed_args;
|
||||
auto _r = parser.parse(args, kwargs, parsed_args);
|
||||
|
||||
auto idx = _r.toInt64(0);
|
||||
const auto& mode = c10::impl::TorchDispatchModeTLS::get_stack_at(idx);
|
||||
auto* r = mode->ptr(getPyInterpreter());
|
||||
Py_INCREF(r);
|
||||
return r;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* len_torch_dispatch_stack(
|
||||
PyObject* _unused,
|
||||
PyObject* _unused2) {
|
||||
HANDLE_TH_ERRORS
|
||||
const auto len = c10::impl::TorchDispatchModeTLS::stack_len();
|
||||
return utils::wrap(static_cast<int64_t>(len));
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
// autograd methods on torch._C
|
||||
static PyMethodDef methods[] = { // NOLINT
|
||||
{"_set_grad_enabled", set_grad_enabled, METH_O, nullptr},
|
||||
@ -685,10 +769,44 @@ static PyMethodDef methods[] = { // NOLINT
|
||||
castPyCFunctionWithKeywords(python_exit_dual_level),
|
||||
METH_VARARGS | METH_KEYWORDS,
|
||||
nullptr},
|
||||
{"_set_torch_dispatch_mode", set_torch_dispatch_mode, METH_O, nullptr},
|
||||
{"_get_torch_dispatch_mode", get_torch_dispatch_mode, METH_NOARGS, nullptr},
|
||||
{"_is_torch_function_mode_enabled",
|
||||
is_torch_function_mode_enabled,
|
||||
METH_NOARGS,
|
||||
nullptr},
|
||||
{"_set_torch_function_mode", set_torch_function_mode, METH_O, nullptr},
|
||||
{"_get_torch_function_mode", get_torch_function_mode, METH_NOARGS, nullptr},
|
||||
{"_push_on_torch_function_stack",
|
||||
push_on_torch_function_stack,
|
||||
METH_O,
|
||||
nullptr},
|
||||
{"_pop_torch_function_stack",
|
||||
pop_torch_function_stack,
|
||||
METH_NOARGS,
|
||||
nullptr},
|
||||
{"_get_function_stack_at",
|
||||
castPyCFunctionWithKeywords(get_function_stack_at),
|
||||
METH_VARARGS | METH_KEYWORDS,
|
||||
nullptr},
|
||||
{"_len_torch_function_stack",
|
||||
len_torch_function_stack,
|
||||
METH_NOARGS,
|
||||
nullptr},
|
||||
{"_set_torch_dispatch_mode", set_torch_dispatch_mode, METH_O, nullptr},
|
||||
{"_push_on_torch_dispatch_stack",
|
||||
push_on_torch_dispatch_stack,
|
||||
METH_O,
|
||||
nullptr},
|
||||
{"_pop_torch_dispatch_stack",
|
||||
pop_torch_dispatch_stack,
|
||||
METH_NOARGS,
|
||||
nullptr},
|
||||
{"_get_dispatch_stack_at",
|
||||
castPyCFunctionWithKeywords(get_dispatch_stack_at),
|
||||
METH_VARARGS | METH_KEYWORDS,
|
||||
nullptr},
|
||||
{"_len_torch_dispatch_stack",
|
||||
len_torch_dispatch_stack,
|
||||
METH_NOARGS,
|
||||
nullptr},
|
||||
{nullptr, nullptr, 0, nullptr}};
|
||||
|
||||
PyMethodDef* python_functions() {
|
||||
|
@ -293,7 +293,7 @@ auto handle_torch_function_no_python_arg_parser(
|
||||
torch_function_name == TorchFunctionName::TorchFunction;
|
||||
auto get_mode = [&]() {
|
||||
return is_torch_function ? at::impl::PythonTorchFunctionTLS::get_mode()
|
||||
: c10::impl::TorchDispatchModeTLS::get_state();
|
||||
: c10::impl::TorchDispatchModeTLS::get_mode();
|
||||
};
|
||||
|
||||
const auto& maybe_mode = get_mode();
|
||||
|
@ -19,7 +19,7 @@ struct StashTorchFunctionModeGuard {
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<c10::SafePyObject> old_mode_ = nullptr;
|
||||
std::shared_ptr<c10::SafePyObject> old_mode_;
|
||||
};
|
||||
|
||||
} // namespace overrides
|
||||
|
@ -8,16 +8,15 @@ namespace torch_dispatch_mode {
|
||||
struct StashTorchDispatchModeGuard {
|
||||
public:
|
||||
StashTorchDispatchModeGuard() {
|
||||
saved_ = c10::impl::TorchDispatchModeTLS::get_state();
|
||||
c10::impl::TorchDispatchModeTLS::set_state(nullptr);
|
||||
c10::impl::TorchDispatchModeTLS::swap_mode(saved_mode_);
|
||||
}
|
||||
|
||||
~StashTorchDispatchModeGuard() {
|
||||
c10::impl::TorchDispatchModeTLS::set_state(saved_);
|
||||
c10::impl::TorchDispatchModeTLS::set_mode(std::move(saved_mode_));
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<at::SafePyObject> saved_;
|
||||
std::shared_ptr<at::SafePyObject> saved_mode_;
|
||||
};
|
||||
|
||||
} // namespace torch_dispatch_mode
|
||||
|
@ -19,7 +19,7 @@ from dataclasses import dataclass
|
||||
import weakref
|
||||
import operator
|
||||
|
||||
from torch.utils._python_dispatch import TorchDispatchMode, enable_torch_dispatch_mode
|
||||
from torch.utils._python_dispatch import TorchDispatchMode, _pop_mode_temporarily, _get_current_dispatch_mode
|
||||
from torch._subclasses import FakeTensor
|
||||
from .symbolic_shapes import ShapeEnv, SymDispatchMode, PySymInt, PySymFloat
|
||||
from torch.fx import Proxy
|
||||
@ -167,9 +167,9 @@ def track_tensor_tree(inner_res, proxy_res, *, constant, tracer):
|
||||
def maybe_disable_fake_tensor_mode():
|
||||
# TODO: figure out if this API generally makes sense and bake it into the
|
||||
# library
|
||||
mb_fake_mode = torch._C._get_torch_dispatch_mode()
|
||||
mb_fake_mode = _get_current_dispatch_mode()
|
||||
if isinstance(mb_fake_mode, FakeTensorMode):
|
||||
return enable_torch_dispatch_mode(mb_fake_mode.inner, replace=mb_fake_mode)
|
||||
return _pop_mode_temporarily()
|
||||
else:
|
||||
return nullcontext()
|
||||
|
||||
@ -432,7 +432,7 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
|
||||
@contextmanager
|
||||
def restore(self):
|
||||
with self.sym_mode.enable(True):
|
||||
with super().restore():
|
||||
with self:
|
||||
yield
|
||||
|
||||
def inner_torch_dispatch(self, func, types, args=(), kwargs=None):
|
||||
@ -637,12 +637,7 @@ def make_fx(f, decomposition_table=None, tracing_mode="real"):
|
||||
|
||||
|
||||
def get_torch_dispatch_modes():
|
||||
modes = [torch._C._get_torch_dispatch_mode()]
|
||||
if modes[-1] is None:
|
||||
return list()
|
||||
while modes[-1].inner is not None:
|
||||
modes.append(modes[-1].inner)
|
||||
return modes
|
||||
return torch.utils._python_dispatch._get_current_dispatch_mode_stack()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
|
@ -32,9 +32,9 @@ import contextlib
|
||||
import torch
|
||||
from torch._C import (
|
||||
_has_torch_function, _has_torch_function_unary,
|
||||
_has_torch_function_variadic, _add_docstr, _set_torch_function_mode, _get_torch_function_mode)
|
||||
|
||||
from torch.utils._mode_utils import _enable_mode, _ModeInfo, _restore_mode
|
||||
_has_torch_function_variadic, _add_docstr,
|
||||
_push_on_torch_function_stack, _pop_torch_function_stack, _get_function_stack_at, _len_torch_function_stack,
|
||||
_set_torch_function_mode, _is_torch_function_mode_enabled)
|
||||
|
||||
__all__ = [
|
||||
"get_ignored_functions",
|
||||
@ -1510,8 +1510,8 @@ def handle_torch_function(
|
||||
types = tuple(map(type, overloaded_args))
|
||||
|
||||
# Check for __torch_function__ mode.
|
||||
mode = _get_torch_function_mode()
|
||||
if mode is not None:
|
||||
if _is_torch_function_mode_enabled():
|
||||
mode = _get_current_function_mode()
|
||||
# NB: unlike on tensors, modes are instances
|
||||
with _no_torch_function_mode():
|
||||
result = mode.__torch_function__(public_api, types, args, kwargs)
|
||||
@ -1541,8 +1541,8 @@ def handle_torch_function(
|
||||
"no implementation found for '{}' on types that implement "
|
||||
'__torch_function__: {}'
|
||||
).format(func_name, [type(arg) for arg in overloaded_args])
|
||||
if mode is not None:
|
||||
msg += f" nor in mode {mode}"
|
||||
if _is_torch_function_mode_enabled():
|
||||
msg += f" nor in mode {_get_current_function_mode()}"
|
||||
raise TypeError(msg)
|
||||
|
||||
has_torch_function = _add_docstr(
|
||||
@ -1766,53 +1766,7 @@ def is_tensor_like(inp):
|
||||
"""
|
||||
return type(inp) is torch.Tensor or hasattr(type(inp), "__torch_function__")
|
||||
|
||||
|
||||
def _wrap_torch_function(f):
|
||||
@functools.wraps(f)
|
||||
def wrapped(self, *args, **kwargs):
|
||||
if isinstance(f, classmethod):
|
||||
raise RuntimeError("TorchFunctionMode's torch_function function " +
|
||||
"should be a normal method not a class method")
|
||||
inner = getattr(self, "inner", None)
|
||||
|
||||
with enable_torch_function_mode(inner):
|
||||
return f(self, *args, **kwargs)
|
||||
return wrapped
|
||||
|
||||
|
||||
# Implementation note: I had a choice about how much of mode stacks
|
||||
# to implement in Python versus in C++. At time of writing, I did not care
|
||||
# too much about implementation efficiency; however, I do care about making it
|
||||
# hard for users to implement modes in the wrong way. In the end, it turned
|
||||
# out to be possible to implement mode stacks entirely from userland, with the
|
||||
# C++ API providing only _get_torch_function_mode() and
|
||||
# _set_torch_function_mode(), so I opted to provide some unsafe C++ bindings and
|
||||
# have the bulk of the logic for managing the stack in Python, which helped
|
||||
# simplify the C++ API surface. It would also have been valid to build in the
|
||||
# notion of mode stack directly into C++ but in this design it's substantially
|
||||
# more difficult to interact with TorchFunctionModeMeta.
|
||||
class TorchFunctionModeMeta(type):
|
||||
"""
|
||||
Metaclass for :class:`TorchFunctionMode`; it does two things:
|
||||
|
||||
* Adds an implicit ``inner`` kwarg to ``__init__``, to
|
||||
allow the modes to be chained together to form a stack.
|
||||
|
||||
* Reenables the inner mode, so that by default PyTorch API calls
|
||||
will compositionally proceed to the next mode on the stack.
|
||||
|
||||
The default behavior for the second bullet is important, as it is easy to
|
||||
accidentally write ``__torch_function__`` implementations that are not
|
||||
compositional, and the wrapping here makes the obvious code do the
|
||||
right thing (aka, this is why there is a metaclass).
|
||||
"""
|
||||
def __new__(metacls, name, bases, dct):
|
||||
if '__torch_function__' in dct:
|
||||
dct['__torch_function__'] = _wrap_torch_function(dct['__torch_function__'])
|
||||
return super().__new__(metacls, name, bases, dct)
|
||||
|
||||
|
||||
class TorchFunctionMode(metaclass=TorchFunctionModeMeta):
|
||||
class TorchFunctionMode:
|
||||
"""
|
||||
A ``TorchFunctionMode`` allows you to override the meaning of all
|
||||
``__torch_function__`` overrideable functions within a dynamic scope,
|
||||
@ -1852,24 +1806,11 @@ class TorchFunctionMode(metaclass=TorchFunctionModeMeta):
|
||||
raise NotImplementedError()
|
||||
|
||||
def __enter__(self):
|
||||
old = _get_torch_function_mode()
|
||||
if hasattr(self, "inner"):
|
||||
raise RuntimeError(f"{self} has already been used as a mode. Please use a fresh version or use restore")
|
||||
else:
|
||||
self.inner = old
|
||||
if old is None:
|
||||
self.ancestors = set()
|
||||
else:
|
||||
self.ancestors = self.inner.ancestors.union({self.inner})
|
||||
_set_torch_function_mode(self)
|
||||
_push_mode(self)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
_set_torch_function_mode(self.inner)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def restore(self):
|
||||
return _restore_mode(self, mode_info=_TorchFunctionModeInfo())
|
||||
_pop_mode()
|
||||
|
||||
@classmethod
|
||||
def push(cls, *args, **kwargs):
|
||||
@ -1877,6 +1818,50 @@ class TorchFunctionMode(metaclass=TorchFunctionModeMeta):
|
||||
instance = cls(*args, **kwargs)
|
||||
return instance
|
||||
|
||||
|
||||
def _get_current_function_mode():
|
||||
stack_len = _len_torch_function_stack()
|
||||
return _get_function_stack_at(stack_len - 1) if stack_len > 0 else None
|
||||
|
||||
|
||||
def _get_current_function_mode_stack():
|
||||
stack_len = _len_torch_function_stack()
|
||||
return [_get_function_stack_at(i) for i in range(stack_len)]
|
||||
|
||||
def _push_mode(mode):
|
||||
if _len_torch_function_stack() == 0:
|
||||
_set_torch_function_mode(_TorchFunctionStackMode())
|
||||
_push_on_torch_function_stack(mode)
|
||||
|
||||
|
||||
def _pop_mode():
|
||||
old = _pop_torch_function_stack()
|
||||
if _len_torch_function_stack() == 0:
|
||||
_set_torch_function_mode(None)
|
||||
return old
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _pop_mode_temporarily():
|
||||
old = _pop_mode()
|
||||
try:
|
||||
yield old
|
||||
finally:
|
||||
_push_mode(old)
|
||||
|
||||
# a helper "mode" used by the torch_function push helper method. This is the only mode that will ever
|
||||
# be active at the C++ level and it will run the current mode
|
||||
class _TorchFunctionStackMode:
|
||||
def __torch_function__(self, func, types, args=(), kwargs=None):
|
||||
with _pop_mode_temporarily() as old:
|
||||
if _len_torch_function_stack() > 0:
|
||||
_set_torch_function_mode(self)
|
||||
# we can't check the type of __torch_function__ here but this is sufficient for checking it's a classmethod
|
||||
if old.__torch_function__.__self__ is type(old):
|
||||
raise RuntimeError("TorchFunctionMode's torch_function function " +
|
||||
"should be a normal method not a class method")
|
||||
return old.__torch_function__(func, types, args, kwargs)
|
||||
|
||||
class BaseTorchFunctionMode(TorchFunctionMode):
|
||||
def __torch_function__(self, func, types, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
@ -1889,57 +1874,14 @@ class BaseTorchFunctionMode(TorchFunctionMode):
|
||||
# library code though, e.g., in handle_torch_function
|
||||
@contextlib.contextmanager
|
||||
def _no_torch_function_mode() -> Iterator[None]:
|
||||
old = _get_torch_function_mode()
|
||||
_set_torch_function_mode(None)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_set_torch_function_mode(old)
|
||||
if _len_torch_function_stack() > 0:
|
||||
_set_torch_function_mode(_TorchFunctionStackMode())
|
||||
|
||||
|
||||
class _TorchFunctionModeInfo(_ModeInfo):
|
||||
def __init__(self):
|
||||
super().__init__(mode_name="torch_function", mode_class=TorchFunctionMode)
|
||||
|
||||
def get_mode(self):
|
||||
return _get_torch_function_mode()
|
||||
|
||||
def set_mode(self, mode):
|
||||
return _set_torch_function_mode(mode)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def enable_torch_function_mode(mode, *, replace=None, ignore_preexisting=False) -> Iterator[None]:
|
||||
"""
|
||||
Context manager that sets the current :class:`TorchFunctionMode`; see the
|
||||
class for more information on what modes are. This function is
|
||||
non-compositional; if there is already an existing mode, it will raise an
|
||||
error; prefer using ``with MyMode():`` if your ``__torch_function__``
|
||||
implementation can defer to an inner mode.
|
||||
|
||||
This function is safe to use inside a ``__torch_function__`` mode handler,
|
||||
as the mode is guaranteed to be disabled in this context. You can use
|
||||
this context manager to reinstate the mode so that calls to overridable
|
||||
APIs recursively call back into your mode handler (this can easily cause
|
||||
infinite loops, so use with care!)
|
||||
|
||||
Args:
|
||||
mode (:class:`TorchFunctionMode`, Tensor-like class or None): the
|
||||
mode to set as current mode. If you pass a Tensor-like class,
|
||||
it will be treated as a non-compositional mode with no state,
|
||||
which is convenient if you have an existing tensor subclass
|
||||
that you'd like to apply globally in a quick and dirty way.
|
||||
Passing None will disable the current mode.
|
||||
replace (:class:`TorchFunctionMode` or Tensor-like class): the
|
||||
mode to replace. You can use this argument to change the mode in
|
||||
a situation where you know what the current mode is (and you are
|
||||
intentionally overwriting it.) If you don't know what the current
|
||||
mode is, use ``ignore_preexisting`` instead.
|
||||
ignore_preexisting (bool): if True, ignore any preexisting mode
|
||||
and overwrite it with the passed mode.
|
||||
"""
|
||||
return _enable_mode(mode, _TorchFunctionModeInfo(), replace=replace, ignore_preexisting=ignore_preexisting)
|
||||
|
||||
class enable_reentrant_dispatch():
|
||||
def __enter__(self):
|
||||
self._raii_guard = torch._C._RestorePythonTLSSnapshot()
|
||||
|
@ -1,10 +1,11 @@
|
||||
import torch
|
||||
from torch import Tensor
|
||||
import itertools
|
||||
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
|
||||
from functools import partial
|
||||
from torch.utils._mode_utils import no_dispatch
|
||||
from torch.utils._python_dispatch import enable_torch_dispatch_mode
|
||||
from torch.utils._mode_utils import no_dispatch, all_same_mode
|
||||
import torch.autograd.forward_ad as fwAD
|
||||
from typing import Callable
|
||||
import re
|
||||
@ -99,7 +100,7 @@ def is_inplace(func):
|
||||
return name[-1] == '_'
|
||||
|
||||
|
||||
def generate_cct(autograd_view_consistency=True):
|
||||
def generate_cct_and_mode(autograd_view_consistency=True):
|
||||
# This function returns a new class CompositeCompliantTensor
|
||||
# The two arguments control the behaviour described below.
|
||||
|
||||
@ -116,7 +117,7 @@ def generate_cct(autograd_view_consistency=True):
|
||||
__torch_function__ = torch._C._disabled_torch_function_impl
|
||||
|
||||
@staticmethod
|
||||
def __new__(cls, elem, *args, **kwargs):
|
||||
def __new__(cls, elem, mode, *args, **kwargs):
|
||||
assert type(elem) is not cls, \
|
||||
"Wrapping a CompositeCompliantTensor in a CompositeCompliantTensor is not supported"
|
||||
|
||||
@ -148,6 +149,8 @@ def generate_cct(autograd_view_consistency=True):
|
||||
# Ref: https://github.com/albanD/subclass_zoo/issues/21
|
||||
torch._C._set_conj(r, r.elem.is_conj())
|
||||
torch._C._set_neg(r, r.elem.is_neg())
|
||||
|
||||
r.mode = mode
|
||||
return r
|
||||
|
||||
def __repr__(self):
|
||||
@ -155,11 +158,20 @@ def generate_cct(autograd_view_consistency=True):
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
all_args = tree_flatten(args)[0] + tree_flatten(kwargs)[0]
|
||||
modes = tuple(e.mode for e in all_args if isinstance(e, CompositeCompliantTensor))
|
||||
if not all_same_mode(modes):
|
||||
raise RuntimeError("Multiple CompositeCompliantTensorModes NYI")
|
||||
with modes[0]:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
class CompositeCompliantTensorMode(TorchDispatchMode):
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
def unwrap(e):
|
||||
return e.elem if isinstance(e, CompositeCompliantTensor) else e
|
||||
|
||||
def wrap(e):
|
||||
return CompositeCompliantTensor(e) if isinstance(e, torch.Tensor) else e
|
||||
return CompositeCompliantTensor(e, self) if isinstance(e, torch.Tensor) else e
|
||||
|
||||
if func == torch.ops.aten._local_scalar_dense.default:
|
||||
raise RuntimeError(
|
||||
@ -224,13 +236,13 @@ def generate_cct(autograd_view_consistency=True):
|
||||
# For each CompositeCompliantTensor t, we check that t and t.elem
|
||||
# have consistent metadata. If they don't have consistent metadata,
|
||||
# that means the operator did something fishy.
|
||||
check = partial(check_metadata_consistency, CCT=cls)
|
||||
check = partial(check_metadata_consistency, CCT=CompositeCompliantTensor)
|
||||
tree_map(check, args)
|
||||
tree_map(check, kwargs)
|
||||
tree_map(check, rs)
|
||||
return rs
|
||||
|
||||
return CompositeCompliantTensor
|
||||
return CompositeCompliantTensor, CompositeCompliantTensorMode()
|
||||
|
||||
def is_tensorlist(lst):
|
||||
if not isinstance(lst, list) and not isinstance(lst, tuple):
|
||||
@ -251,12 +263,12 @@ def maybe_map(fn, should_map, arg):
|
||||
return fn(arg) if should_map else arg
|
||||
|
||||
|
||||
def wrap(arg, CCT):
|
||||
# CCT: CompositeCompliantTensor class which is generated using generate_cct
|
||||
def wrap(arg, CCT, cct_mode):
|
||||
# CCT: CompositeCompliantTensor class which is generated using generate_cct_and_mode
|
||||
if isinstance(arg, torch.Tensor):
|
||||
return CCT(arg)
|
||||
return CCT(arg, cct_mode)
|
||||
if is_tensorlist(arg):
|
||||
return [CCT(a) for a in arg]
|
||||
return [CCT(a, cct_mode) for a in arg]
|
||||
raise RuntimeError("wrap assumes that the input can be wrapped")
|
||||
|
||||
|
||||
@ -270,14 +282,14 @@ def wrap(arg, CCT):
|
||||
# [A, 1, B]
|
||||
# NB: Yes, this is exponential. No, we don't care too much because PyTorch ops
|
||||
# don't accept that many input Tensors.
|
||||
def generate_subclass_choices(flat_args, CCT):
|
||||
# CCT: CompositeCompliantTensor class which is generated using generate_cct
|
||||
def generate_subclass_choices(flat_args, CCT, cct_mode):
|
||||
# CCT: CompositeCompliantTensor class which is generated using generate_cct_and_mode
|
||||
is_tensor_likes = [isinstance(arg, torch.Tensor) or is_tensorlist(arg) for arg in flat_args]
|
||||
subclass_options = [[False, True] if is_tensor_like else [False] for is_tensor_like in is_tensor_likes]
|
||||
|
||||
for which_args_are_wrapped in itertools.product(*subclass_options):
|
||||
|
||||
result = [maybe_map(partial(wrap, CCT=CCT), should_wrap_arg, arg)
|
||||
result = [maybe_map(partial(wrap, CCT=CCT, cct_mode=cct_mode), should_wrap_arg, arg)
|
||||
for should_wrap_arg, arg in zip(which_args_are_wrapped, flat_args)]
|
||||
yield result, which_args_are_wrapped
|
||||
|
||||
@ -285,11 +297,11 @@ def generate_subclass_choices(flat_args, CCT):
|
||||
# For an operation f(*args, **kwargs), each Tensor argument may either be
|
||||
# a regular Tensor or a Tensor Subclass. This iterator iterates through
|
||||
# all of those options.
|
||||
def generate_subclass_choices_args_kwargs(args, kwargs, CCT):
|
||||
# CCT: CompositeCompliantTensor class which is generated using generate_cct
|
||||
def generate_subclass_choices_args_kwargs(args, kwargs, CCT, cct_mode):
|
||||
# CCT: CompositeCompliantTensor class which is generated using generate_cct_and_mode
|
||||
flat_kwargs, spec = tree_flatten(kwargs)
|
||||
flat_args_kwargs = list(args) + list(flat_kwargs)
|
||||
for choice, debug_metadata in generate_subclass_choices(flat_args_kwargs, CCT):
|
||||
for choice, debug_metadata in generate_subclass_choices(flat_args_kwargs, CCT, cct_mode):
|
||||
new_args = choice[:len(args)]
|
||||
new_kwargs = tree_unflatten(choice[len(args):], spec)
|
||||
which_args_are_wrapped = debug_metadata[:len(args)]
|
||||
@ -320,9 +332,9 @@ def raise_composite_compliance_error(err, additional_info=''):
|
||||
# If some composite operation does any non-compliant behavior,
|
||||
# CompositeCompliantTensor will raise an error.
|
||||
def check_all_permutations(op, args, kwargs, assert_equal_fn):
|
||||
CCT = generate_cct()
|
||||
CCT, cct_mode = generate_cct_and_mode()
|
||||
expected = op(*args, **kwargs)
|
||||
for choice in generate_subclass_choices_args_kwargs(args, kwargs, CCT):
|
||||
for choice in generate_subclass_choices_args_kwargs(args, kwargs, CCT, cct_mode):
|
||||
new_args, new_kwargs, which_args_are_wrapped, which_kwargs_are_wrapped = choice
|
||||
|
||||
try:
|
||||
@ -365,17 +377,17 @@ def check_all_permutations(op, args, kwargs, assert_equal_fn):
|
||||
# Composite does any non-compliant behavior,
|
||||
# CompositeCompliantTensor will raise an error.
|
||||
def check_with_mode(op, args, kwargs, assert_equal_fn):
|
||||
CCT = generate_cct()
|
||||
CCT, cct_mode = generate_cct_and_mode()
|
||||
|
||||
def wrap(e):
|
||||
return CCT(e) if isinstance(e, torch.Tensor) else e
|
||||
return CCT(e, cct_mode) if isinstance(e, torch.Tensor) else e
|
||||
|
||||
expected = op(*args, **kwargs)
|
||||
|
||||
args = tree_map(wrap, args)
|
||||
kwargs = tree_map(wrap, kwargs)
|
||||
try:
|
||||
with enable_torch_dispatch_mode(CCT):
|
||||
with cct_mode:
|
||||
actual = op(*args, **kwargs)
|
||||
# see NOTE: [What errors are Composite Compliance trying to catch?]
|
||||
except RuntimeError as err:
|
||||
@ -429,11 +441,11 @@ def compute_expected_grads(op, args, kwargs, output_process_fn_grad=None, gradch
|
||||
def check_backward_formula(op: Callable, args, kwargs,
|
||||
output_process_fn_grad=None,
|
||||
gradcheck_wrapper=None, assert_equal_fn=None):
|
||||
CCT = generate_cct()
|
||||
CCT, cct_mode = generate_cct_and_mode()
|
||||
|
||||
expected = compute_expected_grads(op, args, kwargs, output_process_fn_grad, gradcheck_wrapper)
|
||||
|
||||
for choice in generate_subclass_choices_args_kwargs(args, kwargs, CCT):
|
||||
for choice in generate_subclass_choices_args_kwargs(args, kwargs, CCT, cct_mode):
|
||||
new_args, new_kwargs, which_args_are_wrapped, which_kwargs_are_wrapped = choice
|
||||
leaf_tensors = gather_leaf_tensors(new_args, new_kwargs)
|
||||
assert len(leaf_tensors) > 0
|
||||
@ -460,7 +472,7 @@ def check_backward_formula(op: Callable, args, kwargs,
|
||||
# NB: ones, not ones_like, so we get a regular Tensor here
|
||||
grads = [torch.ones(r.shape, device=r.device, dtype=r.dtype)
|
||||
for r in flat_diff_results]
|
||||
for flat_new_grads, which_grad_is_batched in generate_subclass_choices(grads, CCT):
|
||||
for flat_new_grads, which_grad_is_batched in generate_subclass_choices(grads, CCT, cct_mode):
|
||||
try:
|
||||
actual = torch.autograd.grad(flat_diff_results, leaf_tensors, flat_new_grads,
|
||||
allow_unused=True, retain_graph=True)
|
||||
@ -486,7 +498,7 @@ def check_backward_formula(op: Callable, args, kwargs,
|
||||
# this means we can apply check_forward_ad_formula to things that aren't OpInfos
|
||||
# while debugging.
|
||||
def check_forward_ad_formula(op: Callable, args, kwargs, gradcheck_wrapper=None, assert_equal_fn=None):
|
||||
CCT = generate_cct(autograd_view_consistency=False)
|
||||
CCT, cct_mode = generate_cct_and_mode(autograd_view_consistency=False)
|
||||
|
||||
def maybe_tangent(t):
|
||||
assert type(t) is not CCT
|
||||
@ -529,11 +541,11 @@ def check_forward_ad_formula(op: Callable, args, kwargs, gradcheck_wrapper=None,
|
||||
expected_tangents = tree_map(lambda x: x.tangent, expected)
|
||||
|
||||
# Permutations of arg and kwargs in CCT.
|
||||
for choice in generate_subclass_choices_args_kwargs(args, kwargs, CCT):
|
||||
for choice in generate_subclass_choices_args_kwargs(args, kwargs, CCT, cct_mode):
|
||||
new_args, new_kwargs, which_args_are_wrapped, which_kwargs_are_wrapped = choice
|
||||
|
||||
# Permutations tangent arg and tangent kwargs in CCT.
|
||||
for tang_choice in generate_subclass_choices_args_kwargs(tangent_args, tangent_kwargs, CCT):
|
||||
for tang_choice in generate_subclass_choices_args_kwargs(tangent_args, tangent_kwargs, CCT, cct_mode):
|
||||
new_tang_args, new_tang_kwargs, \
|
||||
which_tang_args_are_wrapped, which_tang_kwargs_are_wrapped = tang_choice
|
||||
|
||||
|
@ -1,125 +1,13 @@
|
||||
import torch
|
||||
from typing import Iterator, TypeVar
|
||||
from dataclasses import dataclass
|
||||
from typing import TypeVar
|
||||
from contextlib import contextmanager
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
# This file has all the logic to dedupe logic between torch dispatch and
|
||||
# torch function modes
|
||||
#
|
||||
# Specifically, it has the helper functions for enable_ and push_X_mode and the
|
||||
# ModeInfo class, which is extended by each where they are different
|
||||
|
||||
|
||||
# in order to dedupe the logic between TorchDispatchMode and TorchFunctionMode, this
|
||||
# is a container to hold all the differences between the modes. Then functions like
|
||||
# _enable_mode are able to use this container to call functions or get correctly
|
||||
# formatted names
|
||||
@dataclass
|
||||
class _ModeInfo:
|
||||
mode_name: str
|
||||
mode_class: type # the class related to the mode that's allowed to be passed in
|
||||
|
||||
def mode_class_name(self):
|
||||
return self.mode_class.__name__
|
||||
|
||||
def get_mode(self):
|
||||
"""gets the current mode for this type of mode"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def set_mode(self, mode):
|
||||
"""
|
||||
set mode to for this type of mode. Note that no checks are done on this, it's the unsafe
|
||||
version where checks are assumed to have been already done by the helper function
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
# shared version of enable_torch_function/enable_torch_dispatch_mode in order to deduplicate the code.
|
||||
# The differences between the modes are captured by `mode_info` and then queried when they're
|
||||
# needed during the function's invocation
|
||||
def _enable_mode(mode: T, mode_info: _ModeInfo, *, replace=None, ignore_preexisting=False) -> Iterator[T]:
|
||||
if not (
|
||||
mode is None or
|
||||
isinstance(mode, mode_info.mode_class) or
|
||||
(isinstance(mode, type) and not issubclass(mode, mode_info.mode_class))
|
||||
):
|
||||
raise ValueError(f'expected to get {mode_info.mode_class_name()}, Tensor-like class, '
|
||||
f'or None as an argument got {type(mode)} instead')
|
||||
old = mode_info.get_mode()
|
||||
if old is mode:
|
||||
yield mode # type: ignore[misc]
|
||||
return
|
||||
if old is not None and not ignore_preexisting and old is not replace:
|
||||
if isinstance(mode, mode_info.mode_class):
|
||||
help_text = 'Use `with Mode():` instead.'
|
||||
else:
|
||||
help_text = (
|
||||
'If you intended to completely override the preexisting mode, '
|
||||
'pass ignore_preexisting=True. This can result in unexpected '
|
||||
'behavior; please consider rewriting your mode to be a subclass '
|
||||
f'of {mode_info.mode_class_name()} to make it compositional!'
|
||||
)
|
||||
raise ValueError(
|
||||
f'Attempted to enable_{mode_info.mode_name}_mode, but there is already an '
|
||||
f'active mode {old}. {help_text}'
|
||||
)
|
||||
# NB: we don't require TorchFunctionMode/TorchDispatchMode since this is intended to also
|
||||
# let you directly pass a Tensor subclass type to "mode-ify" it.
|
||||
if mode is not None:
|
||||
required_fn = "__" + mode_info.mode_name + "__"
|
||||
if not hasattr(mode, required_fn):
|
||||
raise ValueError(
|
||||
f'The argument passed to enable_{mode_info.mode_name}_mode must implement {required_fn}'
|
||||
)
|
||||
mode_info.set_mode(mode)
|
||||
try:
|
||||
yield mode # type: ignore[misc]
|
||||
finally:
|
||||
mode_info.set_mode(old)
|
||||
|
||||
|
||||
def _restore_mode(mode, mode_info: _ModeInfo):
|
||||
if not hasattr(mode, "ancestors"):
|
||||
raise RuntimeError(f"{mode} does not have any ancestors. Use the standard version instead of restore")
|
||||
old = mode_info.get_mode()
|
||||
if old is not None and old not in mode.ancestors:
|
||||
raise RuntimeError(f"{mode} is not valid in the current state because the current mode is not its ancestor")
|
||||
mode_info.set_mode(mode)
|
||||
try:
|
||||
yield mode
|
||||
finally:
|
||||
mode_info.set_mode(old)
|
||||
|
||||
|
||||
# To help with non-lexical scoping, it will error if all the modes are from different scopes or haven't been used
|
||||
def find_outermost_mode(modes):
|
||||
outermost = None
|
||||
for mode in modes:
|
||||
if mode is not None:
|
||||
if not hasattr(mode, "ancestors"):
|
||||
raise RuntimeError(f"{mode}, doesn't have ancestors set so the ordering with other modes is unclear")
|
||||
if outermost is None:
|
||||
outermost = mode
|
||||
elif mode not in outermost.ancestors and outermost not in mode.ancestors:
|
||||
raise RuntimeError(f"modes {mode} and {outermost} are not compatible because they "
|
||||
"don't come from the same scope")
|
||||
elif outermost in mode.ancestors:
|
||||
outermost = mode
|
||||
return outermost
|
||||
|
||||
|
||||
# returns if all are the same mode
|
||||
def all_same_mode(modes):
|
||||
return all(tuple(mode == modes[0] for mode in modes))
|
||||
|
||||
# returns if all modes are from the current scope, ``cur_mode``
|
||||
def all_same_mode_scope(modes, cur_mode):
|
||||
if not hasattr(cur_mode, "ancestors"):
|
||||
return False
|
||||
return all(tuple(mode == cur_mode or mode in cur_mode.ancestors for mode in modes))
|
||||
|
||||
@contextmanager
|
||||
def no_dispatch():
|
||||
guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined]
|
||||
|
@ -1,115 +1,17 @@
|
||||
import contextlib
|
||||
from typing import Iterator, Set
|
||||
import functools
|
||||
|
||||
import warnings
|
||||
from torch.utils._mode_utils import _enable_mode, _ModeInfo, _restore_mode
|
||||
from torch._C import _get_torch_dispatch_mode, _set_torch_dispatch_mode
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class TorchDispatchModeInfo(_ModeInfo):
|
||||
def __init__(self):
|
||||
super().__init__(mode_name="torch_dispatch", mode_class=TorchDispatchMode)
|
||||
|
||||
def get_mode(self):
|
||||
return _get_torch_dispatch_mode()
|
||||
|
||||
def set_mode(self, mode):
|
||||
return _set_torch_dispatch_mode(mode)
|
||||
from torch._C import _len_torch_dispatch_stack, _get_dispatch_stack_at,\
|
||||
_pop_torch_dispatch_stack, _push_on_torch_dispatch_stack, _set_torch_dispatch_mode
|
||||
|
||||
|
||||
# TODO: Limitations and things about enable_torch_dispatch_mode we should fix before exposing it:
|
||||
# - We need a better user-facing api for torch._C._DisableTorchDispatch that
|
||||
# - We need a better user-facing api for _DisableTorchDispatch that
|
||||
# is able to selectively disable __torch_dispatch__ of a particular class.
|
||||
# - It doesn't work with the tensor constructors (torch.tensor, torch.Tensor)
|
||||
# - Better name (see https://github.com/pytorch/pytorch/pull/63496#discussion_r694091694)
|
||||
@contextlib.contextmanager
|
||||
def enable_torch_dispatch_mode(mode, *, replace=None, ignore_preexisting=False) -> Iterator[None]:
|
||||
"""
|
||||
Context manager that causes all pytorch operators to dispatch to the passed-in
|
||||
type's __torch_dispatch__ function, including operations that accept no tensors
|
||||
but return a tensor.
|
||||
|
||||
This function is non-compositional; if there is already an existing mode,
|
||||
it will raise an error
|
||||
|
||||
This function is safe to use inside a ``__torch_dispatch__`` mode handler,
|
||||
as the mode is guaranteed to be disabled in this context. You can use
|
||||
this context manager to reinstate the mode so that calls to overridable
|
||||
APIs recursively call back into your mode handler (this can easily cause
|
||||
infinite loops, so use with care!)
|
||||
|
||||
enable_torch_dispatch_mode is affected by _DisableTorchDispatch.
|
||||
|
||||
Args:
|
||||
mode (:class:`TorchDispatchMode`, Tensor-like class, or None): the
|
||||
mode to set as current mode. If you pass a Tensor-like class,
|
||||
it will be treated as a non-compositional mode with no state,
|
||||
which is convenient if you have an existing tensor subclass
|
||||
that you'd like to apply globally in a quick and dirty way.
|
||||
Passing None will disable the current mode.
|
||||
replace (:class:`TorchDispatchMode` or Tensor-like class): the
|
||||
mode to replace. You can use this argument to change the mode in
|
||||
a situation where you know what the current mode is (and you are
|
||||
intentionally overwriting it.) If you don't know what the current
|
||||
mode is, use ``ignore_preexisting`` instead.
|
||||
ignore_preexisting (bool): if True, ignore any preexisting mode
|
||||
and overwrite it with the passed mode.
|
||||
"""
|
||||
|
||||
return _enable_mode(mode, mode_info=TorchDispatchModeInfo(), replace=replace, ignore_preexisting=ignore_preexisting)
|
||||
|
||||
|
||||
def _wrap_torch_dispatch(f):
|
||||
@functools.wraps(f)
|
||||
def wrapped(self, *args, **kwargs):
|
||||
if isinstance(f, classmethod):
|
||||
raise RuntimeError("TorchDispatchMode's torch_dispatch function " +
|
||||
"should be a normal method not a class method")
|
||||
inner = getattr(self, "inner", None)
|
||||
|
||||
with enable_torch_dispatch_mode(inner):
|
||||
return f(self, *args, **kwargs)
|
||||
return wrapped
|
||||
|
||||
|
||||
# Implementation note, since this is based on TorchFunctionMode, this had the
|
||||
# same dilemma: I had a choice about how much of mode stacks
|
||||
# to implement in Python versus in C++. At time of writing, I did not care
|
||||
# too much about implementation efficiency; however, I do care about making it
|
||||
# hard for users to implement modes in the wrong way. In the end, it turned
|
||||
# out to be possible to implement mode stacks entirely from userland, with the
|
||||
# C++ API providing only _get_torch_dispatch_mode() and
|
||||
# _set_torch_dispatch_mode(), so I opted to provide some unsafe C++ bindings and
|
||||
# have the bulk of the logic for managing the stack in Python, which helped
|
||||
# simplify the C++ API surface. It would also have been valid to build in the
|
||||
# notion of mode stack directly into C++ but in this design it's substantially
|
||||
# more difficult to interact with TorchDispatchModeMeta.
|
||||
|
||||
class TorchDispatchModeMeta(type):
|
||||
"""
|
||||
Metaclass for :class:`TorchDispatchMode`; it does two things:
|
||||
|
||||
* Adds an implicit ``inner`` kwarg to ``__init__``, to
|
||||
allow the modes to be chained together to form a stack.
|
||||
|
||||
* Reenables the inner mode, so that by default PyTorch API calls
|
||||
will compositionally proceed to the next mode on the stack.
|
||||
|
||||
The default behavior for the second bullet is important, as it is easy to
|
||||
accidentally write ``_wrap_torch_dispatch`` implementations that are not
|
||||
compositional, and the wrapping here makes the obvious code do the
|
||||
right thing (aka, this is why there is a metaclass).
|
||||
"""
|
||||
def __new__(metacls, name, bases, dct):
|
||||
if '__torch_dispatch__' in dct:
|
||||
dct['__torch_dispatch__'] = _wrap_torch_dispatch(dct['__torch_dispatch__'])
|
||||
return super().__new__(metacls, name, bases, dct)
|
||||
|
||||
|
||||
class TorchDispatchMode(metaclass=TorchDispatchModeMeta):
|
||||
class TorchDispatchMode:
|
||||
"""
|
||||
A ``TorchDispatchMode`` allows you to override the meaning of all
|
||||
``__torch_dispatch__`` overrideable functions within a dynamic scope,
|
||||
@ -139,32 +41,15 @@ class TorchDispatchMode(metaclass=TorchDispatchModeMeta):
|
||||
``__torch_dispatch__(self)`` to make PyTorch
|
||||
API self-referential (beware of infinite loops, in this case!)
|
||||
"""
|
||||
# Force metaclass to generate constructor at the base of the hierarchy
|
||||
def __init__(self):
|
||||
self.ancestors: Set[TorchDispatchMode]
|
||||
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
raise NotImplementedError()
|
||||
|
||||
def __enter__(self):
|
||||
old = _get_torch_dispatch_mode()
|
||||
if hasattr(self, "inner"):
|
||||
raise RuntimeError(f"{self} has already been used as a mode. Please use a fresh version or use restore")
|
||||
else:
|
||||
self.inner = old
|
||||
if old is None:
|
||||
self.ancestors = set()
|
||||
else:
|
||||
self.ancestors = self.inner.ancestors.union({self.inner})
|
||||
_set_torch_dispatch_mode(self)
|
||||
_push_mode(self)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
_set_torch_dispatch_mode(self.inner)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def restore(self):
|
||||
return _restore_mode(self, mode_info=TorchDispatchModeInfo())
|
||||
_pop_mode()
|
||||
|
||||
@classmethod
|
||||
def push(cls, *args, **kwargs):
|
||||
@ -172,6 +57,48 @@ class TorchDispatchMode(metaclass=TorchDispatchModeMeta):
|
||||
instance = cls(*args, **kwargs)
|
||||
return instance
|
||||
|
||||
def _get_current_dispatch_mode():
|
||||
stack_len = _len_torch_dispatch_stack()
|
||||
return _get_dispatch_stack_at(stack_len - 1) if stack_len > 0 else None
|
||||
|
||||
|
||||
def _get_current_dispatch_mode_stack():
|
||||
stack_len = _len_torch_dispatch_stack()
|
||||
return [_get_dispatch_stack_at(i) for i in range(stack_len)]
|
||||
|
||||
def _push_mode(mode):
|
||||
if _len_torch_dispatch_stack() == 0:
|
||||
_set_torch_dispatch_mode(_TorchDispatchStackMode())
|
||||
_push_on_torch_dispatch_stack(mode)
|
||||
|
||||
|
||||
def _pop_mode():
|
||||
old = _pop_torch_dispatch_stack()
|
||||
if _len_torch_dispatch_stack() == 0:
|
||||
_set_torch_dispatch_mode(None)
|
||||
return old
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _pop_mode_temporarily():
|
||||
old = _pop_mode()
|
||||
try:
|
||||
yield old
|
||||
finally:
|
||||
_push_mode(old)
|
||||
|
||||
# a helper "mode" used by the torch dispatch push helper method. This is the only mode that will ever
|
||||
# be active at the C++ level and it will run the current mode
|
||||
class _TorchDispatchStackMode:
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
with _pop_mode_temporarily() as old:
|
||||
if _len_torch_dispatch_stack() > 0:
|
||||
_set_torch_dispatch_mode(self)
|
||||
# we can't check the type of __torch_dispatch__ here but this is sufficient for checking it's a classmethod
|
||||
if old.__torch_dispatch__.__self__ is type(old):
|
||||
raise RuntimeError(f"{type(old)}'s torch_dispatch function " +
|
||||
"should be a normal method not a class method")
|
||||
return old.__torch_dispatch__(func, types, args, kwargs)
|
||||
|
||||
class BaseTorchDispatchMode(TorchDispatchMode):
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
|
Reference in New Issue
Block a user