mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Get SchemaCheckMode to error on ops that return inputs directly. Expose as a dynamo backend, eager_debug (#99744)
Talked to @zou3519 and @ezyang on what the right UX is: tentatively, adding a new dynamo backend is cheap and simple, so it seems worth doing. And longer term, we agreed (?) that it's worth seeing if we can get custom ops sanity asserts to run more automatically, instead of needing a separate backend. Side comment: that actually seems tough: the mode detects secret mutations by cloning every input to every op, running the op, and checking that the data matches between the real input and the cloned input. So I doubt we'll be able to make that behavior always-on? It would need some config at least. Pull Request resolved: https://github.com/pytorch/pytorch/pull/99744 Approved by: https://github.com/albanD, https://github.com/ezyang, https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
1f2d00e537
commit
9834358e0f
@ -16,6 +16,31 @@ from torch.testing._internal.common_device_type import ops, OpDTypes, instantiat
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
|
||||
def secretly_aliasing(x):
|
||||
return x.view(-1)
|
||||
|
||||
def secretly_mutating(x):
|
||||
x.mul_(2)
|
||||
return x * 3
|
||||
|
||||
def output_is_input(x):
|
||||
return x
|
||||
|
||||
custom_lib = torch.library.Library("bad_schemas", "DEF")
|
||||
custom_lib.define("secretly_aliasing(Tensor x) -> Tensor")
|
||||
custom_lib.define("secretly_mutating(Tensor x) -> Tensor")
|
||||
custom_lib.define("output_is_input(Tensor(a) x) -> Tensor(a)")
|
||||
|
||||
custom_lib_cpu = torch.library.Library("bad_schemas", "IMPL", "CPU")
|
||||
custom_lib_cpu.impl("secretly_aliasing", secretly_aliasing)
|
||||
custom_lib_cpu.impl("secretly_mutating", secretly_mutating)
|
||||
custom_lib_cpu.impl("output_is_input", output_is_input)
|
||||
|
||||
custom_lib_meta = torch.library.Library("bad_schemas", "IMPL", "Meta")
|
||||
custom_lib_meta.impl("secretly_aliasing", secretly_aliasing)
|
||||
custom_lib_meta.impl("secretly_mutating", secretly_mutating)
|
||||
custom_lib_meta.impl("output_is_input", output_is_input)
|
||||
|
||||
# This TorchDispatchTensor Subclass is used to simulate an incorrect schema
|
||||
# which is then used to test that SchemaCheckMode behaves as expected
|
||||
|
||||
@ -365,6 +390,36 @@ class TestSchemaCheck(JitTestCase):
|
||||
with SchemaCheckMode() as s:
|
||||
IncorrectAliasTensor(x).aminmax(dim=0)
|
||||
|
||||
# When this file was written, python op registration didn't exist.
|
||||
# It's probably worth re-writing the entire file to use it,
|
||||
# but instead I just added extra tests.
|
||||
def test_alias_check_fail_custom_ops_secretly_aliasing(self):
|
||||
def f(x):
|
||||
return torch.ops.bad_schemas.secretly_aliasing(x)
|
||||
|
||||
x = torch.rand((3, 3))
|
||||
with self.assertRaisesRegex(RuntimeError, "not defined to alias output but was aliasing"):
|
||||
with SchemaCheckMode() as s:
|
||||
out = f(x)
|
||||
|
||||
def test_alias_check_fail_custom_ops_secretly_mutating(self):
|
||||
def f(x):
|
||||
return torch.ops.bad_schemas.secretly_mutating(x)
|
||||
|
||||
x = torch.rand((3, 3))
|
||||
with self.assertRaisesRegex(RuntimeError, "not defined as mutable but was mutated"):
|
||||
with SchemaCheckMode() as s:
|
||||
out = f(x)
|
||||
|
||||
def test_alias_check_fail_custom_ops_output_is_input(self):
|
||||
def f(x):
|
||||
return torch.ops.bad_schemas.output_is_input(x)
|
||||
|
||||
x = torch.rand((3, 3))
|
||||
with self.assertRaisesRegex(RuntimeError, "are not allowed to directly return inputs"):
|
||||
with SchemaCheckMode() as s:
|
||||
out = f(x)
|
||||
|
||||
# Tests that is_alias_of returns as expected
|
||||
def test_is_alias_of_basic(self):
|
||||
x = torch.rand((3, 3), requires_grad=True)
|
||||
|
@ -18,6 +18,20 @@ def eager(gm, fake_tensor_inputs):
|
||||
return gm
|
||||
|
||||
|
||||
@register_backend
|
||||
def eager_debug(gm, fake_tensor_inputs):
|
||||
from torch._subclasses.schema_check_mode import SchemaCheckMode
|
||||
|
||||
# We could add more debugging bits here.
|
||||
# Right now, this backend can be used to check for and error on
|
||||
# custom dispatcher ops that have incorrect schemas.
|
||||
def inner(*args):
|
||||
with SchemaCheckMode():
|
||||
return torch.fx.Interpreter(gm).run(*args)
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
@register_backend(name="ts")
|
||||
def torchscript(gm, fake_tensor_inputs):
|
||||
return torch.jit.script(gm)
|
||||
|
@ -145,6 +145,19 @@ class SchemaCheckMode(TorchDispatchMode):
|
||||
self.aliasing.append(
|
||||
Aliasing(func._schema.name, name, f"output_{j}")
|
||||
)
|
||||
if after is tuple_out[j] and isinstance(after, torch.Tensor):
|
||||
# Only mutable ops e.g. (add_, add.out) are allowed to directly return inputs.
|
||||
if not schema_info.is_mutable(
|
||||
SchemaArgument(SchemaArgType.input, i)
|
||||
) and func not in [
|
||||
torch.ops.aten.lift.default,
|
||||
torch.ops.aten.lift_fresh.default,
|
||||
]:
|
||||
raise RuntimeError(
|
||||
f"""\
|
||||
Dispatcher operators below autograd are not allowed to directly return inputs.
|
||||
However, we found that `outputs[{str(j)}] is {name}"""
|
||||
)
|
||||
if any(
|
||||
has_mutated(a, b, c)
|
||||
for a, b, c in zip(
|
||||
|
Reference in New Issue
Block a user