mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Optimize mutable torch.library.custom_op overhead (#139513)
We don't need to do a loop over all the args, kwargs in the AdInplaceOrView key; we just need to bump the version on the args, kwargs that are mutable. On the benchmark mentioned in https://github.com/pytorch/pytorch/issues/139494 this made the time go from ``` mutate2 = 61.72943878173828 no_mutate2 = 36.89440155029297 mutate = 236.3092498779297 no_mutate = 59.31964874267578 ``` to ``` mutate2 = 47.976478576660156 no_mutate2 = 38.37468719482422 mutate = 71.21315002441406 no_mutate = 59.7432975769043 ``` Test Plan: - existing tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/139513 Approved by: https://github.com/bdhirsh ghstack dependencies: #139509
This commit is contained in:
@ -3,7 +3,7 @@ import dataclasses
|
||||
import inspect
|
||||
import sys
|
||||
import warnings
|
||||
from typing import Any, Callable, Dict, Iterable, Iterator, Tuple, Union
|
||||
from typing import Any, Callable, Dict, Iterable, Iterator, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
@ -463,3 +463,15 @@ def has_fake_kernel(op: torch._ops.OpOverload) -> bool:
|
||||
if opdef._abstract_fn is not None:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def mutated_args_kwargs(schema: _C.FunctionSchema) -> Tuple[List[int], List[str]]:
|
||||
idxs = []
|
||||
keys = []
|
||||
for i, info in enumerate(schema.arguments):
|
||||
if info.alias_info is not None and info.alias_info.is_write:
|
||||
if info.kwarg_only:
|
||||
keys.append(info.name)
|
||||
else:
|
||||
idxs.append(i)
|
||||
return idxs, keys
|
||||
|
Reference in New Issue
Block a user