Add mutated_args field to custom_op (#123129)

If provided, we:
- autogenerate an ADInplaceOrView implementation
- assume that no mutated inputs are returned as outputs. There are
  already aliasing runtime checks that check this.

Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123129
Approved by: https://github.com/albanD
ghstack dependencies: #123108, #123109, #123110
This commit is contained in:
rzou
2024-04-05 11:58:23 -07:00
committed by PyTorch MergeBot
parent 9e8d2b6de2
commit 81e7a7c955
7 changed files with 151 additions and 24 deletions

View File

@ -2151,6 +2151,61 @@ class TestCustomOpAPI(TestCase):
self.assertEqual(z, x + y)
self.assertTrue(cpu_called)
def test_mutated_error(self):
with self.assertRaisesRegex(
ValueError, r".*{'y'} in mutated_args were not found"
):
@torch.library.custom_op(
"_torch_testing::numpy_sin_inplace",
mutated_args={"y"},
device_types="cpu",
)
def numpy_sin_inplace(x: Tensor) -> None:
x_np = x.numpy()
np.sin(x_np, out=x_np)
def test_mutated(self):
@torch.library.custom_op(
"_torch_testing::numpy_sin_inplace", mutated_args={"x"}, device_types="cpu"
)
def numpy_sin_inplace(x: Tensor) -> None:
x_np = x.numpy()
np.sin(x_np, out=x_np)
x = torch.randn(3)
version = x._version
expected = x.sin()
numpy_sin_inplace(x)
self.assertEqual(x, expected)
self.assertGreater(x._version, version)
@torch.library.custom_op("_torch_testing::f", mutated_args={"y", "z", "w"})
def f(
x: Tensor, y: Optional[Tensor], z: List[Tensor], w: List[Optional[Tensor]]
) -> None:
return
x = torch.randn(3)
y = torch.randn(3)
z = [torch.randn(3), torch.randn(3)]
w = [torch.randn(3), None, torch.randn(3)]
initial_versions = pytree.tree_map_only(
torch.Tensor, lambda x: x._version, (x, y, z, w)
)
f(x, y, z, w)
new_versions = pytree.tree_map_only(
torch.Tensor, lambda x: x._version, (x, y, z, w)
)
self.assertEqual(initial_versions[0], new_versions[0])
initial_versions, _ = pytree.tree_flatten(initial_versions[1:])
new_versions, _ = pytree.tree_flatten(new_versions[1:])
for prev, after in zip(initial_versions, new_versions):
if prev is None and after is None:
continue
self.assertGreater(after, prev)
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
def test_fake(self):
@torch.library.custom_op("_torch_testing::add", mutated_args=())
@ -2350,6 +2405,18 @@ Please use `add.register_fake` to add an fake impl.""",
with self.assertRaisesRegex(RuntimeError, "may not alias"):
f(x)
@torch.library.custom_op(
"_torch_testing::f", mutated_args={"x"}, device_types="cpu"
)
def numpy_sin_inplace(x: Tensor) -> Tensor:
x_np = x.numpy()
np.sin(x_np, out=x_np)
return x
x = torch.randn(3)
with self.assertRaisesRegex(RuntimeError, "may not alias"):
numpy_sin_inplace(x)
class MiniOpTestOther(CustomOpTestCaseBase):
test_ns = "mini_op_test"

View File

@ -1535,6 +1535,11 @@ class _AutoDispatchBelowAutograd:
def __enter__(self): ...
def __exit__(self, exc_type, exc_value, traceback): ...
class _AutoDispatchBelowADInplaceOrView:
def __init__(self): ...
def __enter__(self): ...
def __exit__(self, exc_type, exc_value, traceback): ...
def _dispatch_print_registrations_for_dispatch_key(dispatch_key: str = "") -> None: ...
def _dispatch_get_registrations_for_dispatch_key(
dispatch_key: str = "",

View File

@ -787,6 +787,7 @@ def infer_schema(prototype_function: typing.Callable, mutated_args=()) -> str:
)
params = []
seen_args = set()
for idx, (name, param) in enumerate(sig.parameters.items()):
if not supported_param(param):
error_fn("We do not support positional-only args, varargs, or varkwargs.")
@ -811,7 +812,14 @@ def infer_schema(prototype_function: typing.Callable, mutated_args=()) -> str:
if not schema_type.startswith("Tensor"):
error_fn(f"Parameter {name} is in mutable_args but only Tensors or collections of Tensors can be mutated")
schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor'):]}"
seen_args.add(name)
params.append(f"{schema_type} {name}")
mutated_args_not_seen = set(mutated_args) - seen_args
if len(mutated_args_not_seen) > 0:
error_fn(f"{mutated_args_not_seen} in mutated_args were not found in "
f"the custom op's signature. "
f"mutated_args should contain the names of all args that the "
f"custom op mutates.")
ret = parse_return(sig.return_annotation, error_fn)
return f"({', '.join(params)}) -> {ret}"

View File

@ -1,9 +1,20 @@
import inspect
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Union
from typing import (
Any,
Callable,
Dict,
Iterable,
Iterator,
List,
Optional,
Sequence,
Tuple,
Union,
)
from torch.utils._exposed_in import exposed_in
from .. import _C, _library, library, Tensor
from .. import _C, _library, autograd, library, Tensor
device_types_t = Optional[Union[str, Sequence[str]]]
@ -14,7 +25,7 @@ def custom_op(
name: str,
/,
*,
mutated_args: Sequence[str],
mutated_args: Iterable[str],
device_types: device_types_t = None,
qualname: Optional[str] = None,
) -> Callable:
@ -34,7 +45,7 @@ def custom_op(
e.g. "mylib::my_linear". The name is used as a stable identifier for
if you wish to serialize the custom op, e.g., via torch.save/torch.export.
To avoid name collisions, please use your project name as the namespace.
mutated_args (Sequence[str]): The names of args that the function mutates.
mutated_args (Iterable[str]): The names of args that the function mutates.
This MUST be accurate, otherwise, the behavior is undefined.
device_types (None | str | Sequence[str]): The device type(s) the function
is valid for. If no device type is provided, then the function
@ -67,9 +78,19 @@ def custom_op(
>>> x = torch.randn(3)
>>> y = numpy_sin_cpu(x)
>>> assert torch.allclose(y, x.sin())
>>>
>>> # Example of a custom op that mutates an input
>>> @custom_op("mylib::numpy_sin_inplace", mutated_args={"x"}, device_types="cpu")
>>> def numpy_sin_inplace(x: Tensor) -> None:
>>> x_np = x.numpy()
>>> np.sin(x_np, out=x_np)
>>>
>>> x = torch.randn(3)
>>> expected = x.sin()
>>> numpy_sin_inplace(x)
>>> assert torch.allclose(x, expected)
"""
assert len(mutated_args) == 0, "NYI"
def inner(fn):
import torch
@ -366,6 +387,8 @@ class CustomOpDef:
def fake_impl(*args, **kwargs):
if self._abstract_fn is None:
if _library.utils.can_generate_trivial_fake_impl(self._opoverload):
return None
raise RuntimeError(
f"There was no fake impl registered for {self}. "
f"This is necessary for torch.compile/export/fx tracing to work. "
@ -379,6 +402,26 @@ class CustomOpDef:
autograd_impl = _library.autograd.make_autograd_impl(self)
lib.impl(self._name, autograd_impl, "Autograd")
schema = self._opoverload._schema
if schema.is_mutable:
def adinplaceorview_impl(*args, **kwargs):
for arg, val in _library.utils.zip_schema(schema, args, kwargs):
if not arg.alias_info:
continue
if not arg.alias_info.is_write:
continue
if isinstance(val, Tensor):
autograd.graph.increment_version(val)
elif isinstance(val, (tuple, list)):
for v in val:
if isinstance(v, Tensor):
autograd.graph.increment_version(v)
with _C._AutoDispatchBelowADInplaceOrView():
return self._opoverload(*args, **kwargs)
lib.impl(self._name, adinplaceorview_impl, "ADInplaceOrView")
def __call__(self, *args, **kwargs):
return self._opoverload(*args, **kwargs)

View File

@ -1,7 +1,7 @@
import dataclasses
import inspect
import sys
from typing import Any, Callable, Tuple
from typing import Any, Callable, Dict, Iterable, Tuple
import torch
from torch import _C
@ -151,7 +151,9 @@ def mutates_and_returns_first_arg(op: torch._ops.OpOverload):
return True
def zip_schema(schema, args, kwargs):
def zip_schema(
schema: _C.FunctionSchema, args: Tuple[Any, ...], kwargs: Dict[str, Any]
) -> Iterable[Tuple[_C.Argument, Any]]:
"""zips schema.arguments and (args, kwargs) together.
Assumes that (args, kwargs) were the inputs to some torch._ops.OpOverload:
@ -171,3 +173,19 @@ def zip_schema(schema, args, kwargs):
continue
yield info, args[i]
return
def can_generate_trivial_fake_impl(op: torch._ops.OpOverload) -> bool:
assert isinstance(op, torch._ops.OpOverload)
if is_builtin(op):
# We control the built-ins. These may (in rare cases)
# do input metadata mutation (which we have banned on custom ops)
return False
schema = op._schema
# It's suspicious if the op is not mutable but returns nothing, so we return False out of an abundance of caution
if not schema.is_mutable:
return False
if len(schema.returns) > 0:
return False
# If the op returns nothing, then it has a trivial fake impl.
return True

View File

@ -1380,7 +1380,7 @@ class FakeTensorMode(TorchDispatchMode):
# We infer the meta of a custom ops that return None to just
# return None. custom ops are not allowed to mutate metadata
# of their inputs, so this is safe.
if can_generate_trivial_abstract_impl(func):
if torch._library.utils.can_generate_trivial_fake_impl(func):
return None
# no meta kernel registered, fallback to kernel for the device
if has_symbolic_sizes or not self.can_run_unsafe_fallback(func):
@ -1669,22 +1669,6 @@ def run_fallback_kernel(
return pytree.tree_map(map_out, r)
def can_generate_trivial_abstract_impl(op: torch._ops.OpOverload) -> bool:
assert isinstance(op, torch._ops.OpOverload)
if torch._library.utils.is_builtin(op):
# We control the built-ins. These may (in rare cases)
# do input metadata mutation (which we have banned on custom ops)
return False
schema = op._schema
# It's suspicious if the op is not mutable but returns nothing, so we return False out of an abundance of caution
if not schema.is_mutable:
return False
if len(schema.returns) > 0:
return False
# If the op returns nothing, then it has a trivial abstract impl.
return True
# Just for use to allow copying a module to fake tensors,
# does not apply elsewhere
class FakeCopyMode(TorchFunctionMode):

View File

@ -730,6 +730,8 @@ void initDispatchBindings(PyObject* module) {
py_context_manager_DEPRECATED<at::AutoDispatchBelowAutograd>(
m, "_AutoDispatchBelowAutograd");
py_context_manager<at::AutoDispatchBelowADInplaceOrView>(
m, "_AutoDispatchBelowADInplaceOrView");
// Prints out the name of every operator that has a kernel registered to the
// Dispatcher under [dispatch_key]. If no arguments are specified, it'll print