From 9834358e0f5ac36d12b6e3855ec8fb4b96dc3604 Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Thu, 27 Apr 2023 14:50:59 +0000 Subject: [PATCH] Get SchemaCheckMode to error on ops that return inputs directly. Expose as a dynamo backend, eager_debug (#99744) Talked to @zou3519 and @ezyang on what the right UX is: tentatively, adding a new dynamo backend is cheap and simple, so it seems worth doing. And longer term, we agreed (?) that it's worth seeing if we can get custom ops sanity asserts to run more automatically, instead of needing a separate backend. Side comment: that actually seems tough: the mode detects secret mutations by cloning every input to every op, running the op, and checking that the data matches between the real input and the cloned input. So I doubt we'll be able to make that behavior always-on? It would need some config at least. Pull Request resolved: https://github.com/pytorch/pytorch/pull/99744 Approved by: https://github.com/albanD, https://github.com/ezyang, https://github.com/zou3519 --- test/test_schema_check.py | 55 ++++++++++++++++++++++++++ torch/_dynamo/backends/debugging.py | 14 +++++++ torch/_subclasses/schema_check_mode.py | 13 ++++++ 3 files changed, 82 insertions(+) diff --git a/test/test_schema_check.py b/test/test_schema_check.py index c6c50a387368..355879638171 100644 --- a/test/test_schema_check.py +++ b/test/test_schema_check.py @@ -16,6 +16,31 @@ from torch.testing._internal.common_device_type import ops, OpDTypes, instantiat pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) +def secretly_aliasing(x): + return x.view(-1) + +def secretly_mutating(x): + x.mul_(2) + return x * 3 + +def output_is_input(x): + return x + +custom_lib = torch.library.Library("bad_schemas", "DEF") +custom_lib.define("secretly_aliasing(Tensor x) -> Tensor") +custom_lib.define("secretly_mutating(Tensor x) -> Tensor") +custom_lib.define("output_is_input(Tensor(a) x) -> Tensor(a)") + +custom_lib_cpu = torch.library.Library("bad_schemas", "IMPL", "CPU") +custom_lib_cpu.impl("secretly_aliasing", secretly_aliasing) +custom_lib_cpu.impl("secretly_mutating", secretly_mutating) +custom_lib_cpu.impl("output_is_input", output_is_input) + +custom_lib_meta = torch.library.Library("bad_schemas", "IMPL", "Meta") +custom_lib_meta.impl("secretly_aliasing", secretly_aliasing) +custom_lib_meta.impl("secretly_mutating", secretly_mutating) +custom_lib_meta.impl("output_is_input", output_is_input) + # This TorchDispatchTensor Subclass is used to simulate an incorrect schema # which is then used to test that SchemaCheckMode behaves as expected @@ -365,6 +390,36 @@ class TestSchemaCheck(JitTestCase): with SchemaCheckMode() as s: IncorrectAliasTensor(x).aminmax(dim=0) + # When this file was written, python op registration didn't exist. + # It's probably worth re-writing the entire file to use it, + # but instead I just added extra tests. + def test_alias_check_fail_custom_ops_secretly_aliasing(self): + def f(x): + return torch.ops.bad_schemas.secretly_aliasing(x) + + x = torch.rand((3, 3)) + with self.assertRaisesRegex(RuntimeError, "not defined to alias output but was aliasing"): + with SchemaCheckMode() as s: + out = f(x) + + def test_alias_check_fail_custom_ops_secretly_mutating(self): + def f(x): + return torch.ops.bad_schemas.secretly_mutating(x) + + x = torch.rand((3, 3)) + with self.assertRaisesRegex(RuntimeError, "not defined as mutable but was mutated"): + with SchemaCheckMode() as s: + out = f(x) + + def test_alias_check_fail_custom_ops_output_is_input(self): + def f(x): + return torch.ops.bad_schemas.output_is_input(x) + + x = torch.rand((3, 3)) + with self.assertRaisesRegex(RuntimeError, "are not allowed to directly return inputs"): + with SchemaCheckMode() as s: + out = f(x) + # Tests that is_alias_of returns as expected def test_is_alias_of_basic(self): x = torch.rand((3, 3), requires_grad=True) diff --git a/torch/_dynamo/backends/debugging.py b/torch/_dynamo/backends/debugging.py index 7b5a291b0dad..f2457b3c6d77 100644 --- a/torch/_dynamo/backends/debugging.py +++ b/torch/_dynamo/backends/debugging.py @@ -18,6 +18,20 @@ def eager(gm, fake_tensor_inputs): return gm +@register_backend +def eager_debug(gm, fake_tensor_inputs): + from torch._subclasses.schema_check_mode import SchemaCheckMode + + # We could add more debugging bits here. + # Right now, this backend can be used to check for and error on + # custom dispatcher ops that have incorrect schemas. + def inner(*args): + with SchemaCheckMode(): + return torch.fx.Interpreter(gm).run(*args) + + return inner + + @register_backend(name="ts") def torchscript(gm, fake_tensor_inputs): return torch.jit.script(gm) diff --git a/torch/_subclasses/schema_check_mode.py b/torch/_subclasses/schema_check_mode.py index 1535a43e1763..8545ac63cedf 100644 --- a/torch/_subclasses/schema_check_mode.py +++ b/torch/_subclasses/schema_check_mode.py @@ -145,6 +145,19 @@ class SchemaCheckMode(TorchDispatchMode): self.aliasing.append( Aliasing(func._schema.name, name, f"output_{j}") ) + if after is tuple_out[j] and isinstance(after, torch.Tensor): + # Only mutable ops e.g. (add_, add.out) are allowed to directly return inputs. + if not schema_info.is_mutable( + SchemaArgument(SchemaArgType.input, i) + ) and func not in [ + torch.ops.aten.lift.default, + torch.ops.aten.lift_fresh.default, + ]: + raise RuntimeError( + f"""\ +Dispatcher operators below autograd are not allowed to directly return inputs. +However, we found that `outputs[{str(j)}] is {name}""" + ) if any( has_mutated(a, b, c) for a, b, c in zip(