dynamo tracing perf: cache on import_source: 52.9 -> 52.58 (#143058)

See #143056 for overall docs.

This PR: add cache to `InstructionTranslatorBase.import_source()`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143058
Approved by: https://github.com/jansel
ghstack dependencies: #143066, #143056
This commit is contained in:
Aaron Orenstein
2024-12-12 19:38:17 -08:00
committed by PyTorch MergeBot
parent b472d82c96
commit 6bcda3a21a
2 changed files with 48 additions and 0 deletions

View File

@ -26,6 +26,7 @@ import torch
import torch._logging
from torch._dynamo.exc import TensorifyScalarRestartAnalysis
from torch._guards import tracing, TracingContext
from torch.utils._functools import cache_method
from . import config, exc, logging as torchdynamo_logging, trace_rules, variables
from .bytecode_analysis import (
@ -1198,6 +1199,9 @@ class InstructionTranslatorBase(
unimplemented("Storing handles in globals - NYI")
self.output.side_effects.store_global(variable, name, value)
# Cache note: This cache only exists for the duration of this
# InstructionTranslator - so it should be safe to do.
@cache_method
def import_source(self, module_name):
"""Create an alias to a module for use in guards"""
if "torch_package" in module_name:

44
torch/utils/_functools.py Normal file
View File

@ -0,0 +1,44 @@
import functools
from typing import Callable, TypeVar
from typing_extensions import Concatenate, ParamSpec
_P = ParamSpec("_P")
_T = TypeVar("_T")
_C = TypeVar("_C")
# Sentinel used to indicate that cache lookup failed.
_cache_sentinel = object()
def cache_method(
f: Callable[Concatenate[_C, _P], _T]
) -> Callable[Concatenate[_C, _P], _T]:
"""
Like `@functools.cache` but for methods.
`@functools.cache` (and similarly `@functools.lru_cache`) shouldn't be used
on methods because it caches `self`, keeping it alive
forever. `@cache_method` ignores `self` so won't keep `self` alive (assuming
no cycles with `self` in the parameters).
Footgun warning: This decorator completely ignores self's properties so only
use it when you know that self is frozen or won't change in a meaningful
way (such as the wrapped function being pure).
"""
cache_name = "_cache_method_" + f.__name__
@functools.wraps(f)
def wrap(self: _C, *args: _P.args, **kwargs: _P.kwargs) -> _T:
assert not kwargs
if not (cache := getattr(self, cache_name, None)):
cache = {}
setattr(self, cache_name, cache)
cached_value = cache.get(args, _cache_sentinel)
if cached_value is not _cache_sentinel:
return cached_value
value = f(self, *args, **kwargs)
cache[args] = value
return value
return wrap