Compare commits

...

1 Commits

3 changed files with 55 additions and 1 deletions

View File

@ -151,6 +151,28 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
def test_inline_lru_cache_fn_with_default_args(a, b):
return inline_lru_cache_fn_with_default_args(a, 2, b)
def test_lru_cache_warning_issued_during_tracing(self):
import warnings
from functools import lru_cache
@lru_cache
def foo(x):
return x + 1
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
torch.compile(foo, backend="eager")(torch.randn(4))
for warning in w:
warning_message = str(warning.message)
if (
"Dynamo detected a call to a `functools.lru_cache` wrapped function"
in warning_message
):
break
else:
self.assertTrue(False, "Expected warning about lru_cache not found")
@make_test
def test_add(a, b):
return a + b

View File

@ -4717,6 +4717,29 @@ class ReproTests(torch._dynamo.test_case.TestCase):
):
f_compiled(a)
# https://github.com/pytorch/pytorch/issues/146598
@unittest.expectedFailure
def test_lru_cache_tracing(self):
from functools import lru_cache
counter = 0
@lru_cache
def cached_fn(x):
nonlocal counter
counter += 1
return x + 1
compiled_fn = torch.compile(cached_fn, backend="eager")
t = torch.randn(2, 2)
result1 = compiled_fn(t)
self.assertEqual(counter, 1)
result2 = compiled_fn(t)
self.assertEqual(counter, 1)
self.assertEqual(result1, result2)
def test_dont_aggressively_write_assert(self):
record_graph = torch._dynamo.testing.EagerAndRecordGraphs()
@ -5431,6 +5454,7 @@ def forward(self, s77 : torch.SymInt, s27 : torch.SymInt, L_x_ : torch.Tensor):
mod = Mod()
opt_mod = torch.compile(mod, backend="eager", fullgraph=True)
x = torch.randn(4)
self.assertEqual(mod(x), opt_mod(x))
def test_enum(self):

View File

@ -29,14 +29,15 @@ import inspect
import itertools
import sys
import types
import warnings
from collections.abc import Sequence
from types import FunctionType
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar
from typing_extensions import Never
from unittest.mock import patch
from weakref import WeakKeyDictionary
import torch
from typing_extensions import Never
from .. import config, graph_break_hints, polyfills, variables
from ..bytecode_transformation import create_call_function, create_rot_n, is_generator
@ -445,6 +446,7 @@ class UserFunctionVariable(BaseUserFunctionVariable):
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
# Handle patch_dynamo_config call
if self.fn is torch._dynamo.patch_dynamo_config:
try:
args_const = [arg.as_python_constant() for arg in args]
@ -1534,6 +1536,12 @@ class WrapperUserFunctionVariable(VariableTracker):
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
if hasattr(self.wrapper_obj, "cache_info"):
warnings.warn(
"Dynamo detected a call to a `functools.lru_cache` wrapped function."
"Dynamo currently ignores `functools.lru_cache` and directly traces the wrapped function."
"`functools.lru_cache` wrapped functions that read outside state may not be traced soundly."
)
return variables.UserFunctionVariable(
polyfills.getattr_and_trace
).call_function(