get rid of push_torch_{dispatch, function}_mode (#78215)

Currently we have 2 ways of doing the same thing for torch dispatch and function modes:
`with push_torch_dispatch_mode(X)` or `with X.push(...)`
is now the equivalent of doing
`with X()`

This removes the first API (which is older and private so we don't need to go through a deprecation cycle)

There is some risk here that this might land race with a PR that uses the old API but in general it seems like most are using the `with X()` API or `enable_torch_dispatch_mode(X())` which isn't getting removed.

EDIT: left the `with X.push(...)` API since there were ~3 land races with that over the past day or so. But made it give a warning and ask users to use the other API
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78215
Approved by: https://github.com/ezyang
This commit is contained in:
samdow
2022-07-22 11:10:52 -04:00
committed by PyTorch MergeBot
parent 3bd08e3410
commit 2ac24675cc
13 changed files with 60 additions and 206 deletions

View File

@ -11,10 +11,9 @@ from torch.utils._mode_utils import no_dispatch, find_outermost_mode, all_same_m
from torch.testing._internal.logging_tensor import LoggingTensor, LoggingTensorReentrant, LoggingTensorMode, \
log_input, capture_logs, capture_logs_with_logging_tensor_mode
from torch.utils._pytree import tree_map
from torch.utils._python_dispatch import enable_torch_dispatch_mode, push_torch_dispatch_mode, TorchDispatchMode
from torch.utils._python_dispatch import enable_torch_dispatch_mode, TorchDispatchMode
import logging
from functools import partial
class TestPythonRegistration(TestCase):
@ -745,7 +744,7 @@ $6 = torch._ops.aten.add_.Tensor($1, $5)''')
def test_enable_torch_dispatch_mode_basic(self) -> None:
with capture_logs(is_mode=True) as logs:
with enable_torch_dispatch_mode(LoggingTensorMode(inner=None)):
with enable_torch_dispatch_mode(LoggingTensorMode()):
torch.empty([])
self.assertExpectedInline('\n'.join(logs), ("$0 = torch._ops.aten.empty.SymInt([], dtype=torch.float32," +
" device=device(type='cpu'), pin_memory=False)"))
@ -754,7 +753,7 @@ $6 = torch._ops.aten.add_.Tensor($1, $5)''')
x = torch.randn([])
y = torch.randn([])
with capture_logs(is_mode=True) as logs:
with enable_torch_dispatch_mode(LoggingTensorMode(inner=None)):
with enable_torch_dispatch_mode(LoggingTensorMode()):
x + y
self.assertExpectedInline('\n'.join(logs), """\
$2 = torch._ops.aten.add.Tensor($0, $1)""")
@ -769,8 +768,8 @@ $2 = torch._ops.aten.add.Tensor($0, $1)""")
x = torch.randn([])
y = torch.randn([])
with capture_logs(is_mode=True) as logs:
with push_torch_dispatch_mode(LoggingTensorMode):
with push_torch_dispatch_mode(LoggingTensorMode):
with LoggingTensorMode():
with LoggingTensorMode():
torch.empty([])
x + y
@ -852,12 +851,12 @@ $3 = torch._ops.aten.add.Tensor($1, $2)""")
def test_enable_torch_dispatch_mode_respects_no_dispatch(self) -> None:
with capture_logs(is_mode=True) as logs1:
with enable_torch_dispatch_mode(LoggingTensorMode(inner=None)):
with enable_torch_dispatch_mode(LoggingTensorMode()):
torch.ones([2, 3])
with no_dispatch():
torch.ones([2, 3])
with capture_logs(is_mode=True) as logs2:
with enable_torch_dispatch_mode(LoggingTensorMode(inner=None)):
with enable_torch_dispatch_mode(LoggingTensorMode()):
torch.ones([2, 3])
self.assertEqual(logs1, logs2)
@ -878,21 +877,21 @@ $3 = torch._ops.aten.add.Tensor($1, $2)""")
pass
with self.assertRaisesRegex(ValueError, "there is already an active mode"):
with enable_torch_dispatch_mode(LoggingTensorMode(inner=None)):
with enable_torch_dispatch_mode(A(inner=None)):
with enable_torch_dispatch_mode(LoggingTensorMode()):
with enable_torch_dispatch_mode(A()):
pass
# For nesting to be a noop, they need to be the same instance
with self.assertRaisesRegex(ValueError, "there is already an active mode"):
with enable_torch_dispatch_mode(LoggingTensorMode(inner=None)):
with enable_torch_dispatch_mode(LoggingTensorMode(inner=None)):
with enable_torch_dispatch_mode(LoggingTensorMode()):
with enable_torch_dispatch_mode(LoggingTensorMode()):
pass
def test_nesting_with_same_enable_torch_dispatch_mode(self) -> None:
# "nested" enable_torch_dispatch_modes are allowed if they're the same mode (same instance).
# It's the equivalent of a noop, so it will only write once to the log
x = torch.tensor([3.])
mode = LoggingTensorMode(inner=None)
mode = LoggingTensorMode()
with capture_logs(is_mode=True) as logs:
log_input("x", x)
with enable_torch_dispatch_mode(mode):
@ -909,8 +908,8 @@ $1 = torch._ops.aten.add.Tensor($0, $0)''')
x = torch.tensor([3.])
with capture_logs(is_mode=True) as logs:
with enable_torch_dispatch_mode(A(inner=None)):
with enable_torch_dispatch_mode(LoggingTensorMode(inner=None), ignore_preexisting=True):
with enable_torch_dispatch_mode(A()):
with enable_torch_dispatch_mode(LoggingTensorMode(), ignore_preexisting=True):
x + x
self.assertExpectedInline('\n'.join(logs), """\
$1 = torch._ops.aten.add.Tensor($0, $0)""")
@ -921,10 +920,10 @@ $1 = torch._ops.aten.add.Tensor($0, $0)""")
raise AssertionError
x = torch.tensor([3.])
outer_mode = A(inner=None)
outer_mode = A()
with capture_logs(is_mode=True) as logs:
with enable_torch_dispatch_mode(outer_mode):
with enable_torch_dispatch_mode(LoggingTensorMode(inner=None), replace=outer_mode):
with enable_torch_dispatch_mode(LoggingTensorMode(), replace=outer_mode):
x + x
self.assertExpectedInline('\n'.join(logs), """\
$1 = torch._ops.aten.add.Tensor($0, $0)""")
@ -948,58 +947,6 @@ $1 = torch._ops.aten.add.Tensor($0, $0)""")
pass
self.assertTrue(isinstance(torch.zeros(()), A))
def test_push_torch_dispatch_mode(self) -> None:
class ErrorA(RuntimeError):
def __init__(self, msg=None):
return super().__init__(msg)
class A(TorchDispatchMode):
def __init__(self, msg=None):
self.msg = msg
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
raise ErrorA(self.msg)
x = torch.randn(3)
with self.assertRaises(ErrorA):
with push_torch_dispatch_mode(A):
torch.add(x, x)
with self.assertRaisesRegex(ErrorA, r"partial constructor"):
with push_torch_dispatch_mode(partial(A, "partial constructor")):
x + x
def test_torch_dispatch_mode_stack(self) -> None:
logs = []
class Logger(TorchDispatchMode):
def __init__(self, name):
self.name = name
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
logs.append(self.name)
return func(*args, **kwargs)
x = torch.randn(1)
with Logger.push("A"):
with Logger.push("B"):
x + x
self.assertEqual(logs, ["B", "A"])
def test_push_mode_instance_errors(self):
class A(TorchDispatchMode):
pass
with self.assertRaisesRegex(ValueError, 'instance of TorchDispatchMode'):
with push_torch_dispatch_mode(A()):
pass
def test_push_mode_returns_unrelated(self):
with self.assertRaisesRegex(ValueError, 'return a TorchDispatchMode'):
with push_torch_dispatch_mode(lambda *, inner: None):
pass
def test_ctor_no_inner(self):
class A(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):