Optimize mutable torch.library.custom_op overhead (#139513)

We don't need to do a loop over all the args, kwargs in the
AdInplaceOrView key; we just need to bump the version on the args,
kwargs that are mutable.

On the benchmark mentioned in
https://github.com/pytorch/pytorch/issues/139494
this made the time go from
```
mutate2 = 61.72943878173828
no_mutate2 = 36.89440155029297
mutate = 236.3092498779297
no_mutate = 59.31964874267578

```
to
```
mutate2 = 47.976478576660156
no_mutate2 = 38.37468719482422
mutate = 71.21315002441406
no_mutate = 59.7432975769043
```

Test Plan:
- existing tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139513
Approved by: https://github.com/bdhirsh
ghstack dependencies: #139509
This commit is contained in:
rzou
2024-11-01 13:21:12 -07:00
committed by PyTorch MergeBot
parent 9dc5851f5d
commit 27ec3921bc
2 changed files with 27 additions and 12 deletions

View File

@ -3,7 +3,7 @@ import dataclasses
import inspect
import sys
import warnings
from typing import Any, Callable, Dict, Iterable, Iterator, Tuple, Union
from typing import Any, Callable, Dict, Iterable, Iterator, List, Tuple, Union
import torch
import torch.utils._pytree as pytree
@ -463,3 +463,15 @@ def has_fake_kernel(op: torch._ops.OpOverload) -> bool:
if opdef._abstract_fn is not None:
return True
return False
def mutated_args_kwargs(schema: _C.FunctionSchema) -> Tuple[List[int], List[str]]:
idxs = []
keys = []
for i, info in enumerate(schema.arguments):
if info.alias_info is not None and info.alias_info.is_write:
if info.kwarg_only:
keys.append(info.name)
else:
idxs.append(i)
return idxs, keys