Optimize stack frame inspection in torch._custom_op.impl:CustomOp._register_impl (#105940)

Summary: This is surprisingly expensive when the stack is deep. We can instead just process the specific stack frame that's relevant -- it's much faster.

Test Plan:
```
import inspect
import sys
import time

def make_deep_stack(fn, n: int = 10):
    if n > 0:
        return make_deep_stack(fn, n - 1)

    return fn()

def full_stack():
    return inspect.stack()[1][3]

def via_current_frame():
    return inspect.getframeinfo(sys._getframe(1))[2]

start = time.perf_counter()
for _ in range(1000):
    make_deep_stack(full_stack)
print(f"full_stack took {time.perf_counter() - start}s")

start = time.perf_counter()
for _ in range(1000):
    make_deep_stack(via_current_frame)
print(f"via_current_frame took {time.perf_counter() - start}s")

> full_stack took 31.788201928138733s
> via_current_frame took 2.33455612603575s
```

Differential Revision: D47674015

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105940
Approved by: https://github.com/zou3519
This commit is contained in:
Jeffrey Dunn
2023-07-31 15:49:31 +00:00
committed by PyTorch MergeBot
parent c54afea6ee
commit c5b9dc1f40

View File

@ -2,6 +2,7 @@ import contextlib
import dataclasses
import functools
import inspect
import sys
import typing
import weakref
@ -210,7 +211,7 @@ class CustomOp:
f"that already has a {kind} impl registered from Python at "
f"{location}. This is not supported."
)
frame = inspect.stack()[stacklevel]
frame = inspect.getframeinfo(sys._getframe(stacklevel))
location = f"{frame.filename}:{frame.lineno}"
self._impls[kind] = FuncAndLocation(func, location)