Add mutated_args field to custom_op (#123129)

If provided, we:
- autogenerate an ADInplaceOrView implementation
- assume that no mutated inputs are returned as outputs. There are
  already aliasing runtime checks that check this.

Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123129
Approved by: https://github.com/albanD
ghstack dependencies: #123108, #123109, #123110
This commit is contained in:
rzou
2024-04-05 11:58:23 -07:00
committed by PyTorch MergeBot
parent 9e8d2b6de2
commit 81e7a7c955
7 changed files with 151 additions and 24 deletions

View File

@ -2151,6 +2151,61 @@ class TestCustomOpAPI(TestCase):
self.assertEqual(z, x + y)
self.assertTrue(cpu_called)
def test_mutated_error(self):
with self.assertRaisesRegex(
ValueError, r".*{'y'} in mutated_args were not found"
):
@torch.library.custom_op(
"_torch_testing::numpy_sin_inplace",
mutated_args={"y"},
device_types="cpu",
)
def numpy_sin_inplace(x: Tensor) -> None:
x_np = x.numpy()
np.sin(x_np, out=x_np)
def test_mutated(self):
@torch.library.custom_op(
"_torch_testing::numpy_sin_inplace", mutated_args={"x"}, device_types="cpu"
)
def numpy_sin_inplace(x: Tensor) -> None:
x_np = x.numpy()
np.sin(x_np, out=x_np)
x = torch.randn(3)
version = x._version
expected = x.sin()
numpy_sin_inplace(x)
self.assertEqual(x, expected)
self.assertGreater(x._version, version)
@torch.library.custom_op("_torch_testing::f", mutated_args={"y", "z", "w"})
def f(
x: Tensor, y: Optional[Tensor], z: List[Tensor], w: List[Optional[Tensor]]
) -> None:
return
x = torch.randn(3)
y = torch.randn(3)
z = [torch.randn(3), torch.randn(3)]
w = [torch.randn(3), None, torch.randn(3)]
initial_versions = pytree.tree_map_only(
torch.Tensor, lambda x: x._version, (x, y, z, w)
)
f(x, y, z, w)
new_versions = pytree.tree_map_only(
torch.Tensor, lambda x: x._version, (x, y, z, w)
)
self.assertEqual(initial_versions[0], new_versions[0])
initial_versions, _ = pytree.tree_flatten(initial_versions[1:])
new_versions, _ = pytree.tree_flatten(new_versions[1:])
for prev, after in zip(initial_versions, new_versions):
if prev is None and after is None:
continue
self.assertGreater(after, prev)
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
def test_fake(self):
@torch.library.custom_op("_torch_testing::add", mutated_args=())
@ -2350,6 +2405,18 @@ Please use `add.register_fake` to add an fake impl.""",
with self.assertRaisesRegex(RuntimeError, "may not alias"):
f(x)
@torch.library.custom_op(
"_torch_testing::f", mutated_args={"x"}, device_types="cpu"
)
def numpy_sin_inplace(x: Tensor) -> Tensor:
x_np = x.numpy()
np.sin(x_np, out=x_np)
return x
x = torch.randn(3)
with self.assertRaisesRegex(RuntimeError, "may not alias"):
numpy_sin_inplace(x)
class MiniOpTestOther(CustomOpTestCaseBase):
test_ns = "mini_op_test"