mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
add nested mode to python mode
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75965 Approved by: https://github.com/albanD, https://github.com/ezyang, https://github.com/zou3519
This commit is contained in:
@ -7,9 +7,10 @@ from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
from torch.testing._internal.logging_tensor import LoggingTensor, LoggingTensorReentrant, LoggingTensorMode, \
|
||||
log_input, capture_logs, no_dispatch
|
||||
from torch.utils._pytree import tree_map
|
||||
from torch.utils._python_dispatch import enable_torch_dispatch_mode
|
||||
from torch.utils._python_dispatch import enable_torch_dispatch_mode, push_torch_dispatch_mode, TorchDispatchMode
|
||||
|
||||
import logging
|
||||
from functools import partial
|
||||
|
||||
class TestPythonDispatch(TestCase):
|
||||
def test_basic(self) -> None:
|
||||
@ -448,11 +449,8 @@ $6 = torch._ops.aten.add_.Tensor($1, $5)''')
|
||||
self.assertEqual(called_funcs, [torch.ops.aten.index_put_.default])
|
||||
|
||||
def test_enable_torch_dispatch_mode_error(self) -> None:
|
||||
with self.assertRaisesRegex(ValueError, "__torch_dispatch__"):
|
||||
with enable_torch_dispatch_mode(torch.Tensor):
|
||||
pass
|
||||
z = LoggingTensor(torch.empty([]))
|
||||
with self.assertRaisesRegex(ValueError, "must be the type"):
|
||||
with self.assertRaisesRegex(ValueError, "expected to get TorchDispatchMode, Tensor-like class, or None"):
|
||||
with enable_torch_dispatch_mode(z):
|
||||
pass
|
||||
|
||||
@ -497,9 +495,12 @@ $6 = torch._ops.aten.add_.Tensor($1, $5)''')
|
||||
b = B(torch.empty(1))
|
||||
with self.assertRaises(ErrorA):
|
||||
a + a
|
||||
|
||||
# B has precedence over A due to the subclass relationship
|
||||
with self.assertRaises(ErrorB):
|
||||
a + b
|
||||
|
||||
# B has precedence over A due to the subclass relationship yet
|
||||
# modes take precedence over arguments
|
||||
with self.assertRaises(ErrorA):
|
||||
with enable_torch_dispatch_mode(A):
|
||||
b + b
|
||||
with self.assertRaises(ErrorB):
|
||||
@ -517,12 +518,149 @@ $6 = torch._ops.aten.add_.Tensor($1, $5)''')
|
||||
expected = torch.ones([2, 3])
|
||||
self.assertEqual(z.elem, expected)
|
||||
|
||||
def test_enable_torch_dispatch_mode_instance(self) -> None:
|
||||
class TestMode(TorchDispatchMode):
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
return func(*args, **kwargs)
|
||||
|
||||
x = TestMode(inner=None)
|
||||
y = torch.tensor([2.])
|
||||
with enable_torch_dispatch_mode(x):
|
||||
y + y
|
||||
|
||||
def test_nested_enable_torch_dispatch_mode(self) -> None:
|
||||
with self.assertRaisesRegex(RuntimeError, "has already been set"):
|
||||
class A(LoggingTensorMode):
|
||||
pass
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "there is already an active mode"):
|
||||
with enable_torch_dispatch_mode(LoggingTensorMode):
|
||||
with enable_torch_dispatch_mode(LoggingTensorMode):
|
||||
with enable_torch_dispatch_mode(A):
|
||||
pass
|
||||
|
||||
def test_nesting_with_same_enable_torch_dispatch_mode(self) -> None:
|
||||
# "nested" enable_torch_dispatch_modes are allowed if they're the same mode. It's the equivalent of
|
||||
# a noop, so it will only write once to the log
|
||||
with capture_logs() as logs:
|
||||
x = LoggingTensor(torch.tensor([3.]))
|
||||
log_input("x", x)
|
||||
with enable_torch_dispatch_mode(LoggingTensor):
|
||||
with enable_torch_dispatch_mode(LoggingTensor):
|
||||
x + x
|
||||
|
||||
self.assertExpectedInline('\n'.join(logs), '''\
|
||||
$0 = input('x')
|
||||
$1 = torch._ops.aten.add.Tensor($0, $0)''')
|
||||
|
||||
def test_enable_torch_dispatch_mode_ignore_preexisting(self):
|
||||
class A(torch.Tensor):
|
||||
@staticmethod
|
||||
def __new__(cls, elem):
|
||||
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
return cls(torch.zeros(()))
|
||||
|
||||
class B(A):
|
||||
pass
|
||||
|
||||
with enable_torch_dispatch_mode(A):
|
||||
with enable_torch_dispatch_mode(B, ignore_preexisting=True):
|
||||
self.assertTrue(isinstance(torch.zeros(()), B))
|
||||
|
||||
def test_enable_torch_dispatch_mode_replace(self):
|
||||
class A(torch.Tensor):
|
||||
@staticmethod
|
||||
def __new__(cls, elem):
|
||||
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
return cls(torch.zeros(()))
|
||||
|
||||
class B(A):
|
||||
pass
|
||||
|
||||
with enable_torch_dispatch_mode(A):
|
||||
with enable_torch_dispatch_mode(B, replace=A):
|
||||
self.assertTrue(isinstance(torch.zeros(()), B))
|
||||
|
||||
def test_exception_handling(self):
|
||||
class A(torch.Tensor):
|
||||
@staticmethod
|
||||
def __new__(cls, elem):
|
||||
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
if func.__name__ == 'randn.default':
|
||||
raise RuntimeError()
|
||||
return cls(torch.zeros(()))
|
||||
|
||||
with enable_torch_dispatch_mode(A):
|
||||
try:
|
||||
torch.randn(())
|
||||
except RuntimeError:
|
||||
pass
|
||||
self.assertTrue(isinstance(torch.zeros(()), A))
|
||||
|
||||
def test_push_torch_dispatch_mode(self) -> None:
|
||||
class ErrorA(RuntimeError):
|
||||
def __init__(self, msg=None):
|
||||
return super().__init__(msg)
|
||||
|
||||
class A(TorchDispatchMode):
|
||||
def __init__(self, msg=None):
|
||||
self.msg = msg
|
||||
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
raise ErrorA(self.msg)
|
||||
|
||||
x = torch.randn(3)
|
||||
with self.assertRaises(ErrorA):
|
||||
with push_torch_dispatch_mode(A):
|
||||
torch.add(x, x)
|
||||
|
||||
with self.assertRaisesRegex(ErrorA, r"partial constructor"):
|
||||
with push_torch_dispatch_mode(partial(A, "partial constructor")):
|
||||
x + x
|
||||
|
||||
def test_torch_dispatch_mode_stack(self) -> None:
|
||||
logs = []
|
||||
|
||||
class Logger(TorchDispatchMode):
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
logs.append(self.name)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
x = torch.randn(1)
|
||||
with push_torch_dispatch_mode(partial(Logger, "A")):
|
||||
with push_torch_dispatch_mode(partial(Logger, "B")):
|
||||
x + x
|
||||
self.assertEqual(logs, ["B", "A"])
|
||||
|
||||
def test_push_mode_instance_errors(self):
|
||||
class A(TorchDispatchMode):
|
||||
pass
|
||||
with self.assertRaisesRegex(ValueError, 'instance of TorchDispatchMode'):
|
||||
with push_torch_dispatch_mode(A(inner=None)):
|
||||
pass
|
||||
|
||||
def test_push_mode_returns_unrelated(self):
|
||||
with self.assertRaisesRegex(ValueError, 'return a TorchDispatchMode'):
|
||||
with push_torch_dispatch_mode(lambda *, inner: None):
|
||||
pass
|
||||
|
||||
def test_missing_inner_mode_ctor(self):
|
||||
self.assertRaisesRegex(TypeError, 'push_torch_dispatch_mode', lambda: TorchDispatchMode())
|
||||
|
||||
def test_tolist_numpy_with_torch_dispatch_mode(self) -> None:
|
||||
x = LoggingTensor(torch.tensor([2.0, 3.0]))
|
||||
with self.assertRaisesRegex(RuntimeError, "is not supported for tensor subclasses."):
|
||||
@ -554,10 +692,7 @@ $6 = torch._ops.aten.add_.Tensor($1, $5)''')
|
||||
def wrap(e):
|
||||
return NonWrapperSubclass(e) if isinstance(e, torch.Tensor) else e
|
||||
|
||||
# no_dispatch is only needed if you use enable_torch_dispatch_mode.
|
||||
# It prevents infinite recursion.
|
||||
with no_dispatch():
|
||||
rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
|
||||
rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
|
||||
logging.getLogger("NonWrapperSubclass").info(f"{func.__module__}.{func.__name__}", args, kwargs, rs)
|
||||
return rs
|
||||
|
||||
@ -591,10 +726,7 @@ $6 = torch._ops.aten.add_.Tensor($1, $5)''')
|
||||
def wrap(e):
|
||||
return SubclassWithNone(e) if isinstance(e, torch.Tensor) else e
|
||||
|
||||
# no_dispatch is only needed if you use enable_torch_dispatch_mode.
|
||||
# It prevents infinite recursion.
|
||||
with no_dispatch():
|
||||
rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
|
||||
rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
|
||||
if func.overloadpacket.__name__ == "add":
|
||||
return None
|
||||
else:
|
||||
|
Reference in New Issue
Block a user