mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[Modes] remove enable and rewrite mode stack (squashed) (#84774)
Based on @ezyang's suggestion, mode stack now has "one true mode" which is the _only_ mode that can ever be active at the C++ level. That mode's torch dispatch is just to take the top mode in the stack, reenable itself (if we aren't at the end of the mode stack), and run the top mode's torch_{dispatch|function} This maintains that in the middle of a mode's torch dispatch, the mode itself will not be active. It changes the function the user has to call to see what the current mode is (no longer queries the C++, it's python only) but allows the user to also see the entire mode stack easily Removes `enable_torch_dispatch_mode` and `.restore()` since neither makes sense in this new setup ### Background Why do we want this? Well, a pretty common pattern that was coming up was that users had to do something like ```python ## PRE-PR UX def f(mode): with mode.restore(): # user needs to understand this restore thing? ... with Mode() as m: pass f(m) ``` Many users were getting error from forgetting to call `.restore` or from forgetting to add the (tbh weird) "mode instantiation" step where they use the mode as a context manager with an empty body. Really, they wanted to treat modes like context managers and just write ```python ## FROM FEEDBACK, USER DESIRED CODE. POSSIBLE POST-PR def f(mode): with mode: ... f(Mode()) ``` ** Technical Details ** With the old mode stack, we basically had a linked list so the mode itself could only be used once and had a fixed parent. In this new design, the mode stack is just a python list that we're pushing to and popping from. There's only one mode that's ever active at the C++ level and it runs the next mode in the Python list. The modes don't have state on them anymore Pull Request resolved: https://github.com/pytorch/pytorch/pull/84774 Approved by: https://github.com/ezyang, https://github.com/zou3519
This commit is contained in:
@ -8,7 +8,7 @@ from torch.utils._pytree import tree_map
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.fx.operator_schemas import normalize_function
|
||||
from torch.testing._internal.schema_check_mode import SchemaCheckMode
|
||||
from torch.utils._python_dispatch import enable_torch_dispatch_mode, TorchDispatchMode
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
from torch.testing._internal.common_methods_invocations import op_db
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
from torch.testing._internal.common_device_type import ops, OpDTypes, instantiate_device_type_tests
|
||||
@ -72,16 +72,14 @@ class IncorrectAliasTensor(torch.Tensor):
|
||||
class TestSchemaCheck(JitTestCase):
|
||||
# Tests that SchemaCheckMode records operator order with grad
|
||||
def test_schema_check_mode_operator_order(self):
|
||||
schema_check = SchemaCheckMode()
|
||||
with enable_torch_dispatch_mode(schema_check):
|
||||
with SchemaCheckMode() as schema_check:
|
||||
x = torch.rand((3, 3), requires_grad=True)
|
||||
x.relu().sin()
|
||||
self.assertEqual(["aten::rand", "aten::relu", "aten::detach", "aten::sin"], schema_check.ops)
|
||||
|
||||
# Tests that SchemaCheckMode records operator order without grad
|
||||
def test_schema_check_mode_operator_order_without_grad(self):
|
||||
schema_check = SchemaCheckMode()
|
||||
with enable_torch_dispatch_mode(schema_check):
|
||||
with SchemaCheckMode() as schema_check:
|
||||
x = torch.rand((3, 3), requires_grad=False)
|
||||
x.relu().sin()
|
||||
self.assertEqual(["aten::rand", "aten::relu", "aten::sin"], schema_check.ops)
|
||||
@ -91,8 +89,7 @@ class TestSchemaCheck(JitTestCase):
|
||||
# NB: previously requires_grad=True, but this induces a detach for
|
||||
# saved variable
|
||||
x = torch.rand((3, 3))
|
||||
schema_check = SchemaCheckMode()
|
||||
with enable_torch_dispatch_mode(schema_check):
|
||||
with SchemaCheckMode() as schema_check:
|
||||
actual = x.relu().sin()
|
||||
self.assertEqual([], schema_check.mutated)
|
||||
self.assertEqual([], schema_check.aliasing)
|
||||
@ -100,8 +97,7 @@ class TestSchemaCheck(JitTestCase):
|
||||
# Tests that SchemaCheckMode records mutations and aliases with mutation expected
|
||||
def test_schema_check_mode_mutated_aliasing_mutation(self):
|
||||
actual = torch.rand((3, 3), requires_grad=False)
|
||||
schema_check = SchemaCheckMode()
|
||||
with enable_torch_dispatch_mode(schema_check):
|
||||
with SchemaCheckMode() as schema_check:
|
||||
actual.sinh_()
|
||||
self.assertEqual([('aten::sinh_', 'input')], schema_check.mutated)
|
||||
self.assertEqual([('aten::sinh_', 'input', 'output_0')], schema_check.aliasing)
|
||||
@ -109,8 +105,7 @@ class TestSchemaCheck(JitTestCase):
|
||||
# Tests that SchemaCheckMode records mutations and aliases with resize_
|
||||
def test_schema_check_mode_mutated_aliasing_resize_(self):
|
||||
actual = torch.rand((3, 3), requires_grad=False)
|
||||
schema_check = SchemaCheckMode()
|
||||
with enable_torch_dispatch_mode(schema_check):
|
||||
with SchemaCheckMode() as schema_check:
|
||||
actual.resize_(9)
|
||||
self.assertEqual([('aten::resize_', 'input')], schema_check.mutated)
|
||||
self.assertEqual([('aten::resize_', 'input', 'output_0')], schema_check.aliasing)
|
||||
@ -119,8 +114,7 @@ class TestSchemaCheck(JitTestCase):
|
||||
def test_schema_check_mode_mutated_aliasing_aliasing_inputs(self):
|
||||
actual = torch.rand((3, 3))
|
||||
y = actual
|
||||
schema_check = SchemaCheckMode()
|
||||
with enable_torch_dispatch_mode(schema_check):
|
||||
with SchemaCheckMode() as schema_check:
|
||||
actual.add_(y)
|
||||
self.assertEqual(
|
||||
[
|
||||
@ -140,8 +134,7 @@ class TestSchemaCheck(JitTestCase):
|
||||
# Tests that SchemaCheckMode records mutations and alias with as_strided
|
||||
def test_schema_check_mode_mutated_aliasing_as_strided(self):
|
||||
x = torch.rand((3, 6, 4))
|
||||
schema_check = SchemaCheckMode()
|
||||
with enable_torch_dispatch_mode(schema_check):
|
||||
with SchemaCheckMode() as schema_check:
|
||||
x.as_strided_([3, 6, 4], [9, 1, 1])
|
||||
self.assertEqual(
|
||||
[
|
||||
@ -161,8 +154,7 @@ class TestSchemaCheck(JitTestCase):
|
||||
x = torch.arange(9.)
|
||||
m_actual = torch.arange(9.)
|
||||
e_actual = torch.zeros([9], dtype=torch.int32)
|
||||
schema_check = SchemaCheckMode()
|
||||
with enable_torch_dispatch_mode(schema_check):
|
||||
with SchemaCheckMode() as schema_check:
|
||||
torch.frexp(x, out=(m_actual, e_actual))
|
||||
self.assertEqual(
|
||||
[
|
||||
@ -183,8 +175,7 @@ class TestSchemaCheck(JitTestCase):
|
||||
def test_schema_check_mode_mutated_aliasing_aliasing_outputs(self):
|
||||
x = torch.rand((3, 3))
|
||||
actual = torch.zeros(3)
|
||||
schema_check = SchemaCheckMode()
|
||||
with enable_torch_dispatch_mode(schema_check):
|
||||
with SchemaCheckMode() as schema_check:
|
||||
torch.aminmax(x, dim=0, out=[actual, actual])
|
||||
self.assertEqual(
|
||||
[
|
||||
@ -207,7 +198,7 @@ class TestSchemaCheck(JitTestCase):
|
||||
def test_schema_check_mode_functionality(self):
|
||||
x = torch.rand((3, 3), requires_grad=True)
|
||||
expected = x.relu().sin()
|
||||
with enable_torch_dispatch_mode(SchemaCheckMode()):
|
||||
with SchemaCheckMode():
|
||||
actual = x.relu().sin()
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
@ -215,7 +206,7 @@ class TestSchemaCheck(JitTestCase):
|
||||
def test_schema_check_mode_functionality_default_replaced(self):
|
||||
x = torch.rand((3, 3), requires_grad=True)
|
||||
expected = x.add(x, alpha=2)
|
||||
with enable_torch_dispatch_mode(SchemaCheckMode()):
|
||||
with SchemaCheckMode():
|
||||
actual = x.add(x, alpha=2)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
@ -225,7 +216,7 @@ class TestSchemaCheck(JitTestCase):
|
||||
b = torch.rand((3, 3))
|
||||
c = torch.rand((3, 3))
|
||||
expected = torch.linalg.multi_dot([a, b, c])
|
||||
with enable_torch_dispatch_mode(SchemaCheckMode()):
|
||||
with SchemaCheckMode():
|
||||
actual = torch.linalg.multi_dot([a, b, c])
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
@ -233,7 +224,7 @@ class TestSchemaCheck(JitTestCase):
|
||||
def test_schema_check_mode_functionality_wildcard_after(self):
|
||||
x = torch.rand((3, 3))
|
||||
expected = x.chunk(6)
|
||||
with enable_torch_dispatch_mode(SchemaCheckMode()):
|
||||
with SchemaCheckMode():
|
||||
actual = x.chunk(6)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
@ -242,7 +233,7 @@ class TestSchemaCheck(JitTestCase):
|
||||
x = torch.rand((3, 5))
|
||||
w = torch.rand((4))
|
||||
expected = torch.stft(x, 4, win_length=4, window=w, return_complex=True)
|
||||
with enable_torch_dispatch_mode(SchemaCheckMode()):
|
||||
with SchemaCheckMode():
|
||||
actual = torch.stft(x, 4, win_length=4, window=w, return_complex=True)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
@ -251,7 +242,7 @@ class TestSchemaCheck(JitTestCase):
|
||||
expected = torch.rand((3, 3), requires_grad=False)
|
||||
actual = torch.clone(expected)
|
||||
expected.sinh_()
|
||||
with enable_torch_dispatch_mode(SchemaCheckMode()):
|
||||
with SchemaCheckMode():
|
||||
actual.sinh_()
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
@ -262,7 +253,7 @@ class TestSchemaCheck(JitTestCase):
|
||||
actual = torch.clone(expected)
|
||||
y = actual
|
||||
expected.add_(x)
|
||||
with enable_torch_dispatch_mode(SchemaCheckMode()):
|
||||
with SchemaCheckMode():
|
||||
actual.add_(y)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
@ -272,7 +263,7 @@ class TestSchemaCheck(JitTestCase):
|
||||
m_expected, e_expected = torch.frexp(x)
|
||||
m_actual = torch.arange(9.)
|
||||
e_actual = torch.zeros([9], dtype=torch.int32)
|
||||
with enable_torch_dispatch_mode(SchemaCheckMode()):
|
||||
with SchemaCheckMode():
|
||||
torch.frexp(x, out=(m_actual, e_actual))
|
||||
self.assertEqual(m_expected, m_actual)
|
||||
self.assertEqual(e_expected, e_actual)
|
||||
@ -281,13 +272,13 @@ class TestSchemaCheck(JitTestCase):
|
||||
def test_schema_check_mode_functionality_with_multiple_outputs_aliasing(self):
|
||||
x = torch.rand((3, 3))
|
||||
actual = torch.zeros(3)
|
||||
with enable_torch_dispatch_mode(SchemaCheckMode()):
|
||||
with SchemaCheckMode():
|
||||
torch.aminmax(x, dim=0, out=[actual, actual])
|
||||
self.assertEqual(torch.amax(x, dim=0), actual)
|
||||
|
||||
# Tests that SchemaCheckMode wraps Torch.tensor in ops with real Device input
|
||||
def test_schema_check_mode_functionality_device_input(self):
|
||||
with enable_torch_dispatch_mode(SchemaCheckMode()):
|
||||
with SchemaCheckMode():
|
||||
x = torch.rand((3, 3), device="cpu", dtype=torch.double)
|
||||
y = x + x
|
||||
self.assertEqual(x + x, y)
|
||||
@ -297,7 +288,7 @@ class TestSchemaCheck(JitTestCase):
|
||||
x = torch.rand((3, 3), requires_grad=True)
|
||||
batch = torch.nn.BatchNorm1d(3, track_running_stats=True)
|
||||
expected = batch(x)
|
||||
with enable_torch_dispatch_mode(SchemaCheckMode()):
|
||||
with SchemaCheckMode():
|
||||
actual = batch(x)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
@ -311,7 +302,7 @@ class TestSchemaCheck(JitTestCase):
|
||||
expected.relu_()
|
||||
expected = batch(expected)
|
||||
|
||||
with enable_torch_dispatch_mode(SchemaCheckMode()):
|
||||
with SchemaCheckMode():
|
||||
actual.sinh_()
|
||||
actual.tanh_()
|
||||
actual.relu_()
|
||||
@ -321,7 +312,7 @@ class TestSchemaCheck(JitTestCase):
|
||||
# Tests that SchemaCheckMode wraps Torch.tensor with empty list input
|
||||
def test_schema_check_mode_empty_list_input(self):
|
||||
expected = torch.atleast_1d([])
|
||||
with enable_torch_dispatch_mode(SchemaCheckMode()):
|
||||
with SchemaCheckMode():
|
||||
actual = torch.atleast_1d([])
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
@ -330,7 +321,7 @@ class TestSchemaCheck(JitTestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, "Argument input is not defined as mutable but was mutated"):
|
||||
x = torch.rand((3, 3))
|
||||
y = torch.rand((3, 3))
|
||||
with enable_torch_dispatch_mode(SchemaCheckMode()):
|
||||
with SchemaCheckMode():
|
||||
IncorrectAliasTensor(x).sub(IncorrectAliasTensor(y))
|
||||
|
||||
# # Tests that an exception is raised for a mismatching mutation over multiple ops
|
||||
@ -338,7 +329,7 @@ class TestSchemaCheck(JitTestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, "Argument input is not defined as mutable but was mutated"):
|
||||
x = torch.rand((3, 3))
|
||||
y = torch.rand((3, 3))
|
||||
with enable_torch_dispatch_mode(SchemaCheckMode()):
|
||||
with SchemaCheckMode():
|
||||
IncorrectAliasTensor(x).sin().cos().sub(IncorrectAliasTensor(y))
|
||||
|
||||
# Tests that an exception is raised for a mismatching alias
|
||||
@ -346,7 +337,7 @@ class TestSchemaCheck(JitTestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, "Argument input is not defined to alias output but was aliasing"):
|
||||
x = torch.rand((3, 3), requires_grad=True)
|
||||
y = torch.rand((3, 3))
|
||||
with enable_torch_dispatch_mode(SchemaCheckMode()):
|
||||
with SchemaCheckMode():
|
||||
IncorrectAliasTensor(x).add(IncorrectAliasTensor(y), alpha=2)
|
||||
|
||||
# Tests that an exception is raised for a mismatching alias over multiple ops
|
||||
@ -354,7 +345,7 @@ class TestSchemaCheck(JitTestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, "Argument input is not defined to alias output but was aliasing"):
|
||||
x = torch.rand((3, 3), requires_grad=True)
|
||||
y = torch.zeros((3, 3), requires_grad=True)
|
||||
with enable_torch_dispatch_mode(SchemaCheckMode()):
|
||||
with SchemaCheckMode():
|
||||
IncorrectAliasTensor(x).sin().relu().add(IncorrectAliasTensor(y), alpha=2)
|
||||
|
||||
# Tests that an exception is raised for a centered mismatching alias over multiple ops
|
||||
@ -362,15 +353,14 @@ class TestSchemaCheck(JitTestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, "Argument input is not defined to alias output but was aliasing"):
|
||||
x = torch.rand((3, 3), requires_grad=True)
|
||||
y = torch.zeros((3, 3), requires_grad=True)
|
||||
with enable_torch_dispatch_mode(SchemaCheckMode()):
|
||||
with SchemaCheckMode():
|
||||
IncorrectAliasTensor(x).sin().add(IncorrectAliasTensor(y), alpha=2).relu()
|
||||
|
||||
# Tests that an exception is raised for a centered mismatching alias over multiple ops
|
||||
def test_alias_check_fail_outputs_unexpectedly_aliasing(self):
|
||||
with self.assertRaisesRegex(RuntimeError, "Outputs 0 and 1 alias unexpectedly"):
|
||||
x = torch.rand((3, 3))
|
||||
s = SchemaCheckMode()
|
||||
with enable_torch_dispatch_mode(s):
|
||||
with SchemaCheckMode() as s:
|
||||
IncorrectAliasTensor(x).aminmax(dim=0)
|
||||
|
||||
# Tests that is_alias_of returns as expected
|
||||
@ -439,8 +429,7 @@ class TestSchemaCheck(JitTestCase):
|
||||
|
||||
return func(*args, **kwargs)
|
||||
x = torch.rand((3, 3))
|
||||
schemaInfoCheck = SchemaInfoBindTestMode(self)
|
||||
with enable_torch_dispatch_mode(schemaInfoCheck):
|
||||
with SchemaInfoBindTestMode(self) as schemaInfoCheck:
|
||||
x.add(x)
|
||||
|
||||
|
||||
@ -452,7 +441,7 @@ class TestSchemaCheckModeOpInfo(JitTestCase):
|
||||
if (dtype == torch.complex32):
|
||||
return
|
||||
for sample in op.sample_inputs(device, dtype, requires_grad=False):
|
||||
with enable_torch_dispatch_mode(SchemaCheckMode()):
|
||||
with SchemaCheckMode():
|
||||
op(sample.input, *sample.args, **sample.kwargs)
|
||||
|
||||
instantiate_device_type_tests(TestSchemaCheckModeOpInfo, globals(), only_for=("cpu", "cuda"))
|
||||
|
Reference in New Issue
Block a user