mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Add mutated_args field to custom_op (#123129)
If provided, we: - autogenerate an ADInplaceOrView implementation - assume that no mutated inputs are returned as outputs. There are already aliasing runtime checks that check this. Test Plan: - new tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/123129 Approved by: https://github.com/albanD ghstack dependencies: #123108, #123109, #123110
This commit is contained in:
@ -2151,6 +2151,61 @@ class TestCustomOpAPI(TestCase):
|
||||
self.assertEqual(z, x + y)
|
||||
self.assertTrue(cpu_called)
|
||||
|
||||
def test_mutated_error(self):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r".*{'y'} in mutated_args were not found"
|
||||
):
|
||||
|
||||
@torch.library.custom_op(
|
||||
"_torch_testing::numpy_sin_inplace",
|
||||
mutated_args={"y"},
|
||||
device_types="cpu",
|
||||
)
|
||||
def numpy_sin_inplace(x: Tensor) -> None:
|
||||
x_np = x.numpy()
|
||||
np.sin(x_np, out=x_np)
|
||||
|
||||
def test_mutated(self):
|
||||
@torch.library.custom_op(
|
||||
"_torch_testing::numpy_sin_inplace", mutated_args={"x"}, device_types="cpu"
|
||||
)
|
||||
def numpy_sin_inplace(x: Tensor) -> None:
|
||||
x_np = x.numpy()
|
||||
np.sin(x_np, out=x_np)
|
||||
|
||||
x = torch.randn(3)
|
||||
version = x._version
|
||||
expected = x.sin()
|
||||
numpy_sin_inplace(x)
|
||||
self.assertEqual(x, expected)
|
||||
self.assertGreater(x._version, version)
|
||||
|
||||
@torch.library.custom_op("_torch_testing::f", mutated_args={"y", "z", "w"})
|
||||
def f(
|
||||
x: Tensor, y: Optional[Tensor], z: List[Tensor], w: List[Optional[Tensor]]
|
||||
) -> None:
|
||||
return
|
||||
|
||||
x = torch.randn(3)
|
||||
y = torch.randn(3)
|
||||
z = [torch.randn(3), torch.randn(3)]
|
||||
w = [torch.randn(3), None, torch.randn(3)]
|
||||
initial_versions = pytree.tree_map_only(
|
||||
torch.Tensor, lambda x: x._version, (x, y, z, w)
|
||||
)
|
||||
f(x, y, z, w)
|
||||
new_versions = pytree.tree_map_only(
|
||||
torch.Tensor, lambda x: x._version, (x, y, z, w)
|
||||
)
|
||||
|
||||
self.assertEqual(initial_versions[0], new_versions[0])
|
||||
initial_versions, _ = pytree.tree_flatten(initial_versions[1:])
|
||||
new_versions, _ = pytree.tree_flatten(new_versions[1:])
|
||||
for prev, after in zip(initial_versions, new_versions):
|
||||
if prev is None and after is None:
|
||||
continue
|
||||
self.assertGreater(after, prev)
|
||||
|
||||
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
|
||||
def test_fake(self):
|
||||
@torch.library.custom_op("_torch_testing::add", mutated_args=())
|
||||
@ -2350,6 +2405,18 @@ Please use `add.register_fake` to add an fake impl.""",
|
||||
with self.assertRaisesRegex(RuntimeError, "may not alias"):
|
||||
f(x)
|
||||
|
||||
@torch.library.custom_op(
|
||||
"_torch_testing::f", mutated_args={"x"}, device_types="cpu"
|
||||
)
|
||||
def numpy_sin_inplace(x: Tensor) -> Tensor:
|
||||
x_np = x.numpy()
|
||||
np.sin(x_np, out=x_np)
|
||||
return x
|
||||
|
||||
x = torch.randn(3)
|
||||
with self.assertRaisesRegex(RuntimeError, "may not alias"):
|
||||
numpy_sin_inplace(x)
|
||||
|
||||
|
||||
class MiniOpTestOther(CustomOpTestCaseBase):
|
||||
test_ns = "mini_op_test"
|
||||
|
Reference in New Issue
Block a user