mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add mutated_args field to custom_op (#123129)
If provided, we: - autogenerate an ADInplaceOrView implementation - assume that no mutated inputs are returned as outputs. There are already aliasing runtime checks that check this. Test Plan: - new tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/123129 Approved by: https://github.com/albanD ghstack dependencies: #123108, #123109, #123110
This commit is contained in:
@ -787,6 +787,7 @@ def infer_schema(prototype_function: typing.Callable, mutated_args=()) -> str:
|
||||
)
|
||||
|
||||
params = []
|
||||
seen_args = set()
|
||||
for idx, (name, param) in enumerate(sig.parameters.items()):
|
||||
if not supported_param(param):
|
||||
error_fn("We do not support positional-only args, varargs, or varkwargs.")
|
||||
@ -811,7 +812,14 @@ def infer_schema(prototype_function: typing.Callable, mutated_args=()) -> str:
|
||||
if not schema_type.startswith("Tensor"):
|
||||
error_fn(f"Parameter {name} is in mutable_args but only Tensors or collections of Tensors can be mutated")
|
||||
schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor'):]}"
|
||||
seen_args.add(name)
|
||||
params.append(f"{schema_type} {name}")
|
||||
mutated_args_not_seen = set(mutated_args) - seen_args
|
||||
if len(mutated_args_not_seen) > 0:
|
||||
error_fn(f"{mutated_args_not_seen} in mutated_args were not found in "
|
||||
f"the custom op's signature. "
|
||||
f"mutated_args should contain the names of all args that the "
|
||||
f"custom op mutates.")
|
||||
ret = parse_return(sig.return_annotation, error_fn)
|
||||
return f"({', '.join(params)}) -> {ret}"
|
||||
|
||||
|
Reference in New Issue
Block a user