add nested mode to python mode

Pull Request resolved: https://github.com/pytorch/pytorch/pull/75965

Approved by: https://github.com/albanD, https://github.com/ezyang, https://github.com/zou3519
This commit is contained in:
samdow
2022-05-03 13:13:03 -04:00
committed by PyTorch MergeBot
parent 68fa6d8fec
commit 6779366f27
12 changed files with 544 additions and 209 deletions

View File

@ -7,9 +7,10 @@ from torch.testing._internal.common_utils import TestCase, run_tests
from torch.testing._internal.logging_tensor import LoggingTensor, LoggingTensorReentrant, LoggingTensorMode, \
log_input, capture_logs, no_dispatch
from torch.utils._pytree import tree_map
from torch.utils._python_dispatch import enable_torch_dispatch_mode
from torch.utils._python_dispatch import enable_torch_dispatch_mode, push_torch_dispatch_mode, TorchDispatchMode
import logging
from functools import partial
class TestPythonDispatch(TestCase):
def test_basic(self) -> None:
@ -448,11 +449,8 @@ $6 = torch._ops.aten.add_.Tensor($1, $5)''')
self.assertEqual(called_funcs, [torch.ops.aten.index_put_.default])
def test_enable_torch_dispatch_mode_error(self) -> None:
with self.assertRaisesRegex(ValueError, "__torch_dispatch__"):
with enable_torch_dispatch_mode(torch.Tensor):
pass
z = LoggingTensor(torch.empty([]))
with self.assertRaisesRegex(ValueError, "must be the type"):
with self.assertRaisesRegex(ValueError, "expected to get TorchDispatchMode, Tensor-like class, or None"):
with enable_torch_dispatch_mode(z):
pass
@ -497,9 +495,12 @@ $6 = torch._ops.aten.add_.Tensor($1, $5)''')
b = B(torch.empty(1))
with self.assertRaises(ErrorA):
a + a
# B has precedence over A due to the subclass relationship
with self.assertRaises(ErrorB):
a + b
# B has precedence over A due to the subclass relationship yet
# modes take precedence over arguments
with self.assertRaises(ErrorA):
with enable_torch_dispatch_mode(A):
b + b
with self.assertRaises(ErrorB):
@ -517,12 +518,149 @@ $6 = torch._ops.aten.add_.Tensor($1, $5)''')
expected = torch.ones([2, 3])
self.assertEqual(z.elem, expected)
def test_enable_torch_dispatch_mode_instance(self) -> None:
class TestMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
return func(*args, **kwargs)
x = TestMode(inner=None)
y = torch.tensor([2.])
with enable_torch_dispatch_mode(x):
y + y
def test_nested_enable_torch_dispatch_mode(self) -> None:
with self.assertRaisesRegex(RuntimeError, "has already been set"):
class A(LoggingTensorMode):
pass
with self.assertRaisesRegex(ValueError, "there is already an active mode"):
with enable_torch_dispatch_mode(LoggingTensorMode):
with enable_torch_dispatch_mode(LoggingTensorMode):
with enable_torch_dispatch_mode(A):
pass
def test_nesting_with_same_enable_torch_dispatch_mode(self) -> None:
# "nested" enable_torch_dispatch_modes are allowed if they're the same mode. It's the equivalent of
# a noop, so it will only write once to the log
with capture_logs() as logs:
x = LoggingTensor(torch.tensor([3.]))
log_input("x", x)
with enable_torch_dispatch_mode(LoggingTensor):
with enable_torch_dispatch_mode(LoggingTensor):
x + x
self.assertExpectedInline('\n'.join(logs), '''\
$0 = input('x')
$1 = torch._ops.aten.add.Tensor($0, $0)''')
def test_enable_torch_dispatch_mode_ignore_preexisting(self):
class A(torch.Tensor):
@staticmethod
def __new__(cls, elem):
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
return cls(torch.zeros(()))
class B(A):
pass
with enable_torch_dispatch_mode(A):
with enable_torch_dispatch_mode(B, ignore_preexisting=True):
self.assertTrue(isinstance(torch.zeros(()), B))
def test_enable_torch_dispatch_mode_replace(self):
class A(torch.Tensor):
@staticmethod
def __new__(cls, elem):
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
return cls(torch.zeros(()))
class B(A):
pass
with enable_torch_dispatch_mode(A):
with enable_torch_dispatch_mode(B, replace=A):
self.assertTrue(isinstance(torch.zeros(()), B))
def test_exception_handling(self):
class A(torch.Tensor):
@staticmethod
def __new__(cls, elem):
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
if func.__name__ == 'randn.default':
raise RuntimeError()
return cls(torch.zeros(()))
with enable_torch_dispatch_mode(A):
try:
torch.randn(())
except RuntimeError:
pass
self.assertTrue(isinstance(torch.zeros(()), A))
def test_push_torch_dispatch_mode(self) -> None:
class ErrorA(RuntimeError):
def __init__(self, msg=None):
return super().__init__(msg)
class A(TorchDispatchMode):
def __init__(self, msg=None):
self.msg = msg
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
raise ErrorA(self.msg)
x = torch.randn(3)
with self.assertRaises(ErrorA):
with push_torch_dispatch_mode(A):
torch.add(x, x)
with self.assertRaisesRegex(ErrorA, r"partial constructor"):
with push_torch_dispatch_mode(partial(A, "partial constructor")):
x + x
def test_torch_dispatch_mode_stack(self) -> None:
logs = []
class Logger(TorchDispatchMode):
def __init__(self, name):
self.name = name
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
logs.append(self.name)
return func(*args, **kwargs)
x = torch.randn(1)
with push_torch_dispatch_mode(partial(Logger, "A")):
with push_torch_dispatch_mode(partial(Logger, "B")):
x + x
self.assertEqual(logs, ["B", "A"])
def test_push_mode_instance_errors(self):
class A(TorchDispatchMode):
pass
with self.assertRaisesRegex(ValueError, 'instance of TorchDispatchMode'):
with push_torch_dispatch_mode(A(inner=None)):
pass
def test_push_mode_returns_unrelated(self):
with self.assertRaisesRegex(ValueError, 'return a TorchDispatchMode'):
with push_torch_dispatch_mode(lambda *, inner: None):
pass
def test_missing_inner_mode_ctor(self):
self.assertRaisesRegex(TypeError, 'push_torch_dispatch_mode', lambda: TorchDispatchMode())
def test_tolist_numpy_with_torch_dispatch_mode(self) -> None:
x = LoggingTensor(torch.tensor([2.0, 3.0]))
with self.assertRaisesRegex(RuntimeError, "is not supported for tensor subclasses."):
@ -554,10 +692,7 @@ $6 = torch._ops.aten.add_.Tensor($1, $5)''')
def wrap(e):
return NonWrapperSubclass(e) if isinstance(e, torch.Tensor) else e
# no_dispatch is only needed if you use enable_torch_dispatch_mode.
# It prevents infinite recursion.
with no_dispatch():
rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
logging.getLogger("NonWrapperSubclass").info(f"{func.__module__}.{func.__name__}", args, kwargs, rs)
return rs
@ -591,10 +726,7 @@ $6 = torch._ops.aten.add_.Tensor($1, $5)''')
def wrap(e):
return SubclassWithNone(e) if isinstance(e, torch.Tensor) else e
# no_dispatch is only needed if you use enable_torch_dispatch_mode.
# It prevents infinite recursion.
with no_dispatch():
rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
if func.overloadpacket.__name__ == "add":
return None
else: