Respect TorchDispatchMode for shallow_copy_and_detach (#83372)

I noticed I was missing tensor creations with modes when I tried
to delete proxy tensor.  This was the cause.

Hypothetically, all PyInterpreter calls could get this treatment.
But I think it only matters for detach; the rest do not return
Tensors and most modes will not be interested in them.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83372
Approved by: https://github.com/zou3519
This commit is contained in:
Edward Z. Yang
2022-08-15 20:03:12 -07:00
committed by PyTorch MergeBot
parent 1665715cb0
commit a3907ca92d
4 changed files with 42 additions and 13 deletions

View File

@ -76,7 +76,7 @@ class TestSchemaCheck(JitTestCase):
with enable_torch_dispatch_mode(schema_check):
x = torch.rand((3, 3), requires_grad=True)
x.relu().sin()
self.assertEqual(["aten::rand", "aten::relu", "aten::sin"], schema_check.ops)
self.assertEqual(["aten::rand", "aten::relu", "aten::detach", "aten::sin"], schema_check.ops)
# Tests that SchemaCheckMode records operator order without grad
def test_schema_check_mode_operator_order_without_grad(self):
@ -88,7 +88,9 @@ class TestSchemaCheck(JitTestCase):
# Tests that SchemaCheckMode records mutations and aliases with none expected
def test_schema_check_mode_mutated_aliasing_none(self):
x = torch.rand((3, 3), requires_grad=True)
# NB: previously requires_grad=True, but this induces a detach for
# saved variable
x = torch.rand((3, 3))
schema_check = SchemaCheckMode()
with enable_torch_dispatch_mode(schema_check):
actual = x.relu().sin()