[custom_ops] mutated_args -> mutates_args (#123437)

This seemed better, since when you're construction a custom op you need
to provide "the args that the custom op mutates".

Test Plan:
- existing tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123437
Approved by: https://github.com/albanD
ghstack dependencies: #123108, #123109, #123110, #123129
This commit is contained in:
rzou
2024-04-05 11:58:24 -07:00
committed by PyTorch MergeBot
parent 81e7a7c955
commit cd6c58baea
4 changed files with 44 additions and 43 deletions

View File

@ -769,13 +769,13 @@ def validate_function_matches_schema(
compare(kwargonly, schema.arguments.flat_kwarg_only)
def infer_schema(prototype_function: typing.Callable, mutated_args=()) -> str:
def infer_schema(prototype_function: typing.Callable, mutates_args=()) -> str:
"""Given a function with type hints, parses a schema.
We make some assumptions to make our lives easier that correspond to how people
write custom ops in real life:
- none of the outputs alias any of the inputs or each other.
- only the args listed in mutated_args are being mutated.
- only the args listed in mutates_args are being mutated.
Callers (e.g. the custom ops API) are responsible for checking these assumptions.
"""
@ -808,17 +808,17 @@ def infer_schema(prototype_function: typing.Callable, mutated_args=()) -> str:
f"default values that invokes the custom op."
)
schema_type = SUPPORTED_PARAM_TYPES[param.annotation]
if name in mutated_args:
if name in mutates_args:
if not schema_type.startswith("Tensor"):
error_fn(f"Parameter {name} is in mutable_args but only Tensors or collections of Tensors can be mutated")
schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor'):]}"
seen_args.add(name)
params.append(f"{schema_type} {name}")
mutated_args_not_seen = set(mutated_args) - seen_args
if len(mutated_args_not_seen) > 0:
error_fn(f"{mutated_args_not_seen} in mutated_args were not found in "
mutates_args_not_seen = set(mutates_args) - seen_args
if len(mutates_args_not_seen) > 0:
error_fn(f"{mutates_args_not_seen} in mutates_args were not found in "
f"the custom op's signature. "
f"mutated_args should contain the names of all args that the "
f"mutates_args should contain the names of all args that the "
f"custom op mutates.")
ret = parse_return(sig.return_annotation, error_fn)
return f"({', '.join(params)}) -> {ret}"