mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add mutated_args field to custom_op (#123129)
If provided, we: - autogenerate an ADInplaceOrView implementation - assume that no mutated inputs are returned as outputs. There are already aliasing runtime checks that check this. Test Plan: - new tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/123129 Approved by: https://github.com/albanD ghstack dependencies: #123108, #123109, #123110
This commit is contained in:
@ -2151,6 +2151,61 @@ class TestCustomOpAPI(TestCase):
|
||||
self.assertEqual(z, x + y)
|
||||
self.assertTrue(cpu_called)
|
||||
|
||||
def test_mutated_error(self):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r".*{'y'} in mutated_args were not found"
|
||||
):
|
||||
|
||||
@torch.library.custom_op(
|
||||
"_torch_testing::numpy_sin_inplace",
|
||||
mutated_args={"y"},
|
||||
device_types="cpu",
|
||||
)
|
||||
def numpy_sin_inplace(x: Tensor) -> None:
|
||||
x_np = x.numpy()
|
||||
np.sin(x_np, out=x_np)
|
||||
|
||||
def test_mutated(self):
|
||||
@torch.library.custom_op(
|
||||
"_torch_testing::numpy_sin_inplace", mutated_args={"x"}, device_types="cpu"
|
||||
)
|
||||
def numpy_sin_inplace(x: Tensor) -> None:
|
||||
x_np = x.numpy()
|
||||
np.sin(x_np, out=x_np)
|
||||
|
||||
x = torch.randn(3)
|
||||
version = x._version
|
||||
expected = x.sin()
|
||||
numpy_sin_inplace(x)
|
||||
self.assertEqual(x, expected)
|
||||
self.assertGreater(x._version, version)
|
||||
|
||||
@torch.library.custom_op("_torch_testing::f", mutated_args={"y", "z", "w"})
|
||||
def f(
|
||||
x: Tensor, y: Optional[Tensor], z: List[Tensor], w: List[Optional[Tensor]]
|
||||
) -> None:
|
||||
return
|
||||
|
||||
x = torch.randn(3)
|
||||
y = torch.randn(3)
|
||||
z = [torch.randn(3), torch.randn(3)]
|
||||
w = [torch.randn(3), None, torch.randn(3)]
|
||||
initial_versions = pytree.tree_map_only(
|
||||
torch.Tensor, lambda x: x._version, (x, y, z, w)
|
||||
)
|
||||
f(x, y, z, w)
|
||||
new_versions = pytree.tree_map_only(
|
||||
torch.Tensor, lambda x: x._version, (x, y, z, w)
|
||||
)
|
||||
|
||||
self.assertEqual(initial_versions[0], new_versions[0])
|
||||
initial_versions, _ = pytree.tree_flatten(initial_versions[1:])
|
||||
new_versions, _ = pytree.tree_flatten(new_versions[1:])
|
||||
for prev, after in zip(initial_versions, new_versions):
|
||||
if prev is None and after is None:
|
||||
continue
|
||||
self.assertGreater(after, prev)
|
||||
|
||||
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
|
||||
def test_fake(self):
|
||||
@torch.library.custom_op("_torch_testing::add", mutated_args=())
|
||||
@ -2350,6 +2405,18 @@ Please use `add.register_fake` to add an fake impl.""",
|
||||
with self.assertRaisesRegex(RuntimeError, "may not alias"):
|
||||
f(x)
|
||||
|
||||
@torch.library.custom_op(
|
||||
"_torch_testing::f", mutated_args={"x"}, device_types="cpu"
|
||||
)
|
||||
def numpy_sin_inplace(x: Tensor) -> Tensor:
|
||||
x_np = x.numpy()
|
||||
np.sin(x_np, out=x_np)
|
||||
return x
|
||||
|
||||
x = torch.randn(3)
|
||||
with self.assertRaisesRegex(RuntimeError, "may not alias"):
|
||||
numpy_sin_inplace(x)
|
||||
|
||||
|
||||
class MiniOpTestOther(CustomOpTestCaseBase):
|
||||
test_ns = "mini_op_test"
|
||||
|
@ -1535,6 +1535,11 @@ class _AutoDispatchBelowAutograd:
|
||||
def __enter__(self): ...
|
||||
def __exit__(self, exc_type, exc_value, traceback): ...
|
||||
|
||||
class _AutoDispatchBelowADInplaceOrView:
|
||||
def __init__(self): ...
|
||||
def __enter__(self): ...
|
||||
def __exit__(self, exc_type, exc_value, traceback): ...
|
||||
|
||||
def _dispatch_print_registrations_for_dispatch_key(dispatch_key: str = "") -> None: ...
|
||||
def _dispatch_get_registrations_for_dispatch_key(
|
||||
dispatch_key: str = "",
|
||||
|
@ -787,6 +787,7 @@ def infer_schema(prototype_function: typing.Callable, mutated_args=()) -> str:
|
||||
)
|
||||
|
||||
params = []
|
||||
seen_args = set()
|
||||
for idx, (name, param) in enumerate(sig.parameters.items()):
|
||||
if not supported_param(param):
|
||||
error_fn("We do not support positional-only args, varargs, or varkwargs.")
|
||||
@ -811,7 +812,14 @@ def infer_schema(prototype_function: typing.Callable, mutated_args=()) -> str:
|
||||
if not schema_type.startswith("Tensor"):
|
||||
error_fn(f"Parameter {name} is in mutable_args but only Tensors or collections of Tensors can be mutated")
|
||||
schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor'):]}"
|
||||
seen_args.add(name)
|
||||
params.append(f"{schema_type} {name}")
|
||||
mutated_args_not_seen = set(mutated_args) - seen_args
|
||||
if len(mutated_args_not_seen) > 0:
|
||||
error_fn(f"{mutated_args_not_seen} in mutated_args were not found in "
|
||||
f"the custom op's signature. "
|
||||
f"mutated_args should contain the names of all args that the "
|
||||
f"custom op mutates.")
|
||||
ret = parse_return(sig.return_annotation, error_fn)
|
||||
return f"({', '.join(params)}) -> {ret}"
|
||||
|
||||
|
@ -1,9 +1,20 @@
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
from torch.utils._exposed_in import exposed_in
|
||||
|
||||
from .. import _C, _library, library, Tensor
|
||||
from .. import _C, _library, autograd, library, Tensor
|
||||
|
||||
|
||||
device_types_t = Optional[Union[str, Sequence[str]]]
|
||||
@ -14,7 +25,7 @@ def custom_op(
|
||||
name: str,
|
||||
/,
|
||||
*,
|
||||
mutated_args: Sequence[str],
|
||||
mutated_args: Iterable[str],
|
||||
device_types: device_types_t = None,
|
||||
qualname: Optional[str] = None,
|
||||
) -> Callable:
|
||||
@ -34,7 +45,7 @@ def custom_op(
|
||||
e.g. "mylib::my_linear". The name is used as a stable identifier for
|
||||
if you wish to serialize the custom op, e.g., via torch.save/torch.export.
|
||||
To avoid name collisions, please use your project name as the namespace.
|
||||
mutated_args (Sequence[str]): The names of args that the function mutates.
|
||||
mutated_args (Iterable[str]): The names of args that the function mutates.
|
||||
This MUST be accurate, otherwise, the behavior is undefined.
|
||||
device_types (None | str | Sequence[str]): The device type(s) the function
|
||||
is valid for. If no device type is provided, then the function
|
||||
@ -67,9 +78,19 @@ def custom_op(
|
||||
>>> x = torch.randn(3)
|
||||
>>> y = numpy_sin_cpu(x)
|
||||
>>> assert torch.allclose(y, x.sin())
|
||||
>>>
|
||||
>>> # Example of a custom op that mutates an input
|
||||
>>> @custom_op("mylib::numpy_sin_inplace", mutated_args={"x"}, device_types="cpu")
|
||||
>>> def numpy_sin_inplace(x: Tensor) -> None:
|
||||
>>> x_np = x.numpy()
|
||||
>>> np.sin(x_np, out=x_np)
|
||||
>>>
|
||||
>>> x = torch.randn(3)
|
||||
>>> expected = x.sin()
|
||||
>>> numpy_sin_inplace(x)
|
||||
>>> assert torch.allclose(x, expected)
|
||||
|
||||
"""
|
||||
assert len(mutated_args) == 0, "NYI"
|
||||
|
||||
def inner(fn):
|
||||
import torch
|
||||
@ -366,6 +387,8 @@ class CustomOpDef:
|
||||
|
||||
def fake_impl(*args, **kwargs):
|
||||
if self._abstract_fn is None:
|
||||
if _library.utils.can_generate_trivial_fake_impl(self._opoverload):
|
||||
return None
|
||||
raise RuntimeError(
|
||||
f"There was no fake impl registered for {self}. "
|
||||
f"This is necessary for torch.compile/export/fx tracing to work. "
|
||||
@ -379,6 +402,26 @@ class CustomOpDef:
|
||||
autograd_impl = _library.autograd.make_autograd_impl(self)
|
||||
lib.impl(self._name, autograd_impl, "Autograd")
|
||||
|
||||
schema = self._opoverload._schema
|
||||
if schema.is_mutable:
|
||||
|
||||
def adinplaceorview_impl(*args, **kwargs):
|
||||
for arg, val in _library.utils.zip_schema(schema, args, kwargs):
|
||||
if not arg.alias_info:
|
||||
continue
|
||||
if not arg.alias_info.is_write:
|
||||
continue
|
||||
if isinstance(val, Tensor):
|
||||
autograd.graph.increment_version(val)
|
||||
elif isinstance(val, (tuple, list)):
|
||||
for v in val:
|
||||
if isinstance(v, Tensor):
|
||||
autograd.graph.increment_version(v)
|
||||
with _C._AutoDispatchBelowADInplaceOrView():
|
||||
return self._opoverload(*args, **kwargs)
|
||||
|
||||
lib.impl(self._name, adinplaceorview_impl, "ADInplaceOrView")
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self._opoverload(*args, **kwargs)
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
import dataclasses
|
||||
import inspect
|
||||
import sys
|
||||
from typing import Any, Callable, Tuple
|
||||
from typing import Any, Callable, Dict, Iterable, Tuple
|
||||
|
||||
import torch
|
||||
from torch import _C
|
||||
@ -151,7 +151,9 @@ def mutates_and_returns_first_arg(op: torch._ops.OpOverload):
|
||||
return True
|
||||
|
||||
|
||||
def zip_schema(schema, args, kwargs):
|
||||
def zip_schema(
|
||||
schema: _C.FunctionSchema, args: Tuple[Any, ...], kwargs: Dict[str, Any]
|
||||
) -> Iterable[Tuple[_C.Argument, Any]]:
|
||||
"""zips schema.arguments and (args, kwargs) together.
|
||||
|
||||
Assumes that (args, kwargs) were the inputs to some torch._ops.OpOverload:
|
||||
@ -171,3 +173,19 @@ def zip_schema(schema, args, kwargs):
|
||||
continue
|
||||
yield info, args[i]
|
||||
return
|
||||
|
||||
|
||||
def can_generate_trivial_fake_impl(op: torch._ops.OpOverload) -> bool:
|
||||
assert isinstance(op, torch._ops.OpOverload)
|
||||
if is_builtin(op):
|
||||
# We control the built-ins. These may (in rare cases)
|
||||
# do input metadata mutation (which we have banned on custom ops)
|
||||
return False
|
||||
schema = op._schema
|
||||
# It's suspicious if the op is not mutable but returns nothing, so we return False out of an abundance of caution
|
||||
if not schema.is_mutable:
|
||||
return False
|
||||
if len(schema.returns) > 0:
|
||||
return False
|
||||
# If the op returns nothing, then it has a trivial fake impl.
|
||||
return True
|
||||
|
@ -1380,7 +1380,7 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
# We infer the meta of a custom ops that return None to just
|
||||
# return None. custom ops are not allowed to mutate metadata
|
||||
# of their inputs, so this is safe.
|
||||
if can_generate_trivial_abstract_impl(func):
|
||||
if torch._library.utils.can_generate_trivial_fake_impl(func):
|
||||
return None
|
||||
# no meta kernel registered, fallback to kernel for the device
|
||||
if has_symbolic_sizes or not self.can_run_unsafe_fallback(func):
|
||||
@ -1669,22 +1669,6 @@ def run_fallback_kernel(
|
||||
return pytree.tree_map(map_out, r)
|
||||
|
||||
|
||||
def can_generate_trivial_abstract_impl(op: torch._ops.OpOverload) -> bool:
|
||||
assert isinstance(op, torch._ops.OpOverload)
|
||||
if torch._library.utils.is_builtin(op):
|
||||
# We control the built-ins. These may (in rare cases)
|
||||
# do input metadata mutation (which we have banned on custom ops)
|
||||
return False
|
||||
schema = op._schema
|
||||
# It's suspicious if the op is not mutable but returns nothing, so we return False out of an abundance of caution
|
||||
if not schema.is_mutable:
|
||||
return False
|
||||
if len(schema.returns) > 0:
|
||||
return False
|
||||
# If the op returns nothing, then it has a trivial abstract impl.
|
||||
return True
|
||||
|
||||
|
||||
# Just for use to allow copying a module to fake tensors,
|
||||
# does not apply elsewhere
|
||||
class FakeCopyMode(TorchFunctionMode):
|
||||
|
@ -730,6 +730,8 @@ void initDispatchBindings(PyObject* module) {
|
||||
|
||||
py_context_manager_DEPRECATED<at::AutoDispatchBelowAutograd>(
|
||||
m, "_AutoDispatchBelowAutograd");
|
||||
py_context_manager<at::AutoDispatchBelowADInplaceOrView>(
|
||||
m, "_AutoDispatchBelowADInplaceOrView");
|
||||
|
||||
// Prints out the name of every operator that has a kernel registered to the
|
||||
// Dispatcher under [dispatch_key]. If no arguments are specified, it'll print
|
||||
|
Reference in New Issue
Block a user