Get SchemaCheckMode to error on ops that return inputs directly. Expose as a dynamo backend, eager_debug (#99744)

Talked to @zou3519 and @ezyang on what the right UX is: tentatively, adding a new dynamo backend is cheap and simple, so it seems worth doing. And longer term, we agreed (?) that it's worth seeing if we can get custom ops sanity asserts to run more automatically, instead of needing a separate backend.

Side comment: that actually seems tough: the mode detects secret mutations by cloning every input to every op, running the op, and checking that the data matches between the real input and the cloned input. So I doubt we'll be able to make that behavior always-on? It would need some config at least.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99744
Approved by: https://github.com/albanD, https://github.com/ezyang, https://github.com/zou3519
This commit is contained in:
Brian Hirsh
2023-04-27 14:50:59 +00:00
committed by PyTorch MergeBot
parent 1f2d00e537
commit 9834358e0f
3 changed files with 82 additions and 0 deletions

View File

@ -16,6 +16,31 @@ from torch.testing._internal.common_device_type import ops, OpDTypes, instantiat
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
def secretly_aliasing(x):
return x.view(-1)
def secretly_mutating(x):
x.mul_(2)
return x * 3
def output_is_input(x):
return x
custom_lib = torch.library.Library("bad_schemas", "DEF")
custom_lib.define("secretly_aliasing(Tensor x) -> Tensor")
custom_lib.define("secretly_mutating(Tensor x) -> Tensor")
custom_lib.define("output_is_input(Tensor(a) x) -> Tensor(a)")
custom_lib_cpu = torch.library.Library("bad_schemas", "IMPL", "CPU")
custom_lib_cpu.impl("secretly_aliasing", secretly_aliasing)
custom_lib_cpu.impl("secretly_mutating", secretly_mutating)
custom_lib_cpu.impl("output_is_input", output_is_input)
custom_lib_meta = torch.library.Library("bad_schemas", "IMPL", "Meta")
custom_lib_meta.impl("secretly_aliasing", secretly_aliasing)
custom_lib_meta.impl("secretly_mutating", secretly_mutating)
custom_lib_meta.impl("output_is_input", output_is_input)
# This TorchDispatchTensor Subclass is used to simulate an incorrect schema
# which is then used to test that SchemaCheckMode behaves as expected
@ -365,6 +390,36 @@ class TestSchemaCheck(JitTestCase):
with SchemaCheckMode() as s:
IncorrectAliasTensor(x).aminmax(dim=0)
# When this file was written, python op registration didn't exist.
# It's probably worth re-writing the entire file to use it,
# but instead I just added extra tests.
def test_alias_check_fail_custom_ops_secretly_aliasing(self):
def f(x):
return torch.ops.bad_schemas.secretly_aliasing(x)
x = torch.rand((3, 3))
with self.assertRaisesRegex(RuntimeError, "not defined to alias output but was aliasing"):
with SchemaCheckMode() as s:
out = f(x)
def test_alias_check_fail_custom_ops_secretly_mutating(self):
def f(x):
return torch.ops.bad_schemas.secretly_mutating(x)
x = torch.rand((3, 3))
with self.assertRaisesRegex(RuntimeError, "not defined as mutable but was mutated"):
with SchemaCheckMode() as s:
out = f(x)
def test_alias_check_fail_custom_ops_output_is_input(self):
def f(x):
return torch.ops.bad_schemas.output_is_input(x)
x = torch.rand((3, 3))
with self.assertRaisesRegex(RuntimeError, "are not allowed to directly return inputs"):
with SchemaCheckMode() as s:
out = f(x)
# Tests that is_alias_of returns as expected
def test_is_alias_of_basic(self):
x = torch.rand((3, 3), requires_grad=True)

View File

@ -18,6 +18,20 @@ def eager(gm, fake_tensor_inputs):
return gm
@register_backend
def eager_debug(gm, fake_tensor_inputs):
from torch._subclasses.schema_check_mode import SchemaCheckMode
# We could add more debugging bits here.
# Right now, this backend can be used to check for and error on
# custom dispatcher ops that have incorrect schemas.
def inner(*args):
with SchemaCheckMode():
return torch.fx.Interpreter(gm).run(*args)
return inner
@register_backend(name="ts")
def torchscript(gm, fake_tensor_inputs):
return torch.jit.script(gm)

View File

@ -145,6 +145,19 @@ class SchemaCheckMode(TorchDispatchMode):
self.aliasing.append(
Aliasing(func._schema.name, name, f"output_{j}")
)
if after is tuple_out[j] and isinstance(after, torch.Tensor):
# Only mutable ops e.g. (add_, add.out) are allowed to directly return inputs.
if not schema_info.is_mutable(
SchemaArgument(SchemaArgType.input, i)
) and func not in [
torch.ops.aten.lift.default,
torch.ops.aten.lift_fresh.default,
]:
raise RuntimeError(
f"""\
Dispatcher operators below autograd are not allowed to directly return inputs.
However, we found that `outputs[{str(j)}] is {name}"""
)
if any(
has_mutated(a, b, c)
for a, b, c in zip(