Add mutated_args field to custom_op (#123129)

If provided, we:
- autogenerate an ADInplaceOrView implementation
- assume that no mutated inputs are returned as outputs. There are
  already aliasing runtime checks that check this.

Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123129
Approved by: https://github.com/albanD
ghstack dependencies: #123108, #123109, #123110
This commit is contained in:
rzou
2024-04-05 11:58:23 -07:00
committed by PyTorch MergeBot
parent 9e8d2b6de2
commit 81e7a7c955
7 changed files with 151 additions and 24 deletions

View File

@ -787,6 +787,7 @@ def infer_schema(prototype_function: typing.Callable, mutated_args=()) -> str:
)
params = []
seen_args = set()
for idx, (name, param) in enumerate(sig.parameters.items()):
if not supported_param(param):
error_fn("We do not support positional-only args, varargs, or varkwargs.")
@ -811,7 +812,14 @@ def infer_schema(prototype_function: typing.Callable, mutated_args=()) -> str:
if not schema_type.startswith("Tensor"):
error_fn(f"Parameter {name} is in mutable_args but only Tensors or collections of Tensors can be mutated")
schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor'):]}"
seen_args.add(name)
params.append(f"{schema_type} {name}")
mutated_args_not_seen = set(mutated_args) - seen_args
if len(mutated_args_not_seen) > 0:
error_fn(f"{mutated_args_not_seen} in mutated_args were not found in "
f"the custom op's signature. "
f"mutated_args should contain the names of all args that the "
f"custom op mutates.")
ret = parse_return(sig.return_annotation, error_fn)
return f"({', '.join(params)}) -> {ret}"