[Modes] remove enable and rewrite mode stack (squashed) (#84774)

Based on @ezyang's suggestion, mode stack now has "one true mode" which is the _only_ mode that can ever be active at the C++ level. That mode's torch dispatch is just to take the top mode in the stack, reenable itself (if we aren't at the end of the mode stack), and run the top mode's torch_{dispatch|function}

This maintains that in the middle of a mode's torch dispatch, the mode itself will not be active. It changes the function the user has to call to see what the current mode is (no longer queries the C++, it's python only) but allows the user to also see the entire mode stack easily

Removes `enable_torch_dispatch_mode` and `.restore()` since neither makes sense in this new setup

### Background
Why do we want this? Well, a pretty common pattern that was coming up was that users had to do something like

```python
## PRE-PR UX
def f(mode):
  with mode.restore():  # user needs to understand this restore thing?
    ...

with Mode() as m:
  pass
f(m)
```

Many users were getting error from forgetting to call `.restore` or from forgetting to add the (tbh weird) "mode instantiation"  step where they use the mode as a context manager with an empty body. Really, they wanted to treat modes like context managers and just write
```python
## FROM FEEDBACK, USER DESIRED CODE. POSSIBLE POST-PR
def f(mode):
  with mode:
    ...
f(Mode())
```

** Technical Details **
With the old mode stack, we basically had a linked list so the mode itself could only be used once and had a fixed parent. In this new design, the mode stack is just a python list that we're pushing to and popping from. There's only one mode that's ever active at the C++ level and it runs the next mode in the Python list. The modes don't have state on them anymore
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84774
Approved by: https://github.com/ezyang, https://github.com/zou3519
This commit is contained in:
samdow
2022-09-26 16:42:07 -04:00
committed by PyTorch MergeBot
parent a0be0ca161
commit 18d8c548f4
28 changed files with 666 additions and 999 deletions

View File

@ -8,7 +8,7 @@ from torch.utils._pytree import tree_map
from torch.testing._internal.common_utils import run_tests
from torch.fx.operator_schemas import normalize_function
from torch.testing._internal.schema_check_mode import SchemaCheckMode
from torch.utils._python_dispatch import enable_torch_dispatch_mode, TorchDispatchMode
from torch.utils._python_dispatch import TorchDispatchMode
from torch.testing._internal.common_methods_invocations import op_db
from torch.testing._internal.jit_utils import JitTestCase
from torch.testing._internal.common_device_type import ops, OpDTypes, instantiate_device_type_tests
@ -72,16 +72,14 @@ class IncorrectAliasTensor(torch.Tensor):
class TestSchemaCheck(JitTestCase):
# Tests that SchemaCheckMode records operator order with grad
def test_schema_check_mode_operator_order(self):
schema_check = SchemaCheckMode()
with enable_torch_dispatch_mode(schema_check):
with SchemaCheckMode() as schema_check:
x = torch.rand((3, 3), requires_grad=True)
x.relu().sin()
self.assertEqual(["aten::rand", "aten::relu", "aten::detach", "aten::sin"], schema_check.ops)
# Tests that SchemaCheckMode records operator order without grad
def test_schema_check_mode_operator_order_without_grad(self):
schema_check = SchemaCheckMode()
with enable_torch_dispatch_mode(schema_check):
with SchemaCheckMode() as schema_check:
x = torch.rand((3, 3), requires_grad=False)
x.relu().sin()
self.assertEqual(["aten::rand", "aten::relu", "aten::sin"], schema_check.ops)
@ -91,8 +89,7 @@ class TestSchemaCheck(JitTestCase):
# NB: previously requires_grad=True, but this induces a detach for
# saved variable
x = torch.rand((3, 3))
schema_check = SchemaCheckMode()
with enable_torch_dispatch_mode(schema_check):
with SchemaCheckMode() as schema_check:
actual = x.relu().sin()
self.assertEqual([], schema_check.mutated)
self.assertEqual([], schema_check.aliasing)
@ -100,8 +97,7 @@ class TestSchemaCheck(JitTestCase):
# Tests that SchemaCheckMode records mutations and aliases with mutation expected
def test_schema_check_mode_mutated_aliasing_mutation(self):
actual = torch.rand((3, 3), requires_grad=False)
schema_check = SchemaCheckMode()
with enable_torch_dispatch_mode(schema_check):
with SchemaCheckMode() as schema_check:
actual.sinh_()
self.assertEqual([('aten::sinh_', 'input')], schema_check.mutated)
self.assertEqual([('aten::sinh_', 'input', 'output_0')], schema_check.aliasing)
@ -109,8 +105,7 @@ class TestSchemaCheck(JitTestCase):
# Tests that SchemaCheckMode records mutations and aliases with resize_
def test_schema_check_mode_mutated_aliasing_resize_(self):
actual = torch.rand((3, 3), requires_grad=False)
schema_check = SchemaCheckMode()
with enable_torch_dispatch_mode(schema_check):
with SchemaCheckMode() as schema_check:
actual.resize_(9)
self.assertEqual([('aten::resize_', 'input')], schema_check.mutated)
self.assertEqual([('aten::resize_', 'input', 'output_0')], schema_check.aliasing)
@ -119,8 +114,7 @@ class TestSchemaCheck(JitTestCase):
def test_schema_check_mode_mutated_aliasing_aliasing_inputs(self):
actual = torch.rand((3, 3))
y = actual
schema_check = SchemaCheckMode()
with enable_torch_dispatch_mode(schema_check):
with SchemaCheckMode() as schema_check:
actual.add_(y)
self.assertEqual(
[
@ -140,8 +134,7 @@ class TestSchemaCheck(JitTestCase):
# Tests that SchemaCheckMode records mutations and alias with as_strided
def test_schema_check_mode_mutated_aliasing_as_strided(self):
x = torch.rand((3, 6, 4))
schema_check = SchemaCheckMode()
with enable_torch_dispatch_mode(schema_check):
with SchemaCheckMode() as schema_check:
x.as_strided_([3, 6, 4], [9, 1, 1])
self.assertEqual(
[
@ -161,8 +154,7 @@ class TestSchemaCheck(JitTestCase):
x = torch.arange(9.)
m_actual = torch.arange(9.)
e_actual = torch.zeros([9], dtype=torch.int32)
schema_check = SchemaCheckMode()
with enable_torch_dispatch_mode(schema_check):
with SchemaCheckMode() as schema_check:
torch.frexp(x, out=(m_actual, e_actual))
self.assertEqual(
[
@ -183,8 +175,7 @@ class TestSchemaCheck(JitTestCase):
def test_schema_check_mode_mutated_aliasing_aliasing_outputs(self):
x = torch.rand((3, 3))
actual = torch.zeros(3)
schema_check = SchemaCheckMode()
with enable_torch_dispatch_mode(schema_check):
with SchemaCheckMode() as schema_check:
torch.aminmax(x, dim=0, out=[actual, actual])
self.assertEqual(
[
@ -207,7 +198,7 @@ class TestSchemaCheck(JitTestCase):
def test_schema_check_mode_functionality(self):
x = torch.rand((3, 3), requires_grad=True)
expected = x.relu().sin()
with enable_torch_dispatch_mode(SchemaCheckMode()):
with SchemaCheckMode():
actual = x.relu().sin()
self.assertEqual(expected, actual)
@ -215,7 +206,7 @@ class TestSchemaCheck(JitTestCase):
def test_schema_check_mode_functionality_default_replaced(self):
x = torch.rand((3, 3), requires_grad=True)
expected = x.add(x, alpha=2)
with enable_torch_dispatch_mode(SchemaCheckMode()):
with SchemaCheckMode():
actual = x.add(x, alpha=2)
self.assertEqual(expected, actual)
@ -225,7 +216,7 @@ class TestSchemaCheck(JitTestCase):
b = torch.rand((3, 3))
c = torch.rand((3, 3))
expected = torch.linalg.multi_dot([a, b, c])
with enable_torch_dispatch_mode(SchemaCheckMode()):
with SchemaCheckMode():
actual = torch.linalg.multi_dot([a, b, c])
self.assertEqual(expected, actual)
@ -233,7 +224,7 @@ class TestSchemaCheck(JitTestCase):
def test_schema_check_mode_functionality_wildcard_after(self):
x = torch.rand((3, 3))
expected = x.chunk(6)
with enable_torch_dispatch_mode(SchemaCheckMode()):
with SchemaCheckMode():
actual = x.chunk(6)
self.assertEqual(expected, actual)
@ -242,7 +233,7 @@ class TestSchemaCheck(JitTestCase):
x = torch.rand((3, 5))
w = torch.rand((4))
expected = torch.stft(x, 4, win_length=4, window=w, return_complex=True)
with enable_torch_dispatch_mode(SchemaCheckMode()):
with SchemaCheckMode():
actual = torch.stft(x, 4, win_length=4, window=w, return_complex=True)
self.assertEqual(expected, actual)
@ -251,7 +242,7 @@ class TestSchemaCheck(JitTestCase):
expected = torch.rand((3, 3), requires_grad=False)
actual = torch.clone(expected)
expected.sinh_()
with enable_torch_dispatch_mode(SchemaCheckMode()):
with SchemaCheckMode():
actual.sinh_()
self.assertEqual(expected, actual)
@ -262,7 +253,7 @@ class TestSchemaCheck(JitTestCase):
actual = torch.clone(expected)
y = actual
expected.add_(x)
with enable_torch_dispatch_mode(SchemaCheckMode()):
with SchemaCheckMode():
actual.add_(y)
self.assertEqual(expected, actual)
@ -272,7 +263,7 @@ class TestSchemaCheck(JitTestCase):
m_expected, e_expected = torch.frexp(x)
m_actual = torch.arange(9.)
e_actual = torch.zeros([9], dtype=torch.int32)
with enable_torch_dispatch_mode(SchemaCheckMode()):
with SchemaCheckMode():
torch.frexp(x, out=(m_actual, e_actual))
self.assertEqual(m_expected, m_actual)
self.assertEqual(e_expected, e_actual)
@ -281,13 +272,13 @@ class TestSchemaCheck(JitTestCase):
def test_schema_check_mode_functionality_with_multiple_outputs_aliasing(self):
x = torch.rand((3, 3))
actual = torch.zeros(3)
with enable_torch_dispatch_mode(SchemaCheckMode()):
with SchemaCheckMode():
torch.aminmax(x, dim=0, out=[actual, actual])
self.assertEqual(torch.amax(x, dim=0), actual)
# Tests that SchemaCheckMode wraps Torch.tensor in ops with real Device input
def test_schema_check_mode_functionality_device_input(self):
with enable_torch_dispatch_mode(SchemaCheckMode()):
with SchemaCheckMode():
x = torch.rand((3, 3), device="cpu", dtype=torch.double)
y = x + x
self.assertEqual(x + x, y)
@ -297,7 +288,7 @@ class TestSchemaCheck(JitTestCase):
x = torch.rand((3, 3), requires_grad=True)
batch = torch.nn.BatchNorm1d(3, track_running_stats=True)
expected = batch(x)
with enable_torch_dispatch_mode(SchemaCheckMode()):
with SchemaCheckMode():
actual = batch(x)
self.assertEqual(expected, actual)
@ -311,7 +302,7 @@ class TestSchemaCheck(JitTestCase):
expected.relu_()
expected = batch(expected)
with enable_torch_dispatch_mode(SchemaCheckMode()):
with SchemaCheckMode():
actual.sinh_()
actual.tanh_()
actual.relu_()
@ -321,7 +312,7 @@ class TestSchemaCheck(JitTestCase):
# Tests that SchemaCheckMode wraps Torch.tensor with empty list input
def test_schema_check_mode_empty_list_input(self):
expected = torch.atleast_1d([])
with enable_torch_dispatch_mode(SchemaCheckMode()):
with SchemaCheckMode():
actual = torch.atleast_1d([])
self.assertEqual(expected, actual)
@ -330,7 +321,7 @@ class TestSchemaCheck(JitTestCase):
with self.assertRaisesRegex(RuntimeError, "Argument input is not defined as mutable but was mutated"):
x = torch.rand((3, 3))
y = torch.rand((3, 3))
with enable_torch_dispatch_mode(SchemaCheckMode()):
with SchemaCheckMode():
IncorrectAliasTensor(x).sub(IncorrectAliasTensor(y))
# # Tests that an exception is raised for a mismatching mutation over multiple ops
@ -338,7 +329,7 @@ class TestSchemaCheck(JitTestCase):
with self.assertRaisesRegex(RuntimeError, "Argument input is not defined as mutable but was mutated"):
x = torch.rand((3, 3))
y = torch.rand((3, 3))
with enable_torch_dispatch_mode(SchemaCheckMode()):
with SchemaCheckMode():
IncorrectAliasTensor(x).sin().cos().sub(IncorrectAliasTensor(y))
# Tests that an exception is raised for a mismatching alias
@ -346,7 +337,7 @@ class TestSchemaCheck(JitTestCase):
with self.assertRaisesRegex(RuntimeError, "Argument input is not defined to alias output but was aliasing"):
x = torch.rand((3, 3), requires_grad=True)
y = torch.rand((3, 3))
with enable_torch_dispatch_mode(SchemaCheckMode()):
with SchemaCheckMode():
IncorrectAliasTensor(x).add(IncorrectAliasTensor(y), alpha=2)
# Tests that an exception is raised for a mismatching alias over multiple ops
@ -354,7 +345,7 @@ class TestSchemaCheck(JitTestCase):
with self.assertRaisesRegex(RuntimeError, "Argument input is not defined to alias output but was aliasing"):
x = torch.rand((3, 3), requires_grad=True)
y = torch.zeros((3, 3), requires_grad=True)
with enable_torch_dispatch_mode(SchemaCheckMode()):
with SchemaCheckMode():
IncorrectAliasTensor(x).sin().relu().add(IncorrectAliasTensor(y), alpha=2)
# Tests that an exception is raised for a centered mismatching alias over multiple ops
@ -362,15 +353,14 @@ class TestSchemaCheck(JitTestCase):
with self.assertRaisesRegex(RuntimeError, "Argument input is not defined to alias output but was aliasing"):
x = torch.rand((3, 3), requires_grad=True)
y = torch.zeros((3, 3), requires_grad=True)
with enable_torch_dispatch_mode(SchemaCheckMode()):
with SchemaCheckMode():
IncorrectAliasTensor(x).sin().add(IncorrectAliasTensor(y), alpha=2).relu()
# Tests that an exception is raised for a centered mismatching alias over multiple ops
def test_alias_check_fail_outputs_unexpectedly_aliasing(self):
with self.assertRaisesRegex(RuntimeError, "Outputs 0 and 1 alias unexpectedly"):
x = torch.rand((3, 3))
s = SchemaCheckMode()
with enable_torch_dispatch_mode(s):
with SchemaCheckMode() as s:
IncorrectAliasTensor(x).aminmax(dim=0)
# Tests that is_alias_of returns as expected
@ -439,8 +429,7 @@ class TestSchemaCheck(JitTestCase):
return func(*args, **kwargs)
x = torch.rand((3, 3))
schemaInfoCheck = SchemaInfoBindTestMode(self)
with enable_torch_dispatch_mode(schemaInfoCheck):
with SchemaInfoBindTestMode(self) as schemaInfoCheck:
x.add(x)
@ -452,7 +441,7 @@ class TestSchemaCheckModeOpInfo(JitTestCase):
if (dtype == torch.complex32):
return
for sample in op.sample_inputs(device, dtype, requires_grad=False):
with enable_torch_dispatch_mode(SchemaCheckMode()):
with SchemaCheckMode():
op(sample.input, *sample.args, **sample.kwargs)
instantiate_device_type_tests(TestSchemaCheckModeOpInfo, globals(), only_for=("cpu", "cuda"))