mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[custom_ops] mutated_args -> mutates_args (#123437)
This seemed better, since when you're construction a custom op you need to provide "the args that the custom op mutates". Test Plan: - existing tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/123437 Approved by: https://github.com/albanD ghstack dependencies: #123108, #123109, #123110, #123129
This commit is contained in:
@ -769,13 +769,13 @@ def validate_function_matches_schema(
|
||||
compare(kwargonly, schema.arguments.flat_kwarg_only)
|
||||
|
||||
|
||||
def infer_schema(prototype_function: typing.Callable, mutated_args=()) -> str:
|
||||
def infer_schema(prototype_function: typing.Callable, mutates_args=()) -> str:
|
||||
"""Given a function with type hints, parses a schema.
|
||||
|
||||
We make some assumptions to make our lives easier that correspond to how people
|
||||
write custom ops in real life:
|
||||
- none of the outputs alias any of the inputs or each other.
|
||||
- only the args listed in mutated_args are being mutated.
|
||||
- only the args listed in mutates_args are being mutated.
|
||||
|
||||
Callers (e.g. the custom ops API) are responsible for checking these assumptions.
|
||||
"""
|
||||
@ -808,17 +808,17 @@ def infer_schema(prototype_function: typing.Callable, mutated_args=()) -> str:
|
||||
f"default values that invokes the custom op."
|
||||
)
|
||||
schema_type = SUPPORTED_PARAM_TYPES[param.annotation]
|
||||
if name in mutated_args:
|
||||
if name in mutates_args:
|
||||
if not schema_type.startswith("Tensor"):
|
||||
error_fn(f"Parameter {name} is in mutable_args but only Tensors or collections of Tensors can be mutated")
|
||||
schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor'):]}"
|
||||
seen_args.add(name)
|
||||
params.append(f"{schema_type} {name}")
|
||||
mutated_args_not_seen = set(mutated_args) - seen_args
|
||||
if len(mutated_args_not_seen) > 0:
|
||||
error_fn(f"{mutated_args_not_seen} in mutated_args were not found in "
|
||||
mutates_args_not_seen = set(mutates_args) - seen_args
|
||||
if len(mutates_args_not_seen) > 0:
|
||||
error_fn(f"{mutates_args_not_seen} in mutates_args were not found in "
|
||||
f"the custom op's signature. "
|
||||
f"mutated_args should contain the names of all args that the "
|
||||
f"mutates_args should contain the names of all args that the "
|
||||
f"custom op mutates.")
|
||||
ret = parse_return(sig.return_annotation, error_fn)
|
||||
return f"({', '.join(params)}) -> {ret}"
|
||||
|
Reference in New Issue
Block a user