mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
get rid of push_torch_{dispatch, function}_mode (#78215)
Currently we have 2 ways of doing the same thing for torch dispatch and function modes: `with push_torch_dispatch_mode(X)` or `with X.push(...)` is now the equivalent of doing `with X()` This removes the first API (which is older and private so we don't need to go through a deprecation cycle) There is some risk here that this might land race with a PR that uses the old API but in general it seems like most are using the `with X()` API or `enable_torch_dispatch_mode(X())` which isn't getting removed. EDIT: left the `with X.push(...)` API since there were ~3 land races with that over the past day or so. But made it give a warning and ask users to use the other API Pull Request resolved: https://github.com/pytorch/pytorch/pull/78215 Approved by: https://github.com/ezyang
This commit is contained in:
@ -11,10 +11,9 @@ from torch.utils._mode_utils import no_dispatch, find_outermost_mode, all_same_m
|
||||
from torch.testing._internal.logging_tensor import LoggingTensor, LoggingTensorReentrant, LoggingTensorMode, \
|
||||
log_input, capture_logs, capture_logs_with_logging_tensor_mode
|
||||
from torch.utils._pytree import tree_map
|
||||
from torch.utils._python_dispatch import enable_torch_dispatch_mode, push_torch_dispatch_mode, TorchDispatchMode
|
||||
from torch.utils._python_dispatch import enable_torch_dispatch_mode, TorchDispatchMode
|
||||
|
||||
import logging
|
||||
from functools import partial
|
||||
|
||||
|
||||
class TestPythonRegistration(TestCase):
|
||||
@ -745,7 +744,7 @@ $6 = torch._ops.aten.add_.Tensor($1, $5)''')
|
||||
|
||||
def test_enable_torch_dispatch_mode_basic(self) -> None:
|
||||
with capture_logs(is_mode=True) as logs:
|
||||
with enable_torch_dispatch_mode(LoggingTensorMode(inner=None)):
|
||||
with enable_torch_dispatch_mode(LoggingTensorMode()):
|
||||
torch.empty([])
|
||||
self.assertExpectedInline('\n'.join(logs), ("$0 = torch._ops.aten.empty.SymInt([], dtype=torch.float32," +
|
||||
" device=device(type='cpu'), pin_memory=False)"))
|
||||
@ -754,7 +753,7 @@ $6 = torch._ops.aten.add_.Tensor($1, $5)''')
|
||||
x = torch.randn([])
|
||||
y = torch.randn([])
|
||||
with capture_logs(is_mode=True) as logs:
|
||||
with enable_torch_dispatch_mode(LoggingTensorMode(inner=None)):
|
||||
with enable_torch_dispatch_mode(LoggingTensorMode()):
|
||||
x + y
|
||||
self.assertExpectedInline('\n'.join(logs), """\
|
||||
$2 = torch._ops.aten.add.Tensor($0, $1)""")
|
||||
@ -769,8 +768,8 @@ $2 = torch._ops.aten.add.Tensor($0, $1)""")
|
||||
x = torch.randn([])
|
||||
y = torch.randn([])
|
||||
with capture_logs(is_mode=True) as logs:
|
||||
with push_torch_dispatch_mode(LoggingTensorMode):
|
||||
with push_torch_dispatch_mode(LoggingTensorMode):
|
||||
with LoggingTensorMode():
|
||||
with LoggingTensorMode():
|
||||
torch.empty([])
|
||||
x + y
|
||||
|
||||
@ -852,12 +851,12 @@ $3 = torch._ops.aten.add.Tensor($1, $2)""")
|
||||
|
||||
def test_enable_torch_dispatch_mode_respects_no_dispatch(self) -> None:
|
||||
with capture_logs(is_mode=True) as logs1:
|
||||
with enable_torch_dispatch_mode(LoggingTensorMode(inner=None)):
|
||||
with enable_torch_dispatch_mode(LoggingTensorMode()):
|
||||
torch.ones([2, 3])
|
||||
with no_dispatch():
|
||||
torch.ones([2, 3])
|
||||
with capture_logs(is_mode=True) as logs2:
|
||||
with enable_torch_dispatch_mode(LoggingTensorMode(inner=None)):
|
||||
with enable_torch_dispatch_mode(LoggingTensorMode()):
|
||||
torch.ones([2, 3])
|
||||
self.assertEqual(logs1, logs2)
|
||||
|
||||
@ -878,21 +877,21 @@ $3 = torch._ops.aten.add.Tensor($1, $2)""")
|
||||
pass
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "there is already an active mode"):
|
||||
with enable_torch_dispatch_mode(LoggingTensorMode(inner=None)):
|
||||
with enable_torch_dispatch_mode(A(inner=None)):
|
||||
with enable_torch_dispatch_mode(LoggingTensorMode()):
|
||||
with enable_torch_dispatch_mode(A()):
|
||||
pass
|
||||
|
||||
# For nesting to be a noop, they need to be the same instance
|
||||
with self.assertRaisesRegex(ValueError, "there is already an active mode"):
|
||||
with enable_torch_dispatch_mode(LoggingTensorMode(inner=None)):
|
||||
with enable_torch_dispatch_mode(LoggingTensorMode(inner=None)):
|
||||
with enable_torch_dispatch_mode(LoggingTensorMode()):
|
||||
with enable_torch_dispatch_mode(LoggingTensorMode()):
|
||||
pass
|
||||
|
||||
def test_nesting_with_same_enable_torch_dispatch_mode(self) -> None:
|
||||
# "nested" enable_torch_dispatch_modes are allowed if they're the same mode (same instance).
|
||||
# It's the equivalent of a noop, so it will only write once to the log
|
||||
x = torch.tensor([3.])
|
||||
mode = LoggingTensorMode(inner=None)
|
||||
mode = LoggingTensorMode()
|
||||
with capture_logs(is_mode=True) as logs:
|
||||
log_input("x", x)
|
||||
with enable_torch_dispatch_mode(mode):
|
||||
@ -909,8 +908,8 @@ $1 = torch._ops.aten.add.Tensor($0, $0)''')
|
||||
|
||||
x = torch.tensor([3.])
|
||||
with capture_logs(is_mode=True) as logs:
|
||||
with enable_torch_dispatch_mode(A(inner=None)):
|
||||
with enable_torch_dispatch_mode(LoggingTensorMode(inner=None), ignore_preexisting=True):
|
||||
with enable_torch_dispatch_mode(A()):
|
||||
with enable_torch_dispatch_mode(LoggingTensorMode(), ignore_preexisting=True):
|
||||
x + x
|
||||
self.assertExpectedInline('\n'.join(logs), """\
|
||||
$1 = torch._ops.aten.add.Tensor($0, $0)""")
|
||||
@ -921,10 +920,10 @@ $1 = torch._ops.aten.add.Tensor($0, $0)""")
|
||||
raise AssertionError
|
||||
|
||||
x = torch.tensor([3.])
|
||||
outer_mode = A(inner=None)
|
||||
outer_mode = A()
|
||||
with capture_logs(is_mode=True) as logs:
|
||||
with enable_torch_dispatch_mode(outer_mode):
|
||||
with enable_torch_dispatch_mode(LoggingTensorMode(inner=None), replace=outer_mode):
|
||||
with enable_torch_dispatch_mode(LoggingTensorMode(), replace=outer_mode):
|
||||
x + x
|
||||
self.assertExpectedInline('\n'.join(logs), """\
|
||||
$1 = torch._ops.aten.add.Tensor($0, $0)""")
|
||||
@ -948,58 +947,6 @@ $1 = torch._ops.aten.add.Tensor($0, $0)""")
|
||||
pass
|
||||
self.assertTrue(isinstance(torch.zeros(()), A))
|
||||
|
||||
def test_push_torch_dispatch_mode(self) -> None:
|
||||
class ErrorA(RuntimeError):
|
||||
def __init__(self, msg=None):
|
||||
return super().__init__(msg)
|
||||
|
||||
class A(TorchDispatchMode):
|
||||
def __init__(self, msg=None):
|
||||
self.msg = msg
|
||||
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
raise ErrorA(self.msg)
|
||||
|
||||
x = torch.randn(3)
|
||||
with self.assertRaises(ErrorA):
|
||||
with push_torch_dispatch_mode(A):
|
||||
torch.add(x, x)
|
||||
|
||||
with self.assertRaisesRegex(ErrorA, r"partial constructor"):
|
||||
with push_torch_dispatch_mode(partial(A, "partial constructor")):
|
||||
x + x
|
||||
|
||||
def test_torch_dispatch_mode_stack(self) -> None:
|
||||
logs = []
|
||||
|
||||
class Logger(TorchDispatchMode):
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
logs.append(self.name)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
x = torch.randn(1)
|
||||
with Logger.push("A"):
|
||||
with Logger.push("B"):
|
||||
x + x
|
||||
self.assertEqual(logs, ["B", "A"])
|
||||
|
||||
def test_push_mode_instance_errors(self):
|
||||
class A(TorchDispatchMode):
|
||||
pass
|
||||
with self.assertRaisesRegex(ValueError, 'instance of TorchDispatchMode'):
|
||||
with push_torch_dispatch_mode(A()):
|
||||
pass
|
||||
|
||||
def test_push_mode_returns_unrelated(self):
|
||||
with self.assertRaisesRegex(ValueError, 'return a TorchDispatchMode'):
|
||||
with push_torch_dispatch_mode(lambda *, inner: None):
|
||||
pass
|
||||
|
||||
def test_ctor_no_inner(self):
|
||||
class A(TorchDispatchMode):
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
|
Reference in New Issue
Block a user